running UNIV
This commit is contained in:
parent
e8e483fbf8
commit
7a6e856b4b
Binary file not shown.
@ -216,11 +216,11 @@ class RomaUnsbModel(BaseModel):
|
|||||||
"""配置 CTNx 模型的特定选项"""
|
"""配置 CTNx 模型的特定选项"""
|
||||||
|
|
||||||
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
||||||
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
|
|
||||||
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
|
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
|
||||||
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
|
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
|
||||||
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
||||||
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
||||||
|
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
||||||
|
|
||||||
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
|
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
|
||||||
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
||||||
@ -234,7 +234,6 @@ class RomaUnsbModel(BaseModel):
|
|||||||
type=util.str2bool, nargs='?', const=True, default=False,
|
type=util.str2bool, nargs='?', const=True, default=False,
|
||||||
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
|
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
|
||||||
|
|
||||||
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
|
||||||
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
|
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
|
||||||
|
|
||||||
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
|
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
|
||||||
@ -244,13 +243,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
|
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
|
||||||
|
|
||||||
parser.set_defaults(pool_size=0) # no image pooling
|
|
||||||
|
|
||||||
opt, _ = parser.parse_known_args()
|
opt, _ = parser.parse_known_args()
|
||||||
|
|
||||||
# 直接设置为 sb 模式
|
|
||||||
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
def __init__(self, opt):
|
def __init__(self, opt):
|
||||||
@ -258,7 +252,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
BaseModel.__init__(self, opt)
|
BaseModel.__init__(self, opt)
|
||||||
|
|
||||||
# 指定需要打印的训练损失
|
# 指定需要打印的训练损失
|
||||||
self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB']
|
self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn']
|
||||||
self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0']
|
self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0']
|
||||||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||||
|
|
||||||
@ -542,7 +536,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
warped_fake_B0_2=self.warped_fake_B0_2
|
warped_fake_B0_2=self.warped_fake_B0_2
|
||||||
warped_fake_B0=self.warped_fake_B0
|
warped_fake_B0=self.warped_fake_B0
|
||||||
# 计算L2损失
|
# 计算L2损失
|
||||||
self.ctn_loss = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
|
self.loss_ctn = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
|
||||||
|
|
||||||
if self.opt.lambda_GAN > 0.0:
|
if self.opt.lambda_GAN > 0.0:
|
||||||
pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
|
pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
|
||||||
@ -563,15 +557,15 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
|
self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
|
||||||
|
|
||||||
if self.opt.lambda_global > 0.0:
|
if self.opt.lambda_global > 0.0:
|
||||||
loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
|
self.loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
|
||||||
loss_global *= 0.5
|
self.loss_global *= 0.5
|
||||||
else:
|
else:
|
||||||
loss_global = 0.0
|
self.loss_global = 0.0
|
||||||
|
|
||||||
self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \
|
self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \
|
||||||
self.opt.lambda_SB * self.loss_SB + \
|
self.opt.lambda_SB * self.loss_SB + \
|
||||||
self.opt.lambda_ctn * self.ctn_loss + \
|
self.opt.lambda_ctn * self.loss_ctn + \
|
||||||
loss_global * self.opt.lambda_global
|
self.loss_global * self.opt.lambda_global
|
||||||
return self.loss_G
|
return self.loss_G
|
||||||
|
|
||||||
def calculate_attention_loss(self):
|
def calculate_attention_loss(self):
|
||||||
|
|||||||
Binary file not shown.
@ -7,22 +7,26 @@
|
|||||||
|
|
||||||
python train.py \
|
python train.py \
|
||||||
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
||||||
--name ROMA_UNSB_003 \
|
--name UNIV_2 \
|
||||||
--dataset_mode unaligned_double \
|
--dataset_mode unaligned_double \
|
||||||
--no_flip \
|
--no_flip \
|
||||||
--display_env ROMA \
|
--display_env UNIV \
|
||||||
--model roma_unsb \
|
--model roma_unsb \
|
||||||
--lambda_GAN 1.0 \
|
--lambda_GAN 2.0 \
|
||||||
--lambda_NCE 8.0 \
|
|
||||||
--lambda_SB 1.0 \
|
--lambda_SB 1.0 \
|
||||||
--lambda_ctn 1.0 \
|
--lambda_ctn 1.0 \
|
||||||
--lambda_inc 1.0 \
|
--lambda_inc 1.0 \
|
||||||
--lr 0.00001 \
|
--lr 0.00001 \
|
||||||
--gpu_id 0 \
|
--gpu_id 1 \
|
||||||
--nce_idt False \
|
--nce_idt False \
|
||||||
--netF mlp_sample \
|
--netF mlp_sample \
|
||||||
--flip_equivariance True \
|
--flip_equivariance True \
|
||||||
--eta_ratio 0.4 \
|
--eta_ratio 0.4 \
|
||||||
--tau 0.01 \
|
--tau 0.01 \
|
||||||
--num_timesteps 4 \
|
--num_timesteps 5 \
|
||||||
--input_nc 3
|
--input_nc 3 \
|
||||||
|
--n_epochs 400 \
|
||||||
|
--n_epochs_decay 200 \
|
||||||
|
|
||||||
|
# exp1 num_timesteps=4
|
||||||
|
# exp2 num_timesteps=5
|
||||||
Loading…
x
Reference in New Issue
Block a user