2025-02-22 14:21:54 +08:00
import numpy as np
import math
import timm
import torch
import torch . nn as nn
import torch . nn . functional as F
from torchvision . transforms import GaussianBlur
from . base_model import BaseModel
from . import networks
from . patchnce import PatchNCELoss
import util . util as util
from torchvision . transforms import transforms as tfs
def warp ( image , flow ) : #warp操作
"""
基于光流的图像变形函数
Args :
image : [ B , C , H , W ] 输入图像
2025-02-23 15:22:14 +08:00
flow : [ B , 2 , H , W ] 光流场 ( x / y方向位移 )
2025-02-22 14:21:54 +08:00
Returns :
warped : [ B , C , H , W ] 变形后的图像
"""
B , C , H , W = image . shape
# 生成网格坐标
grid_x , grid_y = torch . meshgrid ( torch . arange ( W ) , torch . arange ( H ) )
grid = torch . stack ( ( grid_x , grid_y ) , dim = 0 ) . float ( ) . to ( image . device ) # [2,H,W]
grid = grid . unsqueeze ( 0 ) . repeat ( B , 1 , 1 , 1 ) # [B,2,H,W]
2025-02-23 15:23:00 +08:00
# 应用光流位移(归一化到[-1,1])
2025-02-22 14:21:54 +08:00
new_grid = grid + flow
new_grid [ : , 0 , : , : ] = 2.0 * new_grid [ : , 0 , : , : ] / ( W - 1 ) - 1.0 # x方向
new_grid [ : , 1 , : , : ] = 2.0 * new_grid [ : , 1 , : , : ] / ( H - 1 ) - 1.0 # y方向
new_grid = new_grid . permute ( 0 , 2 , 3 , 1 ) # [B,H,W,2]
# 双线性插值
return F . grid_sample ( image , new_grid , align_corners = True )
# 时序归一化损失计算
def compute_ctn_loss ( G , x , F_content ) : #公式10
"""
计算内容感知时序归一化损失
Args :
G : 生成器
x : 输入红外图像 [ B , C , H , W ]
F_content : 生成的光流场 [ B , 2 , H , W ]
"""
# 生成可见光图像
y_fake = G ( x ) # [B,3,H,W]
# 对生成结果应用光流变形
warped_fake = warp ( y_fake , F_content ) # [B,3,H,W]
# 对输入应用相同光流后生成图像
warped_x = warp ( x , F_content ) # [B,C,H,W]
y_fake_warped = G ( warped_x ) # [B,3,H,W]
# 计算L2损失
loss = F . mse_loss ( warped_fake , y_fake_warped )
return loss
class ContentAwareOptimization ( nn . Module ) :
def __init__ ( self , lambda_inc = 2.0 , eta_ratio = 0.4 ) :
super ( ) . __init__ ( )
self . lambda_inc = lambda_inc # 权重增强系数
self . eta_ratio = eta_ratio # 选择内容区域的比例
2025-02-26 22:07:11 +08:00
# 改为类成员变量,确保钩子函数可访问
self . gradients_real = [ ]
self . gradients_fake = [ ]
2025-02-22 14:21:54 +08:00
def compute_cosine_similarity ( self , gradients ) :
"""
计算每个patch梯度与平均梯度的余弦相似度
Args :
2025-02-23 15:22:14 +08:00
gradients : [ B , N , D ] 判别器输出的每个patch的梯度 ( N = w * h )
2025-02-22 14:21:54 +08:00
Returns :
cosine_sim : [ B , N ] 每个patch的余弦相似度
"""
mean_grad = torch . mean ( gradients , dim = 1 , keepdim = True ) # [B, 1, D]
# 计算余弦相似度
cosine_sim = F . cosine_similarity ( gradients , mean_grad , dim = 2 ) # [B, N]
return cosine_sim
2025-02-26 22:07:11 +08:00
def generate_weight_map ( self , gradients_real , gradients_fake ) :
2025-02-22 14:21:54 +08:00
"""
生成内容感知权重图
Args :
2025-02-26 22:07:11 +08:00
gradients_real : [ B , N , D ] 真实图像判别器梯度
gradients_fake : [ B , N , D ] 生成图像判别器梯度
2025-02-22 14:21:54 +08:00
Returns :
2025-02-26 22:07:11 +08:00
weight_real : [ B , N ] 真实图像权重图
weight_fake : [ B , N ] 生成图像权重图
2025-02-22 14:21:54 +08:00
"""
2025-02-26 22:07:11 +08:00
# 计算真实图像块的余弦相似度
cosine_real = self . compute_cosine_similarity ( gradients_real ) # [B, N] 公式5
2025-02-22 14:21:54 +08:00
# 计算生成图像块的余弦相似度
cosine_fake = self . compute_cosine_similarity ( gradients_fake ) # [B, N]
2025-02-26 22:07:11 +08:00
# 选择内容丰富的区域( 余弦相似度最低的eta_ratio比例)
k = int ( self . eta_ratio * cosine_real . shape [ 1 ] )
# 对真实图像生成权重图
_ , real_indices = torch . topk ( - cosine_real , k , dim = 1 ) # 选择最不相似的区域
weight_real = torch . ones_like ( cosine_real )
for b in range ( cosine_real . shape [ 0 ] ) :
weight_real [ b , real_indices [ b ] ] = self . lambda_inc / ( 1e-6 + torch . abs ( cosine_real [ b , real_indices [ b ] ] ) ) #公式6
# 对生成图像生成权重图(同理)
2025-02-22 14:21:54 +08:00
_ , fake_indices = torch . topk ( - cosine_fake , k , dim = 1 )
weight_fake = torch . ones_like ( cosine_fake )
for b in range ( cosine_fake . shape [ 0 ] ) :
weight_fake [ b , fake_indices [ b ] ] = self . lambda_inc / ( 1e-6 + torch . abs ( cosine_fake [ b , fake_indices [ b ] ] ) )
2025-02-26 22:07:11 +08:00
return weight_real , weight_fake
2025-02-22 14:21:54 +08:00
def forward ( self , D_real , D_fake , real_scores , fake_scores ) :
2025-02-26 22:07:11 +08:00
# 清空梯度缓存
self . gradients_real . clear ( )
self . gradients_fake . clear ( )
2025-02-22 14:21:54 +08:00
2025-02-26 22:07:11 +08:00
# 注册钩子
hook_real = lambda grad : self . gradients_real . append ( grad . detach ( ) )
hook_fake = lambda grad : self . gradients_fake . append ( grad . detach ( ) )
2025-02-22 14:21:54 +08:00
D_real . register_hook ( hook_real )
D_fake . register_hook ( hook_fake )
2025-02-26 22:07:11 +08:00
# 触发梯度计算
( real_scores . mean ( ) + fake_scores . mean ( ) ) . backward ( retain_graph = True )
2025-02-22 14:21:54 +08:00
2025-02-26 22:07:11 +08:00
# 获取梯度并调整维度
grad_real = self . gradients_real [ 0 ] # [B, N, D]
grad_fake = self . gradients_fake [ 0 ]
2025-02-22 14:21:54 +08:00
# 生成权重图
2025-02-26 22:07:11 +08:00
weight_real , weight_fake = self . generate_weight_map ( grad_real , grad_fake )
2025-02-22 14:21:54 +08:00
2025-02-26 22:07:11 +08:00
# 计算加权损失
loss_co_real = ( weight_real * torch . log ( real_scores + 1e-8 ) ) . mean ( )
loss_co_fake = ( weight_fake * torch . log ( 1 - fake_scores + 1e-8 ) ) . mean ( )
2025-02-22 14:21:54 +08:00
2025-02-26 22:07:11 +08:00
return - ( loss_co_real + loss_co_fake ) , weight_real , weight_fake
2025-02-22 14:21:54 +08:00
class ContentAwareTemporalNorm ( nn . Module ) :
def __init__ ( self , gamma_stride = 0.1 , kernel_size = 21 , sigma = 5.0 ) :
super ( ) . __init__ ( )
self . gamma_stride = gamma_stride # 控制整体运动幅度
self . smoother = GaussianBlur ( kernel_size , sigma = sigma ) # 高斯平滑层
2025-02-26 22:07:11 +08:00
def upsample_weight_map ( self , weight_patch , target_size = ( 256 , 256 ) ) :
"""
将patch级别的权重图上采样到目标分辨率
Args :
weight_patch : [ B , 1 , 24 , 24 ] 来自ViT的patch权重图
target_size : 目标分辨率 ( H , W )
Returns :
weight_full : [ B , 1 , 256 , 256 ] 上采样后的全分辨率权重图
"""
# 使用双线性插值上采样
B = weight_patch . shape [ 0 ]
weight_patch = weight_patch . view ( B , 1 , 24 , 24 )
weight_full = F . interpolate (
weight_patch ,
size = target_size ,
mode = ' bilinear ' ,
align_corners = False
)
# 对每个16x16的patch内部保持权重一致( 可选)
# 通过平均池化再扩展,消除插值引入的渐变
weight_full = F . avg_pool2d ( weight_full , kernel_size = 16 , stride = 16 )
weight_full = F . interpolate ( weight_full , scale_factor = 16 , mode = ' nearest ' )
return weight_full
2025-02-22 14:21:54 +08:00
def forward ( self , weight_map ) :
"""
生成内容感知光流
Args :
2025-02-23 15:23:00 +08:00
weight_map : [ B , 1 , H , W ] 权重图 ( 来自内容感知优化模块 )
2025-02-22 14:21:54 +08:00
Returns :
2025-02-23 15:22:14 +08:00
F_content : [ B , 2 , H , W ] 生成的光流场 ( x / y方向位移 )
2025-02-22 14:21:54 +08:00
"""
2025-02-26 22:07:11 +08:00
# 上采样权重图到全分辨率
weight_full = self . upsample_weight_map ( weight_map ) # [B,1,384,384]
2025-02-22 14:21:54 +08:00
# 1. 归一化权重图
# 保持区域相对强度,同时限制数值范围
2025-02-26 22:07:11 +08:00
weight_norm = F . normalize ( weight_full , p = 1 , dim = ( 2 , 3 ) ) # L1归一化 [B,1,H,W]
2025-02-22 14:21:54 +08:00
2025-02-26 22:07:11 +08:00
# 2. 生成高斯噪声
B , _ , H , W = weight_norm . shape
z = torch . randn ( B , 2 , H , W , device = weight_norm . device ) # [B,2,H,W]
2025-02-22 14:21:54 +08:00
# 3. 合成基础光流
2025-02-23 15:23:00 +08:00
# 将权重图扩展为2通道(x/y方向共享权重)
2025-02-22 14:21:54 +08:00
weight_expanded = weight_norm . expand ( - 1 , 2 , - 1 , - 1 ) # [B,2,H,W]
F_raw = self . gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
2025-02-23 15:23:00 +08:00
# 4. 平滑处理(保持结构连续性)
2025-02-22 14:21:54 +08:00
# 对每个通道独立进行高斯模糊
F_smooth = self . smoother ( F_raw ) # [B,2,H,W]
2025-02-23 15:23:00 +08:00
# 5. 动态范围调整(可选)
2025-02-22 14:21:54 +08:00
# 限制光流幅值,避免极端位移
F_content = torch . tanh ( F_smooth ) # 缩放到[-1,1]范围
return F_content
2025-02-23 15:27:15 +08:00
class RomaUnsbModel ( BaseModel ) :
2025-02-22 14:21:54 +08:00
@staticmethod
def modify_commandline_options ( parser , is_train = True ) :
""" 配置 CTNx 模型的特定选项 """
2025-02-23 15:22:14 +08:00
parser . add_argument ( ' --lambda_GAN ' , type = float , default = 1.0 , help = ' weight for GAN loss: GAN(G(X)) ' )
2025-02-22 14:21:54 +08:00
parser . add_argument ( ' --lambda_SB ' , type = float , default = 0.1 , help = ' weight for SB loss ' )
parser . add_argument ( ' --lambda_ctn ' , type = float , default = 1.0 , help = ' weight for content-aware temporal norm ' )
2025-02-23 23:15:25 +08:00
parser . add_argument ( ' --lambda_D_ViT ' , type = float , default = 1.0 , help = ' weight for discriminator ' )
parser . add_argument ( ' --lambda_global ' , type = float , default = 1.0 , help = ' weight for Global Structural Consistency ' )
2025-02-26 22:24:17 +08:00
parser . add_argument ( ' --lambda_inc ' , type = float , default = 1.0 , help = ' incremental weight for content-aware optimization ' )
2025-02-22 14:21:54 +08:00
parser . add_argument ( ' --nce_idt ' , type = util . str2bool , nargs = ' ? ' , const = True , default = False , help = ' use NCE loss for identity mapping: NCE(G(Y), Y)) ' )
parser . add_argument ( ' --nce_includes_all_negatives_from_minibatch ' ,
type = util . str2bool , nargs = ' ? ' , const = True , default = False ,
help = ' (used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details. ' )
2025-02-24 23:00:25 +08:00
parser . add_argument ( ' --nce_layers ' , type = str , default = ' 0,4,8,12,16 ' , help = ' compute NCE loss on which layers ' )
2025-02-22 14:21:54 +08:00
parser . add_argument ( ' --netF ' , type = str , default = ' mlp_sample ' , choices = [ ' sample ' , ' reshape ' , ' mlp_sample ' ] , help = ' how to downsample the feature map ' )
parser . add_argument ( ' --flip_equivariance ' ,
type = util . str2bool , nargs = ' ? ' , const = True , default = False ,
help = " Enforce flip-equivariance as additional regularization. It ' s used by FastCUT, but not CUT " )
2025-02-24 23:00:25 +08:00
parser . add_argument ( ' --eta_ratio ' , type = float , default = 0.4 , help = ' ratio of content-rich regions ' )
2025-02-23 15:46:18 +08:00
2025-02-23 23:15:25 +08:00
parser . add_argument ( ' --atten_layers ' , type = str , default = ' 5 ' , help = ' compute Cross-Similarity on which layers ' )
2025-02-22 14:21:54 +08:00
2025-02-23 18:42:21 +08:00
parser . add_argument ( ' --tau ' , type = float , default = 0.01 , help = ' Entropy parameter ' )
parser . add_argument ( ' --num_timesteps ' , type = int , default = 5 , help = ' # of discrim filters in the first conv layer ' )
parser . add_argument ( ' --n_mlp ' , type = int , default = 3 , help = ' only used if netD==n_layers ' )
2025-02-22 14:21:54 +08:00
opt , _ = parser . parse_known_args ( )
return parser
def __init__ ( self , opt ) :
""" 初始化 CTNx 模型 """
BaseModel . __init__ ( self , opt )
# 指定需要打印的训练损失
2025-02-26 22:24:17 +08:00
self . loss_names = [ ' G_GAN ' , ' D_real_ViT ' , ' D_fake_ViT ' , ' G ' , ' SB ' , ' global ' , ' ctn ' ]
2025-02-24 23:10:23 +08:00
self . visual_names = [ ' real_A0 ' , ' real_A_noisy ' , ' fake_B0 ' , ' real_B0 ' ]
2025-02-22 14:21:54 +08:00
self . atten_layers = [ int ( i ) for i in self . opt . atten_layers . split ( ' , ' ) ]
if self . opt . phase == ' test ' :
self . visual_names = [ ' real ' ]
for NFE in range ( self . opt . num_timesteps ) :
fake_name = ' fake_ ' + str ( NFE + 1 )
self . visual_names . append ( fake_name )
self . nce_layers = [ int ( i ) for i in self . opt . nce_layers . split ( ' , ' ) ]
if opt . nce_idt and self . isTrain :
self . loss_names + = [ ' NCE_Y ' ]
self . visual_names + = [ ' idt_B ' ]
if self . isTrain :
2025-02-23 23:15:25 +08:00
self . model_names = [ ' G ' , ' D_ViT ' , ' E ' ]
2025-02-22 14:21:54 +08:00
else :
2025-02-23 15:51:57 +08:00
self . model_names = [ ' G ' ]
2025-02-23 18:42:21 +08:00
2025-02-22 14:21:54 +08:00
# 创建网络
self . netG = networks . define_G ( opt . input_nc , opt . output_nc , opt . ngf , opt . netG , opt . normG , not opt . no_dropout , opt . init_type , opt . init_gain , opt . no_antialias , opt . no_antialias_up , self . gpu_ids , opt )
2025-02-23 23:15:25 +08:00
if self . isTrain :
2025-02-22 14:21:54 +08:00
self . netE = networks . define_D ( opt . output_nc * 4 , opt . ndf , opt . netD , opt . n_layers_D , opt . normD , opt . init_type , opt . init_gain , opt . no_antialias , self . gpu_ids , opt )
2025-02-23 18:42:21 +08:00
self . resize = tfs . Resize ( size = ( 384 , 384 ) , antialias = True )
2025-02-22 14:21:54 +08:00
2025-02-23 23:15:25 +08:00
self . netD_ViT = networks . MLPDiscriminator ( ) . to ( self . device )
2025-02-22 14:21:54 +08:00
# 加入预训练VIT
self . netPreViT = timm . create_model ( " vit_base_patch16_384 " , pretrained = True ) . to ( self . device )
# 定义损失函数
2025-02-23 23:15:25 +08:00
self . criterionL1 = torch . nn . L1Loss ( ) . to ( self . device )
2025-02-22 14:21:54 +08:00
self . criterionGAN = networks . GANLoss ( opt . gan_mode ) . to ( self . device )
self . criterionIdt = torch . nn . L1Loss ( ) . to ( self . device )
2025-02-23 22:40:36 +08:00
self . optimizer_G = torch . optim . Adam ( self . netG . parameters ( ) , lr = opt . lr , betas = ( opt . beta1 , opt . beta2 ) )
2025-02-23 23:15:25 +08:00
self . optimizer_D = torch . optim . Adam ( self . netD_ViT . parameters ( ) , lr = opt . lr , betas = ( opt . beta1 , opt . beta2 ) )
2025-02-23 22:40:36 +08:00
self . optimizer_E = torch . optim . Adam ( self . netE . parameters ( ) , lr = opt . lr , betas = ( opt . beta1 , opt . beta2 ) )
self . optimizers = [ self . optimizer_G , self . optimizer_D , self . optimizer_E ]
2025-02-22 14:21:54 +08:00
self . cao = ContentAwareOptimization ( opt . lambda_inc , opt . eta_ratio ) #损失函数
self . ctn = ContentAwareTemporalNorm ( ) #生成的伪光流
def data_dependent_initialize ( self , data ) :
"""
The feature network netF is defined in terms of the shape of the intermediate , extracted
features of the encoder portion of netG . Because of this , the weights of netF are
initialized at the first feedforward pass with some input images .
Please also see PatchSampleF . create_mlp ( ) , which is called at the first forward ( ) call .
"""
#bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
#self.set_input(data)
#self.real_A = self.real_A[:bs_per_gpu]
#self.real_B = self.real_B[:bs_per_gpu]
#self.forward() # compute fake images: G(A)
#if self.opt.isTrain:
#
# self.compute_G_loss().backward()
# self.compute_D_loss().backward()
# self.compute_E_loss().backward()
# if self.opt.lambda_NCE > 0.0:
# self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
# self.optimizers.append(self.optimizer_F)
pass
def optimize_parameters ( self ) :
# forward
self . forward ( )
self . netG . train ( )
self . netE . train ( )
2025-02-23 23:15:25 +08:00
self . netD_ViT . train ( )
2025-02-22 14:21:54 +08:00
# update D
2025-02-23 23:15:25 +08:00
self . set_requires_grad ( self . netD_ViT , True )
2025-02-22 14:21:54 +08:00
self . optimizer_D . zero_grad ( )
self . loss_D = self . compute_D_loss ( )
self . loss_D . backward ( )
self . optimizer_D . step ( )
2025-02-22 15:23:52 +08:00
# update E
2025-02-22 14:21:54 +08:00
self . set_requires_grad ( self . netE , True )
self . optimizer_E . zero_grad ( )
self . loss_E = self . compute_E_loss ( )
self . loss_E . backward ( )
self . optimizer_E . step ( )
# update G
2025-02-23 23:15:25 +08:00
self . set_requires_grad ( self . netD_ViT , False )
2025-02-22 14:21:54 +08:00
self . set_requires_grad ( self . netE , False )
self . optimizer_G . zero_grad ( )
self . loss_G = self . compute_G_loss ( )
self . loss_G . backward ( )
self . optimizer_G . step ( )
def set_input ( self , input ) :
""" Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters :
input ( dict ) : include the data itself and its metadata information .
The option ' direction ' can be used to swap domain A and domain B .
"""
AtoB = self . opt . direction == ' AtoB '
self . real_A0 = input [ ' A0 ' if AtoB else ' B0 ' ] . to ( self . device )
self . real_A1 = input [ ' A1 ' if AtoB else ' B1 ' ] . to ( self . device )
self . real_B0 = input [ ' B0 ' if AtoB else ' A0 ' ] . to ( self . device )
self . real_B1 = input [ ' B1 ' if AtoB else ' A1 ' ] . to ( self . device )
self . image_paths = input [ ' A_paths ' if AtoB else ' B_paths ' ]
def tokens_concat ( self , origin_tokens , adjacent_size ) :
adj_size = adjacent_size
B , token_num , C = origin_tokens . shape [ 0 ] , origin_tokens . shape [ 1 ] , origin_tokens . shape [ 2 ]
S = int ( math . sqrt ( token_num ) )
if S * S != token_num :
print ( ' Error! Not a square! ' )
token_map = origin_tokens . clone ( ) . reshape ( B , S , S , C )
cut_patch_list = [ ]
for i in range ( 0 , S , adj_size ) :
for j in range ( 0 , S , adj_size ) :
i_left = i
i_right = i + adj_size + 1 if i + adj_size < = S else S + 1
j_left = j
j_right = j + adj_size if j + adj_size < = S else S + 1
cut_patch = token_map [ : , i_left : i_right , j_left : j_right , : ]
cut_patch = cut_patch . reshape ( B , - 1 , C )
cut_patch = torch . mean ( cut_patch , dim = 1 , keepdim = True )
cut_patch_list . append ( cut_patch )
result = torch . cat ( cut_patch_list , dim = 1 )
return result
def cat_results ( self , origin_tokens , adj_size_list ) :
res_list = [ origin_tokens ]
for ad_s in adj_size_list :
cat_result = self . tokens_concat ( origin_tokens , ad_s )
res_list . append ( cat_result )
result = torch . cat ( res_list , dim = 1 )
return result
def forward ( self ) :
2025-02-23 22:26:04 +08:00
""" Run forward pass; called by both functions <optimize_parameters> and <test>. """
2025-02-22 14:21:54 +08:00
2025-02-23 22:26:04 +08:00
# ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============
2025-02-22 14:21:54 +08:00
tau = self . opt . tau
T = self . opt . num_timesteps
incs = np . array ( [ 0 ] + [ 1 / ( i + 1 ) for i in range ( T - 1 ) ] )
times = np . cumsum ( incs )
times = times / times [ - 1 ]
times = 0.5 * times [ - 1 ] + 0.5 * times #[0.5,1]
times = np . concatenate ( [ np . zeros ( 1 ) , times ] )
times = torch . tensor ( times ) . float ( ) . cuda ( )
self . times = times
2025-02-23 22:26:04 +08:00
bs = self . real_A0 . size ( 0 )
2025-02-22 14:21:54 +08:00
time_idx = ( torch . randint ( T , size = [ 1 ] ) . cuda ( ) * torch . ones ( size = [ 1 ] ) . cuda ( ) ) . long ( )
self . time_idx = time_idx
with torch . no_grad ( ) :
self . netG . eval ( )
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
for t in range ( self . time_idx . int ( ) . item ( ) + 1 ) :
# 计算增量 delta 与 inter/scale, 用于每个时间步的插值等
if t > 0 :
delta = times [ t ] - times [ t - 1 ]
denom = times [ - 1 ] - times [ t - 1 ]
inter = ( delta / denom ) . reshape ( - 1 , 1 , 1 , 1 )
scale = ( delta * ( 1 - delta / denom ) ) . reshape ( - 1 , 1 , 1 , 1 )
# 对 Xt、Xt2 进行随机噪声更新
2025-02-23 22:26:04 +08:00
Xt = self . real_A0 if ( t == 0 ) else ( 1 - inter ) * Xt + inter * Xt_1 . detach ( ) + \
( scale * tau ) . sqrt ( ) * torch . randn_like ( Xt ) . to ( self . real_A0 . device )
time_idx = ( t * torch . ones ( size = [ self . real_A0 . shape [ 0 ] ] ) . to ( self . real_A0 . device ) ) . long ( )
z = torch . randn ( size = [ self . real_A0 . shape [ 0 ] , 4 * self . opt . ngf ] ) . to ( self . real_A0 . device )
2025-02-22 14:21:54 +08:00
self . time = times [ time_idx ]
Xt_1 = self . netG ( Xt , self . time , z )
2025-02-23 22:26:04 +08:00
Xt2 = self . real_A1 if ( t == 0 ) else ( 1 - inter ) * Xt2 + inter * Xt_12 . detach ( ) + \
( scale * tau ) . sqrt ( ) * torch . randn_like ( Xt2 ) . to ( self . real_A1 . device )
time_idx = ( t * torch . ones ( size = [ self . real_A1 . shape [ 0 ] ] ) . to ( self . real_A1 . device ) ) . long ( )
z = torch . randn ( size = [ self . real_A1 . shape [ 0 ] , 4 * self . opt . ngf ] ) . to ( self . real_A1 . device )
2025-02-22 14:21:54 +08:00
Xt_12 = self . netG ( Xt2 , self . time , z )
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
self . real_A_noisy = Xt . detach ( )
self . real_A_noisy2 = Xt2 . detach ( )
# ============ 第三步:拼接输入并执行网络推理 =============
2025-02-23 22:26:04 +08:00
bs = self . real_A0 . size ( 0 )
2025-02-26 22:07:11 +08:00
self . z_in = torch . randn ( size = [ bs , 4 * self . opt . ngf ] ) . to ( self . real_A0 . device )
self . z_in2 = torch . randn ( size = [ bs , 4 * self . opt . ngf ] ) . to ( self . real_A1 . device )
2025-02-22 14:21:54 +08:00
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
2025-02-23 22:26:04 +08:00
self . real = self . real_A0
2025-02-22 14:21:54 +08:00
self . realt = self . real_A_noisy
if self . opt . flip_equivariance :
self . flipped_for_equivariance = self . opt . isTrain and ( np . random . random ( ) < 0.5 )
if self . flipped_for_equivariance :
self . real = torch . flip ( self . real , [ 3 ] )
self . realt = torch . flip ( self . realt , [ 3 ] )
2025-02-26 22:07:11 +08:00
self . fake_B0 = self . netG ( self . real_A0 , self . time , self . z_in )
self . fake_B1 = self . netG ( self . real_A1 , self . time , self . z_in2 )
2025-02-22 14:21:54 +08:00
if self . opt . phase == ' train ' :
2025-02-23 22:26:04 +08:00
real_A0 = self . real_A0
real_A1 = self . real_A1
real_B0 = self . real_B0
real_B1 = self . real_B1
fake_B0 = self . fake_B0
fake_B1 = self . fake_B1
self . real_A0_resize = self . resize ( real_A0 )
self . real_A1_resize = self . resize ( real_A1 )
real_B0 = self . resize ( real_B0 )
real_B1 = self . resize ( real_B1 )
self . fake_B0_resize = self . resize ( fake_B0 )
self . fake_B1_resize = self . resize ( fake_B1 )
self . mutil_real_A0_tokens = self . netPreViT ( self . real_A0_resize , self . atten_layers , get_tokens = True )
self . mutil_real_A1_tokens = self . netPreViT ( self . real_A1_resize , self . atten_layers , get_tokens = True )
self . mutil_real_B0_tokens = self . netPreViT ( real_B0 , self . atten_layers , get_tokens = True )
self . mutil_real_B1_tokens = self . netPreViT ( real_B1 , self . atten_layers , get_tokens = True )
self . mutil_fake_B0_tokens = self . netPreViT ( self . fake_B0_resize , self . atten_layers , get_tokens = True )
self . mutil_fake_B1_tokens = self . netPreViT ( self . fake_B1_resize , self . atten_layers , get_tokens = True )
2025-02-23 22:40:34 +08:00
# [[1,576,768],[1,576,768],[1,576,768]]
# [3,576,768]
2025-02-22 14:21:54 +08:00
2025-02-23 23:15:25 +08:00
def compute_D_loss ( self ) : #判别器还是没有改
""" Calculate GAN loss for the discriminator """
lambda_D_ViT = self . opt . lambda_D_ViT
fake_B0_tokens = self . mutil_fake_B0_tokens [ 0 ] . detach ( )
2025-02-26 22:07:11 +08:00
2025-02-23 23:15:25 +08:00
real_B0_tokens = self . mutil_real_B0_tokens [ 0 ]
pre_fake0_ViT = self . netD_ViT ( fake_B0_tokens )
2025-02-26 22:07:11 +08:00
self . loss_D_fake_ViT = self . criterionGAN ( pre_fake0_ViT , False )
2025-02-23 23:15:25 +08:00
2025-02-26 22:07:11 +08:00
pred_real0_ViT = self . netD_ViT ( real_B0_tokens )
self . loss_D_real_ViT = self . criterionGAN ( pred_real0_ViT , True )
2025-02-23 23:15:25 +08:00
2025-02-26 22:07:11 +08:00
self . losscao , self . weight_real , self . weight_fake = self . cao ( pred_real0_ViT , pre_fake0_ViT , self . loss_D_real_ViT , self . loss_D_fake_ViT )
return self . losscao * lambda_D_ViT
2025-02-22 14:21:54 +08:00
def compute_E_loss ( self ) :
""" 计算判别器 E 的损失 """
2025-02-23 22:40:34 +08:00
XtXt_1 = torch . cat ( [ self . real_A_noisy , self . fake_B0 . detach ( ) ] , dim = 1 )
XtXt_2 = torch . cat ( [ self . real_A_noisy2 , self . fake_B1 . detach ( ) ] , dim = 1 )
2025-02-22 14:21:54 +08:00
temp = torch . logsumexp ( self . netE ( XtXt_1 , self . time , XtXt_2 ) . reshape ( - 1 ) , dim = 0 ) . mean ( )
self . loss_E = - self . netE ( XtXt_1 , self . time , XtXt_1 ) . mean ( ) + temp + temp * * 2
return self . loss_E
def compute_G_loss ( self ) :
""" 计算生成器的 GAN 损失 """
2025-02-26 22:07:11 +08:00
if self . opt . lambda_ctn > 0.0 :
# 生成图像的CTN光流图
self . f_content = self . ctn ( self . weight_fake )
# 变换后的图片
self . warped_real_A_noisy2 = warp ( self . real_A_noisy , self . f_content )
self . warped_fake_B0 = warp ( self . fake_B0 , self . f_content )
# 经过第二次生成器
self . warped_fake_B0_2 = self . netG ( self . warped_real_A_noisy2 , self . time , self . z_in )
warped_fake_B0_2 = self . warped_fake_B0_2
warped_fake_B0 = self . warped_fake_B0
# 计算L2损失
2025-02-26 22:24:17 +08:00
self . loss_ctn = F . mse_loss ( warped_fake_B0_2 , warped_fake_B0 )
2025-02-26 22:07:11 +08:00
2025-02-22 14:21:54 +08:00
if self . opt . lambda_GAN > 0.0 :
2025-02-23 23:15:25 +08:00
pred_fake = self . netD_ViT ( self . mutil_fake_B0_tokens [ 0 ] )
2025-02-26 22:07:11 +08:00
self . loss_G_GAN = self . criterionGAN ( pred_fake , True ) . mean ( )
2025-02-22 14:21:54 +08:00
else :
self . loss_G_GAN = 0.0
2025-02-26 22:07:11 +08:00
2025-02-22 14:21:54 +08:00
self . loss_SB = 0
if self . opt . lambda_SB > 0.0 :
2025-02-23 22:40:34 +08:00
XtXt_1 = torch . cat ( [ self . real_A_noisy , self . fake_B0 ] , dim = 1 )
XtXt_2 = torch . cat ( [ self . real_A_noisy2 , self . fake_B1 ] , dim = 1 )
2025-02-22 14:21:54 +08:00
bs = self . opt . batch_size
# eq.9
2025-02-26 22:07:11 +08:00
ET_XY = self . netE ( XtXt_1 , self . time , XtXt_1 ) . mean ( ) - self . netE ( XtXt_1 , self . time , XtXt_2 ) . mean ( )
2025-02-22 14:21:54 +08:00
self . loss_SB = - ( self . opt . num_timesteps - self . time [ 0 ] ) / self . opt . num_timesteps * self . opt . tau * ET_XY
2025-02-26 22:07:11 +08:00
self . loss_SB + = torch . mean ( ( self . real_A_noisy - self . fake_B0 ) * * 2 )
2025-02-22 14:21:54 +08:00
if self . opt . lambda_global > 0.0 :
2025-02-26 22:24:17 +08:00
self . loss_global = self . calculate_similarity ( self . real_A0 , self . fake_B0 ) + self . calculate_similarity ( self . real_A1 , self . fake_B1 )
self . loss_global * = 0.5
2025-02-22 14:21:54 +08:00
else :
2025-02-26 22:24:17 +08:00
self . loss_global = 0.0
2025-02-22 14:21:54 +08:00
2025-02-26 22:07:11 +08:00
self . loss_G = self . opt . lambda_GAN * self . loss_G_GAN + \
self . opt . lambda_SB * self . loss_SB + \
2025-02-26 22:24:17 +08:00
self . opt . lambda_ctn * self . loss_ctn + \
self . loss_global * self . opt . lambda_global
2025-02-22 14:21:54 +08:00
return self . loss_G
def calculate_attention_loss ( self ) :
n_layers = len ( self . atten_layers )
mutil_real_A0_tokens = self . mutil_real_A0_tokens
mutil_real_A1_tokens = self . mutil_real_A1_tokens
mutil_fake_B0_tokens = self . mutil_fake_B0_tokens
mutil_fake_B1_tokens = self . mutil_fake_B1_tokens
if self . opt . lambda_global > 0.0 :
loss_global = self . calculate_similarity ( mutil_real_A0_tokens , mutil_fake_B0_tokens ) + self . calculate_similarity ( mutil_real_A1_tokens , mutil_fake_B1_tokens )
loss_global * = 0.5
else :
loss_global = 0.0
if self . opt . lambda_spatial > 0.0 :
loss_spatial = 0.0
local_nums = self . opt . local_nums
tokens_cnt = 576
local_id = np . random . permutation ( tokens_cnt )
local_id = local_id [ : int ( min ( local_nums , tokens_cnt ) ) ]
mutil_real_A0_local_tokens = self . netPreViT ( self . resize ( self . real_A0 ) , self . atten_layers , get_tokens = True , local_id = local_id , side_length = self . opt . side_length )
mutil_real_A1_local_tokens = self . netPreViT ( self . resize ( self . real_A1 ) , self . atten_layers , get_tokens = True , local_id = local_id , side_length = self . opt . side_length )
mutil_fake_B0_local_tokens = self . netPreViT ( self . resize ( self . fake_B0 ) , self . atten_layers , get_tokens = True , local_id = local_id , side_length = self . opt . side_length )
mutil_fake_B1_local_tokens = self . netPreViT ( self . resize ( self . fake_B1 ) , self . atten_layers , get_tokens = True , local_id = local_id , side_length = self . opt . side_length )
loss_spatial = self . calculate_similarity ( mutil_real_A0_local_tokens , mutil_fake_B0_local_tokens ) + self . calculate_similarity ( mutil_real_A1_local_tokens , mutil_fake_B1_local_tokens )
loss_spatial * = 0.5
else :
loss_spatial = 0.0
return loss_global * self . opt . lambda_global , loss_spatial * self . opt . lambda_spatial
def calculate_similarity ( self , mutil_src_tokens , mutil_tgt_tokens ) :
loss = 0.0
n_layers = len ( self . atten_layers )
for src_tokens , tgt_tokens in zip ( mutil_src_tokens , mutil_tgt_tokens ) :
src_tgt = src_tokens . bmm ( tgt_tokens . permute ( 0 , 2 , 1 ) )
tgt_src = tgt_tokens . bmm ( src_tokens . permute ( 0 , 2 , 1 ) )
cos_dis_global = F . cosine_similarity ( src_tgt , tgt_src , dim = - 1 )
loss + = self . criterionL1 ( torch . ones_like ( cos_dis_global ) , cos_dis_global ) . mean ( )
loss = loss / n_layers
return loss