diff --git a/checkpoints/ROMA_UNSB_001/loss_log.txt b/checkpoints/ROMA_UNSB_001/loss_log.txt index fd8dd2f..11f952c 100644 --- a/checkpoints/ROMA_UNSB_001/loss_log.txt +++ b/checkpoints/ROMA_UNSB_001/loss_log.txt @@ -68,3 +68,13 @@ ================ Training Loss (Sun Feb 23 23:13:05 2025) ================ ================ Training Loss (Sun Feb 23 23:13:59 2025) ================ ================ Training Loss (Sun Feb 23 23:14:59 2025) ================ +================ Training Loss (Mon Feb 24 21:53:50 2025) ================ +================ Training Loss (Mon Feb 24 21:54:16 2025) ================ +================ Training Loss (Mon Feb 24 21:54:50 2025) ================ +================ Training Loss (Mon Feb 24 21:55:31 2025) ================ +================ Training Loss (Mon Feb 24 21:56:10 2025) ================ +================ Training Loss (Mon Feb 24 22:09:38 2025) ================ +================ Training Loss (Mon Feb 24 22:10:16 2025) ================ +================ Training Loss (Mon Feb 24 22:12:46 2025) ================ +================ Training Loss (Mon Feb 24 22:13:04 2025) ================ +================ Training Loss (Mon Feb 24 22:14:04 2025) ================ diff --git a/checkpoints/ROMA_UNSB_001/train_opt.txt b/checkpoints/ROMA_UNSB_001/train_opt.txt index 4d2cd07..638f42f 100644 --- a/checkpoints/ROMA_UNSB_001/train_opt.txt +++ b/checkpoints/ROMA_UNSB_001/train_opt.txt @@ -1,5 +1,6 @@ ----------------- Options --------------- - atten_layers: 5 + adj_size_list: [2, 4, 6, 8, 12] + atten_layers: 1,3,5 batch_size: 1 beta1: 0.5 beta2: 0.999 diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index ac7f924..090787d 100644 Binary files a/models/__pycache__/roma_unsb_model.cpython-39.pyc and b/models/__pycache__/roma_unsb_model.cpython-39.pyc differ diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 3563ddf..70f8b15 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -78,7 +78,7 @@ class ContentAwareOptimization(nn.Module): # 计算余弦相似度 cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N] return cosine_sim - + def generate_weight_map(self, gradients_fake, feature_shape): """ 生成内容感知权重图(修正空间维度) @@ -100,16 +100,66 @@ class ContentAwareOptimization(nn.Module): k = int(self.eta_ratio * cosine_fake.shape[1]) _, fake_indices = torch.topk(-cosine_fake, k, dim=1) weight_fake = torch.ones_like(cosine_fake) - + for b in range(cosine_fake.shape[0]): weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]])) - + # 重建空间维度 -------------------------------------------------- # 将权重从[B, N]转换为[B, H, W] + #print(f"Shape of weight_fake before view: {weight_fake.shape}") + #print(f"Shape of cosine_fake: {cosine_fake.shape}") + #print(f"H: {H}, W: {W}, N: {N}") weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W] return weight_fake + def compute_cosine_similarity_image(self, gradients): + """ + 计算每个空间位置梯度与平均梯度的余弦相似度 (图像版本) + Args: + gradients: [B, C, H, W] 判别器输出的梯度 + Returns: + cosine_sim: [B, H, W] 每个空间位置的余弦相似度 + """ + # 将空间维度展平,以便计算所有空间位置的平均梯度 + B, C, H, W = gradients.shape + gradients_reshaped = gradients.view(B, C, H * W) # [B, C, N] where N = H*W + gradients_transposed = gradients_reshaped.transpose(1, 2) # [B, N, C] 将C放到最后一维,方便计算空间位置的平均梯度 + + mean_grad = torch.mean(gradients_transposed, dim=1, keepdim=True) # [B, 1, C] 在空间位置维度上求平均,得到平均梯度 [B, 1, C] + # mean_grad 现在是所有空间位置的平均梯度,形状为 [B, 1, C] + + # 为了计算余弦相似度,我们需要将 mean_grad 扩展到与 gradients_transposed 相同的空间维度 + mean_grad_expanded = mean_grad.expand(-1, H * W, -1) # [B, N, C] + + # 计算余弦相似度,dim=2 表示在特征维度 (C) 上计算 + cosine_sim = F.cosine_similarity(gradients_transposed, mean_grad_expanded, dim=2) # [B, N] + + # 将 cosine_sim 重新reshape回 [B, H, W] + cosine_sim = cosine_sim.view(B, H, W) + return cosine_sim + + def generate_weight_map_image(self, gradients_fake, feature_shape): + """ + 生成内容感知权重图(修正空间维度 - 图像版本) + Args: + gradients_fake: [B, C, H, W] 生成图像判别器梯度 + feature_shape: tuple [H, W] 判别器输出的特征图尺寸 + Returns: + weight_fake: [B, 1, H, W] 生成图像权重图 + """ + H, W = feature_shape + # 计算余弦相似度(图像版本) + cosine_fake = self.compute_cosine_similarity_image(gradients_fake) # [B, H, W] + # 生成权重图(与原代码相同,但现在cosine_fake是[B, H, W]) + k = int(self.eta_ratio * H * W) # k 仍然是基于总的空间位置数量计算 + _, fake_indices = torch.topk(-cosine_fake.view(cosine_fake.shape[0], -1), k, dim=1) # 将 cosine_fake 展平为 [B, N] 以使用 topk + weight_fake = torch.ones_like(cosine_fake).view(cosine_fake.shape[0], -1) # 初始化权重图,并展平为 [B, N] + for b in range(cosine_fake.shape[0]): + weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake.view(cosine_fake.shape[0], -1)[b, fake_indices[b]])) + weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # 重新 reshape 为 [B, H, W],并添加通道维度变为 [B, 1, H, W] + return weight_fake + def forward(self, D_real, D_fake, real_scores, fake_scores): """ 计算内容感知对抗损失 @@ -123,7 +173,7 @@ class ContentAwareOptimization(nn.Module): """ B, C, H, W = D_real.shape N = H * W - shape_hw = [h, w] + shape_hw = [H, W] # 注册钩子获取梯度 gradients_real = [] gradients_fake = [] @@ -146,8 +196,8 @@ class ContentAwareOptimization(nn.Module): total_loss.backward(retain_graph=True) # 获取梯度数据 - gradients_real = gradients_real[0] # [B, N, D] - gradients_fake = gradients_fake[0] # [B, N, D] + gradients_real = gradients_real[1] # [B, N, D] + gradients_fake = gradients_fake[1] # [B, N, D] # 生成权重图 self.weight_real, self.weight_fake = self.generate_weight_map(gradients_fake, shape_hw ) @@ -235,7 +285,7 @@ class RomaUnsbModel(BaseModel): parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter') parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer') - + parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field') parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers') parser.set_defaults(pool_size=0) # no image pooling @@ -373,7 +423,6 @@ class RomaUnsbModel(BaseModel): self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] - def tokens_concat(self, origin_tokens, adjacent_size): adj_size = adjacent_size B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2] @@ -394,10 +443,9 @@ class RomaUnsbModel(BaseModel): cut_patch = torch.mean(cut_patch, dim=1, keepdim=True) cut_patch_list.append(cut_patch) - result = torch.cat(cut_patch_list,dim=1) return result - + def cat_results(self, origin_tokens, adj_size_list): res_list = [origin_tokens] for ad_s in adj_size_list: @@ -405,7 +453,6 @@ class RomaUnsbModel(BaseModel): res_list.append(cat_result) result = torch.cat(res_list, dim=1) - return result def forward(self): @@ -468,10 +515,8 @@ class RomaUnsbModel(BaseModel): self.real = torch.flip(self.real, [3]) self.realt = torch.flip(self.realt, [3]) - print(f'fake_B0: {self.real_A0.shape}, fake_B1: {self.real_A1.shape}') self.fake_B0 = self.netG(self.real_A0, self.time, z_in) self.fake_B1 = self.netG(self.real_A1, self.time, z_in2) - print(f'fake_B0: {self.fake_B0.shape}, fake_B1: {self.fake_B1.shape}') if self.opt.phase == 'train': real_A0 = self.real_A0 @@ -496,12 +541,15 @@ class RomaUnsbModel(BaseModel): self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) # [[1,576,768],[1,576,768],[1,576,768]] # [3,576,768] + #self.mutil_real_A0_tokens = self.cat_results(self.mutil_real_A0_tokens[0], self.opt.adj_size_list) + #print(f'self.mutil_real_A0_tokens[0]:{self.mutil_real_A0_tokens[0].shape}') + shape_hw = list(self.real_A0_resize.shape[2:4]) # 生成图像的梯度 fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0] # 梯度图 - self.weight_fake = self.cao.generate_weight_map(fake_gradient,shape_hw) + self.weight_fake = self.cao.generate_weight_map_image(fake_gradient, shape_hw) # 生成图像的CTN光流图 self.f_content = self.ctn(self.weight_fake)