import numpy as np import math import timm import torch import torchvision.models as models import torch.nn as nn import torch.nn.functional as F from torchvision.transforms import GaussianBlur from .base_model import BaseModel from . import networks from .patchnce import PatchNCELoss import util.util as util from torchvision.transforms import transforms as tfs def warp(image, flow): #warp操作 """ 基于光流的图像变形函数 Args: image: [B, C, H, W] 输入图像 flow: [B, 2, H, W] 光流场(x/y方向位移) Returns: warped: [B, C, H, W] 变形后的图像 """ B, C, H, W = image.shape # 生成网格坐标 grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H)) grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W] grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W] # 应用光流位移(归一化到[-1,1]) new_grid = grid + flow new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向 new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向 new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2] # 双线性插值 return F.grid_sample(image, new_grid, align_corners=True) class ContentAwareOptimization(nn.Module): def __init__(self, lambda_inc=2.0, eta_ratio=0.4): super().__init__() self.lambda_inc = lambda_inc # 控制内容丰富区域的权重增量 self.eta_ratio = eta_ratio # 选择内容丰富区域的比例 self.criterionGAN = networks.GANLoss('lsgan').cuda() # 使用 LSGAN 损失 def compute_cosine_similarity(self, grad_patch, grad_mean): """ 计算每个 patch 梯度与整体平均梯度的余弦相似度 Args: grad_patch: [B, 1, H, W],每个 patch 的梯度(基于 scores) grad_mean: [B, 1],整体平均梯度 Returns: cosine: [B, 1, H, W],余弦相似度 δ_i """ B, _, H, W = grad_patch.shape grad_patch = grad_patch.view(B, 1, -1) # [B, 1, H*W] grad_mean = grad_mean.unsqueeze(-1) # [B, 1, 1] # 计算余弦相似度 cosine = F.cosine_similarity(grad_patch, grad_mean, dim=1) # [B, H*W] return cosine.view(B, 1, H, W) def generate_weight_map(self, cosine): """ 根据余弦相似度生成权重图 Args: cosine: [B, 1, H, W],余弦相似度 δ_i Returns: weights: [B, 1, H, W],权重图 w_i """ B, _, H, W = cosine.shape cosine_flat = cosine.view(B, -1) # [B, H*W] k = int(self.eta_ratio * cosine_flat.size(1)) # 选择 eta_ratio 比例的 patch _, indices = torch.topk(-cosine_flat, k, dim=1) # 选择偏离最大的 k 个 patch weights = torch.ones_like(cosine_flat) for b in range(B): selected_cosine = cosine_flat[b, indices[b]] weights[b, indices[b]] = self.lambda_inc / (torch.exp(torch.abs(selected_cosine)) + 1e-6) return weights.view(B, 1, H, W) def forward(self, scores, target): """ 前向传播,计算加权后的 GAN 损失 Args: scores: [B, 1, H, W],判别器的预测得分 target: 目标标签(True 或 False) Returns: weighted_loss: 加权后的 GAN 损失 weight: 权重图 [B, 1, H, W] """ # 计算原始 GAN 损失 loss = self.criterionGAN(scores, target) # 捕获特征的梯度 grad_scores = torch.autograd.grad(loss, scores, retain_graph=True)[0] # [B, C, H, W] # 计算整体平均梯度 grad_mean = torch.mean(grad_scores, dim=(2, 3)) # [B, 1] # 计算余弦相似度 δ_i(公式 5) cosine = self.compute_cosine_similarity(grad_scores, grad_mean) # [B, 1, H, W] # 生成权重图 w_i(公式 6) weight = self.generate_weight_map(cosine) # 应用权重到损失(公式 7 的部分实现) weighted_loss = torch.mean(weight * self.criterionGAN(scores, target)) return weighted_loss, weight class ContentAwareTemporalNorm(nn.Module): def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0): super().__init__() self.gamma_stride = gamma_stride # 控制整体运动幅度 self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层 def upsample_weight_map(self, weight_patch, target_size=(256, 256)): # weight_patch: [B, 1, 30, 30] 来自 PatchGAN weight_full = F.interpolate( weight_patch, size=target_size, mode='bilinear', # 或 'nearest',根据需求选择 align_corners=False ) return weight_full # [B, 1, 256, 256] def forward(self, weight_map): """ 生成内容感知光流 Args: weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块) Returns: F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) """ # 上采样权重图到全分辨率 weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384] # 1. 归一化权重图 # 保持区域相对强度,同时限制数值范围 weight_norm = F.normalize(weight_full, p=1, dim=(2,3)) # L1归一化 [B,1,H,W] # 2. 生成高斯噪声 B, _, H, W = weight_norm.shape z = torch.randn(B, 2, H, W, device=weight_norm.device) # [B,2,H,W] # 3. 合成基础光流 # 将权重图扩展为2通道(x/y方向共享权重) weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W] F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9 # 4. 平滑处理(保持结构连续性) # 对每个通道独立进行高斯模糊 F_smooth = self.smoother(F_raw) # [B,2,H,W] # 5. 动态范围调整(可选) # 限制光流幅值,避免极端位移 F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围 return F_content class RomaUnsbModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train=True): """配置 CTNx 模型的特定选项""" parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))') parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm') parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator') parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency') parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization') parser.add_argument('--local_nums', type=int, default=64, help='number of local patches') parser.add_argument('--side_length', type=int, default=7) parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers') parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions') parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix') parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers') parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter') parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer') parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers') opt, _ = parser.parse_known_args() return parser def __init__(self, opt): """初始化 CTNx 模型""" BaseModel.__init__(self, opt) # 指定需要打印的训练损失 self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn'] self.visual_names = ['real_A0', 'fake_B0', 'real_B0','real_A1', 'fake_B1', 'real_B1'] self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')] if self.opt.phase == 'test': self.visual_names = ['real'] for NFE in range(self.opt.num_timesteps): fake_name = 'fake_' + str(NFE+1) self.visual_names.append(fake_name) self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')] if self.isTrain: self.model_names = ['G', 'D_ViT'] else: self.model_names = ['G'] # 创建网络 self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt) if self.isTrain: self.resize = tfs.Resize(size=(384,384), antialias=True) self.netD_ViT = networks.MLPDiscriminator().to(self.device) # 加入预训练VIT self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device) # 定义损失函数 self.criterionL1 = torch.nn.L1Loss().to(self.device) self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizers = [self.optimizer_G, self.optimizer_D] self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数 self.ctn = ContentAwareTemporalNorm() #生成的伪光流 def data_dependent_initialize(self, data): """ The feature network netF is defined in terms of the shape of the intermediate, extracted features of the encoder portion of netG. Because of this, the weights of netF are initialized at the first feedforward pass with some input images. Please also see PatchSampleF.create_mlp(), which is called at the first forward() call. """ pass def optimize_parameters(self): # forward self.forward() self.netG.train() self.netD_ViT.train() # update D self.set_requires_grad(self.netD_ViT, True) self.optimizer_D.zero_grad() self.loss_D = self.compute_D_loss() self.loss_D.backward() self.optimizer_D.step() # update G self.set_requires_grad(self.netD_ViT, False) self.optimizer_G.zero_grad() self.loss_G = self.compute_G_loss() self.loss_G.backward() self.optimizer_G.step() def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. The option 'direction' can be used to swap domain A and domain B. """ AtoB = self.opt.direction == 'AtoB' self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device) self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device) self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device) self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): """Run forward pass; called by both functions and .""" self.fake_B0 = self.netG(self.real_A0) self.fake_B1 = self.netG(self.real_A1) if self.opt.isTrain: real_A0 = self.real_A0 real_A1 = self.real_A1 real_B0 = self.real_B0 real_B1 = self.real_B1 fake_B0 = self.fake_B0 fake_B1 = self.fake_B1 self.real_A0_resize = self.resize(real_A0) self.real_A1_resize = self.resize(real_A1) real_B0 = self.resize(real_B0) real_B1 = self.resize(real_B1) self.fake_B0_resize = self.resize(fake_B0) self.fake_B1_resize = self.resize(fake_B1) self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True) self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True) self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True) self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True) self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) def compute_D_loss(self): """Calculate GAN loss with Content-Aware Optimization""" lambda_D_ViT = self.opt.lambda_D_ViT # 处理 real_B0 和 fake_B0 real_B0_tokens = self.mutil_real_B0_tokens[0] pred_real0 = self.netD_ViT(real_B0_tokens) print(pred_real0.shape) fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach() pred_fake0 = self.netD_ViT(fake_B0_tokens) loss_real0, self.weight_real0 = self.cao( pred_real0, True) loss_fake0, self.weight_fake0 = self.cao( pred_fake0, False) # 处理 real_B1 和 fake_B1 real_B1_tokens = self.mutil_real_B1_tokens[0] pred_real1 = self.netD_ViT(real_B1_tokens) fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach() pred_fake1 = self.netD_ViT(fake_B1_tokens) loss_real1, self.weight_real1 = self.cao( pred_real1, True) loss_fake1, self.weight_fake1 = self.cao( pred_fake1, False) # 综合损失 self.loss_D_ViT = (loss_real0 + loss_fake0 + loss_real1 + loss_fake1) * 0.25 * lambda_D_ViT return self.loss_D_ViT def compute_G_loss(self): """计算生成器的损失""" # 初始化总损失 self.loss_G_GAN = 0.0 self.loss_ctn = 0.0 self.loss_global = 0.0 self.loss_spatial = 0.0 # 计算 CTN 损失 if self.opt.lambda_ctn > 0.0: # 生成光流图(使用判别器的权重) self.f_content0 = self.ctn(self.weight_fake0.detach()) self.f_content1 = self.ctn(self.weight_fake1.detach()) # 变换后的图片 self.warped_real_A0 = warp(self.real_A0, self.f_content0) self.warped_real_A1 = warp(self.real_A1, self.f_content1) self.warped_fake_B0 = warp(self.fake_B0, self.f_content0) self.warped_fake_B1 = warp(self.fake_B1, self.f_content1) # 第二次生成 self.warped_fake_B0_2 = self.netG(self.warped_real_A0) self.warped_fake_B1_2 = self.netG(self.warped_real_A1) # 计算 L2 损失 self.loss_ctn0 = F.mse_loss(self.warped_fake_B0_2, self.warped_fake_B0) self.loss_ctn1 = F.mse_loss(self.warped_fake_B1_2, self.warped_fake_B1) self.loss_ctn = (self.loss_ctn0 + self.loss_ctn1) * 0.5 # 计算 GAN 损失(引入 ContentAwareOptimization) if self.opt.lambda_GAN > 0.0: pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens[0]) pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens[0]) self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean() self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean() self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5 else: self.loss_G_GAN = 0.0 if self.opt.lambda_global or self.opt.lambda_spatial > 0.0: self.loss_global, self.loss_spatial = self.calculate_attention_loss() # 总损失 self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \ self.opt.lambda_ctn * self.loss_ctn + \ self.opt.lambda_global * self.loss_global + \ self.opt.lambda_spatial * self.loss_spatial return self.loss_G def calculate_attention_loss(self): n_layers = len(self.atten_layers) mutil_real_A0_tokens = self.mutil_real_A0_tokens mutil_real_A1_tokens = self.mutil_real_A1_tokens mutil_fake_B0_tokens = self.mutil_fake_B0_tokens mutil_fake_B1_tokens = self.mutil_fake_B1_tokens if self.opt.lambda_global > 0.0: loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens) loss_global *= 0.5 else: loss_global = 0.0 if self.opt.lambda_spatial > 0.0: loss_spatial = 0.0 local_nums = self.opt.local_nums tokens_cnt = 576 local_id = np.random.permutation(tokens_cnt) local_id = local_id[:int(min(local_nums, tokens_cnt))] mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length) mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length) mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length) mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length) loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens) loss_spatial *= 0.5 else: loss_spatial = 0.0 return loss_global , loss_spatial def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens): loss = 0.0 n_layers = len(self.atten_layers) for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens): src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1)) tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1)) cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1) loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean() loss = loss / n_layers return loss