保存一个版本
This commit is contained in:
parent
f98c285950
commit
537cb050a5
@ -13,6 +13,7 @@ import util.util as util
|
||||
|
||||
from torchvision.transforms import transforms as tfs
|
||||
|
||||
|
||||
def warp(image, flow): #warp操作
|
||||
"""
|
||||
基于光流的图像变形函数
|
||||
@ -37,76 +38,77 @@ def warp(image, flow): #warp操作
|
||||
# 双线性插值
|
||||
return F.grid_sample(image, new_grid, align_corners=True)
|
||||
|
||||
# 时序归一化损失计算
|
||||
def compute_ctn_loss(G, x, F_content): #公式10
|
||||
"""
|
||||
计算内容感知时序归一化损失
|
||||
Args:
|
||||
G: 生成器
|
||||
x: 输入红外图像 [B,C,H,W]
|
||||
F_content: 生成的光流场 [B,2,H,W]
|
||||
"""
|
||||
|
||||
# 生成可见光图像
|
||||
y_fake = G(x) # [B,3,H,W]
|
||||
|
||||
# 对生成结果应用光流变形
|
||||
warped_fake = warp(y_fake, F_content) # [B,3,H,W]
|
||||
|
||||
# 对输入应用相同光流后生成图像
|
||||
warped_x = warp(x, F_content) # [B,C,H,W]
|
||||
y_fake_warped = G(warped_x) # [B,3,H,W]
|
||||
|
||||
# 计算L2损失
|
||||
loss = F.mse_loss(warped_fake, y_fake_warped)
|
||||
return loss
|
||||
|
||||
|
||||
class ContentAwareOptimization(nn.Module):
|
||||
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
|
||||
super().__init__()
|
||||
self.lambda_inc = lambda_inc
|
||||
self.eta_ratio = eta_ratio
|
||||
self.gradients = [] # 修改为单一梯度列表,通用性更强
|
||||
self.criterionGAN = networks.GANLoss('lsgan').cuda()
|
||||
self.lambda_inc = lambda_inc # 控制内容丰富区域的权重增量
|
||||
self.eta_ratio = eta_ratio # 选择内容丰富区域的比例
|
||||
self.criterionGAN = networks.GANLoss('lsgan').cuda() # 使用 LSGAN 损失
|
||||
|
||||
def compute_cosine_similarity(self, gradients):
|
||||
mean_grad = torch.mean(gradients, dim=1, keepdim=True)
|
||||
return F.cosine_similarity(gradients, mean_grad, dim=2)
|
||||
|
||||
def generate_weight_map(self, gradients):
|
||||
cosine = self.compute_cosine_similarity(gradients)
|
||||
k = int(self.eta_ratio * cosine.shape[1])
|
||||
_, indices = torch.topk(-cosine, k, dim=1)
|
||||
weights = torch.ones_like(cosine)
|
||||
weights.scatter_(1, indices, self.lambda_inc / (1e-6 + torch.abs(cosine.gather(1, indices))))
|
||||
return weights
|
||||
|
||||
def forward(self, features, scores, target):
|
||||
def compute_cosine_similarity(self, grad_patch, grad_mean):
|
||||
"""
|
||||
计算每个 patch 梯度与整体平均梯度的余弦相似度
|
||||
Args:
|
||||
features: 特征张量(可以是判别器的 real/fake 特征,或生成器的 fake 特征)
|
||||
scores: 判别器对特征的预测得分
|
||||
target: 目标标签(True 表示希望判为真,False 表示希望判为假)
|
||||
grad_patch: [B, 1, H, W],每个 patch 的梯度(基于 scores)
|
||||
grad_mean: [B, 1],整体平均梯度
|
||||
Returns:
|
||||
loss: 加权后的 GAN 损失
|
||||
weight: 生成的权重图
|
||||
cosine: [B, 1, H, W],余弦相似度 δ_i
|
||||
"""
|
||||
self.gradients.clear()
|
||||
# 注册梯度钩子
|
||||
hook = lambda grad: self.gradients.append(grad.detach())
|
||||
features.register_hook(hook)
|
||||
B, _, H, W = grad_patch.shape
|
||||
grad_patch = grad_patch.view(B, 1, -1) # [B, 1, H*W]
|
||||
grad_mean = grad_mean.unsqueeze(-1) # [B, 1, 1]
|
||||
# 计算余弦相似度
|
||||
cosine = F.cosine_similarity(grad_patch, grad_mean, dim=1) # [B, H*W]
|
||||
return cosine.view(B, 1, H, W)
|
||||
|
||||
# 触发梯度计算
|
||||
scores.mean().backward(retain_graph=True)
|
||||
def generate_weight_map(self, cosine):
|
||||
"""
|
||||
根据余弦相似度生成权重图
|
||||
Args:
|
||||
cosine: [B, 1, H, W],余弦相似度 δ_i
|
||||
Returns:
|
||||
weights: [B, 1, H, W],权重图 w_i
|
||||
"""
|
||||
B, _, H, W = cosine.shape
|
||||
cosine_flat = cosine.view(B, -1) # [B, H*W]
|
||||
k = int(self.eta_ratio * cosine_flat.size(1)) # 选择 eta_ratio 比例的 patch
|
||||
_, indices = torch.topk(-cosine_flat, k, dim=1) # 选择偏离最大的 k 个 patch
|
||||
weights = torch.ones_like(cosine_flat)
|
||||
for b in range(B):
|
||||
selected_cosine = cosine_flat[b, indices[b]]
|
||||
weights[b, indices[b]] = self.lambda_inc / (torch.exp(torch.abs(selected_cosine)) + 1e-6)
|
||||
return weights.view(B, 1, H, W)
|
||||
|
||||
# 获取梯度并调整维度
|
||||
grad = self.gradients[0].flatten(1) # [B, N, D] → [B, N*D]
|
||||
weight = self.generate_weight_map(grad.view(*features.shape))
|
||||
def forward(self, scores, target):
|
||||
"""
|
||||
前向传播,计算加权后的 GAN 损失
|
||||
Args:
|
||||
scores: [B, 1, H, W],判别器的预测得分
|
||||
target: 目标标签(True 或 False)
|
||||
Returns:
|
||||
weighted_loss: 加权后的 GAN 损失
|
||||
weight: 权重图 [B, 1, H, W]
|
||||
"""
|
||||
# 计算原始 GAN 损失
|
||||
loss = self.criterionGAN(scores, target)
|
||||
|
||||
# 计算加权 GAN 损失
|
||||
loss = torch.mean(weight * self.criterionGAN(scores, target))
|
||||
return loss, weight
|
||||
# 捕获特征的梯度
|
||||
grad_scores = torch.autograd.grad(loss, scores, retain_graph=True)[0] # [B, C, H, W]
|
||||
|
||||
# 计算整体平均梯度
|
||||
grad_mean = torch.mean(grad_scores, dim=(2, 3)) # [B, 1]
|
||||
|
||||
# 计算余弦相似度 δ_i(公式 5)
|
||||
cosine = self.compute_cosine_similarity(grad_scores, grad_mean) # [B, 1, H, W]
|
||||
|
||||
# 生成权重图 w_i(公式 6)
|
||||
weight = self.generate_weight_map(cosine)
|
||||
|
||||
# 应用权重到损失(公式 7 的部分实现)
|
||||
weighted_loss = torch.mean(weight * self.criterionGAN(scores, target))
|
||||
|
||||
return weighted_loss, weight
|
||||
|
||||
class ContentAwareTemporalNorm(nn.Module):
|
||||
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
||||
@ -115,31 +117,14 @@ class ContentAwareTemporalNorm(nn.Module):
|
||||
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
||||
|
||||
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
|
||||
"""
|
||||
将patch级别的权重图上采样到目标分辨率
|
||||
Args:
|
||||
weight_patch: [B, 1, 24, 24] 来自ViT的patch权重图
|
||||
target_size: 目标分辨率 (H, W)
|
||||
Returns:
|
||||
weight_full: [B, 1, 256, 256] 上采样后的全分辨率权重图
|
||||
"""
|
||||
# 使用双线性插值上采样
|
||||
B = weight_patch.shape[0]
|
||||
weight_patch = weight_patch.view(B, 1, 24, 24)
|
||||
|
||||
# weight_patch: [B, 1, 30, 30] 来自 PatchGAN
|
||||
weight_full = F.interpolate(
|
||||
weight_patch,
|
||||
size=target_size,
|
||||
mode='bilinear',
|
||||
mode='bilinear', # 或 'nearest',根据需求选择
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# 对每个16x16的patch内部保持权重一致(可选)
|
||||
# 通过平均池化再扩展,消除插值引入的渐变
|
||||
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
|
||||
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
|
||||
|
||||
return weight_full
|
||||
return weight_full # [B, 1, 256, 256]
|
||||
|
||||
def forward(self, weight_map):
|
||||
"""
|
||||
@ -175,6 +160,7 @@ class ContentAwareTemporalNorm(nn.Module):
|
||||
|
||||
return F_content
|
||||
|
||||
|
||||
class RomaUnsbModel(BaseModel):
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train=True):
|
||||
@ -327,21 +313,22 @@ class RomaUnsbModel(BaseModel):
|
||||
|
||||
# 处理 real_B0 和 fake_B0
|
||||
real_B0_tokens = self.mutil_real_B0_tokens[0]
|
||||
pred_real0, real_features0 = self.netD_ViT(real_B0_tokens)
|
||||
pred_real0 = self.netD_ViT(real_B0_tokens)
|
||||
print(pred_real0.shape)
|
||||
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
|
||||
pred_fake0, fake_features0 = self.netD_ViT(fake_B0_tokens)
|
||||
pred_fake0 = self.netD_ViT(fake_B0_tokens)
|
||||
|
||||
loss_real0, self.weight_real0 = self.cao(real_features0, pred_real0, True)
|
||||
loss_fake0, self.weight_fake0 = self.cao(fake_features0, pred_fake0, False)
|
||||
loss_real0, self.weight_real0 = self.cao( pred_real0, True)
|
||||
loss_fake0, self.weight_fake0 = self.cao( pred_fake0, False)
|
||||
|
||||
# 处理 real_B1 和 fake_B1
|
||||
real_B1_tokens = self.mutil_real_B1_tokens[0]
|
||||
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens)
|
||||
pred_real1 = self.netD_ViT(real_B1_tokens)
|
||||
fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach()
|
||||
pred_fake1, fake_features1 = self.netD_ViT(fake_B1_tokens)
|
||||
pred_fake1 = self.netD_ViT(fake_B1_tokens)
|
||||
|
||||
loss_real1, self.weight_real1 = self.cao(real_features1, pred_real1, True)
|
||||
loss_fake1, self.weight_fake1 = self.cao(fake_features1, pred_fake1, False)
|
||||
loss_real1, self.weight_real1 = self.cao( pred_real1, True)
|
||||
loss_fake1, self.weight_fake1 = self.cao( pred_fake1, False)
|
||||
|
||||
# 综合损失
|
||||
self.loss_D_ViT = (loss_real0 + loss_fake0 + loss_real1 + loss_fake1) * 0.25 * lambda_D_ViT
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user