diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index d54e232..9194daf 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -13,6 +13,7 @@ import util.util as util from torchvision.transforms import transforms as tfs + def warp(image, flow): #warp操作 """ 基于光流的图像变形函数 @@ -37,76 +38,77 @@ def warp(image, flow): #warp操作 # 双线性插值 return F.grid_sample(image, new_grid, align_corners=True) -# 时序归一化损失计算 -def compute_ctn_loss(G, x, F_content): #公式10 - """ - 计算内容感知时序归一化损失 - Args: - G: 生成器 - x: 输入红外图像 [B,C,H,W] - F_content: 生成的光流场 [B,2,H,W] - """ - - # 生成可见光图像 - y_fake = G(x) # [B,3,H,W] - - # 对生成结果应用光流变形 - warped_fake = warp(y_fake, F_content) # [B,3,H,W] - - # 对输入应用相同光流后生成图像 - warped_x = warp(x, F_content) # [B,C,H,W] - y_fake_warped = G(warped_x) # [B,3,H,W] - - # 计算L2损失 - loss = F.mse_loss(warped_fake, y_fake_warped) - return loss - class ContentAwareOptimization(nn.Module): def __init__(self, lambda_inc=2.0, eta_ratio=0.4): super().__init__() - self.lambda_inc = lambda_inc - self.eta_ratio = eta_ratio - self.gradients = [] # 修改为单一梯度列表,通用性更强 - self.criterionGAN = networks.GANLoss('lsgan').cuda() + self.lambda_inc = lambda_inc # 控制内容丰富区域的权重增量 + self.eta_ratio = eta_ratio # 选择内容丰富区域的比例 + self.criterionGAN = networks.GANLoss('lsgan').cuda() # 使用 LSGAN 损失 - def compute_cosine_similarity(self, gradients): - mean_grad = torch.mean(gradients, dim=1, keepdim=True) - return F.cosine_similarity(gradients, mean_grad, dim=2) - - def generate_weight_map(self, gradients): - cosine = self.compute_cosine_similarity(gradients) - k = int(self.eta_ratio * cosine.shape[1]) - _, indices = torch.topk(-cosine, k, dim=1) - weights = torch.ones_like(cosine) - weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices)))) - return weights - - def forward(self, features, scores, target): + def compute_cosine_similarity(self, grad_patch, grad_mean): """ + 计算每个 patch 梯度与整体平均梯度的余弦相似度 Args: - features: 特征张量(可以是判别器的 real/fake 特征,或生成器的 fake 特征) - scores: 判别器对特征的预测得分 - target: 目标标签(True 表示希望判为真,False 表示希望判为假) + grad_patch: [B, 1, H, W],每个 patch 的梯度(基于 scores) + grad_mean: [B, 1],整体平均梯度 Returns: - loss: 加权后的 GAN 损失 - weight: 生成的权重图 + cosine: [B, 1, H, W],余弦相似度 δ_i """ - self.gradients.clear() - # 注册梯度钩子 - hook = lambda grad: self.gradients.append(grad.detach()) - features.register_hook(hook) + B, _, H, W = grad_patch.shape + grad_patch = grad_patch.view(B, 1, -1) # [B, 1, H*W] + grad_mean = grad_mean.unsqueeze(-1) # [B, 1, 1] + # 计算余弦相似度 + cosine = F.cosine_similarity(grad_patch, grad_mean, dim=1) # [B, H*W] + return cosine.view(B, 1, H, W) + + def generate_weight_map(self, cosine): + """ + 根据余弦相似度生成权重图 + Args: + cosine: [B, 1, H, W],余弦相似度 δ_i + Returns: + weights: [B, 1, H, W],权重图 w_i + """ + B, _, H, W = cosine.shape + cosine_flat = cosine.view(B, -1) # [B, H*W] + k = int(self.eta_ratio * cosine_flat.size(1)) # 选择 eta_ratio 比例的 patch + _, indices = torch.topk(-cosine_flat, k, dim=1) # 选择偏离最大的 k 个 patch + weights = torch.ones_like(cosine_flat) + for b in range(B): + selected_cosine = cosine_flat[b, indices[b]] + weights[b, indices[b]] = self.lambda_inc / (torch.exp(torch.abs(selected_cosine)) + 1e-6) + return weights.view(B, 1, H, W) + + def forward(self, scores, target): + """ + 前向传播,计算加权后的 GAN 损失 + Args: + scores: [B, 1, H, W],判别器的预测得分 + target: 目标标签(True 或 False) + Returns: + weighted_loss: 加权后的 GAN 损失 + weight: 权重图 [B, 1, H, W] + """ + # 计算原始 GAN 损失 + loss = self.criterionGAN(scores, target) - # 触发梯度计算 - scores.mean().backward(retain_graph=True) + # 捕获特征的梯度 + grad_scores = torch.autograd.grad(loss, scores, retain_graph=True)[0] # [B, C, H, W] - # 获取梯度并调整维度 - grad = self.gradients[0].flatten(1) # [B, N, D] → [B, N*D] - weight = self.generate_weight_map(grad.view(*features.shape)) + # 计算整体平均梯度 + grad_mean = torch.mean(grad_scores, dim=(2, 3)) # [B, 1] - # 计算加权 GAN 损失 - loss = torch.mean(weight * self.criterionGAN(scores, target)) - return loss, weight + # 计算余弦相似度 δ_i(公式 5) + cosine = self.compute_cosine_similarity(grad_scores, grad_mean) # [B, 1, H, W] + + # 生成权重图 w_i(公式 6) + weight = self.generate_weight_map(cosine) + + # 应用权重到损失(公式 7 的部分实现) + weighted_loss = torch.mean(weight * self.criterionGAN(scores, target)) + + return weighted_loss, weight class ContentAwareTemporalNorm(nn.Module): def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0): @@ -115,31 +117,14 @@ class ContentAwareTemporalNorm(nn.Module): self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层 def upsample_weight_map(self, weight_patch, target_size=(256, 256)): - """ - 将patch级别的权重图上采样到目标分辨率 - Args: - weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图 - target_size: 目标分辨率 (H, W) - Returns: - weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图 - """ - # 使用双线性插值上采样 - B = weight_patch.shape[0] - weight_patch = weight_patch.view(B, 1, 24, 24) - + # weight_patch: [B, 1, 30, 30] 来自 PatchGAN weight_full = F.interpolate( weight_patch, size=target_size, - mode='bilinear', + mode='bilinear', # 或 'nearest',根据需求选择 align_corners=False ) - - # 对每个16x16的patch内部保持权重一致(可选) - # 通过平均池化再扩展,消除插值引入的渐变 - weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16) - weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest') - - return weight_full + return weight_full # [B, 1, 256, 256] def forward(self, weight_map): """ @@ -175,6 +160,7 @@ class ContentAwareTemporalNorm(nn.Module): return F_content + class RomaUnsbModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train=True): @@ -327,21 +313,22 @@ class RomaUnsbModel(BaseModel): # 处理 real_B0 和 fake_B0 real_B0_tokens = self.mutil_real_B0_tokens[0] - pred_real0, real_features0 = self.netD_ViT(real_B0_tokens) + pred_real0 = self.netD_ViT(real_B0_tokens) + print(pred_real0.shape) fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach() - pred_fake0, fake_features0 = self.netD_ViT(fake_B0_tokens) + pred_fake0 = self.netD_ViT(fake_B0_tokens) - loss_real0, self.weight_real0 = self.cao(real_features0, pred_real0, True) - loss_fake0, self.weight_fake0 = self.cao(fake_features0, pred_fake0, False) + loss_real0, self.weight_real0 = self.cao( pred_real0, True) + loss_fake0, self.weight_fake0 = self.cao( pred_fake0, False) # 处理 real_B1 和 fake_B1 real_B1_tokens = self.mutil_real_B1_tokens[0] - pred_real1, real_features1 = self.netD_ViT(real_B1_tokens) + pred_real1 = self.netD_ViT(real_B1_tokens) fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach() - pred_fake1, fake_features1 = self.netD_ViT(fake_B1_tokens) + pred_fake1 = self.netD_ViT(fake_B1_tokens) - loss_real1, self.weight_real1 = self.cao(real_features1, pred_real1, True) - loss_fake1, self.weight_fake1 = self.cao(fake_features1, pred_fake1, False) + loss_real1, self.weight_real1 = self.cao( pred_real1, True) + loss_fake1, self.weight_fake1 = self.cao( pred_fake1, False) # 综合损失 self.loss_D_ViT = (loss_real0 + loss_fake0 + loss_real1 + loss_fake1) * 0.25 * lambda_D_ViT