最新的修改
This commit is contained in:
parent
7af2de920c
commit
e67b0f2511
@ -123,7 +123,7 @@ class ContentAwareOptimization(nn.Module):
|
|||||||
"""
|
"""
|
||||||
B, C, H, W = D_real.shape
|
B, C, H, W = D_real.shape
|
||||||
N = H * W
|
N = H * W
|
||||||
|
shape_hw = [h, w]
|
||||||
# 注册钩子获取梯度
|
# 注册钩子获取梯度
|
||||||
gradients_real = []
|
gradients_real = []
|
||||||
gradients_fake = []
|
gradients_fake = []
|
||||||
@ -150,7 +150,7 @@ class ContentAwareOptimization(nn.Module):
|
|||||||
gradients_fake = gradients_fake[0] # [B, N, D]
|
gradients_fake = gradients_fake[0] # [B, N, D]
|
||||||
|
|
||||||
# 生成权重图
|
# 生成权重图
|
||||||
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
|
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_fake, shape_hw )
|
||||||
|
|
||||||
# 应用权重到对抗损失
|
# 应用权重到对抗损失
|
||||||
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
|
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
|
||||||
@ -496,12 +496,12 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||||||
# [[1,576,768],[1,576,768],[1,576,768]]
|
# [[1,576,768],[1,576,768],[1,576,768]]
|
||||||
# [3,576,768]
|
# [3,576,768]
|
||||||
|
shape_hw = list(self.real_A0_resize.shape[2:4])
|
||||||
# 生成图像的梯度
|
# 生成图像的梯度
|
||||||
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
||||||
|
|
||||||
# 梯度图
|
# 梯度图
|
||||||
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
self.weight_fake = self.cao.generate_weight_map(fake_gradient,shape_hw)
|
||||||
|
|
||||||
# 生成图像的CTN光流图
|
# 生成图像的CTN光流图
|
||||||
self.f_content = self.ctn(self.weight_fake)
|
self.f_content = self.ctn(self.weight_fake)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user