From 4af0d7463def305782529704d0d4ac3706f1a19e Mon Sep 17 00:00:00 2001 From: bishe <123456789@163.com> Date: Mon, 24 Feb 2025 23:00:25 +0800 Subject: [PATCH] withoutCNT --- checkpoints/ROMA_UNSB_001/loss_log.txt | 1 + checkpoints/ROMA_UNSB_001/train_opt.txt | 12 ++++-------- .../roma_unsb_model.cpython-39.pyc | Bin 19104 -> 18789 bytes models/roma_unsb_model.py | 18 +++--------------- scripts/train.sh | 13 ++++--------- 5 files changed, 12 insertions(+), 32 deletions(-) diff --git a/checkpoints/ROMA_UNSB_001/loss_log.txt b/checkpoints/ROMA_UNSB_001/loss_log.txt index fd8dd2f..c7fad7f 100644 --- a/checkpoints/ROMA_UNSB_001/loss_log.txt +++ b/checkpoints/ROMA_UNSB_001/loss_log.txt @@ -68,3 +68,4 @@ ================ Training Loss (Sun Feb 23 23:13:05 2025) ================ ================ Training Loss (Sun Feb 23 23:13:59 2025) ================ ================ Training Loss (Sun Feb 23 23:14:59 2025) ================ +================ Training Loss (Mon Feb 24 22:59:41 2025) ================ diff --git a/checkpoints/ROMA_UNSB_001/train_opt.txt b/checkpoints/ROMA_UNSB_001/train_opt.txt index 4d2cd07..8f10014 100644 --- a/checkpoints/ROMA_UNSB_001/train_opt.txt +++ b/checkpoints/ROMA_UNSB_001/train_opt.txt @@ -19,9 +19,9 @@ easy_label: experiment_name epoch: latest epoch_count: 1 - eta_ratio: 0.1 + eta_ratio: 0.4 evaluation_freq: 5000 - flip_equivariance: False + flip_equivariance: True [default: False] gan_mode: lsgan gpu_ids: 0 init_gain: 0.02 @@ -31,11 +31,10 @@ lambda_D_ViT: 1.0 lambda_GAN: 8.0 [default: 1.0] lambda_NCE: 8.0 [default: 1.0] - lambda_SB: 0.1 + lambda_SB: 1.0 [default: 0.1] lambda_ctn: 1.0 lambda_global: 1.0 lambda_inc: 1.0 - lmda_1: 0.1 load_size: 286 lr: 1e-05 [default: 0.0002] lr_decay_iters: 50 @@ -47,14 +46,12 @@ n_layers_D: 3 n_mlp: 3 name: ROMA_UNSB_001 [default: experiment_name] - nce_T: 0.07 nce_idt: False [default: True] nce_includes_all_negatives_from_minibatch: False nce_layers: 0,4,8,12,16 ndf: 64 netD: basic_cond netF: mlp_sample - netF_nc: 256 netG: resnet_9blocks_cond ngf: 64 no_antialias: False @@ -64,9 +61,8 @@ nce_includes_all_negatives_from_minibatch: False no_html: False normD: instance normG: instance - num_patches: 256 num_threads: 4 - num_timesteps: 10 [default: 5] + num_timesteps: 4 [default: 5] output_nc: 3 phase: train pool_size: 0 diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index ac7f924b59dc6de3192c8da5e0c2f482b13a1437..15c98680c8a2991c3c8302bf82e01a75e6f06846 100644 GIT binary patch delta 2053 zcmZ`(Yiv_x818p^S?O-wt?e$|rE533jcpyn^2hLFDM3^)At3>bBt}J(5^oVyqH_6Rf(D82+p?`Hw0WO==lP!Z z^1bhO&itF~+8oRBXJ#5y_Cy6u#jF?1Oo6@T{S`u3E#*kLQl4a%3TAaoHKhux zB#WYv3Z)`RGpm7SOU2BJsv<>YQ6n2D_jKx9!6Lw7*_JE~YO>J0Xi*7_UGN^=aT4unvP|S?V%bE zrdfy9f#w;d$eqf1jMRm;VGkX<^f$e5LYL91!0};8=z# zTD^{Nmlm_oVXw}*P7`$s*LO3d0H*DOtOqXI9c(Y$v^Se~lX4S6+#m;nV#L2Q$fJA; z+8sp}Ezx=@I1-YBVqaVzk|!n=Q4YkOceF7$Q)Oju!Ii2&i-tE-l)VT(_2%9l->MP+ zj-V*}$AbSY&V`!a(W0F&U2V%%>(%XQJ!=+>Lay4P&S$3Bht+>FW{aJ0x>@cqRK~Rm z_lIO2!!*hLtqEn(+jq_&s2W5;V-1w7K#HBYmz!BcN#TVv1E4+`ux zIM%q8jn26iOMasWbX0C6_!+^OmMr}jxa4YJzrx?HGBy*lHeJzHKCHC|sE@PgyURJ( zQ2>qZ8r>lZHwW9?wu(!*9CEU_VKk_S%J`0;9CdajIulnb{z-lWj=CS$eM1h1;EubF z9Ro*O8M_8OZFa}INXIow2;258DPE!HOdj+{#0{c69*XX4;s;@WTPd4|<85B!J0$v! z;Cq6TbL|^-AJOVC47MMz+`?s?4f!-7IS}P{=t&rk^1E=meN%CwghZB=09rbljzNye z#P-GHb>~?74fGmvW+qRN5%G@=OMY>q_*7_!r<3|oN|!;aRRjitQ^cKxh^LbkzGyqYALnRgv-0=;T=NrHSO^8ST7Cer*I_z)$NN zwV&fMFIL=DB2-eLPHXZG8t&KGw_$U#m_(^(!Pi~Co~E8?Y%>2!Lc_3M32Yamp64^xUX+SC&ys!sHl-Yi~1oE4^f%4crVvl^e{=@UUdHmUfP z(74P!NXOM<__8E!!AVjjYAq+%#IWy=7<-kSxcPCty!6gLC~+w|`ZDN*VsrwN_#@oL zRM{$@A$hmOfUNK`sOx>3bsfFf%UBV7(pRRdpnQknW}lUP44M62=725zk7=z)OK`M* z*z+=#M&APdDZy6+KN2L~QTjmetps#)I8B*T`*;<_wL`(mWE; delta 2493 zcmZ`)du&r>6u;l?V_o+iqwCf^`q---V-Ln&u))~a1}KKOs4SS9uJ^8V^u>4YkToOI z5FQfB>yv~;gjExXF)@;&fIv)$Bxs_EijR~;z=%2p6#p>61miikqZ9*M&M)70zVkch zeCM3o@8%&gbC~41b8?Iv{CSt0^LEjbxm$|K8MdQ%pG_+ki)K+9)e>%s8|TClg&)_5 z1!AE%Pb?MZNA;}LY@61JHqjo{%-J{;B36mjX{;L17}6|i6=c+}AWSiKYrs)7pr`VO!=*1=C(kQ{eGZ`^% z1HX;mrh!ZA;j+Bg9p&-lS|*I1cH82c z&TME1O3Hu`^vrUEoo{f?mvA||jqSnxubxvgk4_3c>QlgQ0?D?|u zadY`8Qr80!Ic*_H{qts6>xq%@|LV17l1~bG5W;R%tjpDbZ7xV3i2Np7Q+Z=r%!3MO z#|vKP-r%=GZIr^9c~fM(0Hg?HK8OjV7{r{-0hWN2f>=PxfLlIAV%bxiV1jk-0iGTM z4;64$B<2^PeII93>`^#VJb|L}|s+(S9@x7v=tO| zC%}ki?ZC>B%rW;SNeH?F(nw}125YV<(%%Ied4R2{v6pHMd?DYZF_2b1pD)yyHATdv zj(b^ry5=4s_3_EN4&piijD$|1ZeNg|!Z2r1Q6Ni-K&8zSJ}SwAm%2szDZA=yA)m2a zR~fm?>Rkh*H%?t`JY4&o1us;7kK9MAn8-vg>TO5k)`rzQ*~>OHt|wu3sqtjlZ6siw zXgRR7OomdA~G$rFA>-W~t3>8h^cN$oq5Y~R(ke*Ize#8`E^ z-S!16hn*y0^g?P%Xp?@fP};BLi+-8g8kNBNq%O=j;j8* zL7mVlzJNrpp~Y_8=TOP*7kY(Y$R|&>(1V~G(VUC%-3%TsB>@{_iCqY$XDAv*JMaLPyJ33hTwyrm&Zk~#F*LmcnWh7 zsRLEUaBe%39+#K1C>>F42L&LXd=FFLNgi&&=O3h6rl^@WA@cR zAE{xc!KZb2u>jjRINH4nkD3uaNB9omCj|9Fzk{pw2sly{+fT8yv=-x4u#6#>roqs~ rT8A8!YAtT!;YSEKXbFQ7qMnTlgF$+MZ5f)(n#Pk1gb&z1L!15saQc+A diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 8dbc273..497ec9c 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -166,7 +166,7 @@ class ContentAwareTemporalNorm(nn.Module): Returns: F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) """ - print(weight_map.shape) + #print(weight_map.shape) B, _, H, W = weight_map.shape # 1. 归一化权重图 @@ -204,23 +204,19 @@ class RomaUnsbModel(BaseModel): parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))') - parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers') parser.add_argument('--nce_includes_all_negatives_from_minibatch', type=util.str2bool, nargs='?', const=True, default=False, help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.') + parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers') parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map') - parser.add_argument('--netF_nc', type=int, default=256) - parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss') - parser.add_argument('--lmda_1', type=float, default=0.1) - parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer') parser.add_argument('--flip_equivariance', type=util.str2bool, nargs='?', const=True, default=False, help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT") parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization') - parser.add_argument('--eta_ratio', type=float, default=0.1, help='ratio of content-rich regions') + parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions') parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers') @@ -261,12 +257,10 @@ class RomaUnsbModel(BaseModel): if self.isTrain: self.model_names = ['G', 'D_ViT', 'E'] - else: self.model_names = ['G'] - print(f'input_nc = {self.opt.input_nc}') # 创建网络 self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt) @@ -284,9 +278,6 @@ class RomaUnsbModel(BaseModel): # 定义损失函数 self.criterionL1 = torch.nn.L1Loss().to(self.device) self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) - self.criterionNCE = [] - for nce_layer in self.nce_layers: - self.criterionNCE.append(PatchNCELoss(opt).to(self.device)) self.criterionIdt = torch.nn.L1Loss().to(self.device) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) @@ -459,10 +450,8 @@ class RomaUnsbModel(BaseModel): self.real = torch.flip(self.real, [3]) self.realt = torch.flip(self.realt, [3]) - print(f'fake_B0: {self.real_A0.shape}, fake_B1: {self.real_A1.shape}') self.fake_B0 = self.netG(self.real_A0, self.time, z_in) self.fake_B1 = self.netG(self.real_A1, self.time, z_in2) - print(f'fake_B0: {self.fake_B0.shape}, fake_B1: {self.fake_B1.shape}') if self.opt.phase == 'train': real_A0 = self.real_A0 @@ -540,7 +529,6 @@ class RomaUnsbModel(BaseModel): def compute_E_loss(self): """计算判别器 E 的损失""" - print(f'resl_A_noisy: {self.real_A_noisy.shape} \n fake_B0: {self.fake_B0.shape}') XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1) XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1) temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean() diff --git a/scripts/train.sh b/scripts/train.sh index dea6a51..77b67e2 100755 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -14,20 +14,15 @@ python train.py \ --model roma_unsb \ --lambda_GAN 8.0 \ --lambda_NCE 8.0 \ - --lambda_SB 0.1 \ + --lambda_SB 1.0 \ --lambda_ctn 1.0 \ --lambda_inc 1.0 \ --lr 0.00001 \ --gpu_id 0 \ --nce_idt False \ - --nce_layers 0,4,8,12,16 \ --netF mlp_sample \ - --netF_nc 256 \ - --nce_T 0.07 \ - --lmda_1 0.1 \ - --num_patches 256 \ - --flip_equivariance False \ - --eta_ratio 0.1 \ + --flip_equivariance True \ + --eta_ratio 0.4 \ --tau 0.01 \ - --num_timesteps 10 \ + --num_timesteps 4 \ --input_nc 3