Compare commits
6 Commits
cpwithatte
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
133f609e79 | ||
|
|
26b770a3c1 | ||
|
|
9850183607 | ||
|
|
e67b0f2511 | ||
|
|
7af2de920c | ||
|
|
55b9db967a |
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,5 +0,0 @@
|
|||||||
checkpoints/
|
|
||||||
*.log
|
|
||||||
*.pth
|
|
||||||
*.ckpt
|
|
||||||
__pycache__/
|
|
||||||
80
checkpoints/ROMA_UNSB_001/loss_log.txt
Normal file
80
checkpoints/ROMA_UNSB_001/loss_log.txt
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
================ Training Loss (Sun Feb 23 15:46:44 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 15:52:29 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 16:00:07 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 16:02:40 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 16:05:19 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 16:06:44 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 16:09:38 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 16:44:56 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 16:49:46 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 16:51:03 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 16:51:23 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:04:02 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:04:39 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:05:17 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:06:40 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:11:48 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:13:31 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:14:11 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:14:29 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:16:27 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:16:44 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:20:39 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:21:44 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:35:27 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:39:21 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:40:15 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:41:15 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:47:46 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:48:36 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:50:20 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:51:50 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:58:45 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:59:52 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 19:03:05 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 19:03:57 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 21:11:47 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 21:17:10 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 21:20:14 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 21:29:03 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 21:34:57 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 21:35:26 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:28:43 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:29:04 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:29:52 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:30:40 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:33:48 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:39:16 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:39:48 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:41:34 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:42:01 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:44:17 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:45:53 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:46:48 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:47:42 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:49:44 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:50:29 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:51:47 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:55:56 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:56:19 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:57:58 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 22:59:09 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 23:02:36 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 23:03:56 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 23:09:21 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 23:10:05 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 23:11:43 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 23:12:41 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 23:13:05 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 23:13:59 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 23:14:59 2025) ================
|
||||||
|
================ Training Loss (Mon Feb 24 21:53:50 2025) ================
|
||||||
|
================ Training Loss (Mon Feb 24 21:54:16 2025) ================
|
||||||
|
================ Training Loss (Mon Feb 24 21:54:50 2025) ================
|
||||||
|
================ Training Loss (Mon Feb 24 21:55:31 2025) ================
|
||||||
|
================ Training Loss (Mon Feb 24 21:56:10 2025) ================
|
||||||
|
================ Training Loss (Mon Feb 24 22:09:38 2025) ================
|
||||||
|
================ Training Loss (Mon Feb 24 22:10:16 2025) ================
|
||||||
|
================ Training Loss (Mon Feb 24 22:12:46 2025) ================
|
||||||
|
================ Training Loss (Mon Feb 24 22:13:04 2025) ================
|
||||||
|
================ Training Loss (Mon Feb 24 22:14:04 2025) ================
|
||||||
88
checkpoints/ROMA_UNSB_001/train_opt.txt
Normal file
88
checkpoints/ROMA_UNSB_001/train_opt.txt
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
----------------- Options ---------------
|
||||||
|
adj_size_list: [2, 4, 6, 8, 12]
|
||||||
|
atten_layers: 1,3,5
|
||||||
|
batch_size: 1
|
||||||
|
beta1: 0.5
|
||||||
|
beta2: 0.999
|
||||||
|
checkpoints_dir: ./checkpoints
|
||||||
|
continue_train: False
|
||||||
|
crop_size: 256
|
||||||
|
dataroot: /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor [default: placeholder]
|
||||||
|
dataset_mode: unaligned_double [default: unaligned]
|
||||||
|
direction: AtoB
|
||||||
|
display_env: ROMA [default: main]
|
||||||
|
display_freq: 50
|
||||||
|
display_id: None
|
||||||
|
display_ncols: 4
|
||||||
|
display_port: 8097
|
||||||
|
display_server: http://localhost
|
||||||
|
display_winsize: 256
|
||||||
|
easy_label: experiment_name
|
||||||
|
epoch: latest
|
||||||
|
epoch_count: 1
|
||||||
|
eta_ratio: 0.1
|
||||||
|
evaluation_freq: 5000
|
||||||
|
flip_equivariance: False
|
||||||
|
gan_mode: lsgan
|
||||||
|
gpu_ids: 0
|
||||||
|
init_gain: 0.02
|
||||||
|
init_type: xavier
|
||||||
|
input_nc: 3
|
||||||
|
isTrain: True [default: None]
|
||||||
|
lambda_D_ViT: 1.0
|
||||||
|
lambda_GAN: 8.0 [default: 1.0]
|
||||||
|
lambda_NCE: 8.0 [default: 1.0]
|
||||||
|
lambda_SB: 0.1
|
||||||
|
lambda_ctn: 1.0
|
||||||
|
lambda_global: 1.0
|
||||||
|
lambda_inc: 1.0
|
||||||
|
lmda_1: 0.1
|
||||||
|
load_size: 286
|
||||||
|
lr: 1e-05 [default: 0.0002]
|
||||||
|
lr_decay_iters: 50
|
||||||
|
lr_policy: linear
|
||||||
|
max_dataset_size: inf
|
||||||
|
model: roma_unsb [default: cut]
|
||||||
|
n_epochs: 100
|
||||||
|
n_epochs_decay: 100
|
||||||
|
n_layers_D: 3
|
||||||
|
n_mlp: 3
|
||||||
|
name: ROMA_UNSB_001 [default: experiment_name]
|
||||||
|
nce_T: 0.07
|
||||||
|
nce_idt: False [default: True]
|
||||||
|
nce_includes_all_negatives_from_minibatch: False
|
||||||
|
nce_layers: 0,4,8,12,16
|
||||||
|
ndf: 64
|
||||||
|
netD: basic_cond
|
||||||
|
netF: mlp_sample
|
||||||
|
netF_nc: 256
|
||||||
|
netG: resnet_9blocks_cond
|
||||||
|
ngf: 64
|
||||||
|
no_antialias: False
|
||||||
|
no_antialias_up: False
|
||||||
|
no_dropout: True
|
||||||
|
no_flip: True [default: False]
|
||||||
|
no_html: False
|
||||||
|
normD: instance
|
||||||
|
normG: instance
|
||||||
|
num_patches: 256
|
||||||
|
num_threads: 4
|
||||||
|
num_timesteps: 10 [default: 5]
|
||||||
|
output_nc: 3
|
||||||
|
phase: train
|
||||||
|
pool_size: 0
|
||||||
|
preprocess: resize_and_crop
|
||||||
|
pretrained_name: None
|
||||||
|
print_freq: 100
|
||||||
|
random_scale_max: 3.0
|
||||||
|
save_by_iter: False
|
||||||
|
save_epoch_freq: 5
|
||||||
|
save_latest_freq: 5000
|
||||||
|
serial_batches: False
|
||||||
|
stylegan2_G_num_downsampling: 1
|
||||||
|
suffix:
|
||||||
|
tau: 0.01
|
||||||
|
update_html_freq: 1000
|
||||||
|
use_idt: False
|
||||||
|
verbose: False
|
||||||
|
----------------- End -------------------
|
||||||
Binary file not shown.
Binary file not shown.
@ -1401,31 +1401,23 @@ class UnetSkipConnectionBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class MLPDiscriminator(nn.Module):
|
class MLPDiscriminator(nn.Module):
|
||||||
def __init__(self, in_feat=768, hid_feat=512, out_feat=768, num_heads=1):
|
def __init__(self, in_feat=768, hid_feat = 768, out_feat = 768, dropout = 0.):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 自注意力层,加入Dropout
|
if not hid_feat:
|
||||||
self.attention = nn.MultiheadAttention(embed_dim=in_feat, num_heads=num_heads, dropout=0.1)
|
hid_feat = in_feat
|
||||||
# 加深加宽的MLP,加入Dropout
|
if not out_feat:
|
||||||
self.mlp = nn.Sequential(
|
out_feat = in_feat
|
||||||
nn.Linear(in_feat, hid_feat), # 768 -> 512
|
self.linear1 = nn.Linear(in_feat, hid_feat)
|
||||||
nn.ReLU(),
|
self.activation = nn.GELU()
|
||||||
nn.Dropout(0.3),
|
self.linear2 = nn.Linear(hid_feat, out_feat)
|
||||||
nn.Linear(hid_feat, hid_feat * 2), # 512 -> 1024
|
self.dropout = nn.Dropout(dropout)
|
||||||
nn.ReLU(),
|
|
||||||
nn.Dropout(0.3),
|
|
||||||
nn.Linear(hid_feat * 2, hid_feat), # 1024 -> 512
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Dropout(0.3),
|
|
||||||
nn.Linear(hid_feat, out_feat), # 512 -> 768
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Dropout(0.3),
|
|
||||||
nn.Linear(out_feat, 1) # 768 -> 1
|
|
||||||
)
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
attn_output, attn_weights = self.attention(x, x, x) # [B, N, D], [B, N, N]
|
x = self.linear1(x)
|
||||||
attn_weights = attn_weights.mean(dim=1) # [B, N]
|
x = self.activation(x)
|
||||||
pred = self.mlp(attn_output.mean(dim=1)) # [B, 1]
|
x = self.dropout(x)
|
||||||
return pred, attn_weights
|
x = self.linear2(x)
|
||||||
|
return self.dropout(x)
|
||||||
|
|
||||||
|
|
||||||
class NLayerDiscriminator(nn.Module):
|
class NLayerDiscriminator(nn.Module):
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import numpy as np
|
|||||||
import math
|
import math
|
||||||
import timm
|
import timm
|
||||||
import torch
|
import torch
|
||||||
import torchvision.models as models
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torchvision.transforms import GaussianBlur
|
from torchvision.transforms import GaussianBlur
|
||||||
@ -61,46 +60,156 @@ def compute_ctn_loss(G, x, F_content): #公式10
|
|||||||
loss = F.mse_loss(warped_fake, y_fake_warped)
|
loss = F.mse_loss(warped_fake, y_fake_warped)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
class ContentAwareOptimization(nn.Module):
|
||||||
class ContentAwareOptimization(nn.Module):
|
|
||||||
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
|
def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lambda_inc = lambda_inc
|
self.lambda_inc = lambda_inc # 权重增强系数
|
||||||
self.eta_ratio = eta_ratio
|
self.eta_ratio = eta_ratio # 选择内容区域的比例
|
||||||
self.criterionGAN=networks.GANLoss('lsgan').cuda()
|
|
||||||
|
def compute_cosine_similarity(self, gradients):
|
||||||
|
"""
|
||||||
|
计算每个patch梯度与平均梯度的余弦相似度
|
||||||
|
Args:
|
||||||
|
gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h)
|
||||||
|
Returns:
|
||||||
|
cosine_sim: [B, N] 每个patch的余弦相似度
|
||||||
|
"""
|
||||||
|
mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D]
|
||||||
|
# 计算余弦相似度
|
||||||
|
cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
|
||||||
|
return cosine_sim
|
||||||
|
|
||||||
|
def generate_weight_map(self, gradients_fake, feature_shape):
|
||||||
|
"""
|
||||||
|
生成内容感知权重图(修正空间维度)
|
||||||
|
Args:
|
||||||
|
gradients_real: [B, N, D] 真实图像判别器梯度
|
||||||
|
gradients_fake: [B, N, D] 生成图像判别器梯度
|
||||||
|
feature_shape: tuple [H, W] 判别器输出的特征图尺寸
|
||||||
|
Returns:
|
||||||
|
weight_real: [B, 1, H, W] 真实图像权重图
|
||||||
|
weight_fake: [B, 1, H, W] 生成图像权重图
|
||||||
|
"""
|
||||||
|
H, W = feature_shape
|
||||||
|
N = H * W
|
||||||
|
|
||||||
def generate_weight_map(self, attn_real, attn_fake):
|
# 计算余弦相似度(与原代码相同)
|
||||||
# attn_real, attn_fake: [B, N],自注意力权重
|
cosine_fake = self.compute_cosine_similarity(gradients_fake)
|
||||||
# 归一化注意力权重
|
|
||||||
weight_real = F.normalize(attn_real, p=1, dim=1) # [B, N]
|
|
||||||
weight_fake = F.normalize(attn_fake, p=1, dim=1) # [B, N]
|
|
||||||
|
|
||||||
# 对真实图像权重处理
|
# 生成权重图(与原代码相同)
|
||||||
k = int(self.eta_ratio * weight_real.shape[1])
|
k = int(self.eta_ratio * cosine_fake.shape[1])
|
||||||
values_real, indices_real = torch.topk(weight_real, k, dim=1)
|
_, fake_indices = torch.topk(-cosine_fake, k, dim=1)
|
||||||
weight_real_enhanced = torch.ones_like(weight_real)
|
weight_fake = torch.ones_like(cosine_fake)
|
||||||
weight_real_enhanced.scatter_(1, indices_real, self.lambda_inc / (values_real + 1e-6))
|
|
||||||
# 对生成图像权重处理
|
for b in range(cosine_fake.shape[0]):
|
||||||
values_fake, indices_fake = torch.topk(weight_fake, k, dim=1)
|
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
|
||||||
weight_fake_enhanced = torch.ones_like(weight_fake)
|
|
||||||
weight_fake_enhanced.scatter_(1, indices_fake, self.lambda_inc / (values_fake + 1e-6))
|
# 重建空间维度 --------------------------------------------------
|
||||||
|
# 将权重从[B, N]转换为[B, H, W]
|
||||||
|
#print(f"Shape of weight_fake before view: {weight_fake.shape}")
|
||||||
|
#print(f"Shape of cosine_fake: {cosine_fake.shape}")
|
||||||
|
#print(f"H: {H}, W: {W}, N: {N}")
|
||||||
|
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # [B,1,H,W]
|
||||||
|
|
||||||
return weight_real_enhanced, weight_fake_enhanced
|
return weight_fake
|
||||||
|
|
||||||
def forward(self,real_scores, fake_scores, attn_real, attn_fake):
|
def compute_cosine_similarity_image(self, gradients):
|
||||||
# real_scores, fake_scores: 判别器预测得分 [B, 1]
|
"""
|
||||||
# attn_real, attn_fake: 自注意力权重 [B, N]
|
计算每个空间位置梯度与平均梯度的余弦相似度 (图像版本)
|
||||||
|
Args:
|
||||||
|
gradients: [B, C, H, W] 判别器输出的梯度
|
||||||
|
Returns:
|
||||||
|
cosine_sim: [B, H, W] 每个空间位置的余弦相似度
|
||||||
|
"""
|
||||||
|
# 将空间维度展平,以便计算所有空间位置的平均梯度
|
||||||
|
B, C, H, W = gradients.shape
|
||||||
|
gradients_reshaped = gradients.view(B, C, H * W) # [B, C, N] where N = H*W
|
||||||
|
gradients_transposed = gradients_reshaped.transpose(1, 2) # [B, N, C] 将C放到最后一维,方便计算空间位置的平均梯度
|
||||||
|
|
||||||
|
mean_grad = torch.mean(gradients_transposed, dim=1, keepdim=True) # [B, 1, C] 在空间位置维度上求平均,得到平均梯度 [B, 1, C]
|
||||||
|
# mean_grad 现在是所有空间位置的平均梯度,形状为 [B, 1, C]
|
||||||
|
|
||||||
|
# 为了计算余弦相似度,我们需要将 mean_grad 扩展到与 gradients_transposed 相同的空间维度
|
||||||
|
mean_grad_expanded = mean_grad.expand(-1, H * W, -1) # [B, N, C]
|
||||||
|
|
||||||
|
# 计算余弦相似度,dim=2 表示在特征维度 (C) 上计算
|
||||||
|
cosine_sim = F.cosine_similarity(gradients_transposed, mean_grad_expanded, dim=2) # [B, N]
|
||||||
|
|
||||||
|
# 将 cosine_sim 重新reshape回 [B, H, W]
|
||||||
|
cosine_sim = cosine_sim.view(B, H, W)
|
||||||
|
return cosine_sim
|
||||||
|
|
||||||
|
def generate_weight_map_image(self, gradients_fake, feature_shape):
|
||||||
|
"""
|
||||||
|
生成内容感知权重图(修正空间维度 - 图像版本)
|
||||||
|
Args:
|
||||||
|
gradients_fake: [B, C, H, W] 生成图像判别器梯度
|
||||||
|
feature_shape: tuple [H, W] 判别器输出的特征图尺寸
|
||||||
|
Returns:
|
||||||
|
weight_fake: [B, 1, H, W] 生成图像权重图
|
||||||
|
"""
|
||||||
|
H, W = feature_shape
|
||||||
|
# 计算余弦相似度(图像版本)
|
||||||
|
cosine_fake = self.compute_cosine_similarity_image(gradients_fake) # [B, H, W]
|
||||||
|
# 生成权重图(与原代码相同,但现在cosine_fake是[B, H, W])
|
||||||
|
k = int(self.eta_ratio * H * W) # k 仍然是基于总的空间位置数量计算
|
||||||
|
_, fake_indices = torch.topk(-cosine_fake.view(cosine_fake.shape[0], -1), k, dim=1) # 将 cosine_fake 展平为 [B, N] 以使用 topk
|
||||||
|
weight_fake = torch.ones_like(cosine_fake).view(cosine_fake.shape[0], -1) # 初始化权重图,并展平为 [B, N]
|
||||||
|
for b in range(cosine_fake.shape[0]):
|
||||||
|
weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake.view(cosine_fake.shape[0], -1)[b, fake_indices[b]]))
|
||||||
|
weight_fake = weight_fake.view(-1, H, W).unsqueeze(1) # 重新 reshape 为 [B, H, W],并添加通道维度变为 [B, 1, H, W]
|
||||||
|
return weight_fake
|
||||||
|
|
||||||
|
def forward(self, D_real, D_fake, real_scores, fake_scores):
|
||||||
|
"""
|
||||||
|
计算内容感知对抗损失
|
||||||
|
Args:
|
||||||
|
D_real: 判别器对真实图像的特征输出 [B, C, H, W]
|
||||||
|
D_fake: 判别器对生成图像的特征输出 [B, C, H, W]
|
||||||
|
real_scores: 真实图像的判别器预测 [B, N] (N=H*W)
|
||||||
|
fake_scores: 生成图像的判别器预测 [B, N]
|
||||||
|
Returns:
|
||||||
|
loss_co_adv: 内容感知对抗损失
|
||||||
|
"""
|
||||||
|
B, C, H, W = D_real.shape
|
||||||
|
N = H * W
|
||||||
|
shape_hw = [H, W]
|
||||||
|
# 注册钩子获取梯度
|
||||||
|
gradients_real = []
|
||||||
|
gradients_fake = []
|
||||||
|
|
||||||
|
def hook_real(grad):
|
||||||
|
gradients_real.append(grad.detach().view(B, N, -1))
|
||||||
|
|
||||||
|
def hook_fake(grad):
|
||||||
|
gradients_fake.append(grad.detach().view(B, N, -1))
|
||||||
|
|
||||||
|
D_real.register_hook(hook_real)
|
||||||
|
D_fake.register_hook(hook_fake)
|
||||||
|
|
||||||
|
# 计算原始对抗损失以触发梯度计算
|
||||||
|
loss_real = torch.mean(torch.log(real_scores + 1e-8))
|
||||||
|
loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8))
|
||||||
|
# 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递
|
||||||
|
loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum())
|
||||||
|
total_loss = loss_real + loss_fake + loss_dummy
|
||||||
|
total_loss.backward(retain_graph=True)
|
||||||
|
|
||||||
|
# 获取梯度数据
|
||||||
|
gradients_real = gradients_real[1] # [B, N, D]
|
||||||
|
gradients_fake = gradients_fake[1] # [B, N, D]
|
||||||
|
|
||||||
# 生成权重图
|
# 生成权重图
|
||||||
weight_real, weight_fake = self.generate_weight_map(attn_real, attn_fake)
|
self.weight_real, self.weight_fake = self.generate_weight_map(gradients_fake, shape_hw )
|
||||||
|
|
||||||
# 应用权重到 GAN 损失
|
# 应用权重到对抗损失
|
||||||
loss_co_real = torch.mean(weight_real * self.criterionGAN(real_scores, True))
|
loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
|
||||||
loss_co_fake = torch.mean(weight_fake * self.criterionGAN(fake_scores, False))
|
loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8))
|
||||||
|
|
||||||
# 总损失
|
# 计算并返回最终内容感知对抗损失
|
||||||
loss_co_adv = (loss_co_real + loss_co_fake) * 0.5
|
loss_co_adv = -(loss_co_real + loss_co_fake)
|
||||||
return loss_co_adv, weight_real, weight_fake
|
|
||||||
|
return loss_co_adv
|
||||||
|
|
||||||
class ContentAwareTemporalNorm(nn.Module):
|
class ContentAwareTemporalNorm(nn.Module):
|
||||||
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
|
||||||
@ -108,33 +217,6 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
self.gamma_stride = gamma_stride # 控制整体运动幅度
|
self.gamma_stride = gamma_stride # 控制整体运动幅度
|
||||||
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
|
||||||
|
|
||||||
def upsample_weight_map(self, weight_patch, target_size=(256, 256)):
|
|
||||||
# 如果 weight_patch 是 [N, 1] 形状(例如 [576, 1]),添加批次维度
|
|
||||||
if weight_patch.dim() == 2 and weight_patch.shape[1] == 1:
|
|
||||||
weight_patch = weight_patch.unsqueeze(0) # 变为 [1, 576, 1]
|
|
||||||
|
|
||||||
# 获取调整后的形状
|
|
||||||
B, N, _ = weight_patch.shape # 例如 B=1, N=576
|
|
||||||
if N != 576:
|
|
||||||
raise ValueError(f"预期 patch 数量 N=576 (24x24),但实际得到 N={N}")
|
|
||||||
|
|
||||||
# 重塑为 [B, 1, 24, 24]
|
|
||||||
weight_patch = weight_patch.view(B, 1, 24, 24) # [1, 1, 24, 24]
|
|
||||||
|
|
||||||
# 使用双线性插值上采样到目标大小
|
|
||||||
weight_full = F.interpolate(
|
|
||||||
weight_patch,
|
|
||||||
size=target_size,
|
|
||||||
mode='bilinear',
|
|
||||||
align_corners=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# 可选:保持每个 16x16 patch 内部权重一致
|
|
||||||
weight_full = F.avg_pool2d(weight_full, kernel_size=16, stride=16)
|
|
||||||
weight_full = F.interpolate(weight_full, scale_factor=16, mode='nearest')
|
|
||||||
|
|
||||||
return weight_full
|
|
||||||
|
|
||||||
def forward(self, weight_map):
|
def forward(self, weight_map):
|
||||||
"""
|
"""
|
||||||
生成内容感知光流
|
生成内容感知光流
|
||||||
@ -143,17 +225,15 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
|
||||||
"""
|
"""
|
||||||
# 上采样权重图到全分辨率
|
print(weight_map.shape)
|
||||||
|
B, _, H, W = weight_map.shape
|
||||||
weight_full = self.upsample_weight_map(weight_map) # [B,1,384,384]
|
|
||||||
|
|
||||||
# 1. 归一化权重图
|
# 1. 归一化权重图
|
||||||
# 保持区域相对强度,同时限制数值范围
|
# 保持区域相对强度,同时限制数值范围
|
||||||
weight_norm = F.normalize(weight_full, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
|
weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
|
||||||
|
|
||||||
# 2. 生成高斯噪声
|
# 2. 生成高斯噪声(与光流场同尺寸)
|
||||||
B, _, H, W = weight_norm.shape
|
z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
|
||||||
z = torch.randn(B, 2, H, W, device=weight_norm.device) # [B,2,H,W]
|
|
||||||
|
|
||||||
# 3. 合成基础光流
|
# 3. 合成基础光流
|
||||||
# 将权重图扩展为2通道(x/y方向共享权重)
|
# 将权重图扩展为2通道(x/y方向共享权重)
|
||||||
@ -168,7 +248,7 @@ class ContentAwareTemporalNorm(nn.Module):
|
|||||||
# 限制光流幅值,避免极端位移
|
# 限制光流幅值,避免极端位移
|
||||||
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
|
F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
|
||||||
|
|
||||||
return F_content
|
return F_content
|
||||||
|
|
||||||
class RomaUnsbModel(BaseModel):
|
class RomaUnsbModel(BaseModel):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -176,26 +256,44 @@ class RomaUnsbModel(BaseModel):
|
|||||||
"""配置 CTNx 模型的特定选项"""
|
"""配置 CTNx 模型的特定选项"""
|
||||||
|
|
||||||
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
||||||
|
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
|
||||||
|
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
|
||||||
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
|
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
|
||||||
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
||||||
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
||||||
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
|
|
||||||
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
|
||||||
parser.add_argument('--local_nums', type=int, default=64, help='number of local patches')
|
|
||||||
parser.add_argument('--side_length', type=int, default=7)
|
|
||||||
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
||||||
|
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
||||||
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
|
type=util.str2bool, nargs='?', const=True, default=False,
|
||||||
parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix')
|
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
|
||||||
|
|
||||||
|
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
|
||||||
|
parser.add_argument('--netF_nc', type=int, default=256)
|
||||||
|
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
|
||||||
|
|
||||||
|
parser.add_argument('--lmda_1', type=float, default=0.1)
|
||||||
|
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
|
||||||
|
parser.add_argument('--flip_equivariance',
|
||||||
|
type=util.str2bool, nargs='?', const=True, default=False,
|
||||||
|
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
|
||||||
|
|
||||||
|
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
|
||||||
|
parser.add_argument('--eta_ratio', type=float, default=0.1, help='ratio of content-rich regions')
|
||||||
|
|
||||||
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
|
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
|
||||||
|
|
||||||
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
|
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
|
||||||
parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
|
parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
|
||||||
|
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
|
||||||
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
|
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
|
||||||
|
|
||||||
|
parser.set_defaults(pool_size=0) # no image pooling
|
||||||
|
|
||||||
opt, _ = parser.parse_known_args()
|
opt, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
# 直接设置为 sb 模式
|
||||||
|
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -204,11 +302,11 @@ class RomaUnsbModel(BaseModel):
|
|||||||
BaseModel.__init__(self, opt)
|
BaseModel.__init__(self, opt)
|
||||||
|
|
||||||
# 指定需要打印的训练损失
|
# 指定需要打印的训练损失
|
||||||
self.loss_names = ['G_GAN', 'D_ViT', 'G', 'global', 'spatial','ctn']
|
self.loss_names = ['G_GAN_1', 'D_real_1', 'D_fake_1', 'G_1', 'NCE_1', 'SB_1',
|
||||||
self.visual_names = ['real_A0', 'fake_B0', 'real_B0','real_A1', 'fake_B1', 'real_B1']
|
'G_2']
|
||||||
|
self.visual_names = ['real_A', 'real_A_noisy', 'fake_B', 'real_B']
|
||||||
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||||
|
|
||||||
|
|
||||||
if self.opt.phase == 'test':
|
if self.opt.phase == 'test':
|
||||||
self.visual_names = ['real']
|
self.visual_names = ['real']
|
||||||
for NFE in range(self.opt.num_timesteps):
|
for NFE in range(self.opt.num_timesteps):
|
||||||
@ -216,18 +314,24 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.visual_names.append(fake_name)
|
self.visual_names.append(fake_name)
|
||||||
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
|
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
|
||||||
|
|
||||||
|
if opt.nce_idt and self.isTrain:
|
||||||
|
self.loss_names += ['NCE_Y']
|
||||||
|
self.visual_names += ['idt_B']
|
||||||
|
|
||||||
if self.isTrain:
|
if self.isTrain:
|
||||||
self.model_names = ['G', 'D_ViT']
|
self.model_names = ['G', 'D_ViT', 'E']
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.model_names = ['G']
|
self.model_names = ['G']
|
||||||
|
|
||||||
|
print(f'input_nc = {self.opt.input_nc}')
|
||||||
# 创建网络
|
# 创建网络
|
||||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
||||||
|
|
||||||
|
|
||||||
if self.isTrain:
|
if self.isTrain:
|
||||||
|
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
||||||
|
|
||||||
self.resize = tfs.Resize(size=(384,384), antialias=True)
|
self.resize = tfs.Resize(size=(384,384), antialias=True)
|
||||||
|
|
||||||
@ -239,9 +343,14 @@ class RomaUnsbModel(BaseModel):
|
|||||||
# 定义损失函数
|
# 定义损失函数
|
||||||
self.criterionL1 = torch.nn.L1Loss().to(self.device)
|
self.criterionL1 = torch.nn.L1Loss().to(self.device)
|
||||||
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||||
|
self.criterionNCE = []
|
||||||
|
for nce_layer in self.nce_layers:
|
||||||
|
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
|
||||||
|
self.criterionIdt = torch.nn.L1Loss().to(self.device)
|
||||||
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||||
self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||||
self.optimizers = [self.optimizer_G, self.optimizer_D]
|
self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||||
|
self.optimizers = [self.optimizer_G, self.optimizer_D, self.optimizer_E]
|
||||||
|
|
||||||
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
|
self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
|
||||||
self.ctn = ContentAwareTemporalNorm() #生成的伪光流
|
self.ctn = ContentAwareTemporalNorm() #生成的伪光流
|
||||||
@ -253,6 +362,19 @@ class RomaUnsbModel(BaseModel):
|
|||||||
initialized at the first feedforward pass with some input images.
|
initialized at the first feedforward pass with some input images.
|
||||||
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
||||||
"""
|
"""
|
||||||
|
#bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
|
||||||
|
#self.set_input(data)
|
||||||
|
#self.real_A = self.real_A[:bs_per_gpu]
|
||||||
|
#self.real_B = self.real_B[:bs_per_gpu]
|
||||||
|
#self.forward() # compute fake images: G(A)
|
||||||
|
#if self.opt.isTrain:
|
||||||
|
#
|
||||||
|
# self.compute_G_loss().backward()
|
||||||
|
# self.compute_D_loss().backward()
|
||||||
|
# self.compute_E_loss().backward()
|
||||||
|
# if self.opt.lambda_NCE > 0.0:
|
||||||
|
# self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
|
||||||
|
# self.optimizers.append(self.optimizer_F)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def optimize_parameters(self):
|
def optimize_parameters(self):
|
||||||
@ -260,6 +382,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.forward()
|
self.forward()
|
||||||
|
|
||||||
self.netG.train()
|
self.netG.train()
|
||||||
|
self.netE.train()
|
||||||
self.netD_ViT.train()
|
self.netD_ViT.train()
|
||||||
|
|
||||||
# update D
|
# update D
|
||||||
@ -269,9 +392,19 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.loss_D.backward()
|
self.loss_D.backward()
|
||||||
self.optimizer_D.step()
|
self.optimizer_D.step()
|
||||||
|
|
||||||
|
# update E
|
||||||
|
self.set_requires_grad(self.netE, True)
|
||||||
|
self.optimizer_E.zero_grad()
|
||||||
|
self.loss_E = self.compute_E_loss()
|
||||||
|
self.loss_E.backward()
|
||||||
|
self.optimizer_E.step()
|
||||||
|
|
||||||
# update G
|
# update G
|
||||||
self.set_requires_grad(self.netD_ViT, False)
|
self.set_requires_grad(self.netD_ViT, False)
|
||||||
|
self.set_requires_grad(self.netE, False)
|
||||||
|
|
||||||
self.optimizer_G.zero_grad()
|
self.optimizer_G.zero_grad()
|
||||||
|
|
||||||
self.loss_G = self.compute_G_loss()
|
self.loss_G = self.compute_G_loss()
|
||||||
self.loss_G.backward()
|
self.loss_G.backward()
|
||||||
self.optimizer_G.step()
|
self.optimizer_G.step()
|
||||||
@ -290,113 +423,222 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
|
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
|
||||||
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||||
|
|
||||||
|
def tokens_concat(self, origin_tokens, adjacent_size):
|
||||||
|
adj_size = adjacent_size
|
||||||
|
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
|
||||||
|
S = int(math.sqrt(token_num))
|
||||||
|
if S * S != token_num:
|
||||||
|
print('Error! Not a square!')
|
||||||
|
token_map = origin_tokens.clone().reshape(B,S,S,C)
|
||||||
|
cut_patch_list = []
|
||||||
|
for i in range(0, S, adj_size):
|
||||||
|
for j in range(0, S, adj_size):
|
||||||
|
i_left = i
|
||||||
|
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
|
||||||
|
j_left = j
|
||||||
|
j_right = j + adj_size if j + adj_size <= S else S + 1
|
||||||
|
|
||||||
|
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
|
||||||
|
cut_patch= cut_patch.reshape(B,-1,C)
|
||||||
|
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
|
||||||
|
cut_patch_list.append(cut_patch)
|
||||||
|
|
||||||
|
result = torch.cat(cut_patch_list,dim=1)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def cat_results(self, origin_tokens, adj_size_list):
|
||||||
|
res_list = [origin_tokens]
|
||||||
|
for ad_s in adj_size_list:
|
||||||
|
cat_result = self.tokens_concat(origin_tokens, ad_s)
|
||||||
|
res_list.append(cat_result)
|
||||||
|
|
||||||
|
result = torch.cat(res_list, dim=1)
|
||||||
|
return result
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||||
self.fake_B0 = self.netG(self.real_A0)
|
|
||||||
self.fake_B1 = self.netG(self.real_A1)
|
|
||||||
|
|
||||||
if self.opt.isTrain:
|
# ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============
|
||||||
|
tau = self.opt.tau
|
||||||
|
T = self.opt.num_timesteps
|
||||||
|
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
|
||||||
|
times = np.cumsum(incs)
|
||||||
|
times = times / times[-1]
|
||||||
|
times = 0.5 * times[-1] + 0.5 * times #[0.5,1]
|
||||||
|
times = np.concatenate([np.zeros(1), times])
|
||||||
|
times = torch.tensor(times).float().cuda()
|
||||||
|
self.times = times
|
||||||
|
bs = self.real_A0.size(0)
|
||||||
|
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
|
||||||
|
self.time_idx = time_idx
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self.netG.eval()
|
||||||
|
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
|
||||||
|
for t in range(self.time_idx.int().item() + 1):
|
||||||
|
# 计算增量 delta 与 inter/scale,用于每个时间步的插值等
|
||||||
|
if t > 0:
|
||||||
|
delta = times[t] - times[t - 1]
|
||||||
|
denom = times[-1] - times[t - 1]
|
||||||
|
inter = (delta / denom).reshape(-1, 1, 1, 1)
|
||||||
|
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
|
||||||
|
|
||||||
|
# 对 Xt、Xt2 进行随机噪声更新
|
||||||
|
Xt = self.real_A0 if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
|
||||||
|
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device)
|
||||||
|
time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long()
|
||||||
|
z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device)
|
||||||
|
self.time = times[time_idx]
|
||||||
|
Xt_1 = self.netG(Xt, self.time, z)
|
||||||
|
|
||||||
|
Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
|
||||||
|
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device)
|
||||||
|
time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long()
|
||||||
|
z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device)
|
||||||
|
Xt_12 = self.netG(Xt2, self.time, z)
|
||||||
|
|
||||||
|
# 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
|
||||||
|
self.real_A_noisy = Xt.detach()
|
||||||
|
self.real_A_noisy2 = Xt2.detach()
|
||||||
|
|
||||||
|
# ============ 第三步:拼接输入并执行网络推理 =============
|
||||||
|
bs = self.real_A0.size(0)
|
||||||
|
z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device)
|
||||||
|
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
|
||||||
|
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
|
||||||
|
self.real = self.real_A0
|
||||||
|
self.realt = self.real_A_noisy
|
||||||
|
|
||||||
|
if self.opt.flip_equivariance:
|
||||||
|
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
|
||||||
|
if self.flipped_for_equivariance:
|
||||||
|
self.real = torch.flip(self.real, [3])
|
||||||
|
self.realt = torch.flip(self.realt, [3])
|
||||||
|
|
||||||
|
self.fake_B0 = self.netG(self.real_A0, self.time, z_in)
|
||||||
|
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
|
||||||
|
|
||||||
|
if self.opt.phase == 'train':
|
||||||
real_A0 = self.real_A0
|
real_A0 = self.real_A0
|
||||||
real_A1 = self.real_A1
|
real_A1 = self.real_A1
|
||||||
real_B0 = self.real_B0
|
real_B0 = self.real_B0
|
||||||
real_B1 = self.real_B1
|
real_B1 = self.real_B1
|
||||||
fake_B0 = self.fake_B0
|
fake_B0 = self.fake_B0
|
||||||
fake_B1 = self.fake_B1
|
fake_B1 = self.fake_B1
|
||||||
|
|
||||||
self.real_A0_resize = self.resize(real_A0)
|
self.real_A0_resize = self.resize(real_A0)
|
||||||
self.real_A1_resize = self.resize(real_A1)
|
self.real_A1_resize = self.resize(real_A1)
|
||||||
real_B0 = self.resize(real_B0)
|
real_B0 = self.resize(real_B0)
|
||||||
real_B1 = self.resize(real_B1)
|
real_B1 = self.resize(real_B1)
|
||||||
self.fake_B0_resize = self.resize(fake_B0)
|
self.fake_B0_resize = self.resize(fake_B0)
|
||||||
self.fake_B1_resize = self.resize(fake_B1)
|
self.fake_B1_resize = self.resize(fake_B1)
|
||||||
|
|
||||||
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
||||||
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
||||||
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
|
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
|
||||||
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
||||||
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||||
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||||||
|
# [[1,576,768],[1,576,768],[1,576,768]]
|
||||||
|
# [3,576,768]
|
||||||
|
#self.mutil_real_A0_tokens = self.cat_results(self.mutil_real_A0_tokens[0], self.opt.adj_size_list)
|
||||||
|
#print(f'self.mutil_real_A0_tokens[0]:{self.mutil_real_A0_tokens[0].shape}')
|
||||||
|
|
||||||
|
shape_hw = list(self.real_A0_resize.shape[2:4])
|
||||||
|
# 生成图像的梯度
|
||||||
|
fake_gradient = torch.autograd.grad(self.mutil_fake_B0_tokens[0].sum(), self.mutil_fake_B0_tokens, create_graph=True)[0]
|
||||||
|
|
||||||
|
# 梯度图
|
||||||
|
self.weight_fake = self.cao.generate_weight_map_image(fake_gradient, shape_hw)
|
||||||
|
|
||||||
|
# 生成图像的CTN光流图
|
||||||
|
self.f_content = self.ctn(self.weight_fake)
|
||||||
|
|
||||||
|
# 变换后的图片
|
||||||
|
self.warped_real_A_noisy2 = warp(self.real_A_noisy, self.f_content)
|
||||||
|
self.warped_fake_B0 = warp(self.fake_B0,self.f_content)
|
||||||
|
|
||||||
|
# 经过第二次生成器
|
||||||
|
self.warped_fake_B0_2 = self.netG(self.warped_real_A_noisy2, self.time, z_in)
|
||||||
|
|
||||||
|
# warped_fake_B0_2=self.warped_fake_B0_2
|
||||||
|
# warped_fake_B0=self.warped_fake_B0
|
||||||
|
# self.warped_fake_B0_2_resize = self.resize(warped_fake_B0_2)
|
||||||
|
# self.warped_fake_B0_resize = self.resize(warped_fake_B0)
|
||||||
|
# self.mutil_warped_fake_B0_tokens = self.netPreViT(self.warped_fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||||
|
# self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_D_loss(self): #判别器还是没有改
|
||||||
|
"""Calculate GAN loss for the discriminator"""
|
||||||
|
|
||||||
def compute_D_loss(self):
|
|
||||||
"""Calculate GAN loss with Content-Aware Optimization"""
|
|
||||||
lambda_D_ViT = self.opt.lambda_D_ViT
|
lambda_D_ViT = self.opt.lambda_D_ViT
|
||||||
|
fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
|
||||||
pred_real0, attn_real0 = self.netD_ViT(self.mutil_real_B0_tokens[0]) # scores, features
|
fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach()
|
||||||
pred_real1, attn_real1 = self.netD_ViT(self.mutil_real_B1_tokens[0]) # scores, features
|
|
||||||
|
real_B0_tokens = self.mutil_real_B0_tokens[0]
|
||||||
pred_fake0, attn_fake0 = self.netD_ViT(self.mutil_fake_B0_tokens[0].detach())
|
real_B1_tokens = self.mutil_real_B1_tokens[0]
|
||||||
pred_fake1, attn_fake1 = self.netD_ViT(self.mutil_fake_B1_tokens[0].detach())
|
|
||||||
loss_cao0, self.weight_real0, self.weight_fake0 = self.cao(
|
|
||||||
real_scores=pred_real0,
|
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
|
||||||
fake_scores=pred_fake0,
|
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
|
||||||
attn_real=attn_real0,
|
|
||||||
attn_fake=attn_fake0
|
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
|
||||||
)
|
|
||||||
loss_cao1, self.weight_real1, self.weight_fake1 = self.cao(
|
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
|
||||||
real_scores=pred_real1,
|
pred_real1_ViT = self.netD_ViT(real_B1_tokens)
|
||||||
fake_scores=pred_fake1,
|
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
|
||||||
attn_real=attn_real1,
|
|
||||||
attn_fake=attn_fake1
|
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
|
||||||
)
|
|
||||||
|
|
||||||
self.loss_D_ViT = (loss_cao0 + loss_cao1) * 0.5 * lambda_D_ViT
|
|
||||||
|
|
||||||
|
|
||||||
# 记录损失值供可视化
|
|
||||||
# self.loss_D_real = loss_D_real.item()
|
|
||||||
# self.loss_D_fake = loss_D_fake.item()
|
|
||||||
# self.loss_cao = (loss_cao0 + loss_cao1).item() * 0.5
|
|
||||||
|
|
||||||
return self.loss_D_ViT
|
return self.loss_D_ViT
|
||||||
|
|
||||||
|
def compute_E_loss(self):
|
||||||
|
"""计算判别器 E 的损失"""
|
||||||
|
|
||||||
|
print(f'resl_A_noisy: {self.real_A_noisy.shape} \n fake_B0: {self.fake_B0.shape}')
|
||||||
|
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
|
||||||
|
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
|
||||||
|
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
|
||||||
|
self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
|
||||||
|
|
||||||
|
return self.loss_E
|
||||||
|
|
||||||
def compute_G_loss(self):
|
def compute_G_loss(self):
|
||||||
"""计算生成器的 GAN 损失"""
|
"""计算生成器的 GAN 损失"""
|
||||||
if self.opt.lambda_ctn > 0.0:
|
|
||||||
# 生成图像的CTN光流图
|
|
||||||
self.f_content0 = self.ctn(self.weight_fake0.detach())
|
|
||||||
self.f_content1 = self.ctn(self.weight_fake1.detach())
|
|
||||||
|
|
||||||
# 变换后的图片
|
|
||||||
self.warped_real_A0 = warp(self.real_A0, self.f_content0)
|
|
||||||
self.warped_real_A1 = warp(self.real_A1, self.f_content1)
|
|
||||||
self.warped_fake_B0 = warp(self.fake_B0,self.f_content0)
|
|
||||||
self.warped_fake_B1 = warp(self.fake_B1,self.f_content1)
|
|
||||||
|
|
||||||
# 经过第二次生成器
|
|
||||||
self.warped_fake_B0_2 = self.netG(self.warped_real_A0)
|
|
||||||
self.warped_fake_B1_2 = self.netG(self.warped_real_A1)
|
|
||||||
|
|
||||||
warped_fake_B0_2=self.warped_fake_B0_2
|
|
||||||
warped_fake_B1_2=self.warped_fake_B1_2
|
|
||||||
warped_fake_B0=self.warped_fake_B0
|
|
||||||
warped_fake_B1=self.warped_fake_B1
|
|
||||||
# 计算L2损失
|
|
||||||
self.loss_ctn0 = F.mse_loss(warped_fake_B0_2, warped_fake_B0)
|
|
||||||
self.loss_ctn1 = F.mse_loss(warped_fake_B1_2, warped_fake_B1)
|
|
||||||
self.loss_ctn = (self.loss_ctn0 + self.loss_ctn1)*0.5
|
|
||||||
|
|
||||||
if self.opt.lambda_GAN > 0.0:
|
if self.opt.lambda_GAN > 0.0:
|
||||||
|
pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
|
||||||
pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens[0])
|
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
|
||||||
pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens[0])
|
|
||||||
self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean()
|
|
||||||
self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean()
|
|
||||||
self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5
|
|
||||||
else:
|
else:
|
||||||
self.loss_G_GAN = 0.0
|
self.loss_G_GAN = 0.0
|
||||||
|
self.loss_SB = 0
|
||||||
|
if self.opt.lambda_SB > 0.0:
|
||||||
|
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0], dim=1)
|
||||||
|
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1], dim=1)
|
||||||
|
|
||||||
|
bs = self.opt.batch_size
|
||||||
|
|
||||||
if self.opt.lambda_global or self.opt.lambda_spatial > 0.0:
|
# eq.9
|
||||||
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
|
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0)
|
||||||
else:
|
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
|
||||||
self.loss_global, self.loss_spatial = 0.0, 0.0
|
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
|
||||||
|
|
||||||
self.loss_G = self.opt.lambda_GAN * self.loss_G_GAN + \
|
if self.opt.lambda_global > 0.0:
|
||||||
self.opt.lambda_ctn * self.loss_ctn + \
|
loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)
|
||||||
self.loss_global * self.opt.lambda_global+\
|
loss_global *= 0.5
|
||||||
self.loss_spatial * self.opt.lambda_spatial
|
else:
|
||||||
|
loss_global = 0.0
|
||||||
|
|
||||||
|
self.l2_loss = 0.0
|
||||||
|
if self.opt.lambda_l2 > 0.0:
|
||||||
|
wapped_fake_B = warp(self.fake_B0, self.f_content) # use updated self.f_content
|
||||||
|
self.l2_loss = F.mse_loss(self.warped_fake_B0_2, wapped_fake_B) # complete the loss calculation
|
||||||
|
|
||||||
|
self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global
|
||||||
return self.loss_G
|
return self.loss_G
|
||||||
|
|
||||||
def calculate_attention_loss(self):
|
def calculate_attention_loss(self):
|
||||||
n_layers = len(self.atten_layers)
|
n_layers = len(self.atten_layers)
|
||||||
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
||||||
@ -419,19 +661,20 @@ class RomaUnsbModel(BaseModel):
|
|||||||
local_id = np.random.permutation(tokens_cnt)
|
local_id = np.random.permutation(tokens_cnt)
|
||||||
local_id = local_id[:int(min(local_nums, tokens_cnt))]
|
local_id = local_id[:int(min(local_nums, tokens_cnt))]
|
||||||
|
|
||||||
mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
mutil_real_A0_local_tokens = self.netPreViT(self.resize(self.real_A0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
|
||||||
mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
mutil_real_A1_local_tokens = self.netPreViT(self.resize(self.real_A1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
|
||||||
|
|
||||||
mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
mutil_fake_B0_local_tokens = self.netPreViT(self.resize(self.fake_B0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
|
||||||
mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
mutil_fake_B1_local_tokens = self.netPreViT(self.resize(self.fake_B1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
|
||||||
|
|
||||||
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
|
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
|
||||||
loss_spatial *= 0.5
|
loss_spatial *= 0.5
|
||||||
|
|
||||||
else:
|
else:
|
||||||
loss_spatial = 0.0
|
loss_spatial = 0.0
|
||||||
return loss_global , loss_spatial
|
|
||||||
|
|
||||||
|
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
|
||||||
|
|
||||||
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
|
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
n_layers = len(self.atten_layers)
|
n_layers = len(self.atten_layers)
|
||||||
@ -445,3 +688,5 @@ class RomaUnsbModel(BaseModel):
|
|||||||
loss = loss / n_layers
|
loss = loss / n_layers
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Binary file not shown.
Binary file not shown.
@ -36,7 +36,7 @@ class BaseOptions():
|
|||||||
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
|
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
|
||||||
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
|
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
|
||||||
parser.add_argument('--netD', type=str, default='basic_cond', choices=['basic_cond', 'basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
parser.add_argument('--netD', type=str, default='basic_cond', choices=['basic_cond', 'basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
||||||
parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture')
|
parser.add_argument('--netG', type=str, default='resnet_9blocks_cond', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture')
|
||||||
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
|
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
|
||||||
parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
|
parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
|
||||||
parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
|
parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class TrainOptions(BaseOptions):
|
|||||||
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
|
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
|
||||||
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
|
||||||
parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
|
parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
|
||||||
|
|
||||||
# training parameters
|
# training parameters
|
||||||
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')
|
parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')
|
||||||
parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')
|
parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')
|
||||||
|
|||||||
301
roma.py
Normal file
301
roma.py
Normal file
@ -0,0 +1,301 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from .base_model import BaseModel
|
||||||
|
from . import networks
|
||||||
|
from .patchnce import PatchNCELoss
|
||||||
|
import util.util as util
|
||||||
|
import timm
|
||||||
|
import time
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import sys
|
||||||
|
from functools import partial
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
from torchvision.transforms import transforms as tfs
|
||||||
|
|
||||||
|
class ROMAModel(BaseModel):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def modify_commandline_options(parser, is_train=True):
|
||||||
|
""" Configures options specific for CUT model
|
||||||
|
"""
|
||||||
|
parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
|
||||||
|
parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator')
|
||||||
|
parser.add_argument('--lambda_motion', type=float, default=1.0, help='weight for Temporal Consistency')
|
||||||
|
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
|
||||||
|
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
|
||||||
|
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
|
||||||
|
parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
|
||||||
|
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
|
||||||
|
parser.add_argument('--local_nums', type=int, default=256)
|
||||||
|
parser.add_argument('--which_D_layer', type=int, default=-1)
|
||||||
|
parser.add_argument('--side_length', type=int, default=7)
|
||||||
|
|
||||||
|
parser.set_defaults(pool_size=0)
|
||||||
|
|
||||||
|
opt, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
def __init__(self, opt):
|
||||||
|
BaseModel.__init__(self, opt)
|
||||||
|
|
||||||
|
|
||||||
|
self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial', 'motion']
|
||||||
|
self.visual_names = ['real_A0', 'real_A1', 'fake_B0', 'fake_B1', 'real_B0', 'real_B1']
|
||||||
|
self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
|
||||||
|
|
||||||
|
|
||||||
|
if self.isTrain:
|
||||||
|
self.model_names = ['G', 'D_ViT']
|
||||||
|
else: # during test time, only load G
|
||||||
|
self.model_names = ['G']
|
||||||
|
|
||||||
|
|
||||||
|
# define networks (both generator and discriminator)
|
||||||
|
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
||||||
|
|
||||||
|
|
||||||
|
if self.isTrain:
|
||||||
|
|
||||||
|
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
|
||||||
|
self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
|
||||||
|
|
||||||
|
|
||||||
|
self.norm = F.softmax
|
||||||
|
|
||||||
|
self.resize = tfs.Resize(size=(384,384))
|
||||||
|
|
||||||
|
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
||||||
|
self.criterionNCE = []
|
||||||
|
|
||||||
|
for atten_layer in self.atten_layers:
|
||||||
|
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
|
||||||
|
|
||||||
|
self.criterionL1 = torch.nn.L1Loss().to(self.device)
|
||||||
|
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
||||||
|
self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
|
||||||
|
self.optimizers.append(self.optimizer_G)
|
||||||
|
self.optimizers.append(self.optimizer_D_ViT)
|
||||||
|
|
||||||
|
def data_dependent_initialize(self, data):
|
||||||
|
"""
|
||||||
|
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
||||||
|
features of the encoder portion of netG. Because of this, the weights of netF are
|
||||||
|
initialized at the first feedforward pass with some input images.
|
||||||
|
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_parameters(self):
|
||||||
|
# forward
|
||||||
|
self.forward()
|
||||||
|
|
||||||
|
# update D
|
||||||
|
self.set_requires_grad(self.netD_ViT, True)
|
||||||
|
self.optimizer_D_ViT.zero_grad()
|
||||||
|
self.loss_D = self.compute_D_loss()
|
||||||
|
self.loss_D.backward()
|
||||||
|
self.optimizer_D_ViT.step()
|
||||||
|
|
||||||
|
# update G
|
||||||
|
self.set_requires_grad(self.netD_ViT, False)
|
||||||
|
self.optimizer_G.zero_grad()
|
||||||
|
self.loss_G = self.compute_G_loss()
|
||||||
|
self.loss_G.backward()
|
||||||
|
self.optimizer_G.step()
|
||||||
|
|
||||||
|
def set_input(self, input):
|
||||||
|
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||||
|
Parameters:
|
||||||
|
input (dict): include the data itself and its metadata information.
|
||||||
|
The option 'direction' can be used to swap domain A and domain B.
|
||||||
|
"""
|
||||||
|
AtoB = self.opt.direction == 'AtoB'
|
||||||
|
self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
|
||||||
|
self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
|
||||||
|
self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
|
||||||
|
self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
|
||||||
|
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
||||||
|
self.fake_B0 = self.netG(self.real_A0)
|
||||||
|
self.fake_B1 = self.netG(self.real_A1)
|
||||||
|
|
||||||
|
if self.opt.isTrain:
|
||||||
|
real_A0 = self.real_A0
|
||||||
|
real_A1 = self.real_A1
|
||||||
|
real_B0 = self.real_B0
|
||||||
|
real_B1 = self.real_B1
|
||||||
|
fake_B0 = self.fake_B0
|
||||||
|
fake_B1 = self.fake_B1
|
||||||
|
self.real_A0_resize = self.resize(real_A0)
|
||||||
|
self.real_A1_resize = self.resize(real_A1)
|
||||||
|
real_B0 = self.resize(real_B0)
|
||||||
|
real_B1 = self.resize(real_B1)
|
||||||
|
self.fake_B0_resize = self.resize(fake_B0)
|
||||||
|
self.fake_B1_resize = self.resize(fake_B1)
|
||||||
|
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
|
||||||
|
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
|
||||||
|
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
|
||||||
|
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
|
||||||
|
self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
|
||||||
|
self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
|
||||||
|
|
||||||
|
def tokens_concat(self, origin_tokens, adjacent_size):
|
||||||
|
adj_size = adjacent_size
|
||||||
|
B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
|
||||||
|
S = int(math.sqrt(token_num))
|
||||||
|
if S * S != token_num:
|
||||||
|
print('Error! Not a square!')
|
||||||
|
token_map = origin_tokens.clone().reshape(B,S,S,C)
|
||||||
|
cut_patch_list = []
|
||||||
|
for i in range(0, S, adj_size):
|
||||||
|
for j in range(0, S, adj_size):
|
||||||
|
i_left = i
|
||||||
|
i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
|
||||||
|
j_left = j
|
||||||
|
j_right = j + adj_size if j + adj_size <= S else S + 1
|
||||||
|
|
||||||
|
cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
|
||||||
|
cut_patch= cut_patch.reshape(B,-1,C)
|
||||||
|
cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
|
||||||
|
cut_patch_list.append(cut_patch)
|
||||||
|
|
||||||
|
|
||||||
|
result = torch.cat(cut_patch_list,dim=1)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def cat_results(self, origin_tokens, adj_size_list):
|
||||||
|
res_list = [origin_tokens]
|
||||||
|
for ad_s in adj_size_list:
|
||||||
|
cat_result = self.tokens_concat(origin_tokens, ad_s)
|
||||||
|
res_list.append(cat_result)
|
||||||
|
|
||||||
|
result = torch.cat(res_list, dim=1)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def compute_D_loss(self):
|
||||||
|
"""Calculate GAN loss for the discriminator"""
|
||||||
|
|
||||||
|
|
||||||
|
lambda_D_ViT = self.opt.lambda_D_ViT
|
||||||
|
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer].detach()
|
||||||
|
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer].detach()
|
||||||
|
|
||||||
|
real_B0_tokens = self.mutil_real_B0_tokens[self.opt.which_D_layer]
|
||||||
|
real_B1_tokens = self.mutil_real_B1_tokens[self.opt.which_D_layer]
|
||||||
|
|
||||||
|
|
||||||
|
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
|
||||||
|
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
real_B0_tokens = self.cat_results(real_B0_tokens, self.opt.adj_size_list)
|
||||||
|
real_B1_tokens = self.cat_results(real_B1_tokens, self.opt.adj_size_list)
|
||||||
|
|
||||||
|
pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
|
||||||
|
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
|
||||||
|
|
||||||
|
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
|
||||||
|
|
||||||
|
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
|
||||||
|
pred_real1_ViT = self.netD_ViT(real_B1_tokens)
|
||||||
|
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
|
||||||
|
|
||||||
|
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
|
||||||
|
|
||||||
|
|
||||||
|
return self.loss_D_ViT
|
||||||
|
|
||||||
|
def compute_G_loss(self):
|
||||||
|
|
||||||
|
if self.opt.lambda_GAN > 0.0:
|
||||||
|
|
||||||
|
fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer]
|
||||||
|
fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer]
|
||||||
|
fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
|
||||||
|
fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
|
||||||
|
pred_fake0_ViT = self.netD_ViT(fake_B0_tokens)
|
||||||
|
pred_fake1_ViT = self.netD_ViT(fake_B1_tokens)
|
||||||
|
self.loss_G_GAN_ViT = (self.criterionGAN(pred_fake0_ViT, True) + self.criterionGAN(pred_fake1_ViT, True)) * 0.5 * self.opt.lambda_GAN
|
||||||
|
else:
|
||||||
|
self.loss_G_GAN_ViT = 0.0
|
||||||
|
|
||||||
|
if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
|
||||||
|
self.loss_global, self.loss_spatial = self.calculate_attention_loss()
|
||||||
|
else:
|
||||||
|
self.loss_global, self.loss_spatial = 0.0, 0.0
|
||||||
|
|
||||||
|
if self.opt.lambda_motion > 0.0:
|
||||||
|
self.loss_motion = 0.0
|
||||||
|
for real_A0_tokens, real_A1_tokens, fake_B0_tokens, fake_B1_tokens in zip(self.mutil_real_A0_tokens, self.mutil_real_A1_tokens, self.mutil_fake_B0_tokens, self.mutil_fake_B1_tokens):
|
||||||
|
A0_B1 = real_A0_tokens.bmm(fake_B1_tokens.permute(0,2,1))
|
||||||
|
B0_A1 = fake_B0_tokens.bmm(real_A1_tokens.permute(0,2,1))
|
||||||
|
cos_dis_global = F.cosine_similarity(A0_B1, B0_A1, dim=-1)
|
||||||
|
self.loss_motion += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
|
||||||
|
else:
|
||||||
|
self.loss_motion = 0.0
|
||||||
|
|
||||||
|
self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial + self.loss_motion
|
||||||
|
return self.loss_G
|
||||||
|
|
||||||
|
def calculate_attention_loss(self):
|
||||||
|
n_layers = len(self.atten_layers)
|
||||||
|
mutil_real_A0_tokens = self.mutil_real_A0_tokens
|
||||||
|
mutil_real_A1_tokens = self.mutil_real_A1_tokens
|
||||||
|
mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
|
||||||
|
mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
|
||||||
|
|
||||||
|
|
||||||
|
if self.opt.lambda_global > 0.0:
|
||||||
|
loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
|
||||||
|
loss_global *= 0.5
|
||||||
|
|
||||||
|
else:
|
||||||
|
loss_global = 0.0
|
||||||
|
|
||||||
|
if self.opt.lambda_spatial > 0.0:
|
||||||
|
loss_spatial = 0.0
|
||||||
|
local_nums = self.opt.local_nums
|
||||||
|
tokens_cnt = 576
|
||||||
|
local_id = np.random.permutation(tokens_cnt)
|
||||||
|
local_id = local_id[:int(min(local_nums, tokens_cnt))]
|
||||||
|
|
||||||
|
mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||||
|
mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||||
|
|
||||||
|
mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||||
|
mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
|
||||||
|
|
||||||
|
loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
|
||||||
|
loss_spatial *= 0.5
|
||||||
|
|
||||||
|
else:
|
||||||
|
loss_spatial = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
|
||||||
|
|
||||||
|
def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
|
||||||
|
loss = 0.0
|
||||||
|
n_layers = len(self.atten_layers)
|
||||||
|
|
||||||
|
for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
|
||||||
|
|
||||||
|
src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
|
||||||
|
tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
|
||||||
|
cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
|
||||||
|
loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
|
||||||
|
|
||||||
|
loss = loss / n_layers
|
||||||
|
return loss
|
||||||
@ -7,29 +7,27 @@
|
|||||||
|
|
||||||
python train.py \
|
python train.py \
|
||||||
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
||||||
--name UNIV_5 \
|
--name ROMA_UNSB_001 \
|
||||||
--dataset_mode unaligned_double \
|
--dataset_mode unaligned_double \
|
||||||
--display_env UNIV \
|
--no_flip \
|
||||||
|
--display_env ROMA \
|
||||||
--model roma_unsb \
|
--model roma_unsb \
|
||||||
--lambda_SB 1.0 \
|
--lambda_GAN 8.0 \
|
||||||
--lambda_ctn 10 \
|
--lambda_NCE 8.0 \
|
||||||
|
--lambda_SB 0.1 \
|
||||||
|
--lambda_ctn 1.0 \
|
||||||
--lambda_inc 1.0 \
|
--lambda_inc 1.0 \
|
||||||
--lambda_global 6.0 \
|
--lr 0.00001 \
|
||||||
--gamma_stride 20 \
|
|
||||||
--lr 0.000002 \
|
|
||||||
--gpu_id 0 \
|
--gpu_id 0 \
|
||||||
--nce_idt False \
|
--nce_idt False \
|
||||||
|
--nce_layers 0,4,8,12,16 \
|
||||||
--netF mlp_sample \
|
--netF mlp_sample \
|
||||||
--eta_ratio 0.4 \
|
--netF_nc 256 \
|
||||||
|
--nce_T 0.07 \
|
||||||
|
--lmda_1 0.1 \
|
||||||
|
--num_patches 256 \
|
||||||
|
--flip_equivariance False \
|
||||||
|
--eta_ratio 0.1 \
|
||||||
--tau 0.01 \
|
--tau 0.01 \
|
||||||
--num_timesteps 5 \
|
--num_timesteps 10 \
|
||||||
--input_nc 3 \
|
--input_nc 3
|
||||||
--n_epochs 400 \
|
|
||||||
--n_epochs_decay 200 \
|
|
||||||
|
|
||||||
# exp1 num_timesteps=4 (已停)
|
|
||||||
# exp2 num_timesteps=5 (已停)
|
|
||||||
# exp3 --num_timesteps 5,--lambda_inc 8 ,--gamma_stride 20,--lambda_global 6.0,--lambda_ctn 10, --lr 0.000002 (已停)
|
|
||||||
# exp4 --num_timesteps 5,--lambda_inc 8 ,--gamma_stride 20,--lambda_global 6.0,--lambda_ctn 10, --lr 0.000002, ET_XY=self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time_idx, XtXt_2).reshape(-1), dim=0) ,并把GAN,CTN loss考虑到了A1和B1 (已停)
|
|
||||||
# exp5 基于 exp4 ,修改了 self.loss_global = self.calculate_similarity(self.mutil_real_A0_tokens, self.mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, self.mutil_fake_B1_tokens) ,gpu_id 1 (已停)
|
|
||||||
# 上面几个实验效果都不好,实验结果都已经删除了,开的新的train_sbiv 对代码进行了调整,效果变得更好了。
|
|
||||||
|
|||||||
@ -1,32 +0,0 @@
|
|||||||
#!/bin/sh
|
|
||||||
# Train for video mode
|
|
||||||
#CUDA_VISIBLE_DEVICES=0 python train.py --dataroot /path --name ROMA_name --dataset_mode unaligned_double --no_flip --local_nums 64 --display_env ROMA_env --model roma --side_length 7 --lambda_spatial 5.0 --lambda_global 5.0 --lambda_motion 1.0 --atten_layers 1,3,5 --lr 0.00001
|
|
||||||
|
|
||||||
# Train for image mode
|
|
||||||
#CUDA_VISIBLE_DEVICES=0 python train.py --dataroot /path --name ROMA_name --dataset_mode unaligned --local_nums 64 --display_env ROMA_env --model roma --side_length 7 --lambda_spatial 5.0 --lambda_global 5.0 --atten_layers 1,3,5 --lr 0.00001
|
|
||||||
|
|
||||||
python train.py \
|
|
||||||
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
|
||||||
--name SBIV_1 \
|
|
||||||
--dataset_mode unaligned_double \
|
|
||||||
--display_env SBIV2 \
|
|
||||||
--model roma_unsb \
|
|
||||||
--lambda_ctn 10 \
|
|
||||||
--lambda_inc 1.0 \
|
|
||||||
--lambda_global 8.0 \
|
|
||||||
--lambda_spatial 8.0 \
|
|
||||||
--gamma_stride 20 \
|
|
||||||
--lr 0.000001 \
|
|
||||||
--gpu_id 0 \
|
|
||||||
--eta_ratio 0.3 \
|
|
||||||
--tau 0.01 \
|
|
||||||
--num_timesteps 3 \
|
|
||||||
--input_nc 3 \
|
|
||||||
--n_epochs 400 \
|
|
||||||
--n_epochs_decay 200 \
|
|
||||||
|
|
||||||
# exp6 num_timesteps=4 ,gpu_id 0(基于 exp5 ,exp1 已停) (已停)
|
|
||||||
# exp7 num_timesteps=3 ,gpu_id 0 基于 exp6 (已停)
|
|
||||||
# # exp8 num_timesteps=4 ,gpu_id 1 ,修改了训练判别器的loss,以及ctnloss(基于,exp6)
|
|
||||||
# # exp9 num_timesteps=3 ,gpu_id 2 ,(基于 exp8)
|
|
||||||
# # # exp10 num_timesteps=4 ,gpu_id 0 , --name SBIV_1 ,让判别器看到了每一个时间步的输出,修改了训练判别器的loss,以及ctnloss(基于,exp9)
|
|
||||||
@ -1,20 +0,0 @@
|
|||||||
python train.py \
|
|
||||||
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
|
|
||||||
--name cp_3 \
|
|
||||||
--dataset_mode unaligned_double \
|
|
||||||
--display_env CP \
|
|
||||||
--model roma_unsb \
|
|
||||||
--lambda_ctn 10 \
|
|
||||||
--lambda_inc 8.0 \
|
|
||||||
--eta_ratio 0.4 \
|
|
||||||
--lambda_global 6.0 \
|
|
||||||
--lambda_spatial 6.0 \
|
|
||||||
--gamma_stride 20 \
|
|
||||||
--lr 0.00002 \
|
|
||||||
--gpu_id 3 \
|
|
||||||
--eta_ratio 0.4 \
|
|
||||||
--n_epochs 100 \
|
|
||||||
--n_epochs_decay 100 \
|
|
||||||
# cp1 复现cptrans的效果 --lr 0.000001
|
|
||||||
# cp2 修了一下cp1的代码,--lr 0.000002
|
|
||||||
## cp3 将梯度加强修改为attention加强,--lr 0.000005,--lambda_inc 8.0,--gpu_id 3(基于cp2的sh)
|
|
||||||
1
train.py
1
train.py
@ -44,7 +44,6 @@ if __name__ == '__main__':
|
|||||||
model.setup(opt) # regular setup: load and print networks; create schedulers
|
model.setup(opt) # regular setup: load and print networks; create schedulers
|
||||||
model.parallelize()
|
model.parallelize()
|
||||||
model.set_input(data) # unpack data from dataset and apply preprocessing
|
model.set_input(data) # unpack data from dataset and apply preprocessing
|
||||||
#print('Call opt paras')
|
|
||||||
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
|
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
|
||||||
if len(opt.gpu_ids) > 0:
|
if len(opt.gpu_ids) > 0:
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user