修改后的最新
This commit is contained in:
parent
7a6e856b4b
commit
2a0a56ac26
@ -118,8 +118,7 @@ class ContentAwareOptimization(nn.Module):
|
||||
def forward(self, D_real, D_fake, real_scores, fake_scores):
|
||||
# 清空梯度缓存
|
||||
self.gradients_real.clear()
|
||||
self.gradients_fake.clear()
|
||||
|
||||
self.gradients_fake.clear()
|
||||
# 注册钩子
|
||||
hook_real = lambda grad: self.gradients_real.append(grad.detach())
|
||||
hook_fake = lambda grad: self.gradients_fake.append(grad.detach())
|
||||
@ -138,10 +137,10 @@ class ContentAwareOptimization(nn.Module):
|
||||
weight_real, weight_fake = self.generate_weight_map(grad_real, grad_fake)
|
||||
|
||||
# 计算加权损失
|
||||
loss_co_real = (weight_real * torch.log(real_scores + 1e-8)).mean()
|
||||
loss_co_fake = (weight_fake * torch.log(1 - fake_scores + 1e-8)).mean()
|
||||
loss_co_real = (weight_real * real_scores).mean()
|
||||
loss_co_fake = (weight_fake * fake_scores).mean()
|
||||
|
||||
return -(loss_co_real + loss_co_fake), weight_real, weight_fake
|
||||
return (loss_co_real + loss_co_fake), weight_real, weight_fake
|
||||
|
||||
class ContentAwareTemporalNorm(nn.Module):
|
||||
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
||||
@ -252,7 +251,7 @@ class RomaUnsbModel(BaseModel):
|
||||
BaseModel.__init__(self, opt)
|
||||
|
||||
# 指定需要打印的训练损失
|
||||
self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn']
|
||||
self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn',]
|
||||
self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0']
|
||||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||
|
||||
@ -368,40 +367,6 @@ class RomaUnsbModel(BaseModel):
|
||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||
|
||||
|
||||
def tokens_concat(self, origin_tokens, adjacent_size):
|
||||
adj_size = adjacent_size
|
||||
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
|
||||
S = int(math.sqrt(token_num))
|
||||
if S * S != token_num:
|
||||
print('Error! Not a square!')
|
||||
token_map = origin_tokens.clone().reshape(B,S,S,C)
|
||||
cut_patch_list = []
|
||||
for i in range(0, S, adj_size):
|
||||
for j in range(0, S, adj_size):
|
||||
i_left = i
|
||||
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
|
||||
j_left = j
|
||||
j_right = j + adj_size if j + adj_size <= S else S + 1
|
||||
|
||||
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
|
||||
cut_patch= cut_patch.reshape(B,-1,C)
|
||||
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
|
||||
cut_patch_list.append(cut_patch)
|
||||
|
||||
|
||||
result = torch.cat(cut_patch_list,dim=1)
|
||||
return result
|
||||
|
||||
def cat_results(self, origin_tokens, adj_size_list):
|
||||
res_list = [origin_tokens]
|
||||
for ad_s in adj_size_list:
|
||||
cat_result = self.tokens_concat(origin_tokens, ad_s)
|
||||
res_list.append(cat_result)
|
||||
|
||||
result = torch.cat(res_list, dim=1)
|
||||
|
||||
return result
|
||||
|
||||
def forward(self):
|
||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||
|
||||
@ -462,8 +427,8 @@ class RomaUnsbModel(BaseModel):
|
||||
self.real = torch.flip(self.real, [3])
|
||||
self.realt = torch.flip(self.realt, [3])
|
||||
|
||||
self.fake_B0 = self.netG(self.real_A0, self.time, self.z_in)
|
||||
self.fake_B1 = self.netG(self.real_A1, self.time, self.z_in2)
|
||||
self.fake_B0 = self.netG(self.real_A_noisy, self.time, self.z_in)
|
||||
self.fake_B1 = self.netG(self.real_A_noisy2, self.time, self.z_in2)
|
||||
|
||||
if self.opt.phase == 'train':
|
||||
real_A0 = self.real_A0
|
||||
@ -507,8 +472,8 @@ class RomaUnsbModel(BaseModel):
|
||||
self.loss_D_real_ViT = self.criterionGAN(pred_real0_ViT, True)
|
||||
|
||||
self.losscao, self.weight_real, self.weight_fake = self.cao(pred_real0_ViT, pre_fake0_ViT, self.loss_D_real_ViT, self.loss_D_fake_ViT)
|
||||
|
||||
return self.losscao* lambda_D_ViT
|
||||
self.loss_D_ViT = self.losscao* lambda_D_ViT
|
||||
return self.loss_D_ViT
|
||||
|
||||
def compute_E_loss(self):
|
||||
"""计算判别器 E 的损失"""
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
|
||||
python train.py \
|
||||
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
||||
--name UNIV_2 \
|
||||
--name UNIV_1 \
|
||||
--dataset_mode unaligned_double \
|
||||
--no_flip \
|
||||
--display_env UNIV \
|
||||
@ -17,13 +17,14 @@ python train.py \
|
||||
--lambda_ctn 1.0 \
|
||||
--lambda_inc 1.0 \
|
||||
--lr 0.00001 \
|
||||
--gpu_id 1 \
|
||||
--gpu_id 0 \
|
||||
--lambda_D_ViT 1 \
|
||||
--nce_idt False \
|
||||
--netF mlp_sample \
|
||||
--flip_equivariance True \
|
||||
--eta_ratio 0.4 \
|
||||
--tau 0.01 \
|
||||
--num_timesteps 5 \
|
||||
--num_timesteps 4 \
|
||||
--input_nc 3 \
|
||||
--n_epochs 400 \
|
||||
--n_epochs_decay 200 \
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user