cptrans复现
This commit is contained in:
parent
997fdd3770
commit
e0dc08030c
@ -227,7 +227,7 @@ class RomaUnsbModel(BaseModel):
|
||||
|
||||
# 指定需要打印的训练损失
|
||||
self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
|
||||
self.visual_names = ['real_A0', 'fake_B0_1','fake_B0', 'real_B0','real_A1', 'fake_B1_1', 'fake_B1', 'real_B1']
|
||||
self.visual_names = ['real_A0', 'fake_B0', 'real_B0','real_A1', 'fake_B1', 'real_B1']
|
||||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||
|
||||
|
||||
@ -380,8 +380,8 @@ class RomaUnsbModel(BaseModel):
|
||||
"""计算生成器的 GAN 损失"""
|
||||
if self.opt.lambda_ctn > 0.0:
|
||||
# 生成图像的CTN光流图
|
||||
self.f_content0 = self.ctn(self.weight_fake0)
|
||||
self.f_content1 = self.ctn(self.weight_fake1)
|
||||
self.f_content0 = self.ctn(self.weight_fake0.detach())
|
||||
self.f_content1 = self.ctn(self.weight_fake1.detach())
|
||||
|
||||
# 变换后的图片
|
||||
self.warped_real_A0 = warp(self.real_A0, self.f_content0)
|
||||
@ -429,8 +429,8 @@ class RomaUnsbModel(BaseModel):
|
||||
n_layers = len(self.atten_layers)
|
||||
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
||||
mutil_real_A1_tokens = self.mutil_real_A1_tokens
|
||||
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens_list[-1]
|
||||
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens_list[-1]
|
||||
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
|
||||
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
|
||||
|
||||
|
||||
if self.opt.lambda_global > 0.0:
|
||||
|
||||
@ -17,7 +17,7 @@ python train.py \
|
||||
--lambda_global 6.0 \
|
||||
--gamma_stride 20 \
|
||||
--lr 0.000002 \
|
||||
--gpu_id 1 \
|
||||
--gpu_id 0 \
|
||||
--nce_idt False \
|
||||
--netF mlp_sample \
|
||||
--eta_ratio 0.4 \
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
python train.py \
|
||||
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
||||
--name cp_1 \
|
||||
--name cp_2 \
|
||||
--dataset_mode unaligned_double \
|
||||
--display_env CP \
|
||||
--model roma_unsb \
|
||||
@ -9,9 +9,10 @@ python train.py \
|
||||
--lambda_global 6.0 \
|
||||
--lambda_spatial 6.0 \
|
||||
--gamma_stride 20 \
|
||||
--lr 0.000001 \
|
||||
--gpu_id 2 \
|
||||
--lr 0.000002 \
|
||||
--gpu_id 0 \
|
||||
--eta_ratio 0.4 \
|
||||
--n_epochs 100 \
|
||||
--n_epochs_decay 100 \
|
||||
# cp1 复现cptrans的效果
|
||||
# cp1 复现cptrans的效果 --lr 0.000001
|
||||
# cp2 修了一下cp1的代码,--lr 0.000002
|
||||
Loading…
x
Reference in New Issue
Block a user