保存一个版本
This commit is contained in:
parent
f98c285950
commit
537cb050a5
@ -13,6 +13,7 @@ import util.util as util
|
|||||||
|
|
||||||
from torchvision.transforms import transforms as tfs
|
from torchvision.transforms import transforms as tfs
|
||||||
|
|
||||||
|
|
||||||
def warp(image, flow): #warp操作
|
def warp(image, flow): #warp操作
|
||||||
"""
|
"""
|
||||||
基于光流的图像变形函数
|
基于光流的图像变形函数
|
||||||
@ -37,76 +38,77 @@ def warp(image, flow): #warp操作
|
|||||||
# 双线性插值
|
# 双线性插值
|
||||||
return F.grid_sample(image, new_grid, align_corners=True)
|
return F.grid_sample(image, new_grid, align_corners=True)
|
||||||
|
|
||||||
# 时序归一化损失计算
|
|
||||||
def compute_ctn_loss(G, x, F_content): #公式10
|
|
||||||
"""
|
|
||||||
计算内容感知时序归一化损失
|
|
||||||
Args:
|
|
||||||
G: 生成器
|
|
||||||
x: 输入红外图像 [B,C,H,W]
|
|
||||||
F_content: 生成的光流场 [B,2,H,W]
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 生成可见光图像
|
|
||||||
y_fake = G(x) # [B,3,H,W]
|
|
||||||
|
|
||||||
# 对生成结果应用光流变形
|
|
||||||
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
|
|
||||||
|
|
||||||
# 对输入应用相同光流后生成图像
|
|
||||||
warped_x = warp(x, F_content) # [B,C,H,W]
|
|
||||||
y_fake_warped = G(warped_x) # [B,3,H,W]
|
|
||||||
|
|
||||||
# 计算L2损失
|
|
||||||
loss = F.mse_loss(warped_fake, y_fake_warped)
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
class ContentAwareOptimization(nn.Module):
|
class ContentAwareOptimization(nn.Module):
|
||||||
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
|
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lambda_inc = lambda_inc
|
self.lambda_inc = lambda_inc # 控制内容丰富区域的权重增量
|
||||||
self.eta_ratio = eta_ratio
|
self.eta_ratio = eta_ratio # 选择内容丰富区域的比例
|
||||||
self.gradients = [] # 修改为单一梯度列表,通用性更强
|
self.criterionGAN = networks.GANLoss('lsgan').cuda() # 使用 LSGAN 损失
|
||||||
self.criterionGAN = networks.GANLoss('lsgan').cuda()
|
|
||||||
|
|
||||||
def compute_cosine_similarity(self, gradients):
|
def compute_cosine_similarity(self, grad_patch, grad_mean):
|
||||||
mean_grad = torch.mean(gradients, dim=1, keepdim=True)
|
|
||||||
return F.cosine_similarity(gradients, mean_grad, dim=2)
|
|
||||||
|
|
||||||
def generate_weight_map(self, gradients):
|
|
||||||
cosine = self.compute_cosine_similarity(gradients)
|
|
||||||
k = int(self.eta_ratio * cosine.shape[1])
|
|
||||||
_, indices = torch.topk(-cosine, k, dim=1)
|
|
||||||
weights = torch.ones_like(cosine)
|
|
||||||
weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices))))
|
|
||||||
return weights
|
|
||||||
|
|
||||||
def forward(self, features, scores, target):
|
|
||||||
"""
|
"""
|
||||||
|
计算每个 patch 梯度与整体平均梯度的余弦相似度
|
||||||
Args:
|
Args:
|
||||||
features: 特征张量(可以是判别器的 real/fake 特征,或生成器的 fake 特征)
|
grad_patch: [B, 1, H, W],每个 patch 的梯度(基于 scores)
|
||||||
scores: 判别器对特征的预测得分
|
grad_mean: [B, 1],整体平均梯度
|
||||||
target: 目标标签(True 表示希望判为真,False 表示希望判为假)
|
|
||||||
Returns:
|
Returns:
|
||||||
loss: 加权后的 GAN 损失
|
cosine: [B, 1, H, W],余弦相似度 δ_i
|
||||||
weight: 生成的权重图
|
|
||||||
"""
|
"""
|
||||||
self.gradients.clear()
|
B, _, H, W = grad_patch.shape
|
||||||
# 注册梯度钩子
|
grad_patch = grad_patch.view(B, 1, -1) # [B, 1, H*W]
|
||||||
hook = lambda grad: self.gradients.append(grad.detach())
|
grad_mean = grad_mean.unsqueeze(-1) # [B, 1, 1]
|
||||||
features.register_hook(hook)
|
# 计算余弦相似度
|
||||||
|
cosine = F.cosine_similarity(grad_patch, grad_mean, dim=1) # [B, H*W]
|
||||||
|
return cosine.view(B, 1, H, W)
|
||||||
|
|
||||||
# 触发梯度计算
|
def generate_weight_map(self, cosine):
|
||||||
scores.mean().backward(retain_graph=True)
|
"""
|
||||||
|
根据余弦相似度生成权重图
|
||||||
|
Args:
|
||||||
|
cosine: [B, 1, H, W],余弦相似度 δ_i
|
||||||
|
Returns:
|
||||||
|
weights: [B, 1, H, W],权重图 w_i
|
||||||
|
"""
|
||||||
|
B, _, H, W = cosine.shape
|
||||||
|
cosine_flat = cosine.view(B, -1) # [B, H*W]
|
||||||
|
k = int(self.eta_ratio * cosine_flat.size(1)) # 选择 eta_ratio 比例的 patch
|
||||||
|
_, indices = torch.topk(-cosine_flat, k, dim=1) # 选择偏离最大的 k 个 patch
|
||||||
|
weights = torch.ones_like(cosine_flat)
|
||||||
|
for b in range(B):
|
||||||
|
selected_cosine = cosine_flat[b, indices[b]]
|
||||||
|
weights[b, indices[b]] = self.lambda_inc / (torch.exp(torch.abs(selected_cosine)) + 1e-6)
|
||||||
|
return weights.view(B, 1, H, W)
|
||||||
|
|
||||||
# 获取梯度并调整维度
|
def forward(self, scores, target):
|
||||||
grad = self.gradients[0].flatten(1) # [B, N, D] → [B, N*D]
|
"""
|
||||||
weight = self.generate_weight_map(grad.view(*features.shape))
|
前向传播,计算加权后的 GAN 损失
|
||||||
|
Args:
|
||||||
|
scores: [B, 1, H, W],判别器的预测得分
|
||||||
|
target: 目标标签(True 或 False)
|
||||||
|
Returns:
|
||||||
|
weighted_loss: 加权后的 GAN 损失
|
||||||
|
weight: 权重图 [B, 1, H, W]
|
||||||
|
"""
|
||||||
|
# 计算原始 GAN 损失
|
||||||
|
loss = self.criterionGAN(scores, target)
|
||||||
|
|
||||||
# 计算加权 GAN 损失
|
# 捕获特征的梯度
|
||||||
loss = torch.mean(weight * self.criterionGAN(scores, target))
|
grad_scores = torch.autograd.grad(loss, scores, retain_graph=True)[0] # [B, C, H, W]
|
||||||
return loss, weight
|
|
||||||
|
# 计算整体平均梯度
|
||||||
|
grad_mean = torch.mean(grad_scores, dim=(2, 3)) # [B, 1]
|
||||||
|
|
||||||
|
# 计算余弦相似度 δ_i(公式 5)
|
||||||
|
cosine = self.compute_cosine_similarity(grad_scores, grad_mean) # [B, 1, H, W]
|
||||||
|
|
||||||
|
# 生成权重图 w_i(公式 6)
|
||||||
|
weight = self.generate_weight_map(cosine)
|
||||||
|
|
||||||
|
# 应用权重到损失(公式 7 的部分实现)
|
||||||
|
weighted_loss = torch.mean(weight * self.criterionGAN(scores, target))
|
||||||
|
|
||||||
|
return weighted_loss, weight
|
||||||
|
|
||||||
class ContentAwareTemporalNorm(nn.Module):
|
class ContentAwareTemporalNorm(nn.Module):
|
||||||
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
||||||
@ -115,31 +117,14 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
||||||
|
|
||||||
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
|
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
|
||||||
"""
|
# weight_patch: [B, 1, 30, 30] 来自 PatchGAN
|
||||||
将patch级别的权重图上采样到目标分辨率
|
|
||||||
Args:
|
|
||||||
weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图
|
|
||||||
target_size: 目标分辨率 (H, W)
|
|
||||||
Returns:
|
|
||||||
weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图
|
|
||||||
"""
|
|
||||||
# 使用双线性插值上采样
|
|
||||||
B = weight_patch.shape[0]
|
|
||||||
weight_patch = weight_patch.view(B, 1, 24, 24)
|
|
||||||
|
|
||||||
weight_full = F.interpolate(
|
weight_full = F.interpolate(
|
||||||
weight_patch,
|
weight_patch,
|
||||||
size=target_size,
|
size=target_size,
|
||||||
mode='bilinear',
|
mode='bilinear', # 或 'nearest',根据需求选择
|
||||||
align_corners=False
|
align_corners=False
|
||||||
)
|
)
|
||||||
|
return weight_full # [B, 1, 256, 256]
|
||||||
# 对每个16x16的patch内部保持权重一致(可选)
|
|
||||||
# 通过平均池化再扩展,消除插值引入的渐变
|
|
||||||
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
|
|
||||||
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
|
|
||||||
|
|
||||||
return weight_full
|
|
||||||
|
|
||||||
def forward(self, weight_map):
|
def forward(self, weight_map):
|
||||||
"""
|
"""
|
||||||
@ -175,6 +160,7 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
|
|
||||||
return F_content
|
return F_content
|
||||||
|
|
||||||
|
|
||||||
class RomaUnsbModel(BaseModel):
|
class RomaUnsbModel(BaseModel):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def modify_commandline_options(parser, is_train=True):
|
def modify_commandline_options(parser, is_train=True):
|
||||||
@ -327,21 +313,22 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
# 处理 real_B0 和 fake_B0
|
# 处理 real_B0 和 fake_B0
|
||||||
real_B0_tokens = self.mutil_real_B0_tokens[0]
|
real_B0_tokens = self.mutil_real_B0_tokens[0]
|
||||||
pred_real0, real_features0 = self.netD_ViT(real_B0_tokens)
|
pred_real0 = self.netD_ViT(real_B0_tokens)
|
||||||
|
print(pred_real0.shape)
|
||||||
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
|
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
|
||||||
pred_fake0, fake_features0 = self.netD_ViT(fake_B0_tokens)
|
pred_fake0 = self.netD_ViT(fake_B0_tokens)
|
||||||
|
|
||||||
loss_real0, self.weight_real0 = self.cao(real_features0, pred_real0, True)
|
loss_real0, self.weight_real0 = self.cao( pred_real0, True)
|
||||||
loss_fake0, self.weight_fake0 = self.cao(fake_features0, pred_fake0, False)
|
loss_fake0, self.weight_fake0 = self.cao( pred_fake0, False)
|
||||||
|
|
||||||
# 处理 real_B1 和 fake_B1
|
# 处理 real_B1 和 fake_B1
|
||||||
real_B1_tokens = self.mutil_real_B1_tokens[0]
|
real_B1_tokens = self.mutil_real_B1_tokens[0]
|
||||||
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens)
|
pred_real1 = self.netD_ViT(real_B1_tokens)
|
||||||
fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach()
|
fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach()
|
||||||
pred_fake1, fake_features1 = self.netD_ViT(fake_B1_tokens)
|
pred_fake1 = self.netD_ViT(fake_B1_tokens)
|
||||||
|
|
||||||
loss_real1, self.weight_real1 = self.cao(real_features1, pred_real1, True)
|
loss_real1, self.weight_real1 = self.cao( pred_real1, True)
|
||||||
loss_fake1, self.weight_fake1 = self.cao(fake_features1, pred_fake1, False)
|
loss_fake1, self.weight_fake1 = self.cao( pred_fake1, False)
|
||||||
|
|
||||||
# 综合损失
|
# 综合损失
|
||||||
self.loss_D_ViT = (loss_real0 + loss_fake0 + loss_real1 + loss_fake1) * 0.25 * lambda_D_ViT
|
self.loss_D_ViT = (loss_real0 + loss_fake0 + loss_real1 + loss_fake1) * 0.25 * lambda_D_ViT
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user