尝试在每一步都给判别器看,但是速度太慢了

This commit is contained in:
bishe 2025-03-07 18:43:06 +08:00
parent 76fcec26e8
commit c6cb68e700
2 changed files with 155 additions and 186 deletions

View File

@ -2,6 +2,7 @@ import numpy as np
import math
import timm
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import GaussianBlur
@ -60,87 +61,69 @@ def compute_ctn_loss(G, x, F_content): #公式10
loss = F.mse_loss(warped_fake, y_fake_warped)
return loss
class ContentAwareOptimization(nn.Module):
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
super().__init__()
self.lambda_inc = lambda_inc # 权重增强系数
self.eta_ratio = eta_ratio # 选择内容区域的比例
# 改为类成员变量,确保钩子函数可访问
self.lambda_inc = lambda_inc
self.eta_ratio = eta_ratio
self.gradients_real = []
self.gradients_fake = []
def compute_cosine_similarity(self, gradients):
"""
计算每个patch梯度与平均梯度的余弦相似度
Args:
gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h)
Returns:
cosine_sim: [B, N] 每个patch的余弦相似度
"""
mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D]
# 计算余弦相似度
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
return cosine_sim
mean_grad = torch.mean(gradients, dim=1, keepdim=True)
return F.cosine_similarity(gradients, mean_grad, dim=2)
def generate_weight_map(self, gradients_real, gradients_fake):
"""
生成内容感知权重图
Args:
gradients_real: [B, N, D] 真实图像判别器梯度
gradients_fake: [B, N, D] 生成图像判别器梯度
Returns:
weight_real: [B, N] 真实图像权重图
weight_fake: [B, N] 生成图像权重图
"""
# 计算真实图像块的余弦相似度
cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5
# 计算生成图像块的余弦相似度
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
# 计算余弦相似度
cosine_real = self.compute_cosine_similarity(gradients_real)
cosine_fake = self.compute_cosine_similarity(gradients_fake)
# 选择内容丰富的区域余弦相似度最低的eta_ratio比例
k = int(self.eta_ratio * cosine_real.shape[1])
# 对真实图像生成权重图
_, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域
weight_real = torch.ones_like(cosine_real)
for b in range(cosine_real.shape[0]):
weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6
# 对生成图像生成权重图(同理)
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
weight_fake = torch.ones_like(cosine_fake)
for b in range(cosine_fake.shape[0]):
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
# 生成权重图(优化实现)
def _get_weights(cosine):
k = int(self.eta_ratio * cosine.shape[1])
_, indices = torch.topk(-cosine, k, dim=1)
weights = torch.ones_like(cosine)
weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices))))
return weights
weight_real = _get_weights(cosine_real)
weight_fake = _get_weights(cosine_fake)
return weight_real, weight_fake
def forward(self, D_real, D_fake, real_scores, fake_scores):
# 清空梯度缓存
self.gradients_real.clear()
self.gradients_fake.clear()
# 注册钩子
self.criterionGAN=networks.GANLoss('lsgan').cuda()
# 注册钩子捕获梯度
hook_real = lambda grad: self.gradients_real.append(grad.detach())
hook_fake = lambda grad: self.gradients_fake.append(grad.detach())
D_real.register_hook(hook_real)
D_fake.register_hook(hook_fake)
# 触发梯度计算
# 触发梯度计算(保留计算图)
(real_scores.mean() + fake_scores.mean()).backward(retain_graph=True)
# 获取梯度并调整维度
grad_real = self.gradients_real[0] # [B, N, D]
grad_fake = self.gradients_fake[0]
grad_real = self.gradients_real[0].flatten(1) # [B, N, D] → [B, N*D]
grad_fake = self.gradients_fake[0].flatten(1)
# 生成权重图
weight_real, weight_fake = self.generate_weight_map(grad_real, grad_fake)
weight_real, weight_fake = self.generate_weight_map(
grad_real.view(*D_real.shape),
grad_fake.view(*D_fake.shape)
)
# 计算加权损失
loss_co_real = (weight_real * real_scores).mean()
loss_co_fake = (weight_fake * fake_scores).mean()
# 正确应用权重到对数概率论文公式7
loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores , True))
loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores , False))
return (loss_co_real + loss_co_fake), weight_real, weight_fake
# 总损失(注意符号:判别器需最大化该损失)
loss_co_adv = (loss_co_real + loss_co_fake)*0.5
return loss_co_adv, weight_real, weight_fake
class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
@ -219,8 +202,10 @@ class RomaUnsbModel(BaseModel):
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
parser.add_argument('--local_nums', type=int, default=64, help='number of local patches')
parser.add_argument('--side_length', type=int, default=7)
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
type=util.str2bool, nargs='?', const=True, default=False,
@ -229,12 +214,8 @@ class RomaUnsbModel(BaseModel):
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
parser.add_argument('--flip_equivariance',
type=util.str2bool, nargs='?', const=True, default=False,
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix')
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
@ -251,10 +232,11 @@ class RomaUnsbModel(BaseModel):
BaseModel.__init__(self, opt)
# 指定需要打印的训练损失
self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn',]
self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0']
self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
self.visual_names = ['real_A0', 'fake_B0_1','fake_B0', 'real_B0','real_A1', 'fake_B1_1', 'fake_B1', 'real_B1']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
if self.opt.phase == 'test':
self.visual_names = ['real']
for NFE in range(self.opt.num_timesteps):
@ -262,12 +244,9 @@ class RomaUnsbModel(BaseModel):
self.visual_names.append(fake_name)
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
if opt.nce_idt and self.isTrain:
self.loss_names += ['NCE_Y']
self.visual_names += ['idt_B']
if self.isTrain:
self.model_names = ['G', 'D_ViT', 'E']
self.model_names = ['G', 'D_ViT']
else:
self.model_names = ['G']
@ -277,7 +256,6 @@ class RomaUnsbModel(BaseModel):
if self.isTrain:
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
self.resize = tfs.Resize(size=(384,384), antialias=True)
@ -289,11 +267,9 @@ class RomaUnsbModel(BaseModel):
# 定义损失函数
self.criterionL1 = torch.nn.L1Loss().to(self.device)
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionIdt = torch.nn.L1Loss().to(self.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizers = [self.optimizer_G, self.optimizer_D, self.optimizer_E]
self.optimizers = [self.optimizer_G, self.optimizer_D]
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
self.ctn = ContentAwareTemporalNorm() #生成的伪光流
@ -312,7 +288,6 @@ class RomaUnsbModel(BaseModel):
self.forward()
self.netG.train()
self.netE.train()
self.netD_ViT.train()
# update D
@ -322,19 +297,9 @@ class RomaUnsbModel(BaseModel):
self.loss_D.backward()
self.optimizer_D.step()
# update E
self.set_requires_grad(self.netE, True)
self.optimizer_E.zero_grad()
self.loss_E = self.compute_E_loss()
self.loss_E.backward()
self.optimizer_E.step()
# update G
self.set_requires_grad(self.netD_ViT, False)
self.set_requires_grad(self.netE, False)
self.optimizer_G.zero_grad()
self.loss_G = self.compute_G_loss()
self.loss_G.backward()
self.optimizer_G.step()
@ -370,6 +335,8 @@ class RomaUnsbModel(BaseModel):
bs = self.real_A0.size(0)
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
self.time_idx = time_idx
self.fake_B0_list = []
self.fake_B1_list = []
with torch.no_grad():
self.netG.eval()
@ -387,36 +354,23 @@ class RomaUnsbModel(BaseModel):
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device)
time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long()
z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device)
self.time = times[time_idx]
Xt_1 = self.netG(Xt, self.time, z)
time = times[time_idx]
Xt_1 = self.netG(Xt.detach(), time, z)
Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device)
time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long()
z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device)
Xt_12 = self.netG(Xt2, self.time, z)
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
self.real_A_noisy = Xt.detach()
self.real_A_noisy2 = Xt2.detach()
# ============ 第三步:拼接输入并执行网络推理 =============
bs = self.real_A0.size(0)
self.z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device)
self.z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
self.real = self.real_A0
self.realt = self.real_A_noisy
if self.opt.flip_equivariance:
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
if self.flipped_for_equivariance:
self.real = torch.flip(self.real, [3])
self.realt = torch.flip(self.realt, [3])
self.fake_B0 = self.netG(self.real_A_noisy, self.time, self.z_in)
self.fake_B1 = self.netG(self.real_A_noisy2, self.time, self.z_in2)
Xt_12 = self.netG(Xt2.detach(), time, z)
self.fake_B0_list.append(Xt_1)
self.fake_B1_list.append(Xt_12)
self.fake_B0_1 = self.fake_B0_list[0]
self.fake_B1_1 = self.fake_B0_list[0]
self.fake_B0 = self.fake_B0_list[-1]
self.fake_B1 = self.fake_B1_list[-1]
self.z_in = z
self.z_in2 = z
if self.opt.phase == 'train':
real_A0 = self.real_A0
real_A1 = self.real_A1
@ -424,6 +378,16 @@ class RomaUnsbModel(BaseModel):
real_B1 = self.real_B1
fake_B0 = self.fake_B0
fake_B1 = self.fake_B1
self.mutil_fake_B0_tokens_list = []
self.mutil_fake_B1_tokens_list = []
for fake_B0_t in self.fake_B0_list:
fake_B0_t_resize = self.resize(fake_B0_t) # 调整到 ViT 输入尺寸
tokens = self.netPreViT(fake_B0_t_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens_list.append(tokens)
for fake_B1_t in self.fake_B1_list:
fake_B1_t_resize = self.resize(fake_B1_t)
tokens = self.netPreViT(fake_B1_t_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens_list.append(tokens)
self.real_A0_resize = self.resize(real_A0)
self.real_A1_resize = self.resize(real_A1)
@ -436,96 +400,105 @@ class RomaUnsbModel(BaseModel):
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
# [[1,576,768],[1,576,768],[1,576,768]]
# [3,576,768]
def compute_D_loss(self): #判别器还是没有改
"""Calculate GAN loss for the discriminator"""
def compute_D_loss(self):
"""Calculate GAN loss with Content-Aware Optimization"""
lambda_D_ViT = self.opt.lambda_D_ViT
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
loss_cao = 0.0
real_B0_tokens = self.mutil_real_B0_tokens[0]
pred_real0, real_features0 = self.netD_ViT(real_B0_tokens) # scores, features
real_B1_tokens = self.mutil_real_B1_tokens[0]
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens) # scores, features
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
self.loss_D_fake_ViT = self.criterionGAN(pre_fake0_ViT, False)
for fake0_token, fake1_token in zip(self.mutil_fake_B0_tokens_list, self.mutil_fake_B1_tokens_list):
pre_fake0, fake_features0 = self.netD_ViT(fake0_token[0].detach())
pre_fake1, fake_features1 = self.netD_ViT(fake1_token[0].detach())
loss_cao0, self.weight_real0, self.weight_fake0 = self.cao(
D_real=real_features0,
D_fake=fake_features0,
real_scores=pred_real0,
fake_scores=pre_fake0
)
loss_cao1, self.weight_real1, self.weight_fake1 = self.cao(
D_real=real_features1,
D_fake=fake_features1,
real_scores=pred_real1,
fake_scores=pre_fake1
)
loss_cao += loss_cao0 + loss_cao1
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
self.loss_D_real_ViT = self.criterionGAN(pred_real0_ViT, True)
self.losscao, self.weight_real, self.weight_fake = self.cao(pred_real0_ViT, pre_fake0_ViT, self.loss_D_real_ViT, self.loss_D_fake_ViT)
self.loss_D_ViT = self.losscao* lambda_D_ViT
# ===== 综合损失 =====
total_steps = len(self.fake_B0_list)
self.loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT/ total_steps
# 记录损失值供可视化
# self.loss_D_real = loss_D_real.item()
# self.loss_D_fake = loss_D_fake.item()
# self.loss_cao = (loss_cao0 + loss_cao1).item() * 0.5
return self.loss_D_ViT
def compute_E_loss(self):
"""计算判别器 E 的损失"""
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
return self.loss_E
def compute_G_loss(self):
"""计算生成器的 GAN 损失"""
if self.opt.lambda_ctn > 0.0:
# 生成图像的CTN光流图
self.f_content = self.ctn(self.weight_fake)
self.f_content0 = self.ctn(self.weight_fake0)
self.f_content1 = self.ctn(self.weight_fake1)
# 变换后的图片
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
self.warped_real_A0 = warp(self.real_A0, self.f_content0)
self.warped_real_A1 = warp(self.real_A1, self.f_content1)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content0)
self.warped_fake_B1 = warp(self.fake_B1,self.f_content1)
# 经过第二次生成器
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, self.z_in)
self.warped_fake_B0_2 = self.netG(self.warped_real_A0, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in)
self.warped_fake_B1_2 = self.netG(self.warped_real_A1, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in2)
warped_fake_B0_2=self.warped_fake_B0_2
warped_fake_B1_2=self.warped_fake_B1_2
warped_fake_B0=self.warped_fake_B0
warped_fake_B1=self.warped_fake_B1
# 计算L2损失
self.loss_ctn = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
self.loss_ctn0 = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
self.loss_ctn1 = F.mse_loss(warped_fake_B1_2, warped_fake_B1)
self.loss_ctn = (self.loss_ctn0 + self.loss_ctn1)*0.5
if self.opt.lambda_GAN > 0.0:
pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean()
pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens_list[-1][0])
pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens_list[-1][0])
self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean()
self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean()
self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5
else:
self.loss_G_GAN = 0.0
self.loss_SB = 0
if self.opt.lambda_SB > 0.0:
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1)
bs = self.opt.batch_size
# eq.9
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - self.netE(XtXt_1, self.time, XtXt_2).mean()
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
if self.opt.lambda_global > 0.0:
self.loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
self.loss_global *= 0.5
if self.opt.lambda_global or self.opt.lambda_spatial > 0.0:
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
else:
self.loss_global = 0.0
self.loss_global, self.loss_spatial = 0.0, 0.0
self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \
self.opt.lambda_SB * self.loss_SB + \
self.opt.lambda_ctn * self.loss_ctn + \
self.loss_global * self.opt.lambda_global
self.loss_global * self.opt.lambda_global+\
self.loss_spatial * self.opt.lambda_spatial
return self.loss_G
def calculate_attention_loss(self):
n_layers = len(self.atten_layers)
mutil_real_A0_tokens = self.mutil_real_A0_tokens
mutil_real_A1_tokens = self.mutil_real_A1_tokens
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens_list[-1]
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens_list[-1]
if self.opt.lambda_global > 0.0:
@ -542,19 +515,18 @@ class RomaUnsbModel(BaseModel):
local_id = np.random.permutation(tokens_cnt)
local_id = local_id[:int(min(local_nums, tokens_cnt))]
mutil_real_A0_local_tokens = self.netPreViT(self.resize(self.real_A0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
mutil_real_A1_local_tokens = self.netPreViT(self.resize(self.real_A1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B0_local_tokens = self.netPreViT(self.resize(self.fake_B0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
mutil_fake_B1_local_tokens = self.netPreViT(self.resize(self.fake_B1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
loss_spatial *= 0.5
else:
loss_spatial = 0.0
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
return loss_global , loss_spatial
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
loss = 0.0
@ -569,5 +541,3 @@ class RomaUnsbModel(BaseModel):
loss = loss / n_layers
return loss

View File

@ -7,20 +7,18 @@
python train.py \
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
--name SBIV_4 \
--name SBIV_1 \
--dataset_mode unaligned_double \
--display_env SBIV \
--display_env SBIV2 \
--model roma_unsb \
--lambda_SB 1.0 \
--lambda_ctn 10 \
--lambda_inc 1.0 \
--lambda_global 6.0 \
--lambda_global 8.0 \
--lambda_spatial 8.0 \
--gamma_stride 20 \
--lr 0.000002 \
--gpu_id 2 \
--nce_idt False \
--netF mlp_sample \
--eta_ratio 0.4 \
--lr 0.000001 \
--gpu_id 0 \
--eta_ratio 0.3 \
--tau 0.01 \
--num_timesteps 3 \
--input_nc 3 \
@ -28,6 +26,7 @@ python train.py \
--n_epochs_decay 200 \
# exp6 num_timesteps=4 gpu_id 0基于 exp5 ,exp1 已停) (已停)
# exp7 num_timesteps=3 gpu_id 0 基于 exp6 (停)
# exp7 num_timesteps=3 gpu_id 0 基于 exp6 (停)
# # exp8 num_timesteps=4 gpu_id 1 ,修改了训练判别器的loss以及ctnloss基于exp6
# # exp9 num_timesteps=3 gpu_id 2 ,(基于 exp8
# # # exp10 num_timesteps=4 gpu_id 0 , --name SBIV_1 ,让判别器看到了每一个时间步的输出修改了训练判别器的loss以及ctnloss基于exp9