CPTrans复现发现问题的最后一版

This commit is contained in:
bishe 2025-03-15 15:00:00 +08:00
parent e0dc08030c
commit f98c285950
10000 changed files with 81 additions and 102 deletions

View File

@ -67,63 +67,46 @@ class ContentAwareOptimization(nn.Module):
super().__init__()
self.lambda_inc = lambda_inc
self.eta_ratio = eta_ratio
self.gradients_real = []
self.gradients_fake = []
self.gradients = [] # 修改为单一梯度列表,通用性更强
self.criterionGAN = networks.GANLoss('lsgan').cuda()
def compute_cosine_similarity(self, gradients):
mean_grad = torch.mean(gradients, dim=1, keepdim=True)
return F.cosine_similarity(gradients, mean_grad, dim=2)
def generate_weight_map(self, gradients_real, gradients_fake):
# 计算余弦相似度
cosine_real = self.compute_cosine_similarity(gradients_real)
cosine_fake = self.compute_cosine_similarity(gradients_fake)
def generate_weight_map(self, gradients):
cosine = self.compute_cosine_similarity(gradients)
k = int(self.eta_ratio * cosine.shape[1])
_, indices = torch.topk(-cosine, k, dim=1)
weights = torch.ones_like(cosine)
weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices))))
return weights
# 生成权重图(优化实现)
def _get_weights(cosine):
k = int(self.eta_ratio * cosine.shape[1])
_, indices = torch.topk(-cosine, k, dim=1)
weights = torch.ones_like(cosine)
weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices))))
return weights
def forward(self, features, scores, target):
"""
Args:
features: 特征张量可以是判别器的 real/fake 特征或生成器的 fake 特征
scores: 判别器对特征的预测得分
target: 目标标签True 表示希望判为真False 表示希望判为假
Returns:
loss: 加权后的 GAN 损失
weight: 生成的权重图
"""
self.gradients.clear()
# 注册梯度钩子
hook = lambda grad: self.gradients.append(grad.detach())
features.register_hook(hook)
weight_real = _get_weights(cosine_real)
weight_fake = _get_weights(cosine_fake)
return weight_real, weight_fake
def forward(self, D_real, D_fake, real_scores, fake_scores):
# 清空梯度缓存
self.gradients_real.clear()
self.gradients_fake.clear()
self.criterionGAN=networks.GANLoss('lsgan').cuda()
# 注册钩子捕获梯度
hook_real = lambda grad: self.gradients_real.append(grad.detach())
hook_fake = lambda grad: self.gradients_fake.append(grad.detach())
D_real.register_hook(hook_real)
D_fake.register_hook(hook_fake)
# 触发梯度计算(保留计算图)
(real_scores.mean() + fake_scores.mean()).backward(retain_graph=True)
# 触发梯度计算
scores.mean().backward(retain_graph=True)
# 获取梯度并调整维度
grad_real = self.gradients_real[0].flatten(1) # [B, N, D] → [B, N*D]
grad_fake = self.gradients_fake[0].flatten(1)
grad = self.gradients[0].flatten(1) # [B, N, D] → [B, N*D]
weight = self.generate_weight_map(grad.view(*features.shape))
# 生成权重图
weight_real, weight_fake = self.generate_weight_map(
grad_real.view(*D_real.shape),
grad_fake.view(*D_fake.shape)
)
# 正确应用权重到对数概率论文公式7
loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores , True))
loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores , False))
# 总损失(注意符号:判别器需最大化该损失)
loss_co_adv = (loss_co_real + loss_co_fake)*0.5
return loss_co_adv, weight_real, weight_fake
# 计算加权 GAN 损失
loss = torch.mean(weight * self.criterionGAN(scores, target))
return loss, weight
class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
@ -342,66 +325,59 @@ class RomaUnsbModel(BaseModel):
"""Calculate GAN loss with Content-Aware Optimization"""
lambda_D_ViT = self.opt.lambda_D_ViT
loss_cao = 0.0
# 处理 real_B0 和 fake_B0
real_B0_tokens = self.mutil_real_B0_tokens[0]
pred_real0, real_features0 = self.netD_ViT(real_B0_tokens) # scores, features
pred_real0, real_features0 = self.netD_ViT(real_B0_tokens)
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
pred_fake0, fake_features0 = self.netD_ViT(fake_B0_tokens)
loss_real0, self.weight_real0 = self.cao(real_features0, pred_real0, True)
loss_fake0, self.weight_fake0 = self.cao(fake_features0, pred_fake0, False)
# 处理 real_B1 和 fake_B1
real_B1_tokens = self.mutil_real_B1_tokens[0]
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens) # scores, features
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens)
fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach()
pred_fake1, fake_features1 = self.netD_ViT(fake_B1_tokens)
pre_fake0, fake_features0 = self.netD_ViT(self.mutil_fake_B0_tokens[0].detach())
pre_fake1, fake_features1 = self.netD_ViT(self.mutil_fake_B1_tokens[0].detach())
loss_cao0, self.weight_real0, self.weight_fake0 = self.cao(
D_real=real_features0,
D_fake=fake_features0,
real_scores=pred_real0,
fake_scores=pre_fake0
)
loss_cao1, self.weight_real1, self.weight_fake1 = self.cao(
D_real=real_features1,
D_fake=fake_features1,
real_scores=pred_real1,
fake_scores=pre_fake1
)
loss_cao += loss_cao0 + loss_cao1
loss_real1, self.weight_real1 = self.cao(real_features1, pred_real1, True)
loss_fake1, self.weight_fake1 = self.cao(fake_features1, pred_fake1, False)
# ===== 综合损失 =====
self.loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT
# 记录损失值供可视化
# self.loss_D_real = loss_D_real.item()
# self.loss_D_fake = loss_D_fake.item()
# self.loss_cao = (loss_cao0 + loss_cao1).item() * 0.5
# 综合损失
self.loss_D_ViT = (loss_real0 + loss_fake0 + loss_real1 + loss_fake1) * 0.25 * lambda_D_ViT
return self.loss_D_ViT
def compute_G_loss(self):
"""计算生成器的 GAN 损失"""
"""计算生成器的损失"""
# 初始化总损失
self.loss_G_GAN = 0.0
self.loss_ctn = 0.0
self.loss_global = 0.0
self.loss_spatial = 0.0
# 计算 CTN 损失
if self.opt.lambda_ctn > 0.0:
# 生成图像的CTN光流图
# 生成光流图(使用判别器的权重)
self.f_content0 = self.ctn(self.weight_fake0.detach())
self.f_content1 = self.ctn(self.weight_fake1.detach())
# 变换后的图片
self.warped_real_A0 = warp(self.real_A0, self.f_content0)
self.warped_real_A1 = warp(self.real_A1, self.f_content1)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content0)
self.warped_fake_B1 = warp(self.fake_B1,self.f_content1)
self.warped_fake_B0 = warp(self.fake_B0, self.f_content0)
self.warped_fake_B1 = warp(self.fake_B1, self.f_content1)
# 经过第二次生成
# 第二次生成
self.warped_fake_B0_2 = self.netG(self.warped_real_A0)
self.warped_fake_B1_2 = self.netG(self.warped_real_A1)
warped_fake_B0_2=self.warped_fake_B0_2
warped_fake_B1_2=self.warped_fake_B1_2
warped_fake_B0=self.warped_fake_B0
warped_fake_B1=self.warped_fake_B1
# 计算L2损失
self.loss_ctn0 = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
self.loss_ctn1 = F.mse_loss(warped_fake_B1_2, warped_fake_B1)
self.loss_ctn = (self.loss_ctn0 + self.loss_ctn1)*0.5
# 计算 L2 损失
self.loss_ctn0 = F.mse_loss(self.warped_fake_B0_2, self.warped_fake_B0)
self.loss_ctn1 = F.mse_loss(self.warped_fake_B1_2, self.warped_fake_B1)
self.loss_ctn = (self.loss_ctn0 + self.loss_ctn1) * 0.5
# 计算 GAN 损失(引入 ContentAwareOptimization
if self.opt.lambda_GAN > 0.0:
pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens[0])
@ -415,13 +391,12 @@ class RomaUnsbModel(BaseModel):
if self.opt.lambda_global or self.opt.lambda_spatial > 0.0:
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
else:
self.loss_global, self.loss_spatial = 0.0, 0.0
# 总损失
self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \
self.opt.lambda_ctn * self.loss_ctn + \
self.loss_global * self.opt.lambda_global+\
self.loss_spatial * self.opt.lambda_spatial
self.opt.lambda_ctn * self.loss_ctn + \
self.opt.lambda_global * self.loss_global + \
self.opt.lambda_spatial * self.loss_spatial
return self.loss_G

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 129 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 130 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 131 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 132 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 132 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 132 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 132 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 133 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 133 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 133 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 133 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 135 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 135 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 138 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 138 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 135 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 137 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 136 KiB

Some files were not shown because too many files have changed in this diff Show More