更新 models/rome_unsb_model.py

This commit is contained in:
123456 2025-02-23 15:23:00 +08:00
parent e5accb1d4c
commit 515907f2e8

View File

@ -27,7 +27,7 @@ def warp(image, flow): #warp操作
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W] grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W] grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
# 应用光流位移(归一化到[-1,1] # 应用光流位移(归一化到[-1,1])
new_grid = grid + flow new_grid = grid + flow
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向 new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向 new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
@ -90,10 +90,10 @@ class ContentAwareOptimization(nn.Module):
# 计算生成图像块的余弦相似度 # 计算生成图像块的余弦相似度
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N] cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
# 选择内容丰富的区域余弦相似度最低的eta_ratio比例 # 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
k = int(self.eta_ratio * cosine_fake.shape[1]) k = int(self.eta_ratio * cosine_fake.shape[1])
# 对生成图像生成权重图(同理) # 对生成图像生成权重图(同理)
_, fake_indices = torch.topk(-cosine_fake, k, dim=1) _, fake_indices = torch.topk(-cosine_fake, k, dim=1)
weight_fake = torch.ones_like(cosine_fake) weight_fake = torch.ones_like(cosine_fake)
for b in range(cosine_fake.shape[0]): for b in range(cosine_fake.shape[0]):
@ -162,7 +162,7 @@ class ContentAwareTemporalNorm(nn.Module):
""" """
生成内容感知光流 生成内容感知光流
Args: Args:
weight_map: [B, 1, H, W] 权重图来自内容感知优化模块 weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
Returns: Returns:
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
""" """
@ -172,19 +172,19 @@ class ContentAwareTemporalNorm(nn.Module):
# 保持区域相对强度,同时限制数值范围 # 保持区域相对强度,同时限制数值范围
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W] weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
# 2. 生成高斯噪声(与光流场同尺寸) # 2. 生成高斯噪声(与光流场同尺寸)
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W] z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
# 3. 合成基础光流 # 3. 合成基础光流
# 将权重图扩展为2通道x/y方向共享权重 # 将权重图扩展为2通道(x/y方向共享权重)
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W] weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9 F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
# 4. 平滑处理(保持结构连续性) # 4. 平滑处理(保持结构连续性)
# 对每个通道独立进行高斯模糊 # 对每个通道独立进行高斯模糊
F_smooth = self.smoother(F_raw) # [B,2,H,W] F_smooth = self.smoother(F_raw) # [B,2,H,W]
# 5. 动态范围调整(可选) # 5. 动态范围调整(可选)
# 限制光流幅值,避免极端位移 # 限制光流幅值,避免极端位移
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围 F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
@ -405,7 +405,7 @@ class CTNxModel(BaseModel):
# 执行一次SB模块 # 执行一次SB模块
# ============ 第一步:初始化时间步与时间索引 ============ # ============ 第一步:初始化时间步与时间索引 ============
# 计算 times并确定当前 time_idx(随机选取用来表示当前时间步) # 计算 times并确定当前 time_idx(随机选取用来表示当前时间步)
tau = self.opt.tau tau = self.opt.tau
T = self.opt.num_timesteps T = self.opt.num_timesteps
incs = np.array([0] + [1/(i+1) for i in range(T-1)]) incs = np.array([0] + [1/(i+1) for i in range(T-1)])