use kun's forward method
This commit is contained in:
parent
0639032b6c
commit
687559866d
@ -34,3 +34,8 @@
|
|||||||
================ Training Loss (Sun Feb 23 19:03:05 2025) ================
|
================ Training Loss (Sun Feb 23 19:03:05 2025) ================
|
||||||
================ Training Loss (Sun Feb 23 19:03:57 2025) ================
|
================ Training Loss (Sun Feb 23 19:03:57 2025) ================
|
||||||
================ Training Loss (Sun Feb 23 21:11:47 2025) ================
|
================ Training Loss (Sun Feb 23 21:11:47 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 21:17:10 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 21:20:14 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 21:29:03 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 21:34:57 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 21:35:26 2025) ================
|
||||||
|
|||||||
Binary file not shown.
@ -83,9 +83,9 @@ class ContentAwareOptimization(nn.Module):
|
|||||||
"""
|
"""
|
||||||
生成内容感知权重图
|
生成内容感知权重图
|
||||||
Args:
|
Args:
|
||||||
gradients_fake: [B, N, D] 生成图像判别器梯度
|
gradients_fake: [B, N, D] 生成图像判别器梯度 [2,3,256,256]
|
||||||
Returns:
|
Returns:
|
||||||
weight_fake: [B, N] 生成图像权重图
|
weight_fake: [B, N] 生成图像权重图 [2,3,256]
|
||||||
"""
|
"""
|
||||||
# 计算生成图像块的余弦相似度
|
# 计算生成图像块的余弦相似度
|
||||||
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
|
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
|
||||||
@ -398,28 +398,9 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
"""执行前向传递以生成输出图像"""
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||||
|
|
||||||
if self.opt.isTrain:
|
# ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============
|
||||||
print(f'before resize: {self.real_A0.shape}')
|
|
||||||
real_A0 = self.resize(self.real_A0)
|
|
||||||
real_A1 = self.resize(self.real_A1)
|
|
||||||
real_B0 = self.resize(self.real_B0).requires_grad_(True)
|
|
||||||
real_B1 = self.resize(self.real_B1).requires_grad_(True)
|
|
||||||
# 使用VIT
|
|
||||||
|
|
||||||
print(f'before vit: {real_A0.shape}')
|
|
||||||
self.mutil_real_A0_tokens = self.netPreViT(real_A0, self.atten_layers, get_tokens=True)
|
|
||||||
self.mutil_real_A1_tokens = self.netPreViT(real_A1, self.atten_layers, get_tokens=True)
|
|
||||||
|
|
||||||
print(f'before cat: len = {len(self.mutil_real_A0_tokens)}\n{self.mutil_real_A0_tokens[0].shape}')
|
|
||||||
self.mutil_real_A0_tokens = torch.cat(self.mutil_real_A0_tokens, dim=0).unsqueeze(0).to(self.device)
|
|
||||||
self.mutil_real_A1_tokens = torch.cat(self.mutil_real_A1_tokens, dim=0).unsqueeze(0).to(self.device)
|
|
||||||
|
|
||||||
# 执行一次SB模块
|
|
||||||
|
|
||||||
# ============ 第一步:初始化时间步与时间索引 ============
|
|
||||||
# 计算 times,并确定当前 time_idx(随机选取用来表示当前时间步)
|
|
||||||
tau = self.opt.tau
|
tau = self.opt.tau
|
||||||
T = self.opt.num_timesteps
|
T = self.opt.num_timesteps
|
||||||
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
|
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
|
||||||
@ -429,7 +410,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
times = np.concatenate([np.zeros(1), times])
|
times = np.concatenate([np.zeros(1), times])
|
||||||
times = torch.tensor(times).float().cuda()
|
times = torch.tensor(times).float().cuda()
|
||||||
self.times = times
|
self.times = times
|
||||||
bs = self.mutil_real_A0_tokens.size(0)
|
bs = self.real_A0.size(0)
|
||||||
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
|
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
|
||||||
self.time_idx = time_idx
|
self.time_idx = time_idx
|
||||||
|
|
||||||
@ -444,34 +425,30 @@ class RomaUnsbModel(BaseModel):
|
|||||||
inter = (delta / denom).reshape(-1, 1, 1, 1)
|
inter = (delta / denom).reshape(-1, 1, 1, 1)
|
||||||
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
|
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
|
||||||
|
|
||||||
print(f'before noisy: {self.mutil_real_A0_tokens.shape}')
|
|
||||||
# 对 Xt、Xt2 进行随机噪声更新
|
# 对 Xt、Xt2 进行随机噪声更新
|
||||||
Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
|
Xt = self.real_A0 if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
|
||||||
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
|
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device)
|
||||||
time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long()
|
time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long()
|
||||||
z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
|
z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device)
|
||||||
self.time = times[time_idx]
|
self.time = times[time_idx]
|
||||||
Xt_1 = self.netG(Xt, self.time, z)
|
Xt_1 = self.netG(Xt, self.time, z)
|
||||||
|
|
||||||
Xt2 = self.mutil_real_A1_tokens if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
|
Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
|
||||||
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.mutil_real_A1_tokens.device)
|
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device)
|
||||||
time_idx = (t * torch.ones(size=[self.mutil_real_A1_tokens.shape[0]]).to(self.mutil_real_A1_tokens.device)).long()
|
time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long()
|
||||||
z = torch.randn(size=[self.mutil_real_A1_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
|
z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device)
|
||||||
Xt_12 = self.netG(Xt2, self.time, z)
|
Xt_12 = self.netG(Xt2, self.time, z)
|
||||||
|
|
||||||
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
|
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
|
||||||
self.real_A_noisy = Xt.detach()
|
self.real_A_noisy = Xt.detach()
|
||||||
self.real_A_noisy2 = Xt2.detach()
|
self.real_A_noisy2 = Xt2.detach()
|
||||||
# 保存noisy_map
|
|
||||||
print(f'after noisy map: {self.real_A_noisy.shape}')
|
|
||||||
self.noisy_map = self.real_A_noisy - self.mutil_real_A0_tokens
|
|
||||||
|
|
||||||
# ============ 第三步:拼接输入并执行网络推理 =============
|
# ============ 第三步:拼接输入并执行网络推理 =============
|
||||||
bs = self.mutil_real_A0_tokens.size(0)
|
bs = self.real_A0.size(0)
|
||||||
z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
|
z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.real_A0.device)
|
||||||
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
|
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
|
||||||
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
|
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
|
||||||
self.real = self.mutil_real_A0_tokens
|
self.real = self.real_A0
|
||||||
self.realt = self.real_A_noisy
|
self.realt = self.real_A_noisy
|
||||||
|
|
||||||
if self.opt.flip_equivariance:
|
if self.opt.flip_equivariance:
|
||||||
@ -480,65 +457,58 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.real = torch.flip(self.real, [3])
|
self.real = torch.flip(self.real, [3])
|
||||||
self.realt = torch.flip(self.realt, [3])
|
self.realt = torch.flip(self.realt, [3])
|
||||||
|
|
||||||
# 使用 netG 生成最终的 fake, fake_B2 等结果
|
|
||||||
self.fake_B = self.netG(self.realt, self.time, z_in)
|
|
||||||
self.fake_B2 = self.netG(self.real, self.time, z_in2)
|
|
||||||
|
|
||||||
self.fake_B = self.resize(self.fake_B)
|
self.fake_B0 = self.netG(self.real_A0, self.time, z_in)
|
||||||
self.fake_B2 = self.resize(self.fake_B2)
|
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
|
||||||
|
|
||||||
self.fake_B0 = self.fake_B
|
|
||||||
self.fake_B1 = self.fake_B2
|
|
||||||
|
|
||||||
# 使用VIT
|
|
||||||
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B, self.atten_layers, get_tokens=True)
|
|
||||||
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B2, self.atten_layers, get_tokens=True)
|
|
||||||
|
|
||||||
# ============ 第四步:推理模式下的多次采样 ============
|
|
||||||
if self.opt.phase == 'test':
|
|
||||||
tau = self.opt.tau
|
|
||||||
T = self.opt.num_timesteps
|
|
||||||
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
|
|
||||||
times = np.cumsum(incs)
|
|
||||||
times = times / times[-1]
|
|
||||||
times = 0.5 * times[-1] + 0.5 * times
|
|
||||||
times = np.concatenate([np.zeros(1),times])
|
|
||||||
times = torch.tensor(times).float().cuda()
|
|
||||||
self.times = times
|
|
||||||
bs = self.real.size(0)
|
|
||||||
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
|
|
||||||
self.time_idx = time_idx
|
|
||||||
visuals = []
|
|
||||||
with torch.no_grad():
|
|
||||||
self.netG.eval()
|
|
||||||
for t in range(self.opt.num_timesteps):
|
|
||||||
|
|
||||||
if t > 0:
|
|
||||||
delta = times[t] - times[t-1]
|
|
||||||
denom = times[-1] - times[t-1]
|
|
||||||
inter = (delta / denom).reshape(-1,1,1,1)
|
|
||||||
scale = (delta * (1 - delta / denom)).reshape(-1,1,1,1)
|
|
||||||
Xt = self.mutil_real_A0_tokens if (t == 0) else (1-inter) * Xt + inter * Xt_1.detach() + (scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
|
|
||||||
time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long()
|
|
||||||
time = times[time_idx]
|
|
||||||
z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
|
|
||||||
Xt_1 = self.netG(Xt, time_idx, z)
|
|
||||||
|
|
||||||
setattr(self, "fake_"+str(t+1), Xt_1)
|
|
||||||
|
|
||||||
if self.opt.phase == 'train':
|
if self.opt.phase == 'train':
|
||||||
# 生成图像的梯度
|
# 生成图像的梯度
|
||||||
|
print(f'self.fake_B0: {self.fake_B0.shape}')
|
||||||
fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0]
|
fake_gradient = torch.autograd.grad(self.fake_B0.sum(), self.fake_B0, create_graph=True)[0]
|
||||||
|
|
||||||
# 梯度图
|
# 梯度图
|
||||||
|
print(f'fake_gradient: {fake_gradient.shape}')
|
||||||
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
self.weight_fake = self.cao.generate_weight_map(fake_gradient)
|
||||||
|
|
||||||
# 生成图像的CTN光流图
|
# 生成图像的CTN光流图
|
||||||
|
print(f'weight_fake: {self.weight_fake.shape}')
|
||||||
self.f_content = self.ctn(self.weight_fake)
|
self.f_content = self.ctn(self.weight_fake)
|
||||||
|
|
||||||
# 变换后的图片
|
# 变换后的图片
|
||||||
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
|
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
|
||||||
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
|
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
|
||||||
|
|
||||||
# 经过第二次生成器
|
# 经过第二次生成器
|
||||||
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
|
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
|
||||||
|
|
||||||
|
if self.opt.isTrain:
|
||||||
|
real_A0 = self.real_A0
|
||||||
|
real_A1 = self.real_A1
|
||||||
|
real_B0 = self.real_B0
|
||||||
|
real_B1 = self.real_B1
|
||||||
|
fake_B0 = self.fake_B0
|
||||||
|
fake_B1 = self.fake_B1
|
||||||
|
warped_fake_B0_2=self.warped_fake_B0_2
|
||||||
|
warped_fake_B0=self.warped_fake_B0
|
||||||
|
|
||||||
|
self.real_A0_resize = self.resize(real_A0)
|
||||||
|
self.real_A1_resize = self.resize(real_A1)
|
||||||
|
real_B0 = self.resize(real_B0)
|
||||||
|
real_B1 = self.resize(real_B1)
|
||||||
|
self.fake_B0_resize = self.resize(fake_B0)
|
||||||
|
self.fake_B1_resize = self.resize(fake_B1)
|
||||||
|
self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
|
||||||
|
self.warped_fake_B0_resize = self.resize(warped_fake_B0)
|
||||||
|
|
||||||
|
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
||||||
|
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
||||||
|
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
|
||||||
|
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
||||||
|
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||||
|
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||||||
|
self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||||
|
self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
|
||||||
|
|
||||||
def compute_D_loss(self):
|
def compute_D_loss(self):
|
||||||
"""计算判别器的 GAN 损失"""
|
"""计算判别器的 GAN 损失"""
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user