roma_unsb/models/template_model.py
2025-02-22 14:21:54 +08:00

100 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Model class template
This module provides a template for users to implement custom models.
You can specify '--model template' to use this model.
The class name should be consistent with both the filename and its model option.
The filename should be <model>_dataset.py
The class name should be <Model>Dataset.py
It implements a simple image-to-image translation baseline based on regression loss.
Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
min_<netG> ||netG(data_A) - data_B||_1
You need to implement the following functions:
<modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
<__init__>: Initialize this model class.
<set_input>: Unpack input data and perform data pre-processing.
<forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
<optimize_parameters>: Update network weights; it will be called in every training iteration.
"""
import torch
from .base_model import BaseModel
from . import networks
class TemplateModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
"""Add new model-specific options and rewrite default values for existing options.
Parameters:
parser -- the option parser
is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
if is_train:
parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
return parser
def __init__(self, opt):
"""Initialize this model class.
Parameters:
opt -- training/test options
A few things can be done here.
- (required) call the initialization function of BaseModel
- define loss function, visualization images, model names, and optimizers
"""
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
# specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
self.loss_names = ['loss_G']
# specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
self.visual_names = ['data_A', 'data_B', 'output']
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
# you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
self.model_names = ['G']
# define networks; you can use opt.isTrain to specify different behaviors for training and test.
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
if self.isTrain: # only defined during training time
# define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
# We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
self.criterionLoss = torch.nn.L1Loss()
# define and initialize optimizers. You can define one optimizer for each network.
# If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = [self.optimizer]
# Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
"""
AtoB = self.opt.direction == 'AtoB' # use <direction> to swap data_A and data_B
self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
def forward(self):
"""Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
self.output = self.netG(self.data_A) # generate output image given the input data_A
def backward(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# caculate the intermediate results if necessary; here self.output has been computed during function <forward>
# calculate loss given the input and intermediate results
self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
def optimize_parameters(self):
"""Update network weights; it will be called in every training iteration."""
self.forward() # first call forward to calculate intermediate results
self.optimizer.zero_grad() # clear network G's existing gradients
self.backward() # calculate gradients for network G
self.optimizer.step() # update gradients for network G