2025-02-22 14:21:54 +08:00
import numpy as np
import math
import timm
import torch
2025-03-07 18:43:06 +08:00
import torchvision . models as models
2025-02-22 14:21:54 +08:00
import torch . nn as nn
import torch . nn . functional as F
from torchvision . transforms import GaussianBlur
from . base_model import BaseModel
from . import networks
from . patchnce import PatchNCELoss
import util . util as util
from torchvision . transforms import transforms as tfs
def warp ( image , flow ) : #warp操作
"""
基于光流的图像变形函数
Args :
image : [ B , C , H , W ] 输入图像
2025-02-23 15:22:14 +08:00
flow : [ B , 2 , H , W ] 光流场 ( x / y方向位移 )
2025-02-22 14:21:54 +08:00
Returns :
warped : [ B , C , H , W ] 变形后的图像
"""
B , C , H , W = image . shape
# 生成网格坐标
grid_x , grid_y = torch . meshgrid ( torch . arange ( W ) , torch . arange ( H ) )
grid = torch . stack ( ( grid_x , grid_y ) , dim = 0 ) . float ( ) . to ( image . device ) # [2,H,W]
grid = grid . unsqueeze ( 0 ) . repeat ( B , 1 , 1 , 1 ) # [B,2,H,W]
2025-02-23 15:23:00 +08:00
# 应用光流位移(归一化到[-1,1])
2025-02-22 14:21:54 +08:00
new_grid = grid + flow
new_grid [ : , 0 , : , : ] = 2.0 * new_grid [ : , 0 , : , : ] / ( W - 1 ) - 1.0 # x方向
new_grid [ : , 1 , : , : ] = 2.0 * new_grid [ : , 1 , : , : ] / ( H - 1 ) - 1.0 # y方向
new_grid = new_grid . permute ( 0 , 2 , 3 , 1 ) # [B,H,W,2]
# 双线性插值
return F . grid_sample ( image , new_grid , align_corners = True )
# 时序归一化损失计算
def compute_ctn_loss ( G , x , F_content ) : #公式10
"""
计算内容感知时序归一化损失
Args :
G : 生成器
x : 输入红外图像 [ B , C , H , W ]
F_content : 生成的光流场 [ B , 2 , H , W ]
"""
# 生成可见光图像
y_fake = G ( x ) # [B,3,H,W]
# 对生成结果应用光流变形
warped_fake = warp ( y_fake , F_content ) # [B,3,H,W]
# 对输入应用相同光流后生成图像
warped_x = warp ( x , F_content ) # [B,C,H,W]
y_fake_warped = G ( warped_x ) # [B,3,H,W]
# 计算L2损失
loss = F . mse_loss ( warped_fake , y_fake_warped )
return loss
2025-03-07 18:43:06 +08:00
2025-02-22 14:21:54 +08:00
class ContentAwareOptimization ( nn . Module ) :
def __init__ ( self , lambda_inc = 2.0 , eta_ratio = 0.4 ) :
super ( ) . __init__ ( )
2025-03-07 18:43:06 +08:00
self . lambda_inc = lambda_inc
self . eta_ratio = eta_ratio
2025-02-26 22:07:11 +08:00
self . gradients_real = [ ]
self . gradients_fake = [ ]
2025-03-07 18:43:06 +08:00
2025-02-22 14:21:54 +08:00
def compute_cosine_similarity ( self , gradients ) :
2025-03-07 18:43:06 +08:00
mean_grad = torch . mean ( gradients , dim = 1 , keepdim = True )
return F . cosine_similarity ( gradients , mean_grad , dim = 2 )
2025-02-22 14:21:54 +08:00
2025-02-26 22:07:11 +08:00
def generate_weight_map ( self , gradients_real , gradients_fake ) :
2025-03-07 18:43:06 +08:00
# 计算余弦相似度
cosine_real = self . compute_cosine_similarity ( gradients_real )
cosine_fake = self . compute_cosine_similarity ( gradients_fake )
# 生成权重图(优化实现)
def _get_weights ( cosine ) :
k = int ( self . eta_ratio * cosine . shape [ 1 ] )
_ , indices = torch . topk ( - cosine , k , dim = 1 )
weights = torch . ones_like ( cosine )
weights . scatter_ ( 1 , indices , self . lambda_inc / ( 1e-6 + torch . abs ( cosine . gather ( 1 , indices ) ) ) )
return weights
weight_real = _get_weights ( cosine_real )
weight_fake = _get_weights ( cosine_fake )
2025-02-26 22:07:11 +08:00
return weight_real , weight_fake
2025-02-22 14:21:54 +08:00
def forward ( self , D_real , D_fake , real_scores , fake_scores ) :
2025-02-26 22:07:11 +08:00
# 清空梯度缓存
self . gradients_real . clear ( )
2025-03-07 18:43:06 +08:00
self . gradients_fake . clear ( )
self . criterionGAN = networks . GANLoss ( ' lsgan ' ) . cuda ( )
# 注册钩子捕获梯度
hook_real = lambda grad : self . gradients_real . append ( grad . detach ( ) )
hook_fake = lambda grad : self . gradients_fake . append ( grad . detach ( ) )
2025-02-22 14:21:54 +08:00
D_real . register_hook ( hook_real )
D_fake . register_hook ( hook_fake )
2025-03-07 18:43:06 +08:00
# 触发梯度计算(保留计算图)
2025-02-26 22:07:11 +08:00
( real_scores . mean ( ) + fake_scores . mean ( ) ) . backward ( retain_graph = True )
2025-03-07 18:43:06 +08:00
2025-02-26 22:07:11 +08:00
# 获取梯度并调整维度
2025-03-07 18:43:06 +08:00
grad_real = self . gradients_real [ 0 ] . flatten ( 1 ) # [B, N, D] → [B, N*D]
grad_fake = self . gradients_fake [ 0 ] . flatten ( 1 )
2025-02-22 14:21:54 +08:00
# 生成权重图
2025-03-07 18:43:06 +08:00
weight_real , weight_fake = self . generate_weight_map (
grad_real . view ( * D_real . shape ) ,
grad_fake . view ( * D_fake . shape )
)
2025-02-22 14:21:54 +08:00
2025-03-07 18:43:06 +08:00
# 正确应用权重到对数概率( 论文公式7)
loss_co_real = torch . mean ( weight_real * self . criterionGAN ( real_scores , True ) )
loss_co_fake = torch . mean ( weight_fake * self . criterionGAN ( fake_scores , False ) )
2025-02-22 14:21:54 +08:00
2025-03-07 18:43:06 +08:00
# 总损失(注意符号:判别器需最大化该损失)
loss_co_adv = ( loss_co_real + loss_co_fake ) * 0.5
return loss_co_adv , weight_real , weight_fake
2025-02-22 14:21:54 +08:00
class ContentAwareTemporalNorm ( nn . Module ) :
def __init__ ( self , gamma_stride = 0.1 , kernel_size = 21 , sigma = 5.0 ) :
super ( ) . __init__ ( )
self . gamma_stride = gamma_stride # 控制整体运动幅度
self . smoother = GaussianBlur ( kernel_size , sigma = sigma ) # 高斯平滑层
2025-02-26 22:07:11 +08:00
def upsample_weight_map ( self , weight_patch , target_size = ( 256 , 256 ) ) :
"""
将patch级别的权重图上采样到目标分辨率
Args :
weight_patch : [ B , 1 , 24 , 24 ] 来自ViT的patch权重图
target_size : 目标分辨率 ( H , W )
Returns :
weight_full : [ B , 1 , 256 , 256 ] 上采样后的全分辨率权重图
"""
# 使用双线性插值上采样
B = weight_patch . shape [ 0 ]
weight_patch = weight_patch . view ( B , 1 , 24 , 24 )
weight_full = F . interpolate (
weight_patch ,
size = target_size ,
mode = ' bilinear ' ,
align_corners = False
)
# 对每个16x16的patch内部保持权重一致( 可选)
# 通过平均池化再扩展,消除插值引入的渐变
weight_full = F . avg_pool2d ( weight_full , kernel_size = 16 , stride = 16 )
weight_full = F . interpolate ( weight_full , scale_factor = 16 , mode = ' nearest ' )
return weight_full
2025-02-22 14:21:54 +08:00
def forward ( self , weight_map ) :
"""
生成内容感知光流
Args :
2025-02-23 15:23:00 +08:00
weight_map : [ B , 1 , H , W ] 权重图 ( 来自内容感知优化模块 )
2025-02-22 14:21:54 +08:00
Returns :
2025-02-23 15:22:14 +08:00
F_content : [ B , 2 , H , W ] 生成的光流场 ( x / y方向位移 )
2025-02-22 14:21:54 +08:00
"""
2025-02-26 22:07:11 +08:00
# 上采样权重图到全分辨率
weight_full = self . upsample_weight_map ( weight_map ) # [B,1,384,384]
2025-02-22 14:21:54 +08:00
# 1. 归一化权重图
# 保持区域相对强度,同时限制数值范围
2025-02-26 22:07:11 +08:00
weight_norm = F . normalize ( weight_full , p = 1 , dim = ( 2 , 3 ) ) # L1归一化 [B,1,H,W]
2025-02-22 14:21:54 +08:00
2025-02-26 22:07:11 +08:00
# 2. 生成高斯噪声
B , _ , H , W = weight_norm . shape
z = torch . randn ( B , 2 , H , W , device = weight_norm . device ) # [B,2,H,W]
2025-02-22 14:21:54 +08:00
# 3. 合成基础光流
2025-02-23 15:23:00 +08:00
# 将权重图扩展为2通道(x/y方向共享权重)
2025-02-22 14:21:54 +08:00
weight_expanded = weight_norm . expand ( - 1 , 2 , - 1 , - 1 ) # [B,2,H,W]
F_raw = self . gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
2025-02-23 15:23:00 +08:00
# 4. 平滑处理(保持结构连续性)
2025-02-22 14:21:54 +08:00
# 对每个通道独立进行高斯模糊
F_smooth = self . smoother ( F_raw ) # [B,2,H,W]
2025-02-23 15:23:00 +08:00
# 5. 动态范围调整(可选)
2025-02-22 14:21:54 +08:00
# 限制光流幅值,避免极端位移
F_content = torch . tanh ( F_smooth ) # 缩放到[-1,1]范围
2025-03-07 18:43:06 +08:00
return F_content
2025-02-22 14:21:54 +08:00
2025-02-23 15:27:15 +08:00
class RomaUnsbModel ( BaseModel ) :
2025-02-22 14:21:54 +08:00
@staticmethod
def modify_commandline_options ( parser , is_train = True ) :
""" 配置 CTNx 模型的特定选项 """
2025-02-23 15:22:14 +08:00
parser . add_argument ( ' --lambda_GAN ' , type = float , default = 1.0 , help = ' weight for GAN loss: GAN(G(X)) ' )
2025-02-22 14:21:54 +08:00
parser . add_argument ( ' --lambda_SB ' , type = float , default = 0.1 , help = ' weight for SB loss ' )
parser . add_argument ( ' --lambda_ctn ' , type = float , default = 1.0 , help = ' weight for content-aware temporal norm ' )
2025-02-23 23:15:25 +08:00
parser . add_argument ( ' --lambda_D_ViT ' , type = float , default = 1.0 , help = ' weight for discriminator ' )
parser . add_argument ( ' --lambda_global ' , type = float , default = 1.0 , help = ' weight for Global Structural Consistency ' )
2025-03-07 18:43:06 +08:00
parser . add_argument ( ' --lambda_spatial ' , type = float , default = 1.0 , help = ' weight for Local Structural Consistency ' )
2025-02-26 22:24:17 +08:00
parser . add_argument ( ' --lambda_inc ' , type = float , default = 1.0 , help = ' incremental weight for content-aware optimization ' )
2025-03-07 18:43:06 +08:00
parser . add_argument ( ' --local_nums ' , type = int , default = 64 , help = ' number of local patches ' )
parser . add_argument ( ' --side_length ' , type = int , default = 7 )
2025-02-22 14:21:54 +08:00
parser . add_argument ( ' --nce_idt ' , type = util . str2bool , nargs = ' ? ' , const = True , default = False , help = ' use NCE loss for identity mapping: NCE(G(Y), Y)) ' )
parser . add_argument ( ' --nce_includes_all_negatives_from_minibatch ' ,
type = util . str2bool , nargs = ' ? ' , const = True , default = False ,
help = ' (used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details. ' )
2025-02-24 23:00:25 +08:00
parser . add_argument ( ' --nce_layers ' , type = str , default = ' 0,4,8,12,16 ' , help = ' compute NCE loss on which layers ' )
2025-03-07 18:43:06 +08:00
2025-02-22 14:21:54 +08:00
parser . add_argument ( ' --netF ' , type = str , default = ' mlp_sample ' , choices = [ ' sample ' , ' reshape ' , ' mlp_sample ' ] , help = ' how to downsample the feature map ' )
2025-03-07 18:43:06 +08:00
2025-02-24 23:00:25 +08:00
parser . add_argument ( ' --eta_ratio ' , type = float , default = 0.4 , help = ' ratio of content-rich regions ' )
2025-03-07 18:43:06 +08:00
parser . add_argument ( ' --gamma_stride ' , type = float , default = 20 , help = ' ratio of stride for computing the similarity matrix ' )
2025-02-23 23:15:25 +08:00
parser . add_argument ( ' --atten_layers ' , type = str , default = ' 5 ' , help = ' compute Cross-Similarity on which layers ' )
2025-02-22 14:21:54 +08:00
2025-02-23 18:42:21 +08:00
parser . add_argument ( ' --tau ' , type = float , default = 0.01 , help = ' Entropy parameter ' )
parser . add_argument ( ' --num_timesteps ' , type = int , default = 5 , help = ' # of discrim filters in the first conv layer ' )
parser . add_argument ( ' --n_mlp ' , type = int , default = 3 , help = ' only used if netD==n_layers ' )
2025-02-22 14:21:54 +08:00
opt , _ = parser . parse_known_args ( )
return parser
def __init__ ( self , opt ) :
""" 初始化 CTNx 模型 """
BaseModel . __init__ ( self , opt )
# 指定需要打印的训练损失
2025-03-07 18:43:06 +08:00
self . loss_names = [ ' G_GAN ' , ' D_ViT ' , ' G ' , ' global ' , ' spatial ' , ' ctn ' ]
self . visual_names = [ ' real_A0 ' , ' fake_B0_1 ' , ' fake_B0 ' , ' real_B0 ' , ' real_A1 ' , ' fake_B1_1 ' , ' fake_B1 ' , ' real_B1 ' ]
2025-02-22 14:21:54 +08:00
self . atten_layers = [ int ( i ) for i in self . opt . atten_layers . split ( ' , ' ) ]
2025-03-07 18:43:06 +08:00
2025-02-22 14:21:54 +08:00
if self . opt . phase == ' test ' :
self . visual_names = [ ' real ' ]
for NFE in range ( self . opt . num_timesteps ) :
fake_name = ' fake_ ' + str ( NFE + 1 )
self . visual_names . append ( fake_name )
self . nce_layers = [ int ( i ) for i in self . opt . nce_layers . split ( ' , ' ) ]
if self . isTrain :
2025-03-07 18:43:06 +08:00
self . model_names = [ ' G ' , ' D_ViT ' ]
2025-02-22 14:21:54 +08:00
else :
2025-02-23 15:51:57 +08:00
self . model_names = [ ' G ' ]
2025-02-23 18:42:21 +08:00
2025-02-22 14:21:54 +08:00
# 创建网络
self . netG = networks . define_G ( opt . input_nc , opt . output_nc , opt . ngf , opt . netG , opt . normG , not opt . no_dropout , opt . init_type , opt . init_gain , opt . no_antialias , opt . no_antialias_up , self . gpu_ids , opt )
2025-02-23 23:15:25 +08:00
if self . isTrain :
2025-02-22 14:21:54 +08:00
2025-02-23 18:42:21 +08:00
self . resize = tfs . Resize ( size = ( 384 , 384 ) , antialias = True )
2025-02-22 14:21:54 +08:00
2025-02-23 23:15:25 +08:00
self . netD_ViT = networks . MLPDiscriminator ( ) . to ( self . device )
2025-02-22 14:21:54 +08:00
# 加入预训练VIT
self . netPreViT = timm . create_model ( " vit_base_patch16_384 " , pretrained = True ) . to ( self . device )
# 定义损失函数
2025-02-23 23:15:25 +08:00
self . criterionL1 = torch . nn . L1Loss ( ) . to ( self . device )
2025-02-22 14:21:54 +08:00
self . criterionGAN = networks . GANLoss ( opt . gan_mode ) . to ( self . device )
2025-02-23 22:40:36 +08:00
self . optimizer_G = torch . optim . Adam ( self . netG . parameters ( ) , lr = opt . lr , betas = ( opt . beta1 , opt . beta2 ) )
2025-02-23 23:15:25 +08:00
self . optimizer_D = torch . optim . Adam ( self . netD_ViT . parameters ( ) , lr = opt . lr , betas = ( opt . beta1 , opt . beta2 ) )
2025-03-07 18:43:06 +08:00
self . optimizers = [ self . optimizer_G , self . optimizer_D ]
2025-02-22 14:21:54 +08:00
self . cao = ContentAwareOptimization ( opt . lambda_inc , opt . eta_ratio ) #损失函数
self . ctn = ContentAwareTemporalNorm ( ) #生成的伪光流
def data_dependent_initialize ( self , data ) :
"""
The feature network netF is defined in terms of the shape of the intermediate , extracted
features of the encoder portion of netG . Because of this , the weights of netF are
initialized at the first feedforward pass with some input images .
Please also see PatchSampleF . create_mlp ( ) , which is called at the first forward ( ) call .
"""
pass
def optimize_parameters ( self ) :
# forward
self . forward ( )
self . netG . train ( )
2025-02-23 23:15:25 +08:00
self . netD_ViT . train ( )
2025-02-22 14:21:54 +08:00
# update D
2025-02-23 23:15:25 +08:00
self . set_requires_grad ( self . netD_ViT , True )
2025-02-22 14:21:54 +08:00
self . optimizer_D . zero_grad ( )
self . loss_D = self . compute_D_loss ( )
self . loss_D . backward ( )
self . optimizer_D . step ( )
# update G
2025-03-07 18:43:06 +08:00
self . set_requires_grad ( self . netD_ViT , False )
2025-02-22 14:21:54 +08:00
self . optimizer_G . zero_grad ( )
self . loss_G = self . compute_G_loss ( )
self . loss_G . backward ( )
self . optimizer_G . step ( )
def set_input ( self , input ) :
""" Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters :
input ( dict ) : include the data itself and its metadata information .
The option ' direction ' can be used to swap domain A and domain B .
"""
AtoB = self . opt . direction == ' AtoB '
self . real_A0 = input [ ' A0 ' if AtoB else ' B0 ' ] . to ( self . device )
self . real_A1 = input [ ' A1 ' if AtoB else ' B1 ' ] . to ( self . device )
self . real_B0 = input [ ' B0 ' if AtoB else ' A0 ' ] . to ( self . device )
self . real_B1 = input [ ' B1 ' if AtoB else ' A1 ' ] . to ( self . device )
self . image_paths = input [ ' A_paths ' if AtoB else ' B_paths ' ]
def forward ( self ) :
2025-02-23 22:26:04 +08:00
""" Run forward pass; called by both functions <optimize_parameters> and <test>. """
2025-02-22 14:21:54 +08:00
2025-02-23 22:26:04 +08:00
# ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============
2025-02-22 14:21:54 +08:00
tau = self . opt . tau
T = self . opt . num_timesteps
incs = np . array ( [ 0 ] + [ 1 / ( i + 1 ) for i in range ( T - 1 ) ] )
times = np . cumsum ( incs )
times = times / times [ - 1 ]
times = 0.5 * times [ - 1 ] + 0.5 * times #[0.5,1]
times = np . concatenate ( [ np . zeros ( 1 ) , times ] )
times = torch . tensor ( times ) . float ( ) . cuda ( )
self . times = times
2025-02-23 22:26:04 +08:00
bs = self . real_A0 . size ( 0 )
2025-02-22 14:21:54 +08:00
time_idx = ( torch . randint ( T , size = [ 1 ] ) . cuda ( ) * torch . ones ( size = [ 1 ] ) . cuda ( ) ) . long ( )
self . time_idx = time_idx
2025-03-07 18:43:06 +08:00
self . fake_B0_list = [ ]
self . fake_B1_list = [ ]
2025-02-22 14:21:54 +08:00
with torch . no_grad ( ) :
self . netG . eval ( )
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
for t in range ( self . time_idx . int ( ) . item ( ) + 1 ) :
# 计算增量 delta 与 inter/scale, 用于每个时间步的插值等
if t > 0 :
delta = times [ t ] - times [ t - 1 ]
denom = times [ - 1 ] - times [ t - 1 ]
inter = ( delta / denom ) . reshape ( - 1 , 1 , 1 , 1 )
scale = ( delta * ( 1 - delta / denom ) ) . reshape ( - 1 , 1 , 1 , 1 )
# 对 Xt、Xt2 进行随机噪声更新
2025-02-23 22:26:04 +08:00
Xt = self . real_A0 if ( t == 0 ) else ( 1 - inter ) * Xt + inter * Xt_1 . detach ( ) + \
( scale * tau ) . sqrt ( ) * torch . randn_like ( Xt ) . to ( self . real_A0 . device )
time_idx = ( t * torch . ones ( size = [ self . real_A0 . shape [ 0 ] ] ) . to ( self . real_A0 . device ) ) . long ( )
z = torch . randn ( size = [ self . real_A0 . shape [ 0 ] , 4 * self . opt . ngf ] ) . to ( self . real_A0 . device )
2025-03-07 18:43:06 +08:00
time = times [ time_idx ]
Xt_1 = self . netG ( Xt . detach ( ) , time , z )
2025-02-22 14:21:54 +08:00
2025-02-23 22:26:04 +08:00
Xt2 = self . real_A1 if ( t == 0 ) else ( 1 - inter ) * Xt2 + inter * Xt_12 . detach ( ) + \
( scale * tau ) . sqrt ( ) * torch . randn_like ( Xt2 ) . to ( self . real_A1 . device )
time_idx = ( t * torch . ones ( size = [ self . real_A1 . shape [ 0 ] ] ) . to ( self . real_A1 . device ) ) . long ( )
z = torch . randn ( size = [ self . real_A1 . shape [ 0 ] , 4 * self . opt . ngf ] ) . to ( self . real_A1 . device )
2025-03-07 18:43:06 +08:00
Xt_12 = self . netG ( Xt2 . detach ( ) , time , z )
self . fake_B0_list . append ( Xt_1 )
self . fake_B1_list . append ( Xt_12 )
self . fake_B0_1 = self . fake_B0_list [ 0 ]
self . fake_B1_1 = self . fake_B0_list [ 0 ]
self . fake_B0 = self . fake_B0_list [ - 1 ]
self . fake_B1 = self . fake_B1_list [ - 1 ]
self . z_in = z
self . z_in2 = z
2025-02-22 14:21:54 +08:00
if self . opt . phase == ' train ' :
2025-02-23 22:26:04 +08:00
real_A0 = self . real_A0
real_A1 = self . real_A1
real_B0 = self . real_B0
real_B1 = self . real_B1
fake_B0 = self . fake_B0
fake_B1 = self . fake_B1
2025-03-07 18:43:06 +08:00
self . mutil_fake_B0_tokens_list = [ ]
self . mutil_fake_B1_tokens_list = [ ]
for fake_B0_t in self . fake_B0_list :
fake_B0_t_resize = self . resize ( fake_B0_t ) # 调整到 ViT 输入尺寸
tokens = self . netPreViT ( fake_B0_t_resize , self . atten_layers , get_tokens = True )
self . mutil_fake_B0_tokens_list . append ( tokens )
for fake_B1_t in self . fake_B1_list :
fake_B1_t_resize = self . resize ( fake_B1_t )
tokens = self . netPreViT ( fake_B1_t_resize , self . atten_layers , get_tokens = True )
self . mutil_fake_B1_tokens_list . append ( tokens )
2025-02-23 22:26:04 +08:00
self . real_A0_resize = self . resize ( real_A0 )
self . real_A1_resize = self . resize ( real_A1 )
real_B0 = self . resize ( real_B0 )
real_B1 = self . resize ( real_B1 )
self . fake_B0_resize = self . resize ( fake_B0 )
self . fake_B1_resize = self . resize ( fake_B1 )
2025-03-07 18:43:06 +08:00
2025-02-23 22:26:04 +08:00
self . mutil_real_A0_tokens = self . netPreViT ( self . real_A0_resize , self . atten_layers , get_tokens = True )
self . mutil_real_A1_tokens = self . netPreViT ( self . real_A1_resize , self . atten_layers , get_tokens = True )
self . mutil_real_B0_tokens = self . netPreViT ( real_B0 , self . atten_layers , get_tokens = True )
self . mutil_real_B1_tokens = self . netPreViT ( real_B1 , self . atten_layers , get_tokens = True )
2025-02-23 22:40:34 +08:00
# [[1,576,768],[1,576,768],[1,576,768]]
# [3,576,768]
2025-03-07 18:43:06 +08:00
def compute_D_loss ( self ) :
""" Calculate GAN loss with Content-Aware Optimization """
2025-02-23 23:15:25 +08:00
lambda_D_ViT = self . opt . lambda_D_ViT
2025-02-22 14:21:54 +08:00
2025-03-07 18:43:06 +08:00
loss_cao = 0.0
real_B0_tokens = self . mutil_real_B0_tokens [ 0 ]
pred_real0 , real_features0 = self . netD_ViT ( real_B0_tokens ) # scores, features
real_B1_tokens = self . mutil_real_B1_tokens [ 0 ]
pred_real1 , real_features1 = self . netD_ViT ( real_B1_tokens ) # scores, features
for fake0_token , fake1_token in zip ( self . mutil_fake_B0_tokens_list , self . mutil_fake_B1_tokens_list ) :
pre_fake0 , fake_features0 = self . netD_ViT ( fake0_token [ 0 ] . detach ( ) )
pre_fake1 , fake_features1 = self . netD_ViT ( fake1_token [ 0 ] . detach ( ) )
loss_cao0 , self . weight_real0 , self . weight_fake0 = self . cao (
D_real = real_features0 ,
D_fake = fake_features0 ,
real_scores = pred_real0 ,
fake_scores = pre_fake0
)
loss_cao1 , self . weight_real1 , self . weight_fake1 = self . cao (
D_real = real_features1 ,
D_fake = fake_features1 ,
real_scores = pred_real1 ,
fake_scores = pre_fake1
)
loss_cao + = loss_cao0 + loss_cao1
# ===== 综合损失 =====
total_steps = len ( self . fake_B0_list )
self . loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT / total_steps
# 记录损失值供可视化
# self.loss_D_real = loss_D_real.item()
# self.loss_D_fake = loss_D_fake.item()
# self.loss_cao = (loss_cao0 + loss_cao1).item() * 0.5
2025-02-22 14:21:54 +08:00
2025-03-07 18:43:06 +08:00
return self . loss_D_ViT
2025-02-22 14:21:54 +08:00
def compute_G_loss ( self ) :
""" 计算生成器的 GAN 损失 """
2025-02-26 22:07:11 +08:00
if self . opt . lambda_ctn > 0.0 :
# 生成图像的CTN光流图
2025-03-07 18:43:06 +08:00
self . f_content0 = self . ctn ( self . weight_fake0 )
self . f_content1 = self . ctn ( self . weight_fake1 )
2025-02-26 22:07:11 +08:00
# 变换后的图片
2025-03-07 18:43:06 +08:00
self . warped_real_A0 = warp ( self . real_A0 , self . f_content0 )
self . warped_real_A1 = warp ( self . real_A1 , self . f_content1 )
self . warped_fake_B0 = warp ( self . fake_B0 , self . f_content0 )
self . warped_fake_B1 = warp ( self . fake_B1 , self . f_content1 )
2025-02-26 22:07:11 +08:00
# 经过第二次生成器
2025-03-07 18:43:06 +08:00
self . warped_fake_B0_2 = self . netG ( self . warped_real_A0 , self . times [ torch . zeros ( size = [ 1 ] ) . cuda ( ) . long ( ) ] , self . z_in )
self . warped_fake_B1_2 = self . netG ( self . warped_real_A1 , self . times [ torch . zeros ( size = [ 1 ] ) . cuda ( ) . long ( ) ] , self . z_in2 )
2025-02-26 22:07:11 +08:00
warped_fake_B0_2 = self . warped_fake_B0_2
2025-03-07 18:43:06 +08:00
warped_fake_B1_2 = self . warped_fake_B1_2
2025-02-26 22:07:11 +08:00
warped_fake_B0 = self . warped_fake_B0
2025-03-07 18:43:06 +08:00
warped_fake_B1 = self . warped_fake_B1
2025-02-26 22:07:11 +08:00
# 计算L2损失
2025-03-07 18:43:06 +08:00
self . loss_ctn0 = F . mse_loss ( warped_fake_B0_2 , warped_fake_B0 )
self . loss_ctn1 = F . mse_loss ( warped_fake_B1_2 , warped_fake_B1 )
self . loss_ctn = ( self . loss_ctn0 + self . loss_ctn1 ) * 0.5
2025-02-26 22:07:11 +08:00
2025-02-22 14:21:54 +08:00
if self . opt . lambda_GAN > 0.0 :
2025-03-07 18:43:06 +08:00
pred_fake0 , _ = self . netD_ViT ( self . mutil_fake_B0_tokens_list [ - 1 ] [ 0 ] )
pred_fake1 , _ = self . netD_ViT ( self . mutil_fake_B1_tokens_list [ - 1 ] [ 0 ] )
self . loss_G_GAN0 = self . criterionGAN ( pred_fake0 , True ) . mean ( )
self . loss_G_GAN1 = self . criterionGAN ( pred_fake1 , True ) . mean ( )
self . loss_G_GAN = ( self . loss_G_GAN0 + self . loss_G_GAN1 ) * 0.5
2025-02-22 14:21:54 +08:00
else :
self . loss_G_GAN = 0.0
2025-02-26 22:07:11 +08:00
2025-02-22 14:21:54 +08:00
2025-03-07 18:43:06 +08:00
if self . opt . lambda_global or self . opt . lambda_spatial > 0.0 :
self . loss_global , self . loss_spatial = self . calculate_attention_loss ( )
2025-02-22 14:21:54 +08:00
else :
2025-03-07 18:43:06 +08:00
self . loss_global , self . loss_spatial = 0.0 , 0.0
2025-02-22 14:21:54 +08:00
2025-02-26 22:07:11 +08:00
self . loss_G = self . opt . lambda_GAN * self . loss_G_GAN + \
2025-02-26 22:24:17 +08:00
self . opt . lambda_ctn * self . loss_ctn + \
2025-03-07 18:43:06 +08:00
self . loss_global * self . opt . lambda_global + \
self . loss_spatial * self . opt . lambda_spatial
2025-02-22 14:21:54 +08:00
return self . loss_G
2025-03-07 18:43:06 +08:00
2025-02-22 14:21:54 +08:00
def calculate_attention_loss ( self ) :
n_layers = len ( self . atten_layers )
mutil_real_A0_tokens = self . mutil_real_A0_tokens
mutil_real_A1_tokens = self . mutil_real_A1_tokens
2025-03-07 18:43:06 +08:00
mutil_fake_B0_tokens = self . mutil_fake_B0_tokens_list [ - 1 ]
mutil_fake_B1_tokens = self . mutil_fake_B1_tokens_list [ - 1 ]
2025-02-22 14:21:54 +08:00
if self . opt . lambda_global > 0.0 :
loss_global = self . calculate_similarity ( mutil_real_A0_tokens , mutil_fake_B0_tokens ) + self . calculate_similarity ( mutil_real_A1_tokens , mutil_fake_B1_tokens )
loss_global * = 0.5
else :
loss_global = 0.0
if self . opt . lambda_spatial > 0.0 :
loss_spatial = 0.0
local_nums = self . opt . local_nums
tokens_cnt = 576
local_id = np . random . permutation ( tokens_cnt )
local_id = local_id [ : int ( min ( local_nums , tokens_cnt ) ) ]
2025-03-07 18:43:06 +08:00
mutil_real_A0_local_tokens = self . netPreViT ( self . real_A0_resize , self . atten_layers , get_tokens = True , local_id = local_id , side_length = self . opt . side_length )
mutil_real_A1_local_tokens = self . netPreViT ( self . real_A1_resize , self . atten_layers , get_tokens = True , local_id = local_id , side_length = self . opt . side_length )
2025-02-22 14:21:54 +08:00
2025-03-07 18:43:06 +08:00
mutil_fake_B0_local_tokens = self . netPreViT ( self . fake_B0_resize , self . atten_layers , get_tokens = True , local_id = local_id , side_length = self . opt . side_length )
mutil_fake_B1_local_tokens = self . netPreViT ( self . fake_B1_resize , self . atten_layers , get_tokens = True , local_id = local_id , side_length = self . opt . side_length )
2025-02-22 14:21:54 +08:00
loss_spatial = self . calculate_similarity ( mutil_real_A0_local_tokens , mutil_fake_B0_local_tokens ) + self . calculate_similarity ( mutil_real_A1_local_tokens , mutil_fake_B1_local_tokens )
loss_spatial * = 0.5
else :
loss_spatial = 0.0
2025-03-07 18:43:06 +08:00
return loss_global , loss_spatial
2025-02-22 14:21:54 +08:00
def calculate_similarity ( self , mutil_src_tokens , mutil_tgt_tokens ) :
loss = 0.0
n_layers = len ( self . atten_layers )
for src_tokens , tgt_tokens in zip ( mutil_src_tokens , mutil_tgt_tokens ) :
src_tgt = src_tokens . bmm ( tgt_tokens . permute ( 0 , 2 , 1 ) )
tgt_src = tgt_tokens . bmm ( src_tokens . permute ( 0 , 2 , 1 ) )
cos_dis_global = F . cosine_similarity ( src_tgt , tgt_src , dim = - 1 )
loss + = self . criterionL1 ( torch . ones_like ( cos_dis_global ) , cos_dis_global ) . mean ( )
loss = loss / n_layers
return loss