更新 models/rome_unsb_model.py
This commit is contained in:
parent
e5accb1d4c
commit
515907f2e8
@ -27,7 +27,7 @@ def warp(image, flow): #warp操作
|
|||||||
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
|
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
|
||||||
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
|
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
|
||||||
|
|
||||||
# 应用光流位移(归一化到[-1,1])
|
# 应用光流位移(归一化到[-1,1])
|
||||||
new_grid = grid + flow
|
new_grid = grid + flow
|
||||||
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
|
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
|
||||||
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
|
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
|
||||||
@ -90,10 +90,10 @@ class ContentAwareOptimization(nn.Module):
|
|||||||
# 计算生成图像块的余弦相似度
|
# 计算生成图像块的余弦相似度
|
||||||
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
|
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
|
||||||
|
|
||||||
# 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
|
# 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
|
||||||
k = int(self.eta_ratio * cosine_fake.shape[1])
|
k = int(self.eta_ratio * cosine_fake.shape[1])
|
||||||
|
|
||||||
# 对生成图像生成权重图(同理)
|
# 对生成图像生成权重图(同理)
|
||||||
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
|
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
|
||||||
weight_fake = torch.ones_like(cosine_fake)
|
weight_fake = torch.ones_like(cosine_fake)
|
||||||
for b in range(cosine_fake.shape[0]):
|
for b in range(cosine_fake.shape[0]):
|
||||||
@ -162,7 +162,7 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
"""
|
"""
|
||||||
生成内容感知光流
|
生成内容感知光流
|
||||||
Args:
|
Args:
|
||||||
weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
|
weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
|
||||||
Returns:
|
Returns:
|
||||||
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
||||||
"""
|
"""
|
||||||
@ -172,19 +172,19 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
# 保持区域相对强度,同时限制数值范围
|
# 保持区域相对强度,同时限制数值范围
|
||||||
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
|
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
|
||||||
|
|
||||||
# 2. 生成高斯噪声(与光流场同尺寸)
|
# 2. 生成高斯噪声(与光流场同尺寸)
|
||||||
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
|
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
|
||||||
|
|
||||||
# 3. 合成基础光流
|
# 3. 合成基础光流
|
||||||
# 将权重图扩展为2通道(x/y方向共享权重)
|
# 将权重图扩展为2通道(x/y方向共享权重)
|
||||||
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
|
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
|
||||||
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
|
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
|
||||||
|
|
||||||
# 4. 平滑处理(保持结构连续性)
|
# 4. 平滑处理(保持结构连续性)
|
||||||
# 对每个通道独立进行高斯模糊
|
# 对每个通道独立进行高斯模糊
|
||||||
F_smooth = self.smoother(F_raw) # [B,2,H,W]
|
F_smooth = self.smoother(F_raw) # [B,2,H,W]
|
||||||
|
|
||||||
# 5. 动态范围调整(可选)
|
# 5. 动态范围调整(可选)
|
||||||
# 限制光流幅值,避免极端位移
|
# 限制光流幅值,避免极端位移
|
||||||
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
|
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
|
||||||
|
|
||||||
@ -405,7 +405,7 @@ class CTNxModel(BaseModel):
|
|||||||
# 执行一次SB模块
|
# 执行一次SB模块
|
||||||
|
|
||||||
# ============ 第一步:初始化时间步与时间索引 ============
|
# ============ 第一步:初始化时间步与时间索引 ============
|
||||||
# 计算 times,并确定当前 time_idx(随机选取用来表示当前时间步)
|
# 计算 times,并确定当前 time_idx(随机选取用来表示当前时间步)
|
||||||
tau = self.opt.tau
|
tau = self.opt.tau
|
||||||
T = self.opt.num_timesteps
|
T = self.opt.num_timesteps
|
||||||
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
|
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user