renew
This commit is contained in:
parent
55b9db967a
commit
7af2de920c
@ -97,9 +97,13 @@ class ContentAwareOptimization(nn.Module):
|
||||
cosine_fake = self.compute_cosine_similarity(gradients_fake)
|
||||
|
||||
# 生成权重图(与原代码相同)
|
||||
k = int(self.eta_ratio * N)
|
||||
k = int(self.eta_ratio * cosine_fake.shape[1])
|
||||
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
|
||||
weight_fake = torch.ones_like(cosine_fake)
|
||||
|
||||
|
||||
for b in range(cosine_fake.shape[0]):
|
||||
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
|
||||
|
||||
# 重建空间维度 --------------------------------------------------
|
||||
# 将权重从[B, N]转换为[B, H, W]
|
||||
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user