635 lines
29 KiB
Python
635 lines
29 KiB
Python
import numpy as np
|
||
import math
|
||
import timm
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from torchvision.transforms import GaussianBlur
|
||
from .base_model import BaseModel
|
||
from . import networks
|
||
from .patchnce import PatchNCELoss
|
||
import util.util as util
|
||
|
||
from torchvision.transforms import transforms as tfs
|
||
|
||
def warp(image, flow): #warp操作
|
||
"""
|
||
基于光流的图像变形函数
|
||
Args:
|
||
image: [B, C, H, W] 输入图像
|
||
flow: [B, 2, H, W] 光流场(x/y方向位移)
|
||
Returns:
|
||
warped: [B, C, H, W] 变形后的图像
|
||
"""
|
||
B, C, H, W = image.shape
|
||
# 生成网格坐标
|
||
grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H))
|
||
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
|
||
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
|
||
|
||
# 应用光流位移(归一化到[-1,1])
|
||
new_grid = grid + flow
|
||
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
|
||
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
|
||
new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2]
|
||
|
||
# 双线性插值
|
||
return F.grid_sample(image, new_grid, align_corners=True)
|
||
|
||
# 时序归一化损失计算
|
||
def compute_ctn_loss(G, x, F_content): #公式10
|
||
"""
|
||
计算内容感知时序归一化损失
|
||
Args:
|
||
G: 生成器
|
||
x: 输入红外图像 [B,C,H,W]
|
||
F_content: 生成的光流场 [B,2,H,W]
|
||
"""
|
||
|
||
# 生成可见光图像
|
||
y_fake = G(x) # [B,3,H,W]
|
||
|
||
# 对生成结果应用光流变形
|
||
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
|
||
|
||
# 对输入应用相同光流后生成图像
|
||
warped_x = warp(x, F_content) # [B,C,H,W]
|
||
y_fake_warped = G(warped_x) # [B,3,H,W]
|
||
|
||
# 计算L2损失
|
||
loss = F.mse_loss(warped_fake, y_fake_warped)
|
||
return loss
|
||
|
||
class ContentAwareOptimization(nn.Module):
|
||
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
|
||
super().__init__()
|
||
self.lambda_inc = lambda_inc # 权重增强系数
|
||
self.eta_ratio = eta_ratio # 选择内容区域的比例
|
||
|
||
def compute_cosine_similarity(self, gradients):
|
||
"""
|
||
计算每个patch梯度与平均梯度的余弦相似度
|
||
Args:
|
||
gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h)
|
||
Returns:
|
||
cosine_sim: [B, N] 每个patch的余弦相似度
|
||
"""
|
||
mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D]
|
||
# 计算余弦相似度
|
||
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
|
||
return cosine_sim
|
||
|
||
def generate_weight_map(self, gradients_fake):
|
||
"""
|
||
生成内容感知权重图
|
||
Args:
|
||
gradients_fake: [B, N, D] 生成图像判别器梯度 [2,3,256,256]
|
||
Returns:
|
||
weight_fake: [B, N] 生成图像权重图 [2,3,256]
|
||
"""
|
||
# 计算生成图像块的余弦相似度
|
||
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
|
||
|
||
# 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
|
||
k = int(self.eta_ratio * cosine_fake.shape[1])
|
||
|
||
# 对生成图像生成权重图(同理)
|
||
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
|
||
weight_fake = torch.ones_like(cosine_fake)
|
||
for b in range(cosine_fake.shape[0]):
|
||
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
|
||
|
||
return weight_fake
|
||
|
||
def forward(self, D_real, D_fake, real_scores, fake_scores):
|
||
"""
|
||
计算内容感知对抗损失
|
||
Args:
|
||
D_real: 判别器对真实图像的特征输出 [B, C, H, W]
|
||
D_fake: 判别器对生成图像的特征输出 [B, C, H, W]
|
||
real_scores: 真实图像的判别器预测 [B, N] (N=H*W)
|
||
fake_scores: 生成图像的判别器预测 [B, N]
|
||
Returns:
|
||
loss_co_adv: 内容感知对抗损失
|
||
"""
|
||
B, C, H, W = D_real.shape
|
||
N = H * W
|
||
|
||
# 注册钩子获取梯度
|
||
gradients_real = []
|
||
gradients_fake = []
|
||
|
||
def hook_real(grad):
|
||
gradients_real.append(grad.detach().view(B, N, -1))
|
||
|
||
def hook_fake(grad):
|
||
gradients_fake.append(grad.detach().view(B, N, -1))
|
||
|
||
D_real.register_hook(hook_real)
|
||
D_fake.register_hook(hook_fake)
|
||
|
||
# 计算原始对抗损失以触发梯度计算
|
||
loss_real = torch.mean(torch.log(real_scores + 1e-8))
|
||
loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8))
|
||
# 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递
|
||
loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum())
|
||
total_loss = loss_real + loss_fake + loss_dummy
|
||
total_loss.backward(retain_graph=True)
|
||
|
||
# 获取梯度数据
|
||
gradients_real = gradients_real[0] # [B, N, D]
|
||
gradients_fake = gradients_fake[0] # [B, N, D]
|
||
|
||
# 生成权重图
|
||
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
|
||
|
||
# 应用权重到对抗损失
|
||
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
|
||
loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8))
|
||
|
||
# 计算并返回最终内容感知对抗损失
|
||
loss_co_adv = -(loss_co_real + loss_co_fake)
|
||
|
||
return loss_co_adv
|
||
|
||
class ContentAwareTemporalNorm(nn.Module):
|
||
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
||
super().__init__()
|
||
self.gamma_stride = gamma_stride # 控制整体运动幅度
|
||
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
||
|
||
def forward(self, weight_map):
|
||
"""
|
||
生成内容感知光流
|
||
Args:
|
||
weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
|
||
Returns:
|
||
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
||
"""
|
||
print(weight_map.shape)
|
||
B, _, H, W = weight_map.shape
|
||
|
||
# 1. 归一化权重图
|
||
# 保持区域相对强度,同时限制数值范围
|
||
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
|
||
|
||
# 2. 生成高斯噪声(与光流场同尺寸)
|
||
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
|
||
|
||
# 3. 合成基础光流
|
||
# 将权重图扩展为2通道(x/y方向共享权重)
|
||
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
|
||
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
|
||
|
||
# 4. 平滑处理(保持结构连续性)
|
||
# 对每个通道独立进行高斯模糊
|
||
F_smooth = self.smoother(F_raw) # [B,2,H,W]
|
||
|
||
# 5. 动态范围调整(可选)
|
||
# 限制光流幅值,避免极端位移
|
||
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
|
||
|
||
return F_content
|
||
|
||
class RomaUnsbModel(BaseModel):
|
||
@staticmethod
|
||
def modify_commandline_options(parser, is_train=True):
|
||
"""配置 CTNx 模型的特定选项"""
|
||
|
||
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
||
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
|
||
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
|
||
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
|
||
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
||
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
||
|
||
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
|
||
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
||
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
||
type=util.str2bool, nargs='?', const=True, default=False,
|
||
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
|
||
|
||
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
|
||
parser.add_argument('--netF_nc', type=int, default=256)
|
||
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
|
||
|
||
parser.add_argument('--lmda_1', type=float, default=0.1)
|
||
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
|
||
parser.add_argument('--flip_equivariance',
|
||
type=util.str2bool, nargs='?', const=True, default=False,
|
||
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
|
||
|
||
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
||
parser.add_argument('--eta_ratio', type=float, default=0.1, help='ratio of content-rich regions')
|
||
|
||
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
|
||
|
||
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
|
||
parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
|
||
|
||
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
|
||
|
||
parser.set_defaults(pool_size=0) # no image pooling
|
||
|
||
opt, _ = parser.parse_known_args()
|
||
|
||
# 直接设置为 sb 模式
|
||
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
|
||
|
||
return parser
|
||
|
||
def __init__(self, opt):
|
||
"""初始化 CTNx 模型"""
|
||
BaseModel.__init__(self, opt)
|
||
|
||
# 指定需要打印的训练损失
|
||
self.loss_names = ['G_GAN_1', 'D_real_1', 'D_fake_1', 'G_1', 'NCE_1', 'SB_1',
|
||
'G_2']
|
||
self.visual_names = ['real_A', 'real_A_noisy', 'fake_B', 'real_B']
|
||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||
|
||
if self.opt.phase == 'test':
|
||
self.visual_names = ['real']
|
||
for NFE in range(self.opt.num_timesteps):
|
||
fake_name = 'fake_' + str(NFE+1)
|
||
self.visual_names.append(fake_name)
|
||
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
|
||
|
||
if opt.nce_idt and self.isTrain:
|
||
self.loss_names += ['NCE_Y']
|
||
self.visual_names += ['idt_B']
|
||
|
||
if self.isTrain:
|
||
self.model_names = ['G', 'D_ViT', 'E']
|
||
|
||
|
||
else:
|
||
self.model_names = ['G']
|
||
|
||
print(f'input_nc = {self.opt.input_nc}')
|
||
# 创建网络
|
||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
||
|
||
|
||
if self.isTrain:
|
||
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
||
|
||
self.resize = tfs.Resize(size=(384,384), antialias=True)
|
||
|
||
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
|
||
|
||
# 加入预训练VIT
|
||
self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device)
|
||
|
||
# 定义损失函数
|
||
self.criterionL1 = torch.nn.L1Loss().to(self.device)
|
||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||
self.criterionNCE = []
|
||
for nce_layer in self.nce_layers:
|
||
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
|
||
self.criterionIdt = torch.nn.L1Loss().to(self.device)
|
||
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||
self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||
self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||
self.optimizers = [self.optimizer_G, self.optimizer_D, self.optimizer_E]
|
||
|
||
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
|
||
self.ctn = ContentAwareTemporalNorm() #生成的伪光流
|
||
|
||
def data_dependent_initialize(self, data):
|
||
"""
|
||
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
||
features of the encoder portion of netG. Because of this, the weights of netF are
|
||
initialized at the first feedforward pass with some input images.
|
||
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
||
"""
|
||
#bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
|
||
#self.set_input(data)
|
||
#self.real_A = self.real_A[:bs_per_gpu]
|
||
#self.real_B = self.real_B[:bs_per_gpu]
|
||
#self.forward() # compute fake images: G(A)
|
||
#if self.opt.isTrain:
|
||
#
|
||
# self.compute_G_loss().backward()
|
||
# self.compute_D_loss().backward()
|
||
# self.compute_E_loss().backward()
|
||
# if self.opt.lambda_NCE > 0.0:
|
||
# self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
|
||
# self.optimizers.append(self.optimizer_F)
|
||
pass
|
||
|
||
def optimize_parameters(self):
|
||
# forward
|
||
self.forward()
|
||
|
||
self.netG.train()
|
||
self.netE.train()
|
||
self.netD_ViT.train()
|
||
|
||
# update D
|
||
self.set_requires_grad(self.netD_ViT, True)
|
||
self.optimizer_D.zero_grad()
|
||
self.loss_D = self.compute_D_loss()
|
||
self.loss_D.backward()
|
||
self.optimizer_D.step()
|
||
|
||
# update E
|
||
self.set_requires_grad(self.netE, True)
|
||
self.optimizer_E.zero_grad()
|
||
self.loss_E = self.compute_E_loss()
|
||
self.loss_E.backward()
|
||
self.optimizer_E.step()
|
||
|
||
# update G
|
||
self.set_requires_grad(self.netD_ViT, False)
|
||
self.set_requires_grad(self.netE, False)
|
||
|
||
self.optimizer_G.zero_grad()
|
||
|
||
self.loss_G = self.compute_G_loss()
|
||
self.loss_G.backward()
|
||
self.optimizer_G.step()
|
||
|
||
|
||
def set_input(self, input):
|
||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||
Parameters:
|
||
input (dict): include the data itself and its metadata information.
|
||
The option 'direction' can be used to swap domain A and domain B.
|
||
"""
|
||
AtoB = self.opt.direction == 'AtoB'
|
||
self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
|
||
self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
|
||
self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
|
||
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
|
||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||
|
||
|
||
def tokens_concat(self, origin_tokens, adjacent_size):
|
||
adj_size = adjacent_size
|
||
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
|
||
S = int(math.sqrt(token_num))
|
||
if S * S != token_num:
|
||
print('Error! Not a square!')
|
||
token_map = origin_tokens.clone().reshape(B,S,S,C)
|
||
cut_patch_list = []
|
||
for i in range(0, S, adj_size):
|
||
for j in range(0, S, adj_size):
|
||
i_left = i
|
||
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
|
||
j_left = j
|
||
j_right = j + adj_size if j + adj_size <= S else S + 1
|
||
|
||
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
|
||
cut_patch= cut_patch.reshape(B,-1,C)
|
||
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
|
||
cut_patch_list.append(cut_patch)
|
||
|
||
|
||
result = torch.cat(cut_patch_list,dim=1)
|
||
return result
|
||
|
||
def cat_results(self, origin_tokens, adj_size_list):
|
||
res_list = [origin_tokens]
|
||
for ad_s in adj_size_list:
|
||
cat_result = self.tokens_concat(origin_tokens, ad_s)
|
||
res_list.append(cat_result)
|
||
|
||
result = torch.cat(res_list, dim=1)
|
||
|
||
return result
|
||
|
||
def forward(self):
|
||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||
|
||
# ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============
|
||
tau = self.opt.tau
|
||
T = self.opt.num_timesteps
|
||
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
|
||
times = np.cumsum(incs)
|
||
times = times / times[-1]
|
||
times = 0.5 * times[-1] + 0.5 * times #[0.5,1]
|
||
times = np.concatenate([np.zeros(1), times])
|
||
times = torch.tensor(times).float().cuda()
|
||
self.times = times
|
||
bs = self.real_A0.size(0)
|
||
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
|
||
self.time_idx = time_idx
|
||
|
||
with torch.no_grad():
|
||
self.netG.eval()
|
||
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
|
||
for t in range(self.time_idx.int().item() + 1):
|
||
# 计算增量 delta 与 inter/scale,用于每个时间步的插值等
|
||
if t > 0:
|
||
delta = times[t] - times[t - 1]
|
||
denom = times[-1] - times[t - 1]
|
||
inter = (delta / denom).reshape(-1, 1, 1, 1)
|
||
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
|
||
|
||
# 对 Xt、Xt2 进行随机噪声更新
|
||
Xt = self.real_A0 if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
|
||
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device)
|
||
time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long()
|
||
z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device)
|
||
self.time = times[time_idx]
|
||
Xt_1 = self.netG(Xt, self.time, z)
|
||
|
||
Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
|
||
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device)
|
||
time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long()
|
||
z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device)
|
||
Xt_12 = self.netG(Xt2, self.time, z)
|
||
|
||
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
|
||
self.real_A_noisy = Xt.detach()
|
||
self.real_A_noisy2 = Xt2.detach()
|
||
|
||
# ============ 第三步:拼接输入并执行网络推理 =============
|
||
bs = self.real_A0.size(0)
|
||
z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device)
|
||
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
|
||
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
|
||
self.real = self.real_A0
|
||
self.realt = self.real_A_noisy
|
||
|
||
if self.opt.flip_equivariance:
|
||
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
|
||
if self.flipped_for_equivariance:
|
||
self.real = torch.flip(self.real, [3])
|
||
self.realt = torch.flip(self.realt, [3])
|
||
|
||
print(f'fake_B0: {self.real_A0.shape}, fake_B1: {self.real_A1.shape}')
|
||
self.fake_B0 = self.netG(self.real_A0, self.time, z_in)
|
||
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
|
||
print(f'fake_B0: {self.fake_B0.shape}, fake_B1: {self.fake_B1.shape}')
|
||
|
||
if self.opt.phase == 'train':
|
||
real_A0 = self.real_A0
|
||
real_A1 = self.real_A1
|
||
real_B0 = self.real_B0
|
||
real_B1 = self.real_B1
|
||
fake_B0 = self.fake_B0
|
||
fake_B1 = self.fake_B1
|
||
|
||
self.real_A0_resize = self.resize(real_A0)
|
||
self.real_A1_resize = self.resize(real_A1)
|
||
real_B0 = self.resize(real_B0)
|
||
real_B1 = self.resize(real_B1)
|
||
self.fake_B0_resize = self.resize(fake_B0)
|
||
self.fake_B1_resize = self.resize(fake_B1)
|
||
|
||
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
||
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
||
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
|
||
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
||
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
||
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||
# [[1,576,768],[1,576,768],[1,576,768]]
|
||
# [3,576,768]
|
||
|
||
## 生成图像的梯度
|
||
#fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens.sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
||
#
|
||
## 梯度图
|
||
#self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
||
#
|
||
## 生成图像的CTN光流图
|
||
#self.f_content = self.ctn(self.weight_fake)
|
||
#
|
||
## 变换后的图片
|
||
#self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
|
||
#self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
|
||
#
|
||
## 经过第二次生成器
|
||
#self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
|
||
|
||
#warped_fake_B0_2=self.warped_fake_B0_2
|
||
#warped_fake_B0=self.warped_fake_B0
|
||
#self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
|
||
#self.warped_fake_B0_resize = self.resize(warped_fake_B0)
|
||
#self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
|
||
#self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
|
||
|
||
|
||
def compute_D_loss(self): #判别器还是没有改
|
||
"""Calculate GAN loss for the discriminator"""
|
||
|
||
lambda_D_ViT = self.opt.lambda_D_ViT
|
||
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
|
||
fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach()
|
||
|
||
real_B0_tokens = self.mutil_real_B0_tokens[0]
|
||
real_B1_tokens = self.mutil_real_B1_tokens[0]
|
||
|
||
|
||
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
|
||
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
|
||
|
||
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
|
||
|
||
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
|
||
pred_real1_ViT = self.netD_ViT(real_B1_tokens)
|
||
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
|
||
|
||
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
|
||
|
||
|
||
return self.loss_D_ViT
|
||
|
||
def compute_E_loss(self):
|
||
"""计算判别器 E 的损失"""
|
||
|
||
print(f'resl_A_noisy: {self.real_A_noisy.shape} \n fake_B0: {self.fake_B0.shape}')
|
||
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
|
||
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
|
||
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
|
||
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
|
||
|
||
return self.loss_E
|
||
|
||
def compute_G_loss(self):
|
||
"""计算生成器的 GAN 损失"""
|
||
|
||
if self.opt.lambda_GAN > 0.0:
|
||
pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
|
||
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
|
||
else:
|
||
self.loss_G_GAN = 0.0
|
||
self.loss_SB = 0
|
||
if self.opt.lambda_SB > 0.0:
|
||
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
|
||
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1)
|
||
|
||
bs = self.opt.batch_size
|
||
|
||
# eq.9
|
||
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0)
|
||
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
|
||
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
|
||
|
||
if self.opt.lambda_global > 0.0:
|
||
loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
|
||
loss_global *= 0.5
|
||
else:
|
||
loss_global = 0.0
|
||
|
||
self.l2_loss = 0.0
|
||
#if self.opt.lambda_ctn > 0.0:
|
||
# wapped_fake_B = warp(self.fake_B, self.f_content) # use updated self.f_content
|
||
# self.l2_loss = F.mse_loss(self.fake_B_2, wapped_fake_B) # complete the loss calculation
|
||
|
||
self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global
|
||
return self.loss_G
|
||
|
||
def calculate_attention_loss(self):
|
||
n_layers = len(self.atten_layers)
|
||
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
||
mutil_real_A1_tokens = self.mutil_real_A1_tokens
|
||
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
|
||
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
|
||
|
||
|
||
if self.opt.lambda_global > 0.0:
|
||
loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
|
||
loss_global *= 0.5
|
||
|
||
else:
|
||
loss_global = 0.0
|
||
|
||
if self.opt.lambda_spatial > 0.0:
|
||
loss_spatial = 0.0
|
||
local_nums = self.opt.local_nums
|
||
tokens_cnt = 576
|
||
local_id = np.random.permutation(tokens_cnt)
|
||
local_id = local_id[:int(min(local_nums, tokens_cnt))]
|
||
|
||
mutil_real_A0_local_tokens = self.netPreViT(self.resize(self.real_A0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
|
||
mutil_real_A1_local_tokens = self.netPreViT(self.resize(self.real_A1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
|
||
|
||
mutil_fake_B0_local_tokens = self.netPreViT(self.resize(self.fake_B0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
|
||
mutil_fake_B1_local_tokens = self.netPreViT(self.resize(self.fake_B1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
|
||
|
||
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
|
||
loss_spatial *= 0.5
|
||
|
||
else:
|
||
loss_spatial = 0.0
|
||
|
||
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
|
||
|
||
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
|
||
loss = 0.0
|
||
n_layers = len(self.atten_layers)
|
||
|
||
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
|
||
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
|
||
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
|
||
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
|
||
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
|
||
|
||
loss = loss / n_layers
|
||
return loss
|
||
|
||
|
||
|