修改后的最新

This commit is contained in:
bishe 2025-02-27 18:00:41 +08:00
parent 7a6e856b4b
commit 2a0a56ac26
2 changed files with 13 additions and 47 deletions

View File

@ -119,7 +119,6 @@ class ContentAwareOptimization(nn.Module):
# 清空梯度缓存 # 清空梯度缓存
self.gradients_real.clear() self.gradients_real.clear()
self.gradients_fake.clear() self.gradients_fake.clear()
# 注册钩子 # 注册钩子
hook_real = lambda grad: self.gradients_real.append(grad.detach()) hook_real = lambda grad: self.gradients_real.append(grad.detach())
hook_fake = lambda grad: self.gradients_fake.append(grad.detach()) hook_fake = lambda grad: self.gradients_fake.append(grad.detach())
@ -138,10 +137,10 @@ class ContentAwareOptimization(nn.Module):
weight_real, weight_fake = self.generate_weight_map(grad_real, grad_fake) weight_real, weight_fake = self.generate_weight_map(grad_real, grad_fake)
# 计算加权损失 # 计算加权损失
loss_co_real = (weight_real * torch.log(real_scores + 1e-8)).mean() loss_co_real = (weight_real * real_scores).mean()
loss_co_fake = (weight_fake * torch.log(1 - fake_scores + 1e-8)).mean() loss_co_fake = (weight_fake * fake_scores).mean()
return -(loss_co_real + loss_co_fake), weight_real, weight_fake return (loss_co_real + loss_co_fake), weight_real, weight_fake
class ContentAwareTemporalNorm(nn.Module): class ContentAwareTemporalNorm(nn.Module):
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0): def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
@ -252,7 +251,7 @@ class RomaUnsbModel(BaseModel):
BaseModel.__init__(self, opt) BaseModel.__init__(self, opt)
# 指定需要打印的训练损失 # 指定需要打印的训练损失
self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn'] self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn',]
self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0'] self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')] self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
@ -368,40 +367,6 @@ class RomaUnsbModel(BaseModel):
self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.image_paths = input['A_paths' if AtoB else 'B_paths']
def tokens_concat(self, origin_tokens, adjacent_size):
adj_size = adjacent_size
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
S = int(math.sqrt(token_num))
if S * S != token_num:
print('Error! Not a square!')
token_map = origin_tokens.clone().reshape(B,S,S,C)
cut_patch_list = []
for i in range(0, S, adj_size):
for j in range(0, S, adj_size):
i_left = i
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
j_left = j
j_right = j + adj_size if j + adj_size <= S else S + 1
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
cut_patch= cut_patch.reshape(B,-1,C)
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
cut_patch_list.append(cut_patch)
result = torch.cat(cut_patch_list,dim=1)
return result
def cat_results(self, origin_tokens, adj_size_list):
res_list = [origin_tokens]
for ad_s in adj_size_list:
cat_result = self.tokens_concat(origin_tokens, ad_s)
res_list.append(cat_result)
result = torch.cat(res_list, dim=1)
return result
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
@ -462,8 +427,8 @@ class RomaUnsbModel(BaseModel):
self.real = torch.flip(self.real, [3]) self.real = torch.flip(self.real, [3])
self.realt = torch.flip(self.realt, [3]) self.realt = torch.flip(self.realt, [3])
self.fake_B0 = self.netG(self.real_A0, self.time, self.z_in) self.fake_B0 = self.netG(self.real_A_noisy, self.time, self.z_in)
self.fake_B1 = self.netG(self.real_A1, self.time, self.z_in2) self.fake_B1 = self.netG(self.real_A_noisy2, self.time, self.z_in2)
if self.opt.phase == 'train': if self.opt.phase == 'train':
real_A0 = self.real_A0 real_A0 = self.real_A0
@ -507,8 +472,8 @@ class RomaUnsbModel(BaseModel):
self.loss_D_real_ViT = self.criterionGAN(pred_real0_ViT, True) self.loss_D_real_ViT = self.criterionGAN(pred_real0_ViT, True)
self.losscao, self.weight_real, self.weight_fake = self.cao(pred_real0_ViT, pre_fake0_ViT, self.loss_D_real_ViT, self.loss_D_fake_ViT) self.losscao, self.weight_real, self.weight_fake = self.cao(pred_real0_ViT, pre_fake0_ViT, self.loss_D_real_ViT, self.loss_D_fake_ViT)
self.loss_D_ViT = self.losscao* lambda_D_ViT
return self.losscao* lambda_D_ViT return self.loss_D_ViT
def compute_E_loss(self): def compute_E_loss(self):
"""计算判别器 E 的损失""" """计算判别器 E 的损失"""

View File

@ -7,7 +7,7 @@
python train.py \ python train.py \
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
--name UNIV_2 \ --name UNIV_1 \
--dataset_mode unaligned_double \ --dataset_mode unaligned_double \
--no_flip \ --no_flip \
--display_env UNIV \ --display_env UNIV \
@ -17,13 +17,14 @@ python train.py \
--lambda_ctn 1.0 \ --lambda_ctn 1.0 \
--lambda_inc 1.0 \ --lambda_inc 1.0 \
--lr 0.00001 \ --lr 0.00001 \
--gpu_id 1 \ --gpu_id 0 \
--lambda_D_ViT 1 \
--nce_idt False \ --nce_idt False \
--netF mlp_sample \ --netF mlp_sample \
--flip_equivariance True \ --flip_equivariance True \
--eta_ratio 0.4 \ --eta_ratio 0.4 \
--tau 0.01 \ --tau 0.01 \
--num_timesteps 5 \ --num_timesteps 4 \
--input_nc 3 \ --input_nc 3 \
--n_epochs 400 \ --n_epochs 400 \
--n_epochs_decay 200 \ --n_epochs_decay 200 \