2025-02-22 14:21:54 +08:00
|
|
|
|
import numpy as np
|
|
|
|
|
|
import torch
|
|
|
|
|
|
from .base_model import BaseModel
|
|
|
|
|
|
from . import networks
|
|
|
|
|
|
from .patchnce import PatchNCELoss
|
2025-02-22 15:23:52 +08:00
|
|
|
|
from .cnt import *
|
2025-02-22 14:21:54 +08:00
|
|
|
|
import util.util as util
|
|
|
|
|
|
import timm
|
|
|
|
|
|
import time
|
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
import sys
|
|
|
|
|
|
from functools import partial
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
import math
|
|
|
|
|
|
|
|
|
|
|
|
from torchvision.transforms import transforms as tfs
|
|
|
|
|
|
|
|
|
|
|
|
class ROMAModel(BaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def modify_commandline_options(parser, is_train=True):
|
|
|
|
|
|
""" Configures options specific for CUT model
|
|
|
|
|
|
"""
|
|
|
|
|
|
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
|
2025-02-22 15:23:52 +08:00
|
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
|
parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator')
|
|
|
|
|
|
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
|
|
|
|
|
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
|
|
|
|
|
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
|
|
|
|
|
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
|
2025-02-22 15:23:52 +08:00
|
|
|
|
parser.add_argument('--lambda_inc', type=float, default=2.0, help='weight for Content Aware Optimization')
|
|
|
|
|
|
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio for selecting content region')
|
|
|
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
|
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
|
|
|
|
|
|
parser.add_argument('--local_nums', type=int, default=256)
|
|
|
|
|
|
parser.add_argument('--which_D_layer', type=int, default=-1)
|
|
|
|
|
|
parser.add_argument('--side_length', type=int, default=7)
|
|
|
|
|
|
|
2025-02-22 15:23:52 +08:00
|
|
|
|
parser.set_defaults(pool_size=0)
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
opt, _ = parser.parse_known_args()
|
|
|
|
|
|
|
|
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, opt):
|
|
|
|
|
|
BaseModel.__init__(self, opt)
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-02-22 15:23:52 +08:00
|
|
|
|
self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial']
|
2025-02-22 14:21:54 +08:00
|
|
|
|
self.visual_names = ['real_A0', 'real_A1', 'fake_B0', 'fake_B1', 'real_B0', 'real_B1']
|
|
|
|
|
|
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.isTrain:
|
2025-02-22 15:23:52 +08:00
|
|
|
|
self.model_names = ['G', 'D_ViT', 'G_2']
|
2025-02-22 14:21:54 +08:00
|
|
|
|
else: # during test time, only load G
|
|
|
|
|
|
self.model_names = ['G']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# define networks (both generator and discriminator)
|
|
|
|
|
|
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.isTrain:
|
|
|
|
|
|
|
|
|
|
|
|
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
|
|
|
|
|
|
self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
|
|
|
|
|
|
|
2025-02-22 15:23:52 +08:00
|
|
|
|
# From UNSB
|
|
|
|
|
|
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
2025-02-22 15:23:52 +08:00
|
|
|
|
# Deine another generator
|
|
|
|
|
|
self.netG_2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
|
|
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
|
self.norm = F.softmax
|
|
|
|
|
|
|
|
|
|
|
|
self.resize = tfs.Resize(size=(384,384))
|
|
|
|
|
|
|
|
|
|
|
|
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
|
|
|
|
|
self.criterionNCE = []
|
|
|
|
|
|
|
|
|
|
|
|
for atten_layer in self.atten_layers:
|
|
|
|
|
|
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
|
|
|
|
|
|
|
|
|
|
|
|
self.criterionL1 = torch.nn.L1Loss().to(self.device)
|
|
|
|
|
|
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
|
|
|
|
|
self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
|
2025-02-22 15:23:52 +08:00
|
|
|
|
self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
|
2025-02-22 14:21:54 +08:00
|
|
|
|
self.optimizers.append(self.optimizer_G)
|
|
|
|
|
|
self.optimizers.append(self.optimizer_D_ViT)
|
2025-02-22 15:23:52 +08:00
|
|
|
|
self.optimizers.append(self.optimizer_E)
|
|
|
|
|
|
|
|
|
|
|
|
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
|
|
|
|
|
|
self.ctn = ContentAwareTemporalNorm() #生成的伪光流场
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
def data_dependent_initialize(self, data):
|
|
|
|
|
|
"""
|
|
|
|
|
|
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
|
|
|
|
|
features of the encoder portion of netG. Because of this, the weights of netF are
|
|
|
|
|
|
initialized at the first feedforward pass with some input images.
|
|
|
|
|
|
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
|
|
|
|
|
"""
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def optimize_parameters(self):
|
|
|
|
|
|
# forward
|
|
|
|
|
|
self.forward()
|
|
|
|
|
|
|
|
|
|
|
|
# update D
|
|
|
|
|
|
self.set_requires_grad(self.netD_ViT, True)
|
|
|
|
|
|
self.optimizer_D_ViT.zero_grad()
|
|
|
|
|
|
self.loss_D = self.compute_D_loss()
|
|
|
|
|
|
self.loss_D.backward()
|
|
|
|
|
|
self.optimizer_D_ViT.step()
|
2025-02-22 15:23:52 +08:00
|
|
|
|
|
|
|
|
|
|
# update E
|
|
|
|
|
|
self.set_requires_grad(self.netE, True)
|
|
|
|
|
|
self.optimizer_E.zero_grad()
|
|
|
|
|
|
self.loss_E = self.compute_E_loss()
|
|
|
|
|
|
self.loss_E.backward()
|
|
|
|
|
|
self.optimizer_E.step()
|
2025-02-22 14:21:54 +08:00
|
|
|
|
|
|
|
|
|
|
# update G
|
|
|
|
|
|
self.set_requires_grad(self.netD_ViT, False)
|
|
|
|
|
|
self.optimizer_G.zero_grad()
|
|
|
|
|
|
self.loss_G = self.compute_G_loss()
|
|
|
|
|
|
self.loss_G.backward()
|
|
|
|
|
|
self.optimizer_G.step()
|
|
|
|
|
|
|
|
|
|
|
|
def set_input(self, input):
|
|
|
|
|
|
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
|
|
|
|
|
Parameters:
|
|
|
|
|
|
input (dict): include the data itself and its metadata information.
|
|
|
|
|
|
The option 'direction' can be used to swap domain A and domain B.
|
|
|
|
|
|
"""
|
|
|
|
|
|
AtoB = self.opt.direction == 'AtoB'
|
|
|
|
|
|
self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
|
|
|
|
|
|
self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
|
|
|
|
|
|
self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
|
|
|
|
|
|
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
|
|
|
|
|
|
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self):
|
|
|
|
|
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
|
|
|
|
|
|
|
|
|
|
|
# ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============
|
|
|
|
|
|
tau = self.opt.tau
|
|
|
|
|
|
T = self.opt.num_timesteps
|
|
|
|
|
|
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
|
|
|
|
|
|
times = np.cumsum(incs)
|
|
|
|
|
|
times = times / times[-1]
|
|
|
|
|
|
times = 0.5 * times[-1] + 0.5 * times #[0.5,1]
|
|
|
|
|
|
times = np.concatenate([np.zeros(1), times])
|
|
|
|
|
|
times = torch.tensor(times).float().cuda()
|
|
|
|
|
|
self.times = times
|
2025-02-22 15:23:52 +08:00
|
|
|
|
bs = self.real_A0.size(0)
|
2025-02-22 14:21:54 +08:00
|
|
|
|
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
|
|
|
|
|
|
self.time_idx = time_idx
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
self.netG.eval()
|
|
|
|
|
|
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
|
|
|
|
|
|
for t in range(self.time_idx.int().item() + 1):
|
|
|
|
|
|
# 计算增量 delta 与 inter/scale,用于每个时间步的插值等
|
|
|
|
|
|
if t > 0:
|
|
|
|
|
|
delta = times[t] - times[t - 1]
|
|
|
|
|
|
denom = times[-1] - times[t - 1]
|
|
|
|
|
|
inter = (delta / denom).reshape(-1, 1, 1, 1)
|
|
|
|
|
|
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
# 对 Xt、Xt2 进行随机噪声更新
|
2025-02-22 15:23:52 +08:00
|
|
|
|
Xt = self.real_A0 if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
|
|
|
|
|
|
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device)
|
|
|
|
|
|
time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long()
|
|
|
|
|
|
z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device)
|
2025-02-22 14:21:54 +08:00
|
|
|
|
self.time = times[time_idx]
|
|
|
|
|
|
Xt_1 = self.netG(Xt, self.time, z)
|
|
|
|
|
|
|
2025-02-22 15:23:52 +08:00
|
|
|
|
Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
|
|
|
|
|
|
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device)
|
|
|
|
|
|
time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long()
|
|
|
|
|
|
z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device)
|
2025-02-22 14:21:54 +08:00
|
|
|
|
Xt_12 = self.netG(Xt2, self.time, z)
|
|
|
|
|
|
|
|
|
|
|
|
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
|
|
|
|
|
|
self.real_A_noisy = Xt.detach()
|
|
|
|
|
|
self.real_A_noisy2 = Xt2.detach()
|
|
|
|
|
|
# 保存noisy_map
|
|
|
|
|
|
self.noisy_map = self.real_A_noisy - self.real_A
|
|
|
|
|
|
|
|
|
|
|
|
# ============ 第三步:拼接输入并执行网络推理 =============
|
2025-02-22 15:23:52 +08:00
|
|
|
|
bs = self.real_A0.size(0)
|
|
|
|
|
|
z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.real_A0.device)
|
|
|
|
|
|
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
|
2025-02-22 14:21:54 +08:00
|
|
|
|
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
|
2025-02-22 15:23:52 +08:00
|
|
|
|
self.real = self.real_A0
|
2025-02-22 14:21:54 +08:00
|
|
|
|
self.realt = self.real_A_noisy
|
|
|
|
|
|
|
|
|
|
|
|
if self.opt.flip_equivariance:
|
|
|
|
|
|
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
|
|
|
|
|
|
if self.flipped_for_equivariance:
|
|
|
|
|
|
self.real = torch.flip(self.real, [3])
|
|
|
|
|
|
self.realt = torch.flip(self.realt, [3])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.fake_B0 = self.netG(self.real_A0)
|
|
|
|
|
|
self.fake_B1 = self.netG(self.real_A1)
|
|
|
|
|
|
|
|
|
|
|
|
if self.opt.isTrain:
|
|
|
|
|
|
real_A0 = self.real_A0
|
|
|
|
|
|
real_A1 = self.real_A1
|
|
|
|
|
|
real_B0 = self.real_B0
|
|
|
|
|
|
real_B1 = self.real_B1
|
|
|
|
|
|
fake_B0 = self.fake_B0
|
|
|
|
|
|
fake_B1 = self.fake_B1
|
|
|
|
|
|
self.real_A0_resize = self.resize(real_A0)
|
|
|
|
|
|
self.real_A1_resize = self.resize(real_A1)
|
|
|
|
|
|
real_B0 = self.resize(real_B0)
|
|
|
|
|
|
real_B1 = self.resize(real_B1)
|
|
|
|
|
|
self.fake_B0_resize = self.resize(fake_B0)
|
|
|
|
|
|
self.fake_B1_resize = self.resize(fake_B1)
|
|
|
|
|
|
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
|
|
|
|
|
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
|
|
|
|
|
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
|
|
|
|
|
|
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
|
|
|
|
|
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
|
|
|
|
|
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
|
|
|
|
|
|
2025-02-22 15:23:52 +08:00
|
|
|
|
|
|
|
|
|
|
if self.opt.phase == 'train':
|
|
|
|
|
|
# 真实图像的梯度
|
|
|
|
|
|
real_gradient = torch.autograd.grad(self.real_B.sum(), self.real_B, create_graph=True)[0]
|
|
|
|
|
|
# 生成图像的梯度
|
|
|
|
|
|
fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0]
|
|
|
|
|
|
# 梯度图
|
|
|
|
|
|
self.weight_real, self.weight_fake = self.cao.generate_weight_map(real_gradient, fake_gradient)
|
|
|
|
|
|
|
|
|
|
|
|
# 生成图像的CTN光流图
|
|
|
|
|
|
self.f_content = self.ctn(self.weight_fake)
|
|
|
|
|
|
|
|
|
|
|
|
# 把前面生成后的图片再加上noisy_map
|
|
|
|
|
|
self.fake_B0_2 = self.fake_B0 + self.noisy_map
|
|
|
|
|
|
|
|
|
|
|
|
# 变换后的图片
|
|
|
|
|
|
wapped_fake_B0_2 = warp(self.fake_B0_2, self.f_content)
|
|
|
|
|
|
|
|
|
|
|
|
# 经过第二次生成器
|
|
|
|
|
|
self.fake_B0_2 = self.netG_2(wapped_fake_B0_2, self.time, z_in)
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
|
def tokens_concat(self, origin_tokens, adjacent_size):
|
|
|
|
|
|
adj_size = adjacent_size
|
|
|
|
|
|
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
|
|
|
|
|
|
S = int(math.sqrt(token_num))
|
|
|
|
|
|
if S * S != token_num:
|
|
|
|
|
|
print('Error! Not a square!')
|
|
|
|
|
|
token_map = origin_tokens.clone().reshape(B,S,S,C)
|
|
|
|
|
|
cut_patch_list = []
|
|
|
|
|
|
for i in range(0, S, adj_size):
|
|
|
|
|
|
for j in range(0, S, adj_size):
|
|
|
|
|
|
i_left = i
|
|
|
|
|
|
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
|
|
|
|
|
|
j_left = j
|
|
|
|
|
|
j_right = j + adj_size if j + adj_size <= S else S + 1
|
|
|
|
|
|
|
|
|
|
|
|
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
|
|
|
|
|
|
cut_patch= cut_patch.reshape(B,-1,C)
|
|
|
|
|
|
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
|
|
|
|
|
|
cut_patch_list.append(cut_patch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = torch.cat(cut_patch_list,dim=1)
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cat_results(self, origin_tokens, adj_size_list):
|
|
|
|
|
|
res_list = [origin_tokens]
|
|
|
|
|
|
for ad_s in adj_size_list:
|
|
|
|
|
|
cat_result = self.tokens_concat(origin_tokens, ad_s)
|
|
|
|
|
|
res_list.append(cat_result)
|
|
|
|
|
|
|
|
|
|
|
|
result = torch.cat(res_list, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_D_loss(self):
|
|
|
|
|
|
"""Calculate GAN loss for the discriminator"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lambda_D_ViT = self.opt.lambda_D_ViT
|
|
|
|
|
|
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer].detach()
|
|
|
|
|
|
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer].detach()
|
|
|
|
|
|
|
|
|
|
|
|
real_B0_tokens = self.mutil_real_B0_tokens[self.opt.which_D_layer]
|
|
|
|
|
|
real_B1_tokens = self.mutil_real_B1_tokens[self.opt.which_D_layer]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
|
|
|
|
|
|
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
real_B0_tokens = self.cat_results(real_B0_tokens, self.opt.adj_size_list)
|
|
|
|
|
|
real_B1_tokens = self.cat_results(real_B1_tokens, self.opt.adj_size_list)
|
|
|
|
|
|
|
|
|
|
|
|
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
|
|
|
|
|
|
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
|
|
|
|
|
|
|
|
|
|
|
|
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
|
|
|
|
|
|
pred_real1_ViT = self.netD_ViT(real_B1_tokens)
|
|
|
|
|
|
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
|
|
|
|
|
|
|
|
|
|
|
|
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return self.loss_D_ViT
|
|
|
|
|
|
|
2025-02-22 15:23:52 +08:00
|
|
|
|
|
|
|
|
|
|
def compute_E_loss(self):
|
|
|
|
|
|
"""计算判别器 E 的损失"""
|
|
|
|
|
|
|
|
|
|
|
|
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
|
|
|
|
|
|
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
|
|
|
|
|
|
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
|
|
|
|
|
|
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
|
|
|
|
|
|
|
|
|
|
|
|
return self.loss_E
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
|
def compute_G_loss(self):
|
|
|
|
|
|
|
|
|
|
|
|
if self.opt.lambda_GAN > 0.0:
|
|
|
|
|
|
|
|
|
|
|
|
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer]
|
|
|
|
|
|
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer]
|
|
|
|
|
|
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
|
|
|
|
|
|
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
|
|
|
|
|
|
pred_fake0_ViT = self.netD_ViT(fake_B0_tokens)
|
|
|
|
|
|
pred_fake1_ViT = self.netD_ViT(fake_B1_tokens)
|
|
|
|
|
|
self.loss_G_GAN_ViT = (self.criterionGAN(pred_fake0_ViT, True) + self.criterionGAN(pred_fake1_ViT, True)) * 0.5 * self.opt.lambda_GAN
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.loss_G_GAN_ViT = 0.0
|
|
|
|
|
|
|
2025-02-22 15:23:52 +08:00
|
|
|
|
|
|
|
|
|
|
self.loss_SB = 0
|
|
|
|
|
|
if self.opt.lambda_SB > 0.0:
|
|
|
|
|
|
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
|
|
|
|
|
|
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
bs = self.opt.batch_size
|
|
|
|
|
|
|
|
|
|
|
|
# eq.9
|
|
|
|
|
|
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0)
|
|
|
|
|
|
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
|
|
|
|
|
|
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
|
|
|
|
|
|
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy2 - self.fake_B1) ** 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-02-22 14:21:54 +08:00
|
|
|
|
if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
|
|
|
|
|
|
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.loss_global, self.loss_spatial = 0.0, 0.0
|
2025-02-22 15:23:52 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.opt.lambda_ctn > 0.0:
|
|
|
|
|
|
wapped_fake_B1 = warp(self.fake_B1, self.f_content) # use updated self.f_content
|
|
|
|
|
|
self.l2_loss = F.mse_loss(self.fake_B0_2, wapped_fake_B1) * self.opt.lambda_ctn
|
2025-02-22 14:21:54 +08:00
|
|
|
|
else:
|
2025-02-22 15:23:52 +08:00
|
|
|
|
self.l2_loss = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial + self.l2_loss # include l2_loss in total loss
|
2025-02-22 14:21:54 +08:00
|
|
|
|
return self.loss_G
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_attention_loss(self):
|
|
|
|
|
|
n_layers = len(self.atten_layers)
|
|
|
|
|
|
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
|
|
|
|
|
mutil_real_A1_tokens = self.mutil_real_A1_tokens
|
|
|
|
|
|
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
|
|
|
|
|
|
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.opt.lambda_global > 0.0:
|
|
|
|
|
|
loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
|
|
|
|
|
|
loss_global *= 0.5
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
loss_global = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
if self.opt.lambda_spatial > 0.0:
|
|
|
|
|
|
loss_spatial = 0.0
|
|
|
|
|
|
local_nums = self.opt.local_nums
|
|
|
|
|
|
tokens_cnt = 576
|
|
|
|
|
|
local_id = np.random.permutation(tokens_cnt)
|
|
|
|
|
|
local_id = local_id[:int(min(local_nums, tokens_cnt))]
|
|
|
|
|
|
|
|
|
|
|
|
mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
|
|
|
|
|
mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
|
|
|
|
|
|
|
|
|
|
|
mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
|
|
|
|
|
mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
|
|
|
|
|
|
|
|
|
|
|
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
|
|
|
|
|
|
loss_spatial *= 0.5
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
loss_spatial = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
|
|
|
|
|
|
|
|
|
|
|
|
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
|
|
|
|
|
|
loss = 0.0
|
|
|
|
|
|
n_layers = len(self.atten_layers)
|
|
|
|
|
|
|
|
|
|
|
|
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
|
|
|
|
|
|
|
|
|
|
|
|
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
|
|
|
|
|
|
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
|
|
|
|
|
|
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
|
|
|
|
|
|
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
|
|
|
|
|
|
|
|
|
|
|
|
loss = loss / n_layers
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|