This commit is contained in:
bishe 2025-03-07 19:20:37 +08:00
parent c6cb68e700
commit 14ba81514f
7 changed files with 46 additions and 97 deletions

View File

@ -198,7 +198,7 @@ class RomaUnsbModel(BaseModel):
"""配置 CTNx 模型的特定选项""" """配置 CTNx 模型的特定选项"""
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))') parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm') parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator') parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency') parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
@ -206,14 +206,8 @@ class RomaUnsbModel(BaseModel):
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization') parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
parser.add_argument('--local_nums', type=int, default=64, help='number of local patches') parser.add_argument('--local_nums', type=int, default=64, help='number of local patches')
parser.add_argument('--side_length', type=int, default=7) parser.add_argument('--side_length', type=int, default=7)
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
type=util.str2bool, nargs='?', const=True, default=False,
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers') parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions') parser.add_argument('--eta_ratio', type=float, default=0.4, help='ratio of content-rich regions')
parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix') parser.add_argument('--gamma_stride', type=float, default=20, help='ratio of stride for computing the similarity matrix')
parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers') parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
@ -253,7 +247,7 @@ class RomaUnsbModel(BaseModel):
# 创建网络 # 创建网络
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt) self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
if self.isTrain: if self.isTrain:
@ -321,88 +315,28 @@ class RomaUnsbModel(BaseModel):
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B0 = self.netG(self.real_A0)
self.fake_B1 = self.netG(self.real_A1)
# ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============ if self.opt.isTrain:
tau = self.opt.tau
T = self.opt.num_timesteps
incs = np.array([0] + [1/(i+1) for i in range(T-1)])
times = np.cumsum(incs)
times = times / times[-1]
times = 0.5 * times[-1] + 0.5 * times #[0.5,1]
times = np.concatenate([np.zeros(1), times])
times = torch.tensor(times).float().cuda()
self.times = times
bs = self.real_A0.size(0)
time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
self.time_idx = time_idx
self.fake_B0_list = []
self.fake_B1_list = []
with torch.no_grad():
self.netG.eval()
# ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
for t in range(self.time_idx.int().item() + 1):
# 计算增量 delta 与 inter/scale用于每个时间步的插值等
if t > 0:
delta = times[t] - times[t - 1]
denom = times[-1] - times[t - 1]
inter = (delta / denom).reshape(-1, 1, 1, 1)
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
# 对 Xt、Xt2 进行随机噪声更新
Xt = self.real_A0 if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.real_A0.device)
time_idx = (t * torch.ones(size=[self.real_A0.shape[0]]).to(self.real_A0.device)).long()
z = torch.randn(size=[self.real_A0.shape[0], 4 * self.opt.ngf]).to(self.real_A0.device)
time = times[time_idx]
Xt_1 = self.netG(Xt.detach(), time, z)
Xt2 = self.real_A1 if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt2).to(self.real_A1.device)
time_idx = (t * torch.ones(size=[self.real_A1.shape[0]]).to(self.real_A1.device)).long()
z = torch.randn(size=[self.real_A1.shape[0], 4 * self.opt.ngf]).to(self.real_A1.device)
Xt_12 = self.netG(Xt2.detach(), time, z)
self.fake_B0_list.append(Xt_1)
self.fake_B1_list.append(Xt_12)
self.fake_B0_1 = self.fake_B0_list[0]
self.fake_B1_1 = self.fake_B0_list[0]
self.fake_B0 = self.fake_B0_list[-1]
self.fake_B1 = self.fake_B1_list[-1]
self.z_in = z
self.z_in2 = z
if self.opt.phase == 'train':
real_A0 = self.real_A0 real_A0 = self.real_A0
real_A1 = self.real_A1 real_A1 = self.real_A1
real_B0 = self.real_B0 real_B0 = self.real_B0
real_B1 = self.real_B1 real_B1 = self.real_B1
fake_B0 = self.fake_B0 fake_B0 = self.fake_B0
fake_B1 = self.fake_B1 fake_B1 = self.fake_B1
self.mutil_fake_B0_tokens_list = []
self.mutil_fake_B1_tokens_list = []
for fake_B0_t in self.fake_B0_list:
fake_B0_t_resize = self.resize(fake_B0_t) # 调整到 ViT 输入尺寸
tokens = self.netPreViT(fake_B0_t_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B0_tokens_list.append(tokens)
for fake_B1_t in self.fake_B1_list:
fake_B1_t_resize = self.resize(fake_B1_t)
tokens = self.netPreViT(fake_B1_t_resize, self.atten_layers, get_tokens=True)
self.mutil_fake_B1_tokens_list.append(tokens)
self.real_A0_resize = self.resize(real_A0) self.real_A0_resize = self.resize(real_A0)
self.real_A1_resize = self.resize(real_A1) self.real_A1_resize = self.resize(real_A1)
real_B0 = self.resize(real_B0) real_B0 = self.resize(real_B0)
real_B1 = self.resize(real_B1) real_B1 = self.resize(real_B1)
self.fake_B0_resize = self.resize(fake_B0) self.fake_B0_resize = self.resize(fake_B0)
self.fake_B1_resize = self.resize(fake_B1) self.fake_B1_resize = self.resize(fake_B1)
self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True) self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True) self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True) self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True) self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
# [[1,576,768],[1,576,768],[1,576,768]] self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
# [3,576,768] self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
def compute_D_loss(self): def compute_D_loss(self):
"""Calculate GAN loss with Content-Aware Optimization""" """Calculate GAN loss with Content-Aware Optimization"""
@ -414,27 +348,25 @@ class RomaUnsbModel(BaseModel):
real_B1_tokens = self.mutil_real_B1_tokens[0] real_B1_tokens = self.mutil_real_B1_tokens[0]
pred_real1, real_features1 = self.netD_ViT(real_B1_tokens) # scores, features pred_real1, real_features1 = self.netD_ViT(real_B1_tokens) # scores, features
for fake0_token, fake1_token in zip(self.mutil_fake_B0_tokens_list, self.mutil_fake_B1_tokens_list): pre_fake0, fake_features0 = self.netD_ViT(self.mutil_fake_B0_tokens[0].detach())
pre_fake0, fake_features0 = self.netD_ViT(fake0_token[0].detach()) pre_fake1, fake_features1 = self.netD_ViT(self.mutil_fake_B1_tokens[0].detach())
pre_fake1, fake_features1 = self.netD_ViT(fake1_token[0].detach()) loss_cao0, self.weight_real0, self.weight_fake0 = self.cao(
loss_cao0, self.weight_real0, self.weight_fake0 = self.cao( D_real=real_features0,
D_real=real_features0, D_fake=fake_features0,
D_fake=fake_features0, real_scores=pred_real0,
real_scores=pred_real0, fake_scores=pre_fake0
fake_scores=pre_fake0 )
) loss_cao1, self.weight_real1, self.weight_fake1 = self.cao(
loss_cao1, self.weight_real1, self.weight_fake1 = self.cao( D_real=real_features1,
D_real=real_features1, D_fake=fake_features1,
D_fake=fake_features1, real_scores=pred_real1,
real_scores=pred_real1, fake_scores=pre_fake1
fake_scores=pre_fake1 )
) loss_cao += loss_cao0 + loss_cao1
loss_cao += loss_cao0 + loss_cao1
# ===== 综合损失 ===== # ===== 综合损失 =====
total_steps = len(self.fake_B0_list) self.loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT
self.loss_D_ViT = loss_cao * 0.5 * lambda_D_ViT/ total_steps
# 记录损失值供可视化 # 记录损失值供可视化
@ -458,8 +390,8 @@ class RomaUnsbModel(BaseModel):
self.warped_fake_B1 = warp(self.fake_B1,self.f_content1) self.warped_fake_B1 = warp(self.fake_B1,self.f_content1)
# 经过第二次生成器 # 经过第二次生成器
self.warped_fake_B0_2 = self.netG(self.warped_real_A0, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in) self.warped_fake_B0_2 = self.netG(self.warped_real_A0)
self.warped_fake_B1_2 = self.netG(self.warped_real_A1, self.times[torch.zeros(size=[1]).cuda().long()], self.z_in2) self.warped_fake_B1_2 = self.netG(self.warped_real_A1)
warped_fake_B0_2=self.warped_fake_B0_2 warped_fake_B0_2=self.warped_fake_B0_2
warped_fake_B1_2=self.warped_fake_B1_2 warped_fake_B1_2=self.warped_fake_B1_2
@ -472,8 +404,8 @@ class RomaUnsbModel(BaseModel):
if self.opt.lambda_GAN > 0.0: if self.opt.lambda_GAN > 0.0:
pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens_list[-1][0]) pred_fake0,_ = self.netD_ViT(self.mutil_fake_B0_tokens[0])
pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens_list[-1][0]) pred_fake1,_ = self.netD_ViT(self.mutil_fake_B1_tokens[0])
self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean() self.loss_G_GAN0 = self.criterionGAN(pred_fake0, True).mean()
self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean() self.loss_G_GAN1 = self.criterionGAN(pred_fake1, True).mean()
self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5 self.loss_G_GAN = (self.loss_G_GAN0 + self.loss_G_GAN1)*0.5

View File

@ -36,7 +36,7 @@ class BaseOptions():
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--netD', type=str, default='basic_cond', choices=['basic_cond', 'basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') parser.add_argument('--netD', type=str, default='basic_cond', choices=['basic_cond', 'basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--netG', type=str, default='resnet_9blocks_cond', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture') parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture')
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G') parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D') parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')

17
scripts/traincp.sh Normal file
View File

@ -0,0 +1,17 @@
python train.py \
--dataroot /home/openxs/kunyu/datasets/InfraredCity-Lite/Double/Moitor \
--name cp_1 \
--dataset_mode unaligned_double \
--display_env CP \
--model roma_unsb \
--lambda_ctn 10 \
--lambda_inc 1.0 \
--lambda_global 6.0 \
--lambda_spatial 6.0 \
--gamma_stride 20 \
--lr 0.000001 \
--gpu_id 2 \
--eta_ratio 0.4 \
--n_epochs 100 \
--n_epochs_decay 100 \
# cp1 复现cptrans的效果