From 7a6e856b4bc350097341b1aa00ec701e98067e15 Mon Sep 17 00:00:00 2001 From: bishe <123456789@163.com> Date: Wed, 26 Feb 2025 22:24:17 +0800 Subject: [PATCH] running UNIV --- .../roma_unsb_model.cpython-39.pyc | Bin 19028 -> 18903 bytes models/roma_unsb_model.py | 22 +++++++----------- .../__pycache__/train_options.cpython-39.pyc | Bin 3071 -> 3071 bytes options/train_options.py | 2 +- scripts/train.sh | 18 ++++++++------ 5 files changed, 20 insertions(+), 22 deletions(-) diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index 5f5b809df6c87ae575dc5db0bf146ffaaaabffdf..728de33f792daffe23a976b3dd95e459a933e128 100644 GIT binary patch delta 1743 zcmZ`(Yitx%6rMBtYInB_rA7KM+m`NjciXabD?%68(k(5t5%56+(r6~rouO{o7k8%e zNGpj6F;R)&DWFoLVEjQKBxLww{9!OeBQZgZ(l{C|K0|^AOoWI6-ZNXFF*?b7Irlq{ zd(XM&+|MUq3c)9rmAM7}Eoyvkcrtt1_j48W;q@wig;Q}U_I-8`#)J_;DbYa@F;x9d zs!<8d;)A+fX;MP7m@v1DD9y8&7Tuw=DvM?@QO+z@+WukQbRHF+5XZ%F+nBIXh{P~o zvmMK8&*CzF()g@a6>+Ekju>&_vAQUJQzt44@l z6vm-5Ih|u?$?5|l$~9G&n7Vy5%~VYqX0pOAVs}fse~PreR6bi6)m5oi%H%c8bP`Iv z>RJsl6j1YYQf!5M^j9sGt3@j$sj7Irx?}7F86770h2R`c$9tR~k$4=Nm%mhXn~Nq44BFGVl*XLE zSK^Sy-1zD8&UPwMahrP)C_Jm=cK>dO&0KNNNUqopVJi)9M$Qf?veYAON)NCq@>oZ{ z3rM1?WEap}*#uuAB>b=t>l2STzUB0<#*RcSbW&m2K>}+>%XKcf(%xoiCO@o=X4RdA zVr$GH&0exzN$?$h(%I&?#6>{Is)VnEOHTf;=D#*=JhEy8;5B@=%U?4lNWn3|E{p^f zS6mSJD(6&ySiIa7na_C|_15+HcF3%tDpK*5+8DJuvvMfMsboXF)wMj@-NUaUGx}zbhLibAI9+LIcQ<9e6trlx?l7ikAFxwqofA=O$j}^N(?Cb|_5-BNmZ-h!1 z<}g@VU^8c90V0D0wWM+M4oHu0!K;K3W+ODY$_T+u0d!Fxv(y|+c1_m}oyVQ``qn!rUkNw&4J8hcs-55~f*%R45?C+KpCr+! zF`5;2iQs#J2Wh`nyx!Lgjo8p1^jHPDMye0-(f$fJk81xe&+|mM2;Rrb{oDQqZ2;Ik delta 1859 zcmZ`(du&rx7{BMXT_1FXZMa7_*E0ILdvxuX?2)lh*3nIIMqLCJ&9&>j+wRu2<@7ed zvV}hk(GZ8?$76~l42>F)7*a9thatfb42Fb1h?onS5FhY|L8AmsMDTpKWSfa@xxaqr z_dU*czVn^$8j!2|i6dTIY~$G93D>#NS5l`OwNwayaU z{iu1Kdy${vXACTpq7_o9#7j@i@~C>A6V^hHb6d#X)Rc@*4oh)ypr?0g)pl7K9aG&S zY3gQccQT#LcA^xB1-1x{?kxi3o#*qPRm(i&sv}Tc`>=r$Rs?}))=WB`6tl{dEby8+ zm5@b6Qip7d(FF_CxPry@tgMQXJQB|()hzV8UM6edxGS)Nl=AFfM>}sfIPy1Kmv~}= z59DyU4!$neYiy)&G@i9;Vv{8jmsO0a}c4zGE=RuF^Kj_Eysqaw;oCsD2iM z8qd*_aIftdasW2AhskN!(e5I5VXoa%@+tGvjH)sz(>oaFAY5*DlcR8_y_S3e#T|{c zhtR@25}%O84Pp8*8qFd6f^Y#gcSMUWp?DnTJKm_|SW9VSaLA~n5?N}&6+fJ%cIfNu zc47ew+bl-F@U$F`c2j?Zp0&=j{)8gDrnn+Vv$6w6H8aK9jdz5V6zW?vc0;{86aORvxGsQ(fHCz57S b(=;&_Or_{oaJz5kve!{#M)(4f{S*HHZdmTm diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 6e4859d..d0aee8d 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -216,11 +216,11 @@ class RomaUnsbModel(BaseModel): """配置 CTNx 模型的特定选项""" parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))') - parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)') parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss') parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm') parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator') parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') + parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization') parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))') parser.add_argument('--nce_includes_all_negatives_from_minibatch', @@ -234,7 +234,6 @@ class RomaUnsbModel(BaseModel): type=util.str2bool, nargs='?', const=True, default=False, help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT") - parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization') parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions') parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers') @@ -244,12 +243,7 @@ class RomaUnsbModel(BaseModel): parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers') - parser.set_defaults(pool_size=0) # no image pooling - opt, _ = parser.parse_known_args() - - # 直接设置为 sb 模式 - parser.set_defaults(nce_idt=True, lambda_NCE=1.0) return parser @@ -258,7 +252,7 @@ class RomaUnsbModel(BaseModel): BaseModel.__init__(self, opt) # 指定需要打印的训练损失 - self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB'] + self.loss_names = ['G_GAN', 'D_real_ViT', 'D_fake_ViT', 'G', 'SB', 'global', 'ctn'] self.visual_names = ['real_A0', 'real_A_noisy', 'fake_B0', 'real_B0'] self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')] @@ -542,7 +536,7 @@ class RomaUnsbModel(BaseModel): warped_fake_B0_2=self.warped_fake_B0_2 warped_fake_B0=self.warped_fake_B0 # 计算L2损失 - self.ctn_loss = F.mse_loss(warped_fake_B0_2, warped_fake_B0) + self.loss_ctn = F.mse_loss(warped_fake_B0_2, warped_fake_B0) if self.opt.lambda_GAN > 0.0: pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0]) @@ -563,15 +557,15 @@ class RomaUnsbModel(BaseModel): self.loss_SB += torch.mean((self.real_A_noisy - self.fake_B0) ** 2) if self.opt.lambda_global > 0.0: - loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1) - loss_global *= 0.5 + self.loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1) + self.loss_global *= 0.5 else: - loss_global = 0.0 + self.loss_global = 0.0 self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \ self.opt.lambda_SB * self.loss_SB + \ - self.opt.lambda_ctn * self.ctn_loss + \ - loss_global * self.opt.lambda_global + self.opt.lambda_ctn * self.loss_ctn + \ + self.loss_global * self.opt.lambda_global return self.loss_G def calculate_attention_loss(self): diff --git a/options/__pycache__/train_options.cpython-39.pyc b/options/__pycache__/train_options.cpython-39.pyc index 495487e07ba7d02e4c922b3714d9da5a92d20cfe..76d2a2fc686afffb771fd4e7b0e6d28e428d3c89 100644 GIT binary patch delta 22 ccmew_{$HFok(ZZ?0SK6s_NPzV$oq*K07r`kS^xk5 delta 22 ccmew_{$HFok(ZZ?0SM%e?n>|5$oq*K08L#74FCWD diff --git a/options/train_options.py b/options/train_options.py index 5df79aa..c7ba288 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -31,7 +31,7 @@ class TrainOptions(BaseOptions): parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint') - + # training parameters parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate') parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero') diff --git a/scripts/train.sh b/scripts/train.sh index 0c016bf..6b117d9 100755 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -7,22 +7,26 @@ python train.py \ --dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \ - --name ROMA_UNSB_003 \ + --name UNIV_2 \ --dataset_mode unaligned_double \ --no_flip \ - --display_env ROMA \ + --display_env UNIV \ --model roma_unsb \ - --lambda_GAN 1.0 \ - --lambda_NCE 8.0 \ + --lambda_GAN 2.0 \ --lambda_SB 1.0 \ --lambda_ctn 1.0 \ --lambda_inc 1.0 \ --lr 0.00001 \ - --gpu_id 0 \ + --gpu_id 1 \ --nce_idt False \ --netF mlp_sample \ --flip_equivariance True \ --eta_ratio 0.4 \ --tau 0.01 \ - --num_timesteps 4 \ - --input_nc 3 + --num_timesteps 5 \ + --input_nc 3 \ + --n_epochs 400 \ + --n_epochs_decay 200 \ + +# exp1 num_timesteps=4 +# exp2 num_timesteps=5 \ No newline at end of file