2025-02-22 14:21:54 +08:00
|
|
|
import numpy as np
|
|
|
|
|
import math
|
|
|
|
|
import timm
|
|
|
|
|
import torch
|
2025-03-07 18:43:06 +08:00
|
|
|
import torchvision.models as models
|
2025-02-22 14:21:54 +08:00
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from torchvision.transforms import GaussianBlur
|
|
|
|
|
from .base_model import BaseModel
|
|
|
|
|
from . import networks
|
|
|
|
|
from .patchnce import PatchNCELoss
|
|
|
|
|
import util.util as util
|
|
|
|
|
|
|
|
|
|
from torchvision.transforms import transforms as tfs
|
|
|
|
|
|
|
|
|
|
def warp(image, flow): #warp操作
|
|
|
|
|
"""
|
|
|
|
|
基于光流的图像变形函数
|
|
|
|
|
Args:
|
|
|
|
|
image: [B, C, H, W] 输入图像
|
2025-02-23 15:22:14 +08:00
|
|
|
flow: [B, 2, H, W] 光流场(x/y方向位移)
|
2025-02-22 14:21:54 +08:00
|
|
|
Returns:
|
|
|
|
|
warped: [B, C, H, W] 变形后的图像
|
|
|
|
|
"""
|
|
|
|
|
B, C, H, W = image.shape
|
|
|
|
|
# 生成网格坐标
|
|
|
|
|
grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H))
|
|
|
|
|
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
|
|
|
|
|
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
|
|
|
|
|
|
2025-02-23 15:23:00 +08:00
|
|
|
# 应用光流位移(归一化到[-1,1])
|
2025-02-22 14:21:54 +08:00
|
|
|
new_grid = grid + flow
|
|
|
|
|
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
|
|
|
|
|
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
|
|
|
|
|
new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2]
|
|
|
|
|
|
|
|
|
|
# 双线性插值
|
|
|
|
|
return F.grid_sample(image, new_grid, align_corners=True)
|
|
|
|
|
|
|
|
|
|
# 时序归一化损失计算
|
|
|
|
|
def compute_ctn_loss(G, x, F_content): #公式10
|
|
|
|
|
"""
|
|
|
|
|
计算内容感知时序归一化损失
|
|
|
|
|
Args:
|
|
|
|
|
G: 生成器
|
|
|
|
|
x: 输入红外图像 [B,C,H,W]
|
|
|
|
|
F_content: 生成的光流场 [B,2,H,W]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# 生成可见光图像
|
|
|
|
|
y_fake = G(x) # [B,3,H,W]
|
|
|
|
|
|
|
|
|
|
# 对生成结果应用光流变形
|
|
|
|
|
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
|
|
|
|
|
|
|
|
|
|
# 对输入应用相同光流后生成图像
|
|
|
|
|
warped_x = warp(x, F_content) # [B,C,H,W]
|
|
|
|
|
y_fake_warped = G(warped_x) # [B,3,H,W]
|
|
|
|
|
|
|
|
|
|
# 计算L2损失
|
|
|
|
|
loss = F.mse_loss(warped_fake, y_fake_warped)
|
|
|
|
|
return loss
|
|
|
|
|
|
2025-03-07 18:43:06 +08:00
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
class ContentAwareOptimization(nn.Module):
|
|
|
|
|
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
|
|
|
|
|
super().__init__()
|
2025-03-07 18:43:06 +08:00
|
|
|
self.lambda_inc = lambda_inc
|
|
|
|
|
self.eta_ratio = eta_ratio
|
|
|
|
|
self.criterionGAN=networks.GANLoss('lsgan').cuda()
|
2025-03-09 23:30:05 +08:00
|
|
|
|
|
|
|
|
def generate_weight_map(self, attn_real, attn_fake):
|
|
|
|
|
# attn_real, attn_fake: [B, N],自注意力权重
|
|
|
|
|
# 归一化注意力权重
|
|
|
|
|
weight_real = F.normalize(attn_real, p=1, dim=1) # [B, N]
|
|
|
|
|
weight_fake = F.normalize(attn_fake, p=1, dim=1) # [B, N]
|
|
|
|
|
|
|
|
|
|
# 对真实图像权重处理
|
|
|
|
|
k = int(self.eta_ratio * weight_real.shape[1])
|
|
|
|
|
values_real, indices_real = torch.topk(weight_real, k, dim=1)
|
|
|
|
|
weight_real_enhanced = torch.ones_like(weight_real)
|
|
|
|
|
weight_real_enhanced.scatter_(1, indices_real, self.lambda_inc / (values_real + 1e-6))
|
|
|
|
|
# 对生成图像权重处理
|
|
|
|
|
values_fake, indices_fake = torch.topk(weight_fake, k, dim=1)
|
|
|
|
|
weight_fake_enhanced = torch.ones_like(weight_fake)
|
|
|
|
|
weight_fake_enhanced.scatter_(1, indices_fake, self.lambda_inc / (values_fake + 1e-6))
|
|
|
|
|
|
|
|
|
|
return weight_real_enhanced, weight_fake_enhanced
|
2025-03-07 18:43:06 +08:00
|
|
|
|
2025-03-09 23:30:05 +08:00
|
|
|
def forward(self,real_scores, fake_scores, attn_real, attn_fake):
|
|
|
|
|
# real_scores, fake_scores: 判别器预测得分 [B, 1]
|
|
|
|
|
# attn_real, attn_fake: 自注意力权重 [B, N]
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
# 生成权重图
|
2025-03-09 23:30:05 +08:00
|
|
|
weight_real, weight_fake = self.generate_weight_map(attn_real, attn_fake)
|
2025-02-22 14:21:54 +08:00
|
|
|
|
2025-03-09 23:30:05 +08:00
|
|
|
# 应用权重到 GAN 损失
|
|
|
|
|
loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores, True))
|
|
|
|
|
loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores, False))
|
2025-03-07 18:43:06 +08:00
|
|
|
|
2025-03-09 23:30:05 +08:00
|
|
|
# 总损失
|
|
|
|
|
loss_co_adv = (loss_co_real + loss_co_fake) * 0.5
|
2025-03-07 18:43:06 +08:00
|
|
|
return loss_co_adv, weight_real, weight_fake
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
class ContentAwareTemporalNorm(nn.Module):
|
|
|
|
|
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.gamma_stride = gamma_stride # 控制整体运动幅度
|
|
|
|
|
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
|
|
|
|
|
2025-02-26 22:07:11 +08:00
|
|
|
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
|
2025-03-09 23:30:05 +08:00
|
|
|
# 如果 weight_patch 是 [N, 1] 形状(例如 [576, 1]),添加批次维度
|
|
|
|
|
if weight_patch.dim() == 2 and weight_patch.shape[1] == 1:
|
|
|
|
|
weight_patch = weight_patch.unsqueeze(0) # 变为 [1, 576, 1]
|
|
|
|
|
|
|
|
|
|
# 获取调整后的形状
|
|
|
|
|
B, N, _ = weight_patch.shape # 例如 B=1, N=576
|
|
|
|
|
if N != 576:
|
|
|
|
|
raise ValueError(f"预期 patch 数量 N=576 (24x24),但实际得到 N={N}")
|
2025-02-26 22:07:11 +08:00
|
|
|
|
2025-03-09 23:30:05 +08:00
|
|
|
# 重塑为 [B, 1, 24, 24]
|
|
|
|
|
weight_patch = weight_patch.view(B, 1, 24, 24) # [1, 1, 24, 24]
|
|
|
|
|
|
|
|
|
|
# 使用双线性插值上采样到目标大小
|
2025-02-26 22:07:11 +08:00
|
|
|
weight_full = F.interpolate(
|
|
|
|
|
weight_patch,
|
|
|
|
|
size=target_size,
|
|
|
|
|
mode='bilinear',
|
|
|
|
|
align_corners=False
|
|
|
|
|
)
|
|
|
|
|
|
2025-03-09 23:30:05 +08:00
|
|
|
# 可选:保持每个 16x16 patch 内部权重一致
|
2025-02-26 22:07:11 +08:00
|
|
|
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
|
|
|
|
|
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
|
|
|
|
|
|
|
|
|
|
return weight_full
|
|
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
def forward(self, weight_map):
|
|
|
|
|
"""
|
|
|
|
|
生成内容感知光流
|
|
|
|
|
Args:
|
2025-02-23 15:23:00 +08:00
|
|
|
weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
|
2025-02-22 14:21:54 +08:00
|
|
|
Returns:
|
2025-02-23 15:22:14 +08:00
|
|
|
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
2025-02-22 14:21:54 +08:00
|
|
|
"""
|
2025-02-26 22:07:11 +08:00
|
|
|
# 上采样权重图到全分辨率
|
2025-03-09 23:30:05 +08:00
|
|
|
|
2025-02-26 22:07:11 +08:00
|
|
|
weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384]
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
# 1. 归一化权重图
|
|
|
|
|
# 保持区域相对强度,同时限制数值范围
|
2025-02-26 22:07:11 +08:00
|
|
|
weight_norm = F.normalize(weight_full, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
|
2025-02-22 14:21:54 +08:00
|
|
|
|
2025-02-26 22:07:11 +08:00
|
|
|
# 2. 生成高斯噪声
|
|
|
|
|
B, _, H, W = weight_norm.shape
|
|
|
|
|
z = torch.randn(B, 2, H, W, device=weight_norm.device) # [B,2,H,W]
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
# 3. 合成基础光流
|
2025-02-23 15:23:00 +08:00
|
|
|
# 将权重图扩展为2通道(x/y方向共享权重)
|
2025-02-22 14:21:54 +08:00
|
|
|
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
|
|
|
|
|
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
|
|
|
|
|
|
2025-02-23 15:23:00 +08:00
|
|
|
# 4. 平滑处理(保持结构连续性)
|
2025-02-22 14:21:54 +08:00
|
|
|
# 对每个通道独立进行高斯模糊
|
|
|
|
|
F_smooth = self.smoother(F_raw) # [B,2,H,W]
|
|
|
|
|
|
2025-02-23 15:23:00 +08:00
|
|
|
# 5. 动态范围调整(可选)
|
2025-02-22 14:21:54 +08:00
|
|
|
# 限制光流幅值,避免极端位移
|
|
|
|
|
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
|
|
|
|
|
|
2025-03-07 18:43:06 +08:00
|
|
|
return F_content
|
2025-02-22 14:21:54 +08:00
|
|
|
|
2025-02-23 15:27:15 +08:00
|
|
|
class RomaUnsbModel(BaseModel):
|
2025-02-22 14:21:54 +08:00
|
|
|
@staticmethod
|
|
|
|
|
def modify_commandline_options(parser, is_train=True):
|
|
|
|
|
"""配置 CTNx 模型的特定选项"""
|
|
|
|
|
|
2025-02-23 15:22:14 +08:00
|
|
|
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
2025-03-07 19:20:37 +08:00
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
|
2025-02-23 23:15:25 +08:00
|
|
|
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
|
|
|
|
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
2025-03-07 18:43:06 +08:00
|
|
|
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
|
2025-02-26 22:24:17 +08:00
|
|
|
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
2025-03-07 18:43:06 +08:00
|
|
|
parser.add_argument('--local_nums', type=int, default=64, help='number of local patches')
|
|
|
|
|
parser.add_argument('--side_length', type=int, default=7)
|
2025-02-24 23:00:25 +08:00
|
|
|
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
2025-03-07 19:20:37 +08:00
|
|
|
|
2025-02-24 23:00:25 +08:00
|
|
|
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
|
2025-03-07 18:43:06 +08:00
|
|
|
parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix')
|
2025-02-23 23:15:25 +08:00
|
|
|
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
|
2025-02-22 14:21:54 +08:00
|
|
|
|
2025-02-23 18:42:21 +08:00
|
|
|
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
|
|
|
|
|
parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
|
|
|
|
|
|
|
|
|
|
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
opt, _ = parser.parse_known_args()
|
|
|
|
|
|
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
|
def __init__(self, opt):
|
|
|
|
|
"""初始化 CTNx 模型"""
|
|
|
|
|
BaseModel.__init__(self, opt)
|
|
|
|
|
|
|
|
|
|
# 指定需要打印的训练损失
|
2025-03-07 18:43:06 +08:00
|
|
|
self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
|
2025-03-09 21:41:52 +08:00
|
|
|
self.visual_names = ['real_A0', 'fake_B0', 'real_B0','real_A1', 'fake_B1', 'real_B1']
|
2025-02-22 14:21:54 +08:00
|
|
|
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
|
|
|
|
|
2025-03-07 18:43:06 +08:00
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
if self.opt.phase == 'test':
|
|
|
|
|
self.visual_names = ['real']
|
|
|
|
|
for NFE in range(self.opt.num_timesteps):
|
|
|
|
|
fake_name = 'fake_' + str(NFE+1)
|
|
|
|
|
self.visual_names.append(fake_name)
|
|
|
|
|
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.isTrain:
|
2025-03-07 18:43:06 +08:00
|
|
|
self.model_names = ['G', 'D_ViT']
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
else:
|
2025-02-23 15:51:57 +08:00
|
|
|
self.model_names = ['G']
|
2025-02-23 18:42:21 +08:00
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
# 创建网络
|
|
|
|
|
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
2025-03-07 19:20:37 +08:00
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
|
2025-02-23 23:15:25 +08:00
|
|
|
if self.isTrain:
|
2025-02-22 14:21:54 +08:00
|
|
|
|
2025-02-23 18:42:21 +08:00
|
|
|
self.resize = tfs.Resize(size=(384,384), antialias=True)
|
2025-02-22 14:21:54 +08:00
|
|
|
|
2025-02-23 23:15:25 +08:00
|
|
|
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
|
|
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
# 加入预训练VIT
|
|
|
|
|
self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device)
|
|
|
|
|
|
|
|
|
|
# 定义损失函数
|
2025-02-23 23:15:25 +08:00
|
|
|
self.criterionL1 = torch.nn.L1Loss().to(self.device)
|
2025-02-22 14:21:54 +08:00
|
|
|
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
2025-02-23 22:40:36 +08:00
|
|
|
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
2025-02-23 23:15:25 +08:00
|
|
|
self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
2025-03-07 18:43:06 +08:00
|
|
|
self.optimizers = [self.optimizer_G, self.optimizer_D]
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
|
|
|
|
|
self.ctn = ContentAwareTemporalNorm() #生成的伪光流
|
|
|
|
|
|
|
|
|
|
def data_dependent_initialize(self, data):
|
|
|
|
|
"""
|
|
|
|
|
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
|
|
|
|
features of the encoder portion of netG. Because of this, the weights of netF are
|
|
|
|
|
initialized at the first feedforward pass with some input images.
|
|
|
|
|
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
|
|
|
|
"""
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def optimize_parameters(self):
|
|
|
|
|
# forward
|
|
|
|
|
self.forward()
|
|
|
|
|
|
|
|
|
|
self.netG.train()
|
2025-02-23 23:15:25 +08:00
|
|
|
self.netD_ViT.train()
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
# update D
|
2025-02-23 23:15:25 +08:00
|
|
|
self.set_requires_grad(self.netD_ViT, True)
|
2025-02-22 14:21:54 +08:00
|
|
|
self.optimizer_D.zero_grad()
|
|
|
|
|
self.loss_D = self.compute_D_loss()
|
|
|
|
|
self.loss_D.backward()
|
|
|
|
|
self.optimizer_D.step()
|
|
|
|
|
|
|
|
|
|
# update G
|
2025-03-07 18:43:06 +08:00
|
|
|
self.set_requires_grad(self.netD_ViT, False)
|
2025-02-22 14:21:54 +08:00
|
|
|
self.optimizer_G.zero_grad()
|
|
|
|
|
self.loss_G = self.compute_G_loss()
|
|
|
|
|
self.loss_G.backward()
|
|
|
|
|
self.optimizer_G.step()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_input(self, input):
|
|
|
|
|
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
|
|
|
|
Parameters:
|
|
|
|
|
input (dict): include the data itself and its metadata information.
|
|
|
|
|
The option 'direction' can be used to swap domain A and domain B.
|
|
|
|
|
"""
|
|
|
|
|
AtoB = self.opt.direction == 'AtoB'
|
|
|
|
|
self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
|
|
|
|
|
self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
|
|
|
|
|
self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
|
|
|
|
|
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
|
|
|
|
|
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self):
|
2025-02-23 22:26:04 +08:00
|
|
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
2025-03-07 19:20:37 +08:00
|
|
|
self.fake_B0 = self.netG(self.real_A0)
|
|
|
|
|
self.fake_B1 = self.netG(self.real_A1)
|
2025-02-22 14:21:54 +08:00
|
|
|
|
2025-03-07 19:20:37 +08:00
|
|
|
if self.opt.isTrain:
|
2025-02-23 22:26:04 +08:00
|
|
|
real_A0 = self.real_A0
|
|
|
|
|
real_A1 = self.real_A1
|
|
|
|
|
real_B0 = self.real_B0
|
|
|
|
|
real_B1 = self.real_B1
|
|
|
|
|
fake_B0 = self.fake_B0
|
|
|
|
|
fake_B1 = self.fake_B1
|
|
|
|
|
self.real_A0_resize = self.resize(real_A0)
|
|
|
|
|
self.real_A1_resize = self.resize(real_A1)
|
|
|
|
|
real_B0 = self.resize(real_B0)
|
|
|
|
|
real_B1 = self.resize(real_B1)
|
|
|
|
|
self.fake_B0_resize = self.resize(fake_B0)
|
|
|
|
|
self.fake_B1_resize = self.resize(fake_B1)
|
|
|
|
|
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
|
|
|
|
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
|
|
|
|
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
|
|
|
|
|
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
2025-03-07 19:20:37 +08:00
|
|
|
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
|
|
|
|
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
2025-02-23 22:40:34 +08:00
|
|
|
|
2025-03-07 18:43:06 +08:00
|
|
|
def compute_D_loss(self):
|
|
|
|
|
"""Calculate GAN loss with Content-Aware Optimization"""
|
2025-02-23 23:15:25 +08:00
|
|
|
lambda_D_ViT = self.opt.lambda_D_ViT
|
2025-02-22 14:21:54 +08:00
|
|
|
|
2025-03-09 23:30:05 +08:00
|
|
|
pred_real0, attn_real0 = self.netD_ViT(self.mutil_real_B0_tokens[0]) # scores, features
|
|
|
|
|
pred_real1, attn_real1 = self.netD_ViT(self.mutil_real_B1_tokens[0]) # scores, features
|
2025-03-07 18:43:06 +08:00
|
|
|
|
2025-03-09 23:30:05 +08:00
|
|
|
pred_fake0, attn_fake0 = self.netD_ViT(self.mutil_fake_B0_tokens[0].detach())
|
|
|
|
|
pred_fake1, attn_fake1 = self.netD_ViT(self.mutil_fake_B1_tokens[0].detach())
|
2025-03-07 19:20:37 +08:00
|
|
|
loss_cao0, self.weight_real0, self.weight_fake0 = self.cao(
|
|
|
|
|
real_scores=pred_real0,
|
2025-03-09 23:30:05 +08:00
|
|
|
fake_scores=pred_fake0,
|
|
|
|
|
attn_real=attn_real0,
|
|
|
|
|
attn_fake=attn_fake0
|
2025-03-07 19:20:37 +08:00
|
|
|
)
|
|
|
|
|
loss_cao1, self.weight_real1, self.weight_fake1 = self.cao(
|
|
|
|
|
real_scores=pred_real1,
|
2025-03-09 23:30:05 +08:00
|
|
|
fake_scores=pred_fake1,
|
|
|
|
|
attn_real=attn_real1,
|
|
|
|
|
attn_fake=attn_fake1
|
2025-03-07 19:20:37 +08:00
|
|
|
)
|
2025-03-07 18:43:06 +08:00
|
|
|
|
2025-03-09 23:30:05 +08:00
|
|
|
self.loss_D_ViT = (loss_cao0 + loss_cao1) * 0.5 * lambda_D_ViT
|
2025-03-07 18:43:06 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
# 记录损失值供可视化
|
|
|
|
|
# self.loss_D_real = loss_D_real.item()
|
|
|
|
|
# self.loss_D_fake = loss_D_fake.item()
|
|
|
|
|
# self.loss_cao = (loss_cao0 + loss_cao1).item() * 0.5
|
2025-02-22 14:21:54 +08:00
|
|
|
|
2025-03-07 18:43:06 +08:00
|
|
|
return self.loss_D_ViT
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
def compute_G_loss(self):
|
|
|
|
|
"""计算生成器的 GAN 损失"""
|
2025-02-26 22:07:11 +08:00
|
|
|
if self.opt.lambda_ctn > 0.0:
|
|
|
|
|
# 生成图像的CTN光流图
|
2025-03-09 21:41:52 +08:00
|
|
|
self.f_content0 = self.ctn(self.weight_fake0.detach())
|
|
|
|
|
self.f_content1 = self.ctn(self.weight_fake1.detach())
|
2025-02-26 22:07:11 +08:00
|
|
|
|
|
|
|
|
# 变换后的图片
|
2025-03-07 18:43:06 +08:00
|
|
|
self.warped_real_A0 = warp(self.real_A0, self.f_content0)
|
|
|
|
|
self.warped_real_A1 = warp(self.real_A1, self.f_content1)
|
|
|
|
|
self.warped_fake_B0 = warp(self.fake_B0,self.f_content0)
|
|
|
|
|
self.warped_fake_B1 = warp(self.fake_B1,self.f_content1)
|
2025-02-26 22:07:11 +08:00
|
|
|
|
|
|
|
|
# 经过第二次生成器
|
2025-03-07 19:20:37 +08:00
|
|
|
self.warped_fake_B0_2 = self.netG(self.warped_real_A0)
|
|
|
|
|
self.warped_fake_B1_2 = self.netG(self.warped_real_A1)
|
2025-02-26 22:07:11 +08:00
|
|
|
|
|
|
|
|
warped_fake_B0_2=self.warped_fake_B0_2
|
2025-03-07 18:43:06 +08:00
|
|
|
warped_fake_B1_2=self.warped_fake_B1_2
|
2025-02-26 22:07:11 +08:00
|
|
|
warped_fake_B0=self.warped_fake_B0
|
2025-03-07 18:43:06 +08:00
|
|
|
warped_fake_B1=self.warped_fake_B1
|
2025-02-26 22:07:11 +08:00
|
|
|
# 计算L2损失
|
2025-03-07 18:43:06 +08:00
|
|
|
self.loss_ctn0 = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
|
|
|
|
|
self.loss_ctn1 = F.mse_loss(warped_fake_B1_2, warped_fake_B1)
|
|
|
|
|
self.loss_ctn = (self.loss_ctn0 + self.loss_ctn1)*0.5
|
2025-02-26 22:07:11 +08:00
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
if self.opt.lambda_GAN > 0.0:
|
2025-03-07 18:43:06 +08:00
|
|
|
|
2025-03-07 19:20:37 +08:00
|
|
|
pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens[0])
|
|
|
|
|
pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens[0])
|
2025-03-07 18:43:06 +08:00
|
|
|
self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean()
|
|
|
|
|
self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean()
|
|
|
|
|
self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5
|
2025-02-22 14:21:54 +08:00
|
|
|
else:
|
|
|
|
|
self.loss_G_GAN = 0.0
|
2025-02-26 22:07:11 +08:00
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
|
2025-03-07 18:43:06 +08:00
|
|
|
if self.opt.lambda_global or self.opt.lambda_spatial > 0.0:
|
|
|
|
|
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
|
2025-02-22 14:21:54 +08:00
|
|
|
else:
|
2025-03-07 18:43:06 +08:00
|
|
|
self.loss_global, self.loss_spatial = 0.0, 0.0
|
2025-02-22 14:21:54 +08:00
|
|
|
|
2025-02-26 22:07:11 +08:00
|
|
|
self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \
|
2025-02-26 22:24:17 +08:00
|
|
|
self.opt.lambda_ctn * self.loss_ctn + \
|
2025-03-07 18:43:06 +08:00
|
|
|
self.loss_global * self.opt.lambda_global+\
|
|
|
|
|
self.loss_spatial * self.opt.lambda_spatial
|
|
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
return self.loss_G
|
2025-03-07 18:43:06 +08:00
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
def calculate_attention_loss(self):
|
|
|
|
|
n_layers = len(self.atten_layers)
|
|
|
|
|
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
|
|
|
|
mutil_real_A1_tokens = self.mutil_real_A1_tokens
|
2025-03-09 21:41:52 +08:00
|
|
|
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
|
|
|
|
|
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.opt.lambda_global > 0.0:
|
|
|
|
|
loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
|
|
|
|
|
loss_global *= 0.5
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
loss_global = 0.0
|
|
|
|
|
|
|
|
|
|
if self.opt.lambda_spatial > 0.0:
|
|
|
|
|
loss_spatial = 0.0
|
|
|
|
|
local_nums = self.opt.local_nums
|
|
|
|
|
tokens_cnt = 576
|
|
|
|
|
local_id = np.random.permutation(tokens_cnt)
|
|
|
|
|
local_id = local_id[:int(min(local_nums, tokens_cnt))]
|
|
|
|
|
|
2025-03-07 18:43:06 +08:00
|
|
|
mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
|
|
|
|
mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
2025-02-22 14:21:54 +08:00
|
|
|
|
2025-03-07 18:43:06 +08:00
|
|
|
mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
|
|
|
|
mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
|
|
|
|
|
loss_spatial *= 0.5
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
loss_spatial = 0.0
|
2025-03-07 18:43:06 +08:00
|
|
|
return loss_global , loss_spatial
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
|
|
|
|
|
loss = 0.0
|
|
|
|
|
n_layers = len(self.atten_layers)
|
|
|
|
|
|
|
|
|
|
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
|
|
|
|
|
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
|
|
|
|
|
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
|
|
|
|
|
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
|
|
|
|
|
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
|
|
|
|
|
|
|
|
|
|
loss = loss / n_layers
|
|
|
|
|
return loss
|
|
|
|
|
|