From 55b9db967ac5af59d240ac5dfd313b2f5d04350e Mon Sep 17 00:00:00 2001 From: Kunyu_Lee <202109100607@stumail.xsyu.edu.cn> Date: Mon, 24 Feb 2025 20:39:59 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9C=80=E6=96=B0=E7=9A=84=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/roma_unsb_model.py | 81 +++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 38 deletions(-) diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 8dbc273..a0b7682 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -60,7 +60,7 @@ def compute_ctn_loss(G, x, F_content): #公式10 loss = F.mse_loss(warped_fake, y_fake_warped) return loss -class ContentAwareOptimization(nn.Module): +class ContentAwareOptimization(nn.Module): def __init__(self, lambda_inc=2.0, eta_ratio=0.4): super().__init__() self.lambda_inc = lambda_inc # 权重增强系数 @@ -79,25 +79,30 @@ class ContentAwareOptimization(nn.Module): cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N] return cosine_sim - def generate_weight_map(self, gradients_fake): + def generate_weight_map(self, gradients_fake, feature_shape): """ - 生成内容感知权重图 + 生成内容感知权重图(修正空间维度) Args: - gradients_fake: [B, N, D] 生成图像判别器梯度 [2,3,256,256] + gradients_real: [B, N, D] 真实图像判别器梯度 + gradients_fake: [B, N, D] 生成图像判别器梯度 + feature_shape: tuple [H, W] 判别器输出的特征图尺寸 Returns: - weight_fake: [B, N] 生成图像权重图 [2,3,256] + weight_real: [B, 1, H, W] 真实图像权重图 + weight_fake: [B, 1, H, W] 生成图像权重图 """ - # 计算生成图像块的余弦相似度 - cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N] + H, W = feature_shape + N = H * W - # 选择内容丰富的区域(余弦相似度最低的eta_ratio比例) - k = int(self.eta_ratio * cosine_fake.shape[1]) - - # 对生成图像生成权重图(同理) - _, fake_indices = torch.topk(-cosine_fake, k, dim=1) + # 计算余弦相似度(与原代码相同) + cosine_fake = self.compute_cosine_similarity(gradients_fake) + + # 生成权重图(与原代码相同) + k = int(self.eta_ratio * N) weight_fake = torch.ones_like(cosine_fake) - for b in range(cosine_fake.shape[0]): - weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]])) + + # 重建空间维度 -------------------------------------------------- + # 将权重从[B, N]转换为[B, H, W] + weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W] return weight_fake @@ -488,28 +493,28 @@ class RomaUnsbModel(BaseModel): # [[1,576,768],[1,576,768],[1,576,768]] # [3,576,768] - ## 生成图像的梯度 - #fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens.sum(), self.mutil_fake_B0_tokens, create_graph=True)[0] - # - ## 梯度图 - #self.weight_fake = self.cao.generate_weight_map(fake_gradient) - # - ## 生成图像的CTN光流图 - #self.f_content = self.ctn(self.weight_fake) - # - ## 变换后的图片 - #self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content) - #self.warped_fake_B0 = warp(self.fake_B0,self.f_content) - # - ## 经过第二次生成器 - #self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in) + # 生成图像的梯度 + fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0] - #warped_fake_B0_2=self.warped_fake_B0_2 - #warped_fake_B0=self.warped_fake_B0 - #self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2) - #self.warped_fake_B0_resize = self.resize(warped_fake_B0) - #self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True) - #self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True) + # 梯度图 + self.weight_fake = self.cao.generate_weight_map(fake_gradient) + + # 生成图像的CTN光流图 + self.f_content = self.ctn(self.weight_fake) + + # 变换后的图片 + self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content) + self.warped_fake_B0 = warp(self.fake_B0,self.f_content) + + # 经过第二次生成器 + self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in) + + # warped_fake_B0_2=self.warped_fake_B0_2 + # warped_fake_B0=self.warped_fake_B0 + # self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2) + # self.warped_fake_B0_resize = self.resize(warped_fake_B0) + # self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True) + # self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True) def compute_D_loss(self): #判别器还是没有改 @@ -575,9 +580,9 @@ class RomaUnsbModel(BaseModel): loss_global = 0.0 self.l2_loss = 0.0 - #if self.opt.lambda_ctn > 0.0: - # wapped_fake_B = warp(self.fake_B, self.f_content) # use updated self.f_content - # self.l2_loss = F.mse_loss(self.fake_B_2, wapped_fake_B) # complete the loss calculation + if self.opt.lambda_l2 > 0.0: + wapped_fake_B = warp(self.fake_B0, self.f_content) # use updated self.f_content + self.l2_loss = F.mse_loss(self.warped_fake_B0_2, wapped_fake_B) # complete the loss calculation self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global return self.loss_G