fixing CTN
This commit is contained in:
parent
1caa5f0625
commit
6705075876
@ -25,3 +25,11 @@
|
|||||||
================ Training Loss (Sun Feb 23 18:39:21 2025) ================
|
================ Training Loss (Sun Feb 23 18:39:21 2025) ================
|
||||||
================ Training Loss (Sun Feb 23 18:40:15 2025) ================
|
================ Training Loss (Sun Feb 23 18:40:15 2025) ================
|
||||||
================ Training Loss (Sun Feb 23 18:41:15 2025) ================
|
================ Training Loss (Sun Feb 23 18:41:15 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:47:46 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:48:36 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:50:20 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:51:50 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:58:45 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:59:52 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 19:03:05 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 19:03:57 2025) ================
|
||||||
|
|||||||
Binary file not shown.
@ -166,6 +166,7 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
||||||
"""
|
"""
|
||||||
|
print(weight_map.shape)
|
||||||
B, _, H, W = weight_map.shape
|
B, _, H, W = weight_map.shape
|
||||||
|
|
||||||
# 1. 归一化权重图
|
# 1. 归一化权重图
|
||||||
@ -403,8 +404,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
print(f'before resize: {self.real_A0.shape}')
|
print(f'before resize: {self.real_A0.shape}')
|
||||||
real_A0 = self.resize(self.real_A0)
|
real_A0 = self.resize(self.real_A0)
|
||||||
real_A1 = self.resize(self.real_A1)
|
real_A1 = self.resize(self.real_A1)
|
||||||
real_B0 = self.resize(self.real_B0)
|
real_B0 = self.resize(self.real_B0).requires_grad_(True)
|
||||||
real_B1 = self.resize(self.real_B1)
|
real_B1 = self.resize(self.real_B1).requires_grad_(True)
|
||||||
# 使用VIT
|
# 使用VIT
|
||||||
|
|
||||||
print(f'before vit: {real_A0.shape}')
|
print(f'before vit: {real_A0.shape}')
|
||||||
@ -526,12 +527,14 @@ class RomaUnsbModel(BaseModel):
|
|||||||
setattr(self, "fake_"+str(t+1), Xt_1)
|
setattr(self, "fake_"+str(t+1), Xt_1)
|
||||||
|
|
||||||
if self.opt.phase == 'train':
|
if self.opt.phase == 'train':
|
||||||
|
print(f'real_B0.shape = {real_B0.shape} fake_B0.shape = {self.fake_B0.shape}')
|
||||||
|
print(f"self.real_B0.requires_grad: {real_B0.requires_grad}")
|
||||||
# 真实图像的梯度
|
# 真实图像的梯度
|
||||||
real_gradient = torch.autograd.grad(self.real_B0.sum(), self.real_B0, create_graph=True)[0]
|
real_gradient = torch.autograd.grad(real_B0.sum(), real_B0, create_graph=True)[0]
|
||||||
# 生成图像的梯度
|
# 生成图像的梯度
|
||||||
fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0]
|
fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0]
|
||||||
# 梯度图
|
# 梯度图
|
||||||
self.weight_real, self.weight_fake = self.cao.generate_weight_map(real_gradient, fake_gradient)
|
self.weight_real, self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
||||||
|
|
||||||
# 生成图像的CTN光流图
|
# 生成图像的CTN光流图
|
||||||
self.f_content = self.ctn(self.weight_fake)
|
self.f_content = self.ctn(self.weight_fake)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user