This commit is contained in:
Kunyu_Lee 2025-02-24 21:45:06 +08:00
commit 26b770a3c1

View File

@ -97,9 +97,13 @@ class ContentAwareOptimization(nn.Module):
cosine_fake = self.compute_cosine_similarity(gradients_fake) cosine_fake = self.compute_cosine_similarity(gradients_fake)
# 生成权重图(与原代码相同) # 生成权重图(与原代码相同)
k = int(self.eta_ratio * N) k = int(self.eta_ratio * cosine_fake.shape[1])
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
weight_fake = torch.ones_like(cosine_fake) weight_fake = torch.ones_like(cosine_fake)
for b in range(cosine_fake.shape[0]):
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
# 重建空间维度 -------------------------------------------------- # 重建空间维度 --------------------------------------------------
# 将权重从[B, N]转换为[B, H, W] # 将权重从[B, N]转换为[B, H, W]
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W] weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W]
@ -119,7 +123,7 @@ class ContentAwareOptimization(nn.Module):
""" """
B, C, H, W = D_real.shape B, C, H, W = D_real.shape
N = H * W N = H * W
shape_hw = [h, w]
# 注册钩子获取梯度 # 注册钩子获取梯度
gradients_real = [] gradients_real = []
gradients_fake = [] gradients_fake = []
@ -146,7 +150,7 @@ class ContentAwareOptimization(nn.Module):
gradients_fake = gradients_fake[0] # [B, N, D] gradients_fake = gradients_fake[0] # [B, N, D]
# 生成权重图 # 生成权重图
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake) self.weight_real, self.weight_fake = self.generate_weight_map(gradients_fake, shape_hw )
# 应用权重到对抗损失 # 应用权重到对抗损失
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8)) loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
@ -492,12 +496,12 @@ class RomaUnsbModel(BaseModel):
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
# [[1,576,768],[1,576,768],[1,576,768]] # [[1,576,768],[1,576,768],[1,576,768]]
# [3,576,768] # [3,576,768]
shape_hw = list(self.real_A0_resize.shape[2:4])
# 生成图像的梯度 # 生成图像的梯度
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0] fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
# 梯度图 # 梯度图
self.weight_fake = self.cao.generate_weight_map(fake_gradient) self.weight_fake = self.cao.generate_weight_map(fake_gradient,shape_hw)
# 生成图像的CTN光流图 # 生成图像的CTN光流图
self.f_content = self.ctn(self.weight_fake) self.f_content = self.ctn(self.weight_fake)