Compare commits

...

6 Commits
exp ... main

Author SHA1 Message Date
bishe
133f609e79 添加image光流 2025-02-24 22:49:38 +08:00
Kunyu_Lee
26b770a3c1 Merge branch 'main' of http://47.108.14.56:4000/123456/roma_unsb 2025-02-24 21:45:06 +08:00
Kunyu_Lee
9850183607 原始ROMA 2025-02-24 21:44:52 +08:00
bishe
e67b0f2511 最新的修改 2025-02-24 21:28:21 +08:00
bishe
7af2de920c renew 2025-02-24 21:13:36 +08:00
Kunyu_Lee
55b9db967a 最新的修改 2025-02-24 20:39:59 +08:00
5 changed files with 417 additions and 48 deletions

View File

@ -68,3 +68,13 @@
================ Training Loss (Sun Feb 23 23:13:05 2025) ================
================ Training Loss (Sun Feb 23 23:13:59 2025) ================
================ Training Loss (Sun Feb 23 23:14:59 2025) ================
================ Training Loss (Mon Feb 24 21:53:50 2025) ================
================ Training Loss (Mon Feb 24 21:54:16 2025) ================
================ Training Loss (Mon Feb 24 21:54:50 2025) ================
================ Training Loss (Mon Feb 24 21:55:31 2025) ================
================ Training Loss (Mon Feb 24 21:56:10 2025) ================
================ Training Loss (Mon Feb 24 22:09:38 2025) ================
================ Training Loss (Mon Feb 24 22:10:16 2025) ================
================ Training Loss (Mon Feb 24 22:12:46 2025) ================
================ Training Loss (Mon Feb 24 22:13:04 2025) ================
================ Training Loss (Mon Feb 24 22:14:04 2025) ================

View File

@ -1,5 +1,6 @@
----------------- Options ---------------
atten_layers: 5
adj_size_list: [2, 4, 6, 8, 12]
atten_layers: 1,3,5
batch_size: 1
beta1: 0.5
beta2: 0.999

View File

@ -79,26 +79,85 @@ class ContentAwareOptimization(nn.Module):
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
return cosine_sim
def generate_weight_map(self, gradients_fake):
def generate_weight_map(self, gradients_fake, feature_shape):
"""
生成内容感知权重图
生成内容感知权重图修正空间维度
Args:
gradients_fake: [B, N, D] 生成图像判别器梯度 [2,3,256,256]
gradients_real: [B, N, D] 真实图像判别器梯度
gradients_fake: [B, N, D] 生成图像判别器梯度
feature_shape: tuple [H, W] 判别器输出的特征图尺寸
Returns:
weight_fake: [B, N] 生成图像权重图 [2,3,256]
weight_real: [B, 1, H, W] 真实图像权重图
weight_fake: [B, 1, H, W] 生成图像权重图
"""
# 计算生成图像块的余弦相似度
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
H, W = feature_shape
N = H * W
# 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
# 计算余弦相似度(与原代码相同)
cosine_fake = self.compute_cosine_similarity(gradients_fake)
# 生成权重图(与原代码相同)
k = int(self.eta_ratio * cosine_fake.shape[1])
# 对生成图像生成权重图(同理)
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
weight_fake = torch.ones_like(cosine_fake)
for b in range(cosine_fake.shape[0]):
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
# 重建空间维度 --------------------------------------------------
# 将权重从[B, N]转换为[B, H, W]
#print(f"Shape of weight_fake before view: {weight_fake.shape}")
#print(f"Shape of cosine_fake: {cosine_fake.shape}")
#print(f"H: {H}, W: {W}, N: {N}")
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W]
return weight_fake
def compute_cosine_similarity_image(self, gradients):
"""
计算每个空间位置梯度与平均梯度的余弦相似度 (图像版本)
Args:
gradients: [B, C, H, W] 判别器输出的梯度
Returns:
cosine_sim: [B, H, W] 每个空间位置的余弦相似度
"""
# 将空间维度展平,以便计算所有空间位置的平均梯度
B, C, H, W = gradients.shape
gradients_reshaped = gradients.view(B, C, H * W) # [B, C, N] where N = H*W
gradients_transposed = gradients_reshaped.transpose(1, 2) # [B, N, C] 将C放到最后一维方便计算空间位置的平均梯度
mean_grad = torch.mean(gradients_transposed, dim=1, keepdim=True) # [B, 1, C] 在空间位置维度上求平均,得到平均梯度 [B, 1, C]
# mean_grad 现在是所有空间位置的平均梯度,形状为 [B, 1, C]
# 为了计算余弦相似度,我们需要将 mean_grad 扩展到与 gradients_transposed 相同的空间维度
mean_grad_expanded = mean_grad.expand(-1, H * W, -1) # [B, N, C]
# 计算余弦相似度dim=2 表示在特征维度 (C) 上计算
cosine_sim = F.cosine_similarity(gradients_transposed, mean_grad_expanded, dim=2) # [B, N]
# 将 cosine_sim 重新reshape回 [B, H, W]
cosine_sim = cosine_sim.view(B, H, W)
return cosine_sim
def generate_weight_map_image(self, gradients_fake, feature_shape):
"""
生成内容感知权重图修正空间维度 - 图像版本
Args:
gradients_fake: [B, C, H, W] 生成图像判别器梯度
feature_shape: tuple [H, W] 判别器输出的特征图尺寸
Returns:
weight_fake: [B, 1, H, W] 生成图像权重图
"""
H, W = feature_shape
# 计算余弦相似度(图像版本)
cosine_fake = self.compute_cosine_similarity_image(gradients_fake) # [B, H, W]
# 生成权重图与原代码相同但现在cosine_fake是[B, H, W]
k = int(self.eta_ratio * H * W) # k 仍然是基于总的空间位置数量计算
_, fake_indices = torch.topk(-cosine_fake.view(cosine_fake.shape[0], -1), k, dim=1) # 将 cosine_fake 展平为 [B, N] 以使用 topk
weight_fake = torch.ones_like(cosine_fake).view(cosine_fake.shape[0], -1) # 初始化权重图,并展平为 [B, N]
for b in range(cosine_fake.shape[0]):
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake.view(cosine_fake.shape[0], -1)[b, fake_indices[b]]))
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # 重新 reshape 为 [B, H, W],并添加通道维度变为 [B, 1, H, W]
return weight_fake
def forward(self, D_real, D_fake, real_scores, fake_scores):
@ -114,7 +173,7 @@ class ContentAwareOptimization(nn.Module):
"""
B, C, H, W = D_real.shape
N = H * W
shape_hw = [H, W]
# 注册钩子获取梯度
gradients_real = []
gradients_fake = []
@ -137,11 +196,11 @@ class ContentAwareOptimization(nn.Module):
total_loss.backward(retain_graph=True)
# 获取梯度数据
gradients_real = gradients_real[0] # [B, N, D]
gradients_fake = gradients_fake[0] # [B, N, D]
gradients_real = gradients_real[1] # [B, N, D]
gradients_fake = gradients_fake[1] # [B, N, D]
# 生成权重图
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_fake, shape_hw )
# 应用权重到对抗损失
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
@ -226,7 +285,7 @@ class RomaUnsbModel(BaseModel):
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
parser.set_defaults(pool_size=0) # no image pooling
@ -364,7 +423,6 @@ class RomaUnsbModel(BaseModel):
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def tokens_concat(self, origin_tokens, adjacent_size):
adj_size = adjacent_size
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
@ -385,7 +443,6 @@ class RomaUnsbModel(BaseModel):
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
cut_patch_list.append(cut_patch)
result = torch.cat(cut_patch_list,dim=1)
return result
@ -396,7 +453,6 @@ class RomaUnsbModel(BaseModel):
res_list.append(cat_result)
result = torch.cat(res_list, dim=1)
return result
def forward(self):
@ -459,10 +515,8 @@ class RomaUnsbModel(BaseModel):
self.real = torch.flip(self.real, [3])
self.realt = torch.flip(self.realt, [3])
print(f'fake_B0: {self.real_A0.shape}, fake_B1: {self.real_A1.shape}')
self.fake_B0 = self.netG(self.real_A0, self.time, z_in)
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
print(f'fake_B0: {self.fake_B0.shape}, fake_B1: {self.fake_B1.shape}')
if self.opt.phase == 'train':
real_A0 = self.real_A0
@ -487,22 +541,25 @@ class RomaUnsbModel(BaseModel):
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
# [[1,576,768],[1,576,768],[1,576,768]]
# [3,576,768]
#self.mutil_real_A0_tokens = self.cat_results(self.mutil_real_A0_tokens[0], self.opt.adj_size_list)
#print(f'self.mutil_real_A0_tokens[0]:{self.mutil_real_A0_tokens[0].shape}')
## 生成图像的梯度
#fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens.sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
#
## 梯度图
#self.weight_fake = self.cao.generate_weight_map(fake_gradient)
#
## 生成图像的CTN光流图
#self.f_content = self.ctn(self.weight_fake)
#
## 变换后的图片
#self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
#self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
#
## 经过第二次生成器
#self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
shape_hw = list(self.real_A0_resize.shape[2:4])
# 生成图像的梯度
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
# 梯度图
self.weight_fake = self.cao.generate_weight_map_image(fake_gradient, shape_hw)
# 生成图像的CTN光流图
self.f_content = self.ctn(self.weight_fake)
# 变换后的图片
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
# 经过第二次生成器
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
# warped_fake_B0_2=self.warped_fake_B0_2
# warped_fake_B0=self.warped_fake_B0
@ -575,9 +632,9 @@ class RomaUnsbModel(BaseModel):
loss_global = 0.0
self.l2_loss = 0.0
#if self.opt.lambda_ctn > 0.0:
# wapped_fake_B = warp(self.fake_B, self.f_content) # use updated self.f_content
# self.l2_loss = F.mse_loss(self.fake_B_2, wapped_fake_B) # complete the loss calculation
if self.opt.lambda_l2 > 0.0:
wapped_fake_B = warp(self.fake_B0, self.f_content) # use updated self.f_content
self.l2_loss = F.mse_loss(self.warped_fake_B0_2, wapped_fake_B) # complete the loss calculation
self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global
return self.loss_G

301
roma.py Normal file
View File

@ -0,0 +1,301 @@
import numpy as np
import torch
from .base_model import BaseModel
from . import networks
from .patchnce import PatchNCELoss
import util.util as util
import timm
import time
import torch.nn.functional as F
import sys
from functools import partial
import torch.nn as nn
import math
from torchvision.transforms import transforms as tfs
class ROMAModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
""" Configures options specific for CUT model
"""
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator')
parser.add_argument('--lambda_motion', type=float, default=1.0, help='weight for Temporal Consistency')
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
parser.add_argument('--local_nums', type=int, default=256)
parser.add_argument('--which_D_layer', type=int, default=-1)
parser.add_argument('--side_length', type=int, default=7)
parser.set_defaults(pool_size=0)
opt, _ = parser.parse_known_args()
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial', 'motion']
self.visual_names = ['real_A0', 'real_A1', 'fake_B0', 'fake_B1', 'real_B0', 'real_B1']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
if self.isTrain:
self.model_names = ['G', 'D_ViT']
else: # during test time, only load G
self.model_names = ['G']
# define networks (both generator and discriminator)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
if self.isTrain:
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
self.norm = F.softmax
self.resize = tfs.Resize(size=(384,384))
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionNCE = []
for atten_layer in self.atten_layers:
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
self.criterionL1 = torch.nn.L1Loss().to(self.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D_ViT)
def data_dependent_initialize(self, data):
"""
The feature network netF is defined in terms of the shape of the intermediate, extracted
features of the encoder portion of netG. Because of this, the weights of netF are
initialized at the first feedforward pass with some input images.
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
"""
pass
def optimize_parameters(self):
# forward
self.forward()
# update D
self.set_requires_grad(self.netD_ViT, True)
self.optimizer_D_ViT.zero_grad()
self.loss_D = self.compute_D_loss()
self.loss_D.backward()
self.optimizer_D_ViT.step()
# update G
self.set_requires_grad(self.netD_ViT, False)
self.optimizer_G.zero_grad()
self.loss_G = self.compute_G_loss()
self.loss_G.backward()
self.optimizer_G.step()
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B0 = self.netG(self.real_A0)
self.fake_B1 = self.netG(self.real_A1)
if self.opt.isTrain:
real_A0 = self.real_A0
real_A1 = self.real_A1
real_B0 = self.real_B0
real_B1 = self.real_B1
fake_B0 = self.fake_B0
fake_B1 = self.fake_B1
self.real_A0_resize = self.resize(real_A0)
self.real_A1_resize = self.resize(real_A1)
real_B0 = self.resize(real_B0)
real_B1 = self.resize(real_B1)
self.fake_B0_resize = self.resize(fake_B0)
self.fake_B1_resize = self.resize(fake_B1)
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
def tokens_concat(self, origin_tokens, adjacent_size):
adj_size = adjacent_size
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
S = int(math.sqrt(token_num))
if S * S != token_num:
print('Error! Not a square!')
token_map = origin_tokens.clone().reshape(B,S,S,C)
cut_patch_list = []
for i in range(0, S, adj_size):
for j in range(0, S, adj_size):
i_left = i
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
j_left = j
j_right = j + adj_size if j + adj_size <= S else S + 1
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
cut_patch= cut_patch.reshape(B,-1,C)
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
cut_patch_list.append(cut_patch)
result = torch.cat(cut_patch_list,dim=1)
return result
def cat_results(self, origin_tokens, adj_size_list):
res_list = [origin_tokens]
for ad_s in adj_size_list:
cat_result = self.tokens_concat(origin_tokens, ad_s)
res_list.append(cat_result)
result = torch.cat(res_list, dim=1)
return result
def compute_D_loss(self):
"""Calculate GAN loss for the discriminator"""
lambda_D_ViT = self.opt.lambda_D_ViT
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer].detach()
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer].detach()
real_B0_tokens = self.mutil_real_B0_tokens[self.opt.which_D_layer]
real_B1_tokens = self.mutil_real_B1_tokens[self.opt.which_D_layer]
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
real_B0_tokens = self.cat_results(real_B0_tokens, self.opt.adj_size_list)
real_B1_tokens = self.cat_results(real_B1_tokens, self.opt.adj_size_list)
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
pred_real1_ViT = self.netD_ViT(real_B1_tokens)
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
return self.loss_D_ViT
def compute_G_loss(self):
if self.opt.lambda_GAN > 0.0:
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer]
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer]
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
pred_fake0_ViT = self.netD_ViT(fake_B0_tokens)
pred_fake1_ViT = self.netD_ViT(fake_B1_tokens)
self.loss_G_GAN_ViT = (self.criterionGAN(pred_fake0_ViT, True) + self.criterionGAN(pred_fake1_ViT, True)) * 0.5 * self.opt.lambda_GAN
else:
self.loss_G_GAN_ViT = 0.0
if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
else:
self.loss_global, self.loss_spatial = 0.0, 0.0
if self.opt.lambda_motion > 0.0:
self.loss_motion = 0.0
for real_A0_tokens, real_A1_tokens, fake_B0_tokens, fake_B1_tokens in zip(self.mutil_real_A0_tokens, self.mutil_real_A1_tokens, self.mutil_fake_B0_tokens, self.mutil_fake_B1_tokens):
A0_B1 = real_A0_tokens.bmm(fake_B1_tokens.permute(0,2,1))
B0_A1 = fake_B0_tokens.bmm(real_A1_tokens.permute(0,2,1))
cos_dis_global = F.cosine_similarity(A0_B1, B0_A1, dim=-1)
self.loss_motion += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
else:
self.loss_motion = 0.0
self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial + self.loss_motion
return self.loss_G
def calculate_attention_loss(self):
n_layers = len(self.atten_layers)
mutil_real_A0_tokens = self.mutil_real_A0_tokens
mutil_real_A1_tokens = self.mutil_real_A1_tokens
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
loss_global *= 0.5
else:
loss_global = 0.0
if self.opt.lambda_spatial > 0.0:
loss_spatial = 0.0
local_nums = self.opt.local_nums
tokens_cnt = 576
local_id = np.random.permutation(tokens_cnt)
local_id = local_id[:int(min(local_nums, tokens_cnt))]
mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
loss_spatial *= 0.5
else:
loss_spatial = 0.0
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
loss = 0.0
n_layers = len(self.atten_layers)
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
loss = loss / n_layers
return loss