diff --git a/checkpoints/ROMA_UNSB_001/loss_log.txt b/checkpoints/ROMA_UNSB_001/loss_log.txt index e2b8c6e..741a618 100644 --- a/checkpoints/ROMA_UNSB_001/loss_log.txt +++ b/checkpoints/ROMA_UNSB_001/loss_log.txt @@ -33,3 +33,4 @@ ================ Training Loss (Sun Feb 23 18:59:52 2025) ================ ================ Training Loss (Sun Feb 23 19:03:05 2025) ================ ================ Training Loss (Sun Feb 23 19:03:57 2025) ================ +================ Training Loss (Sun Feb 23 21:11:47 2025) ================ diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index cb9f345..0c9cd4b 100644 Binary files a/models/__pycache__/roma_unsb_model.cpython-39.pyc and b/models/__pycache__/roma_unsb_model.cpython-39.pyc differ diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 45b220e..05cfaa7 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -527,26 +527,17 @@ class RomaUnsbModel(BaseModel): setattr(self, "fake_"+str(t+1), Xt_1) if self.opt.phase == 'train': - print(f'real_B0.shape = {real_B0.shape} fake_B0.shape = {self.fake_B0.shape}') - print(f"self.real_B0.requires_grad: {real_B0.requires_grad}") - # 真实图像的梯度 - real_gradient = torch.autograd.grad(real_B0.sum(), real_B0, create_graph=True)[0] # 生成图像的梯度 fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0] # 梯度图 - self.weight_real, self.weight_fake = self.cao.generate_weight_map(fake_gradient) - + self.weight_fake = self.cao.generate_weight_map(fake_gradient) # 生成图像的CTN光流图 self.f_content = self.ctn(self.weight_fake) - - # 把前面生成后的图片再加上noisy_map - self.fake_B_2 = self.fake_B + self.noisy_map - # 变换后的图片 - wapped_fake_B = warp(self.fake_B, self.f_content) - + self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content) + self.warped_fake_B0 = warp(self.fake_B0,self.f_content) # 经过第二次生成器 - self.fake_B_2 = self.netG(wapped_fake_B, self.time, z_in) + self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in) def compute_D_loss(self): """计算判别器的 GAN 损失"""