From 133f609e79f2749fbd56b323fa25b009271cc291 Mon Sep 17 00:00:00 2001 From: bishe <123456789@163.com> Date: Mon, 24 Feb 2025 22:49:38 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0image=E5=85=89=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- checkpoints/ROMA_UNSB_001/loss_log.txt | 10 +++ checkpoints/ROMA_UNSB_001/train_opt.txt | 3 +- .../roma_unsb_model.cpython-39.pyc | Bin 19104 -> 20190 bytes models/roma_unsb_model.py | 76 ++++++++++++++---- 4 files changed, 74 insertions(+), 15 deletions(-) diff --git a/checkpoints/ROMA_UNSB_001/loss_log.txt b/checkpoints/ROMA_UNSB_001/loss_log.txt index fd8dd2f..11f952c 100644 --- a/checkpoints/ROMA_UNSB_001/loss_log.txt +++ b/checkpoints/ROMA_UNSB_001/loss_log.txt @@ -68,3 +68,13 @@ ================ Training Loss (Sun Feb 23 23:13:05 2025) ================ ================ Training Loss (Sun Feb 23 23:13:59 2025) ================ ================ Training Loss (Sun Feb 23 23:14:59 2025) ================ +================ Training Loss (Mon Feb 24 21:53:50 2025) ================ +================ Training Loss (Mon Feb 24 21:54:16 2025) ================ +================ Training Loss (Mon Feb 24 21:54:50 2025) ================ +================ Training Loss (Mon Feb 24 21:55:31 2025) ================ +================ Training Loss (Mon Feb 24 21:56:10 2025) ================ +================ Training Loss (Mon Feb 24 22:09:38 2025) ================ +================ Training Loss (Mon Feb 24 22:10:16 2025) ================ +================ Training Loss (Mon Feb 24 22:12:46 2025) ================ +================ Training Loss (Mon Feb 24 22:13:04 2025) ================ +================ Training Loss (Mon Feb 24 22:14:04 2025) ================ diff --git a/checkpoints/ROMA_UNSB_001/train_opt.txt b/checkpoints/ROMA_UNSB_001/train_opt.txt index 4d2cd07..638f42f 100644 --- a/checkpoints/ROMA_UNSB_001/train_opt.txt +++ b/checkpoints/ROMA_UNSB_001/train_opt.txt @@ -1,5 +1,6 @@ ----------------- Options --------------- - atten_layers: 5 + adj_size_list: [2, 4, 6, 8, 12] + atten_layers: 1,3,5 batch_size: 1 beta1: 0.5 beta2: 0.999 diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index ac7f924b59dc6de3192c8da5e0c2f482b13a1437..090787dd4fc119728a0f623d6c65c9e8a6d4410a 100644 GIT binary patch delta 5076 zcmaJ_X>c4z6`r2iWA@aw(n_{$ttHv=+L9v44whu=@}c;MLk`6u8#~K*XC&=vwY%vV zVJv2&1lzGpStPiDgn&W9sT3iH6E>+zj4>6ULO~S?P@Izh`4L%hK`_M?pehOP_2{r& zfpzuM>({Se_j~>I>o=bM0{Q)E5>x^KuK>Tt?ma)W`u^vFhI_&XT^F=%(DWP-lca?G zj=Y$c@{YXIB-6wcP04ha&X6!8(0x zLEB@vQf||2dQ?$$fri)ga(iG6MxnRvHX-UBlZYUABlEwvpP89^=HZKHAAIY@Z_Yh& z?$X20zxAD|x$iu~e&cQ^imaXe@`c$W56+#Ky!e%;qU;0rx>eg6BRd--cN~b!PChm} z`Q+@EzVY@2HhcWs+@nV?e*Nt1xyRX?p_Oc0F0QoIKE(7p?&UXA3B0F%AOlZ`)B}Gl z_lTzD)a2-;58y{-5JhELgd>hmB(+19$Dm3^&t-Lu7IW#SrerLa_Q{-*QQ4bv)r603 z!~xF`Rn(-GF?6~T`m_c~1j#BOvC_riIMwuFC8w#e(2{I%6*cC;72-G)D_vyc_JT0U zo5jzg zyEtf^(@xWwa;c8%n5#VTQjgoWjT=~6O{_lz{{Qzh82$MgUsU1Qx>LUQi& zU%SY%p&+~2>uI=re)6sFK6CNuGd4KqzWqEvEYcLY5}23Ip8#mqv%J^mo`)wp>0PmV z5#)&VK}|8nC>$0ZcxS{I%cZr*UH=Qpb0=P!y>LG+^6a_UvoCT?e>LQ0Z+N{EtqV5* zH#+%{y_1Wlk6b!F1^nWo6nqkB2NI5T1oY6p5xB*gCYR5ORaYyp){RGR>rt?~Blf z;JQk`Ky86e09eJ)d7rlV3=l=~qw0-$GbI?%8X z67m7lXZpi}83+scpc#Y<;a24-Fgz)#t}58rB;-x_E5e0xPl~2I0~gE`;xaT3uEX4c za#F-r1Xh$3n?+`ESnvt?60-y_>?!OCyiTh5c|RdpE+v=KGOFcLHA6`Z^CM;3iZ+oK z*r$rdt8av0(Y;7+1`?I*JasDy`dLYFC3z9lHZ!fbj!hL;_icdzOG;;lEJ+_5wd8|J zV&srQ)mYhw9A?Znf3Q?j(sUQvFJ)G7cZnr)?8k=>(E(iPc2-j|)MPJ;{o9c6Flaf-gv9B8%*~64TcUeYyw9e%4;P$-PltCYRG@cAaYlW2IKvAPiZ)5shZF zG~~u{jcV+^vM~8|;gPbR6Vky(%G><33k{$u9b!+F*OHy=wekVyeh|FD)>icFJBULG zBr1@oiylD!ZX|Ietw>OvejEwrUz=Qy9J)ajAisj9JAlX}2v-O;S;>A`ahAMVI2L|@ z>~#$ZYRA`!+6hmp1F$Bf1a-6FRJW*IixRimt@a>H02cC$067z~E@-;dQg~>4r$ih9 z2=uY7RcFaT_SdR+tL}jnEoVldL%QWmWHY)!DTt_E5SQ;Gu*cwr1g5)wdGYxU;o!LnCvoWVaMGH7RO}#~6}?ZbY5@ zMZJXnUUDJ9q;m;h+3|WcIXDPMkuf5ALP=|SBs&<%X*8kbjAS+w8BA(vl|8cZcH&`w zUK#cF!6eIpR@w5JPB@-@HO;Q~L8$k#&(yR_q=FQ_U-P&``dF&2LEK0ft9zjCK~SI% zAz3;QeijU5o530SdA6l~W02;d9~Cd%hoUdBd+T%LB)eQ6S-&0RRv)0 zoj!o}Zo&a?CZWaC%3+P_^wX?!^)iww+`jr4A*b2vYi_Hwq=b^?#}1s4fpB0wKu@y% zhBmR06vi7mL^4=-KDvQ~Idq;vqfU77lNtIFs$O94uM3k`S;hJWe+&m5xCQze+qJ%$ z{Dcjz?GdBGVY_cNeT@8pbv3UbrweyAf069MTt7fJp&eem_<0hib|gm?V;HR6 zLk*fj!_Gt+k~=+&+UPWY04UycdK2qx86l_G<1KfRZ=PyhPej((daU~Id~8wmGH_A1 z?crb8Pg<`de`D{p)|9^tA=UB71>%BB>xbGRNLXvz+M?wk!;Ni&rcXKx!);%cCIHms z0OlaHTC-#ZfE{m|m(_7uDVH?^FQl$heIw#SPI?q)?rWFtrI97)Z6Nel;+K z8|;{uf~hN{om`3=S<>9hWw@;+*%mGnQm{^{ePYpQ*Sr{)TvF~fdQ^-}2NqAMnlB5d z#UuC>^{7Q$>bOQ)%%vq)Nqa%IAyuQ6sPZwWj-V{7GCZli)K-vBbfMwtZSYVDsV>+| z=@h()RYHCl^tVs%FqfHCGx!MgLsV`6SwD}uzWG2+M7ntovUV_NV;90$RqSr?!-{miA>6i19gS@O!f-O6a4& z@XLG?4IukVAOlfid7;FO8`%*pqmQr9we(=~=$MgAL&dG6<9#jjLwB_t==^`ft=A53 zzj}Jxfli+BU_pNZOhZ)zwNP$2x|g1WjaZVQjM2xC6Eiu>sZgpMrrTMey~j0+yM2`X ztG#g*&jJqZ6YyNxcQt8fqxQ}GC~CdKc68LNOC$Rv+HzoOq8G7`HIo&vor-6&N&Rpe zeVBc=V9Xs&T?Z#7T?;kz_e0)R4Im#zK!Jksm|@Lmyw%cyE~6h@a#a3gTRuN zF(Zq$9lZs|??m!5B>eHkm!19`h~*iCuZ)ZVshu`JKNf*5c=oR7{Gc!uy*ylAeFU zgn?+lK3IQR89k^f@$K<_$pIR|HThAMBUg%~49NwULw|^53%jSQgREjtc74wI8xV9C z`nqo*cO#tbLLD<^w8)M=o~oU@xl34@Q%tw^8a}j)Zdx8F$)F&E4?_F(@Z!i}ROEU5 zb#}3*#c4tR4z|3vCy2S%4UOmDXA8r<1Efz#Emz5i;~_%FjcT~8erPL>;Me}j=^Ef8 zrkh`XOb0G#LJF$U6qM9Z$6{|4`&nPuSqpqEd#A6rjXQ+CP>ff!j)yDeoe$RyeK*-L z;Srk;Onhs!hkM7}<{mG(=slgzn#Z!D+4gBtwgzosfReqSHvk3ryR3aj zZT$?0XfY0!AmLHskUEJ2Yf%SVY7gkA#P% z5vTI&vUo`VZ0M&YXBH}&4_kO0OpN(4m~lM^i$IQQCPEQ}%}%cI1l;=dtg-(q67EU$&W)s%-LiA9GYYET zU}tvjn7F$3w){H!n;6xM;jBuha9%5tUm*ECl0PEh1!^=+b+EKg>6c32D;YM(Y zD~@k94y-|v1M(g;y&4D`+7;*u$YI$d8*-KGlWXOAt`V(SZ6~KyPAC@nt+qHC(?xWci zj#YbzV6bEg@f#oz8?cgKLQ>$QE)>QwB#=rfp&*riQkN7a70aIrB$%o|NL3)+bBYA1 ztZJTmdb)dNx_f%&-g<(Zc!c<3Uaw1lZ*ar+2ixg)eWqi#4Z1S4P0)deKj-t{S2?0xOj+?kWBdOHtZJ;Z+KU&>xr zig)++hdwKH{^iH#&Yrw@?DYAwXF+^nb}GSI`koc36aMENBf6f^5=rWT9#sHT8!g5W zTT+juSidvaW}~gha1PR#me5nC5pm*J6#EfZJZ&UWdei_LE6@gqm8cM8#1$;r0;FLF z5R_uEm`sG(%_TwB=L$~7MdZ$ZI5n<_d-hHsEa8M;ky!#hC6+KOTHslY1ch?pFv&`$ zZCJKMw_pKJ6KAA-V%CTo(vU4WqbqHfSEIR~F9{&^T_SZNsJaPE))XdzmS0GBu*_&>+4I{cO8^hPE!`9DU_1CaTzk`i=R8#bRVhTsw-BM|f!K)7TD(1^jJ@hD z3*G`+h7Uk+x}a3BJQyG`8kcc&{wsn14rrJvMmem=BE!Ny+fgAaXYJrh z@Kw2Rlne?-C1~Q5Fh!=M13VOWDOrX0Md(YK4f;-toFJAmONd~JKsi~jFVOEY#S;=V zn`XCcm4YS*AQBKdw11^3^8@({6&8x)t5hk?^!8tk7Y7?O3T|d z-HgSD^0wO&`a$L`Sx+X|mXe8Ut_K&;8xVQ{REgFgeG|fU2#e|N97?B0xZ|^+xq+=N zT~^1V2|>iD<-JrNOcH_g)=;%(*YK%2$18sLlwUwWINjt2p8q;`j;}XltxjSYV4js z9ofvD3-rl*ficVMm0b-{9Qq_b?Cfa0{DWhyJn_5*ngnt>%%AwHSz zhecJgT;(ZpIu{QfA?szWO>5U|Gd9Q@0b;&E>%f>}jAsQn@FJ8~Zn@)Sy{*Z~awd zF!xOT-Gn^Gn$%sD5h)%^Q*=5|W=fTL#h^#nV`^BuBIMps+eFfkvp23K0Uo7i(5P(E zSRzHghqJuIZfXjUpRu8)de0slwBZuy>+D!l75OPU-L#eTN(m`a%@zHcjrw38G8=)AX4dF(1 zZ+L{<%l;YOOET=1wGBjKx39gs>Rmp#AbSZY)lt~uWwx|+HF<}9u61d}H2BoOn@%8h zq7V=9h(3`WZe3Lr1Rkzz6Ex!kTkcHjlhUNwq`4u(e8WqR37Q9}Wk@|?J6y=Z3LJx0 zXO+xCCPtefGm*E9^GZ2y^}LPd1skhBZDUv%Zqd9W;sZANK3dC`p?`=p#T-2GC2rfD=x<-J;jEw1+0#jiEwkO?!y`90vPT zl6MX2X4FiN=qW?(rbmFAm(18WeF`ZtmC4I7N@Is;H@mW-%l;N{=_71iN8>P0dYgV5 z)ZapCCQLn9NW>i|bpZ{yD3rofbTk1~vjVw(1g!8hfp8KaZ^JW6&tZQW`%wF1qtTvd zDxEM6h3SLrjgA^J%l_KYu&R*sj`=<1=ZQ7+NmPCt;m-&!vbN66_OnPHXNNndYSmyN z!LI--Px>OZL@I93kFX~W81&x)dwOHz=abNnNPEn1MBZaY*MuXsiD)7fkrBg`@13vm zZICNWWXC`_5cu7;aKZ8F{SfetY1ZBqbySgZ_^V>ayHm!oulbNV>}7BSfQ4|T_F_0%B#w6*vp$* zrVXtphCl^8+ag2teqiMrfm=`Vp`a4`tp+rWU8LDgiElNHM7r?SwtOElA|tkGCBsWlPW=Iv~*vf_?Ca zdXJgg>y{q}7FBT2iNL+U69bPdtwvrAJF`7B`9AjH830d}j82o|X3R{aQ-w1sCEvnbuf@M1 zwHE=;6U8*A7&KasYpY?m@2vbn3i_CTMiqcUMB-^u+7-W2tN4^k#ifiXLB*}qDa$EJ zs#Wl)8c${pQ688%V9_%OxQo1ErfGbrIhCR>vv+qMa!=w29w+(&8|z)-*pFmNV3Rxj S?AhKT5@Ij)hNOL^_5T4V{QPeK diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 3563ddf..70f8b15 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -78,7 +78,7 @@ class ContentAwareOptimization(nn.Module): # 计算余弦相似度 cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N] return cosine_sim - + def generate_weight_map(self, gradients_fake, feature_shape): """ 生成内容感知权重图(修正空间维度) @@ -100,16 +100,66 @@ class ContentAwareOptimization(nn.Module): k = int(self.eta_ratio * cosine_fake.shape[1]) _, fake_indices = torch.topk(-cosine_fake, k, dim=1) weight_fake = torch.ones_like(cosine_fake) - + for b in range(cosine_fake.shape[0]): weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]])) - + # 重建空间维度 -------------------------------------------------- # 将权重从[B, N]转换为[B, H, W] + #print(f"Shape of weight_fake before view: {weight_fake.shape}") + #print(f"Shape of cosine_fake: {cosine_fake.shape}") + #print(f"H: {H}, W: {W}, N: {N}") weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W] return weight_fake + def compute_cosine_similarity_image(self, gradients): + """ + 计算每个空间位置梯度与平均梯度的余弦相似度 (图像版本) + Args: + gradients: [B, C, H, W] 判别器输出的梯度 + Returns: + cosine_sim: [B, H, W] 每个空间位置的余弦相似度 + """ + # 将空间维度展平,以便计算所有空间位置的平均梯度 + B, C, H, W = gradients.shape + gradients_reshaped = gradients.view(B, C, H * W) # [B, C, N] where N = H*W + gradients_transposed = gradients_reshaped.transpose(1, 2) # [B, N, C] 将C放到最后一维,方便计算空间位置的平均梯度 + + mean_grad = torch.mean(gradients_transposed, dim=1, keepdim=True) # [B, 1, C] 在空间位置维度上求平均,得到平均梯度 [B, 1, C] + # mean_grad 现在是所有空间位置的平均梯度,形状为 [B, 1, C] + + # 为了计算余弦相似度,我们需要将 mean_grad 扩展到与 gradients_transposed 相同的空间维度 + mean_grad_expanded = mean_grad.expand(-1, H * W, -1) # [B, N, C] + + # 计算余弦相似度,dim=2 表示在特征维度 (C) 上计算 + cosine_sim = F.cosine_similarity(gradients_transposed, mean_grad_expanded, dim=2) # [B, N] + + # 将 cosine_sim 重新reshape回 [B, H, W] + cosine_sim = cosine_sim.view(B, H, W) + return cosine_sim + + def generate_weight_map_image(self, gradients_fake, feature_shape): + """ + 生成内容感知权重图(修正空间维度 - 图像版本) + Args: + gradients_fake: [B, C, H, W] 生成图像判别器梯度 + feature_shape: tuple [H, W] 判别器输出的特征图尺寸 + Returns: + weight_fake: [B, 1, H, W] 生成图像权重图 + """ + H, W = feature_shape + # 计算余弦相似度(图像版本) + cosine_fake = self.compute_cosine_similarity_image(gradients_fake) # [B, H, W] + # 生成权重图(与原代码相同,但现在cosine_fake是[B, H, W]) + k = int(self.eta_ratio * H * W) # k 仍然是基于总的空间位置数量计算 + _, fake_indices = torch.topk(-cosine_fake.view(cosine_fake.shape[0], -1), k, dim=1) # 将 cosine_fake 展平为 [B, N] 以使用 topk + weight_fake = torch.ones_like(cosine_fake).view(cosine_fake.shape[0], -1) # 初始化权重图,并展平为 [B, N] + for b in range(cosine_fake.shape[0]): + weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake.view(cosine_fake.shape[0], -1)[b, fake_indices[b]])) + weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # 重新 reshape 为 [B, H, W],并添加通道维度变为 [B, 1, H, W] + return weight_fake + def forward(self, D_real, D_fake, real_scores, fake_scores): """ 计算内容感知对抗损失 @@ -123,7 +173,7 @@ class ContentAwareOptimization(nn.Module): """ B, C, H, W = D_real.shape N = H * W - shape_hw = [h, w] + shape_hw = [H, W] # 注册钩子获取梯度 gradients_real = [] gradients_fake = [] @@ -146,8 +196,8 @@ class ContentAwareOptimization(nn.Module): total_loss.backward(retain_graph=True) # 获取梯度数据 - gradients_real = gradients_real[0] # [B, N, D] - gradients_fake = gradients_fake[0] # [B, N, D] + gradients_real = gradients_real[1] # [B, N, D] + gradients_fake = gradients_fake[1] # [B, N, D] # 生成权重图 self.weight_real, self.weight_fake = self.generate_weight_map(gradients_fake, shape_hw ) @@ -235,7 +285,7 @@ class RomaUnsbModel(BaseModel): parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter') parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer') - + parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field') parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers') parser.set_defaults(pool_size=0) # no image pooling @@ -373,7 +423,6 @@ class RomaUnsbModel(BaseModel): self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] - def tokens_concat(self, origin_tokens, adjacent_size): adj_size = adjacent_size B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2] @@ -394,10 +443,9 @@ class RomaUnsbModel(BaseModel): cut_patch = torch.mean(cut_patch, dim=1, keepdim=True) cut_patch_list.append(cut_patch) - result = torch.cat(cut_patch_list,dim=1) return result - + def cat_results(self, origin_tokens, adj_size_list): res_list = [origin_tokens] for ad_s in adj_size_list: @@ -405,7 +453,6 @@ class RomaUnsbModel(BaseModel): res_list.append(cat_result) result = torch.cat(res_list, dim=1) - return result def forward(self): @@ -468,10 +515,8 @@ class RomaUnsbModel(BaseModel): self.real = torch.flip(self.real, [3]) self.realt = torch.flip(self.realt, [3]) - print(f'fake_B0: {self.real_A0.shape}, fake_B1: {self.real_A1.shape}') self.fake_B0 = self.netG(self.real_A0, self.time, z_in) self.fake_B1 = self.netG(self.real_A1, self.time, z_in2) - print(f'fake_B0: {self.fake_B0.shape}, fake_B1: {self.fake_B1.shape}') if self.opt.phase == 'train': real_A0 = self.real_A0 @@ -496,12 +541,15 @@ class RomaUnsbModel(BaseModel): self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) # [[1,576,768],[1,576,768],[1,576,768]] # [3,576,768] + #self.mutil_real_A0_tokens = self.cat_results(self.mutil_real_A0_tokens[0], self.opt.adj_size_list) + #print(f'self.mutil_real_A0_tokens[0]:{self.mutil_real_A0_tokens[0].shape}') + shape_hw = list(self.real_A0_resize.shape[2:4]) # 生成图像的梯度 fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0] # 梯度图 - self.weight_fake = self.cao.generate_weight_map(fake_gradient,shape_hw) + self.weight_fake = self.cao.generate_weight_map_image(fake_gradient, shape_hw) # 生成图像的CTN光流图 self.f_content = self.ctn(self.weight_fake)