From c6cb68e700519d14a4ca360b921666e51095025d Mon Sep 17 00:00:00 2001 From: bishe <123456789@163.com> Date: Fri, 7 Mar 2025 18:43:06 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=9D=E8=AF=95=E5=9C=A8=E6=AF=8F=E4=B8=80?= =?UTF-8?q?=E6=AD=A5=E9=83=BD=E7=BB=99=E5=88=A4=E5=88=AB=E5=99=A8=E7=9C=8B?= =?UTF-8?q?=EF=BC=8C=E4=BD=86=E6=98=AF=E9=80=9F=E5=BA=A6=E5=A4=AA=E6=85=A2?= =?UTF-8?q?=E4=BA=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/roma_unsb_model.py | 320 +++++++++++++++++--------------------- scripts/train_sbiv.sh | 21 ++- 2 files changed, 155 insertions(+), 186 deletions(-) diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 81421b8..9de01c6 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -2,6 +2,7 @@ import numpy as np import math import timm import torch +import torchvision.models as models import torch.nn as nn import torch.nn.functional as F from torchvision.transforms import GaussianBlur @@ -60,87 +61,69 @@ def compute_ctn_loss(G, x, F_content): #公式10 loss = F.mse_loss(warped_fake, y_fake_warped) return loss + class ContentAwareOptimization(nn.Module): def __init__(self, lambda_inc=2.0, eta_ratio=0.4): super().__init__() - self.lambda_inc = lambda_inc # 权重增强系数 - self.eta_ratio = eta_ratio # 选择内容区域的比例 - - # 改为类成员变量,确保钩子函数可访问 + self.lambda_inc = lambda_inc + self.eta_ratio = eta_ratio self.gradients_real = [] self.gradients_fake = [] - + + def compute_cosine_similarity(self, gradients): - """ - 计算每个patch梯度与平均梯度的余弦相似度 - Args: - gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h) - Returns: - cosine_sim: [B, N] 每个patch的余弦相似度 - """ - mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D] - # 计算余弦相似度 - cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N] - return cosine_sim + mean_grad = torch.mean(gradients, dim=1, keepdim=True) + return F.cosine_similarity(gradients, mean_grad, dim=2) def generate_weight_map(self, gradients_real, gradients_fake): - """ - 生成内容感知权重图 - Args: - gradients_real: [B, N, D] 真实图像判别器梯度 - gradients_fake: [B, N, D] 生成图像判别器梯度 - Returns: - weight_real: [B, N] 真实图像权重图 - weight_fake: [B, N] 生成图像权重图 - """ - # 计算真实图像块的余弦相似度 - cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5 - # 计算生成图像块的余弦相似度 - cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N] + # 计算余弦相似度 + cosine_real = self.compute_cosine_similarity(gradients_real) + cosine_fake = self.compute_cosine_similarity(gradients_fake) - # 选择内容丰富的区域(余弦相似度最低的eta_ratio比例) - k = int(self.eta_ratio * cosine_real.shape[1]) - - # 对真实图像生成权重图 - _, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域 - weight_real = torch.ones_like(cosine_real) - for b in range(cosine_real.shape[0]): - weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6 - - # 对生成图像生成权重图(同理) - _, fake_indices = torch.topk(-cosine_fake, k, dim=1) - weight_fake = torch.ones_like(cosine_fake) - for b in range(cosine_fake.shape[0]): - weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]])) + # 生成权重图(优化实现) + def _get_weights(cosine): + k = int(self.eta_ratio * cosine.shape[1]) + _, indices = torch.topk(-cosine, k, dim=1) + weights = torch.ones_like(cosine) + weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices)))) + return weights + weight_real = _get_weights(cosine_real) + weight_fake = _get_weights(cosine_fake) return weight_real, weight_fake def forward(self, D_real, D_fake, real_scores, fake_scores): # 清空梯度缓存 self.gradients_real.clear() - self.gradients_fake.clear() - # 注册钩子 - hook_real = lambda grad: self.gradients_real.append(grad.detach()) - hook_fake = lambda grad: self.gradients_fake.append(grad.detach()) - + self.gradients_fake.clear() + self.criterionGAN=networks.GANLoss('lsgan').cuda() + # 注册钩子捕获梯度 + hook_real = lambda grad: self.gradients_real.append(grad.detach()) + hook_fake = lambda grad: self.gradients_fake.append(grad.detach()) D_real.register_hook(hook_real) D_fake.register_hook(hook_fake) - # 触发梯度计算 + # 触发梯度计算(保留计算图) (real_scores.mean() + fake_scores.mean()).backward(retain_graph=True) - + # 获取梯度并调整维度 - grad_real = self.gradients_real[0] # [B, N, D] - grad_fake = self.gradients_fake[0] + grad_real = self.gradients_real[0].flatten(1) # [B, N, D] → [B, N*D] + grad_fake = self.gradients_fake[0].flatten(1) # 生成权重图 - weight_real, weight_fake = self.generate_weight_map(grad_real, grad_fake) + weight_real, weight_fake = self.generate_weight_map( + grad_real.view(*D_real.shape), + grad_fake.view(*D_fake.shape) + ) - # 计算加权损失 - loss_co_real = (weight_real * real_scores).mean() - loss_co_fake = (weight_fake * fake_scores).mean() + # 正确应用权重到对数概率(论文公式7) + loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores , True)) + loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores , False)) - return (loss_co_real + loss_co_fake), weight_real, weight_fake + # 总损失(注意符号:判别器需最大化该损失) + loss_co_adv = (loss_co_real + loss_co_fake)*0.5 + + return loss_co_adv, weight_real, weight_fake class ContentAwareTemporalNorm(nn.Module): def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0): @@ -207,7 +190,7 @@ class ContentAwareTemporalNorm(nn.Module): # 限制光流幅值,避免极端位移 F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围 - return F_content + return F_content class RomaUnsbModel(BaseModel): @staticmethod @@ -219,22 +202,20 @@ class RomaUnsbModel(BaseModel): parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm') parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator') parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') + parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency') parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization') - + parser.add_argument('--local_nums', type=int, default=64, help='number of local patches') + parser.add_argument('--side_length', type=int, default=7) parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))') parser.add_argument('--nce_includes_all_negatives_from_minibatch', type=util.str2bool, nargs='?', const=True, default=False, help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.') parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers') - + parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map') - - parser.add_argument('--flip_equivariance', - type=util.str2bool, nargs='?', const=True, default=False, - help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT") - + parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions') - + parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix') parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers') parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter') @@ -251,10 +232,11 @@ class RomaUnsbModel(BaseModel): BaseModel.__init__(self, opt) # 指定需要打印的训练损失 - self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn',] - self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0'] + self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn'] + self.visual_names = ['real_A0', 'fake_B0_1','fake_B0', 'real_B0','real_A1', 'fake_B1_1', 'fake_B1', 'real_B1'] self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')] + if self.opt.phase == 'test': self.visual_names = ['real'] for NFE in range(self.opt.num_timesteps): @@ -262,12 +244,9 @@ class RomaUnsbModel(BaseModel): self.visual_names.append(fake_name) self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')] - if opt.nce_idt and self.isTrain: - self.loss_names += ['NCE_Y'] - self.visual_names += ['idt_B'] if self.isTrain: - self.model_names = ['G', 'D_ViT', 'E'] + self.model_names = ['G', 'D_ViT'] else: self.model_names = ['G'] @@ -277,7 +256,6 @@ class RomaUnsbModel(BaseModel): if self.isTrain: - self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) self.resize = tfs.Resize(size=(384,384), antialias=True) @@ -289,11 +267,9 @@ class RomaUnsbModel(BaseModel): # 定义损失函数 self.criterionL1 = torch.nn.L1Loss().to(self.device) self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) - self.criterionIdt = torch.nn.L1Loss().to(self.device) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) - self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) - self.optimizers = [self.optimizer_G, self.optimizer_D, self.optimizer_E] + self.optimizers = [self.optimizer_G, self.optimizer_D] self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数 self.ctn = ContentAwareTemporalNorm() #生成的伪光流 @@ -312,7 +288,6 @@ class RomaUnsbModel(BaseModel): self.forward() self.netG.train() - self.netE.train() self.netD_ViT.train() # update D @@ -322,19 +297,9 @@ class RomaUnsbModel(BaseModel): self.loss_D.backward() self.optimizer_D.step() - # update E - self.set_requires_grad(self.netE, True) - self.optimizer_E.zero_grad() - self.loss_E = self.compute_E_loss() - self.loss_E.backward() - self.optimizer_E.step() - # update G - self.set_requires_grad(self.netD_ViT, False) - self.set_requires_grad(self.netE, False) - + self.set_requires_grad(self.netD_ViT, False) self.optimizer_G.zero_grad() - self.loss_G = self.compute_G_loss() self.loss_G.backward() self.optimizer_G.step() @@ -370,7 +335,9 @@ class RomaUnsbModel(BaseModel): bs = self.real_A0.size(0) time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long() self.time_idx = time_idx - + self.fake_B0_list = [] + self.fake_B1_list = [] + with torch.no_grad(): self.netG.eval() # ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============ @@ -387,36 +354,23 @@ class RomaUnsbModel(BaseModel): (scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device) time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long() z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device) - self.time = times[time_idx] - Xt_1 = self.netG(Xt, self.time, z) + time = times[time_idx] + Xt_1 = self.netG(Xt.detach(), time, z) Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \ (scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device) time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long() z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device) - Xt_12 = self.netG(Xt2, self.time, z) - - # 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接 - self.real_A_noisy = Xt.detach() - self.real_A_noisy2 = Xt2.detach() - - # ============ 第三步:拼接输入并执行网络推理 ============= - bs = self.real_A0.size(0) - self.z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device) - self.z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device) - # 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB - self.real = self.real_A0 - self.realt = self.real_A_noisy - - if self.opt.flip_equivariance: - self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5) - if self.flipped_for_equivariance: - self.real = torch.flip(self.real, [3]) - self.realt = torch.flip(self.realt, [3]) + Xt_12 = self.netG(Xt2.detach(), time, z) + self.fake_B0_list.append(Xt_1) + self.fake_B1_list.append(Xt_12) - self.fake_B0 = self.netG(self.real_A_noisy, self.time, self.z_in) - self.fake_B1 = self.netG(self.real_A_noisy2, self.time, self.z_in2) - + self.fake_B0_1 = self.fake_B0_list[0] + self.fake_B1_1 = self.fake_B0_list[0] + self.fake_B0 = self.fake_B0_list[-1] + self.fake_B1 = self.fake_B1_list[-1] + self.z_in = z + self.z_in2 = z if self.opt.phase == 'train': real_A0 = self.real_A0 real_A1 = self.real_A1 @@ -424,6 +378,16 @@ class RomaUnsbModel(BaseModel): real_B1 = self.real_B1 fake_B0 = self.fake_B0 fake_B1 = self.fake_B1 + self.mutil_fake_B0_tokens_list = [] + self.mutil_fake_B1_tokens_list = [] + for fake_B0_t in self.fake_B0_list: + fake_B0_t_resize = self.resize(fake_B0_t) # 调整到 ViT 输入尺寸 + tokens = self.netPreViT(fake_B0_t_resize, self.atten_layers, get_tokens=True) + self.mutil_fake_B0_tokens_list.append(tokens) + for fake_B1_t in self.fake_B1_list: + fake_B1_t_resize = self.resize(fake_B1_t) + tokens = self.netPreViT(fake_B1_t_resize, self.atten_layers, get_tokens=True) + self.mutil_fake_B1_tokens_list.append(tokens) self.real_A0_resize = self.resize(real_A0) self.real_A1_resize = self.resize(real_A1) @@ -431,101 +395,110 @@ class RomaUnsbModel(BaseModel): real_B1 = self.resize(real_B1) self.fake_B0_resize = self.resize(fake_B0) self.fake_B1_resize = self.resize(fake_B1) - + self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True) self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True) self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True) self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True) - self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True) - self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) # [[1,576,768],[1,576,768],[1,576,768]] # [3,576,768] - - def compute_D_loss(self): #判别器还是没有改 - """Calculate GAN loss for the discriminator""" - + def compute_D_loss(self): + """Calculate GAN loss with Content-Aware Optimization""" lambda_D_ViT = self.opt.lambda_D_ViT - fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach() - - + + loss_cao = 0.0 real_B0_tokens = self.mutil_real_B0_tokens[0] + pred_real0, real_features0 = self.netD_ViT(real_B0_tokens) # scores, features + real_B1_tokens = self.mutil_real_B1_tokens[0] + pred_real1, real_features1 = self.netD_ViT(real_B1_tokens) # scores, features + + for fake0_token, fake1_token in zip(self.mutil_fake_B0_tokens_list, self.mutil_fake_B1_tokens_list): + pre_fake0, fake_features0 = self.netD_ViT(fake0_token[0].detach()) + pre_fake1, fake_features1 = self.netD_ViT(fake1_token[0].detach()) + loss_cao0, self.weight_real0, self.weight_fake0 = self.cao( + D_real=real_features0, + D_fake=fake_features0, + real_scores=pred_real0, + fake_scores=pre_fake0 + ) + loss_cao1, self.weight_real1, self.weight_fake1 = self.cao( + D_real=real_features1, + D_fake=fake_features1, + real_scores=pred_real1, + fake_scores=pre_fake1 + ) + loss_cao += loss_cao0 + loss_cao1 - pre_fake0_ViT = self.netD_ViT(fake_B0_tokens) - self.loss_D_fake_ViT = self.criterionGAN(pre_fake0_ViT, False) - pred_real0_ViT = self.netD_ViT(real_B0_tokens) - self.loss_D_real_ViT = self.criterionGAN(pred_real0_ViT, True) - - self.losscao, self.weight_real, self.weight_fake = self.cao(pred_real0_ViT, pre_fake0_ViT, self.loss_D_real_ViT, self.loss_D_fake_ViT) - self.loss_D_ViT = self.losscao* lambda_D_ViT + # ===== 综合损失 ===== + total_steps = len(self.fake_B0_list) + self.loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT/ total_steps + + + # 记录损失值供可视化 + # self.loss_D_real = loss_D_real.item() + # self.loss_D_fake = loss_D_fake.item() + # self.loss_cao = (loss_cao0 + loss_cao1).item() * 0.5 + return self.loss_D_ViT - - def compute_E_loss(self): - """计算判别器 E 的损失""" - - XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1) - XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1) - temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean() - self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2 - - return self.loss_E def compute_G_loss(self): """计算生成器的 GAN 损失""" if self.opt.lambda_ctn > 0.0: # 生成图像的CTN光流图 - self.f_content = self.ctn(self.weight_fake) + self.f_content0 = self.ctn(self.weight_fake0) + self.f_content1 = self.ctn(self.weight_fake1) # 变换后的图片 - self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content) - self.warped_fake_B0 = warp(self.fake_B0,self.f_content) + self.warped_real_A0 = warp(self.real_A0, self.f_content0) + self.warped_real_A1 = warp(self.real_A1, self.f_content1) + self.warped_fake_B0 = warp(self.fake_B0,self.f_content0) + self.warped_fake_B1 = warp(self.fake_B1,self.f_content1) # 经过第二次生成器 - self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, self.z_in) + self.warped_fake_B0_2 = self.netG(self.warped_real_A0, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in) + self.warped_fake_B1_2 = self.netG(self.warped_real_A1, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in2) warped_fake_B0_2=self.warped_fake_B0_2 + warped_fake_B1_2=self.warped_fake_B1_2 warped_fake_B0=self.warped_fake_B0 + warped_fake_B1=self.warped_fake_B1 # 计算L2损失 - self.loss_ctn = F.mse_loss(warped_fake_B0_2, warped_fake_B0) + self.loss_ctn0 = F.mse_loss(warped_fake_B0_2, warped_fake_B0) + self.loss_ctn1 = F.mse_loss(warped_fake_B1_2, warped_fake_B1) + self.loss_ctn = (self.loss_ctn0 + self.loss_ctn1)*0.5 if self.opt.lambda_GAN > 0.0: - pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0]) - self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() + + pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens_list[-1][0]) + pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens_list[-1][0]) + self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean() + self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean() + self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5 else: self.loss_G_GAN = 0.0 - self.loss_SB = 0 - if self.opt.lambda_SB > 0.0: - XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1) - XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1) - - bs = self.opt.batch_size - # eq.9 - ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - self.netE(XtXt_1, self.time, XtXt_2).mean() - self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY - self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2) - - if self.opt.lambda_global > 0.0: - self.loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1) - self.loss_global *= 0.5 + if self.opt.lambda_global or self.opt.lambda_spatial > 0.0: + self.loss_global, self.loss_spatial = self.calculate_attention_loss() else: - self.loss_global = 0.0 + self.loss_global, self.loss_spatial = 0.0, 0.0 self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \ - self.opt.lambda_SB * self.loss_SB + \ self.opt.lambda_ctn * self.loss_ctn + \ - self.loss_global * self.opt.lambda_global + self.loss_global * self.opt.lambda_global+\ + self.loss_spatial * self.opt.lambda_spatial + return self.loss_G - + def calculate_attention_loss(self): n_layers = len(self.atten_layers) mutil_real_A0_tokens = self.mutil_real_A0_tokens mutil_real_A1_tokens = self.mutil_real_A1_tokens - mutil_fake_B0_tokens = self.mutil_fake_B0_tokens - mutil_fake_B1_tokens = self.mutil_fake_B1_tokens + mutil_fake_B0_tokens = self.mutil_fake_B0_tokens_list[-1] + mutil_fake_B1_tokens = self.mutil_fake_B1_tokens_list[-1] if self.opt.lambda_global > 0.0: @@ -542,20 +515,19 @@ class RomaUnsbModel(BaseModel): local_id = np.random.permutation(tokens_cnt) local_id = local_id[:int(min(local_nums, tokens_cnt))] - mutil_real_A0_local_tokens = self.netPreViT(self.resize(self.real_A0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length) - mutil_real_A1_local_tokens = self.netPreViT(self.resize(self.real_A1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length) + mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length) + mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length) - mutil_fake_B0_local_tokens = self.netPreViT(self.resize(self.fake_B0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length) - mutil_fake_B1_local_tokens = self.netPreViT(self.resize(self.fake_B1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length) + mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length) + mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length) loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens) loss_spatial *= 0.5 else: loss_spatial = 0.0 + return loss_global , loss_spatial - return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial - def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens): loss = 0.0 n_layers = len(self.atten_layers) @@ -569,5 +541,3 @@ class RomaUnsbModel(BaseModel): loss = loss / n_layers return loss - - \ No newline at end of file diff --git a/scripts/train_sbiv.sh b/scripts/train_sbiv.sh index 8aa431a..5124c55 100755 --- a/scripts/train_sbiv.sh +++ b/scripts/train_sbiv.sh @@ -7,20 +7,18 @@ python train.py \ --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ - --name SBIV_4 \ + --name SBIV_1 \ --dataset_mode unaligned_double \ - --display_env SBIV \ + --display_env SBIV2 \ --model roma_unsb \ - --lambda_SB 1.0 \ --lambda_ctn 10 \ --lambda_inc 1.0 \ - --lambda_global 6.0 \ + --lambda_global 8.0 \ + --lambda_spatial 8.0 \ --gamma_stride 20 \ - --lr 0.000002 \ - --gpu_id 2 \ - --nce_idt False \ - --netF mlp_sample \ - --eta_ratio 0.4 \ + --lr 0.000001 \ + --gpu_id 0 \ + --eta_ratio 0.3 \ --tau 0.01 \ --num_timesteps 3 \ --input_nc 3 \ @@ -28,6 +26,7 @@ python train.py \ --n_epochs_decay 200 \ # exp6 num_timesteps=4 ,gpu_id 0(基于 exp5 ,exp1 已停) (已停) -# exp7 num_timesteps=3 ,gpu_id 0 基于 exp6 (将停) +# exp7 num_timesteps=3 ,gpu_id 0 基于 exp6 (已停) # # exp8 num_timesteps=4 ,gpu_id 1 ,修改了训练判别器的loss,以及ctnloss(基于,exp6) -# # exp9 num_timesteps=3 ,gpu_id 2 ,(基于 exp8) \ No newline at end of file +# # exp9 num_timesteps=3 ,gpu_id 2 ,(基于 exp8) +# # # exp10 num_timesteps=4 ,gpu_id 0 , --name SBIV_1 ,让判别器看到了每一个时间步的输出,修改了训练判别器的loss,以及ctnloss(基于,exp9) \ No newline at end of file