cptrans复现
This commit is contained in:
parent
997fdd3770
commit
e0dc08030c
@ -227,7 +227,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
# 指定需要打印的训练损失
|
# 指定需要打印的训练损失
|
||||||
self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
|
self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
|
||||||
self.visual_names = ['real_A0', 'fake_B0_1','fake_B0', 'real_B0','real_A1', 'fake_B1_1', 'fake_B1', 'real_B1']
|
self.visual_names = ['real_A0', 'fake_B0', 'real_B0','real_A1', 'fake_B1', 'real_B1']
|
||||||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||||
|
|
||||||
|
|
||||||
@ -380,8 +380,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
"""计算生成器的 GAN 损失"""
|
"""计算生成器的 GAN 损失"""
|
||||||
if self.opt.lambda_ctn > 0.0:
|
if self.opt.lambda_ctn > 0.0:
|
||||||
# 生成图像的CTN光流图
|
# 生成图像的CTN光流图
|
||||||
self.f_content0 = self.ctn(self.weight_fake0)
|
self.f_content0 = self.ctn(self.weight_fake0.detach())
|
||||||
self.f_content1 = self.ctn(self.weight_fake1)
|
self.f_content1 = self.ctn(self.weight_fake1.detach())
|
||||||
|
|
||||||
# 变换后的图片
|
# 变换后的图片
|
||||||
self.warped_real_A0 = warp(self.real_A0, self.f_content0)
|
self.warped_real_A0 = warp(self.real_A0, self.f_content0)
|
||||||
@ -429,8 +429,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
n_layers = len(self.atten_layers)
|
n_layers = len(self.atten_layers)
|
||||||
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
||||||
mutil_real_A1_tokens = self.mutil_real_A1_tokens
|
mutil_real_A1_tokens = self.mutil_real_A1_tokens
|
||||||
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens_list[-1]
|
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
|
||||||
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens_list[-1]
|
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
|
||||||
|
|
||||||
|
|
||||||
if self.opt.lambda_global > 0.0:
|
if self.opt.lambda_global > 0.0:
|
||||||
|
|||||||
@ -17,7 +17,7 @@ python train.py \
|
|||||||
--lambda_global 6.0 \
|
--lambda_global 6.0 \
|
||||||
--gamma_stride 20 \
|
--gamma_stride 20 \
|
||||||
--lr 0.000002 \
|
--lr 0.000002 \
|
||||||
--gpu_id 1 \
|
--gpu_id 0 \
|
||||||
--nce_idt False \
|
--nce_idt False \
|
||||||
--netF mlp_sample \
|
--netF mlp_sample \
|
||||||
--eta_ratio 0.4 \
|
--eta_ratio 0.4 \
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
python train.py \
|
python train.py \
|
||||||
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
||||||
--name cp_1 \
|
--name cp_2 \
|
||||||
--dataset_mode unaligned_double \
|
--dataset_mode unaligned_double \
|
||||||
--display_env CP \
|
--display_env CP \
|
||||||
--model roma_unsb \
|
--model roma_unsb \
|
||||||
@ -9,9 +9,10 @@ python train.py \
|
|||||||
--lambda_global 6.0 \
|
--lambda_global 6.0 \
|
||||||
--lambda_spatial 6.0 \
|
--lambda_spatial 6.0 \
|
||||||
--gamma_stride 20 \
|
--gamma_stride 20 \
|
||||||
--lr 0.000001 \
|
--lr 0.000002 \
|
||||||
--gpu_id 2 \
|
--gpu_id 0 \
|
||||||
--eta_ratio 0.4 \
|
--eta_ratio 0.4 \
|
||||||
--n_epochs 100 \
|
--n_epochs 100 \
|
||||||
--n_epochs_decay 100 \
|
--n_epochs_decay 100 \
|
||||||
# cp1 复现cptrans的效果
|
# cp1 复现cptrans的效果 --lr 0.000001
|
||||||
|
# cp2 修了一下cp1的代码,--lr 0.000002
|
||||||
Loading…
x
Reference in New Issue
Block a user