debug
This commit is contained in:
parent
687559866d
commit
8a081af0a3
@ -39,3 +39,8 @@
|
|||||||
================ Training Loss (Sun Feb 23 21:29:03 2025) ================
|
================ Training Loss (Sun Feb 23 21:29:03 2025) ================
|
||||||
================ Training Loss (Sun Feb 23 21:34:57 2025) ================
|
================ Training Loss (Sun Feb 23 21:34:57 2025) ================
|
||||||
================ Training Loss (Sun Feb 23 21:35:26 2025) ================
|
================ Training Loss (Sun Feb 23 21:35:26 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:28:43 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:29:04 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:29:52 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:30:40 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:33:48 2025) ================
|
||||||
|
|||||||
Binary file not shown.
@ -395,8 +395,6 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||||
|
|
||||||
@ -462,34 +460,12 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
|
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
|
||||||
|
|
||||||
if self.opt.phase == 'train':
|
if self.opt.phase == 'train':
|
||||||
# 生成图像的梯度
|
|
||||||
print(f'self.fake_B0: {self.fake_B0.shape}')
|
|
||||||
fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0]
|
|
||||||
|
|
||||||
# 梯度图
|
|
||||||
print(f'fake_gradient: {fake_gradient.shape}')
|
|
||||||
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
|
||||||
|
|
||||||
# 生成图像的CTN光流图
|
|
||||||
print(f'weight_fake: {self.weight_fake.shape}')
|
|
||||||
self.f_content = self.ctn(self.weight_fake)
|
|
||||||
|
|
||||||
# 变换后的图片
|
|
||||||
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
|
|
||||||
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
|
|
||||||
|
|
||||||
# 经过第二次生成器
|
|
||||||
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
|
|
||||||
|
|
||||||
if self.opt.isTrain:
|
|
||||||
real_A0 = self.real_A0
|
real_A0 = self.real_A0
|
||||||
real_A1 = self.real_A1
|
real_A1 = self.real_A1
|
||||||
real_B0 = self.real_B0
|
real_B0 = self.real_B0
|
||||||
real_B1 = self.real_B1
|
real_B1 = self.real_B1
|
||||||
fake_B0 = self.fake_B0
|
fake_B0 = self.fake_B0
|
||||||
fake_B1 = self.fake_B1
|
fake_B1 = self.fake_B1
|
||||||
warped_fake_B0_2=self.warped_fake_B0_2
|
|
||||||
warped_fake_B0=self.warped_fake_B0
|
|
||||||
|
|
||||||
self.real_A0_resize = self.resize(real_A0)
|
self.real_A0_resize = self.resize(real_A0)
|
||||||
self.real_A1_resize = self.resize(real_A1)
|
self.real_A1_resize = self.resize(real_A1)
|
||||||
@ -497,8 +473,6 @@ class RomaUnsbModel(BaseModel):
|
|||||||
real_B1 = self.resize(real_B1)
|
real_B1 = self.resize(real_B1)
|
||||||
self.fake_B0_resize = self.resize(fake_B0)
|
self.fake_B0_resize = self.resize(fake_B0)
|
||||||
self.fake_B1_resize = self.resize(fake_B1)
|
self.fake_B1_resize = self.resize(fake_B1)
|
||||||
self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
|
|
||||||
self.warped_fake_B0_resize = self.resize(warped_fake_B0)
|
|
||||||
|
|
||||||
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
||||||
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
||||||
@ -506,8 +480,32 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
||||||
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||||
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||||||
|
# [[1,576,768],[1,576,768],[1,576,768]]
|
||||||
|
# [3,576,768]
|
||||||
|
|
||||||
|
# 生成图像的梯度
|
||||||
|
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens.sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
||||||
|
|
||||||
|
# 梯度图
|
||||||
|
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
||||||
|
|
||||||
|
# 生成图像的CTN光流图
|
||||||
|
self.f_content = self.ctn(self.weight_fake)
|
||||||
|
|
||||||
|
# 变换后的图片
|
||||||
|
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
|
||||||
|
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
|
||||||
|
|
||||||
|
# 经过第二次生成器
|
||||||
|
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
|
||||||
|
|
||||||
|
warped_fake_B0_2=self.warped_fake_B0_2
|
||||||
|
warped_fake_B0=self.warped_fake_B0
|
||||||
|
self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
|
||||||
|
self.warped_fake_B0_resize = self.resize(warped_fake_B0)
|
||||||
self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
|
self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||||
self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
|
self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
|
||||||
|
|
||||||
|
|
||||||
def compute_D_loss(self):
|
def compute_D_loss(self):
|
||||||
"""计算判别器的 GAN 损失"""
|
"""计算判别器的 GAN 损失"""
|
||||||
@ -526,8 +524,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
def compute_E_loss(self):
|
def compute_E_loss(self):
|
||||||
"""计算判别器 E 的损失"""
|
"""计算判别器 E 的损失"""
|
||||||
|
|
||||||
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B.detach()], dim=1)
|
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
|
||||||
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2.detach()], dim=1)
|
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
|
||||||
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
|
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
|
||||||
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
|
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
|
||||||
|
|
||||||
@ -536,10 +534,10 @@ class RomaUnsbModel(BaseModel):
|
|||||||
def compute_G_loss(self):
|
def compute_G_loss(self):
|
||||||
"""计算生成器的 GAN 损失"""
|
"""计算生成器的 GAN 损失"""
|
||||||
|
|
||||||
bs = self.mutil_real_A0_tokens.size(0)
|
bs = self.real_A0.size(0)
|
||||||
tau = self.opt.tau
|
tau = self.opt.tau
|
||||||
|
|
||||||
fake = self.fake_B
|
fake = self.fake_B0
|
||||||
std = torch.rand(size=[1]).item() * self.opt.std
|
std = torch.rand(size=[1]).item() * self.opt.std
|
||||||
|
|
||||||
if self.opt.lambda_GAN > 0.0:
|
if self.opt.lambda_GAN > 0.0:
|
||||||
@ -549,8 +547,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.loss_G_GAN = 0.0
|
self.loss_G_GAN = 0.0
|
||||||
self.loss_SB = 0
|
self.loss_SB = 0
|
||||||
if self.opt.lambda_SB > 0.0:
|
if self.opt.lambda_SB > 0.0:
|
||||||
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B], dim=1)
|
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
|
||||||
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2], dim=1)
|
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1)
|
||||||
|
|
||||||
bs = self.opt.batch_size
|
bs = self.opt.batch_size
|
||||||
|
|
||||||
@ -560,7 +558,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B) ** 2)
|
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B) ** 2)
|
||||||
|
|
||||||
if self.opt.lambda_global > 0.0:
|
if self.opt.lambda_global > 0.0:
|
||||||
loss_global = self.calculate_similarity(self.mutil_real_A0_tokens, self.mutil_fake_B0_tokens) + self.calculate_similarity(self.mutil_real_A1_tokens, self.mutil_fake_B1_tokens)
|
loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
|
||||||
loss_global *= 0.5
|
loss_global *= 0.5
|
||||||
else:
|
else:
|
||||||
loss_global = 0.0
|
loss_global = 0.0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user