diff --git a/checkpoints/ROMA_UNSB_001/loss_log.txt b/checkpoints/ROMA_UNSB_001/loss_log.txt index 19fcafc..e2b8c6e 100644 --- a/checkpoints/ROMA_UNSB_001/loss_log.txt +++ b/checkpoints/ROMA_UNSB_001/loss_log.txt @@ -25,3 +25,11 @@ ================ Training Loss (Sun Feb 23 18:39:21 2025) ================ ================ Training Loss (Sun Feb 23 18:40:15 2025) ================ ================ Training Loss (Sun Feb 23 18:41:15 2025) ================ +================ Training Loss (Sun Feb 23 18:47:46 2025) ================ +================ Training Loss (Sun Feb 23 18:48:36 2025) ================ +================ Training Loss (Sun Feb 23 18:50:20 2025) ================ +================ Training Loss (Sun Feb 23 18:51:50 2025) ================ +================ Training Loss (Sun Feb 23 18:58:45 2025) ================ +================ Training Loss (Sun Feb 23 18:59:52 2025) ================ +================ Training Loss (Sun Feb 23 19:03:05 2025) ================ +================ Training Loss (Sun Feb 23 19:03:57 2025) ================ diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index fe4d91c..cb9f345 100644 Binary files a/models/__pycache__/roma_unsb_model.cpython-39.pyc and b/models/__pycache__/roma_unsb_model.cpython-39.pyc differ diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index b9c2ad7..45b220e 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -166,6 +166,7 @@ class ContentAwareTemporalNorm(nn.Module): Returns: F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) """ + print(weight_map.shape) B, _, H, W = weight_map.shape # 1. 归一化权重图 @@ -403,8 +404,8 @@ class RomaUnsbModel(BaseModel): print(f'before resize: {self.real_A0.shape}') real_A0 = self.resize(self.real_A0) real_A1 = self.resize(self.real_A1) - real_B0 = self.resize(self.real_B0) - real_B1 = self.resize(self.real_B1) + real_B0 = self.resize(self.real_B0).requires_grad_(True) + real_B1 = self.resize(self.real_B1).requires_grad_(True) # 使用VIT print(f'before vit: {real_A0.shape}') @@ -526,12 +527,14 @@ class RomaUnsbModel(BaseModel): setattr(self, "fake_"+str(t+1), Xt_1) if self.opt.phase == 'train': + print(f'real_B0.shape = {real_B0.shape} fake_B0.shape = {self.fake_B0.shape}') + print(f"self.real_B0.requires_grad: {real_B0.requires_grad}") # 真实图像的梯度 - real_gradient = torch.autograd.grad(self.real_B0.sum(), self.real_B0, create_graph=True)[0] + real_gradient = torch.autograd.grad(real_B0.sum(), real_B0, create_graph=True)[0] # 生成图像的梯度 - fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0] + fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0] # 梯度图 - self.weight_real, self.weight_fake = self.cao.generate_weight_map(real_gradient, fake_gradient) + self.weight_real, self.weight_fake = self.cao.generate_weight_map(fake_gradient) # 生成图像的CTN光流图 self.f_content = self.ctn(self.weight_fake)