192 lines
7.4 KiB
Python
192 lines
7.4 KiB
Python
|
|
import torch
|
|||
|
|
import torch.nn as nn
|
|||
|
|
import torch.nn.functional as F
|
|||
|
|
from torchvision.transforms import GaussianBlur
|
|||
|
|
|
|||
|
|
def warp(image, flow): #warp操作
|
|||
|
|
"""
|
|||
|
|
基于光流的图像变形函数
|
|||
|
|
Args:
|
|||
|
|
image: [B, C, H, W] 输入图像
|
|||
|
|
flow: [B, 2, H, W] 光流场(x/y方向位移)
|
|||
|
|
Returns:
|
|||
|
|
warped: [B, C, H, W] 变形后的图像
|
|||
|
|
"""
|
|||
|
|
B, C, H, W = image.shape
|
|||
|
|
# 生成网格坐标
|
|||
|
|
grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H))
|
|||
|
|
grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
|
|||
|
|
grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
|
|||
|
|
|
|||
|
|
# 应用光流位移(归一化到[-1,1])
|
|||
|
|
new_grid = grid + flow
|
|||
|
|
new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
|
|||
|
|
new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
|
|||
|
|
new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2]
|
|||
|
|
|
|||
|
|
# 双线性插值
|
|||
|
|
return F.grid_sample(image, new_grid, align_corners=True)
|
|||
|
|
|
|||
|
|
# 时序归一化损失计算
|
|||
|
|
def compute_ctn_loss(G, x, F_content): #公式10
|
|||
|
|
"""
|
|||
|
|
计算内容感知时序归一化损失
|
|||
|
|
Args:
|
|||
|
|
G: 生成器
|
|||
|
|
x: 输入红外图像 [B,C,H,W]
|
|||
|
|
F_content: 生成的光流场 [B,2,H,W]
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 生成可见光图像
|
|||
|
|
y_fake = G(x) # [B,3,H,W]
|
|||
|
|
|
|||
|
|
# 对生成结果应用光流变形
|
|||
|
|
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
|
|||
|
|
|
|||
|
|
# 对输入应用相同光流后生成图像
|
|||
|
|
warped_x = warp(x, F_content) # [B,C,H,W]
|
|||
|
|
y_fake_warped = G(warped_x) # [B,3,H,W]
|
|||
|
|
|
|||
|
|
# 计算L2损失
|
|||
|
|
loss = F.mse_loss(warped_fake, y_fake_warped)
|
|||
|
|
return loss
|
|||
|
|
|
|||
|
|
class ContentAwareOptimization(nn.Module):
|
|||
|
|
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
|
|||
|
|
super().__init__()
|
|||
|
|
self.lambda_inc = lambda_inc # 权重增强系数
|
|||
|
|
self.eta_ratio = eta_ratio # 选择内容区域的比例
|
|||
|
|
|
|||
|
|
def compute_cosine_similarity(self, gradients):
|
|||
|
|
"""
|
|||
|
|
计算每个patch梯度与平均梯度的余弦相似度
|
|||
|
|
Args:
|
|||
|
|
gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h)
|
|||
|
|
Returns:
|
|||
|
|
cosine_sim: [B, N] 每个patch的余弦相似度
|
|||
|
|
"""
|
|||
|
|
mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D]
|
|||
|
|
# 计算余弦相似度
|
|||
|
|
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
|
|||
|
|
return cosine_sim
|
|||
|
|
|
|||
|
|
def generate_weight_map(self, gradients_real, gradients_fake):
|
|||
|
|
"""
|
|||
|
|
生成内容感知权重图
|
|||
|
|
Args:
|
|||
|
|
gradients_real: [B, N, D] 真实图像判别器梯度
|
|||
|
|
gradients_fake: [B, N, D] 生成图像判别器梯度
|
|||
|
|
Returns:
|
|||
|
|
weight_real: [B, N] 真实图像权重图
|
|||
|
|
weight_fake: [B, N] 生成图像权重图
|
|||
|
|
"""
|
|||
|
|
# 计算真实图像块的余弦相似度
|
|||
|
|
cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5
|
|||
|
|
# 计算生成图像块的余弦相似度
|
|||
|
|
cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
|
|||
|
|
|
|||
|
|
# 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
|
|||
|
|
k = int(self.eta_ratio * cosine_real.shape[1])
|
|||
|
|
|
|||
|
|
# 对真实图像生成权重图
|
|||
|
|
_, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域
|
|||
|
|
weight_real = torch.ones_like(cosine_real)
|
|||
|
|
for b in range(cosine_real.shape[0]):
|
|||
|
|
weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6
|
|||
|
|
|
|||
|
|
# 对生成图像生成权重图(同理)
|
|||
|
|
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
|
|||
|
|
weight_fake = torch.ones_like(cosine_fake)
|
|||
|
|
for b in range(cosine_fake.shape[0]):
|
|||
|
|
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
|
|||
|
|
|
|||
|
|
return weight_real, weight_fake
|
|||
|
|
|
|||
|
|
def forward(self, D_real, D_fake, real_scores, fake_scores):
|
|||
|
|
"""
|
|||
|
|
计算内容感知对抗损失
|
|||
|
|
Args:
|
|||
|
|
D_real: 判别器对真实图像的特征输出 [B, C, H, W]
|
|||
|
|
D_fake: 判别器对生成图像的特征输出 [B, C, H, W]
|
|||
|
|
real_scores: 真实图像的判别器预测 [B, N] (N=H*W)
|
|||
|
|
fake_scores: 生成图像的判别器预测 [B, N]
|
|||
|
|
Returns:
|
|||
|
|
loss_co_adv: 内容感知对抗损失
|
|||
|
|
"""
|
|||
|
|
B, C, H, W = D_real.shape
|
|||
|
|
N = H * W
|
|||
|
|
|
|||
|
|
# 注册钩子获取梯度
|
|||
|
|
gradients_real = []
|
|||
|
|
gradients_fake = []
|
|||
|
|
|
|||
|
|
def hook_real(grad):
|
|||
|
|
gradients_real.append(grad.detach().view(B, N, -1))
|
|||
|
|
|
|||
|
|
def hook_fake(grad):
|
|||
|
|
gradients_fake.append(grad.detach().view(B, N, -1))
|
|||
|
|
|
|||
|
|
D_real.register_hook(hook_real)
|
|||
|
|
D_fake.register_hook(hook_fake)
|
|||
|
|
|
|||
|
|
# 计算原始对抗损失以触发梯度计算
|
|||
|
|
loss_real = torch.mean(torch.log(real_scores + 1e-8))
|
|||
|
|
loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8))
|
|||
|
|
# 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递
|
|||
|
|
loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum())
|
|||
|
|
total_loss = loss_real + loss_fake + loss_dummy
|
|||
|
|
total_loss.backward(retain_graph=True)
|
|||
|
|
|
|||
|
|
# 获取梯度数据
|
|||
|
|
gradients_real = gradients_real[0] # [B, N, D]
|
|||
|
|
gradients_fake = gradients_fake[0] # [B, N, D]
|
|||
|
|
|
|||
|
|
# 生成权重图
|
|||
|
|
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
|
|||
|
|
|
|||
|
|
# 应用权重到对抗损失
|
|||
|
|
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
|
|||
|
|
loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8))
|
|||
|
|
|
|||
|
|
# 计算并返回最终内容感知对抗损失
|
|||
|
|
loss_co_adv = -(loss_co_real + loss_co_fake)
|
|||
|
|
|
|||
|
|
return loss_co_adv
|
|||
|
|
|
|||
|
|
class ContentAwareTemporalNorm(nn.Module):
|
|||
|
|
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
|||
|
|
super().__init__()
|
|||
|
|
self.gamma_stride = gamma_stride # 控制整体运动幅度
|
|||
|
|
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
|||
|
|
|
|||
|
|
def forward(self, weight_map):
|
|||
|
|
"""
|
|||
|
|
生成内容感知光流
|
|||
|
|
Args:
|
|||
|
|
weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
|
|||
|
|
Returns:
|
|||
|
|
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
|||
|
|
"""
|
|||
|
|
B, _, H, W = weight_map.shape
|
|||
|
|
|
|||
|
|
# 1. 归一化权重图
|
|||
|
|
# 保持区域相对强度,同时限制数值范围
|
|||
|
|
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
|
|||
|
|
|
|||
|
|
# 2. 生成高斯噪声(与光流场同尺寸)
|
|||
|
|
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
|
|||
|
|
|
|||
|
|
# 3. 合成基础光流
|
|||
|
|
# 将权重图扩展为2通道(x/y方向共享权重)
|
|||
|
|
weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
|
|||
|
|
F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
|
|||
|
|
|
|||
|
|
# 4. 平滑处理(保持结构连续性)
|
|||
|
|
# 对每个通道独立进行高斯模糊
|
|||
|
|
F_smooth = self.smoother(F_raw) # [B,2,H,W]
|
|||
|
|
|
|||
|
|
# 5. 动态范围调整(可选)
|
|||
|
|
# 限制光流幅值,避免极端位移
|
|||
|
|
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
|
|||
|
|
|
|||
|
|
return F_content
|