running UNIV

This commit is contained in:
bishe 2025-02-26 22:24:17 +08:00
parent e8e483fbf8
commit 7a6e856b4b
5 changed files with 20 additions and 22 deletions

View File

@ -216,11 +216,11 @@ class RomaUnsbModel(BaseModel):
"""配置 CTNx 模型的特定选项""" """配置 CTNx 模型的特定选项"""
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))') parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss') parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm') parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator') parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))') parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
parser.add_argument('--nce_includes_all_negatives_from_minibatch', parser.add_argument('--nce_includes_all_negatives_from_minibatch',
@ -234,7 +234,6 @@ class RomaUnsbModel(BaseModel):
type=util.str2bool, nargs='?', const=True, default=False, type=util.str2bool, nargs='?', const=True, default=False,
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT") help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions') parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers') parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
@ -244,13 +243,8 @@ class RomaUnsbModel(BaseModel):
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers') parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
parser.set_defaults(pool_size=0) # no image pooling
opt, _ = parser.parse_known_args() opt, _ = parser.parse_known_args()
# 直接设置为 sb 模式
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
return parser return parser
def __init__(self, opt): def __init__(self, opt):
@ -258,7 +252,7 @@ class RomaUnsbModel(BaseModel):
BaseModel.__init__(self, opt) BaseModel.__init__(self, opt)
# 指定需要打印的训练损失 # 指定需要打印的训练损失
self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB'] self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn']
self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0'] self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0']
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')] self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
@ -542,7 +536,7 @@ class RomaUnsbModel(BaseModel):
warped_fake_B0_2=self.warped_fake_B0_2 warped_fake_B0_2=self.warped_fake_B0_2
warped_fake_B0=self.warped_fake_B0 warped_fake_B0=self.warped_fake_B0
# 计算L2损失 # 计算L2损失
self.ctn_loss = F.mse_loss(warped_fake_B0_2, warped_fake_B0) self.loss_ctn = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
if self.opt.lambda_GAN > 0.0: if self.opt.lambda_GAN > 0.0:
pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0]) pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
@ -563,15 +557,15 @@ class RomaUnsbModel(BaseModel):
self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2) self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
if self.opt.lambda_global > 0.0: if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1) self.loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
loss_global *= 0.5 self.loss_global *= 0.5
else: else:
loss_global = 0.0 self.loss_global = 0.0
self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \ self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \
self.opt.lambda_SB * self.loss_SB + \ self.opt.lambda_SB * self.loss_SB + \
self.opt.lambda_ctn * self.ctn_loss + \ self.opt.lambda_ctn * self.loss_ctn + \
loss_global * self.opt.lambda_global self.loss_global * self.opt.lambda_global
return self.loss_G return self.loss_G
def calculate_attention_loss(self): def calculate_attention_loss(self):

View File

@ -7,22 +7,26 @@
python train.py \ python train.py \
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
--name ROMA_UNSB_003 \ --name UNIV_2 \
--dataset_mode unaligned_double \ --dataset_mode unaligned_double \
--no_flip \ --no_flip \
--display_env ROMA \ --display_env UNIV \
--model roma_unsb \ --model roma_unsb \
--lambda_GAN 1.0 \ --lambda_GAN 2.0 \
--lambda_NCE 8.0 \
--lambda_SB 1.0 \ --lambda_SB 1.0 \
--lambda_ctn 1.0 \ --lambda_ctn 1.0 \
--lambda_inc 1.0 \ --lambda_inc 1.0 \
--lr 0.00001 \ --lr 0.00001 \
--gpu_id 0 \ --gpu_id 1 \
--nce_idt False \ --nce_idt False \
--netF mlp_sample \ --netF mlp_sample \
--flip_equivariance True \ --flip_equivariance True \
--eta_ratio 0.4 \ --eta_ratio 0.4 \
--tau 0.01 \ --tau 0.01 \
--num_timesteps 4 \ --num_timesteps 5 \
--input_nc 3 --input_nc 3 \
--n_epochs 400 \
--n_epochs_decay 200 \
# exp1 num_timesteps=4
# exp2 num_timesteps=5