Merge branch 'main' of http://47.108.14.56:4000/123456/roma_unsb
This commit is contained in:
commit
26b770a3c1
@ -97,9 +97,13 @@ class ContentAwareOptimization(nn.Module):
|
|||||||
cosine_fake = self.compute_cosine_similarity(gradients_fake)
|
cosine_fake = self.compute_cosine_similarity(gradients_fake)
|
||||||
|
|
||||||
# 生成权重图(与原代码相同)
|
# 生成权重图(与原代码相同)
|
||||||
k = int(self.eta_ratio * N)
|
k = int(self.eta_ratio * cosine_fake.shape[1])
|
||||||
|
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
|
||||||
weight_fake = torch.ones_like(cosine_fake)
|
weight_fake = torch.ones_like(cosine_fake)
|
||||||
|
|
||||||
|
for b in range(cosine_fake.shape[0]):
|
||||||
|
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
|
||||||
|
|
||||||
# 重建空间维度 --------------------------------------------------
|
# 重建空间维度 --------------------------------------------------
|
||||||
# 将权重从[B, N]转换为[B, H, W]
|
# 将权重从[B, N]转换为[B, H, W]
|
||||||
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W]
|
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W]
|
||||||
@ -119,7 +123,7 @@ class ContentAwareOptimization(nn.Module):
|
|||||||
"""
|
"""
|
||||||
B, C, H, W = D_real.shape
|
B, C, H, W = D_real.shape
|
||||||
N = H * W
|
N = H * W
|
||||||
|
shape_hw = [h, w]
|
||||||
# 注册钩子获取梯度
|
# 注册钩子获取梯度
|
||||||
gradients_real = []
|
gradients_real = []
|
||||||
gradients_fake = []
|
gradients_fake = []
|
||||||
@ -146,7 +150,7 @@ class ContentAwareOptimization(nn.Module):
|
|||||||
gradients_fake = gradients_fake[0] # [B, N, D]
|
gradients_fake = gradients_fake[0] # [B, N, D]
|
||||||
|
|
||||||
# 生成权重图
|
# 生成权重图
|
||||||
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
|
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_fake, shape_hw )
|
||||||
|
|
||||||
# 应用权重到对抗损失
|
# 应用权重到对抗损失
|
||||||
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
|
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
|
||||||
@ -492,12 +496,12 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||||||
# [[1,576,768],[1,576,768],[1,576,768]]
|
# [[1,576,768],[1,576,768],[1,576,768]]
|
||||||
# [3,576,768]
|
# [3,576,768]
|
||||||
|
shape_hw = list(self.real_A0_resize.shape[2:4])
|
||||||
# 生成图像的梯度
|
# 生成图像的梯度
|
||||||
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
||||||
|
|
||||||
# 梯度图
|
# 梯度图
|
||||||
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
self.weight_fake = self.cao.generate_weight_map(fake_gradient,shape_hw)
|
||||||
|
|
||||||
# 生成图像的CTN光流图
|
# 生成图像的CTN光流图
|
||||||
self.f_content = self.ctn(self.weight_fake)
|
self.f_content = self.ctn(self.weight_fake)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user