without cnt

This commit is contained in:
bishe 2025-02-23 23:15:25 +08:00
parent b4f00f4378
commit 67151c73f7
9 changed files with 73 additions and 36 deletions

View File

@ -46,3 +46,25 @@
================ Training Loss (Sun Feb 23 22:33:48 2025) ================ ================ Training Loss (Sun Feb 23 22:33:48 2025) ================
================ Training Loss (Sun Feb 23 22:39:16 2025) ================ ================ Training Loss (Sun Feb 23 22:39:16 2025) ================
================ Training Loss (Sun Feb 23 22:39:48 2025) ================ ================ Training Loss (Sun Feb 23 22:39:48 2025) ================
================ Training Loss (Sun Feb 23 22:41:34 2025) ================
================ Training Loss (Sun Feb 23 22:42:01 2025) ================
================ Training Loss (Sun Feb 23 22:44:17 2025) ================
================ Training Loss (Sun Feb 23 22:45:53 2025) ================
================ Training Loss (Sun Feb 23 22:46:48 2025) ================
================ Training Loss (Sun Feb 23 22:47:42 2025) ================
================ Training Loss (Sun Feb 23 22:49:44 2025) ================
================ Training Loss (Sun Feb 23 22:50:29 2025) ================
================ Training Loss (Sun Feb 23 22:51:47 2025) ================
================ Training Loss (Sun Feb 23 22:55:56 2025) ================
================ Training Loss (Sun Feb 23 22:56:19 2025) ================
================ Training Loss (Sun Feb 23 22:57:58 2025) ================
================ Training Loss (Sun Feb 23 22:59:09 2025) ================
================ Training Loss (Sun Feb 23 23:02:36 2025) ================
================ Training Loss (Sun Feb 23 23:03:56 2025) ================
================ Training Loss (Sun Feb 23 23:09:21 2025) ================
================ Training Loss (Sun Feb 23 23:10:05 2025) ================
================ Training Loss (Sun Feb 23 23:11:43 2025) ================
================ Training Loss (Sun Feb 23 23:12:41 2025) ================
================ Training Loss (Sun Feb 23 23:13:05 2025) ================
================ Training Loss (Sun Feb 23 23:13:59 2025) ================
================ Training Loss (Sun Feb 23 23:14:59 2025) ================

View File

@ -1,5 +1,5 @@
----------------- Options --------------- ----------------- Options ---------------
atten_layers: 1,3,5 atten_layers: 5
batch_size: 1 batch_size: 1
beta1: 0.5 beta1: 0.5
beta2: 0.999 beta2: 0.999
@ -28,10 +28,12 @@
init_type: xavier init_type: xavier
input_nc: 3 input_nc: 3
isTrain: True [default: None] isTrain: True [default: None]
lambda_D_ViT: 1.0
lambda_GAN: 8.0 [default: 1.0] lambda_GAN: 8.0 [default: 1.0]
lambda_NCE: 8.0 [default: 1.0] lambda_NCE: 8.0 [default: 1.0]
lambda_SB: 0.1 lambda_SB: 0.1
lambda_ctn: 1.0 lambda_ctn: 1.0
lambda_global: 1.0
lambda_inc: 1.0 lambda_inc: 1.0
lmda_1: 0.1 lmda_1: 0.1
load_size: 286 load_size: 286
@ -50,7 +52,7 @@
nce_includes_all_negatives_from_minibatch: False nce_includes_all_negatives_from_minibatch: False
nce_layers: 0,4,8,12,16 nce_layers: 0,4,8,12,16
ndf: 64 ndf: 64
netD: basic netD: basic_cond
netF: mlp_sample netF: mlp_sample
netF_nc: 256 netF_nc: 256
netG: resnet_9blocks_cond netG: resnet_9blocks_cond
@ -78,7 +80,7 @@ nce_includes_all_negatives_from_minibatch: False
serial_batches: False serial_batches: False
stylegan2_G_num_downsampling: 1 stylegan2_G_num_downsampling: 1
suffix: suffix:
tau: 0.1 [default: 0.01] tau: 0.01
update_html_freq: 1000 update_html_freq: 1000
use_idt: False use_idt: False
verbose: False verbose: False

View File

@ -331,6 +331,8 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
elif 'stylegan2' in netD: elif 'stylegan2' in netD:
net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, no_antialias=no_antialias, opt=opt) net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, no_antialias=no_antialias, opt=opt)
elif netD == 'basic_cond': # more options
net = NLayerDiscriminator_ncsn(input_nc, ndf, n_layers=3, norm_layer=norm_layer, no_antialias=no_antialias)
else: else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
return init_net(net, init_type, init_gain, gpu_ids, return init_net(net, init_type, init_gain, gpu_ids,

View File

@ -200,6 +200,8 @@ class RomaUnsbModel(BaseModel):
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)') parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss') parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm') parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))') parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers') parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
@ -220,7 +222,7 @@ class RomaUnsbModel(BaseModel):
parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization') parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
parser.add_argument('--eta_ratio', type=float, default=0.1, help='ratio of content-rich regions') parser.add_argument('--eta_ratio', type=float, default=0.1, help='ratio of content-rich regions')
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers') parser.add_argument('--atten_layers', type=str, default='5', help='compute Cross-Similarity on which layers')
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter') parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer') parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
@ -258,7 +260,7 @@ class RomaUnsbModel(BaseModel):
self.visual_names += ['idt_B'] self.visual_names += ['idt_B']
if self.isTrain: if self.isTrain:
self.model_names = ['G', 'D', 'E'] self.model_names = ['G', 'D_ViT', 'E']
else: else:
@ -270,22 +272,24 @@ class RomaUnsbModel(BaseModel):
if self.isTrain: if self.isTrain:
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
self.resize = tfs.Resize(size=(384,384), antialias=True) self.resize = tfs.Resize(size=(384,384), antialias=True)
self.netD_ViT = networks.MLPDiscriminator().to(self.device)
# 加入预训练VIT # 加入预训练VIT
self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device) self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device)
# 定义损失函数 # 定义损失函数
self.criterionL1 = torch.nn.L1Loss().to(self.device)
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
self.criterionNCE = [] self.criterionNCE = []
for nce_layer in self.nce_layers: for nce_layer in self.nce_layers:
self.criterionNCE.append(PatchNCELoss(opt).to(self.device)) self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
self.criterionIdt = torch.nn.L1Loss().to(self.device) self.criterionIdt = torch.nn.L1Loss().to(self.device)
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_D = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2)) self.optimizer_E = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
self.optimizers = [self.optimizer_G, self.optimizer_D, self.optimizer_E] self.optimizers = [self.optimizer_G, self.optimizer_D, self.optimizer_E]
@ -320,10 +324,10 @@ class RomaUnsbModel(BaseModel):
self.netG.train() self.netG.train()
self.netE.train() self.netE.train()
self.netD.train() self.netD_ViT.train()
# update D # update D
self.set_requires_grad(self.netD, True) self.set_requires_grad(self.netD_ViT, True)
self.optimizer_D.zero_grad() self.optimizer_D.zero_grad()
self.loss_D = self.compute_D_loss() self.loss_D = self.compute_D_loss()
self.loss_D.backward() self.loss_D.backward()
@ -337,7 +341,7 @@ class RomaUnsbModel(BaseModel):
self.optimizer_E.step() self.optimizer_E.step()
# update G # update G
self.set_requires_grad(self.netD, False) self.set_requires_grad(self.netD_ViT, False)
self.set_requires_grad(self.netE, False) self.set_requires_grad(self.netE, False)
self.optimizer_G.zero_grad() self.optimizer_G.zero_grad()
@ -443,7 +447,7 @@ class RomaUnsbModel(BaseModel):
# ============ 第三步:拼接输入并执行网络推理 ============= # ============ 第三步:拼接输入并执行网络推理 =============
bs = self.real_A0.size(0) bs = self.real_A0.size(0)
z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.real_A0.device) z_in = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A0.device)
z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device) z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.real_A1.device)
# 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB # 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
self.real = self.real_A0 self.real = self.real_A0
@ -455,9 +459,10 @@ class RomaUnsbModel(BaseModel):
self.real = torch.flip(self.real, [3]) self.real = torch.flip(self.real, [3])
self.realt = torch.flip(self.realt, [3]) self.realt = torch.flip(self.realt, [3])
print(f'fake_B0: {self.real_A0.shape}, fake_B1: {self.real_A1.shape}')
self.fake_B0 = self.netG(self.real_A0, self.time, z_in) self.fake_B0 = self.netG(self.real_A0, self.time, z_in)
self.fake_B1 = self.netG(self.real_A1, self.time, z_in2) self.fake_B1 = self.netG(self.real_A1, self.time, z_in2)
print(f'fake_B0: {self.fake_B0.shape}, fake_B1: {self.fake_B1.shape}')
if self.opt.phase == 'train': if self.opt.phase == 'train':
real_A0 = self.real_A0 real_A0 = self.real_A0
@ -507,23 +512,35 @@ class RomaUnsbModel(BaseModel):
#self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True) #self.mutil_fake_B0_2_tokens = self.netPreViT(self.warped_fake_B0_2_resize, self.atten_layers, get_tokens=True)
def compute_D_loss(self): def compute_D_loss(self): #判别器还是没有改
"""计算判别器的 GAN 损失""" """Calculate GAN loss for the discriminator"""
fake = self.cat_results(self.fake_B.detach()) lambda_D_ViT = self.opt.lambda_D_ViT
pred_fake = self.netD(fake, self.time) fake_B0_tokens = self.mutil_fake_B0_tokens[0].detach()
self.loss_D_fake = self.criterionGAN(pred_fake, False).mean() fake_B1_tokens = self.mutil_fake_B1_tokens[0].detach()
self.pred_real = self.netD(self.real_B0, self.time) real_B0_tokens = self.mutil_real_B0_tokens[0]
loss_D_real = self.criterionGAN(self.pred_real, True) real_B1_tokens = self.mutil_real_B1_tokens[0]
self.loss_D_real = loss_D_real.mean()
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
return self.loss_D pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
pred_real0_ViT = self.netD_ViT(real_B0_tokens)
pred_real1_ViT = self.netD_ViT(real_B1_tokens)
self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
return self.loss_D_ViT
def compute_E_loss(self): def compute_E_loss(self):
"""计算判别器 E 的损失""" """计算判别器 E 的损失"""
print(f'resl_A_noisy: {self.real_A_noisy.shape} \n fake_B0: {self.fake_B0.shape}')
XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1) XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B0.detach()], dim=1)
XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1) XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B1.detach()], dim=1)
temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean() temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
@ -534,14 +551,8 @@ class RomaUnsbModel(BaseModel):
def compute_G_loss(self): def compute_G_loss(self):
"""计算生成器的 GAN 损失""" """计算生成器的 GAN 损失"""
bs = self.real_A0.size(0)
tau = self.opt.tau
fake = self.fake_B0
std = torch.rand(size=[1]).item() * self.opt.std
if self.opt.lambda_GAN > 0.0: if self.opt.lambda_GAN > 0.0:
pred_fake = self.netD(fake, self.time) pred_fake = self.netD_ViT(self.mutil_fake_B0_tokens[0])
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
else: else:
self.loss_G_GAN = 0.0 self.loss_G_GAN = 0.0
@ -555,7 +566,7 @@ class RomaUnsbModel(BaseModel):
# eq.9 # eq.9
ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0) ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0)
self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B) ** 2) self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B0) ** 2)
if self.opt.lambda_global > 0.0: if self.opt.lambda_global > 0.0:
loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1) loss_global = self.calculate_similarity(self.real_A0, self.fake_B0) + self.calculate_similarity(self.real_A1, self.fake_B1)

View File

@ -35,7 +35,7 @@ class BaseOptions():
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') parser.add_argument('--netD', type=str, default='basic_cond', choices=['basic_cond', 'basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--netG', type=str, default='resnet_9blocks_cond', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture') parser.add_argument('--netG', type=str, default='resnet_9blocks_cond', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture')
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G') parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')

View File

@ -28,6 +28,6 @@ python train.py \
--num_patches 256 \ --num_patches 256 \
--flip_equivariance False \ --flip_equivariance False \
--eta_ratio 0.1 \ --eta_ratio 0.1 \
--tau 0.1 \ --tau 0.01 \
--num_timesteps 10 \ --num_timesteps 10 \
--input_nc 3 --input_nc 3