尝试在每一步都给判别器看,但是速度太慢了

This commit is contained in:
bishe 2025-03-07 18:43:06 +08:00
parent 76fcec26e8
commit c6cb68e700
2 changed files with 155 additions and 186 deletions

View File

@ -2,6 +2,7 @@ import numpy as np
import math import math
import timm import timm
import torch import torch
import torchvision.models as models
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision.transforms import GaussianBlur from torchvision.transforms import GaussianBlur
@ -60,87 +61,69 @@ def compute_ctn_loss(G, x, F_content): #公式10
loss = F.mse_loss(warped_fake, y_fake_warped) loss = F.mse_loss(warped_fake, y_fake_warped)
return loss return loss
class ContentAwareOptimization(nn.Module): class ContentAwareOptimization(nn.Module):
def __init__(self, lambda_inc=2.0, eta_ratio=0.4): def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
super().__init__() super().__init__()
self.lambda_inc = lambda_inc # 权重增强系数 self.lambda_inc = lambda_inc
self.eta_ratio = eta_ratio # 选择内容区域的比例 self.eta_ratio = eta_ratio
# 改为类成员变量,确保钩子函数可访问
self.gradients_real = [] self.gradients_real = []
self.gradients_fake = [] self.gradients_fake = []
def compute_cosine_similarity(self, gradients): def compute_cosine_similarity(self, gradients):
""" mean_grad = torch.mean(gradients, dim=1, keepdim=True)
计算每个patch梯度与平均梯度的余弦相似度 return F.cosine_similarity(gradients, mean_grad, dim=2)
Args:
gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h)
Returns:
cosine_sim: [B, N] 每个patch的余弦相似度
"""
mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D]
# 计算余弦相似度
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
return cosine_sim
def generate_weight_map(self, gradients_real, gradients_fake): def generate_weight_map(self, gradients_real, gradients_fake):
""" # 计算余弦相似度
生成内容感知权重图 cosine_real = self.compute_cosine_similarity(gradients_real)
Args: cosine_fake = self.compute_cosine_similarity(gradients_fake)
gradients_real: [B, N, D] 真实图像判别器梯度
gradients_fake: [B, N, D] 生成图像判别器梯度
Returns:
weight_real: [B, N] 真实图像权重图
weight_fake: [B, N] 生成图像权重图
"""
# 计算真实图像块的余弦相似度
cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5
# 计算生成图像块的余弦相似度
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
# 选择内容丰富的区域余弦相似度最低的eta_ratio比例 # 生成权重图(优化实现)
k = int(self.eta_ratio * cosine_real.shape[1]) def _get_weights(cosine):
k = int(self.eta_ratio * cosine.shape[1])
# 对真实图像生成权重图 _, indices = torch.topk(-cosine, k, dim=1)
_, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域 weights = torch.ones_like(cosine)
weight_real = torch.ones_like(cosine_real) weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices))))
for b in range(cosine_real.shape[0]): return weights
weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6
# 对生成图像生成权重图(同理)
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
weight_fake = torch.ones_like(cosine_fake)
for b in range(cosine_fake.shape[0]):
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
weight_real = _get_weights(cosine_real)
weight_fake = _get_weights(cosine_fake)
return weight_real, weight_fake return weight_real, weight_fake
def forward(self, D_real, D_fake, real_scores, fake_scores): def forward(self, D_real, D_fake, real_scores, fake_scores):
# 清空梯度缓存 # 清空梯度缓存
self.gradients_real.clear() self.gradients_real.clear()
self.gradients_fake.clear() self.gradients_fake.clear()
# 注册钩子 self.criterionGAN=networks.GANLoss('lsgan').cuda()
# 注册钩子捕获梯度
hook_real = lambda grad: self.gradients_real.append(grad.detach()) hook_real = lambda grad: self.gradients_real.append(grad.detach())
hook_fake = lambda grad: self.gradients_fake.append(grad.detach()) hook_fake = lambda grad: self.gradients_fake.append(grad.detach())
D_real.register_hook(hook_real) D_real.register_hook(hook_real)
D_fake.register_hook(hook_fake) D_fake.register_hook(hook_fake)
# 触发梯度计算 # 触发梯度计算(保留计算图)
(real_scores.mean() + fake_scores.mean()).backward(retain_graph=True) (real_scores.mean() + fake_scores.mean()).backward(retain_graph=True)
# 获取梯度并调整维度 # 获取梯度并调整维度
grad_real = self.gradients_real[0] # [B, N, D] grad_real = self.gradients_real[0].flatten(1) # [B, N, D] → [B, N*D]
grad_fake = self.gradients_fake[0] grad_fake = self.gradients_fake[0].flatten(1)
# 生成权重图 # 生成权重图
weight_real, weight_fake = self.generate_weight_map(grad_real, grad_fake) weight_real, weight_fake = self.generate_weight_map(
grad_real.view(*D_real.shape),
grad_fake.view(*D_fake.shape)
)
# 计算加权损失 # 正确应用权重到对数概率论文公式7
loss_co_real = (weight_real * real_scores).mean() loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores , True))
loss_co_fake = (weight_fake * fake_scores).mean() loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores , False))
return (loss_co_real + loss_co_fake), weight_real, weight_fake # 总损失(注意符号:判别器需最大化该损失)
loss_co_adv = (loss_co_real + loss_co_fake)*0.5
return loss_co_adv, weight_real, weight_fake
class ContentAwareTemporalNorm(nn.Module): class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0): def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
@ -219,8 +202,10 @@ class RomaUnsbModel(BaseModel):
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm') parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator') parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization') parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
parser.add_argument('--local_nums', type=int, default=64, help='number of local patches')
parser.add_argument('--side_length', type=int, default=7)
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))') parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
parser.add_argument('--nce_includes_all_negatives_from_minibatch', parser.add_argument('--nce_includes_all_negatives_from_minibatch',
type=util.str2bool, nargs='?', const=True, default=False, type=util.str2bool, nargs='?', const=True, default=False,
@ -229,12 +214,8 @@ class RomaUnsbModel(BaseModel):
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map') parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
parser.add_argument('--flip_equivariance',
type=util.str2bool, nargs='?', const=True, default=False,
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions') parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix')
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers') parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter') parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
@ -251,10 +232,11 @@ class RomaUnsbModel(BaseModel):
BaseModel.__init__(self, opt) BaseModel.__init__(self, opt)
# 指定需要打印的训练损失 # 指定需要打印的训练损失
self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn',] self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0'] self.visual_names = ['real_A0', 'fake_B0_1','fake_B0', 'real_B0','real_A1', 'fake_B1_1', 'fake_B1', 'real_B1']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')] self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
if self.opt.phase == 'test': if self.opt.phase == 'test':
self.visual_names = ['real'] self.visual_names = ['real']
for NFE in range(self.opt.num_timesteps): for NFE in range(self.opt.num_timesteps):
@ -262,12 +244,9 @@ class RomaUnsbModel(BaseModel):
self.visual_names.append(fake_name) self.visual_names.append(fake_name)
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')] self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
if opt.nce_idt and self.isTrain:
self.loss_names += ['NCE_Y']
self.visual_names += ['idt_B']
if self.isTrain: if self.isTrain:
self.model_names = ['G', 'D_ViT', 'E'] self.model_names = ['G', 'D_ViT']
else: else:
self.model_names = ['G'] self.model_names = ['G']
@ -277,7 +256,6 @@ class RomaUnsbModel(BaseModel):
if self.isTrain: if self.isTrain:
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
self.resize = tfs.Resize(size=(384,384), antialias=True) self.resize = tfs.Resize(size=(384,384), antialias=True)
@ -289,11 +267,9 @@ class RomaUnsbModel(BaseModel):
# 定义损失函数 # 定义损失函数
self.criterionL1 = torch.nn.L1Loss().to(self.device) self.criterionL1 = torch.nn.L1Loss().to(self.device)
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionIdt = torch.nn.L1Loss().to(self.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizers = [self.optimizer_G, self.optimizer_D]
self.optimizers = [self.optimizer_G, self.optimizer_D, self.optimizer_E]
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数 self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
self.ctn = ContentAwareTemporalNorm() #生成的伪光流 self.ctn = ContentAwareTemporalNorm() #生成的伪光流
@ -312,7 +288,6 @@ class RomaUnsbModel(BaseModel):
self.forward() self.forward()
self.netG.train() self.netG.train()
self.netE.train()
self.netD_ViT.train() self.netD_ViT.train()
# update D # update D
@ -322,19 +297,9 @@ class RomaUnsbModel(BaseModel):
self.loss_D.backward() self.loss_D.backward()
self.optimizer_D.step() self.optimizer_D.step()
# update E
self.set_requires_grad(self.netE, True)
self.optimizer_E.zero_grad()
self.loss_E = self.compute_E_loss()
self.loss_E.backward()
self.optimizer_E.step()
# update G # update G
self.set_requires_grad(self.netD_ViT, False) self.set_requires_grad(self.netD_ViT, False)
self.set_requires_grad(self.netE, False)
self.optimizer_G.zero_grad() self.optimizer_G.zero_grad()
self.loss_G = self.compute_G_loss() self.loss_G = self.compute_G_loss()
self.loss_G.backward() self.loss_G.backward()
self.optimizer_G.step() self.optimizer_G.step()
@ -370,6 +335,8 @@ class RomaUnsbModel(BaseModel):
bs = self.real_A0.size(0) bs = self.real_A0.size(0)
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long() time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
self.time_idx = time_idx self.time_idx = time_idx
self.fake_B0_list = []
self.fake_B1_list = []
with torch.no_grad(): with torch.no_grad():
self.netG.eval() self.netG.eval()
@ -387,36 +354,23 @@ class RomaUnsbModel(BaseModel):
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device) (scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device)
time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long() time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long()
z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device) z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device)
self.time = times[time_idx] time = times[time_idx]
Xt_1 = self.netG(Xt, self.time, z) Xt_1 = self.netG(Xt.detach(), time, z)
Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \ Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device) (scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device)
time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long() time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long()
z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device) z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device)
Xt_12 = self.netG(Xt2, self.time, z) Xt_12 = self.netG(Xt2.detach(), time, z)
self.fake_B0_list.append(Xt_1)
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接 self.fake_B1_list.append(Xt_12)
self.real_A_noisy = Xt.detach()
self.real_A_noisy2 = Xt2.detach()
# ============ 第三步:拼接输入并执行网络推理 =============
bs = self.real_A0.size(0)
self.z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device)
self.z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
self.real = self.real_A0
self.realt = self.real_A_noisy
if self.opt.flip_equivariance:
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
if self.flipped_for_equivariance:
self.real = torch.flip(self.real, [3])
self.realt = torch.flip(self.realt, [3])
self.fake_B0 = self.netG(self.real_A_noisy, self.time, self.z_in)
self.fake_B1 = self.netG(self.real_A_noisy2, self.time, self.z_in2)
self.fake_B0_1 = self.fake_B0_list[0]
self.fake_B1_1 = self.fake_B0_list[0]
self.fake_B0 = self.fake_B0_list[-1]
self.fake_B1 = self.fake_B1_list[-1]
self.z_in = z
self.z_in2 = z
if self.opt.phase == 'train': if self.opt.phase == 'train':
real_A0 = self.real_A0 real_A0 = self.real_A0
real_A1 = self.real_A1 real_A1 = self.real_A1
@ -424,6 +378,16 @@ class RomaUnsbModel(BaseModel):
real_B1 = self.real_B1 real_B1 = self.real_B1
fake_B0 = self.fake_B0 fake_B0 = self.fake_B0
fake_B1 = self.fake_B1 fake_B1 = self.fake_B1
self.mutil_fake_B0_tokens_list = []
self.mutil_fake_B1_tokens_list = []
for fake_B0_t in self.fake_B0_list:
fake_B0_t_resize = self.resize(fake_B0_t) # 调整到 ViT 输入尺寸
tokens = self.netPreViT(fake_B0_t_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens_list.append(tokens)
for fake_B1_t in self.fake_B1_list:
fake_B1_t_resize = self.resize(fake_B1_t)
tokens = self.netPreViT(fake_B1_t_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens_list.append(tokens)
self.real_A0_resize = self.resize(real_A0) self.real_A0_resize = self.resize(real_A0)
self.real_A1_resize = self.resize(real_A1) self.real_A1_resize = self.resize(real_A1)
@ -436,96 +400,105 @@ class RomaUnsbModel(BaseModel):
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True) self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True) self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True) self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
# [[1,576,768],[1,576,768],[1,576,768]] # [[1,576,768],[1,576,768],[1,576,768]]
# [3,576,768] # [3,576,768]
def compute_D_loss(self):
def compute_D_loss(self): #判别器还是没有改 """Calculate GAN loss with Content-Aware Optimization"""
"""Calculate GAN loss for the discriminator"""
lambda_D_ViT = self.opt.lambda_D_ViT lambda_D_ViT = self.opt.lambda_D_ViT
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
loss_cao = 0.0
real_B0_tokens = self.mutil_real_B0_tokens[0] real_B0_tokens = self.mutil_real_B0_tokens[0]
pred_real0, real_features0 = self.netD_ViT(real_B0_tokens) # scores, features
real_B1_tokens = self.mutil_real_B1_tokens[0]
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens) # scores, features
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens) for fake0_token, fake1_token in zip(self.mutil_fake_B0_tokens_list, self.mutil_fake_B1_tokens_list):
self.loss_D_fake_ViT = self.criterionGAN(pre_fake0_ViT, False) pre_fake0, fake_features0 = self.netD_ViT(fake0_token[0].detach())
pre_fake1, fake_features1 = self.netD_ViT(fake1_token[0].detach())
loss_cao0, self.weight_real0, self.weight_fake0 = self.cao(
D_real=real_features0,
D_fake=fake_features0,
real_scores=pred_real0,
fake_scores=pre_fake0
)
loss_cao1, self.weight_real1, self.weight_fake1 = self.cao(
D_real=real_features1,
D_fake=fake_features1,
real_scores=pred_real1,
fake_scores=pre_fake1
)
loss_cao += loss_cao0 + loss_cao1
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
self.loss_D_real_ViT = self.criterionGAN(pred_real0_ViT, True)
self.losscao, self.weight_real, self.weight_fake = self.cao(pred_real0_ViT, pre_fake0_ViT, self.loss_D_real_ViT, self.loss_D_fake_ViT) # ===== 综合损失 =====
self.loss_D_ViT = self.losscao* lambda_D_ViT total_steps = len(self.fake_B0_list)
self.loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT/ total_steps
# 记录损失值供可视化
# self.loss_D_real = loss_D_real.item()
# self.loss_D_fake = loss_D_fake.item()
# self.loss_cao = (loss_cao0 + loss_cao1).item() * 0.5
return self.loss_D_ViT return self.loss_D_ViT
def compute_E_loss(self):
"""计算判别器 E 的损失"""
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
return self.loss_E
def compute_G_loss(self): def compute_G_loss(self):
"""计算生成器的 GAN 损失""" """计算生成器的 GAN 损失"""
if self.opt.lambda_ctn > 0.0: if self.opt.lambda_ctn > 0.0:
# 生成图像的CTN光流图 # 生成图像的CTN光流图
self.f_content = self.ctn(self.weight_fake) self.f_content0 = self.ctn(self.weight_fake0)
self.f_content1 = self.ctn(self.weight_fake1)
# 变换后的图片 # 变换后的图片
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content) self.warped_real_A0 = warp(self.real_A0, self.f_content0)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content) self.warped_real_A1 = warp(self.real_A1, self.f_content1)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content0)
self.warped_fake_B1 = warp(self.fake_B1,self.f_content1)
# 经过第二次生成器 # 经过第二次生成器
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, self.z_in) self.warped_fake_B0_2 = self.netG(self.warped_real_A0, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in)
self.warped_fake_B1_2 = self.netG(self.warped_real_A1, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in2)
warped_fake_B0_2=self.warped_fake_B0_2 warped_fake_B0_2=self.warped_fake_B0_2
warped_fake_B1_2=self.warped_fake_B1_2
warped_fake_B0=self.warped_fake_B0 warped_fake_B0=self.warped_fake_B0
warped_fake_B1=self.warped_fake_B1
# 计算L2损失 # 计算L2损失
self.loss_ctn = F.mse_loss(warped_fake_B0_2, warped_fake_B0) self.loss_ctn0 = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
self.loss_ctn1 = F.mse_loss(warped_fake_B1_2, warped_fake_B1)
self.loss_ctn = (self.loss_ctn0 + self.loss_ctn1)*0.5
if self.opt.lambda_GAN > 0.0: if self.opt.lambda_GAN > 0.0:
pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens_list[-1][0])
pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens_list[-1][0])
self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean()
self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean()
self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5
else: else:
self.loss_G_GAN = 0.0 self.loss_G_GAN = 0.0
self.loss_SB = 0
if self.opt.lambda_SB > 0.0:
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1)
bs = self.opt.batch_size if self.opt.lambda_global or self.opt.lambda_spatial > 0.0:
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
# eq.9
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - self.netE(XtXt_1, self.time, XtXt_2).mean()
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
if self.opt.lambda_global > 0.0:
self.loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
self.loss_global *= 0.5
else: else:
self.loss_global = 0.0 self.loss_global, self.loss_spatial = 0.0, 0.0
self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \ self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \
self.opt.lambda_SB * self.loss_SB + \
self.opt.lambda_ctn * self.loss_ctn + \ self.opt.lambda_ctn * self.loss_ctn + \
self.loss_global * self.opt.lambda_global self.loss_global * self.opt.lambda_global+\
self.loss_spatial * self.opt.lambda_spatial
return self.loss_G return self.loss_G
def calculate_attention_loss(self): def calculate_attention_loss(self):
n_layers = len(self.atten_layers) n_layers = len(self.atten_layers)
mutil_real_A0_tokens = self.mutil_real_A0_tokens mutil_real_A0_tokens = self.mutil_real_A0_tokens
mutil_real_A1_tokens = self.mutil_real_A1_tokens mutil_real_A1_tokens = self.mutil_real_A1_tokens
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens mutil_fake_B0_tokens = self.mutil_fake_B0_tokens_list[-1]
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens mutil_fake_B1_tokens = self.mutil_fake_B1_tokens_list[-1]
if self.opt.lambda_global > 0.0: if self.opt.lambda_global > 0.0:
@ -542,19 +515,18 @@ class RomaUnsbModel(BaseModel):
local_id = np.random.permutation(tokens_cnt) local_id = np.random.permutation(tokens_cnt)
local_id = local_id[:int(min(local_nums, tokens_cnt))] local_id = local_id[:int(min(local_nums, tokens_cnt))]
mutil_real_A0_local_tokens = self.netPreViT(self.resize(self.real_A0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length) mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_real_A1_local_tokens = self.netPreViT(self.resize(self.real_A1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length) mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B0_local_tokens = self.netPreViT(self.resize(self.fake_B0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length) mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B1_local_tokens = self.netPreViT(self.resize(self.fake_B1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length) mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens) loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
loss_spatial *= 0.5 loss_spatial *= 0.5
else: else:
loss_spatial = 0.0 loss_spatial = 0.0
return loss_global , loss_spatial
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens): def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
loss = 0.0 loss = 0.0
@ -569,5 +541,3 @@ class RomaUnsbModel(BaseModel):
loss = loss / n_layers loss = loss / n_layers
return loss return loss

View File

@ -7,20 +7,18 @@
python train.py \ python train.py \
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
--name SBIV_4 \ --name SBIV_1 \
--dataset_mode unaligned_double \ --dataset_mode unaligned_double \
--display_env SBIV \ --display_env SBIV2 \
--model roma_unsb \ --model roma_unsb \
--lambda_SB 1.0 \
--lambda_ctn 10 \ --lambda_ctn 10 \
--lambda_inc 1.0 \ --lambda_inc 1.0 \
--lambda_global 6.0 \ --lambda_global 8.0 \
--lambda_spatial 8.0 \
--gamma_stride 20 \ --gamma_stride 20 \
--lr 0.000002 \ --lr 0.000001 \
--gpu_id 2 \ --gpu_id 0 \
--nce_idt False \ --eta_ratio 0.3 \
--netF mlp_sample \
--eta_ratio 0.4 \
--tau 0.01 \ --tau 0.01 \
--num_timesteps 3 \ --num_timesteps 3 \
--input_nc 3 \ --input_nc 3 \
@ -28,6 +26,7 @@ python train.py \
--n_epochs_decay 200 \ --n_epochs_decay 200 \
# exp6 num_timesteps=4 gpu_id 0基于 exp5 ,exp1 已停) (已停) # exp6 num_timesteps=4 gpu_id 0基于 exp5 ,exp1 已停) (已停)
# exp7 num_timesteps=3 gpu_id 0 基于 exp6 (停) # exp7 num_timesteps=3 gpu_id 0 基于 exp6 (停)
# # exp8 num_timesteps=4 gpu_id 1 ,修改了训练判别器的loss以及ctnloss基于exp6 # # exp8 num_timesteps=4 gpu_id 1 ,修改了训练判别器的loss以及ctnloss基于exp6
# # exp9 num_timesteps=3 gpu_id 2 ,(基于 exp8 # # exp9 num_timesteps=3 gpu_id 2 ,(基于 exp8
# # # exp10 num_timesteps=4 gpu_id 0 , --name SBIV_1 ,让判别器看到了每一个时间步的输出修改了训练判别器的loss以及ctnloss基于exp9