diff --git a/models/networks.py b/models/networks.py index 3c29522..a5f6b38 100644 --- a/models/networks.py +++ b/models/networks.py @@ -1411,13 +1411,12 @@ class MLPDiscriminator(nn.Module): self.activation = nn.GELU() self.linear2 = nn.Linear(hid_feat, out_feat) self.dropout = nn.Dropout(dropout) - def forward(self, x): - features = self.linear1(x) # 中间特征,即 D_real 或 D_fake - x = self.activation(features) + x = self.linear1(x) + x = self.activation(x) x = self.dropout(x) - scores = self.linear2(x) # 最终分数,即 real_scores 或 fake_scores - return scores, features + x = self.linear2(x) + return self.dropout(x) class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator""" diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 9194daf..5f9db77 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -48,68 +48,65 @@ class ContentAwareOptimization(nn.Module): def compute_cosine_similarity(self, grad_patch, grad_mean): """ - 计算每个 patch 梯度与整体平均梯度的余弦相似度 + 计算每个 token 梯度与整体平均梯度的余弦相似度 Args: - grad_patch: [B, 1, H, W],每个 patch 的梯度(基于 scores) - grad_mean: [B, 1],整体平均梯度 + grad_patch: [B, N, D],每个 token 的梯度(来自 scores) + grad_mean: [B, D],整体平均梯度 Returns: - cosine: [B, 1, H, W],余弦相似度 δ_i + cosine: [B, N],余弦相似度 δ_i """ - B, _, H, W = grad_patch.shape - grad_patch = grad_patch.view(B, 1, -1) # [B, 1, H*W] - grad_mean = grad_mean.unsqueeze(-1) # [B, 1, 1] - # 计算余弦相似度 - cosine = F.cosine_similarity(grad_patch, grad_mean, dim=1) # [B, H*W] - return cosine.view(B, 1, H, W) + # 对每个 token 计算余弦相似度 + cosine = F.cosine_similarity(grad_patch, grad_mean.unsqueeze(1), dim=2) # [B, N] + return cosine def generate_weight_map(self, cosine): """ 根据余弦相似度生成权重图 Args: - cosine: [B, 1, H, W],余弦相似度 δ_i + cosine: [B, N],余弦相似度 δ_i Returns: - weights: [B, 1, H, W],权重图 w_i + weights: [B, N],权重图 w_i """ - B, _, H, W = cosine.shape - cosine_flat = cosine.view(B, -1) # [B, H*W] - k = int(self.eta_ratio * cosine_flat.size(1)) # 选择 eta_ratio 比例的 patch - _, indices = torch.topk(-cosine_flat, k, dim=1) # 选择偏离最大的 k 个 patch - weights = torch.ones_like(cosine_flat) + B, N = cosine.shape + k = int(self.eta_ratio * N) # 选择 eta_ratio 比例的 token + _, indices = torch.topk(-cosine, k, dim=1) # 选择偏离最大的 k 个 token + weights = torch.ones_like(cosine) for b in range(B): - selected_cosine = cosine_flat[b, indices[b]] + selected_cosine = cosine[b, indices[b]] weights[b, indices[b]] = self.lambda_inc / (torch.exp(torch.abs(selected_cosine)) + 1e-6) - return weights.view(B, 1, H, W) + return weights def forward(self, scores, target): """ 前向传播,计算加权后的 GAN 损失 Args: - scores: [B, 1, H, W],判别器的预测得分 + scores: [B, N, D],判别器的预测得分 target: 目标标签(True 或 False) Returns: weighted_loss: 加权后的 GAN 损失 - weight: 权重图 [B, 1, H, W] + weight: 权重图 [B, N] """ - # 计算原始 GAN 损失 + # 计算原始 GAN 损失(假设 criterionGAN 返回 [B, N] 的损失分布) loss = self.criterionGAN(scores, target) - # 捕获特征的梯度 - grad_scores = torch.autograd.grad(loss, scores, retain_graph=True)[0] # [B, C, H, W] + # 捕获 scores 的梯度,形状为 [B, N, D] + grad_scores = torch.autograd.grad(loss, scores, retain_graph=True)[0] - # 计算整体平均梯度 - grad_mean = torch.mean(grad_scores, dim=(2, 3)) # [B, 1] + # 计算整体平均梯度(在 N 维度上求均值) + grad_mean = torch.mean(grad_scores, dim=1) # [B, D] - # 计算余弦相似度 δ_i(公式 5) - cosine = self.compute_cosine_similarity(grad_scores, grad_mean) # [B, 1, H, W] + # 计算余弦相似度 δ_i + cosine = self.compute_cosine_similarity(grad_scores, grad_mean) # [B, N] - # 生成权重图 w_i(公式 6) - weight = self.generate_weight_map(cosine) + # 生成权重图 w_i + weight = self.generate_weight_map(cosine) # [B, N] - # 应用权重到损失(公式 7 的部分实现) + # 计算加权后的 GAN 损失 weighted_loss = torch.mean(weight * self.criterionGAN(scores, target)) return weighted_loss, weight + class ContentAwareTemporalNorm(nn.Module): def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0): super().__init__() @@ -117,48 +114,50 @@ class ContentAwareTemporalNorm(nn.Module): self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层 def upsample_weight_map(self, weight_patch, target_size=(256, 256)): - # weight_patch: [B, 1, 30, 30] 来自 PatchGAN + # weight_patch: [B, 1, H, W] 来自转换后的 weight_map weight_full = F.interpolate( weight_patch, size=target_size, mode='bilinear', # 或 'nearest',根据需求选择 align_corners=False ) - return weight_full # [B, 1, 256, 256] - + return weight_full + def forward(self, weight_map): """ 生成内容感知光流 Args: - weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块) + weight_map: [B, N] 权重图(来自 ContentAwareOptimization),其中 N=576 Returns: F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) """ + B = weight_map.shape[0] + N = weight_map.shape[1] + # 假设 N 为完全平方数,计算边长(例如 576 -> 24x24) + side = int(math.sqrt(N)) + weight_map_2d = weight_map.view(B, 1, side, side) # 转换为 [B, 1, side, side] + # 上采样权重图到全分辨率 - weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384] + weight_full = self.upsample_weight_map(weight_map_2d) # [B, 1, 256, 256](例如) - # 1. 归一化权重图 - # 保持区域相对强度,同时限制数值范围 - weight_norm = F.normalize(weight_full, p=1, dim=(2,3)) # L1归一化 [B,1,H,W] + # 归一化权重图(L1归一化) + weight_norm = F.normalize(weight_full, p=1, dim=(2,3)) - # 2. 生成高斯噪声 + # 生成高斯噪声 B, _, H, W = weight_norm.shape - z = torch.randn(B, 2, H, W, device=weight_norm.device) # [B,2,H,W] + z = torch.randn(B, 2, H, W, device=weight_norm.device) - # 3. 合成基础光流 - # 将权重图扩展为2通道(x/y方向共享权重) - weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W] - F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9 + # 合成基础光流 + weight_expanded = weight_norm.expand(-1, 2, -1, -1) + F_raw = self.gamma_stride * weight_expanded * z - # 4. 平滑处理(保持结构连续性) - # 对每个通道独立进行高斯模糊 - F_smooth = self.smoother(F_raw) # [B,2,H,W] + # 平滑处理 + F_smooth = self.smoother(F_raw) - # 5. 动态范围调整(可选) - # 限制光流幅值,避免极端位移 - F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围 + # 动态范围调整 + F_content = torch.tanh(F_smooth) - return F_content + return F_content class RomaUnsbModel(BaseModel): @@ -314,7 +313,6 @@ class RomaUnsbModel(BaseModel): # 处理 real_B0 和 fake_B0 real_B0_tokens = self.mutil_real_B0_tokens[0] pred_real0 = self.netD_ViT(real_B0_tokens) - print(pred_real0.shape) fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach() pred_fake0 = self.netD_ViT(fake_B0_tokens) @@ -367,8 +365,8 @@ class RomaUnsbModel(BaseModel): # 计算 GAN 损失(引入 ContentAwareOptimization) if self.opt.lambda_GAN > 0.0: - pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens[0]) - pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens[0]) + pred_fake0 = self.netD_ViT(self.mutil_fake_B0_tokens[0]) + pred_fake1 = self.netD_ViT(self.mutil_fake_B1_tokens[0]) self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean() self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean() self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5 diff --git a/models/roma_unsb_single_model.py b/models/roma_unsb_single_model.py new file mode 100644 index 0000000..1bff4ed --- /dev/null +++ b/models/roma_unsb_single_model.py @@ -0,0 +1,391 @@ +import numpy as np +import math +import timm +import torch +import torchvision.models as models +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import GaussianBlur +from .base_model import BaseModel +from . import networks +from .patchnce import PatchNCELoss +import util.util as util + +from torchvision.transforms import transforms as tfs + +def warp(image, flow): #warp操作 + """ + 基于光流的图像变形函数 + Args: + image: [B, C, H, W] 输入图像 + flow: [B, 2, H, W] 光流场(x/y方向位移) + Returns: + warped: [B, C, H, W] 变形后的图像 + """ + B, C, H, W = image.shape + # 生成网格坐标 + grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H)) + grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W] + grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W] + + # 应用光流位移(归一化到[-1,1]) + new_grid = grid + flow + new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向 + new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向 + new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2] + + # 双线性插值 + return F.grid_sample(image, new_grid, align_corners=True) + + +class ContentAwareOptimization(nn.Module): + def __init__(self, lambda_inc=2.0, eta_ratio=0.4): + super().__init__() + self.lambda_inc = lambda_inc # 控制内容丰富区域的权重增量 + self.eta_ratio = eta_ratio # 选择内容丰富区域的比例 + self.criterionGAN = networks.GANLoss('lsgan').cuda() # 使用 LSGAN 损失 + + def compute_cosine_similarity(self, grad_patch, grad_mean): + """ + 计算每个 token 梯度与整体平均梯度的余弦相似度 + Args: + grad_patch: [B, N, D],每个 token 的梯度(来自 scores) + grad_mean: [B, D],整体平均梯度 + Returns: + cosine: [B, N],余弦相似度 δ_i + """ + # 对每个 token 计算余弦相似度 + cosine = F.cosine_similarity(grad_patch, grad_mean.unsqueeze(1), dim=2) # [B, N] + return cosine + + def generate_weight_map(self, cosine): + """ + 根据余弦相似度生成权重图 + Args: + cosine: [B, N],余弦相似度 δ_i + Returns: + weights: [B, N],权重图 w_i + """ + B, N = cosine.shape + k = int(self.eta_ratio * N) # 选择 eta_ratio 比例的 token + _, indices = torch.topk(-cosine, k, dim=1) # 选择偏离最大的 k 个 token + weights = torch.ones_like(cosine) + for b in range(B): + selected_cosine = cosine[b, indices[b]] + weights[b, indices[b]] = self.lambda_inc / (torch.exp(torch.abs(selected_cosine)) + 1e-6) + return weights + + def forward(self, scores, target): + """ + 前向传播,计算加权后的 GAN 损失 + Args: + scores: [B, N, D],判别器的预测得分 + target: 目标标签(True 或 False) + Returns: + weighted_loss: 加权后的 GAN 损失 + weight: 权重图 [B, N] + """ + # 计算原始 GAN 损失(假设 criterionGAN 返回 [B, N] 的损失分布) + loss = self.criterionGAN(scores, target) + + # 捕获 scores 的梯度,形状为 [B, N, D] + grad_scores = torch.autograd.grad(loss, scores, retain_graph=True)[0] + + # 计算整体平均梯度(在 N 维度上求均值) + grad_mean = torch.mean(grad_scores, dim=1) # [B, D] + + # 计算余弦相似度 δ_i + cosine = self.compute_cosine_similarity(grad_scores, grad_mean) # [B, N] + + # 生成权重图 w_i + weight = self.generate_weight_map(cosine) # [B, N] + + # 计算加权后的 GAN 损失 + weighted_loss = torch.mean(weight * self.criterionGAN(scores, target)) + + return weighted_loss, weight + + +class ContentAwareTemporalNorm(nn.Module): + def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0): + super().__init__() + self.gamma_stride = gamma_stride # 控制整体运动幅度 + self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层 + + def upsample_weight_map(self, weight_patch, target_size=(256, 256)): + # weight_patch: [B, 1, H, W] 来自转换后的 weight_map + weight_full = F.interpolate( + weight_patch, + size=target_size, + mode='bilinear', # 或 'nearest',根据需求选择 + align_corners=False + ) + return weight_full + + def forward(self, weight_map): + """ + 生成内容感知光流 + Args: + weight_map: [B, N] 权重图(来自 ContentAwareOptimization),其中 N=576 + Returns: + F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) + """ + B = weight_map.shape[0] + N = weight_map.shape[1] + # 假设 N 为完全平方数,计算边长(例如 576 -> 24x24) + side = int(math.sqrt(N)) + weight_map_2d = weight_map.view(B, 1, side, side) # 转换为 [B, 1, side, side] + + # 上采样权重图到全分辨率 + weight_full = self.upsample_weight_map(weight_map_2d) # [B, 1, 256, 256](例如) + + # 归一化权重图(L1归一化) + weight_norm = F.normalize(weight_full, p=1, dim=(2,3)) + + # 生成高斯噪声 + B, _, H, W = weight_norm.shape + z = torch.randn(B, 2, H, W, device=weight_norm.device) + + # 合成基础光流 + weight_expanded = weight_norm.expand(-1, 2, -1, -1) + F_raw = self.gamma_stride * weight_expanded * z + + # 平滑处理 + F_smooth = self.smoother(F_raw) + + # 动态范围调整 + F_content = torch.tanh(F_smooth) + + return F_content +class RomaUnsbSingleModel(BaseModel): + @staticmethod + def modify_commandline_options(parser, is_train=True): + """配置 CTNx 模型的特定选项""" + + parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))') + + parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm') + parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator') + parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') + parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency') + parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization') + parser.add_argument('--local_nums', type=int, default=64, help='number of local patches') + parser.add_argument('--side_length', type=int, default=7) + parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers') + + parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions') + parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix') + parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers') + + parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter') + parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer') + + parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers') + + opt, _ = parser.parse_known_args() + + return parser + + def __init__(self, opt): + BaseModel.__init__(self, opt) + + + self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn'] + self.visual_names = ['real_A', 'fake_B', 'real_B'] + self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')] + + + if self.isTrain: + self.model_names = ['G', 'D_ViT'] + else: # during test time, only load G + self.model_names = ['G'] + + + # define networks (both generator and discriminator) + self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt) + + + if self.isTrain: + + self.netD_ViT = networks.MLPDiscriminator().to(self.device) + # self.netPreViT = timm.create_model("vit_base_patch32_384",pretrained=True).to(self.device) + self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device) + + + self.resize = tfs.Resize(size=(384,384)) + # self.resize = tfs.Resize(size=(224, 224)) + + # define loss functions + self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) + + self.criterionL1 = torch.nn.L1Loss().to(self.device) + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) + self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D_ViT) + + self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数 + self.ctn = ContentAwareTemporalNorm() #生成的伪光流 + def data_dependent_initialize(self, data): + """ + The feature network netF is defined in terms of the shape of the intermediate, extracted + features of the encoder portion of netG. Because of this, the weights of netF are + initialized at the first feedforward pass with some input images. + Please also see PatchSampleF.create_mlp(), which is called at the first forward() call. + """ + pass + + + def optimize_parameters(self): + # forward + self.forward() + + # update D + self.set_requires_grad(self.netD_ViT, True) + self.optimizer_D_ViT.zero_grad() + self.loss_D = self.compute_D_loss() + self.loss_D.backward() + self.optimizer_D_ViT.step() + + # update G + self.set_requires_grad(self.netD_ViT, False) + self.optimizer_G.zero_grad() + self.loss_G = self.compute_G_loss() + self.loss_G.backward() + self.optimizer_G.step() + + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + Parameters: + input (dict): include the data itself and its metadata information. + The option 'direction' can be used to swap domain A and domain B. + """ + AtoB = self.opt.direction == 'AtoB' + self.real_A = input['A' if AtoB else 'B'].to(self.device) + self.real_B = input['B' if AtoB else 'A'].to(self.device) + self.image_paths = input['A_paths' if AtoB else 'B_paths'] + + def forward(self): + """Run forward pass; called by both functions and .""" + self.fake_B = self.netG(self.real_A) + + if self.opt.isTrain: + real_A = self.real_A + real_B = self.real_B + fake_B = self.fake_B + self.real_A_resize = self.resize(real_A) + real_B = self.resize(real_B) + self.fake_B_resize = self.resize(fake_B) + self.mutil_real_A_tokens = self.netPreViT(self.real_A_resize, self.atten_layers, get_tokens=True) + self.mutil_real_B_tokens = self.netPreViT(real_B, self.atten_layers, get_tokens=True) + self.mutil_fake_B_tokens = self.netPreViT(self.fake_B_resize, self.atten_layers, get_tokens=True) + + + def compute_D_loss(self): + """Calculate GAN loss for the discriminator""" + + + lambda_D_ViT = self.opt.lambda_D_ViT + fake_B_tokens = self.mutil_fake_B_tokens[0].detach() + real_B_tokens = self.mutil_real_B_tokens[0] + pre_fake_ViT = self.netD_ViT(fake_B_tokens) + pred_real_ViT = self.netD_ViT(real_B_tokens) + + self.loss_D_real_ViT , self.weight_real = self.cao(pred_real_ViT, True) + self.loss_D_fake_ViT , self.weight_fake = self.cao(pre_fake_ViT, False) + + self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5* lambda_D_ViT + + + return self.loss_D_ViT + + def compute_G_loss(self): + if self.opt.lambda_ctn > 0.0: + # 生成光流图(使用判别器的权重) + self.f_content = self.ctn(self.weight_fake.detach()) + + # 变换后的图片 + self.warped_real_A = warp(self.real_A, self.f_content) + self.warped_fake_B = warp(self.fake_B, self.f_content) + # 第二次生成 + self.warped_fake_B2 = self.netG(self.warped_real_A) + + # 计算损失 + self.loss_ctn = self.criterionL1(self.warped_fake_B, self.warped_fake_B2) * self.opt.lambda_ctn + else: + self.loss_ctn = 0.0 + + # if self.opt.lambda_GAN > 0.0: + + # fake_B_tokens = self.mutil_fake_B_tokens[0] + # pred_fake_ViT = self.netD_ViT(fake_B_tokens) + # self.loss_G_GAN = self.criterionGAN(pred_fake_ViT, True) * self.opt.lambda_GAN + # else: + # self.loss_G_GAN = 0.0 + if self.opt.lambda_GAN > 0.0: + + fake_B_tokens = self.mutil_fake_B_tokens[0] + pred_fake_ViT = self.netD_ViT(fake_B_tokens) + self.loss_G_fake_ViT , self.weight_real = self.cao(pred_fake_ViT, True) + self.loss_G_GAN = self.loss_G_fake_ViT * self.opt.lambda_GAN + else: + self.loss_G_GAN = 0.0 + if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0: + self.loss_global, self.loss_spatial = self.calculate_attention_loss() + else: + self.loss_global, self.loss_spatial = 0.0, 0.0 + + + + self.loss_G = self.loss_G_GAN + self.loss_global + self.loss_spatial + self.loss_ctn + return self.loss_G + + def calculate_attention_loss(self): + n_layers = len(self.atten_layers) + mutil_real_A_tokens = self.mutil_real_A_tokens + mutil_fake_B_tokens = self.mutil_fake_B_tokens + + + + if self.opt.lambda_global > 0.0: + loss_global = self.calculate_similarity(mutil_real_A_tokens, mutil_fake_B_tokens) + + + else: + loss_global = 0.0 + + if self.opt.lambda_spatial > 0.0: + loss_spatial = 0.0 + local_nums = self.opt.local_nums + tokens_cnt = 576 + local_id = np.random.permutation(tokens_cnt) + local_id = local_id[:int(min(local_nums, tokens_cnt))] + + mutil_real_A_local_tokens = self.netPreViT(self.real_A_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length) + + mutil_fake_B_local_tokens = self.netPreViT(self.fake_B_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length) + + loss_spatial = self.calculate_similarity(mutil_real_A_local_tokens, mutil_fake_B_local_tokens) + + + else: + loss_spatial = 0.0 + + + + return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial + + def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens): + loss = 0.0 + n_layers = len(self.atten_layers) + + for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens): + + src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1)) + tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1)) + cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1) + loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean() + + loss = loss / n_layers + return loss + diff --git a/scripts/traincp.sh b/scripts/traincp.sh index 0125243..6ebb804 100644 --- a/scripts/traincp.sh +++ b/scripts/traincp.sh @@ -1,15 +1,15 @@ python train.py \ - --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ - --name cp_5 \ - --dataset_mode unaligned_double \ - --display_env CP \ - --model roma_unsb \ - --lambda_ctn 0 \ + --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Single/Monitor \ + --name cp_1 \ + --dataset_mode unaligned \ + --display_env NEWCP \ + --model roma_unsb_single \ + --lambda_ctn 10 \ --lambda_inc 8.0 \ --lambda_global 6.0 \ --lambda_spatial 6.0 \ --gamma_stride 20 \ - --lr 0.000005 \ + --lr 0.000002 \ --gpu_id 0 \ --eta_ratio 0.4 \ --n_epochs 100 \ @@ -18,4 +18,6 @@ python train.py \ # cp2 修了一下cp1的代码,--lr 0.000002 # cp3 加了--lambda_inc 8.0 --gpu_id 2 # cp4 在cp3的基础上把梯度增强给到了生成器中的ganloss --gpu_id 1 -# cp5 在cp3的基础上,--lambda_ctn 0 ,--gpu_id 0.--lr 0.000005 \ No newline at end of file +# cp5 在cp3的基础上,--lambda_ctn 0 ,--gpu_id 0.--lr 0.000005 +# # newcp1 重新调整了光流算法,并且弄成单帧的脚本了,这一次是最终的复现了。--gpu_id 0 +# # newcp2 把梯度图对loss的影响同样加到了G_GAN中。 \ No newline at end of file diff --git a/train.py b/train.py index 8cd245a..47e446e 100644 --- a/train.py +++ b/train.py @@ -33,7 +33,7 @@ if __name__ == '__main__': if total_iters % opt.print_freq == 0: t_data = iter_start_time - iter_data_time - batch_size = data["A0"].size(0) + batch_size = data["A"].size(0) total_iters += batch_size epoch_iter += batch_size if len(opt.gpu_ids) > 0: