Merge branch 'main' of http://47.108.14.56:4000/123456/roma_unsb
This commit is contained in:
commit
26b770a3c1
@ -97,9 +97,13 @@ class ContentAwareOptimization(nn.Module):
|
||||
cosine_fake = self.compute_cosine_similarity(gradients_fake)
|
||||
|
||||
# 生成权重图(与原代码相同)
|
||||
k = int(self.eta_ratio * N)
|
||||
k = int(self.eta_ratio * cosine_fake.shape[1])
|
||||
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
|
||||
weight_fake = torch.ones_like(cosine_fake)
|
||||
|
||||
|
||||
for b in range(cosine_fake.shape[0]):
|
||||
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
|
||||
|
||||
# 重建空间维度 --------------------------------------------------
|
||||
# 将权重从[B, N]转换为[B, H, W]
|
||||
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W]
|
||||
@ -119,7 +123,7 @@ class ContentAwareOptimization(nn.Module):
|
||||
"""
|
||||
B, C, H, W = D_real.shape
|
||||
N = H * W
|
||||
|
||||
shape_hw = [h, w]
|
||||
# 注册钩子获取梯度
|
||||
gradients_real = []
|
||||
gradients_fake = []
|
||||
@ -146,7 +150,7 @@ class ContentAwareOptimization(nn.Module):
|
||||
gradients_fake = gradients_fake[0] # [B, N, D]
|
||||
|
||||
# 生成权重图
|
||||
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
|
||||
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_fake, shape_hw )
|
||||
|
||||
# 应用权重到对抗损失
|
||||
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
|
||||
@ -492,12 +496,12 @@ class RomaUnsbModel(BaseModel):
|
||||
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||||
# [[1,576,768],[1,576,768],[1,576,768]]
|
||||
# [3,576,768]
|
||||
|
||||
shape_hw = list(self.real_A0_resize.shape[2:4])
|
||||
# 生成图像的梯度
|
||||
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
||||
|
||||
# 梯度图
|
||||
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
||||
self.weight_fake = self.cao.generate_weight_map(fake_gradient,shape_hw)
|
||||
|
||||
# 生成图像的CTN光流图
|
||||
self.f_content = self.ctn(self.weight_fake)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user