running UNIV
This commit is contained in:
parent
e8e483fbf8
commit
7a6e856b4b
Binary file not shown.
@ -216,11 +216,11 @@ class RomaUnsbModel(BaseModel):
|
||||
"""配置 CTNx 模型的特定选项"""
|
||||
|
||||
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
||||
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
|
||||
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
|
||||
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
|
||||
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
||||
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
||||
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
||||
|
||||
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
|
||||
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
||||
@ -234,7 +234,6 @@ class RomaUnsbModel(BaseModel):
|
||||
type=util.str2bool, nargs='?', const=True, default=False,
|
||||
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
|
||||
|
||||
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
||||
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
|
||||
|
||||
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
|
||||
@ -244,13 +243,8 @@ class RomaUnsbModel(BaseModel):
|
||||
|
||||
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
|
||||
|
||||
parser.set_defaults(pool_size=0) # no image pooling
|
||||
|
||||
opt, _ = parser.parse_known_args()
|
||||
|
||||
# 直接设置为 sb 模式
|
||||
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
|
||||
|
||||
return parser
|
||||
|
||||
def __init__(self, opt):
|
||||
@ -258,7 +252,7 @@ class RomaUnsbModel(BaseModel):
|
||||
BaseModel.__init__(self, opt)
|
||||
|
||||
# 指定需要打印的训练损失
|
||||
self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB']
|
||||
self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn']
|
||||
self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0']
|
||||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||
|
||||
@ -542,7 +536,7 @@ class RomaUnsbModel(BaseModel):
|
||||
warped_fake_B0_2=self.warped_fake_B0_2
|
||||
warped_fake_B0=self.warped_fake_B0
|
||||
# 计算L2损失
|
||||
self.ctn_loss = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
|
||||
self.loss_ctn = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
|
||||
|
||||
if self.opt.lambda_GAN > 0.0:
|
||||
pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
|
||||
@ -563,15 +557,15 @@ class RomaUnsbModel(BaseModel):
|
||||
self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
|
||||
|
||||
if self.opt.lambda_global > 0.0:
|
||||
loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
|
||||
loss_global *= 0.5
|
||||
self.loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
|
||||
self.loss_global *= 0.5
|
||||
else:
|
||||
loss_global = 0.0
|
||||
self.loss_global = 0.0
|
||||
|
||||
self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \
|
||||
self.opt.lambda_SB * self.loss_SB + \
|
||||
self.opt.lambda_ctn * self.ctn_loss + \
|
||||
loss_global * self.opt.lambda_global
|
||||
self.opt.lambda_ctn * self.loss_ctn + \
|
||||
self.loss_global * self.opt.lambda_global
|
||||
return self.loss_G
|
||||
|
||||
def calculate_attention_loss(self):
|
||||
|
||||
Binary file not shown.
@ -7,22 +7,26 @@
|
||||
|
||||
python train.py \
|
||||
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
||||
--name ROMA_UNSB_003 \
|
||||
--name UNIV_2 \
|
||||
--dataset_mode unaligned_double \
|
||||
--no_flip \
|
||||
--display_env ROMA \
|
||||
--display_env UNIV \
|
||||
--model roma_unsb \
|
||||
--lambda_GAN 1.0 \
|
||||
--lambda_NCE 8.0 \
|
||||
--lambda_GAN 2.0 \
|
||||
--lambda_SB 1.0 \
|
||||
--lambda_ctn 1.0 \
|
||||
--lambda_inc 1.0 \
|
||||
--lr 0.00001 \
|
||||
--gpu_id 0 \
|
||||
--gpu_id 1 \
|
||||
--nce_idt False \
|
||||
--netF mlp_sample \
|
||||
--flip_equivariance True \
|
||||
--eta_ratio 0.4 \
|
||||
--tau 0.01 \
|
||||
--num_timesteps 4 \
|
||||
--input_nc 3
|
||||
--num_timesteps 5 \
|
||||
--input_nc 3 \
|
||||
--n_epochs 400 \
|
||||
--n_epochs_decay 200 \
|
||||
|
||||
# exp1 num_timesteps=4
|
||||
# exp2 num_timesteps=5
|
||||
Loading…
x
Reference in New Issue
Block a user