diff --git a/checkpoints/ROMA_UNSB_001/loss_log.txt b/checkpoints/ROMA_UNSB_001/loss_log.txt index 8d9f361..19fcafc 100644 --- a/checkpoints/ROMA_UNSB_001/loss_log.txt +++ b/checkpoints/ROMA_UNSB_001/loss_log.txt @@ -6,3 +6,22 @@ ================ Training Loss (Sun Feb 23 16:06:44 2025) ================ ================ Training Loss (Sun Feb 23 16:09:38 2025) ================ ================ Training Loss (Sun Feb 23 16:44:56 2025) ================ +================ Training Loss (Sun Feb 23 16:49:46 2025) ================ +================ Training Loss (Sun Feb 23 16:51:03 2025) ================ +================ Training Loss (Sun Feb 23 16:51:23 2025) ================ +================ Training Loss (Sun Feb 23 18:04:02 2025) ================ +================ Training Loss (Sun Feb 23 18:04:39 2025) ================ +================ Training Loss (Sun Feb 23 18:05:17 2025) ================ +================ Training Loss (Sun Feb 23 18:06:40 2025) ================ +================ Training Loss (Sun Feb 23 18:11:48 2025) ================ +================ Training Loss (Sun Feb 23 18:13:31 2025) ================ +================ Training Loss (Sun Feb 23 18:14:11 2025) ================ +================ Training Loss (Sun Feb 23 18:14:29 2025) ================ +================ Training Loss (Sun Feb 23 18:16:27 2025) ================ +================ Training Loss (Sun Feb 23 18:16:44 2025) ================ +================ Training Loss (Sun Feb 23 18:20:39 2025) ================ +================ Training Loss (Sun Feb 23 18:21:44 2025) ================ +================ Training Loss (Sun Feb 23 18:35:27 2025) ================ +================ Training Loss (Sun Feb 23 18:39:21 2025) ================ +================ Training Loss (Sun Feb 23 18:40:15 2025) ================ +================ Training Loss (Sun Feb 23 18:41:15 2025) ================ diff --git a/checkpoints/ROMA_UNSB_001/train_opt.txt b/checkpoints/ROMA_UNSB_001/train_opt.txt index b15424d..d7766e4 100644 --- a/checkpoints/ROMA_UNSB_001/train_opt.txt +++ b/checkpoints/ROMA_UNSB_001/train_opt.txt @@ -43,6 +43,7 @@ n_epochs: 100 n_epochs_decay: 100 n_layers_D: 3 + n_mlp: 3 name: ROMA_UNSB_001 [default: experiment_name] nce_T: 0.07 nce_idt: False [default: True] @@ -52,7 +53,7 @@ nce_includes_all_negatives_from_minibatch: False netD: basic netF: mlp_sample netF_nc: 256 - netG: resnet_9blocks + netG: resnet_9blocks_cond ngf: 64 no_antialias: False no_antialias_up: False @@ -63,7 +64,7 @@ nce_includes_all_negatives_from_minibatch: False normG: instance num_patches: 256 num_threads: 4 - num_timesteps: 10 + num_timesteps: 10 [default: 5] output_nc: 3 phase: train pool_size: 0 @@ -77,7 +78,7 @@ nce_includes_all_negatives_from_minibatch: False serial_batches: False stylegan2_G_num_downsampling: 1 suffix: - tau: 0.1 + tau: 0.1 [default: 0.01] update_html_freq: 1000 use_idt: False verbose: False diff --git a/models/__pycache__/ncsn_networks.cpython-39.pyc b/models/__pycache__/ncsn_networks.cpython-39.pyc new file mode 100644 index 0000000..2830371 Binary files /dev/null and b/models/__pycache__/ncsn_networks.cpython-39.pyc differ diff --git a/models/__pycache__/networks.cpython-39.pyc b/models/__pycache__/networks.cpython-39.pyc index 800e7d3..40fbb5c 100644 Binary files a/models/__pycache__/networks.cpython-39.pyc and b/models/__pycache__/networks.cpython-39.pyc differ diff --git a/models/__pycache__/roma_unsb_model.cpython-39.pyc b/models/__pycache__/roma_unsb_model.cpython-39.pyc index 4914986..fe4d91c 100644 Binary files a/models/__pycache__/roma_unsb_model.cpython-39.pyc and b/models/__pycache__/roma_unsb_model.cpython-39.pyc differ diff --git a/models/ncsn_networks.py b/models/ncsn_networks.py new file mode 100644 index 0000000..2eee7e7 --- /dev/null +++ b/models/ncsn_networks.py @@ -0,0 +1,719 @@ +""" +The network architectures is based on the implementation of CycleGAN and CUT +Original PyTorch repo of CycleGAN: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix +Original PyTorch repo of CUT: https://github.com/taesungp/contrastive-unpaired-translation +Original CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf +Original CUT paper: https://arxiv.org/pdf/2007.15651.pdf +We use the network architecture for our default modal image translation +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import functools +import numpy as np +from torch.nn import init +import math +class PixelNorm(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) + +def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): + assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 + half_dim = embedding_dim // 2 + # magic number 10000 is from transformers + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = F.pad(emb, (0, 1), mode='constant') + assert emb.shape == (timesteps.shape[0], embedding_dim) + return emb + +################################################################################## +# Discriminator +################################################################################## +class D_NLayersMulti(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, + norm_layer=nn.BatchNorm2d, num_D=1): + super(D_NLayersMulti, self).__init__() + # st() + self.num_D = num_D + if num_D == 1: + layers = self.get_layers(input_nc, ndf, n_layers, norm_layer) + self.model = nn.Sequential(*layers) + else: + layers = self.get_layers(input_nc, ndf, n_layers, norm_layer) + self.add_module("model_0", nn.Sequential(*layers)) + self.down = nn.AvgPool2d(3, stride=2, padding=[ + 1, 1], count_include_pad=False) + for i in range(1, num_D): + ndf_i = int(round(ndf / (2**i))) + layers = self.get_layers(input_nc, ndf_i, n_layers, norm_layer) + self.add_module("model_%d" % i, nn.Sequential(*layers)) + + def get_layers(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=2, padding=padw), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, + kernel_size=kw, stride=1, padding=padw), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, + kernel_size=kw, stride=1, padding=padw)] + + return sequence + + def forward(self, input): + if self.num_D == 1: + return self.model(input) + result = [] + down = input + for i in range(self.num_D): + model = getattr(self, "model_%d" % i) + result.append(model(down)) + if i != self.num_D - 1: + down = self.down(down) + return result + + + + +class ConvBlock_cond(nn.Module): + def __init__(self, in_channel, out_channel,t_emb_dim, kernel_size=4,stride=1,padding=1,norm_layer=None,downsample=True,use_bias=None): + super().__init__() + self.downsample=downsample + self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=use_bias) + + if norm_layer is not None: + self.use_norm =True + self.norm = norm_layer(out_channel) + else: + self.use_norm = False + self.act = nn.LeakyReLU(0.2, True) + self.down = Downsample(out_channel) + + self.dense= nn.Linear(t_emb_dim, out_channel) + def forward(self, input,t_emb): + out = self.conv1(input) + out += self.dense(t_emb)[..., None, None] + if self.use_norm: + out = self.norm(out) + out = self.act(out) + if self.downsample: + out = self.down(out) + + return out + +class NLayerDiscriminator_ncsn(nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False): + """Construct a PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator_ncsn, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + self.model_main = nn.ModuleList() + kw = 4 + padw = 1 + if no_antialias: + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + else: + self.model_main.append(ConvBlock_cond(input_nc, ndf, 4*ndf,kernel_size=kw, stride=1, padding=padw,use_bias=use_bias)) + + nf_mult = 1 + nf_mult_prev = 1 + + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + if no_antialias: + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True)] + else: + self.model_main.append( + ConvBlock_cond(ndf * nf_mult_prev, ndf * nf_mult, 4*ndf,kernel_size=kw, stride=1, padding=padw,use_bias=use_bias,norm_layer=norm_layer) + + ) + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + self.model_main.append( + ConvBlock_cond(ndf * nf_mult_prev, ndf * nf_mult,4*ndf, kernel_size=kw, stride=1, padding=padw,use_bias=use_bias,norm_layer=norm_layer,downsample=False) + + ) + self.final_conv =nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + self.t_embed = TimestepEmbedding( + embedding_dim=4*ndf, + hidden_dim=4*ndf, + output_dim=4*ndf, + act=nn.LeakyReLU(0.2), + ) + + def forward(self, input,t_emb,input2=None): + """Standard forward.""" + t_emb = self.t_embed(t_emb) + if input2 is not None: + out = torch.cat([input,input2],dim=1) + else: + + out = input + for layer in self.model_main: + out = layer(out,t_emb) + + return self.final_conv(out) + +class PixelDiscriminator(nn.Module): + """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" + + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): + """Construct a 1x1 PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + """ + super(PixelDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + """Standard forward.""" + return self.net(input) + + +################################################################################## +# Generator +################################################################################## + +class TimestepEmbedding(nn.Module): + def __init__(self, embedding_dim, hidden_dim, output_dim, act=nn.LeakyReLU(0.2)): + super().__init__() + + self.embedding_dim = embedding_dim + self.output_dim = output_dim + self.hidden_dim = hidden_dim + + self.main = nn.Sequential( + nn.Linear(embedding_dim, hidden_dim), + nn.LeakyReLU(0.2), + nn.Linear(hidden_dim, output_dim), + nn.LeakyReLU(0.2), + # EqualLinear(hidden_dim, output_dim,bias_init = 0, activation='fused_lrelu') + ) + + def forward(self, temp): + temb = get_timestep_embedding(temp, self.embedding_dim) + temb = self.main(temb) + return temb + +def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): + assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 + half_dim = embedding_dim // 2 + # magic number 10000 is from transformers + emb = math.log(max_positions) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = F.pad(emb, (0, 1), mode='constant') + assert emb.shape == (timesteps.shape[0], embedding_dim) + return emb + +class AdaptiveLayer(nn.Module): + def __init__(self, in_channel, style_dim): + super().__init__() + + self.style_net = nn.Linear(style_dim, in_channel * 2) + + self.style_net.bias.data[:in_channel] = 1 + self.style_net.bias.data[in_channel:] = 0 + + def forward(self, input, style): + + style = self.style_net(style).unsqueeze(2).unsqueeze(3) + gamma, beta = style.chunk(2, 1) + + out = gamma * input + beta + + return out + +class ResnetGenerator_ncsn(nn.Module): + """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. + + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=9, + padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None): + """Construct a Resnet-based generator + + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero + """ + assert(n_blocks >= 0) + super(ResnetGenerator_ncsn, self).__init__() + self.opt = opt + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), + norm_layer(ngf), + nn.ReLU(True)] + self.ngf = ngf + n_downsampling = 2 + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + if no_antialias: + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + else: + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias), + norm_layer(ngf * mult * 2), + nn.ReLU(True), + Downsample(ngf * mult * 2) + # nn.AvgPool2d(kernel_size=2, stride=2) + ] + self.model_res = nn.ModuleList() + mult = 2 ** n_downsampling + for i in range(n_blocks): # add ResNet blocks + + self.model_res += [ResnetBlock_cond(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias,temb_dim=4*ngf,z_dim=4*ngf)] + + model_upsample = [] + for i in range(n_downsampling): # add upsampling layers + mult = 2 ** (n_downsampling - i) + if no_antialias_up: + model_upsample += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + else: + model_upsample += [ + Upsample(ngf * mult), + # nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), + nn.Conv2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=1, bias=use_bias), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + model_upsample += [nn.ReflectionPad2d(3)] + model_upsample += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + model_upsample += [nn.Tanh()] + + self.model = nn.Sequential(*model) + self.model_upsample = nn.Sequential(*model_upsample) + mapping_layers = [PixelNorm(), + nn.Linear(self.ngf*4, self.ngf*4), + nn.LeakyReLU(0.2)] + for _ in range(opt.n_mlp): + mapping_layers.append(nn.Linear(self.ngf*4, self.ngf*4)) + mapping_layers.append(nn.LeakyReLU(0.2)) + self.z_transform = nn.Sequential(*mapping_layers) + modules_emb = [] + modules_emb += [nn.Linear(self.ngf,self.ngf*4)] + + nn.init.zeros_(modules_emb[-1].bias) + modules_emb += [nn.LeakyReLU(0.2)] + modules_emb += [nn.Linear(self.ngf*4,self.ngf*4)] + + nn.init.zeros_(modules_emb[-1].bias) + modules_emb += [nn.LeakyReLU(0.2)] + self.time_embed = nn.Sequential(*modules_emb) + + def forward(self, x, time_cond,z,layers=[], encode_only=False): + z_embed = self.z_transform(z) + # print(z_embed.shape) + temb = get_timestep_embedding(time_cond, self.ngf) + time_embed = self.time_embed(temb) + if len(layers) > 0: + feat = x + feats = [] + for layer_id, layer in enumerate(self.model): + feat = layer(feat) + if layer_id in layers: + feats.append(feat) + + for layer_id, layer in enumerate(self.model_res): + feat = layer(feat,time_embed,z_embed) + if layer_id+len(self.model) in layers: + feats.append(feat) + if layer_id+len(self.model) == layers[-1] and encode_only: + return feats + return feat, feats + else: + + out = self.model(x) + for layer in self.model_res: + out = layer(out,time_embed,z_embed) + out = self.model_upsample(out) + return out +################################################################################## +# Basic Blocks +################################################################################## +class ResnetBlock(nn.Module): + """Define a Resnet block""" + + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Initialize the Resnet block + + A resnet block is a conv block with skip connections + We construct a conv block with build_conv_block function, + and implement skip connections in function. + Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + """ + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Construct a convolutional block. + + Parameters: + dim (int) -- the number of channels in the conv layer. + padding_type (str) -- the name of padding layer: reflect | replicate | zero + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + use_bias (bool) -- if the conv layer uses bias or not + + Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) + """ + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + """Forward function (with skip connections)""" + out = x + self.conv_block(x) # add skip connections + return out + +class ResnetBlock_cond(nn.Module): + """Define a Resnet block""" + + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim): + """Initialize the Resnet block + + A resnet block is a conv block with skip connections + We construct a conv block with build_conv_block function, + and implement skip connections in function. + Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + """ + super(ResnetBlock_cond, self).__init__() + self.conv_block,self.adaptive,self.conv_fin = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias,temb_dim,z_dim): + """Construct a convolutional block. + + Parameters: + dim (int) -- the number of channels in the conv layer. + padding_type (str) -- the name of padding layer: reflect | replicate | zero + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + use_bias (bool) -- if the conv layer uses bias or not + + Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) + """ + + self.conv_block = nn.ModuleList() + self.conv_fin = nn.ModuleList() + p = 0 + if padding_type == 'reflect': + self.conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + self.conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + self.conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] + self.adaptive = AdaptiveLayer(dim,z_dim) + self.conv_fin += [nn.ReLU(True)] + if use_dropout: + self.conv_fin += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + self.conv_fin += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + self.conv_fin += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + self.conv_fin += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] + + self.Dense_time = nn.Linear(temb_dim, dim) + # self.Dense_time.weight.data = default_init()(self.Dense_time.weight.data.shape) + nn.init.zeros_(self.Dense_time.bias) + + self.style = nn.Linear(z_dim, dim * 2) + + self.style.bias.data[:dim] = 1 + self.style.bias.data[dim:] = 0 + + return self.conv_block,self.adaptive,self.conv_fin + + def forward(self, x,time_cond,z): + + time_input = self.Dense_time(time_cond) + for n,layer in enumerate(self.conv_block): + out = layer(x) + if n==0: + out += time_input[:, :, None, None] + out = self.adaptive(out,z) + for layer in self.conv_fin: + out = layer(out) + """Forward function (with skip connections)""" + out = x + out # add skip connections + return out +############################################################################### +# Helper Functions +############################################################################### +def get_filter(filt_size=3): + if(filt_size == 1): + a = np.array([1., ]) + elif(filt_size == 2): + a = np.array([1., 1.]) + elif(filt_size == 3): + a = np.array([1., 2., 1.]) + elif(filt_size == 4): + a = np.array([1., 3., 3., 1.]) + elif(filt_size == 5): + a = np.array([1., 4., 6., 4., 1.]) + elif(filt_size == 6): + a = np.array([1., 5., 10., 10., 5., 1.]) + elif(filt_size == 7): + a = np.array([1., 6., 15., 20., 15., 6., 1.]) + + filt = torch.Tensor(a[:, None] * a[None, :]) + filt = filt / torch.sum(filt) + + return filt + + +class Downsample(nn.Module): + def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0): + super(Downsample, self).__init__() + self.filt_size = filt_size + self.pad_off = pad_off + self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] + self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] + self.stride = stride + self.off = int((self.stride - 1) / 2.) + self.channels = channels + + filt = get_filter(filt_size=self.filt_size) + self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) + + self.pad = get_pad_layer(pad_type)(self.pad_sizes) + + def forward(self, inp): + if(self.filt_size == 1): + if(self.pad_off == 0): + return inp[:, :, ::self.stride, ::self.stride] + else: + return self.pad(inp)[:, :, ::self.stride, ::self.stride] + else: + return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) + + +class Upsample2(nn.Module): + def __init__(self, scale_factor, mode='nearest'): + super().__init__() + self.factor = scale_factor + self.mode = mode + + def forward(self, x): + return torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode) + + +class Upsample(nn.Module): + def __init__(self, channels, pad_type='repl', filt_size=4, stride=2): + super(Upsample, self).__init__() + self.filt_size = filt_size + self.filt_odd = np.mod(filt_size, 2) == 1 + self.pad_size = int((filt_size - 1) / 2) + self.stride = stride + self.off = int((self.stride - 1) / 2.) + self.channels = channels + + filt = get_filter(filt_size=self.filt_size) * (stride**2) + self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1))) + + self.pad = get_pad_layer(pad_type)([1, 1, 1, 1]) + + def forward(self, inp): + ret_val = F.conv_transpose2d(self.pad(inp), self.filt, stride=self.stride, padding=1 + self.pad_size, groups=inp.shape[1])[:, :, 1:, 1:] + if(self.filt_odd): + return ret_val + else: + return ret_val[:, :, :-1, :-1] + + +def get_pad_layer(pad_type): + if(pad_type in ['refl', 'reflect']): + PadLayer = nn.ReflectionPad2d + elif(pad_type in ['repl', 'replicate']): + PadLayer = nn.ReplicationPad2d + elif(pad_type == 'zero'): + PadLayer = nn.ZeroPad2d + else: + print('Pad type [%s] not recognized' % pad_type) + return PadLayer + + +class Identity(nn.Module): + def forward(self, x): + return x + + +def get_norm_layer(norm_type='instance'): + """Return a normalization layer + + Parameters: + norm_type (str) -- the name of the normalization layer: batch | instance | none + + For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). + For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. + """ + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + elif norm_type == 'none': + def norm_layer(x): return Identity() + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +def init_weights(net, init_type='normal', init_gain=0.02, debug=False): + """Initialize network weights. + + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if debug: + print(classname) + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True): + """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights + Parameters: + net (network) -- the network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Return an initialized network. + """ + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + if initialize_weights: + init_weights(net, init_type, init_gain=init_gain, debug=debug) + return net \ No newline at end of file diff --git a/models/networks.py b/models/networks.py index ebae8aa..74343e6 100644 --- a/models/networks.py +++ b/models/networks.py @@ -7,6 +7,7 @@ from torch.optim import lr_scheduler import numpy as np import random from .stylegan_networks import StyleGAN2Discriminator, StyleGAN2Generator, TileStyleGAN2Discriminator +from .ncsn_networks import NLayerDiscriminator_ncsn, ResnetGenerator_ncsn ############################################################################### # Helper Functions @@ -266,6 +267,8 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in elif netG == 'resnet_cat': n_blocks = 8 net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu') + elif netG == 'resnet_9blocks_cond': + net = ResnetGenerator_ncsn(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=9, opt=opt) else: raise NotImplementedError('Generator model name [%s] is not recognized' % netG) return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG)) @@ -977,6 +980,7 @@ class ResnetGenerator(nn.Module): feats = [] for layer_id, layer in enumerate(self.model): # print(layer_id, layer) + print(feat.shape) feat = layer(feat) if layer_id in layers: # print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1))) @@ -984,6 +988,11 @@ class ResnetGenerator(nn.Module): else: # print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1))) pass + print(f"layer_id: {layer_id}, type(layer_id): {type(layer_id)}") + print(f"len(list(self.model)) - 1: {len(list(self.model)) - 1}, type(len(list(self.model)) - 1): {type(len(list(self.model)) - 1)}") + # Print layers to see what it is. If layers is still a tensor here, we need to understand why. + print(f"layers: {layers}, type(layers): {type(layers)}") + print(f"encode_only'shape: {encode_only.shape}, type(encode_only): {type(encode_only)}") if layer_id == len(list(self.model)) - 1 and encode_only: # print('encoder only return features') return feats # return intermediate features alone; stop in the last layers diff --git a/models/roma_unsb_model.py b/models/roma_unsb_model.py index 11ce30c..b9c2ad7 100644 --- a/models/roma_unsb_model.py +++ b/models/roma_unsb_model.py @@ -221,8 +221,10 @@ class RomaUnsbModel(BaseModel): parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers') - parser.add_argument('--tau', type=float, default=0.1, help='used in unsb') - parser.add_argument('--num_timesteps', type=int, default=10, help='used in unsb') + parser.add_argument('--tau', type=float, default=0.01, help='Entropy parameter') + parser.add_argument('--num_timesteps', type=int, default=5, help='# of discrim filters in the first conv layer') + + parser.add_argument('--n_mlp', type=int, default=3, help='only used if netD==n_layers') parser.set_defaults(pool_size=0) # no image pooling @@ -260,7 +262,8 @@ class RomaUnsbModel(BaseModel): else: self.model_names = ['G'] - + + print(f'input_nc = {self.opt.input_nc}') # 创建网络 self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt) @@ -269,7 +272,7 @@ class RomaUnsbModel(BaseModel): self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt) - self.resize = tfs.Resize(size=(384,384)) + self.resize = tfs.Resize(size=(384,384), antialias=True) # 加入预训练VIT self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device) @@ -397,16 +400,20 @@ class RomaUnsbModel(BaseModel): """执行前向传递以生成输出图像""" if self.opt.isTrain: + print(f'before resize: {self.real_A0.shape}') real_A0 = self.resize(self.real_A0) real_A1 = self.resize(self.real_A1) real_B0 = self.resize(self.real_B0) real_B1 = self.resize(self.real_B1) # 使用VIT + + print(f'before vit: {real_A0.shape}') self.mutil_real_A0_tokens = self.netPreViT(real_A0, self.atten_layers, get_tokens=True) self.mutil_real_A1_tokens = self.netPreViT(real_A1, self.atten_layers, get_tokens=True) - self.mutil_real_A0_tokens = torch.cat(self.mutil_real_A0_tokens, dim=1).to(self.device) - self.mutil_real_A1_tokens = torch.cat(self.mutil_real_A1_tokens, dim=1).to(self.device) + print(f'before cat: len = {len(self.mutil_real_A0_tokens)}\n{self.mutil_real_A0_tokens[0].shape}') + self.mutil_real_A0_tokens = torch.cat(self.mutil_real_A0_tokens, dim=0).unsqueeze(0).to(self.device) + self.mutil_real_A1_tokens = torch.cat(self.mutil_real_A1_tokens, dim=0).unsqueeze(0).to(self.device) # 执行一次SB模块 @@ -436,6 +443,7 @@ class RomaUnsbModel(BaseModel): inter = (delta / denom).reshape(-1, 1, 1, 1) scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1) + print(f'before noisy: {self.mutil_real_A0_tokens.shape}') # 对 Xt、Xt2 进行随机噪声更新 Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \ (scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device) @@ -454,7 +462,8 @@ class RomaUnsbModel(BaseModel): self.real_A_noisy = Xt.detach() self.real_A_noisy2 = Xt2.detach() # 保存noisy_map - self.noisy_map = self.real_A_noisy - self.real_A0 + print(f'after noisy map: {self.real_A_noisy.shape}') + self.noisy_map = self.real_A_noisy - self.mutil_real_A0_tokens # ============ 第三步:拼接输入并执行网络推理 ============= bs = self.mutil_real_A0_tokens.size(0) @@ -518,7 +527,7 @@ class RomaUnsbModel(BaseModel): if self.opt.phase == 'train': # 真实图像的梯度 - real_gradient = torch.autograd.grad(self.real_B.sum(), self.real_B, create_graph=True)[0] + real_gradient = torch.autograd.grad(self.real_B0.sum(), self.real_B0, create_graph=True)[0] # 生成图像的梯度 fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0] # 梯度图 diff --git a/options/__pycache__/base_options.cpython-39.pyc b/options/__pycache__/base_options.cpython-39.pyc index 066b359..55ab7e1 100644 Binary files a/options/__pycache__/base_options.cpython-39.pyc and b/options/__pycache__/base_options.cpython-39.pyc differ diff --git a/options/base_options.py b/options/base_options.py index 5837dd5..f9de39b 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -36,7 +36,7 @@ class BaseOptions(): parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') - parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat'], help='specify generator architecture') + parser.add_argument('--netG', type=str, default='resnet_9blocks_cond', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat', 'resnet_9blocks_cond'], help='specify generator architecture') parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G') parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D') diff --git a/scripts/train.sh b/scripts/train.sh index 1c60f02..93a5f96 100755 --- a/scripts/train.sh +++ b/scripts/train.sh @@ -30,4 +30,4 @@ python train.py \ --eta_ratio 0.1 \ --tau 0.1 \ --num_timesteps 10 \ - --input_nc 1 + --input_nc 3