fix about netG
This commit is contained in:
parent
a798da6b32
commit
1caa5f0625
@ -6,3 +6,22 @@
|
|||||||
================ Training Loss (Sun Feb 23 16:06:44 2025) ================
|
================ Training Loss (Sun Feb 23 16:06:44 2025) ================
|
||||||
================ Training Loss (Sun Feb 23 16:09:38 2025) ================
|
================ Training Loss (Sun Feb 23 16:09:38 2025) ================
|
||||||
================ Training Loss (Sun Feb 23 16:44:56 2025) ================
|
================ Training Loss (Sun Feb 23 16:44:56 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 16:49:46 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 16:51:03 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 16:51:23 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:04:02 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:04:39 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:05:17 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:06:40 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:11:48 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:13:31 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:14:11 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:14:29 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:16:27 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:16:44 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:20:39 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:21:44 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:35:27 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:39:21 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:40:15 2025) ================
|
||||||
|
================ Training Loss (Sun Feb 23 18:41:15 2025) ================
|
||||||
|
|||||||
@ -43,6 +43,7 @@
|
|||||||
n_epochs: 100
|
n_epochs: 100
|
||||||
n_epochs_decay: 100
|
n_epochs_decay: 100
|
||||||
n_layers_D: 3
|
n_layers_D: 3
|
||||||
|
n_mlp: 3
|
||||||
name: ROMA_UNSB_001 [default: experiment_name]
|
name: ROMA_UNSB_001 [default: experiment_name]
|
||||||
nce_T: 0.07
|
nce_T: 0.07
|
||||||
nce_idt: False [default: True]
|
nce_idt: False [default: True]
|
||||||
@ -52,7 +53,7 @@ nce_includes_all_negatives_from_minibatch: False
|
|||||||
netD: basic
|
netD: basic
|
||||||
netF: mlp_sample
|
netF: mlp_sample
|
||||||
netF_nc: 256
|
netF_nc: 256
|
||||||
netG: resnet_9blocks
|
netG: resnet_9blocks_cond
|
||||||
ngf: 64
|
ngf: 64
|
||||||
no_antialias: False
|
no_antialias: False
|
||||||
no_antialias_up: False
|
no_antialias_up: False
|
||||||
@ -63,7 +64,7 @@ nce_includes_all_negatives_from_minibatch: False
|
|||||||
normG: instance
|
normG: instance
|
||||||
num_patches: 256
|
num_patches: 256
|
||||||
num_threads: 4
|
num_threads: 4
|
||||||
num_timesteps: 10
|
num_timesteps: 10 [default: 5]
|
||||||
output_nc: 3
|
output_nc: 3
|
||||||
phase: train
|
phase: train
|
||||||
pool_size: 0
|
pool_size: 0
|
||||||
@ -77,7 +78,7 @@ nce_includes_all_negatives_from_minibatch: False
|
|||||||
serial_batches: False
|
serial_batches: False
|
||||||
stylegan2_G_num_downsampling: 1
|
stylegan2_G_num_downsampling: 1
|
||||||
suffix:
|
suffix:
|
||||||
tau: 0.1
|
tau: 0.1 [default: 0.01]
|
||||||
update_html_freq: 1000
|
update_html_freq: 1000
|
||||||
use_idt: False
|
use_idt: False
|
||||||
verbose: False
|
verbose: False
|
||||||
|
|||||||
BIN
models/__pycache__/ncsn_networks.cpython-39.pyc
Normal file
BIN
models/__pycache__/ncsn_networks.cpython-39.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
719
models/ncsn_networks.py
Normal file
719
models/ncsn_networks.py
Normal file
@ -0,0 +1,719 @@
|
|||||||
|
"""
|
||||||
|
The network architectures is based on the implementation of CycleGAN and CUT
|
||||||
|
Original PyTorch repo of CycleGAN: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
|
||||||
|
Original PyTorch repo of CUT: https://github.com/taesungp/contrastive-unpaired-translation
|
||||||
|
Original CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
|
||||||
|
Original CUT paper: https://arxiv.org/pdf/2007.15651.pdf
|
||||||
|
We use the network architecture for our default modal image translation
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import functools
|
||||||
|
import numpy as np
|
||||||
|
from torch.nn import init
|
||||||
|
import math
|
||||||
|
class PixelNorm(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
||||||
|
|
||||||
|
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||||
|
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
# magic number 10000 is from transformers
|
||||||
|
emb = math.log(max_positions) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||||
|
emb = timesteps.float()[:, None] * emb[None, :]
|
||||||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||||
|
if embedding_dim % 2 == 1: # zero pad
|
||||||
|
emb = F.pad(emb, (0, 1), mode='constant')
|
||||||
|
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
##################################################################################
|
||||||
|
# Discriminator
|
||||||
|
##################################################################################
|
||||||
|
class D_NLayersMulti(nn.Module):
|
||||||
|
def __init__(self, input_nc, ndf=64, n_layers=3,
|
||||||
|
norm_layer=nn.BatchNorm2d, num_D=1):
|
||||||
|
super(D_NLayersMulti, self).__init__()
|
||||||
|
# st()
|
||||||
|
self.num_D = num_D
|
||||||
|
if num_D == 1:
|
||||||
|
layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
|
||||||
|
self.model = nn.Sequential(*layers)
|
||||||
|
else:
|
||||||
|
layers = self.get_layers(input_nc, ndf, n_layers, norm_layer)
|
||||||
|
self.add_module("model_0", nn.Sequential(*layers))
|
||||||
|
self.down = nn.AvgPool2d(3, stride=2, padding=[
|
||||||
|
1, 1], count_include_pad=False)
|
||||||
|
for i in range(1, num_D):
|
||||||
|
ndf_i = int(round(ndf / (2**i)))
|
||||||
|
layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer)
|
||||||
|
self.add_module("model_%d" % i, nn.Sequential(*layers))
|
||||||
|
|
||||||
|
def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
||||||
|
kw = 4
|
||||||
|
padw = 1
|
||||||
|
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw,
|
||||||
|
stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
||||||
|
|
||||||
|
nf_mult = 1
|
||||||
|
nf_mult_prev = 1
|
||||||
|
for n in range(1, n_layers):
|
||||||
|
nf_mult_prev = nf_mult
|
||||||
|
nf_mult = min(2**n, 8)
|
||||||
|
sequence += [
|
||||||
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||||||
|
kernel_size=kw, stride=2, padding=padw),
|
||||||
|
norm_layer(ndf * nf_mult),
|
||||||
|
nn.LeakyReLU(0.2, True)
|
||||||
|
]
|
||||||
|
|
||||||
|
nf_mult_prev = nf_mult
|
||||||
|
nf_mult = min(2**n_layers, 8)
|
||||||
|
sequence += [
|
||||||
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
|
||||||
|
kernel_size=kw, stride=1, padding=padw),
|
||||||
|
norm_layer(ndf * nf_mult),
|
||||||
|
nn.LeakyReLU(0.2, True)
|
||||||
|
]
|
||||||
|
|
||||||
|
sequence += [nn.Conv2d(ndf * nf_mult, 1,
|
||||||
|
kernel_size=kw, stride=1, padding=padw)]
|
||||||
|
|
||||||
|
return sequence
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
if self.num_D == 1:
|
||||||
|
return self.model(input)
|
||||||
|
result = []
|
||||||
|
down = input
|
||||||
|
for i in range(self.num_D):
|
||||||
|
model = getattr(self, "model_%d" % i)
|
||||||
|
result.append(model(down))
|
||||||
|
if i != self.num_D - 1:
|
||||||
|
down = self.down(down)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlock_cond(nn.Module):
|
||||||
|
def __init__(self, in_channel, out_channel,t_emb_dim, kernel_size=4,stride=1,padding=1,norm_layer=None,downsample=True,use_bias=None):
|
||||||
|
super().__init__()
|
||||||
|
self.downsample=downsample
|
||||||
|
self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=use_bias)
|
||||||
|
|
||||||
|
if norm_layer is not None:
|
||||||
|
self.use_norm =True
|
||||||
|
self.norm = norm_layer(out_channel)
|
||||||
|
else:
|
||||||
|
self.use_norm = False
|
||||||
|
self.act = nn.LeakyReLU(0.2, True)
|
||||||
|
self.down = Downsample(out_channel)
|
||||||
|
|
||||||
|
self.dense= nn.Linear(t_emb_dim, out_channel)
|
||||||
|
def forward(self, input,t_emb):
|
||||||
|
out = self.conv1(input)
|
||||||
|
out += self.dense(t_emb)[..., None, None]
|
||||||
|
if self.use_norm:
|
||||||
|
out = self.norm(out)
|
||||||
|
out = self.act(out)
|
||||||
|
if self.downsample:
|
||||||
|
out = self.down(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class NLayerDiscriminator_ncsn(nn.Module):
|
||||||
|
"""Defines a PatchGAN discriminator"""
|
||||||
|
|
||||||
|
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
|
||||||
|
"""Construct a PatchGAN discriminator
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
input_nc (int) -- the number of channels in input images
|
||||||
|
ndf (int) -- the number of filters in the last conv layer
|
||||||
|
n_layers (int) -- the number of conv layers in the discriminator
|
||||||
|
norm_layer -- normalization layer
|
||||||
|
"""
|
||||||
|
super(NLayerDiscriminator_ncsn, self).__init__()
|
||||||
|
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
||||||
|
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||||
|
else:
|
||||||
|
use_bias = norm_layer == nn.InstanceNorm2d
|
||||||
|
self.model_main = nn.ModuleList()
|
||||||
|
kw = 4
|
||||||
|
padw = 1
|
||||||
|
if no_antialias:
|
||||||
|
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
||||||
|
else:
|
||||||
|
self.model_main.append(ConvBlock_cond(input_nc, ndf, 4*ndf,kernel_size=kw, stride=1, padding=padw,use_bias=use_bias))
|
||||||
|
|
||||||
|
nf_mult = 1
|
||||||
|
nf_mult_prev = 1
|
||||||
|
|
||||||
|
for n in range(1, n_layers): # gradually increase the number of filters
|
||||||
|
nf_mult_prev = nf_mult
|
||||||
|
nf_mult = min(2 ** n, 8)
|
||||||
|
if no_antialias:
|
||||||
|
sequence += [
|
||||||
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||||||
|
norm_layer(ndf * nf_mult),
|
||||||
|
nn.LeakyReLU(0.2, True)]
|
||||||
|
else:
|
||||||
|
self.model_main.append(
|
||||||
|
ConvBlock_cond(ndf * nf_mult_prev, ndf * nf_mult, 4*ndf,kernel_size=kw, stride=1, padding=padw,use_bias=use_bias,norm_layer=norm_layer)
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
nf_mult_prev = nf_mult
|
||||||
|
nf_mult = min(2 ** n_layers, 8)
|
||||||
|
self.model_main.append(
|
||||||
|
ConvBlock_cond(ndf * nf_mult_prev, ndf * nf_mult,4*ndf, kernel_size=kw, stride=1, padding=padw,use_bias=use_bias,norm_layer=norm_layer,downsample=False)
|
||||||
|
|
||||||
|
)
|
||||||
|
self.final_conv =nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
||||||
|
self.t_embed = TimestepEmbedding(
|
||||||
|
embedding_dim=4*ndf,
|
||||||
|
hidden_dim=4*ndf,
|
||||||
|
output_dim=4*ndf,
|
||||||
|
act=nn.LeakyReLU(0.2),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, input,t_emb,input2=None):
|
||||||
|
"""Standard forward."""
|
||||||
|
t_emb = self.t_embed(t_emb)
|
||||||
|
if input2 is not None:
|
||||||
|
out = torch.cat([input,input2],dim=1)
|
||||||
|
else:
|
||||||
|
|
||||||
|
out = input
|
||||||
|
for layer in self.model_main:
|
||||||
|
out = layer(out,t_emb)
|
||||||
|
|
||||||
|
return self.final_conv(out)
|
||||||
|
|
||||||
|
class PixelDiscriminator(nn.Module):
|
||||||
|
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
|
||||||
|
|
||||||
|
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
|
||||||
|
"""Construct a 1x1 PatchGAN discriminator
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
input_nc (int) -- the number of channels in input images
|
||||||
|
ndf (int) -- the number of filters in the last conv layer
|
||||||
|
norm_layer -- normalization layer
|
||||||
|
"""
|
||||||
|
super(PixelDiscriminator, self).__init__()
|
||||||
|
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
||||||
|
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||||
|
else:
|
||||||
|
use_bias = norm_layer == nn.InstanceNorm2d
|
||||||
|
|
||||||
|
self.net = [
|
||||||
|
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
|
||||||
|
nn.LeakyReLU(0.2, True),
|
||||||
|
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
|
||||||
|
norm_layer(ndf * 2),
|
||||||
|
nn.LeakyReLU(0.2, True),
|
||||||
|
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
|
||||||
|
|
||||||
|
self.net = nn.Sequential(*self.net)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
"""Standard forward."""
|
||||||
|
return self.net(input)
|
||||||
|
|
||||||
|
|
||||||
|
##################################################################################
|
||||||
|
# Generator
|
||||||
|
##################################################################################
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(self, embedding_dim, hidden_dim, output_dim, act=nn.LeakyReLU(0.2)):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
|
||||||
|
self.main = nn.Sequential(
|
||||||
|
nn.Linear(embedding_dim, hidden_dim),
|
||||||
|
nn.LeakyReLU(0.2),
|
||||||
|
nn.Linear(hidden_dim, output_dim),
|
||||||
|
nn.LeakyReLU(0.2),
|
||||||
|
# EqualLinear(hidden_dim, output_dim,bias_init = 0, activation='fused_lrelu')
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, temp):
|
||||||
|
temb = get_timestep_embedding(temp, self.embedding_dim)
|
||||||
|
temb = self.main(temb)
|
||||||
|
return temb
|
||||||
|
|
||||||
|
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||||
|
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
|
||||||
|
half_dim = embedding_dim // 2
|
||||||
|
# magic number 10000 is from transformers
|
||||||
|
emb = math.log(max_positions) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
|
||||||
|
emb = timesteps.float()[:, None] * emb[None, :]
|
||||||
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||||
|
if embedding_dim % 2 == 1: # zero pad
|
||||||
|
emb = F.pad(emb, (0, 1), mode='constant')
|
||||||
|
assert emb.shape == (timesteps.shape[0], embedding_dim)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
class AdaptiveLayer(nn.Module):
|
||||||
|
def __init__(self, in_channel, style_dim):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.style_net = nn.Linear(style_dim, in_channel * 2)
|
||||||
|
|
||||||
|
self.style_net.bias.data[:in_channel] = 1
|
||||||
|
self.style_net.bias.data[in_channel:] = 0
|
||||||
|
|
||||||
|
def forward(self, input, style):
|
||||||
|
|
||||||
|
style = self.style_net(style).unsqueeze(2).unsqueeze(3)
|
||||||
|
gamma, beta = style.chunk(2, 1)
|
||||||
|
|
||||||
|
out = gamma * input + beta
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class ResnetGenerator_ncsn(nn.Module):
|
||||||
|
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
||||||
|
|
||||||
|
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9,
|
||||||
|
padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
|
||||||
|
"""Construct a Resnet-based generator
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
input_nc (int) -- the number of channels in input images
|
||||||
|
output_nc (int) -- the number of channels in output images
|
||||||
|
ngf (int) -- the number of filters in the last conv layer
|
||||||
|
norm_layer -- normalization layer
|
||||||
|
use_dropout (bool) -- if use dropout layers
|
||||||
|
n_blocks (int) -- the number of ResNet blocks
|
||||||
|
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
||||||
|
"""
|
||||||
|
assert(n_blocks >= 0)
|
||||||
|
super(ResnetGenerator_ncsn, self).__init__()
|
||||||
|
self.opt = opt
|
||||||
|
if type(norm_layer) == functools.partial:
|
||||||
|
use_bias = norm_layer.func == nn.InstanceNorm2d
|
||||||
|
else:
|
||||||
|
use_bias = norm_layer == nn.InstanceNorm2d
|
||||||
|
|
||||||
|
model = [nn.ReflectionPad2d(3),
|
||||||
|
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
||||||
|
norm_layer(ngf),
|
||||||
|
nn.ReLU(True)]
|
||||||
|
self.ngf = ngf
|
||||||
|
n_downsampling = 2
|
||||||
|
for i in range(n_downsampling): # add downsampling layers
|
||||||
|
mult = 2 ** i
|
||||||
|
if no_antialias:
|
||||||
|
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
||||||
|
norm_layer(ngf * mult * 2),
|
||||||
|
nn.ReLU(True)]
|
||||||
|
else:
|
||||||
|
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
||||||
|
norm_layer(ngf * mult * 2),
|
||||||
|
nn.ReLU(True),
|
||||||
|
Downsample(ngf * mult * 2)
|
||||||
|
# nn.AvgPool2d(kernel_size=2, stride=2)
|
||||||
|
]
|
||||||
|
self.model_res = nn.ModuleList()
|
||||||
|
mult = 2 ** n_downsampling
|
||||||
|
for i in range(n_blocks): # add ResNet blocks
|
||||||
|
|
||||||
|
self.model_res += [ResnetBlock_cond(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias,temb_dim=4*ngf,z_dim=4*ngf)]
|
||||||
|
|
||||||
|
model_upsample = []
|
||||||
|
for i in range(n_downsampling): # add upsampling layers
|
||||||
|
mult = 2 ** (n_downsampling - i)
|
||||||
|
if no_antialias_up:
|
||||||
|
model_upsample += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
|
||||||
|
norm_layer(int(ngf * mult / 2)),
|
||||||
|
nn.ReLU(True)]
|
||||||
|
else:
|
||||||
|
model_upsample += [
|
||||||
|
Upsample(ngf * mult),
|
||||||
|
# nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
||||||
|
nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=1, bias=use_bias),
|
||||||
|
norm_layer(int(ngf * mult / 2)),
|
||||||
|
nn.ReLU(True)]
|
||||||
|
model_upsample += [nn.ReflectionPad2d(3)]
|
||||||
|
model_upsample += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
||||||
|
model_upsample += [nn.Tanh()]
|
||||||
|
|
||||||
|
self.model = nn.Sequential(*model)
|
||||||
|
self.model_upsample = nn.Sequential(*model_upsample)
|
||||||
|
mapping_layers = [PixelNorm(),
|
||||||
|
nn.Linear(self.ngf*4, self.ngf*4),
|
||||||
|
nn.LeakyReLU(0.2)]
|
||||||
|
for _ in range(opt.n_mlp):
|
||||||
|
mapping_layers.append(nn.Linear(self.ngf*4, self.ngf*4))
|
||||||
|
mapping_layers.append(nn.LeakyReLU(0.2))
|
||||||
|
self.z_transform = nn.Sequential(*mapping_layers)
|
||||||
|
modules_emb = []
|
||||||
|
modules_emb += [nn.Linear(self.ngf,self.ngf*4)]
|
||||||
|
|
||||||
|
nn.init.zeros_(modules_emb[-1].bias)
|
||||||
|
modules_emb += [nn.LeakyReLU(0.2)]
|
||||||
|
modules_emb += [nn.Linear(self.ngf*4,self.ngf*4)]
|
||||||
|
|
||||||
|
nn.init.zeros_(modules_emb[-1].bias)
|
||||||
|
modules_emb += [nn.LeakyReLU(0.2)]
|
||||||
|
self.time_embed = nn.Sequential(*modules_emb)
|
||||||
|
|
||||||
|
def forward(self, x, time_cond,z,layers=[], encode_only=False):
|
||||||
|
z_embed = self.z_transform(z)
|
||||||
|
# print(z_embed.shape)
|
||||||
|
temb = get_timestep_embedding(time_cond, self.ngf)
|
||||||
|
time_embed = self.time_embed(temb)
|
||||||
|
if len(layers) > 0:
|
||||||
|
feat = x
|
||||||
|
feats = []
|
||||||
|
for layer_id, layer in enumerate(self.model):
|
||||||
|
feat = layer(feat)
|
||||||
|
if layer_id in layers:
|
||||||
|
feats.append(feat)
|
||||||
|
|
||||||
|
for layer_id, layer in enumerate(self.model_res):
|
||||||
|
feat = layer(feat,time_embed,z_embed)
|
||||||
|
if layer_id+len(self.model) in layers:
|
||||||
|
feats.append(feat)
|
||||||
|
if layer_id+len(self.model) == layers[-1] and encode_only:
|
||||||
|
return feats
|
||||||
|
return feat, feats
|
||||||
|
else:
|
||||||
|
|
||||||
|
out = self.model(x)
|
||||||
|
for layer in self.model_res:
|
||||||
|
out = layer(out,time_embed,z_embed)
|
||||||
|
out = self.model_upsample(out)
|
||||||
|
return out
|
||||||
|
##################################################################################
|
||||||
|
# Basic Blocks
|
||||||
|
##################################################################################
|
||||||
|
class ResnetBlock(nn.Module):
|
||||||
|
"""Define a Resnet block"""
|
||||||
|
|
||||||
|
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
||||||
|
"""Initialize the Resnet block
|
||||||
|
|
||||||
|
A resnet block is a conv block with skip connections
|
||||||
|
We construct a conv block with build_conv_block function,
|
||||||
|
and implement skip connections in <forward> function.
|
||||||
|
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
||||||
|
"""
|
||||||
|
super(ResnetBlock, self).__init__()
|
||||||
|
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
||||||
|
|
||||||
|
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
||||||
|
"""Construct a convolutional block.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
dim (int) -- the number of channels in the conv layer.
|
||||||
|
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
||||||
|
norm_layer -- normalization layer
|
||||||
|
use_dropout (bool) -- if use dropout layers.
|
||||||
|
use_bias (bool) -- if the conv layer uses bias or not
|
||||||
|
|
||||||
|
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
||||||
|
"""
|
||||||
|
conv_block = []
|
||||||
|
p = 0
|
||||||
|
if padding_type == 'reflect':
|
||||||
|
conv_block += [nn.ReflectionPad2d(1)]
|
||||||
|
elif padding_type == 'replicate':
|
||||||
|
conv_block += [nn.ReplicationPad2d(1)]
|
||||||
|
elif padding_type == 'zero':
|
||||||
|
p = 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||||
|
|
||||||
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
||||||
|
if use_dropout:
|
||||||
|
conv_block += [nn.Dropout(0.5)]
|
||||||
|
|
||||||
|
p = 0
|
||||||
|
if padding_type == 'reflect':
|
||||||
|
conv_block += [nn.ReflectionPad2d(1)]
|
||||||
|
elif padding_type == 'replicate':
|
||||||
|
conv_block += [nn.ReplicationPad2d(1)]
|
||||||
|
elif padding_type == 'zero':
|
||||||
|
p = 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||||
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
||||||
|
|
||||||
|
return nn.Sequential(*conv_block)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Forward function (with skip connections)"""
|
||||||
|
out = x + self.conv_block(x) # add skip connections
|
||||||
|
return out
|
||||||
|
|
||||||
|
class ResnetBlock_cond(nn.Module):
|
||||||
|
"""Define a Resnet block"""
|
||||||
|
|
||||||
|
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim):
|
||||||
|
"""Initialize the Resnet block
|
||||||
|
|
||||||
|
A resnet block is a conv block with skip connections
|
||||||
|
We construct a conv block with build_conv_block function,
|
||||||
|
and implement skip connections in <forward> function.
|
||||||
|
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
||||||
|
"""
|
||||||
|
super(ResnetBlock_cond, self).__init__()
|
||||||
|
self.conv_block,self.adaptive,self.conv_fin = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim)
|
||||||
|
|
||||||
|
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim):
|
||||||
|
"""Construct a convolutional block.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
dim (int) -- the number of channels in the conv layer.
|
||||||
|
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
||||||
|
norm_layer -- normalization layer
|
||||||
|
use_dropout (bool) -- if use dropout layers.
|
||||||
|
use_bias (bool) -- if the conv layer uses bias or not
|
||||||
|
|
||||||
|
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.conv_block = nn.ModuleList()
|
||||||
|
self.conv_fin = nn.ModuleList()
|
||||||
|
p = 0
|
||||||
|
if padding_type == 'reflect':
|
||||||
|
self.conv_block += [nn.ReflectionPad2d(1)]
|
||||||
|
elif padding_type == 'replicate':
|
||||||
|
self.conv_block += [nn.ReplicationPad2d(1)]
|
||||||
|
elif padding_type == 'zero':
|
||||||
|
p = 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||||
|
|
||||||
|
self.conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
||||||
|
self.adaptive = AdaptiveLayer(dim,z_dim)
|
||||||
|
self.conv_fin += [nn.ReLU(True)]
|
||||||
|
if use_dropout:
|
||||||
|
self.conv_fin += [nn.Dropout(0.5)]
|
||||||
|
|
||||||
|
p = 0
|
||||||
|
if padding_type == 'reflect':
|
||||||
|
self.conv_fin += [nn.ReflectionPad2d(1)]
|
||||||
|
elif padding_type == 'replicate':
|
||||||
|
self.conv_fin += [nn.ReplicationPad2d(1)]
|
||||||
|
elif padding_type == 'zero':
|
||||||
|
p = 1
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
||||||
|
self.conv_fin += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
||||||
|
|
||||||
|
self.Dense_time = nn.Linear(temb_dim, dim)
|
||||||
|
# self.Dense_time.weight.data = default_init()(self.Dense_time.weight.data.shape)
|
||||||
|
nn.init.zeros_(self.Dense_time.bias)
|
||||||
|
|
||||||
|
self.style = nn.Linear(z_dim, dim * 2)
|
||||||
|
|
||||||
|
self.style.bias.data[:dim] = 1
|
||||||
|
self.style.bias.data[dim:] = 0
|
||||||
|
|
||||||
|
return self.conv_block,self.adaptive,self.conv_fin
|
||||||
|
|
||||||
|
def forward(self, x,time_cond,z):
|
||||||
|
|
||||||
|
time_input = self.Dense_time(time_cond)
|
||||||
|
for n,layer in enumerate(self.conv_block):
|
||||||
|
out = layer(x)
|
||||||
|
if n==0:
|
||||||
|
out += time_input[:, :, None, None]
|
||||||
|
out = self.adaptive(out,z)
|
||||||
|
for layer in self.conv_fin:
|
||||||
|
out = layer(out)
|
||||||
|
"""Forward function (with skip connections)"""
|
||||||
|
out = x + out # add skip connections
|
||||||
|
return out
|
||||||
|
###############################################################################
|
||||||
|
# Helper Functions
|
||||||
|
###############################################################################
|
||||||
|
def get_filter(filt_size=3):
|
||||||
|
if(filt_size == 1):
|
||||||
|
a = np.array([1., ])
|
||||||
|
elif(filt_size == 2):
|
||||||
|
a = np.array([1., 1.])
|
||||||
|
elif(filt_size == 3):
|
||||||
|
a = np.array([1., 2., 1.])
|
||||||
|
elif(filt_size == 4):
|
||||||
|
a = np.array([1., 3., 3., 1.])
|
||||||
|
elif(filt_size == 5):
|
||||||
|
a = np.array([1., 4., 6., 4., 1.])
|
||||||
|
elif(filt_size == 6):
|
||||||
|
a = np.array([1., 5., 10., 10., 5., 1.])
|
||||||
|
elif(filt_size == 7):
|
||||||
|
a = np.array([1., 6., 15., 20., 15., 6., 1.])
|
||||||
|
|
||||||
|
filt = torch.Tensor(a[:, None] * a[None, :])
|
||||||
|
filt = filt / torch.sum(filt)
|
||||||
|
|
||||||
|
return filt
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0):
|
||||||
|
super(Downsample, self).__init__()
|
||||||
|
self.filt_size = filt_size
|
||||||
|
self.pad_off = pad_off
|
||||||
|
self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
|
||||||
|
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
|
||||||
|
self.stride = stride
|
||||||
|
self.off = int((self.stride - 1) / 2.)
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
filt = get_filter(filt_size=self.filt_size)
|
||||||
|
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
||||||
|
|
||||||
|
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
|
||||||
|
|
||||||
|
def forward(self, inp):
|
||||||
|
if(self.filt_size == 1):
|
||||||
|
if(self.pad_off == 0):
|
||||||
|
return inp[:, :, ::self.stride, ::self.stride]
|
||||||
|
else:
|
||||||
|
return self.pad(inp)[:, :, ::self.stride, ::self.stride]
|
||||||
|
else:
|
||||||
|
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample2(nn.Module):
|
||||||
|
def __init__(self, scale_factor, mode='nearest'):
|
||||||
|
super().__init__()
|
||||||
|
self.factor = scale_factor
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample(nn.Module):
|
||||||
|
def __init__(self, channels, pad_type='repl', filt_size=4, stride=2):
|
||||||
|
super(Upsample, self).__init__()
|
||||||
|
self.filt_size = filt_size
|
||||||
|
self.filt_odd = np.mod(filt_size, 2) == 1
|
||||||
|
self.pad_size = int((filt_size - 1) / 2)
|
||||||
|
self.stride = stride
|
||||||
|
self.off = int((self.stride - 1) / 2.)
|
||||||
|
self.channels = channels
|
||||||
|
|
||||||
|
filt = get_filter(filt_size=self.filt_size) * (stride**2)
|
||||||
|
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
||||||
|
|
||||||
|
self.pad = get_pad_layer(pad_type)([1, 1, 1, 1])
|
||||||
|
|
||||||
|
def forward(self, inp):
|
||||||
|
ret_val = F.conv_transpose2d(self.pad(inp), self.filt, stride=self.stride, padding=1 + self.pad_size, groups=inp.shape[1])[:, :, 1:, 1:]
|
||||||
|
if(self.filt_odd):
|
||||||
|
return ret_val
|
||||||
|
else:
|
||||||
|
return ret_val[:, :, :-1, :-1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_pad_layer(pad_type):
|
||||||
|
if(pad_type in ['refl', 'reflect']):
|
||||||
|
PadLayer = nn.ReflectionPad2d
|
||||||
|
elif(pad_type in ['repl', 'replicate']):
|
||||||
|
PadLayer = nn.ReplicationPad2d
|
||||||
|
elif(pad_type == 'zero'):
|
||||||
|
PadLayer = nn.ZeroPad2d
|
||||||
|
else:
|
||||||
|
print('Pad type [%s] not recognized' % pad_type)
|
||||||
|
return PadLayer
|
||||||
|
|
||||||
|
|
||||||
|
class Identity(nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_norm_layer(norm_type='instance'):
|
||||||
|
"""Return a normalization layer
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
||||||
|
|
||||||
|
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
||||||
|
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
||||||
|
"""
|
||||||
|
if norm_type == 'batch':
|
||||||
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
||||||
|
elif norm_type == 'instance':
|
||||||
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||||
|
elif norm_type == 'none':
|
||||||
|
def norm_layer(x): return Identity()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
||||||
|
return norm_layer
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
|
||||||
|
"""Initialize network weights.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
net (network) -- network to be initialized
|
||||||
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
||||||
|
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
||||||
|
|
||||||
|
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
||||||
|
work better for some applications. Feel free to try yourself.
|
||||||
|
"""
|
||||||
|
def init_func(m): # define the initialization function
|
||||||
|
classname = m.__class__.__name__
|
||||||
|
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
||||||
|
if debug:
|
||||||
|
print(classname)
|
||||||
|
if init_type == 'normal':
|
||||||
|
init.normal_(m.weight.data, 0.0, init_gain)
|
||||||
|
elif init_type == 'xavier':
|
||||||
|
init.xavier_normal_(m.weight.data, gain=init_gain)
|
||||||
|
elif init_type == 'kaiming':
|
||||||
|
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||||
|
elif init_type == 'orthogonal':
|
||||||
|
init.orthogonal_(m.weight.data, gain=init_gain)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||||||
|
if hasattr(m, 'bias') and m.bias is not None:
|
||||||
|
init.constant_(m.bias.data, 0.0)
|
||||||
|
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
||||||
|
init.normal_(m.weight.data, 1.0, init_gain)
|
||||||
|
init.constant_(m.bias.data, 0.0)
|
||||||
|
|
||||||
|
print('initialize network with %s' % init_type)
|
||||||
|
net.apply(init_func) # apply the initialization function <init_func>
|
||||||
|
|
||||||
|
|
||||||
|
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
|
||||||
|
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
||||||
|
Parameters:
|
||||||
|
net (network) -- the network to be initialized
|
||||||
|
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
||||||
|
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
||||||
|
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
||||||
|
|
||||||
|
Return an initialized network.
|
||||||
|
"""
|
||||||
|
if len(gpu_ids) > 0:
|
||||||
|
assert(torch.cuda.is_available())
|
||||||
|
net.to(gpu_ids[0])
|
||||||
|
if initialize_weights:
|
||||||
|
init_weights(net, init_type, init_gain=init_gain, debug=debug)
|
||||||
|
return net
|
||||||
@ -7,6 +7,7 @@ from torch.optim import lr_scheduler
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
from .stylegan_networks import StyleGAN2Discriminator, StyleGAN2Generator, TileStyleGAN2Discriminator
|
from .stylegan_networks import StyleGAN2Discriminator, StyleGAN2Generator, TileStyleGAN2Discriminator
|
||||||
|
from .ncsn_networks import NLayerDiscriminator_ncsn, ResnetGenerator_ncsn
|
||||||
|
|
||||||
###############################################################################
|
###############################################################################
|
||||||
# Helper Functions
|
# Helper Functions
|
||||||
@ -266,6 +267,8 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
|
|||||||
elif netG == 'resnet_cat':
|
elif netG == 'resnet_cat':
|
||||||
n_blocks = 8
|
n_blocks = 8
|
||||||
net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu')
|
net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu')
|
||||||
|
elif netG == 'resnet_9blocks_cond':
|
||||||
|
net = ResnetGenerator_ncsn(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=9, opt=opt)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
||||||
return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG))
|
return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG))
|
||||||
@ -977,6 +980,7 @@ class ResnetGenerator(nn.Module):
|
|||||||
feats = []
|
feats = []
|
||||||
for layer_id, layer in enumerate(self.model):
|
for layer_id, layer in enumerate(self.model):
|
||||||
# print(layer_id, layer)
|
# print(layer_id, layer)
|
||||||
|
print(feat.shape)
|
||||||
feat = layer(feat)
|
feat = layer(feat)
|
||||||
if layer_id in layers:
|
if layer_id in layers:
|
||||||
# print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
|
# print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
|
||||||
@ -984,6 +988,11 @@ class ResnetGenerator(nn.Module):
|
|||||||
else:
|
else:
|
||||||
# print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
|
# print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
|
||||||
pass
|
pass
|
||||||
|
print(f"layer_id: {layer_id}, type(layer_id): {type(layer_id)}")
|
||||||
|
print(f"len(list(self.model)) - 1: {len(list(self.model)) - 1}, type(len(list(self.model)) - 1): {type(len(list(self.model)) - 1)}")
|
||||||
|
# Print layers to see what it is. If layers is still a tensor here, we need to understand why.
|
||||||
|
print(f"layers: {layers}, type(layers): {type(layers)}")
|
||||||
|
print(f"encode_only'shape: {encode_only.shape}, type(encode_only): {type(encode_only)}")
|
||||||
if layer_id == len(list(self.model)) - 1 and encode_only:
|
if layer_id == len(list(self.model)) - 1 and encode_only:
|
||||||
# print('encoder only return features')
|
# print('encoder only return features')
|
||||||
return feats # return intermediate features alone; stop in the last layers
|
return feats # return intermediate features alone; stop in the last layers
|
||||||
|
|||||||
@ -221,8 +221,10 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
|
parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
|
||||||
|
|
||||||
parser.add_argument('--tau', type=float, default=0.1, help='used in unsb')
|
parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter')
|
||||||
parser.add_argument('--num_timesteps', type=int, default=10, help='used in unsb')
|
parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer')
|
||||||
|
|
||||||
|
parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers')
|
||||||
|
|
||||||
parser.set_defaults(pool_size=0) # no image pooling
|
parser.set_defaults(pool_size=0) # no image pooling
|
||||||
|
|
||||||
@ -261,6 +263,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
else:
|
else:
|
||||||
self.model_names = ['G']
|
self.model_names = ['G']
|
||||||
|
|
||||||
|
print(f'input_nc = {self.opt.input_nc}')
|
||||||
# 创建网络
|
# 创建网络
|
||||||
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
||||||
|
|
||||||
@ -269,7 +272,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
||||||
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
||||||
|
|
||||||
self.resize = tfs.Resize(size=(384,384))
|
self.resize = tfs.Resize(size=(384,384), antialias=True)
|
||||||
|
|
||||||
# 加入预训练VIT
|
# 加入预训练VIT
|
||||||
self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device)
|
self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device)
|
||||||
@ -397,16 +400,20 @@ class RomaUnsbModel(BaseModel):
|
|||||||
"""执行前向传递以生成输出图像"""
|
"""执行前向传递以生成输出图像"""
|
||||||
|
|
||||||
if self.opt.isTrain:
|
if self.opt.isTrain:
|
||||||
|
print(f'before resize: {self.real_A0.shape}')
|
||||||
real_A0 = self.resize(self.real_A0)
|
real_A0 = self.resize(self.real_A0)
|
||||||
real_A1 = self.resize(self.real_A1)
|
real_A1 = self.resize(self.real_A1)
|
||||||
real_B0 = self.resize(self.real_B0)
|
real_B0 = self.resize(self.real_B0)
|
||||||
real_B1 = self.resize(self.real_B1)
|
real_B1 = self.resize(self.real_B1)
|
||||||
# 使用VIT
|
# 使用VIT
|
||||||
|
|
||||||
|
print(f'before vit: {real_A0.shape}')
|
||||||
self.mutil_real_A0_tokens = self.netPreViT(real_A0, self.atten_layers, get_tokens=True)
|
self.mutil_real_A0_tokens = self.netPreViT(real_A0, self.atten_layers, get_tokens=True)
|
||||||
self.mutil_real_A1_tokens = self.netPreViT(real_A1, self.atten_layers, get_tokens=True)
|
self.mutil_real_A1_tokens = self.netPreViT(real_A1, self.atten_layers, get_tokens=True)
|
||||||
|
|
||||||
self.mutil_real_A0_tokens = torch.cat(self.mutil_real_A0_tokens, dim=1).to(self.device)
|
print(f'before cat: len = {len(self.mutil_real_A0_tokens)}\n{self.mutil_real_A0_tokens[0].shape}')
|
||||||
self.mutil_real_A1_tokens = torch.cat(self.mutil_real_A1_tokens, dim=1).to(self.device)
|
self.mutil_real_A0_tokens = torch.cat(self.mutil_real_A0_tokens, dim=0).unsqueeze(0).to(self.device)
|
||||||
|
self.mutil_real_A1_tokens = torch.cat(self.mutil_real_A1_tokens, dim=0).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
# 执行一次SB模块
|
# 执行一次SB模块
|
||||||
|
|
||||||
@ -436,6 +443,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
inter = (delta / denom).reshape(-1, 1, 1, 1)
|
inter = (delta / denom).reshape(-1, 1, 1, 1)
|
||||||
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
|
scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
|
||||||
|
|
||||||
|
print(f'before noisy: {self.mutil_real_A0_tokens.shape}')
|
||||||
# 对 Xt、Xt2 进行随机噪声更新
|
# 对 Xt、Xt2 进行随机噪声更新
|
||||||
Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
|
Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
|
||||||
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
|
(scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
|
||||||
@ -454,7 +462,8 @@ class RomaUnsbModel(BaseModel):
|
|||||||
self.real_A_noisy = Xt.detach()
|
self.real_A_noisy = Xt.detach()
|
||||||
self.real_A_noisy2 = Xt2.detach()
|
self.real_A_noisy2 = Xt2.detach()
|
||||||
# 保存noisy_map
|
# 保存noisy_map
|
||||||
self.noisy_map = self.real_A_noisy - self.real_A0
|
print(f'after noisy map: {self.real_A_noisy.shape}')
|
||||||
|
self.noisy_map = self.real_A_noisy - self.mutil_real_A0_tokens
|
||||||
|
|
||||||
# ============ 第三步:拼接输入并执行网络推理 =============
|
# ============ 第三步:拼接输入并执行网络推理 =============
|
||||||
bs = self.mutil_real_A0_tokens.size(0)
|
bs = self.mutil_real_A0_tokens.size(0)
|
||||||
@ -518,7 +527,7 @@ class RomaUnsbModel(BaseModel):
|
|||||||
|
|
||||||
if self.opt.phase == 'train':
|
if self.opt.phase == 'train':
|
||||||
# 真实图像的梯度
|
# 真实图像的梯度
|
||||||
real_gradient = torch.autograd.grad(self.real_B.sum(), self.real_B, create_graph=True)[0]
|
real_gradient = torch.autograd.grad(self.real_B0.sum(), self.real_B0, create_graph=True)[0]
|
||||||
# 生成图像的梯度
|
# 生成图像的梯度
|
||||||
fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0]
|
fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0]
|
||||||
# 梯度图
|
# 梯度图
|
||||||
|
|||||||
Binary file not shown.
@ -36,7 +36,7 @@ class BaseOptions():
|
|||||||
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
|
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
|
||||||
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
|
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
|
||||||
parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
||||||
parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat'], help='specify generator architecture')
|
parser.add_argument('--netG', type=str, default='resnet_9blocks_cond', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture')
|
||||||
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
|
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
|
||||||
parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
|
parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
|
||||||
parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
|
parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
|
||||||
|
|||||||
@ -30,4 +30,4 @@ python train.py \
|
|||||||
--eta_ratio 0.1 \
|
--eta_ratio 0.1 \
|
||||||
--tau 0.1 \
|
--tau 0.1 \
|
||||||
--num_timesteps 10 \
|
--num_timesteps 10 \
|
||||||
--input_nc 1
|
--input_nc 3
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user