roma_unsb/models/roma_unsb_model.py

640 lines
29 KiB
Python
Raw Normal View History

2025-02-22 14:21:54 +08:00
import numpy as np
import math
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import GaussianBlur
from .base_model import BaseModel
from . import networks
from .patchnce import PatchNCELoss
import util.util as util
from torchvision.transforms import transforms as tfs
def warp(image, flow): #warp操作
"""
基于光流的图像变形函数
Args:
image: [B, C, H, W] 输入图像
2025-02-23 15:22:14 +08:00
flow: [B, 2, H, W] 光流场(x/y方向位移)
2025-02-22 14:21:54 +08:00
Returns:
warped: [B, C, H, W] 变形后的图像
"""
B, C, H, W = image.shape
# 生成网格坐标
grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H))
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
2025-02-23 15:23:00 +08:00
# 应用光流位移(归一化到[-1,1])
2025-02-22 14:21:54 +08:00
new_grid = grid + flow
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2]
# 双线性插值
return F.grid_sample(image, new_grid, align_corners=True)
# 时序归一化损失计算
def compute_ctn_loss(G, x, F_content): #公式10
"""
计算内容感知时序归一化损失
Args:
G: 生成器
x: 输入红外图像 [B,C,H,W]
F_content: 生成的光流场 [B,2,H,W]
"""
# 生成可见光图像
y_fake = G(x) # [B,3,H,W]
# 对生成结果应用光流变形
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
# 对输入应用相同光流后生成图像
warped_x = warp(x, F_content) # [B,C,H,W]
y_fake_warped = G(warped_x) # [B,3,H,W]
# 计算L2损失
loss = F.mse_loss(warped_fake, y_fake_warped)
return loss
2025-02-24 20:39:59 +08:00
class ContentAwareOptimization(nn.Module):
2025-02-22 14:21:54 +08:00
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
super().__init__()
self.lambda_inc = lambda_inc # 权重增强系数
self.eta_ratio = eta_ratio # 选择内容区域的比例
def compute_cosine_similarity(self, gradients):
"""
计算每个patch梯度与平均梯度的余弦相似度
Args:
2025-02-23 15:22:14 +08:00
gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h)
2025-02-22 14:21:54 +08:00
Returns:
cosine_sim: [B, N] 每个patch的余弦相似度
"""
mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D]
# 计算余弦相似度
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
return cosine_sim
2025-02-24 20:39:59 +08:00
def generate_weight_map(self, gradients_fake, feature_shape):
2025-02-22 14:21:54 +08:00
"""
2025-02-24 20:39:59 +08:00
生成内容感知权重图修正空间维度
2025-02-22 14:21:54 +08:00
Args:
2025-02-24 20:39:59 +08:00
gradients_real: [B, N, D] 真实图像判别器梯度
gradients_fake: [B, N, D] 生成图像判别器梯度
feature_shape: tuple [H, W] 判别器输出的特征图尺寸
2025-02-22 14:21:54 +08:00
Returns:
2025-02-24 20:39:59 +08:00
weight_real: [B, 1, H, W] 真实图像权重图
weight_fake: [B, 1, H, W] 生成图像权重图
2025-02-22 14:21:54 +08:00
"""
2025-02-24 20:39:59 +08:00
H, W = feature_shape
N = H * W
# 计算余弦相似度(与原代码相同)
cosine_fake = self.compute_cosine_similarity(gradients_fake)
# 生成权重图(与原代码相同)
k = int(self.eta_ratio * N)
2025-02-22 14:21:54 +08:00
weight_fake = torch.ones_like(cosine_fake)
2025-02-24 20:39:59 +08:00
# 重建空间维度 --------------------------------------------------
# 将权重从[B, N]转换为[B, H, W]
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W]
2025-02-22 14:21:54 +08:00
2025-02-23 14:37:14 +08:00
return weight_fake
2025-02-22 14:21:54 +08:00
def forward(self, D_real, D_fake, real_scores, fake_scores):
"""
计算内容感知对抗损失
Args:
D_real: 判别器对真实图像的特征输出 [B, C, H, W]
D_fake: 判别器对生成图像的特征输出 [B, C, H, W]
real_scores: 真实图像的判别器预测 [B, N] (N=H*W)
fake_scores: 生成图像的判别器预测 [B, N]
Returns:
loss_co_adv: 内容感知对抗损失
"""
B, C, H, W = D_real.shape
N = H * W
# 注册钩子获取梯度
gradients_real = []
gradients_fake = []
def hook_real(grad):
gradients_real.append(grad.detach().view(B, N, -1))
def hook_fake(grad):
gradients_fake.append(grad.detach().view(B, N, -1))
D_real.register_hook(hook_real)
D_fake.register_hook(hook_fake)
# 计算原始对抗损失以触发梯度计算
loss_real = torch.mean(torch.log(real_scores + 1e-8))
loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8))
# 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递
loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum())
total_loss = loss_real + loss_fake + loss_dummy
total_loss.backward(retain_graph=True)
# 获取梯度数据
gradients_real = gradients_real[0] # [B, N, D]
gradients_fake = gradients_fake[0] # [B, N, D]
# 生成权重图
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
# 应用权重到对抗损失
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8))
# 计算并返回最终内容感知对抗损失
loss_co_adv = -(loss_co_real + loss_co_fake)
return loss_co_adv
class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
super().__init__()
self.gamma_stride = gamma_stride # 控制整体运动幅度
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
def forward(self, weight_map):
"""
生成内容感知光流
Args:
2025-02-23 15:23:00 +08:00
weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
2025-02-22 14:21:54 +08:00
Returns:
2025-02-23 15:22:14 +08:00
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
2025-02-22 14:21:54 +08:00
"""
2025-02-23 19:06:35 +08:00
print(weight_map.shape)
2025-02-22 14:21:54 +08:00
B, _, H, W = weight_map.shape
# 1. 归一化权重图
# 保持区域相对强度,同时限制数值范围
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
2025-02-23 15:23:00 +08:00
# 2. 生成高斯噪声(与光流场同尺寸)
2025-02-22 14:21:54 +08:00
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
# 3. 合成基础光流
2025-02-23 15:23:00 +08:00
# 将权重图扩展为2通道(x/y方向共享权重)
2025-02-22 14:21:54 +08:00
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
2025-02-23 15:23:00 +08:00
# 4. 平滑处理(保持结构连续性)
2025-02-22 14:21:54 +08:00
# 对每个通道独立进行高斯模糊
F_smooth = self.smoother(F_raw) # [B,2,H,W]
2025-02-23 15:23:00 +08:00
# 5. 动态范围调整(可选)
2025-02-22 14:21:54 +08:00
# 限制光流幅值,避免极端位移
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
return F_content
2025-02-23 15:27:15 +08:00
class RomaUnsbModel(BaseModel):
2025-02-22 14:21:54 +08:00
@staticmethod
def modify_commandline_options(parser, is_train=True):
"""配置 CTNx 模型的特定选项"""
2025-02-23 15:22:14 +08:00
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
2025-02-22 14:21:54 +08:00
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
2025-02-23 23:15:25 +08:00
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
2025-02-22 14:21:54 +08:00
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
type=util.str2bool, nargs='?', const=True, default=False,
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
parser.add_argument('--netF_nc', type=int, default=256)
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
parser.add_argument('--lmda_1', type=float, default=0.1)
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
parser.add_argument('--flip_equivariance',
type=util.str2bool, nargs='?', const=True, default=False,
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
parser.add_argument('--eta_ratio', type=float, default=0.1, help='ratio of content-rich regions')
2025-02-23 15:46:18 +08:00
2025-02-23 23:15:25 +08:00
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
2025-02-22 14:21:54 +08:00
2025-02-23 18:42:21 +08:00
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
2025-02-22 14:21:54 +08:00
parser.set_defaults(pool_size=0) # no image pooling
opt, _ = parser.parse_known_args()
# 直接设置为 sb 模式
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
return parser
def __init__(self, opt):
"""初始化 CTNx 模型"""
BaseModel.__init__(self, opt)
# 指定需要打印的训练损失
self.loss_names = ['G_GAN_1', 'D_real_1', 'D_fake_1', 'G_1', 'NCE_1', 'SB_1',
'G_2']
self.visual_names = ['real_A', 'real_A_noisy', 'fake_B', 'real_B']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
if self.opt.phase == 'test':
self.visual_names = ['real']
for NFE in range(self.opt.num_timesteps):
fake_name = 'fake_' + str(NFE+1)
self.visual_names.append(fake_name)
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
if opt.nce_idt and self.isTrain:
self.loss_names += ['NCE_Y']
self.visual_names += ['idt_B']
if self.isTrain:
2025-02-23 23:15:25 +08:00
self.model_names = ['G', 'D_ViT', 'E']
2025-02-22 14:21:54 +08:00
else:
2025-02-23 15:51:57 +08:00
self.model_names = ['G']
2025-02-23 18:42:21 +08:00
print(f'input_nc = {self.opt.input_nc}')
2025-02-22 14:21:54 +08:00
# 创建网络
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
2025-02-23 23:15:25 +08:00
if self.isTrain:
2025-02-22 14:21:54 +08:00
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
2025-02-23 18:42:21 +08:00
self.resize = tfs.Resize(size=(384,384), antialias=True)
2025-02-22 14:21:54 +08:00
2025-02-23 23:15:25 +08:00
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
2025-02-22 14:21:54 +08:00
# 加入预训练VIT
self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device)
# 定义损失函数
2025-02-23 23:15:25 +08:00
self.criterionL1 = torch.nn.L1Loss().to(self.device)
2025-02-22 14:21:54 +08:00
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionNCE = []
for nce_layer in self.nce_layers:
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
self.criterionIdt = torch.nn.L1Loss().to(self.device)
2025-02-23 22:40:36 +08:00
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
2025-02-23 23:15:25 +08:00
self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
2025-02-23 22:40:36 +08:00
self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizers = [self.optimizer_G, self.optimizer_D, self.optimizer_E]
2025-02-22 14:21:54 +08:00
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
self.ctn = ContentAwareTemporalNorm() #生成的伪光流
def data_dependent_initialize(self, data):
"""
The feature network netF is defined in terms of the shape of the intermediate, extracted
features of the encoder portion of netG. Because of this, the weights of netF are
initialized at the first feedforward pass with some input images.
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
"""
#bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
#self.set_input(data)
#self.real_A = self.real_A[:bs_per_gpu]
#self.real_B = self.real_B[:bs_per_gpu]
#self.forward() # compute fake images: G(A)
#if self.opt.isTrain:
#
# self.compute_G_loss().backward()
# self.compute_D_loss().backward()
# self.compute_E_loss().backward()
# if self.opt.lambda_NCE > 0.0:
# self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
# self.optimizers.append(self.optimizer_F)
pass
def optimize_parameters(self):
# forward
self.forward()
self.netG.train()
self.netE.train()
2025-02-23 23:15:25 +08:00
self.netD_ViT.train()
2025-02-22 14:21:54 +08:00
# update D
2025-02-23 23:15:25 +08:00
self.set_requires_grad(self.netD_ViT, True)
2025-02-22 14:21:54 +08:00
self.optimizer_D.zero_grad()
self.loss_D = self.compute_D_loss()
self.loss_D.backward()
self.optimizer_D.step()
2025-02-22 15:23:52 +08:00
# update E
2025-02-22 14:21:54 +08:00
self.set_requires_grad(self.netE, True)
self.optimizer_E.zero_grad()
self.loss_E = self.compute_E_loss()
self.loss_E.backward()
self.optimizer_E.step()
# update G
2025-02-23 23:15:25 +08:00
self.set_requires_grad(self.netD_ViT, False)
2025-02-22 14:21:54 +08:00
self.set_requires_grad(self.netE, False)
self.optimizer_G.zero_grad()
self.loss_G = self.compute_G_loss()
self.loss_G.backward()
self.optimizer_G.step()
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def tokens_concat(self, origin_tokens, adjacent_size):
adj_size = adjacent_size
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
S = int(math.sqrt(token_num))
if S * S != token_num:
print('Error! Not a square!')
token_map = origin_tokens.clone().reshape(B,S,S,C)
cut_patch_list = []
for i in range(0, S, adj_size):
for j in range(0, S, adj_size):
i_left = i
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
j_left = j
j_right = j + adj_size if j + adj_size <= S else S + 1
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
cut_patch= cut_patch.reshape(B,-1,C)
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
cut_patch_list.append(cut_patch)
result = torch.cat(cut_patch_list,dim=1)
return result
def cat_results(self, origin_tokens, adj_size_list):
res_list = [origin_tokens]
for ad_s in adj_size_list:
cat_result = self.tokens_concat(origin_tokens, ad_s)
res_list.append(cat_result)
result = torch.cat(res_list, dim=1)
return result
def forward(self):
2025-02-23 22:26:04 +08:00
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
2025-02-22 14:21:54 +08:00
2025-02-23 22:26:04 +08:00
# ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============
2025-02-22 14:21:54 +08:00
tau = self.opt.tau
T = self.opt.num_timesteps
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
times = np.cumsum(incs)
times = times / times[-1]
times = 0.5 * times[-1] + 0.5 * times #[0.5,1]
times = np.concatenate([np.zeros(1), times])
times = torch.tensor(times).float().cuda()
self.times = times
2025-02-23 22:26:04 +08:00
bs = self.real_A0.size(0)
2025-02-22 14:21:54 +08:00
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
self.time_idx = time_idx
with torch.no_grad():
self.netG.eval()
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
for t in range(self.time_idx.int().item() + 1):
# 计算增量 delta 与 inter/scale用于每个时间步的插值等
if t > 0:
delta = times[t] - times[t - 1]
denom = times[-1] - times[t - 1]
inter = (delta / denom).reshape(-1, 1, 1, 1)
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
# 对 Xt、Xt2 进行随机噪声更新
2025-02-23 22:26:04 +08:00
Xt = self.real_A0 if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device)
time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long()
z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device)
2025-02-22 14:21:54 +08:00
self.time = times[time_idx]
Xt_1 = self.netG(Xt, self.time, z)
2025-02-23 22:26:04 +08:00
Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device)
time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long()
z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device)
2025-02-22 14:21:54 +08:00
Xt_12 = self.netG(Xt2, self.time, z)
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
self.real_A_noisy = Xt.detach()
self.real_A_noisy2 = Xt2.detach()
# ============ 第三步:拼接输入并执行网络推理 =============
2025-02-23 22:26:04 +08:00
bs = self.real_A0.size(0)
2025-02-23 23:15:25 +08:00
z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device)
2025-02-23 22:26:04 +08:00
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
2025-02-22 14:21:54 +08:00
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
2025-02-23 22:26:04 +08:00
self.real = self.real_A0
2025-02-22 14:21:54 +08:00
self.realt = self.real_A_noisy
if self.opt.flip_equivariance:
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
if self.flipped_for_equivariance:
self.real = torch.flip(self.real, [3])
self.realt = torch.flip(self.realt, [3])
2025-02-23 23:15:25 +08:00
print(f'fake_B0: {self.real_A0.shape}, fake_B1: {self.real_A1.shape}')
2025-02-23 22:26:04 +08:00
self.fake_B0 = self.netG(self.real_A0, self.time, z_in)
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
2025-02-23 23:15:25 +08:00
print(f'fake_B0: {self.fake_B0.shape}, fake_B1: {self.fake_B1.shape}')
2025-02-22 14:21:54 +08:00
if self.opt.phase == 'train':
2025-02-23 22:26:04 +08:00
real_A0 = self.real_A0
real_A1 = self.real_A1
real_B0 = self.real_B0
real_B1 = self.real_B1
fake_B0 = self.fake_B0
fake_B1 = self.fake_B1
self.real_A0_resize = self.resize(real_A0)
self.real_A1_resize = self.resize(real_A1)
real_B0 = self.resize(real_B0)
real_B1 = self.resize(real_B1)
self.fake_B0_resize = self.resize(fake_B0)
self.fake_B1_resize = self.resize(fake_B1)
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
2025-02-23 22:40:34 +08:00
# [[1,576,768],[1,576,768],[1,576,768]]
# [3,576,768]
2025-02-24 20:39:59 +08:00
# 生成图像的梯度
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
# 梯度图
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
# 生成图像的CTN光流图
self.f_content = self.ctn(self.weight_fake)
# 变换后的图片
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
# 经过第二次生成器
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
2025-02-23 22:40:34 +08:00
2025-02-24 20:39:59 +08:00
# warped_fake_B0_2=self.warped_fake_B0_2
# warped_fake_B0=self.warped_fake_B0
# self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
# self.warped_fake_B0_resize = self.resize(warped_fake_B0)
# self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
# self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
2025-02-23 22:40:34 +08:00
2025-02-22 14:21:54 +08:00
2025-02-23 23:15:25 +08:00
def compute_D_loss(self): #判别器还是没有改
"""Calculate GAN loss for the discriminator"""
lambda_D_ViT = self.opt.lambda_D_ViT
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach()
real_B0_tokens = self.mutil_real_B0_tokens[0]
real_B1_tokens = self.mutil_real_B1_tokens[0]
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
pred_real1_ViT = self.netD_ViT(real_B1_tokens)
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
return self.loss_D_ViT
2025-02-22 14:21:54 +08:00
def compute_E_loss(self):
"""计算判别器 E 的损失"""
2025-02-23 23:15:25 +08:00
print(f'resl_A_noisy: {self.real_A_noisy.shape} \n fake_B0: {self.fake_B0.shape}')
2025-02-23 22:40:34 +08:00
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
2025-02-22 14:21:54 +08:00
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
return self.loss_E
def compute_G_loss(self):
"""计算生成器的 GAN 损失"""
if self.opt.lambda_GAN > 0.0:
2025-02-23 23:15:25 +08:00
pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
2025-02-22 14:21:54 +08:00
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
else:
self.loss_G_GAN = 0.0
self.loss_SB = 0
if self.opt.lambda_SB > 0.0:
2025-02-23 22:40:34 +08:00
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1)
2025-02-22 14:21:54 +08:00
bs = self.opt.batch_size
# eq.9
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0)
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
2025-02-23 23:15:25 +08:00
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
2025-02-22 14:21:54 +08:00
if self.opt.lambda_global > 0.0:
2025-02-23 22:40:34 +08:00
loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
2025-02-22 14:21:54 +08:00
loss_global *= 0.5
else:
loss_global = 0.0
2025-02-23 22:40:36 +08:00
self.l2_loss = 0.0
2025-02-24 20:39:59 +08:00
if self.opt.lambda_l2 > 0.0:
wapped_fake_B = warp(self.fake_B0, self.f_content) # use updated self.f_content
self.l2_loss = F.mse_loss(self.warped_fake_B0_2, wapped_fake_B) # complete the loss calculation
2025-02-22 14:21:54 +08:00
self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global
return self.loss_G
def calculate_attention_loss(self):
n_layers = len(self.atten_layers)
mutil_real_A0_tokens = self.mutil_real_A0_tokens
mutil_real_A1_tokens = self.mutil_real_A1_tokens
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
loss_global *= 0.5
else:
loss_global = 0.0
if self.opt.lambda_spatial > 0.0:
loss_spatial = 0.0
local_nums = self.opt.local_nums
tokens_cnt = 576
local_id = np.random.permutation(tokens_cnt)
local_id = local_id[:int(min(local_nums, tokens_cnt))]
mutil_real_A0_local_tokens = self.netPreViT(self.resize(self.real_A0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
mutil_real_A1_local_tokens = self.netPreViT(self.resize(self.real_A1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
mutil_fake_B0_local_tokens = self.netPreViT(self.resize(self.fake_B0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
mutil_fake_B1_local_tokens = self.netPreViT(self.resize(self.fake_B1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
loss_spatial *= 0.5
else:
loss_spatial = 0.0
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
loss = 0.0
n_layers = len(self.atten_layers)
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
loss = loss / n_layers
return loss