From 7af2de920c6be16a851e762562a97988ee4aa26c Mon Sep 17 00:00:00 2001 From: bishe <123456789@163.com> Date: Mon, 24 Feb 2025 21:13:36 +0800 Subject: [PATCH 1/2] renew --- models/roma_unsb_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index a0b7682..e54b36f 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -97,9 +97,13 @@ class ContentAwareOptimization(nn.Module): cosine_fake = self.compute_cosine_similarity(gradients_fake) # 生成权重图(与原代码相同) - k = int(self.eta_ratio * N) + k = int(self.eta_ratio * cosine_fake.shape[1]) + _, fake_indices = torch.topk(-cosine_fake, k, dim=1) weight_fake = torch.ones_like(cosine_fake) - + + for b in range(cosine_fake.shape[0]): + weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]])) + # 重建空间维度 -------------------------------------------------- # 将权重从[B, N]转换为[B, H, W] weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W] From e67b0f2511c0fa0e340f9cb7e58f6bc538844e85 Mon Sep 17 00:00:00 2001 From: bishe <123456789@163.com> Date: Mon, 24 Feb 2025 21:28:21 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E6=9C=80=E6=96=B0=E7=9A=84=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/roma_unsb_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index e54b36f..3563ddf 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -123,7 +123,7 @@ class ContentAwareOptimization(nn.Module): """ B, C, H, W = D_real.shape N = H * W - + shape_hw = [h, w] # 注册钩子获取梯度 gradients_real = [] gradients_fake = [] @@ -150,7 +150,7 @@ class ContentAwareOptimization(nn.Module): gradients_fake = gradients_fake[0] # [B, N, D] # 生成权重图 - self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake) + self.weight_real, self.weight_fake = self.generate_weight_map(gradients_fake, shape_hw ) # 应用权重到对抗损失 loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8)) @@ -496,12 +496,12 @@ class RomaUnsbModel(BaseModel): self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) # [[1,576,768],[1,576,768],[1,576,768]] # [3,576,768] - + shape_hw = list(self.real_A0_resize.shape[2:4]) # 生成图像的梯度 fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0] # 梯度图 - self.weight_fake = self.cao.generate_weight_map(fake_gradient) + self.weight_fake = self.cao.generate_weight_map(fake_gradient,shape_hw) # 生成图像的CTN光流图 self.f_content = self.ctn(self.weight_fake)