From 515907f2e8fd8301a2c26e3fd4ef7b1bb8e1a256 Mon Sep 17 00:00:00 2001 From: 123456 <3351416005@qq.com> Date: Sun, 23 Feb 2025 15:23:00 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20models/rome=5Funsb=5Fmodel?= =?UTF-8?q?.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- models/rome_unsb_model.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/models/rome_unsb_model.py b/models/rome_unsb_model.py index fb62536..64a67f1 100644 --- a/models/rome_unsb_model.py +++ b/models/rome_unsb_model.py @@ -27,7 +27,7 @@ def warp(image, flow): #warp操作 grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W] grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W] - # 应用光流位移(归一化到[-1,1]) + # 应用光流位移(归一化到[-1,1]) new_grid = grid + flow new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向 new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向 @@ -90,10 +90,10 @@ class ContentAwareOptimization(nn.Module): # 计算生成图像块的余弦相似度 cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N] - # 选择内容丰富的区域(余弦相似度最低的eta_ratio比例) + # 选择内容丰富的区域(余弦相似度最低的eta_ratio比例) k = int(self.eta_ratio * cosine_fake.shape[1]) - # 对生成图像生成权重图(同理) + # 对生成图像生成权重图(同理) _, fake_indices = torch.topk(-cosine_fake, k, dim=1) weight_fake = torch.ones_like(cosine_fake) for b in range(cosine_fake.shape[0]): @@ -162,7 +162,7 @@ class ContentAwareTemporalNorm(nn.Module): """ 生成内容感知光流 Args: - weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块) + weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块) Returns: F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) """ @@ -172,19 +172,19 @@ class ContentAwareTemporalNorm(nn.Module): # 保持区域相对强度,同时限制数值范围 weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W] - # 2. 生成高斯噪声(与光流场同尺寸) + # 2. 生成高斯噪声(与光流场同尺寸) z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W] # 3. 合成基础光流 - # 将权重图扩展为2通道(x/y方向共享权重) + # 将权重图扩展为2通道(x/y方向共享权重) weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W] F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9 - # 4. 平滑处理(保持结构连续性) + # 4. 平滑处理(保持结构连续性) # 对每个通道独立进行高斯模糊 F_smooth = self.smoother(F_raw) # [B,2,H,W] - # 5. 动态范围调整(可选) + # 5. 动态范围调整(可选) # 限制光流幅值,避免极端位移 F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围 @@ -405,7 +405,7 @@ class CTNxModel(BaseModel): # 执行一次SB模块 # ============ 第一步:初始化时间步与时间索引 ============ - # 计算 times,并确定当前 time_idx(随机选取用来表示当前时间步) + # 计算 times,并确定当前 time_idx(随机选取用来表示当前时间步) tau = self.opt.tau T = self.opt.num_timesteps incs = np.array([0] + [1/(i+1) for i in range(T-1)])