Compare commits
6 Commits
cpwithatte
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
133f609e79 | ||
|
|
26b770a3c1 | ||
|
|
9850183607 | ||
|
|
e67b0f2511 | ||
|
|
7af2de920c | ||
|
|
55b9db967a |
@ -68,3 +68,13 @@
|
||||
================ Training Loss (Sun Feb 23 23:13:05 2025) ================
|
||||
================ Training Loss (Sun Feb 23 23:13:59 2025) ================
|
||||
================ Training Loss (Sun Feb 23 23:14:59 2025) ================
|
||||
================ Training Loss (Mon Feb 24 21:53:50 2025) ================
|
||||
================ Training Loss (Mon Feb 24 21:54:16 2025) ================
|
||||
================ Training Loss (Mon Feb 24 21:54:50 2025) ================
|
||||
================ Training Loss (Mon Feb 24 21:55:31 2025) ================
|
||||
================ Training Loss (Mon Feb 24 21:56:10 2025) ================
|
||||
================ Training Loss (Mon Feb 24 22:09:38 2025) ================
|
||||
================ Training Loss (Mon Feb 24 22:10:16 2025) ================
|
||||
================ Training Loss (Mon Feb 24 22:12:46 2025) ================
|
||||
================ Training Loss (Mon Feb 24 22:13:04 2025) ================
|
||||
================ Training Loss (Mon Feb 24 22:14:04 2025) ================
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
----------------- Options ---------------
|
||||
atten_layers: 5
|
||||
adj_size_list: [2, 4, 6, 8, 12]
|
||||
atten_layers: 1,3,5
|
||||
batch_size: 1
|
||||
beta1: 0.5
|
||||
beta2: 0.999
|
||||
|
||||
Binary file not shown.
@ -79,26 +79,85 @@ class ContentAwareOptimization(nn.Module):
|
||||
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
|
||||
return cosine_sim
|
||||
|
||||
def generate_weight_map(self, gradients_fake):
|
||||
def generate_weight_map(self, gradients_fake, feature_shape):
|
||||
"""
|
||||
生成内容感知权重图
|
||||
生成内容感知权重图(修正空间维度)
|
||||
Args:
|
||||
gradients_fake: [B, N, D] 生成图像判别器梯度 [2,3,256,256]
|
||||
gradients_real: [B, N, D] 真实图像判别器梯度
|
||||
gradients_fake: [B, N, D] 生成图像判别器梯度
|
||||
feature_shape: tuple [H, W] 判别器输出的特征图尺寸
|
||||
Returns:
|
||||
weight_fake: [B, N] 生成图像权重图 [2,3,256]
|
||||
weight_real: [B, 1, H, W] 真实图像权重图
|
||||
weight_fake: [B, 1, H, W] 生成图像权重图
|
||||
"""
|
||||
# 计算生成图像块的余弦相似度
|
||||
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
|
||||
H, W = feature_shape
|
||||
N = H * W
|
||||
|
||||
# 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
|
||||
# 计算余弦相似度(与原代码相同)
|
||||
cosine_fake = self.compute_cosine_similarity(gradients_fake)
|
||||
|
||||
# 生成权重图(与原代码相同)
|
||||
k = int(self.eta_ratio * cosine_fake.shape[1])
|
||||
|
||||
# 对生成图像生成权重图(同理)
|
||||
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
|
||||
weight_fake = torch.ones_like(cosine_fake)
|
||||
|
||||
for b in range(cosine_fake.shape[0]):
|
||||
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
|
||||
|
||||
# 重建空间维度 --------------------------------------------------
|
||||
# 将权重从[B, N]转换为[B, H, W]
|
||||
#print(f"Shape of weight_fake before view: {weight_fake.shape}")
|
||||
#print(f"Shape of cosine_fake: {cosine_fake.shape}")
|
||||
#print(f"H: {H}, W: {W}, N: {N}")
|
||||
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W]
|
||||
|
||||
return weight_fake
|
||||
|
||||
def compute_cosine_similarity_image(self, gradients):
|
||||
"""
|
||||
计算每个空间位置梯度与平均梯度的余弦相似度 (图像版本)
|
||||
Args:
|
||||
gradients: [B, C, H, W] 判别器输出的梯度
|
||||
Returns:
|
||||
cosine_sim: [B, H, W] 每个空间位置的余弦相似度
|
||||
"""
|
||||
# 将空间维度展平,以便计算所有空间位置的平均梯度
|
||||
B, C, H, W = gradients.shape
|
||||
gradients_reshaped = gradients.view(B, C, H * W) # [B, C, N] where N = H*W
|
||||
gradients_transposed = gradients_reshaped.transpose(1, 2) # [B, N, C] 将C放到最后一维,方便计算空间位置的平均梯度
|
||||
|
||||
mean_grad = torch.mean(gradients_transposed, dim=1, keepdim=True) # [B, 1, C] 在空间位置维度上求平均,得到平均梯度 [B, 1, C]
|
||||
# mean_grad 现在是所有空间位置的平均梯度,形状为 [B, 1, C]
|
||||
|
||||
# 为了计算余弦相似度,我们需要将 mean_grad 扩展到与 gradients_transposed 相同的空间维度
|
||||
mean_grad_expanded = mean_grad.expand(-1, H * W, -1) # [B, N, C]
|
||||
|
||||
# 计算余弦相似度,dim=2 表示在特征维度 (C) 上计算
|
||||
cosine_sim = F.cosine_similarity(gradients_transposed, mean_grad_expanded, dim=2) # [B, N]
|
||||
|
||||
# 将 cosine_sim 重新reshape回 [B, H, W]
|
||||
cosine_sim = cosine_sim.view(B, H, W)
|
||||
return cosine_sim
|
||||
|
||||
def generate_weight_map_image(self, gradients_fake, feature_shape):
|
||||
"""
|
||||
生成内容感知权重图(修正空间维度 - 图像版本)
|
||||
Args:
|
||||
gradients_fake: [B, C, H, W] 生成图像判别器梯度
|
||||
feature_shape: tuple [H, W] 判别器输出的特征图尺寸
|
||||
Returns:
|
||||
weight_fake: [B, 1, H, W] 生成图像权重图
|
||||
"""
|
||||
H, W = feature_shape
|
||||
# 计算余弦相似度(图像版本)
|
||||
cosine_fake = self.compute_cosine_similarity_image(gradients_fake) # [B, H, W]
|
||||
# 生成权重图(与原代码相同,但现在cosine_fake是[B, H, W])
|
||||
k = int(self.eta_ratio * H * W) # k 仍然是基于总的空间位置数量计算
|
||||
_, fake_indices = torch.topk(-cosine_fake.view(cosine_fake.shape[0], -1), k, dim=1) # 将 cosine_fake 展平为 [B, N] 以使用 topk
|
||||
weight_fake = torch.ones_like(cosine_fake).view(cosine_fake.shape[0], -1) # 初始化权重图,并展平为 [B, N]
|
||||
for b in range(cosine_fake.shape[0]):
|
||||
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake.view(cosine_fake.shape[0], -1)[b, fake_indices[b]]))
|
||||
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # 重新 reshape 为 [B, H, W],并添加通道维度变为 [B, 1, H, W]
|
||||
return weight_fake
|
||||
|
||||
def forward(self, D_real, D_fake, real_scores, fake_scores):
|
||||
@ -114,7 +173,7 @@ class ContentAwareOptimization(nn.Module):
|
||||
"""
|
||||
B, C, H, W = D_real.shape
|
||||
N = H * W
|
||||
|
||||
shape_hw = [H, W]
|
||||
# 注册钩子获取梯度
|
||||
gradients_real = []
|
||||
gradients_fake = []
|
||||
@ -137,11 +196,11 @@ class ContentAwareOptimization(nn.Module):
|
||||
total_loss.backward(retain_graph=True)
|
||||
|
||||
# 获取梯度数据
|
||||
gradients_real = gradients_real[0] # [B, N, D]
|
||||
gradients_fake = gradients_fake[0] # [B, N, D]
|
||||
gradients_real = gradients_real[1] # [B, N, D]
|
||||
gradients_fake = gradients_fake[1] # [B, N, D]
|
||||
|
||||
# 生成权重图
|
||||
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
|
||||
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_fake, shape_hw )
|
||||
|
||||
# 应用权重到对抗损失
|
||||
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
|
||||
@ -226,7 +285,7 @@ class RomaUnsbModel(BaseModel):
|
||||
|
||||
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
|
||||
parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
|
||||
|
||||
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
|
||||
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
|
||||
|
||||
parser.set_defaults(pool_size=0) # no image pooling
|
||||
@ -364,7 +423,6 @@ class RomaUnsbModel(BaseModel):
|
||||
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
|
||||
def tokens_concat(self, origin_tokens, adjacent_size):
|
||||
adj_size = adjacent_size
|
||||
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
|
||||
@ -385,7 +443,6 @@ class RomaUnsbModel(BaseModel):
|
||||
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
|
||||
cut_patch_list.append(cut_patch)
|
||||
|
||||
|
||||
result = torch.cat(cut_patch_list,dim=1)
|
||||
return result
|
||||
|
||||
@ -396,7 +453,6 @@ class RomaUnsbModel(BaseModel):
|
||||
res_list.append(cat_result)
|
||||
|
||||
result = torch.cat(res_list, dim=1)
|
||||
|
||||
return result
|
||||
|
||||
def forward(self):
|
||||
@ -459,10 +515,8 @@ class RomaUnsbModel(BaseModel):
|
||||
self.real = torch.flip(self.real, [3])
|
||||
self.realt = torch.flip(self.realt, [3])
|
||||
|
||||
print(f'fake_B0: {self.real_A0.shape}, fake_B1: {self.real_A1.shape}')
|
||||
self.fake_B0 = self.netG(self.real_A0, self.time, z_in)
|
||||
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
|
||||
print(f'fake_B0: {self.fake_B0.shape}, fake_B1: {self.fake_B1.shape}')
|
||||
|
||||
if self.opt.phase == 'train':
|
||||
real_A0 = self.real_A0
|
||||
@ -487,29 +541,32 @@ class RomaUnsbModel(BaseModel):
|
||||
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||||
# [[1,576,768],[1,576,768],[1,576,768]]
|
||||
# [3,576,768]
|
||||
#self.mutil_real_A0_tokens = self.cat_results(self.mutil_real_A0_tokens[0], self.opt.adj_size_list)
|
||||
#print(f'self.mutil_real_A0_tokens[0]:{self.mutil_real_A0_tokens[0].shape}')
|
||||
|
||||
## 生成图像的梯度
|
||||
#fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens.sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
||||
#
|
||||
## 梯度图
|
||||
#self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
||||
#
|
||||
## 生成图像的CTN光流图
|
||||
#self.f_content = self.ctn(self.weight_fake)
|
||||
#
|
||||
## 变换后的图片
|
||||
#self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
|
||||
#self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
|
||||
#
|
||||
## 经过第二次生成器
|
||||
#self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
|
||||
shape_hw = list(self.real_A0_resize.shape[2:4])
|
||||
# 生成图像的梯度
|
||||
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
||||
|
||||
#warped_fake_B0_2=self.warped_fake_B0_2
|
||||
#warped_fake_B0=self.warped_fake_B0
|
||||
#self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
|
||||
#self.warped_fake_B0_resize = self.resize(warped_fake_B0)
|
||||
#self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||
#self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
|
||||
# 梯度图
|
||||
self.weight_fake = self.cao.generate_weight_map_image(fake_gradient, shape_hw)
|
||||
|
||||
# 生成图像的CTN光流图
|
||||
self.f_content = self.ctn(self.weight_fake)
|
||||
|
||||
# 变换后的图片
|
||||
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
|
||||
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
|
||||
|
||||
# 经过第二次生成器
|
||||
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
|
||||
|
||||
# warped_fake_B0_2=self.warped_fake_B0_2
|
||||
# warped_fake_B0=self.warped_fake_B0
|
||||
# self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
|
||||
# self.warped_fake_B0_resize = self.resize(warped_fake_B0)
|
||||
# self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||
# self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
|
||||
|
||||
|
||||
def compute_D_loss(self): #判别器还是没有改
|
||||
@ -575,9 +632,9 @@ class RomaUnsbModel(BaseModel):
|
||||
loss_global = 0.0
|
||||
|
||||
self.l2_loss = 0.0
|
||||
#if self.opt.lambda_ctn > 0.0:
|
||||
# wapped_fake_B = warp(self.fake_B, self.f_content) # use updated self.f_content
|
||||
# self.l2_loss = F.mse_loss(self.fake_B_2, wapped_fake_B) # complete the loss calculation
|
||||
if self.opt.lambda_l2 > 0.0:
|
||||
wapped_fake_B = warp(self.fake_B0, self.f_content) # use updated self.f_content
|
||||
self.l2_loss = F.mse_loss(self.warped_fake_B0_2, wapped_fake_B) # complete the loss calculation
|
||||
|
||||
self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global
|
||||
return self.loss_G
|
||||
|
||||
301
roma.py
Normal file
301
roma.py
Normal file
@ -0,0 +1,301 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from .base_model import BaseModel
|
||||
from . import networks
|
||||
from .patchnce import PatchNCELoss
|
||||
import util.util as util
|
||||
import timm
|
||||
import time
|
||||
import torch.nn.functional as F
|
||||
import sys
|
||||
from functools import partial
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
from torchvision.transforms import transforms as tfs
|
||||
|
||||
class ROMAModel(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
""" Configures options specific for CUT model
|
||||
"""
|
||||
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
|
||||
parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator')
|
||||
parser.add_argument('--lambda_motion', type=float, default=1.0, help='weight for Temporal Consistency')
|
||||
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
||||
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
||||
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
||||
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
|
||||
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
|
||||
parser.add_argument('--local_nums', type=int, default=256)
|
||||
parser.add_argument('--which_D_layer', type=int, default=-1)
|
||||
parser.add_argument('--side_length', type=int, default=7)
|
||||
|
||||
parser.set_defaults(pool_size=0)
|
||||
|
||||
opt, _ = parser.parse_known_args()
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
BaseModel.__init__(self, opt)
|
||||
|
||||
|
||||
self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial', 'motion']
|
||||
self.visual_names = ['real_A0', 'real_A1', 'fake_B0', 'fake_B1', 'real_B0', 'real_B1']
|
||||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
self.model_names = ['G', 'D_ViT']
|
||||
else: # during test time, only load G
|
||||
self.model_names = ['G']
|
||||
|
||||
|
||||
# define networks (both generator and discriminator)
|
||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
||||
|
||||
|
||||
if self.isTrain:
|
||||
|
||||
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
|
||||
self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
|
||||
|
||||
|
||||
self.norm = F.softmax
|
||||
|
||||
self.resize = tfs.Resize(size=(384,384))
|
||||
|
||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||
self.criterionNCE = []
|
||||
|
||||
for atten_layer in self.atten_layers:
|
||||
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
|
||||
|
||||
self.criterionL1 = torch.nn.L1Loss().to(self.device)
|
||||
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||
self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
|
||||
self.optimizers.append(self.optimizer_G)
|
||||
self.optimizers.append(self.optimizer_D_ViT)
|
||||
|
||||
def data_dependent_initialize(self, data):
|
||||
"""
|
||||
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
||||
features of the encoder portion of netG. Because of this, the weights of netF are
|
||||
initialized at the first feedforward pass with some input images.
|
||||
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def optimize_parameters(self):
|
||||
# forward
|
||||
self.forward()
|
||||
|
||||
# update D
|
||||
self.set_requires_grad(self.netD_ViT, True)
|
||||
self.optimizer_D_ViT.zero_grad()
|
||||
self.loss_D = self.compute_D_loss()
|
||||
self.loss_D.backward()
|
||||
self.optimizer_D_ViT.step()
|
||||
|
||||
# update G
|
||||
self.set_requires_grad(self.netD_ViT, False)
|
||||
self.optimizer_G.zero_grad()
|
||||
self.loss_G = self.compute_G_loss()
|
||||
self.loss_G.backward()
|
||||
self.optimizer_G.step()
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
Parameters:
|
||||
input (dict): include the data itself and its metadata information.
|
||||
The option 'direction' can be used to swap domain A and domain B.
|
||||
"""
|
||||
AtoB = self.opt.direction == 'AtoB'
|
||||
self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
|
||||
self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
|
||||
self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
|
||||
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
self.fake_B0 = self.netG(self.real_A0)
|
||||
self.fake_B1 = self.netG(self.real_A1)
|
||||
|
||||
if self.opt.isTrain:
|
||||
real_A0 = self.real_A0
|
||||
real_A1 = self.real_A1
|
||||
real_B0 = self.real_B0
|
||||
real_B1 = self.real_B1
|
||||
fake_B0 = self.fake_B0
|
||||
fake_B1 = self.fake_B1
|
||||
self.real_A0_resize = self.resize(real_A0)
|
||||
self.real_A1_resize = self.resize(real_A1)
|
||||
real_B0 = self.resize(real_B0)
|
||||
real_B1 = self.resize(real_B1)
|
||||
self.fake_B0_resize = self.resize(fake_B0)
|
||||
self.fake_B1_resize = self.resize(fake_B1)
|
||||
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
|
||||
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
||||
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||||
|
||||
def tokens_concat(self, origin_tokens, adjacent_size):
|
||||
adj_size = adjacent_size
|
||||
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
|
||||
S = int(math.sqrt(token_num))
|
||||
if S * S != token_num:
|
||||
print('Error! Not a square!')
|
||||
token_map = origin_tokens.clone().reshape(B,S,S,C)
|
||||
cut_patch_list = []
|
||||
for i in range(0, S, adj_size):
|
||||
for j in range(0, S, adj_size):
|
||||
i_left = i
|
||||
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
|
||||
j_left = j
|
||||
j_right = j + adj_size if j + adj_size <= S else S + 1
|
||||
|
||||
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
|
||||
cut_patch= cut_patch.reshape(B,-1,C)
|
||||
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
|
||||
cut_patch_list.append(cut_patch)
|
||||
|
||||
|
||||
result = torch.cat(cut_patch_list,dim=1)
|
||||
return result
|
||||
|
||||
|
||||
def cat_results(self, origin_tokens, adj_size_list):
|
||||
res_list = [origin_tokens]
|
||||
for ad_s in adj_size_list:
|
||||
cat_result = self.tokens_concat(origin_tokens, ad_s)
|
||||
res_list.append(cat_result)
|
||||
|
||||
result = torch.cat(res_list, dim=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def compute_D_loss(self):
|
||||
"""Calculate GAN loss for the discriminator"""
|
||||
|
||||
|
||||
lambda_D_ViT = self.opt.lambda_D_ViT
|
||||
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer].detach()
|
||||
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer].detach()
|
||||
|
||||
real_B0_tokens = self.mutil_real_B0_tokens[self.opt.which_D_layer]
|
||||
real_B1_tokens = self.mutil_real_B1_tokens[self.opt.which_D_layer]
|
||||
|
||||
|
||||
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
|
||||
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
|
||||
|
||||
|
||||
|
||||
real_B0_tokens = self.cat_results(real_B0_tokens, self.opt.adj_size_list)
|
||||
real_B1_tokens = self.cat_results(real_B1_tokens, self.opt.adj_size_list)
|
||||
|
||||
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
|
||||
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
|
||||
|
||||
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
|
||||
|
||||
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
|
||||
pred_real1_ViT = self.netD_ViT(real_B1_tokens)
|
||||
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
|
||||
|
||||
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
|
||||
|
||||
|
||||
return self.loss_D_ViT
|
||||
|
||||
def compute_G_loss(self):
|
||||
|
||||
if self.opt.lambda_GAN > 0.0:
|
||||
|
||||
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer]
|
||||
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer]
|
||||
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
|
||||
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
|
||||
pred_fake0_ViT = self.netD_ViT(fake_B0_tokens)
|
||||
pred_fake1_ViT = self.netD_ViT(fake_B1_tokens)
|
||||
self.loss_G_GAN_ViT = (self.criterionGAN(pred_fake0_ViT, True) + self.criterionGAN(pred_fake1_ViT, True)) * 0.5 * self.opt.lambda_GAN
|
||||
else:
|
||||
self.loss_G_GAN_ViT = 0.0
|
||||
|
||||
if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
|
||||
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
|
||||
else:
|
||||
self.loss_global, self.loss_spatial = 0.0, 0.0
|
||||
|
||||
if self.opt.lambda_motion > 0.0:
|
||||
self.loss_motion = 0.0
|
||||
for real_A0_tokens, real_A1_tokens, fake_B0_tokens, fake_B1_tokens in zip(self.mutil_real_A0_tokens, self.mutil_real_A1_tokens, self.mutil_fake_B0_tokens, self.mutil_fake_B1_tokens):
|
||||
A0_B1 = real_A0_tokens.bmm(fake_B1_tokens.permute(0,2,1))
|
||||
B0_A1 = fake_B0_tokens.bmm(real_A1_tokens.permute(0,2,1))
|
||||
cos_dis_global = F.cosine_similarity(A0_B1, B0_A1, dim=-1)
|
||||
self.loss_motion += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
|
||||
else:
|
||||
self.loss_motion = 0.0
|
||||
|
||||
self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial + self.loss_motion
|
||||
return self.loss_G
|
||||
|
||||
def calculate_attention_loss(self):
|
||||
n_layers = len(self.atten_layers)
|
||||
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
||||
mutil_real_A1_tokens = self.mutil_real_A1_tokens
|
||||
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
|
||||
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
|
||||
|
||||
|
||||
if self.opt.lambda_global > 0.0:
|
||||
loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
|
||||
loss_global *= 0.5
|
||||
|
||||
else:
|
||||
loss_global = 0.0
|
||||
|
||||
if self.opt.lambda_spatial > 0.0:
|
||||
loss_spatial = 0.0
|
||||
local_nums = self.opt.local_nums
|
||||
tokens_cnt = 576
|
||||
local_id = np.random.permutation(tokens_cnt)
|
||||
local_id = local_id[:int(min(local_nums, tokens_cnt))]
|
||||
|
||||
mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||
mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||
|
||||
mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||
mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||
|
||||
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
|
||||
loss_spatial *= 0.5
|
||||
|
||||
else:
|
||||
loss_spatial = 0.0
|
||||
|
||||
|
||||
|
||||
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
|
||||
|
||||
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
|
||||
loss = 0.0
|
||||
n_layers = len(self.atten_layers)
|
||||
|
||||
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
|
||||
|
||||
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
|
||||
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
|
||||
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
|
||||
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
|
||||
|
||||
loss = loss / n_layers
|
||||
return loss
|
||||
Loading…
x
Reference in New Issue
Block a user