This commit is contained in:
bishe 2025-02-23 22:40:34 +08:00
parent 687559866d
commit 8a081af0a3
3 changed files with 36 additions and 33 deletions

View File

@ -39,3 +39,8 @@
================ Training Loss (Sun Feb 23 21:29:03 2025) ================ ================ Training Loss (Sun Feb 23 21:29:03 2025) ================
================ Training Loss (Sun Feb 23 21:34:57 2025) ================ ================ Training Loss (Sun Feb 23 21:34:57 2025) ================
================ Training Loss (Sun Feb 23 21:35:26 2025) ================ ================ Training Loss (Sun Feb 23 21:35:26 2025) ================
================ Training Loss (Sun Feb 23 22:28:43 2025) ================
================ Training Loss (Sun Feb 23 22:29:04 2025) ================
================ Training Loss (Sun Feb 23 22:29:52 2025) ================
================ Training Loss (Sun Feb 23 22:30:40 2025) ================
================ Training Loss (Sun Feb 23 22:33:48 2025) ================

View File

@ -395,8 +395,6 @@ class RomaUnsbModel(BaseModel):
return result return result
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
@ -462,34 +460,12 @@ class RomaUnsbModel(BaseModel):
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2) self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
if self.opt.phase == 'train': if self.opt.phase == 'train':
# 生成图像的梯度
print(f'self.fake_B0: {self.fake_B0.shape}')
fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0]
# 梯度图
print(f'fake_gradient: {fake_gradient.shape}')
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
# 生成图像的CTN光流图
print(f'weight_fake: {self.weight_fake.shape}')
self.f_content = self.ctn(self.weight_fake)
# 变换后的图片
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
# 经过第二次生成器
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
if self.opt.isTrain:
real_A0 = self.real_A0 real_A0 = self.real_A0
real_A1 = self.real_A1 real_A1 = self.real_A1
real_B0 = self.real_B0 real_B0 = self.real_B0
real_B1 = self.real_B1 real_B1 = self.real_B1
fake_B0 = self.fake_B0 fake_B0 = self.fake_B0
fake_B1 = self.fake_B1 fake_B1 = self.fake_B1
warped_fake_B0_2=self.warped_fake_B0_2
warped_fake_B0=self.warped_fake_B0
self.real_A0_resize = self.resize(real_A0) self.real_A0_resize = self.resize(real_A0)
self.real_A1_resize = self.resize(real_A1) self.real_A1_resize = self.resize(real_A1)
@ -497,8 +473,6 @@ class RomaUnsbModel(BaseModel):
real_B1 = self.resize(real_B1) real_B1 = self.resize(real_B1)
self.fake_B0_resize = self.resize(fake_B0) self.fake_B0_resize = self.resize(fake_B0)
self.fake_B1_resize = self.resize(fake_B1) self.fake_B1_resize = self.resize(fake_B1)
self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
self.warped_fake_B0_resize = self.resize(warped_fake_B0)
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True) self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True) self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
@ -506,9 +480,33 @@ class RomaUnsbModel(BaseModel):
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True) self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
# [[1,576,768],[1,576,768],[1,576,768]]
# [3,576,768]
# 生成图像的梯度
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens.sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
# 梯度图
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
# 生成图像的CTN光流图
self.f_content = self.ctn(self.weight_fake)
# 变换后的图片
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
# 经过第二次生成器
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
warped_fake_B0_2=self.warped_fake_B0_2
warped_fake_B0=self.warped_fake_B0
self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
self.warped_fake_B0_resize = self.resize(warped_fake_B0)
self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True) self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True) self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
def compute_D_loss(self): def compute_D_loss(self):
"""计算判别器的 GAN 损失""" """计算判别器的 GAN 损失"""
@ -526,8 +524,8 @@ class RomaUnsbModel(BaseModel):
def compute_E_loss(self): def compute_E_loss(self):
"""计算判别器 E 的损失""" """计算判别器 E 的损失"""
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B.detach()], dim=1) XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2.detach()], dim=1) XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean() temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2 self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
@ -536,10 +534,10 @@ class RomaUnsbModel(BaseModel):
def compute_G_loss(self): def compute_G_loss(self):
"""计算生成器的 GAN 损失""" """计算生成器的 GAN 损失"""
bs = self.mutil_real_A0_tokens.size(0) bs = self.real_A0.size(0)
tau = self.opt.tau tau = self.opt.tau
fake = self.fake_B fake = self.fake_B0
std = torch.rand(size=[1]).item() * self.opt.std std = torch.rand(size=[1]).item() * self.opt.std
if self.opt.lambda_GAN > 0.0: if self.opt.lambda_GAN > 0.0:
@ -549,8 +547,8 @@ class RomaUnsbModel(BaseModel):
self.loss_G_GAN = 0.0 self.loss_G_GAN = 0.0
self.loss_SB = 0 self.loss_SB = 0
if self.opt.lambda_SB > 0.0: if self.opt.lambda_SB > 0.0:
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B], dim=1) XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2], dim=1) XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1)
bs = self.opt.batch_size bs = self.opt.batch_size
@ -560,7 +558,7 @@ class RomaUnsbModel(BaseModel):
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B) ** 2) self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B) ** 2)
if self.opt.lambda_global > 0.0: if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(self.mutil_real_A0_tokens, self.mutil_fake_B0_tokens) + self.calculate_similarity(self.mutil_real_A1_tokens, self.mutil_fake_B1_tokens) loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
loss_global *= 0.5 loss_global *= 0.5
else: else:
loss_global = 0.0 loss_global = 0.0