2025-02-22 14:21:54 +08:00
import numpy as np
import math
import timm
import torch
import torch . nn as nn
import torch . nn . functional as F
from torchvision . transforms import GaussianBlur
from . base_model import BaseModel
from . import networks
from . patchnce import PatchNCELoss
import util . util as util
from torchvision . transforms import transforms as tfs
def warp ( image , flow ) : #warp操作
"""
基于光流的图像变形函数
Args :
image : [ B , C , H , W ] 输入图像
2025-02-23 15:22:14 +08:00
flow : [ B , 2 , H , W ] 光流场 ( x / y方向位移 )
2025-02-22 14:21:54 +08:00
Returns :
warped : [ B , C , H , W ] 变形后的图像
"""
B , C , H , W = image . shape
# 生成网格坐标
grid_x , grid_y = torch . meshgrid ( torch . arange ( W ) , torch . arange ( H ) )
grid = torch . stack ( ( grid_x , grid_y ) , dim = 0 ) . float ( ) . to ( image . device ) # [2,H,W]
grid = grid . unsqueeze ( 0 ) . repeat ( B , 1 , 1 , 1 ) # [B,2,H,W]
2025-02-23 15:23:00 +08:00
# 应用光流位移(归一化到[-1,1])
2025-02-22 14:21:54 +08:00
new_grid = grid + flow
new_grid [ : , 0 , : , : ] = 2.0 * new_grid [ : , 0 , : , : ] / ( W - 1 ) - 1.0 # x方向
new_grid [ : , 1 , : , : ] = 2.0 * new_grid [ : , 1 , : , : ] / ( H - 1 ) - 1.0 # y方向
new_grid = new_grid . permute ( 0 , 2 , 3 , 1 ) # [B,H,W,2]
# 双线性插值
return F . grid_sample ( image , new_grid , align_corners = True )
# 时序归一化损失计算
def compute_ctn_loss ( G , x , F_content ) : #公式10
"""
计算内容感知时序归一化损失
Args :
G : 生成器
x : 输入红外图像 [ B , C , H , W ]
F_content : 生成的光流场 [ B , 2 , H , W ]
"""
# 生成可见光图像
y_fake = G ( x ) # [B,3,H,W]
# 对生成结果应用光流变形
warped_fake = warp ( y_fake , F_content ) # [B,3,H,W]
# 对输入应用相同光流后生成图像
warped_x = warp ( x , F_content ) # [B,C,H,W]
y_fake_warped = G ( warped_x ) # [B,3,H,W]
# 计算L2损失
loss = F . mse_loss ( warped_fake , y_fake_warped )
return loss
class ContentAwareOptimization ( nn . Module ) :
def __init__ ( self , lambda_inc = 2.0 , eta_ratio = 0.4 ) :
super ( ) . __init__ ( )
self . lambda_inc = lambda_inc # 权重增强系数
self . eta_ratio = eta_ratio # 选择内容区域的比例
def compute_cosine_similarity ( self , gradients ) :
"""
计算每个patch梯度与平均梯度的余弦相似度
Args :
2025-02-23 15:22:14 +08:00
gradients : [ B , N , D ] 判别器输出的每个patch的梯度 ( N = w * h )
2025-02-22 14:21:54 +08:00
Returns :
cosine_sim : [ B , N ] 每个patch的余弦相似度
"""
mean_grad = torch . mean ( gradients , dim = 1 , keepdim = True ) # [B, 1, D]
# 计算余弦相似度
cosine_sim = F . cosine_similarity ( gradients , mean_grad , dim = 2 ) # [B, N]
return cosine_sim
2025-02-23 14:37:14 +08:00
def generate_weight_map ( self , gradients_fake ) :
2025-02-22 14:21:54 +08:00
"""
生成内容感知权重图
Args :
gradients_fake : [ B , N , D ] 生成图像判别器梯度
Returns :
weight_fake : [ B , N ] 生成图像权重图
"""
# 计算生成图像块的余弦相似度
cosine_fake = self . compute_cosine_similarity ( gradients_fake ) # [B, N]
2025-02-23 15:23:00 +08:00
# 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
2025-02-23 14:37:14 +08:00
k = int ( self . eta_ratio * cosine_fake . shape [ 1 ] )
2025-02-23 15:23:00 +08:00
# 对生成图像生成权重图(同理)
2025-02-22 14:21:54 +08:00
_ , fake_indices = torch . topk ( - cosine_fake , k , dim = 1 )
weight_fake = torch . ones_like ( cosine_fake )
for b in range ( cosine_fake . shape [ 0 ] ) :
weight_fake [ b , fake_indices [ b ] ] = self . lambda_inc / ( 1e-6 + torch . abs ( cosine_fake [ b , fake_indices [ b ] ] ) )
2025-02-23 14:37:14 +08:00
return weight_fake
2025-02-22 14:21:54 +08:00
def forward ( self , D_real , D_fake , real_scores , fake_scores ) :
"""
计算内容感知对抗损失
Args :
D_real : 判别器对真实图像的特征输出 [ B , C , H , W ]
D_fake : 判别器对生成图像的特征输出 [ B , C , H , W ]
real_scores : 真实图像的判别器预测 [ B , N ] ( N = H * W )
fake_scores : 生成图像的判别器预测 [ B , N ]
Returns :
loss_co_adv : 内容感知对抗损失
"""
B , C , H , W = D_real . shape
N = H * W
# 注册钩子获取梯度
gradients_real = [ ]
gradients_fake = [ ]
def hook_real ( grad ) :
gradients_real . append ( grad . detach ( ) . view ( B , N , - 1 ) )
def hook_fake ( grad ) :
gradients_fake . append ( grad . detach ( ) . view ( B , N , - 1 ) )
D_real . register_hook ( hook_real )
D_fake . register_hook ( hook_fake )
# 计算原始对抗损失以触发梯度计算
loss_real = torch . mean ( torch . log ( real_scores + 1e-8 ) )
loss_fake = torch . mean ( torch . log ( 1 - fake_scores + 1e-8 ) )
# 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递
loss_dummy = 1e-8 * ( D_real . sum ( ) + D_fake . sum ( ) )
total_loss = loss_real + loss_fake + loss_dummy
total_loss . backward ( retain_graph = True )
# 获取梯度数据
gradients_real = gradients_real [ 0 ] # [B, N, D]
gradients_fake = gradients_fake [ 0 ] # [B, N, D]
# 生成权重图
self . weight_real , self . weight_fake = self . generate_weight_map ( gradients_real , gradients_fake )
# 应用权重到对抗损失
loss_co_real = torch . mean ( self . weight_real * torch . log ( real_scores + 1e-8 ) )
loss_co_fake = torch . mean ( self . weight_fake * torch . log ( 1 - fake_scores + 1e-8 ) )
# 计算并返回最终内容感知对抗损失
loss_co_adv = - ( loss_co_real + loss_co_fake )
return loss_co_adv
class ContentAwareTemporalNorm ( nn . Module ) :
def __init__ ( self , gamma_stride = 0.1 , kernel_size = 21 , sigma = 5.0 ) :
super ( ) . __init__ ( )
self . gamma_stride = gamma_stride # 控制整体运动幅度
self . smoother = GaussianBlur ( kernel_size , sigma = sigma ) # 高斯平滑层
def forward ( self , weight_map ) :
"""
生成内容感知光流
Args :
2025-02-23 15:23:00 +08:00
weight_map : [ B , 1 , H , W ] 权重图 ( 来自内容感知优化模块 )
2025-02-22 14:21:54 +08:00
Returns :
2025-02-23 15:22:14 +08:00
F_content : [ B , 2 , H , W ] 生成的光流场 ( x / y方向位移 )
2025-02-22 14:21:54 +08:00
"""
B , _ , H , W = weight_map . shape
# 1. 归一化权重图
# 保持区域相对强度,同时限制数值范围
weight_norm = F . normalize ( weight_map , p = 1 , dim = ( 2 , 3 ) ) # L1归一化 [B,1,H,W]
2025-02-23 15:23:00 +08:00
# 2. 生成高斯噪声(与光流场同尺寸)
2025-02-22 14:21:54 +08:00
z = torch . randn ( B , 2 , H , W , device = weight_map . device ) # [B,2,H,W]
# 3. 合成基础光流
2025-02-23 15:23:00 +08:00
# 将权重图扩展为2通道(x/y方向共享权重)
2025-02-22 14:21:54 +08:00
weight_expanded = weight_norm . expand ( - 1 , 2 , - 1 , - 1 ) # [B,2,H,W]
F_raw = self . gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
2025-02-23 15:23:00 +08:00
# 4. 平滑处理(保持结构连续性)
2025-02-22 14:21:54 +08:00
# 对每个通道独立进行高斯模糊
F_smooth = self . smoother ( F_raw ) # [B,2,H,W]
2025-02-23 15:23:00 +08:00
# 5. 动态范围调整(可选)
2025-02-22 14:21:54 +08:00
# 限制光流幅值,避免极端位移
F_content = torch . tanh ( F_smooth ) # 缩放到[-1,1]范围
return F_content
2025-02-23 15:27:15 +08:00
class RomaUnsbModel ( BaseModel ) :
2025-02-22 14:21:54 +08:00
@staticmethod
def modify_commandline_options ( parser , is_train = True ) :
""" 配置 CTNx 模型的特定选项 """
2025-02-23 15:22:14 +08:00
parser . add_argument ( ' --lambda_GAN ' , type = float , default = 1.0 , help = ' weight for GAN loss: GAN(G(X)) ' )
2025-02-22 14:21:54 +08:00
parser . add_argument ( ' --lambda_NCE ' , type = float , default = 1.0 , help = ' weight for NCE loss: NCE(G(X), X) ' )
parser . add_argument ( ' --lambda_SB ' , type = float , default = 0.1 , help = ' weight for SB loss ' )
parser . add_argument ( ' --lambda_ctn ' , type = float , default = 1.0 , help = ' weight for content-aware temporal norm ' )
parser . add_argument ( ' --nce_idt ' , type = util . str2bool , nargs = ' ? ' , const = True , default = False , help = ' use NCE loss for identity mapping: NCE(G(Y), Y)) ' )
parser . add_argument ( ' --nce_layers ' , type = str , default = ' 0,4,8,12,16 ' , help = ' compute NCE loss on which layers ' )
parser . add_argument ( ' --nce_includes_all_negatives_from_minibatch ' ,
type = util . str2bool , nargs = ' ? ' , const = True , default = False ,
help = ' (used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details. ' )
parser . add_argument ( ' --netF ' , type = str , default = ' mlp_sample ' , choices = [ ' sample ' , ' reshape ' , ' mlp_sample ' ] , help = ' how to downsample the feature map ' )
parser . add_argument ( ' --netF_nc ' , type = int , default = 256 )
parser . add_argument ( ' --nce_T ' , type = float , default = 0.07 , help = ' temperature for NCE loss ' )
parser . add_argument ( ' --lmda_1 ' , type = float , default = 0.1 )
parser . add_argument ( ' --num_patches ' , type = int , default = 256 , help = ' number of patches per layer ' )
parser . add_argument ( ' --flip_equivariance ' ,
type = util . str2bool , nargs = ' ? ' , const = True , default = False ,
help = " Enforce flip-equivariance as additional regularization. It ' s used by FastCUT, but not CUT " )
parser . add_argument ( ' --lambda_inc ' , type = float , default = 1.0 , help = ' incremental weight for content-aware optimization ' )
parser . add_argument ( ' --eta_ratio ' , type = float , default = 0.1 , help = ' ratio of content-rich regions ' )
2025-02-23 15:46:18 +08:00
parser . add_argument ( ' --atten_layers ' , type = str , default = ' 1,3,5 ' , help = ' compute Cross-Similarity on which layers ' )
2025-02-22 14:21:54 +08:00
2025-02-23 15:57:25 +08:00
parser . add_argument ( ' --tau ' , type = float , default = 0.1 , help = ' used in unsb ' )
2025-02-23 16:02:17 +08:00
parser . add_argument ( ' --num_timesteps ' , type = int , default = 10 , help = ' used in unsb ' )
2025-02-22 14:21:54 +08:00
parser . set_defaults ( pool_size = 0 ) # no image pooling
opt , _ = parser . parse_known_args ( )
# 直接设置为 sb 模式
parser . set_defaults ( nce_idt = True , lambda_NCE = 1.0 )
return parser
def __init__ ( self , opt ) :
""" 初始化 CTNx 模型 """
BaseModel . __init__ ( self , opt )
# 指定需要打印的训练损失
self . loss_names = [ ' G_GAN_1 ' , ' D_real_1 ' , ' D_fake_1 ' , ' G_1 ' , ' NCE_1 ' , ' SB_1 ' ,
' G_2 ' ]
self . visual_names = [ ' real_A ' , ' real_A_noisy ' , ' fake_B ' , ' real_B ' ]
self . atten_layers = [ int ( i ) for i in self . opt . atten_layers . split ( ' , ' ) ]
if self . opt . phase == ' test ' :
self . visual_names = [ ' real ' ]
for NFE in range ( self . opt . num_timesteps ) :
fake_name = ' fake_ ' + str ( NFE + 1 )
self . visual_names . append ( fake_name )
self . nce_layers = [ int ( i ) for i in self . opt . nce_layers . split ( ' , ' ) ]
if opt . nce_idt and self . isTrain :
self . loss_names + = [ ' NCE_Y ' ]
self . visual_names + = [ ' idt_B ' ]
if self . isTrain :
2025-02-23 15:51:57 +08:00
self . model_names = [ ' G ' , ' D ' , ' E ' ]
2025-02-22 14:21:54 +08:00
else :
2025-02-23 15:51:57 +08:00
self . model_names = [ ' G ' ]
2025-02-22 14:21:54 +08:00
# 创建网络
self . netG = networks . define_G ( opt . input_nc , opt . output_nc , opt . ngf , opt . netG , opt . normG , not opt . no_dropout , opt . init_type , opt . init_gain , opt . no_antialias , opt . no_antialias_up , self . gpu_ids , opt )
if self . isTrain :
self . netD = networks . define_D ( opt . output_nc , opt . ndf , opt . netD , opt . n_layers_D , opt . normD , opt . init_type , opt . init_gain , opt . no_antialias , self . gpu_ids , opt )
self . netE = networks . define_D ( opt . output_nc * 4 , opt . ndf , opt . netD , opt . n_layers_D , opt . normD , opt . init_type , opt . init_gain , opt . no_antialias , self . gpu_ids , opt )
self . resize = tfs . Resize ( size = ( 384 , 384 ) )
# 加入预训练VIT
self . netPreViT = timm . create_model ( " vit_base_patch16_384 " , pretrained = True ) . to ( self . device )
# 定义损失函数
self . criterionGAN = networks . GANLoss ( opt . gan_mode ) . to ( self . device )
self . criterionNCE = [ ]
for nce_layer in self . nce_layers :
self . criterionNCE . append ( PatchNCELoss ( opt ) . to ( self . device ) )
self . criterionIdt = torch . nn . L1Loss ( ) . to ( self . device )
self . optimizer_G1 = torch . optim . Adam ( self . netG . parameters ( ) , lr = opt . lr , betas = ( opt . beta1 , opt . beta2 ) )
self . optimizer_D1 = torch . optim . Adam ( self . netD . parameters ( ) , lr = opt . lr , betas = ( opt . beta1 , opt . beta2 ) )
self . optimizer_E1 = torch . optim . Adam ( self . netE . parameters ( ) , lr = opt . lr , betas = ( opt . beta1 , opt . beta2 ) )
self . optimizers = [ self . optimizer_G1 , self . optimizer_D1 , self . optimizer_E1 ]
self . cao = ContentAwareOptimization ( opt . lambda_inc , opt . eta_ratio ) #损失函数
self . ctn = ContentAwareTemporalNorm ( ) #生成的伪光流
def data_dependent_initialize ( self , data ) :
"""
The feature network netF is defined in terms of the shape of the intermediate , extracted
features of the encoder portion of netG . Because of this , the weights of netF are
initialized at the first feedforward pass with some input images .
Please also see PatchSampleF . create_mlp ( ) , which is called at the first forward ( ) call .
"""
#bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
#self.set_input(data)
#self.real_A = self.real_A[:bs_per_gpu]
#self.real_B = self.real_B[:bs_per_gpu]
#self.forward() # compute fake images: G(A)
#if self.opt.isTrain:
#
# self.compute_G_loss().backward()
# self.compute_D_loss().backward()
# self.compute_E_loss().backward()
# if self.opt.lambda_NCE > 0.0:
# self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
# self.optimizers.append(self.optimizer_F)
pass
def optimize_parameters ( self ) :
# forward
self . forward ( )
self . netG . train ( )
self . netE . train ( )
self . netD . train ( )
# update D
self . set_requires_grad ( self . netD , True )
self . optimizer_D . zero_grad ( )
self . loss_D = self . compute_D_loss ( )
self . loss_D . backward ( )
self . optimizer_D . step ( )
2025-02-22 15:23:52 +08:00
# update E
2025-02-22 14:21:54 +08:00
self . set_requires_grad ( self . netE , True )
self . optimizer_E . zero_grad ( )
self . loss_E = self . compute_E_loss ( )
self . loss_E . backward ( )
self . optimizer_E . step ( )
# update G
self . set_requires_grad ( self . netD , False )
self . set_requires_grad ( self . netE , False )
self . optimizer_G . zero_grad ( )
self . loss_G = self . compute_G_loss ( )
self . loss_G . backward ( )
self . optimizer_G . step ( )
def set_input ( self , input ) :
""" Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters :
input ( dict ) : include the data itself and its metadata information .
The option ' direction ' can be used to swap domain A and domain B .
"""
AtoB = self . opt . direction == ' AtoB '
self . real_A0 = input [ ' A0 ' if AtoB else ' B0 ' ] . to ( self . device )
self . real_A1 = input [ ' A1 ' if AtoB else ' B1 ' ] . to ( self . device )
self . real_B0 = input [ ' B0 ' if AtoB else ' A0 ' ] . to ( self . device )
self . real_B1 = input [ ' B1 ' if AtoB else ' A1 ' ] . to ( self . device )
self . image_paths = input [ ' A_paths ' if AtoB else ' B_paths ' ]
def tokens_concat ( self , origin_tokens , adjacent_size ) :
adj_size = adjacent_size
B , token_num , C = origin_tokens . shape [ 0 ] , origin_tokens . shape [ 1 ] , origin_tokens . shape [ 2 ]
S = int ( math . sqrt ( token_num ) )
if S * S != token_num :
print ( ' Error! Not a square! ' )
token_map = origin_tokens . clone ( ) . reshape ( B , S , S , C )
cut_patch_list = [ ]
for i in range ( 0 , S , adj_size ) :
for j in range ( 0 , S , adj_size ) :
i_left = i
i_right = i + adj_size + 1 if i + adj_size < = S else S + 1
j_left = j
j_right = j + adj_size if j + adj_size < = S else S + 1
cut_patch = token_map [ : , i_left : i_right , j_left : j_right , : ]
cut_patch = cut_patch . reshape ( B , - 1 , C )
cut_patch = torch . mean ( cut_patch , dim = 1 , keepdim = True )
cut_patch_list . append ( cut_patch )
result = torch . cat ( cut_patch_list , dim = 1 )
return result
def cat_results ( self , origin_tokens , adj_size_list ) :
res_list = [ origin_tokens ]
for ad_s in adj_size_list :
cat_result = self . tokens_concat ( origin_tokens , ad_s )
res_list . append ( cat_result )
result = torch . cat ( res_list , dim = 1 )
return result
def forward ( self ) :
""" 执行前向传递以生成输出图像 """
if self . opt . isTrain :
real_A0 = self . resize ( self . real_A0 )
real_A1 = self . resize ( self . real_A1 )
real_B0 = self . resize ( self . real_B0 )
real_B1 = self . resize ( self . real_B1 )
# 使用VIT
self . mutil_real_A0_tokens = self . netPreViT ( real_A0 , self . atten_layers , get_tokens = True )
self . mutil_real_A1_tokens = self . netPreViT ( real_A1 , self . atten_layers , get_tokens = True )
# 执行一次SB模块
# ============ 第一步:初始化时间步与时间索引 ============
2025-02-23 15:23:00 +08:00
# 计算 times, 并确定当前 time_idx(随机选取用来表示当前时间步)
2025-02-22 14:21:54 +08:00
tau = self . opt . tau
T = self . opt . num_timesteps
incs = np . array ( [ 0 ] + [ 1 / ( i + 1 ) for i in range ( T - 1 ) ] )
times = np . cumsum ( incs )
times = times / times [ - 1 ]
times = 0.5 * times [ - 1 ] + 0.5 * times #[0.5,1]
times = np . concatenate ( [ np . zeros ( 1 ) , times ] )
times = torch . tensor ( times ) . float ( ) . cuda ( )
self . times = times
bs = self . mutil_real_A0_tokens . size ( 0 )
time_idx = ( torch . randint ( T , size = [ 1 ] ) . cuda ( ) * torch . ones ( size = [ 1 ] ) . cuda ( ) ) . long ( )
self . time_idx = time_idx
with torch . no_grad ( ) :
self . netG . eval ( )
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
for t in range ( self . time_idx . int ( ) . item ( ) + 1 ) :
# 计算增量 delta 与 inter/scale, 用于每个时间步的插值等
if t > 0 :
delta = times [ t ] - times [ t - 1 ]
denom = times [ - 1 ] - times [ t - 1 ]
inter = ( delta / denom ) . reshape ( - 1 , 1 , 1 , 1 )
scale = ( delta * ( 1 - delta / denom ) ) . reshape ( - 1 , 1 , 1 , 1 )
# 对 Xt、Xt2 进行随机噪声更新
Xt = self . mutil_real_A0_tokens if ( t == 0 ) else ( 1 - inter ) * Xt + inter * Xt_1 . detach ( ) + \
( scale * tau ) . sqrt ( ) * torch . randn_like ( Xt ) . to ( self . mutil_real_A0_tokens . device )
time_idx = ( t * torch . ones ( size = [ self . mutil_real_A0_tokens . shape [ 0 ] ] ) . to ( self . mutil_real_A0_tokens . device ) ) . long ( )
z = torch . randn ( size = [ self . mutil_real_A0_tokens . shape [ 0 ] , 4 * self . opt . ngf ] ) . to ( self . mutil_real_A0_tokens . device )
self . time = times [ time_idx ]
Xt_1 = self . netG ( Xt , self . time , z )
Xt2 = self . mutil_real_A1_tokens if ( t == 0 ) else ( 1 - inter ) * Xt2 + inter * Xt_12 . detach ( ) + \
( scale * tau ) . sqrt ( ) * torch . randn_like ( Xt2 ) . to ( self . mutil_real_A1_tokens . device )
time_idx = ( t * torch . ones ( size = [ self . mutil_real_A1_tokens . shape [ 0 ] ] ) . to ( self . mutil_real_A1_tokens . device ) ) . long ( )
z = torch . randn ( size = [ self . mutil_real_A1_tokens . shape [ 0 ] , 4 * self . opt . ngf ] ) . to ( self . mutil_real_A1_tokens . device )
Xt_12 = self . netG ( Xt2 , self . time , z )
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
self . real_A_noisy = Xt . detach ( )
self . real_A_noisy2 = Xt2 . detach ( )
# 保存noisy_map
2025-02-23 14:37:14 +08:00
self . noisy_map = self . real_A_noisy - self . real_A0
2025-02-22 14:21:54 +08:00
# ============ 第三步:拼接输入并执行网络推理 =============
bs = self . mutil_real_A0_tokens . size ( 0 )
z_in = torch . randn ( size = [ 2 * bs , 4 * self . opt . ngf ] ) . to ( self . mutil_real_A0_tokens . device )
z_in2 = torch . randn ( size = [ bs , 4 * self . opt . ngf ] ) . to ( self . mutil_real_A1_tokens . device )
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
self . real = self . mutil_real_A0_tokens
self . realt = self . real_A_noisy
if self . opt . flip_equivariance :
self . flipped_for_equivariance = self . opt . isTrain and ( np . random . random ( ) < 0.5 )
if self . flipped_for_equivariance :
self . real = torch . flip ( self . real , [ 3 ] )
self . realt = torch . flip ( self . realt , [ 3 ] )
# 使用 netG 生成最终的 fake, fake_B2 等结果
self . fake_B = self . netG ( self . realt , self . time , z_in )
self . fake_B2 = self . netG ( self . real , self . time , z_in2 )
self . fake_B = self . resize ( self . fake_B )
self . fake_B2 = self . resize ( self . fake_B2 )
self . fake_B0 = self . fake_B
self . fake_B1 = self . fake_B2
# 使用VIT
self . mutil_fake_B0_tokens = self . netPreViT ( self . fake_B , self . atten_layers , get_tokens = True )
self . mutil_fake_B1_tokens = self . netPreViT ( self . fake_B2 , self . atten_layers , get_tokens = True )
# ============ 第四步:推理模式下的多次采样 ============
if self . opt . phase == ' test ' :
tau = self . opt . tau
T = self . opt . num_timesteps
incs = np . array ( [ 0 ] + [ 1 / ( i + 1 ) for i in range ( T - 1 ) ] )
times = np . cumsum ( incs )
times = times / times [ - 1 ]
times = 0.5 * times [ - 1 ] + 0.5 * times
times = np . concatenate ( [ np . zeros ( 1 ) , times ] )
times = torch . tensor ( times ) . float ( ) . cuda ( )
self . times = times
bs = self . real . size ( 0 )
time_idx = ( torch . randint ( T , size = [ 1 ] ) . cuda ( ) * torch . ones ( size = [ 1 ] ) . cuda ( ) ) . long ( )
self . time_idx = time_idx
visuals = [ ]
with torch . no_grad ( ) :
self . netG . eval ( )
for t in range ( self . opt . num_timesteps ) :
if t > 0 :
delta = times [ t ] - times [ t - 1 ]
denom = times [ - 1 ] - times [ t - 1 ]
inter = ( delta / denom ) . reshape ( - 1 , 1 , 1 , 1 )
scale = ( delta * ( 1 - delta / denom ) ) . reshape ( - 1 , 1 , 1 , 1 )
Xt = self . mutil_real_A0_tokens if ( t == 0 ) else ( 1 - inter ) * Xt + inter * Xt_1 . detach ( ) + ( scale * tau ) . sqrt ( ) * torch . randn_like ( Xt ) . to ( self . mutil_real_A0_tokens . device )
time_idx = ( t * torch . ones ( size = [ self . mutil_real_A0_tokens . shape [ 0 ] ] ) . to ( self . mutil_real_A0_tokens . device ) ) . long ( )
time = times [ time_idx ]
z = torch . randn ( size = [ self . mutil_real_A0_tokens . shape [ 0 ] , 4 * self . opt . ngf ] ) . to ( self . mutil_real_A0_tokens . device )
Xt_1 = self . netG ( Xt , time_idx , z )
setattr ( self , " fake_ " + str ( t + 1 ) , Xt_1 )
if self . opt . phase == ' train ' :
# 真实图像的梯度
real_gradient = torch . autograd . grad ( self . real_B . sum ( ) , self . real_B , create_graph = True ) [ 0 ]
# 生成图像的梯度
fake_gradient = torch . autograd . grad ( self . fake_B . sum ( ) , self . fake_B , create_graph = True ) [ 0 ]
# 梯度图
self . weight_real , self . weight_fake = self . cao . generate_weight_map ( real_gradient , fake_gradient )
# 生成图像的CTN光流图
self . f_content = self . ctn ( self . weight_fake )
# 把前面生成后的图片再加上noisy_map
self . fake_B_2 = self . fake_B + self . noisy_map
# 变换后的图片
wapped_fake_B = warp ( self . fake_B , self . f_content )
# 经过第二次生成器
self . fake_B_2 = self . netG ( wapped_fake_B , self . time , z_in )
def compute_D_loss ( self ) :
""" 计算判别器的 GAN 损失 """
fake = self . cat_results ( self . fake_B . detach ( ) )
pred_fake = self . netD ( fake , self . time )
self . loss_D_fake = self . criterionGAN ( pred_fake , False ) . mean ( )
self . pred_real = self . netD ( self . real_B0 , self . time )
loss_D_real = self . criterionGAN ( self . pred_real , True )
self . loss_D_real = loss_D_real . mean ( )
self . loss_D = ( self . loss_D_fake + self . loss_D_real ) * 0.5
return self . loss_D
def compute_E_loss ( self ) :
""" 计算判别器 E 的损失 """
XtXt_1 = torch . cat ( [ self . real_A_noisy , self . fake_B . detach ( ) ] , dim = 1 )
XtXt_2 = torch . cat ( [ self . real_A_noisy2 , self . fake_B2 . detach ( ) ] , dim = 1 )
temp = torch . logsumexp ( self . netE ( XtXt_1 , self . time , XtXt_2 ) . reshape ( - 1 ) , dim = 0 ) . mean ( )
self . loss_E = - self . netE ( XtXt_1 , self . time , XtXt_1 ) . mean ( ) + temp + temp * * 2
return self . loss_E
def compute_G_loss ( self ) :
""" 计算生成器的 GAN 损失 """
bs = self . mutil_real_A0_tokens . size ( 0 )
tau = self . opt . tau
fake = self . fake_B
std = torch . rand ( size = [ 1 ] ) . item ( ) * self . opt . std
if self . opt . lambda_GAN > 0.0 :
pred_fake = self . netD ( fake , self . time )
self . loss_G_GAN = self . criterionGAN ( pred_fake , True ) . mean ( ) * self . opt . lambda_GAN
else :
self . loss_G_GAN = 0.0
self . loss_SB = 0
if self . opt . lambda_SB > 0.0 :
XtXt_1 = torch . cat ( [ self . real_A_noisy , self . fake_B ] , dim = 1 )
XtXt_2 = torch . cat ( [ self . real_A_noisy2 , self . fake_B2 ] , dim = 1 )
bs = self . opt . batch_size
# eq.9
ET_XY = self . netE ( XtXt_1 , self . time , XtXt_1 ) . mean ( ) - torch . logsumexp ( self . netE ( XtXt_1 , self . time , XtXt_2 ) . reshape ( - 1 ) , dim = 0 )
self . loss_SB = - ( self . opt . num_timesteps - self . time [ 0 ] ) / self . opt . num_timesteps * self . opt . tau * ET_XY
self . loss_SB + = self . opt . tau * torch . mean ( ( self . real_A_noisy - self . fake_B ) * * 2 )
if self . opt . lambda_global > 0.0 :
loss_global = self . calculate_similarity ( self . mutil_real_A0_tokens , self . mutil_fake_B0_tokens ) + self . calculate_similarity ( self . mutil_real_A1_tokens , self . mutil_fake_B1_tokens )
loss_global * = 0.5
else :
loss_global = 0.0
if self . opt . lambda_ctn > 0.0 :
wapped_fake_B = warp ( self . fake_B , self . f_content ) # use updated self.f_content
self . l2_loss = F . mse_loss ( self . fake_B_2 , wapped_fake_B ) # complete the loss calculation
self . loss_G = self . loss_G_GAN + self . opt . lambda_SB * self . loss_SB + self . opt . lambda_ctn * self . l2_loss + loss_global * self . opt . lambda_global
return self . loss_G
def calculate_attention_loss ( self ) :
n_layers = len ( self . atten_layers )
mutil_real_A0_tokens = self . mutil_real_A0_tokens
mutil_real_A1_tokens = self . mutil_real_A1_tokens
mutil_fake_B0_tokens = self . mutil_fake_B0_tokens
mutil_fake_B1_tokens = self . mutil_fake_B1_tokens
if self . opt . lambda_global > 0.0 :
loss_global = self . calculate_similarity ( mutil_real_A0_tokens , mutil_fake_B0_tokens ) + self . calculate_similarity ( mutil_real_A1_tokens , mutil_fake_B1_tokens )
loss_global * = 0.5
else :
loss_global = 0.0
if self . opt . lambda_spatial > 0.0 :
loss_spatial = 0.0
local_nums = self . opt . local_nums
tokens_cnt = 576
local_id = np . random . permutation ( tokens_cnt )
local_id = local_id [ : int ( min ( local_nums , tokens_cnt ) ) ]
mutil_real_A0_local_tokens = self . netPreViT ( self . resize ( self . real_A0 ) , self . atten_layers , get_tokens = True , local_id = local_id , side_length = self . opt . side_length )
mutil_real_A1_local_tokens = self . netPreViT ( self . resize ( self . real_A1 ) , self . atten_layers , get_tokens = True , local_id = local_id , side_length = self . opt . side_length )
mutil_fake_B0_local_tokens = self . netPreViT ( self . resize ( self . fake_B0 ) , self . atten_layers , get_tokens = True , local_id = local_id , side_length = self . opt . side_length )
mutil_fake_B1_local_tokens = self . netPreViT ( self . resize ( self . fake_B1 ) , self . atten_layers , get_tokens = True , local_id = local_id , side_length = self . opt . side_length )
loss_spatial = self . calculate_similarity ( mutil_real_A0_local_tokens , mutil_fake_B0_local_tokens ) + self . calculate_similarity ( mutil_real_A1_local_tokens , mutil_fake_B1_local_tokens )
loss_spatial * = 0.5
else :
loss_spatial = 0.0
return loss_global * self . opt . lambda_global , loss_spatial * self . opt . lambda_spatial
def calculate_similarity ( self , mutil_src_tokens , mutil_tgt_tokens ) :
loss = 0.0
n_layers = len ( self . atten_layers )
for src_tokens , tgt_tokens in zip ( mutil_src_tokens , mutil_tgt_tokens ) :
src_tgt = src_tokens . bmm ( tgt_tokens . permute ( 0 , 2 , 1 ) )
tgt_src = tgt_tokens . bmm ( src_tokens . permute ( 0 , 2 , 1 ) )
cos_dis_global = F . cosine_similarity ( src_tgt , tgt_src , dim = - 1 )
loss + = self . criterionL1 ( torch . ones_like ( cos_dis_global ) , cos_dis_global ) . mean ( )
loss = loss / n_layers
return loss