diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index 5f5b809..728de33 100644 Binary files a/models/__pycache__/roma_unsb_model.cpython-39.pyc and b/models/__pycache__/roma_unsb_model.cpython-39.pyc differ diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 6e4859d..d0aee8d 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -216,11 +216,11 @@ class RomaUnsbModel(BaseModel): """配置 CTNx 模型的特定选项""" parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))') - parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)') parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss') parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm') parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator') parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') + parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization') parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))') parser.add_argument('--nce_includes_all_negatives_from_minibatch', @@ -234,7 +234,6 @@ class RomaUnsbModel(BaseModel): type=util.str2bool, nargs='?', const=True, default=False, help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT") - parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization') parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions') parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers') @@ -244,12 +243,7 @@ class RomaUnsbModel(BaseModel): parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers') - parser.set_defaults(pool_size=0) # no image pooling - opt, _ = parser.parse_known_args() - - # 直接设置为 sb 模式 - parser.set_defaults(nce_idt=True, lambda_NCE=1.0) return parser @@ -258,7 +252,7 @@ class RomaUnsbModel(BaseModel): BaseModel.__init__(self, opt) # 指定需要打印的训练损失 - self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB'] + self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn'] self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0'] self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')] @@ -542,7 +536,7 @@ class RomaUnsbModel(BaseModel): warped_fake_B0_2=self.warped_fake_B0_2 warped_fake_B0=self.warped_fake_B0 # 计算L2损失 - self.ctn_loss = F.mse_loss(warped_fake_B0_2, warped_fake_B0) + self.loss_ctn = F.mse_loss(warped_fake_B0_2, warped_fake_B0) if self.opt.lambda_GAN > 0.0: pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0]) @@ -563,15 +557,15 @@ class RomaUnsbModel(BaseModel): self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2) if self.opt.lambda_global > 0.0: - loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1) - loss_global *= 0.5 + self.loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1) + self.loss_global *= 0.5 else: - loss_global = 0.0 + self.loss_global = 0.0 self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \ self.opt.lambda_SB * self.loss_SB + \ - self.opt.lambda_ctn * self.ctn_loss + \ - loss_global * self.opt.lambda_global + self.opt.lambda_ctn * self.loss_ctn + \ + self.loss_global * self.opt.lambda_global return self.loss_G def calculate_attention_loss(self): diff --git a/options/__pycache__/train_options.cpython-39.pyc b/options/__pycache__/train_options.cpython-39.pyc index 495487e..76d2a2f 100644 Binary files a/options/__pycache__/train_options.cpython-39.pyc and b/options/__pycache__/train_options.cpython-39.pyc differ diff --git a/options/train_options.py b/options/train_options.py index 5df79aa..c7ba288 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -31,7 +31,7 @@ class TrainOptions(BaseOptions): parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') - + # training parameters parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate') parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero') diff --git a/scripts/train.sh b/scripts/train.sh index 0c016bf..6b117d9 100755 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -7,22 +7,26 @@ python train.py \ --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ - --name ROMA_UNSB_003 \ + --name UNIV_2 \ --dataset_mode unaligned_double \ --no_flip \ - --display_env ROMA \ + --display_env UNIV \ --model roma_unsb \ - --lambda_GAN 1.0 \ - --lambda_NCE 8.0 \ + --lambda_GAN 2.0 \ --lambda_SB 1.0 \ --lambda_ctn 1.0 \ --lambda_inc 1.0 \ --lr 0.00001 \ - --gpu_id 0 \ + --gpu_id 1 \ --nce_idt False \ --netF mlp_sample \ --flip_equivariance True \ --eta_ratio 0.4 \ --tau 0.01 \ - --num_timesteps 4 \ - --input_nc 3 + --num_timesteps 5 \ + --input_nc 3 \ + --n_epochs 400 \ + --n_epochs_decay 200 \ + +# exp1 num_timesteps=4 +# exp2 num_timesteps=5 \ No newline at end of file