cptrans复现

This commit is contained in:
bishe 2025-03-09 21:41:52 +08:00
parent 997fdd3770
commit e0dc08030c
3 changed files with 11 additions and 10 deletions

View File

@ -227,7 +227,7 @@ class RomaUnsbModel(BaseModel):
# 指定需要打印的训练损失
self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
self.visual_names = ['real_A0', 'fake_B0_1','fake_B0', 'real_B0','real_A1', 'fake_B1_1', 'fake_B1', 'real_B1']
self.visual_names = ['real_A0', 'fake_B0', 'real_B0','real_A1', 'fake_B1', 'real_B1']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
@ -380,8 +380,8 @@ class RomaUnsbModel(BaseModel):
"""计算生成器的 GAN 损失"""
if self.opt.lambda_ctn > 0.0:
# 生成图像的CTN光流图
self.f_content0 = self.ctn(self.weight_fake0)
self.f_content1 = self.ctn(self.weight_fake1)
self.f_content0 = self.ctn(self.weight_fake0.detach())
self.f_content1 = self.ctn(self.weight_fake1.detach())
# 变换后的图片
self.warped_real_A0 = warp(self.real_A0, self.f_content0)
@ -429,8 +429,8 @@ class RomaUnsbModel(BaseModel):
n_layers = len(self.atten_layers)
mutil_real_A0_tokens = self.mutil_real_A0_tokens
mutil_real_A1_tokens = self.mutil_real_A1_tokens
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens_list[-1]
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens_list[-1]
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
if self.opt.lambda_global > 0.0:

View File

@ -17,7 +17,7 @@ python train.py \
--lambda_global 6.0 \
--gamma_stride 20 \
--lr 0.000002 \
--gpu_id 1 \
--gpu_id 0 \
--nce_idt False \
--netF mlp_sample \
--eta_ratio 0.4 \

View File

@ -1,6 +1,6 @@
python train.py \
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
--name cp_1 \
--name cp_2 \
--dataset_mode unaligned_double \
--display_env CP \
--model roma_unsb \
@ -9,9 +9,10 @@ python train.py \
--lambda_global 6.0 \
--lambda_spatial 6.0 \
--gamma_stride 20 \
--lr 0.000001 \
--gpu_id 2 \
--lr 0.000002 \
--gpu_id 0 \
--eta_ratio 0.4 \
--n_epochs 100 \
--n_epochs_decay 100 \
# cp1 复现cptrans的效果
# cp1 复现cptrans的效果 --lr 0.000001
# cp2 修了一下cp1的代码--lr 0.000002