diff --git a/models/rome_unsb_model.py b/models/rome_unsb_model.py index 952a097..fb62536 100644 --- a/models/rome_unsb_model.py +++ b/models/rome_unsb_model.py @@ -17,7 +17,7 @@ def warp(image, flow): #warp操作 基于光流的图像变形函数 Args: image: [B, C, H, W] 输入图像 - flow: [B, 2, H, W] 光流场(x/y方向位移) + flow: [B, 2, H, W] 光流场(x/y方向位移) Returns: warped: [B, C, H, W] 变形后的图像 """ @@ -70,7 +70,7 @@ class ContentAwareOptimization(nn.Module): """ 计算每个patch梯度与平均梯度的余弦相似度 Args: - gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h) + gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h) Returns: cosine_sim: [B, N] 每个patch的余弦相似度 """ @@ -164,7 +164,7 @@ class ContentAwareTemporalNorm(nn.Module): Args: weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块) Returns: - F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) + F_content: [B, 2, H, W] 生成的光流场(x/y方向位移) """ B, _, H, W = weight_map.shape @@ -195,7 +195,7 @@ class CTNxModel(BaseModel): def modify_commandline_options(parser, is_train=True): """配置 CTNx 模型的特定选项""" - parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))') + parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))') parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)') parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss') parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')