This commit is contained in:
areszz 2025-02-23 14:37:14 +08:00
parent 8cd61d0503
commit 09d363ced6
2 changed files with 42 additions and 99 deletions

View File

@ -68,9 +68,6 @@ class ROMAModel(BaseModel):
# From UNSB # From UNSB
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
# Deine another generator
self.netG_2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
self.norm = F.softmax self.norm = F.softmax
@ -186,8 +183,6 @@ class ROMAModel(BaseModel):
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接 # 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
self.real_A_noisy = Xt.detach() self.real_A_noisy = Xt.detach()
self.real_A_noisy2 = Xt2.detach() self.real_A_noisy2 = Xt2.detach()
# 保存noisy_map
self.noisy_map = self.real_A_noisy - self.real_A
# ============ 第三步:拼接输入并执行网络推理 ============= # ============ 第三步:拼接输入并执行网络推理 =============
bs = self.real_A0.size(0) bs = self.real_A0.size(0)
@ -206,7 +201,23 @@ class ROMAModel(BaseModel):
self.fake_B0 = self.netG(self.real_A0) self.fake_B0 = self.netG(self.real_A0)
self.fake_B1 = self.netG(self.real_A1) self.fake_B1 = self.netG(self.real_A1)
if self.opt.phase == 'train':
# 生成图像的梯度
fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0]
# 梯度图
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
# 生成图像的CTN光流图
self.f_content = self.ctn(self.weight_fake)
# 变换后的图片
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
# 经过第二次生成器
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
if self.opt.isTrain: if self.opt.isTrain:
real_A0 = self.real_A0 real_A0 = self.real_A0
real_A1 = self.real_A1 real_A1 = self.real_A1
@ -214,98 +225,41 @@ class ROMAModel(BaseModel):
real_B1 = self.real_B1 real_B1 = self.real_B1
fake_B0 = self.fake_B0 fake_B0 = self.fake_B0
fake_B1 = self.fake_B1 fake_B1 = self.fake_B1
warped_fake_B0_2=self.warped_fake_B0_2
warped_fake_B0=self.warped_fake_B0
self.real_A0_resize = self.resize(real_A0) self.real_A0_resize = self.resize(real_A0)
self.real_A1_resize = self.resize(real_A1) self.real_A1_resize = self.resize(real_A1)
real_B0 = self.resize(real_B0) real_B0 = self.resize(real_B0)
real_B1 = self.resize(real_B1) real_B1 = self.resize(real_B1)
self.fake_B0_resize = self.resize(fake_B0) self.fake_B0_resize = self.resize(fake_B0)
self.fake_B1_resize = self.resize(fake_B1) self.fake_B1_resize = self.resize(fake_B1)
self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
self.warped_fake_B0_resize = self.resize(warped_fake_B0)
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True) self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True) self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True) self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True) self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
if self.opt.phase == 'train':
# 真实图像的梯度
real_gradient = torch.autograd.grad(self.real_B.sum(), self.real_B, create_graph=True)[0]
# 生成图像的梯度
fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0]
# 梯度图
self.weight_real, self.weight_fake = self.cao.generate_weight_map(real_gradient, fake_gradient)
# 生成图像的CTN光流图
self.f_content = self.ctn(self.weight_fake)
# 把前面生成后的图片再加上noisy_map
self.fake_B0_2 = self.fake_B0 + self.noisy_map
# 变换后的图片
wapped_fake_B0_2 = warp(self.fake_B0_2, self.f_content)
# 经过第二次生成器
self.fake_B0_2 = self.netG_2(wapped_fake_B0_2, self.time, z_in)
def tokens_concat(self, origin_tokens, adjacent_size):
adj_size = adjacent_size
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
S = int(math.sqrt(token_num))
if S * S != token_num:
print('Error! Not a square!')
token_map = origin_tokens.clone().reshape(B,S,S,C)
cut_patch_list = []
for i in range(0, S, adj_size):
for j in range(0, S, adj_size):
i_left = i
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
j_left = j
j_right = j + adj_size if j + adj_size <= S else S + 1
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
cut_patch= cut_patch.reshape(B,-1,C)
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
cut_patch_list.append(cut_patch)
result = torch.cat(cut_patch_list,dim=1)
return result
def cat_results(self, origin_tokens, adj_size_list):
res_list = [origin_tokens]
for ad_s in adj_size_list:
cat_result = self.tokens_concat(origin_tokens, ad_s)
res_list.append(cat_result)
result = torch.cat(res_list, dim=1)
return result
def compute_D_loss(self): def compute_D_loss(self): #判别器还是没有改
"""Calculate GAN loss for the discriminator""" """Calculate GAN loss for the discriminator"""
lambda_D_ViT = self.opt.lambda_D_ViT lambda_D_ViT = self.opt.lambda_D_ViT
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer].detach() fake_B0_tokens = self.mutil_fake_B0_tokens.detach()
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer].detach() fake_B1_tokens = self.mutil_fake_B1_tokens.detach()
real_B0_tokens = self.mutil_real_B0_tokens[self.opt.which_D_layer] real_B0_tokens = self.mutil_real_B0_tokens
real_B1_tokens = self.mutil_real_B1_tokens[self.opt.which_D_layer] real_B1_tokens = self.mutil_real_B1_tokens
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
real_B0_tokens = self.cat_results(real_B0_tokens, self.opt.adj_size_list)
real_B1_tokens = self.cat_results(real_B1_tokens, self.opt.adj_size_list)
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens) pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens) pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
@ -336,10 +290,9 @@ class ROMAModel(BaseModel):
if self.opt.lambda_GAN > 0.0: if self.opt.lambda_GAN > 0.0:
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer] fake_B0_tokens = self.mutil_fake_B0_tokens
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer] fake_B1_tokens = self.mutil_fake_B1_tokens
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
pred_fake0_ViT = self.netD_ViT(fake_B0_tokens) pred_fake0_ViT = self.netD_ViT(fake_B0_tokens)
pred_fake1_ViT = self.netD_ViT(fake_B1_tokens) pred_fake1_ViT = self.netD_ViT(fake_B1_tokens)
self.loss_G_GAN_ViT = (self.criterionGAN(pred_fake0_ViT, True) + self.criterionGAN(pred_fake1_ViT, True)) * 0.5 * self.opt.lambda_GAN self.loss_G_GAN_ViT = (self.criterionGAN(pred_fake0_ViT, True) + self.criterionGAN(pred_fake1_ViT, True)) * 0.5 * self.opt.lambda_GAN
@ -357,8 +310,8 @@ class ROMAModel(BaseModel):
# eq.9 # eq.9
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0) ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0)
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B0) ** 2) self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy2 - self.fake_B1) ** 2)
if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0: if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
@ -368,8 +321,8 @@ class ROMAModel(BaseModel):
if self.opt.lambda_ctn > 0.0: if self.opt.lambda_ctn > 0.0:
wapped_fake_B1 = warp(self.fake_B1, self.f_content) # use updated self.f_content warped_fake_B1 = warp(self.fake_B0, self.f_content) # use updated self.f_content
self.l2_loss = F.mse_loss(self.fake_B0_2, wapped_fake_B1) * self.opt.lambda_ctn self.l2_loss = F.mse_loss(self.warped_fake_B0_2, warped_fake_B1) * self.opt.lambda_ctn
else: else:
self.l2_loss = 0.0 self.l2_loss = 0.0

View File

@ -79,37 +79,27 @@ class ContentAwareOptimization(nn.Module):
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N] cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
return cosine_sim return cosine_sim
def generate_weight_map(self, gradients_real, gradients_fake): def generate_weight_map(self, gradients_fake):
""" """
生成内容感知权重图 生成内容感知权重图
Args: Args:
gradients_real: [B, N, D] 真实图像判别器梯度
gradients_fake: [B, N, D] 生成图像判别器梯度 gradients_fake: [B, N, D] 生成图像判别器梯度
Returns: Returns:
weight_real: [B, N] 真实图像权重图
weight_fake: [B, N] 生成图像权重图 weight_fake: [B, N] 生成图像权重图
""" """
# 计算真实图像块的余弦相似度
cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5
# 计算生成图像块的余弦相似度 # 计算生成图像块的余弦相似度
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N] cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
# 选择内容丰富的区域余弦相似度最低的eta_ratio比例 # 选择内容丰富的区域余弦相似度最低的eta_ratio比例
k = int(self.eta_ratio * cosine_real.shape[1]) k = int(self.eta_ratio * cosine_fake.shape[1])
# 对真实图像生成权重图
_, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域
weight_real = torch.ones_like(cosine_real)
for b in range(cosine_real.shape[0]):
weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6
# 对生成图像生成权重图(同理) # 对生成图像生成权重图(同理)
_, fake_indices = torch.topk(-cosine_fake, k, dim=1) _, fake_indices = torch.topk(-cosine_fake, k, dim=1)
weight_fake = torch.ones_like(cosine_fake) weight_fake = torch.ones_like(cosine_fake)
for b in range(cosine_fake.shape[0]): for b in range(cosine_fake.shape[0]):
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]])) weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
return weight_real, weight_fake return weight_fake
def forward(self, D_real, D_fake, real_scores, fake_scores): def forward(self, D_real, D_fake, real_scores, fake_scores):
""" """
@ -458,7 +448,7 @@ class CTNxModel(BaseModel):
self.real_A_noisy = Xt.detach() self.real_A_noisy = Xt.detach()
self.real_A_noisy2 = Xt2.detach() self.real_A_noisy2 = Xt2.detach()
# 保存noisy_map # 保存noisy_map
self.noisy_map = self.real_A_noisy - self.real_A self.noisy_map = self.real_A_noisy - self.real_A0
# ============ 第三步:拼接输入并执行网络推理 ============= # ============ 第三步:拼接输入并执行网络推理 =============
bs = self.mutil_real_A0_tokens.size(0) bs = self.mutil_real_A0_tokens.size(0)