fix about netG

This commit is contained in:
bishe 2025-02-23 18:42:21 +08:00
parent a798da6b32
commit 1caa5f0625
11 changed files with 770 additions and 13 deletions

View File

@ -6,3 +6,22 @@
================ Training Loss (Sun Feb 23 16:06:44 2025) ================ ================ Training Loss (Sun Feb 23 16:06:44 2025) ================
================ Training Loss (Sun Feb 23 16:09:38 2025) ================ ================ Training Loss (Sun Feb 23 16:09:38 2025) ================
================ Training Loss (Sun Feb 23 16:44:56 2025) ================ ================ Training Loss (Sun Feb 23 16:44:56 2025) ================
================ Training Loss (Sun Feb 23 16:49:46 2025) ================
================ Training Loss (Sun Feb 23 16:51:03 2025) ================
================ Training Loss (Sun Feb 23 16:51:23 2025) ================
================ Training Loss (Sun Feb 23 18:04:02 2025) ================
================ Training Loss (Sun Feb 23 18:04:39 2025) ================
================ Training Loss (Sun Feb 23 18:05:17 2025) ================
================ Training Loss (Sun Feb 23 18:06:40 2025) ================
================ Training Loss (Sun Feb 23 18:11:48 2025) ================
================ Training Loss (Sun Feb 23 18:13:31 2025) ================
================ Training Loss (Sun Feb 23 18:14:11 2025) ================
================ Training Loss (Sun Feb 23 18:14:29 2025) ================
================ Training Loss (Sun Feb 23 18:16:27 2025) ================
================ Training Loss (Sun Feb 23 18:16:44 2025) ================
================ Training Loss (Sun Feb 23 18:20:39 2025) ================
================ Training Loss (Sun Feb 23 18:21:44 2025) ================
================ Training Loss (Sun Feb 23 18:35:27 2025) ================
================ Training Loss (Sun Feb 23 18:39:21 2025) ================
================ Training Loss (Sun Feb 23 18:40:15 2025) ================
================ Training Loss (Sun Feb 23 18:41:15 2025) ================

View File

@ -43,6 +43,7 @@
n_epochs: 100 n_epochs: 100
n_epochs_decay: 100 n_epochs_decay: 100
n_layers_D: 3 n_layers_D: 3
n_mlp: 3
name: ROMA_UNSB_001 [default: experiment_name] name: ROMA_UNSB_001 [default: experiment_name]
nce_T: 0.07 nce_T: 0.07
nce_idt: False [default: True] nce_idt: False [default: True]
@ -52,7 +53,7 @@ nce_includes_all_negatives_from_minibatch: False
netD: basic netD: basic
netF: mlp_sample netF: mlp_sample
netF_nc: 256 netF_nc: 256
netG: resnet_9blocks netG: resnet_9blocks_cond
ngf: 64 ngf: 64
no_antialias: False no_antialias: False
no_antialias_up: False no_antialias_up: False
@ -63,7 +64,7 @@ nce_includes_all_negatives_from_minibatch: False
normG: instance normG: instance
num_patches: 256 num_patches: 256
num_threads: 4 num_threads: 4
num_timesteps: 10 num_timesteps: 10 [default: 5]
output_nc: 3 output_nc: 3
phase: train phase: train
pool_size: 0 pool_size: 0
@ -77,7 +78,7 @@ nce_includes_all_negatives_from_minibatch: False
serial_batches: False serial_batches: False
stylegan2_G_num_downsampling: 1 stylegan2_G_num_downsampling: 1
suffix: suffix:
tau: 0.1 tau: 0.1 [default: 0.01]
update_html_freq: 1000 update_html_freq: 1000
use_idt: False use_idt: False
verbose: False verbose: False

Binary file not shown.

719
models/ncsn_networks.py Normal file
View File

@ -0,0 +1,719 @@
"""
The network architectures is based on the implementation of CycleGAN and CUT
Original PyTorch repo of CycleGAN: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
Original PyTorch repo of CUT: https://github.com/taesungp/contrastive-unpaired-translation
Original CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
Original CUT paper: https://arxiv.org/pdf/2007.15651.pdf
We use the network architecture for our default modal image translation
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import functools
import numpy as np
from torch.nn import init
import math
class PixelNorm(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
##################################################################################
# Discriminator
##################################################################################
class D_NLayersMulti(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3,
norm_layer=nn.BatchNorm2d, num_D=1):
super(D_NLayersMulti, self).__init__()
# st()
self.num_D = num_D
if num_D == 1:
layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
self.model = nn.Sequential(*layers)
else:
layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
self.add_module("model_0", nn.Sequential(*layers))
self.down = nn.AvgPool2d(3, stride=2, padding=[
1, 1], count_include_pad=False)
for i in range(1, num_D):
ndf_i = int(round(ndf / (2**i)))
layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer)
self.add_module("model_%d" % i, nn.Sequential(*layers))
def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw,
stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1,
kernel_size=kw, stride=1, padding=padw)]
return sequence
def forward(self, input):
if self.num_D == 1:
return self.model(input)
result = []
down = input
for i in range(self.num_D):
model = getattr(self, "model_%d" % i)
result.append(model(down))
if i != self.num_D - 1:
down = self.down(down)
return result
class ConvBlock_cond(nn.Module):
def __init__(self, in_channel, out_channel,t_emb_dim, kernel_size=4,stride=1,padding=1,norm_layer=None,downsample=True,use_bias=None):
super().__init__()
self.downsample=downsample
self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=use_bias)
if norm_layer is not None:
self.use_norm =True
self.norm = norm_layer(out_channel)
else:
self.use_norm = False
self.act = nn.LeakyReLU(0.2, True)
self.down = Downsample(out_channel)
self.dense= nn.Linear(t_emb_dim, out_channel)
def forward(self, input,t_emb):
out = self.conv1(input)
out += self.dense(t_emb)[..., None, None]
if self.use_norm:
out = self.norm(out)
out = self.act(out)
if self.downsample:
out = self.down(out)
return out
class NLayerDiscriminator_ncsn(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator_ncsn, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
self.model_main = nn.ModuleList()
kw = 4
padw = 1
if no_antialias:
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
else:
self.model_main.append(ConvBlock_cond(input_nc, ndf, 4*ndf,kernel_size=kw, stride=1, padding=padw,use_bias=use_bias))
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
if no_antialias:
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)]
else:
self.model_main.append(
ConvBlock_cond(ndf * nf_mult_prev, ndf * nf_mult, 4*ndf,kernel_size=kw, stride=1, padding=padw,use_bias=use_bias,norm_layer=norm_layer)
)
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
self.model_main.append(
ConvBlock_cond(ndf * nf_mult_prev, ndf * nf_mult,4*ndf, kernel_size=kw, stride=1, padding=padw,use_bias=use_bias,norm_layer=norm_layer,downsample=False)
)
self.final_conv =nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
self.t_embed = TimestepEmbedding(
embedding_dim=4*ndf,
hidden_dim=4*ndf,
output_dim=4*ndf,
act=nn.LeakyReLU(0.2),
)
def forward(self, input,t_emb,input2=None):
"""Standard forward."""
t_emb = self.t_embed(t_emb)
if input2 is not None:
out = torch.cat([input,input2],dim=1)
else:
out = input
for layer in self.model_main:
out = layer(out,t_emb)
return self.final_conv(out)
class PixelDiscriminator(nn.Module):
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
"""Construct a 1x1 PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
"""
super(PixelDiscriminator, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
self.net = [
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
norm_layer(ndf * 2),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
self.net = nn.Sequential(*self.net)
def forward(self, input):
"""Standard forward."""
return self.net(input)
##################################################################################
# Generator
##################################################################################
class TimestepEmbedding(nn.Module):
def __init__(self, embedding_dim, hidden_dim, output_dim, act=nn.LeakyReLU(0.2)):
super().__init__()
self.embedding_dim = embedding_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim
self.main = nn.Sequential(
nn.Linear(embedding_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, output_dim),
nn.LeakyReLU(0.2),
# EqualLinear(hidden_dim, output_dim,bias_init = 0, activation='fused_lrelu')
)
def forward(self, temp):
temb = get_timestep_embedding(temp, self.embedding_dim)
temb = self.main(temb)
return temb
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
class AdaptiveLayer(nn.Module):
def __init__(self, in_channel, style_dim):
super().__init__()
self.style_net = nn.Linear(style_dim, in_channel * 2)
self.style_net.bias.data[:in_channel] = 1
self.style_net.bias.data[in_channel:] = 0
def forward(self, input, style):
style = self.style_net(style).unsqueeze(2).unsqueeze(3)
gamma, beta = style.chunk(2, 1)
out = gamma * input + beta
return out
class ResnetGenerator_ncsn(nn.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9,
padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
"""Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers
n_blocks (int) -- the number of ResNet blocks
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
"""
assert(n_blocks >= 0)
super(ResnetGenerator_ncsn, self).__init__()
self.opt = opt
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
self.ngf = ngf
n_downsampling = 2
for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
if no_antialias:
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
else:
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True),
Downsample(ngf * mult * 2)
# nn.AvgPool2d(kernel_size=2, stride=2)
]
self.model_res = nn.ModuleList()
mult = 2 ** n_downsampling
for i in range(n_blocks): # add ResNet blocks
self.model_res += [ResnetBlock_cond(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias,temb_dim=4*ngf,z_dim=4*ngf)]
model_upsample = []
for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
if no_antialias_up:
model_upsample += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
else:
model_upsample += [
Upsample(ngf * mult),
# nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=1, bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model_upsample += [nn.ReflectionPad2d(3)]
model_upsample += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model_upsample += [nn.Tanh()]
self.model = nn.Sequential(*model)
self.model_upsample = nn.Sequential(*model_upsample)
mapping_layers = [PixelNorm(),
nn.Linear(self.ngf*4, self.ngf*4),
nn.LeakyReLU(0.2)]
for _ in range(opt.n_mlp):
mapping_layers.append(nn.Linear(self.ngf*4, self.ngf*4))
mapping_layers.append(nn.LeakyReLU(0.2))
self.z_transform = nn.Sequential(*mapping_layers)
modules_emb = []
modules_emb += [nn.Linear(self.ngf,self.ngf*4)]
nn.init.zeros_(modules_emb[-1].bias)
modules_emb += [nn.LeakyReLU(0.2)]
modules_emb += [nn.Linear(self.ngf*4,self.ngf*4)]
nn.init.zeros_(modules_emb[-1].bias)
modules_emb += [nn.LeakyReLU(0.2)]
self.time_embed = nn.Sequential(*modules_emb)
def forward(self, x, time_cond,z,layers=[], encode_only=False):
z_embed = self.z_transform(z)
# print(z_embed.shape)
temb = get_timestep_embedding(time_cond, self.ngf)
time_embed = self.time_embed(temb)
if len(layers) > 0:
feat = x
feats = []
for layer_id, layer in enumerate(self.model):
feat = layer(feat)
if layer_id in layers:
feats.append(feat)
for layer_id, layer in enumerate(self.model_res):
feat = layer(feat,time_embed,z_embed)
if layer_id+len(self.model) in layers:
feats.append(feat)
if layer_id+len(self.model) == layers[-1] and encode_only:
return feats
return feat, feats
else:
out = self.model(x)
for layer in self.model_res:
out = layer(out,time_embed,z_embed)
out = self.model_upsample(out)
return out
##################################################################################
# Basic Blocks
##################################################################################
class ResnetBlock(nn.Module):
"""Define a Resnet block"""
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
"""Initialize the Resnet block
A resnet block is a conv block with skip connections
We construct a conv block with build_conv_block function,
and implement skip connections in <forward> function.
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
"""
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
"""Construct a convolutional block.
Parameters:
dim (int) -- the number of channels in the conv layer.
padding_type (str) -- the name of padding layer: reflect | replicate | zero
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers.
use_bias (bool) -- if the conv layer uses bias or not
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
"""
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
"""Forward function (with skip connections)"""
out = x + self.conv_block(x) # add skip connections
return out
class ResnetBlock_cond(nn.Module):
"""Define a Resnet block"""
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim):
"""Initialize the Resnet block
A resnet block is a conv block with skip connections
We construct a conv block with build_conv_block function,
and implement skip connections in <forward> function.
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
"""
super(ResnetBlock_cond, self).__init__()
self.conv_block,self.adaptive,self.conv_fin = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim):
"""Construct a convolutional block.
Parameters:
dim (int) -- the number of channels in the conv layer.
padding_type (str) -- the name of padding layer: reflect | replicate | zero
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers.
use_bias (bool) -- if the conv layer uses bias or not
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
"""
self.conv_block = nn.ModuleList()
self.conv_fin = nn.ModuleList()
p = 0
if padding_type == 'reflect':
self.conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
self.conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
self.conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
self.adaptive = AdaptiveLayer(dim,z_dim)
self.conv_fin += [nn.ReLU(True)]
if use_dropout:
self.conv_fin += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
self.conv_fin += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
self.conv_fin += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
self.conv_fin += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
self.Dense_time = nn.Linear(temb_dim, dim)
# self.Dense_time.weight.data = default_init()(self.Dense_time.weight.data.shape)
nn.init.zeros_(self.Dense_time.bias)
self.style = nn.Linear(z_dim, dim * 2)
self.style.bias.data[:dim] = 1
self.style.bias.data[dim:] = 0
return self.conv_block,self.adaptive,self.conv_fin
def forward(self, x,time_cond,z):
time_input = self.Dense_time(time_cond)
for n,layer in enumerate(self.conv_block):
out = layer(x)
if n==0:
out += time_input[:, :, None, None]
out = self.adaptive(out,z)
for layer in self.conv_fin:
out = layer(out)
"""Forward function (with skip connections)"""
out = x + out # add skip connections
return out
###############################################################################
# Helper Functions
###############################################################################
def get_filter(filt_size=3):
if(filt_size == 1):
a = np.array([1., ])
elif(filt_size == 2):
a = np.array([1., 1.])
elif(filt_size == 3):
a = np.array([1., 2., 1.])
elif(filt_size == 4):
a = np.array([1., 3., 3., 1.])
elif(filt_size == 5):
a = np.array([1., 4., 6., 4., 1.])
elif(filt_size == 6):
a = np.array([1., 5., 10., 10., 5., 1.])
elif(filt_size == 7):
a = np.array([1., 6., 15., 20., 15., 6., 1.])
filt = torch.Tensor(a[:, None] * a[None, :])
filt = filt / torch.sum(filt)
return filt
class Downsample(nn.Module):
def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0):
super(Downsample, self).__init__()
self.filt_size = filt_size
self.pad_off = pad_off
self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
self.stride = stride
self.off = int((self.stride - 1) / 2.)
self.channels = channels
filt = get_filter(filt_size=self.filt_size)
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
def forward(self, inp):
if(self.filt_size == 1):
if(self.pad_off == 0):
return inp[:, :, ::self.stride, ::self.stride]
else:
return self.pad(inp)[:, :, ::self.stride, ::self.stride]
else:
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
class Upsample2(nn.Module):
def __init__(self, scale_factor, mode='nearest'):
super().__init__()
self.factor = scale_factor
self.mode = mode
def forward(self, x):
return torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode)
class Upsample(nn.Module):
def __init__(self, channels, pad_type='repl', filt_size=4, stride=2):
super(Upsample, self).__init__()
self.filt_size = filt_size
self.filt_odd = np.mod(filt_size, 2) == 1
self.pad_size = int((filt_size - 1) / 2)
self.stride = stride
self.off = int((self.stride - 1) / 2.)
self.channels = channels
filt = get_filter(filt_size=self.filt_size) * (stride**2)
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
self.pad = get_pad_layer(pad_type)([1, 1, 1, 1])
def forward(self, inp):
ret_val = F.conv_transpose2d(self.pad(inp), self.filt, stride=self.stride, padding=1 + self.pad_size, groups=inp.shape[1])[:, :, 1:, 1:]
if(self.filt_odd):
return ret_val
else:
return ret_val[:, :, :-1, :-1]
def get_pad_layer(pad_type):
if(pad_type in ['refl', 'reflect']):
PadLayer = nn.ReflectionPad2d
elif(pad_type in ['repl', 'replicate']):
PadLayer = nn.ReplicationPad2d
elif(pad_type == 'zero'):
PadLayer = nn.ZeroPad2d
else:
print('Pad type [%s] not recognized' % pad_type)
return PadLayer
class Identity(nn.Module):
def forward(self, x):
return x
def get_norm_layer(norm_type='instance'):
"""Return a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == 'none':
def norm_layer(x): return Identity()
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if debug:
print(classname)
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Return an initialized network.
"""
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
if initialize_weights:
init_weights(net, init_type, init_gain=init_gain, debug=debug)
return net

View File

@ -7,6 +7,7 @@ from torch.optim import lr_scheduler
import numpy as np import numpy as np
import random import random
from .stylegan_networks import StyleGAN2Discriminator, StyleGAN2Generator, TileStyleGAN2Discriminator from .stylegan_networks import StyleGAN2Discriminator, StyleGAN2Generator, TileStyleGAN2Discriminator
from .ncsn_networks import NLayerDiscriminator_ncsn, ResnetGenerator_ncsn
############################################################################### ###############################################################################
# Helper Functions # Helper Functions
@ -266,6 +267,8 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
elif netG == 'resnet_cat': elif netG == 'resnet_cat':
n_blocks = 8 n_blocks = 8
net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu') net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu')
elif netG == 'resnet_9blocks_cond':
net = ResnetGenerator_ncsn(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=9, opt=opt)
else: else:
raise NotImplementedError('Generator model name [%s] is not recognized' % netG) raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG)) return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG))
@ -977,6 +980,7 @@ class ResnetGenerator(nn.Module):
feats = [] feats = []
for layer_id, layer in enumerate(self.model): for layer_id, layer in enumerate(self.model):
# print(layer_id, layer) # print(layer_id, layer)
print(feat.shape)
feat = layer(feat) feat = layer(feat)
if layer_id in layers: if layer_id in layers:
# print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1))) # print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
@ -984,6 +988,11 @@ class ResnetGenerator(nn.Module):
else: else:
# print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1))) # print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
pass pass
print(f"layer_id: {layer_id}, type(layer_id): {type(layer_id)}")
print(f"len(list(self.model)) - 1: {len(list(self.model)) - 1}, type(len(list(self.model)) - 1): {type(len(list(self.model)) - 1)}")
# Print layers to see what it is. If layers is still a tensor here, we need to understand why.
print(f"layers: {layers}, type(layers): {type(layers)}")
print(f"encode_only'shape: {encode_only.shape}, type(encode_only): {type(encode_only)}")
if layer_id == len(list(self.model)) - 1 and encode_only: if layer_id == len(list(self.model)) - 1 and encode_only:
# print('encoder only return features') # print('encoder only return features')
return feats # return intermediate features alone; stop in the last layers return feats # return intermediate features alone; stop in the last layers

View File

@ -221,8 +221,10 @@ class RomaUnsbModel(BaseModel):
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers') parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
parser.add_argument('--tau', type=float, default=0.1, help='used in unsb') parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
parser.add_argument('--num_timesteps', type=int, default=10, help='used in unsb') parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
parser.set_defaults(pool_size=0) # no image pooling parser.set_defaults(pool_size=0) # no image pooling
@ -260,7 +262,8 @@ class RomaUnsbModel(BaseModel):
else: else:
self.model_names = ['G'] self.model_names = ['G']
print(f'input_nc = {self.opt.input_nc}')
# 创建网络 # 创建网络
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt) self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
@ -269,7 +272,7 @@ class RomaUnsbModel(BaseModel):
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
self.resize = tfs.Resize(size=(384,384)) self.resize = tfs.Resize(size=(384,384), antialias=True)
# 加入预训练VIT # 加入预训练VIT
self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device) self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device)
@ -397,16 +400,20 @@ class RomaUnsbModel(BaseModel):
"""执行前向传递以生成输出图像""" """执行前向传递以生成输出图像"""
if self.opt.isTrain: if self.opt.isTrain:
print(f'before resize: {self.real_A0.shape}')
real_A0 = self.resize(self.real_A0) real_A0 = self.resize(self.real_A0)
real_A1 = self.resize(self.real_A1) real_A1 = self.resize(self.real_A1)
real_B0 = self.resize(self.real_B0) real_B0 = self.resize(self.real_B0)
real_B1 = self.resize(self.real_B1) real_B1 = self.resize(self.real_B1)
# 使用VIT # 使用VIT
print(f'before vit: {real_A0.shape}')
self.mutil_real_A0_tokens = self.netPreViT(real_A0, self.atten_layers, get_tokens=True) self.mutil_real_A0_tokens = self.netPreViT(real_A0, self.atten_layers, get_tokens=True)
self.mutil_real_A1_tokens = self.netPreViT(real_A1, self.atten_layers, get_tokens=True) self.mutil_real_A1_tokens = self.netPreViT(real_A1, self.atten_layers, get_tokens=True)
self.mutil_real_A0_tokens = torch.cat(self.mutil_real_A0_tokens, dim=1).to(self.device) print(f'before cat: len = {len(self.mutil_real_A0_tokens)}\n{self.mutil_real_A0_tokens[0].shape}')
self.mutil_real_A1_tokens = torch.cat(self.mutil_real_A1_tokens, dim=1).to(self.device) self.mutil_real_A0_tokens = torch.cat(self.mutil_real_A0_tokens, dim=0).unsqueeze(0).to(self.device)
self.mutil_real_A1_tokens = torch.cat(self.mutil_real_A1_tokens, dim=0).unsqueeze(0).to(self.device)
# 执行一次SB模块 # 执行一次SB模块
@ -436,6 +443,7 @@ class RomaUnsbModel(BaseModel):
inter = (delta / denom).reshape(-1, 1, 1, 1) inter = (delta / denom).reshape(-1, 1, 1, 1)
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1) scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
print(f'before noisy: {self.mutil_real_A0_tokens.shape}')
# 对 Xt、Xt2 进行随机噪声更新 # 对 Xt、Xt2 进行随机噪声更新
Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \ Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device) (scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
@ -454,7 +462,8 @@ class RomaUnsbModel(BaseModel):
self.real_A_noisy = Xt.detach() self.real_A_noisy = Xt.detach()
self.real_A_noisy2 = Xt2.detach() self.real_A_noisy2 = Xt2.detach()
# 保存noisy_map # 保存noisy_map
self.noisy_map = self.real_A_noisy - self.real_A0 print(f'after noisy map: {self.real_A_noisy.shape}')
self.noisy_map = self.real_A_noisy - self.mutil_real_A0_tokens
# ============ 第三步:拼接输入并执行网络推理 ============= # ============ 第三步:拼接输入并执行网络推理 =============
bs = self.mutil_real_A0_tokens.size(0) bs = self.mutil_real_A0_tokens.size(0)
@ -518,7 +527,7 @@ class RomaUnsbModel(BaseModel):
if self.opt.phase == 'train': if self.opt.phase == 'train':
# 真实图像的梯度 # 真实图像的梯度
real_gradient = torch.autograd.grad(self.real_B.sum(), self.real_B, create_graph=True)[0] real_gradient = torch.autograd.grad(self.real_B0.sum(), self.real_B0, create_graph=True)[0]
# 生成图像的梯度 # 生成图像的梯度
fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0] fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0]
# 梯度图 # 梯度图

View File

@ -36,7 +36,7 @@ class BaseOptions():
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat'], help='specify generator architecture') parser.add_argument('--netG', type=str, default='resnet_9blocks_cond', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture')
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G') parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D') parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')

View File

@ -30,4 +30,4 @@ python train.py \
--eta_ratio 0.1 \ --eta_ratio 0.1 \
--tau 0.1 \ --tau 0.1 \
--num_timesteps 10 \ --num_timesteps 10 \
--input_nc 1 --input_nc 3