commit 5cb1f588525679c4f02b0d58208fe4aa2349d535
Author: areszz <1031614818@qq.com>
Date: Sat Feb 22 14:21:54 2025 +0800
first commit
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..4e8168c
--- /dev/null
+++ b/README.md
@@ -0,0 +1,142 @@
+# ROMA
+This repository is the official Pytorch implementation for ACM MM'22 paper
+"ROMA: Cross-Domain Region Similarity Matching for Unpaired Nighttime Infrared to Daytime Visible Video Translation".[[Arxiv]](https://arxiv.org/abs/2204.12367)
+
+**Examples of Object Detection:**
+
+
+
+
+
+**Examples of Video Fusion**
+
+
+
+More experimental results can be obtained by contacting us.
+
+# Introduction
+
+## Method
+
+
+- The domain gaps between unpaired nighttime infrared and daytime visible videos are even huger than paired ones that captured at the same time, establishing an effective translation mapping will greatly contribute to various fields.
+- Our proposed cross-similarity, which are calculated across domains, could make the generative process focus on learning the content of structural correspondence between real and synthesized frames, getting rid of the negative effects of different styles.
+
+
+
+## Training
+The following is the required structure of dataset. For the video mode, the input of a single data is the result of concatenating **two adjacent frames**; for the image mode, the input of a single data is **a single image**.
+```
+Video/Image mode:
+ trainA: \Path\of\trainA
+ trainB: \Path\of\trainB
+
+```
+Concrete examples of the training and testing are shown in the script files `./scripts/train.sh` and `./scripts/test.sh`, respectively.
+
+
+
+
+## InfraredCity and InfraredCity-Lite Dataset
+
+
+
+
+
+ | InfraredCity |
+ Total Frame |
+
+
+
+
+ | Nighttime Infrared |
+ 201,856 |
+
+
+ | Nighttime Visible |
+ 178,698 |
+
+
+ | Daytime Visible |
+ 199,430 |
+
+
+ |
+
+
+ | InfraredCity-Lite |
+ Infrared Train |
+ Infrared Test |
+ Visible Train |
+ Total |
+
+
+ | City |
+ clearday |
+ 5,538 |
+ 1,000 |
+ 5360 |
+ 15,180 |
+
+
+ | overcast |
+ 2,282 |
+ 1,000 |
+
+
+ | Highway |
+ clearday |
+ 4,412 |
+ 1,000 |
+ 6,463 |
+ 15,853 |
+
+
+ | overcast |
+ 2,978 |
+ 1,000 |
+
+
+ | Monitor |
+ 5,612 |
+ 500 |
+ 4,194 |
+ 10,306 |
+
+
+
+
+The datasets and their more details are available in [InfiRay](http://openai.raytrontek.com/apply/Infrared_city.html/).
+
+
+### Citation
+If you find our work useful in your research or publication, please cite our work:
+```
+@inproceedings{ROMA2022,
+ title = {ROMA: Cross-Domain Region Similarity Matching for Unpaired Nighttime Infrared to Daytime Visible Video Translation},
+ author = {Zhenjie Yu and Kai Chen and Shuang Li and Bingfeng Han and Chi Harold Liu and Shuigen Wang},
+ booktitle = {ACM MM},
+ pages = {5294--5302},
+ year = {2022}
+}
+```
+
+#### Acknowledgements
+This code borrows heavily from the PyTorch implementation of [Cycle-GAN and Pix2Pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) and [CUT](https://github.com/taesungp/contrastive-unpaired-translation).
+A huge thanks to them!
+```
+@inproceedings{CycleGAN2017,
+ title = {Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss},
+ author = {Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
+ booktitle = {ICCV},
+ year = {2017}
+}
+
+@inproceedings{CUT2020,
+ author = {Taesung Park and Alexei A. Efros and Richard Zhang and Jun{-}Yan Zhu},
+ title = {Contrastive Learning for Unpaired Image-to-Image Translation},
+ booktitle = {ECCV},
+ pages = {319--345},
+ year = {2020},
+}
+```
\ No newline at end of file
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000..a7dd29b
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1,98 @@
+"""This package includes all the modules related to data loading and preprocessing
+
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
+ You need to implement four functions:
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
+ -- <__len__>: return the size of dataset.
+ -- <__getitem__>: get a data point from data loader.
+ -- : (optionally) add dataset-specific options and set default options.
+
+Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
+See our template dataset class 'template_dataset.py' for more details.
+"""
+import importlib
+import torch.utils.data
+from data.base_dataset import BaseDataset
+
+
+def find_dataset_using_name(dataset_name):
+ """Import the module "data/[dataset_name]_dataset.py".
+
+ In the file, the class called DatasetNameDataset() will
+ be instantiated. It has to be a subclass of BaseDataset,
+ and it is case-insensitive.
+ """
+ dataset_filename = "data." + dataset_name + "_dataset"
+ datasetlib = importlib.import_module(dataset_filename)
+
+ dataset = None
+ target_dataset_name = dataset_name.replace('_', '') + 'dataset'
+ for name, cls in datasetlib.__dict__.items():
+ if name.lower() == target_dataset_name.lower() \
+ and issubclass(cls, BaseDataset):
+ dataset = cls
+
+ if dataset is None:
+ raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
+
+ return dataset
+
+
+def get_option_setter(dataset_name):
+ """Return the static method of the dataset class."""
+ dataset_class = find_dataset_using_name(dataset_name)
+ return dataset_class.modify_commandline_options
+
+
+def create_dataset(opt):
+ """Create a dataset given the option.
+
+ This function wraps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from data import create_dataset
+ >>> dataset = create_dataset(opt)
+ """
+ data_loader = CustomDatasetDataLoader(opt)
+ dataset = data_loader.load_data()
+ return dataset
+
+
+class CustomDatasetDataLoader():
+ """Wrapper class of Dataset class that performs multi-threaded data loading"""
+
+ def __init__(self, opt):
+ """Initialize this class
+
+ Step 1: create a dataset instance given the name [dataset_mode]
+ Step 2: create a multi-threaded data loader.
+ """
+ self.opt = opt
+ dataset_class = find_dataset_using_name(opt.dataset_mode)
+ self.dataset = dataset_class(opt)
+ print("dataset [%s] was created" % type(self.dataset).__name__)
+ self.dataloader = torch.utils.data.DataLoader(
+ self.dataset,
+ batch_size=opt.batch_size,
+ shuffle=not opt.serial_batches,
+ num_workers=int(opt.num_threads),
+ drop_last=True if opt.isTrain else False,
+ )
+
+ def set_epoch(self, epoch):
+ self.dataset.current_epoch = epoch
+
+ def load_data(self):
+ return self
+
+ def __len__(self):
+ """Return the number of data in the dataset"""
+ return min(len(self.dataset), self.opt.max_dataset_size)
+
+ def __iter__(self):
+ """Return a batch of data"""
+ for i, data in enumerate(self.dataloader):
+ if i * self.opt.batch_size >= self.opt.max_dataset_size:
+ break
+ yield data
diff --git a/data/__pycache__/__init__.cpython-36.pyc b/data/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..b5af05d
Binary files /dev/null and b/data/__pycache__/__init__.cpython-36.pyc differ
diff --git a/data/__pycache__/base_dataset.cpython-36.pyc b/data/__pycache__/base_dataset.cpython-36.pyc
new file mode 100644
index 0000000..6904eb2
Binary files /dev/null and b/data/__pycache__/base_dataset.cpython-36.pyc differ
diff --git a/data/__pycache__/image_folder.cpython-36.pyc b/data/__pycache__/image_folder.cpython-36.pyc
new file mode 100644
index 0000000..130d5a1
Binary files /dev/null and b/data/__pycache__/image_folder.cpython-36.pyc differ
diff --git a/data/__pycache__/unaligned_dataset.cpython-36.pyc b/data/__pycache__/unaligned_dataset.cpython-36.pyc
new file mode 100644
index 0000000..499082c
Binary files /dev/null and b/data/__pycache__/unaligned_dataset.cpython-36.pyc differ
diff --git a/data/__pycache__/unaligned_double_dataset.cpython-36.pyc b/data/__pycache__/unaligned_double_dataset.cpython-36.pyc
new file mode 100644
index 0000000..288527d
Binary files /dev/null and b/data/__pycache__/unaligned_double_dataset.cpython-36.pyc differ
diff --git a/data/base_dataset.py b/data/base_dataset.py
new file mode 100644
index 0000000..5748a9d
--- /dev/null
+++ b/data/base_dataset.py
@@ -0,0 +1,230 @@
+"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
+
+It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
+"""
+import random
+import numpy as np
+import torch.utils.data as data
+from PIL import Image
+import torchvision.transforms as transforms
+from abc import ABC, abstractmethod
+
+
+class BaseDataset(data.Dataset, ABC):
+ """This class is an abstract base class (ABC) for datasets.
+
+ To create a subclass, you need to implement the following four functions:
+ -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
+ -- <__len__>: return the size of dataset.
+ -- <__getitem__>: get a data point.
+ -- : (optionally) add dataset-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the class; save the options in the class
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ self.opt = opt
+ self.root = opt.dataroot
+ self.current_epoch = 0
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def __len__(self):
+ """Return the total number of images in the dataset."""
+ return 0
+
+ @abstractmethod
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index - - a random integer for data indexing
+
+ Returns:
+ a dictionary of data with their names. It ususally contains the data itself and its metadata information.
+ """
+ pass
+
+
+def get_params(opt, size):
+ w, h = size
+ new_h = h
+ new_w = w
+ if opt.preprocess == 'resize_and_crop':
+ new_h = new_w = opt.load_size
+ elif opt.preprocess == 'scale_width_and_crop':
+ new_w = opt.load_size
+ new_h = opt.load_size * h // w
+
+ x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
+ y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
+
+ flip = random.random() > 0.5
+
+ return {'crop_pos': (x, y), 'flip': flip}
+
+
+def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
+ transform_list = []
+ if grayscale:
+ transform_list.append(transforms.Grayscale(1))
+ if 'fixsize' in opt.preprocess:
+ transform_list.append(transforms.Resize(params["size"], method))
+ if 'resize' in opt.preprocess:
+ osize = [opt.load_size, opt.load_size]
+ if "gta2cityscapes" in opt.dataroot:
+ osize[0] = opt.load_size // 2
+ transform_list.append(transforms.Resize(osize, method))
+ elif 'scale_width' in opt.preprocess:
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
+ elif 'scale_shortside' in opt.preprocess:
+ transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method)))
+
+ if 'zoom' in opt.preprocess:
+ if params is None:
+ transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method)))
+ else:
+ transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"])))
+
+ if 'crop' in opt.preprocess:
+ if params is None or 'crop_pos' not in params:
+ transform_list.append(transforms.RandomCrop(opt.crop_size))
+ else:
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
+
+ if 'patch' in opt.preprocess:
+ transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size)))
+
+ if 'trim' in opt.preprocess:
+ transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size)))
+
+ # if opt.preprocess == 'none':
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
+
+ if not opt.no_flip:
+ if params is None or 'flip' not in params:
+ transform_list.append(transforms.RandomHorizontalFlip())
+ elif 'flip' in params:
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
+
+ if convert:
+ transform_list += [transforms.ToTensor()]
+ if grayscale:
+ transform_list += [transforms.Normalize((0.5,), (0.5,))]
+ else:
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
+ return transforms.Compose(transform_list)
+
+
+def __make_power_2(img, base, method=Image.BICUBIC):
+ ow, oh = img.size
+ h = int(round(oh / base) * base)
+ w = int(round(ow / base) * base)
+ if h == oh and w == ow:
+ return img
+
+ return img.resize((w, h), method)
+
+
+def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None):
+ if factor is None:
+ zoom_level = np.random.uniform(0.8, 1.0, size=[2])
+ else:
+ zoom_level = (factor[0], factor[1])
+ iw, ih = img.size
+ zoomw = max(crop_width, iw * zoom_level[0])
+ zoomh = max(crop_width, ih * zoom_level[1])
+ img = img.resize((int(round(zoomw)), int(round(zoomh))), method)
+ return img
+
+
+def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC):
+ ow, oh = img.size
+ shortside = min(ow, oh)
+ if shortside >= target_width:
+ return img
+ else:
+ scale = target_width / shortside
+ return img.resize((round(ow * scale), round(oh * scale)), method)
+
+
+def __trim(img, trim_width):
+ ow, oh = img.size
+ if ow > trim_width:
+ xstart = np.random.randint(ow - trim_width)
+ xend = xstart + trim_width
+ else:
+ xstart = 0
+ xend = ow
+ if oh > trim_width:
+ ystart = np.random.randint(oh - trim_width)
+ yend = ystart + trim_width
+ else:
+ ystart = 0
+ yend = oh
+ return img.crop((xstart, ystart, xend, yend))
+
+
+def __scale_width(img, target_width, crop_width, method=Image.BICUBIC):
+ ow, oh = img.size
+ if ow == target_width and oh >= crop_width:
+ return img
+ w = target_width
+ h = int(max(target_width * oh / ow, crop_width))
+ return img.resize((w, h), method)
+
+
+def __crop(img, pos, size):
+ ow, oh = img.size
+ x1, y1 = pos
+ tw = th = size
+ if (ow > tw or oh > th):
+ return img.crop((x1, y1, x1 + tw, y1 + th))
+ return img
+
+
+def __patch(img, index, size):
+ ow, oh = img.size
+ nw, nh = ow // size, oh // size
+ roomx = ow - nw * size
+ roomy = oh - nh * size
+ startx = np.random.randint(int(roomx) + 1)
+ starty = np.random.randint(int(roomy) + 1)
+
+ index = index % (nw * nh)
+ ix = index // nh
+ iy = index % nh
+ gridx = startx + ix * size
+ gridy = starty + iy * size
+ return img.crop((gridx, gridy, gridx + size, gridy + size))
+
+
+def __flip(img, flip):
+ if flip:
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
+ return img
+
+
+def __print_size_warning(ow, oh, w, h):
+ """Print warning information about image size(only print once)"""
+ if not hasattr(__print_size_warning, 'has_printed'):
+ print("The image size needs to be a multiple of 4. "
+ "The loaded image size was (%d, %d), so it was adjusted to "
+ "(%d, %d). This adjustment will be done to all images "
+ "whose sizes are not multiples of 4" % (ow, oh, w, h))
+ __print_size_warning.has_printed = True
diff --git a/data/image_folder.py b/data/image_folder.py
new file mode 100644
index 0000000..2a137d3
--- /dev/null
+++ b/data/image_folder.py
@@ -0,0 +1,66 @@
+"""A modified image folder class
+
+We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
+so that this class can load images from both current directory and its subdirectories.
+"""
+
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+
+IMG_EXTENSIONS = [
+ '.jpg', '.JPG', '.jpeg', '.JPEG',
+ '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+ '.tif', '.TIF', '.tiff', '.TIFF',
+]
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def make_dataset(dir, max_dataset_size=float("inf")):
+ images = []
+ assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
+
+ for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
+ for fname in fnames:
+ if is_image_file(fname):
+ path = os.path.join(root, fname)
+ images.append(path)
+ return images[:min(max_dataset_size, len(images))]
+
+
+def default_loader(path):
+ return Image.open(path).convert('RGB')
+
+
+class ImageFolder(data.Dataset):
+
+ def __init__(self, root, transform=None, return_paths=False,
+ loader=default_loader):
+ imgs = make_dataset(root)
+ if len(imgs) == 0:
+ raise(RuntimeError("Found 0 images in: " + root + "\n"
+ "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
+
+ self.root = root
+ self.imgs = imgs
+ self.transform = transform
+ self.return_paths = return_paths
+ self.loader = loader
+
+ def __getitem__(self, index):
+ path = self.imgs[index]
+ img = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.return_paths:
+ return img, path
+ else:
+ return img
+
+ def __len__(self):
+ return len(self.imgs)
diff --git a/data/single_dataset.py b/data/single_dataset.py
new file mode 100644
index 0000000..9a5c323
--- /dev/null
+++ b/data/single_dataset.py
@@ -0,0 +1,40 @@
+from data.base_dataset import BaseDataset, get_transform
+from data.image_folder import make_dataset
+from PIL import Image
+
+
+class SingleDataset(BaseDataset):
+ """This dataset class can load a set of images specified by the path --dataroot /path/to/data.
+
+ It can be used for generating CycleGAN results only for one side with the model option '-model test'.
+ """
+
+ def __init__(self, opt):
+ """Initialize this dataset class.
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ BaseDataset.__init__(self, opt)
+ self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
+ input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
+ self.transform = get_transform(opt, grayscale=(input_nc == 1))
+
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index - - a random integer for data indexing
+
+ Returns a dictionary that contains A and A_paths
+ A(tensor) - - an image in one domain
+ A_paths(str) - - the path of the image
+ """
+ A_path = self.A_paths[index]
+ A_img = Image.open(A_path).convert('RGB')
+ A = self.transform(A_img)
+ return {'A': A, 'A_paths': A_path}
+
+ def __len__(self):
+ """Return the total number of images in the dataset."""
+ return len(self.A_paths)
diff --git a/data/singleimage_dataset.py b/data/singleimage_dataset.py
new file mode 100644
index 0000000..0a9f1b5
--- /dev/null
+++ b/data/singleimage_dataset.py
@@ -0,0 +1,108 @@
+import numpy as np
+import os.path
+from data.base_dataset import BaseDataset, get_transform
+from data.image_folder import make_dataset
+from PIL import Image
+import random
+import util.util as util
+
+
+class SingleImageDataset(BaseDataset):
+ """
+ This dataset class can load unaligned/unpaired datasets.
+
+ It requires two directories to host training images from domain A '/path/to/data/trainA'
+ and from domain B '/path/to/data/trainB' respectively.
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
+ Similarly, you need to prepare two directories:
+ '/path/to/data/testA' and '/path/to/data/testB' during test time.
+ """
+
+ def __init__(self, opt):
+ """Initialize this dataset class.
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ BaseDataset.__init__(self, opt)
+
+ self.dir_A = os.path.join(opt.dataroot, 'trainA') # create a path '/path/to/data/trainA'
+ self.dir_B = os.path.join(opt.dataroot, 'trainB') # create a path '/path/to/data/trainB'
+
+ if os.path.exists(self.dir_A) and os.path.exists(self.dir_B):
+ self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
+ self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
+ self.A_size = len(self.A_paths) # get the size of dataset A
+ self.B_size = len(self.B_paths) # get the size of dataset B
+
+ assert len(self.A_paths) == 1 and len(self.B_paths) == 1,\
+ "SingleImageDataset class should be used with one image in each domain"
+ A_img = Image.open(self.A_paths[0]).convert('RGB')
+ B_img = Image.open(self.B_paths[0]).convert('RGB')
+ print("Image sizes %s and %s" % (str(A_img.size), str(B_img.size)))
+
+ self.A_img = A_img
+ self.B_img = B_img
+
+ # In single-image translation, we augment the data loader by applying
+ # random scaling. Still, we design the data loader such that the
+ # amount of scaling is the same within a minibatch. To do this,
+ # we precompute the random scaling values, and repeat them by |batch_size|.
+ A_zoom = 1 / self.opt.random_scale_max
+ zoom_levels_A = np.random.uniform(A_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2))
+ self.zoom_levels_A = np.reshape(np.tile(zoom_levels_A, (1, opt.batch_size, 1)), [-1, 2])
+
+ B_zoom = 1 / self.opt.random_scale_max
+ zoom_levels_B = np.random.uniform(B_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2))
+ self.zoom_levels_B = np.reshape(np.tile(zoom_levels_B, (1, opt.batch_size, 1)), [-1, 2])
+
+ # While the crop locations are randomized, the negative samples should
+ # not come from the same location. To do this, we precompute the
+ # crop locations with no repetition.
+ self.patch_indices_A = list(range(len(self)))
+ random.shuffle(self.patch_indices_A)
+ self.patch_indices_B = list(range(len(self)))
+ random.shuffle(self.patch_indices_B)
+
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index (int) -- a random integer for data indexing
+
+ Returns a dictionary that contains A, B, A_paths and B_paths
+ A (tensor) -- an image in the input domain
+ B (tensor) -- its corresponding image in the target domain
+ A_paths (str) -- image paths
+ B_paths (str) -- image paths
+ """
+ A_path = self.A_paths[0]
+ B_path = self.B_paths[0]
+ A_img = self.A_img
+ B_img = self.B_img
+
+ # apply image transformation
+ if self.opt.phase == "train":
+ param = {'scale_factor': self.zoom_levels_A[index],
+ 'patch_index': self.patch_indices_A[index],
+ 'flip': random.random() > 0.5}
+
+ transform_A = get_transform(self.opt, params=param, method=Image.BILINEAR)
+ A = transform_A(A_img)
+
+ param = {'scale_factor': self.zoom_levels_B[index],
+ 'patch_index': self.patch_indices_B[index],
+ 'flip': random.random() > 0.5}
+ transform_B = get_transform(self.opt, params=param, method=Image.BILINEAR)
+ B = transform_B(B_img)
+ else:
+ transform = get_transform(self.opt, method=Image.BILINEAR)
+ A = transform(A_img)
+ B = transform(B_img)
+
+ return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
+
+ def __len__(self):
+ """ Let's pretend the single image contains 100,000 crops for convenience.
+ """
+ return 100000
diff --git a/data/template_dataset.py b/data/template_dataset.py
new file mode 100644
index 0000000..bfdf16b
--- /dev/null
+++ b/data/template_dataset.py
@@ -0,0 +1,75 @@
+"""Dataset class template
+
+This module provides a template for users to implement custom datasets.
+You can specify '--dataset_mode template' to use this dataset.
+The class name should be consistent with both the filename and its dataset_mode option.
+The filename should be _dataset.py
+The class name should be Dataset.py
+You need to implement the following functions:
+ -- : Add dataset-specific options and rewrite default values for existing options.
+ -- <__init__>: Initialize this dataset class.
+ -- <__getitem__>: Return a data point and its metadata information.
+ -- <__len__>: Return the number of images.
+"""
+from data.base_dataset import BaseDataset, get_transform
+# from data.image_folder import make_dataset
+# from PIL import Image
+
+
+class TemplateDataset(BaseDataset):
+ """A template dataset class for you to implement custom datasets."""
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
+ parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this dataset class.
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ A few things can be done here.
+ - save the options (have been done in BaseDataset)
+ - get image paths and meta information of the dataset.
+ - define the image transformation.
+ """
+ # save the option and dataset root
+ BaseDataset.__init__(self, opt)
+ # get the image paths of your dataset;
+ self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
+ # define the default transform function. You can use ; You can also define your custom transform function
+ self.transform = get_transform(opt)
+
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index -- a random integer for data indexing
+
+ Returns:
+ a dictionary of data with their names. It usually contains the data itself and its metadata information.
+
+ Step 1: get a random image path: e.g., path = self.image_paths[index]
+ Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
+ Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
+ Step 4: return a data point as a dictionary.
+ """
+ path = 'temp' # needs to be a string
+ data_A = None # needs to be a tensor
+ data_B = None # needs to be a tensor
+ return {'data_A': data_A, 'data_B': data_B, 'path': path}
+
+ def __len__(self):
+ """Return the total number of images."""
+ return len(self.image_paths)
diff --git a/data/unaligned_dataset.py b/data/unaligned_dataset.py
new file mode 100644
index 0000000..b8df773
--- /dev/null
+++ b/data/unaligned_dataset.py
@@ -0,0 +1,79 @@
+import os.path
+from data.base_dataset import BaseDataset, get_transform
+from data.image_folder import make_dataset
+from PIL import Image
+import random
+import util.util as util
+
+
+class UnalignedDataset(BaseDataset):
+ """
+ This dataset class can load unaligned/unpaired datasets.
+
+ It requires two directories to host training images from domain A '/path/to/data/trainA'
+ and from domain B '/path/to/data/trainB' respectively.
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
+ Similarly, you need to prepare two directories:
+ '/path/to/data/testA' and '/path/to/data/testB' during test time.
+ """
+
+ def __init__(self, opt):
+ """Initialize this dataset class.
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ BaseDataset.__init__(self, opt)
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
+
+ if opt.phase == "test" and not os.path.exists(self.dir_A) \
+ and os.path.exists(os.path.join(opt.dataroot, "valA")):
+ self.dir_A = os.path.join(opt.dataroot, "valA")
+ self.dir_B = os.path.join(opt.dataroot, "valB")
+
+ self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
+ self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
+ self.A_size = len(self.A_paths) # get the size of dataset A
+ self.B_size = len(self.B_paths) # get the size of dataset B
+
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index (int) -- a random integer for data indexing
+
+ Returns a dictionary that contains A, B, A_paths and B_paths
+ A (tensor) -- an image in the input domain
+ B (tensor) -- its corresponding image in the target domain
+ A_paths (str) -- image paths
+ B_paths (str) -- image paths
+ """
+ A_path = self.A_paths[index % self.A_size] # make sure index is within then range
+ if self.opt.serial_batches: # make sure index is within then range
+ index_B = index % self.B_size
+ else: # randomize the index for domain B to avoid fixed pairs.
+ index_B = random.randint(0, self.B_size - 1)
+ B_path = self.B_paths[index_B]
+ A_img = Image.open(A_path).convert('RGB')
+ B_img = Image.open(B_path).convert('RGB')
+
+ # Apply image transformation
+ # For FastCUT mode, if in finetuning phase (learning rate is decaying),
+ # do not perform resize-crop data augmentation of CycleGAN.
+# print('current_epoch', self.current_epoch)
+ is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs
+ modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size)
+ transform = get_transform(modified_opt)
+ A = transform(A_img)
+ B = transform(B_img)
+
+ return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
+
+ def __len__(self):
+ """Return the total number of images in the dataset.
+
+ As we have two datasets with potentially different number of images,
+ we take a maximum of
+ """
+ return max(self.A_size, self.B_size)
diff --git a/data/unaligned_double_dataset.py b/data/unaligned_double_dataset.py
new file mode 100644
index 0000000..245984a
--- /dev/null
+++ b/data/unaligned_double_dataset.py
@@ -0,0 +1,100 @@
+import os.path
+from data.base_dataset import BaseDataset, get_transform
+from data.image_folder import make_dataset
+from PIL import Image
+import random
+import util.util as util
+import torchvision.transforms.functional as TF
+import random
+from torchvision.transforms import transforms as tfs
+
+class UnalignedDoubleDataset(BaseDataset):
+ """
+ This dataset class can load unaligned/unpaired datasets.
+
+ It requires two directories to host training images from domain A '/path/to/data/trainA'
+ and from domain B '/path/to/data/trainB' respectively.
+ You can train the model with the dataset flag '--dataroot /path/to/data'.
+ Similarly, you need to prepare two directories:
+ '/path/to/data/testA' and '/path/to/data/testB' during test time.
+ """
+
+ def __init__(self, opt):
+ """Initialize this dataset class.
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ # self.use_resize_crop = opt.use_resize_crop
+ BaseDataset.__init__(self, opt)
+ self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
+ self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
+ self.opt = opt
+ if opt.phase == "test" and not os.path.exists(self.dir_A) \
+ and os.path.exists(os.path.join(opt.dataroot, "valA")):
+ self.dir_A = os.path.join(opt.dataroot, "valA")
+ self.dir_B = os.path.join(opt.dataroot, "valB")
+
+ self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
+ self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
+ self.A_size = len(self.A_paths) # get the size of dataset A
+ self.B_size = len(self.B_paths) # get the size of dataset B
+
+ def __getitem__(self, index):
+ """Return a data point and its metadata information.
+
+ Parameters:
+ index (int) -- a random integer for data indexing
+
+ Returns a dictionary that contains A, B, A_paths and B_paths
+ A (tensor) -- an image in the input domain
+ B (tensor) -- its corresponding image in the target domain
+ A_paths (str) -- image paths
+ B_paths (str) -- image paths
+ """
+ A_path = self.A_paths[index % self.A_size] # make sure index is within then range
+ if self.opt.serial_batches: # make sure index is within then range
+ index_B = index % self.B_size
+ else: # randomize the index for domain B to avoid fixed pairs.
+ index_B = random.randint(0, self.B_size - 1)
+ B_path = self.B_paths[index_B]
+ A_img = Image.open(A_path).convert('RGB')
+ A0 = A_img.crop((0,0,256,256))
+ A1 = A_img.crop((256,0,512,256))
+ B_img = Image.open(B_path).convert('RGB')
+ B0 = B_img.crop((0,0,256,256))
+ B1 = B_img.crop((256,0,512,256))
+
+ # Apply image transformation
+ # For FastCUT mode, if in finetuning phase (learning rate is decaying),
+ # do not perform resize-crop data augmentation of CycleGAN.
+# print('current_epoch', self.current_epoch)
+ is_finetuning = self.opt.isTrain and self.current_epoch > self.opt.n_epochs
+ modified_opt = util.copyconf(self.opt, load_size=self.opt.crop_size if is_finetuning else self.opt.load_size)
+
+ resize = tfs.Resize(size=(self.opt.load_size, self.opt.load_size))
+ imgA = resize(A0)
+ param = dict()
+ i, j, h, w = tfs.RandomCrop.get_params(
+ imgA, output_size=(self.opt.crop_size, self.opt.crop_size))
+ param['crop_pos'] = (i, j)
+ transform = get_transform(modified_opt, param)
+ # print(transform)
+ # sys.exit(0)
+ # A = transform(A_img)
+ # B = transform(B_img)
+
+ A0 = transform(A0)
+ B0 = transform(B0)
+ A1 = transform(A1)
+ B1 = transform(B1)
+
+ return {'A0': A0, 'A1': A1, 'B0': B0, 'B1': B1, 'A_paths': A_path, 'B_paths': B_path}
+
+ def __len__(self):
+ """Return the total number of images in the dataset.
+
+ As we have two datasets with potentially different number of images,
+ we take a maximum of
+ """
+ return max(self.A_size, self.B_size)
diff --git a/datasets/bibtex/cityscapes.tex b/datasets/bibtex/cityscapes.tex
new file mode 100644
index 0000000..a87bdbf
--- /dev/null
+++ b/datasets/bibtex/cityscapes.tex
@@ -0,0 +1,6 @@
+@inproceedings{Cordts2016Cityscapes,
+title={The Cityscapes Dataset for Semantic Urban Scene Understanding},
+author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt},
+booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
+year={2016}
+}
diff --git a/datasets/bibtex/facades.tex b/datasets/bibtex/facades.tex
new file mode 100644
index 0000000..08b773e
--- /dev/null
+++ b/datasets/bibtex/facades.tex
@@ -0,0 +1,7 @@
+@INPROCEEDINGS{Tylecek13,
+ author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra},
+ title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure},
+ booktitle = {Proc. GCPR},
+ year = {2013},
+ address = {Saarbrucken, Germany},
+}
diff --git a/datasets/bibtex/handbags.tex b/datasets/bibtex/handbags.tex
new file mode 100644
index 0000000..b79710c
--- /dev/null
+++ b/datasets/bibtex/handbags.tex
@@ -0,0 +1,13 @@
+@inproceedings{zhu2016generative,
+ title={Generative Visual Manipulation on the Natural Image Manifold},
+ author={Zhu, Jun-Yan and Kr{\"a}henb{\"u}hl, Philipp and Shechtman, Eli and Efros, Alexei A.},
+ booktitle={Proceedings of European Conference on Computer Vision (ECCV)},
+ year={2016}
+}
+
+@InProceedings{xie15hed,
+ author = {"Xie, Saining and Tu, Zhuowen"},
+ Title = {Holistically-Nested Edge Detection},
+ Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
+ Year = {2015},
+}
diff --git a/datasets/bibtex/shoes.tex b/datasets/bibtex/shoes.tex
new file mode 100644
index 0000000..e67e158
--- /dev/null
+++ b/datasets/bibtex/shoes.tex
@@ -0,0 +1,14 @@
+@InProceedings{fine-grained,
+ author = {A. Yu and K. Grauman},
+ title = {{F}ine-{G}rained {V}isual {C}omparisons with {L}ocal {L}earning},
+ booktitle = {Computer Vision and Pattern Recognition (CVPR)},
+ month = {June},
+ year = {2014}
+}
+
+@InProceedings{xie15hed,
+ author = {"Xie, Saining and Tu, Zhuowen"},
+ Title = {Holistically-Nested Edge Detection},
+ Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
+ Year = {2015},
+}
diff --git a/datasets/bibtex/transattr.tex b/datasets/bibtex/transattr.tex
new file mode 100644
index 0000000..0585849
--- /dev/null
+++ b/datasets/bibtex/transattr.tex
@@ -0,0 +1,8 @@
+@article {Laffont14,
+ title = {Transient Attributes for High-Level Understanding and Editing of Outdoor Scenes},
+ author = {Pierre-Yves Laffont and Zhile Ren and Xiaofeng Tao and Chao Qian and James Hays},
+ journal = {ACM Transactions on Graphics (proceedings of SIGGRAPH)},
+ volume = {33},
+ number = {4},
+ year = {2014}
+}
diff --git a/datasets/combine_A_and_B.py b/datasets/combine_A_and_B.py
new file mode 100644
index 0000000..329b1ec
--- /dev/null
+++ b/datasets/combine_A_and_B.py
@@ -0,0 +1,48 @@
+import os
+import numpy as np
+import cv2
+import argparse
+
+parser = argparse.ArgumentParser('create image pairs')
+parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
+parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
+parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
+parser.add_argument('--num_imgs', dest='num_imgs', help='number of images', type=int, default=1000000)
+parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)', action='store_true')
+args = parser.parse_args()
+
+for arg in vars(args):
+ print('[%s] = ' % arg, getattr(args, arg))
+
+splits = os.listdir(args.fold_A)
+
+for sp in splits:
+ img_fold_A = os.path.join(args.fold_A, sp)
+ img_fold_B = os.path.join(args.fold_B, sp)
+ img_list = os.listdir(img_fold_A)
+ if args.use_AB:
+ img_list = [img_path for img_path in img_list if '_A.' in img_path]
+
+ num_imgs = min(args.num_imgs, len(img_list))
+ print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
+ img_fold_AB = os.path.join(args.fold_AB, sp)
+ if not os.path.isdir(img_fold_AB):
+ os.makedirs(img_fold_AB)
+ print('split = %s, number of images = %d' % (sp, num_imgs))
+ for n in range(num_imgs):
+ name_A = img_list[n]
+ path_A = os.path.join(img_fold_A, name_A)
+ if args.use_AB:
+ name_B = name_A.replace('_A.', '_B.')
+ else:
+ name_B = name_A
+ path_B = os.path.join(img_fold_B, name_B)
+ if os.path.isfile(path_A) and os.path.isfile(path_B):
+ name_AB = name_A
+ if args.use_AB:
+ name_AB = name_AB.replace('_A.', '.') # remove _A
+ path_AB = os.path.join(img_fold_AB, name_AB)
+ im_A = cv2.imread(path_A, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
+ im_B = cv2.imread(path_B, 1) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
+ im_AB = np.concatenate([im_A, im_B], 1)
+ cv2.imwrite(path_AB, im_AB)
diff --git a/datasets/detect_cat_face.py b/datasets/detect_cat_face.py
new file mode 100644
index 0000000..13cfd61
--- /dev/null
+++ b/datasets/detect_cat_face.py
@@ -0,0 +1,64 @@
+import cv2
+import os
+import glob
+import argparse
+
+
+def get_file_paths(folder):
+ image_file_paths = []
+ for root, dirs, filenames in os.walk(folder):
+ filenames = sorted(filenames)
+ for filename in filenames:
+ input_path = os.path.abspath(root)
+ file_path = os.path.join(input_path, filename)
+ if filename.endswith('.png') or filename.endswith('.jpg'):
+ image_file_paths.append(file_path)
+
+ break # prevent descending into subfolders
+ return image_file_paths
+
+
+SF = 1.05
+N = 3
+
+
+def detect_cat(img_path, cat_cascade, output_dir, ratio=0.05, border_ratio=0.25):
+ print('processing {}'.format(img_path))
+ output_width = 286
+ img = cv2.imread(img_path)
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ H, W = img.shape[0], img.shape[1]
+ minH = int(H * ratio)
+ minW = int(W * ratio)
+ cats = cat_cascade.detectMultiScale(gray, scaleFactor=SF, minNeighbors=N, minSize=(minH, minW))
+
+ for cat_id, (x, y, w, h) in enumerate(cats):
+ x1 = max(0, x - w * border_ratio)
+ x2 = min(W, x + w * (1 + border_ratio))
+ y1 = max(0, y - h * border_ratio)
+ y2 = min(H, y + h * (1 + border_ratio))
+ img_crop = img[int(y1):int(y2), int(x1):int(x2)]
+ img_name = os.path.basename(img_path)
+ out_path = os.path.join(output_dir, img_name.replace('.jpg', '_cat%d.jpg' % cat_id))
+ print('write', out_path)
+ img_crop = cv2.resize(img_crop, (output_width, output_width), interpolation=cv2.INTER_CUBIC)
+ cv2.imwrite(out_path, img_crop, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='detecting cat faces using opencv detector')
+ parser.add_argument('--input_dir', type=str, help='input image directory')
+ parser.add_argument('--output_dir', type=str, help='wihch directory to store cropped cat faces')
+ parser.add_argument('--use_ext', action='store_true', help='if use haarcascade_frontalcatface_extended or not')
+ args = parser.parse_args()
+
+ if args.use_ext:
+ cat_cascade = cv2.CascadeClassifier('haarcascade_frontalcatface.xml')
+ else:
+ cat_cascade = cv2.CascadeClassifier('haarcascade_frontalcatface_extended.xml')
+ img_paths = get_file_paths(args.input_dir)
+ print('total number of images {} from {}'.format(len(img_paths), args.input_dir))
+ if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir)
+ for img_path in img_paths:
+ detect_cat(img_path, cat_cascade, args.output_dir)
diff --git a/datasets/download_cut_dataset.sh b/datasets/download_cut_dataset.sh
new file mode 100644
index 0000000..d1ff919
--- /dev/null
+++ b/datasets/download_cut_dataset.sh
@@ -0,0 +1,23 @@
+set -ex
+
+FILE=$1
+
+if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "mini" && $FILE != "mini_pix2pix" && $FILE != "mini_colorization" && $FILE != "grumpifycat" ]]; then
+ echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos, grumpifycat"
+ exit 1
+fi
+
+if [[ $FILE == "cityscapes" ]]; then
+ echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py."
+ echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py"
+ exit 1
+fi
+
+echo "Specified [$FILE]"
+URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
+ZIP_FILE=./datasets/$FILE.zip
+TARGET_DIR=./datasets/$FILE/
+wget --no-check-certificate -N $URL -O $ZIP_FILE
+mkdir $TARGET_DIR
+unzip $ZIP_FILE -d ./datasets/
+rm $ZIP_FILE
diff --git a/datasets/download_pix2pix_dataset.sh b/datasets/download_pix2pix_dataset.sh
new file mode 100644
index 0000000..a7d09da
--- /dev/null
+++ b/datasets/download_pix2pix_dataset.sh
@@ -0,0 +1,24 @@
+set -ex
+
+FILE=$1
+
+if [[ $FILE != "cityscapes" && $FILE != "night2day" && $FILE != "edges2handbags" && $FILE != "edges2shoes" && $FILE != "facades" && $FILE != "maps" ]]; then
+ echo "Available datasets are cityscapes, night2day, edges2handbags, edges2shoes, facades, maps"
+ exit 1
+fi
+
+if [[ $FILE == "cityscapes" ]]; then
+ echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py."
+ echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py"
+ exit 1
+fi
+
+echo "Specified [$FILE]"
+
+URL=http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/$FILE.tar.gz
+TAR_FILE=./datasets/$FILE.tar.gz
+TARGET_DIR=./datasets/$FILE/
+wget -N $URL -O $TAR_FILE
+mkdir -p $TARGET_DIR
+tar -zxvf $TAR_FILE -C ./datasets/
+rm $TAR_FILE
diff --git a/datasets/make_dataset_aligned.py b/datasets/make_dataset_aligned.py
new file mode 100644
index 0000000..739c767
--- /dev/null
+++ b/datasets/make_dataset_aligned.py
@@ -0,0 +1,63 @@
+import os
+
+from PIL import Image
+
+
+def get_file_paths(folder):
+ image_file_paths = []
+ for root, dirs, filenames in os.walk(folder):
+ filenames = sorted(filenames)
+ for filename in filenames:
+ input_path = os.path.abspath(root)
+ file_path = os.path.join(input_path, filename)
+ if filename.endswith('.png') or filename.endswith('.jpg'):
+ image_file_paths.append(file_path)
+
+ break # prevent descending into subfolders
+ return image_file_paths
+
+
+def align_images(a_file_paths, b_file_paths, target_path):
+ if not os.path.exists(target_path):
+ os.makedirs(target_path)
+
+ for i in range(len(a_file_paths)):
+ img_a = Image.open(a_file_paths[i])
+ img_b = Image.open(b_file_paths[i])
+ assert(img_a.size == img_b.size)
+
+ aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1]))
+ aligned_image.paste(img_a, (0, 0))
+ aligned_image.paste(img_b, (img_a.size[0], 0))
+ aligned_image.save(os.path.join(target_path, '{:04d}.jpg'.format(i)))
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--dataset-path',
+ dest='dataset_path',
+ help='Which folder to process (it should have subfolders testA, testB, trainA and trainB'
+ )
+ args = parser.parse_args()
+
+ dataset_folder = args.dataset_path
+ print(dataset_folder)
+
+ test_a_path = os.path.join(dataset_folder, 'testA')
+ test_b_path = os.path.join(dataset_folder, 'testB')
+ test_a_file_paths = get_file_paths(test_a_path)
+ test_b_file_paths = get_file_paths(test_b_path)
+ assert(len(test_a_file_paths) == len(test_b_file_paths))
+ test_path = os.path.join(dataset_folder, 'test')
+
+ train_a_path = os.path.join(dataset_folder, 'trainA')
+ train_b_path = os.path.join(dataset_folder, 'trainB')
+ train_a_file_paths = get_file_paths(train_a_path)
+ train_b_file_paths = get_file_paths(train_b_path)
+ assert(len(train_a_file_paths) == len(train_b_file_paths))
+ train_path = os.path.join(dataset_folder, 'train')
+
+ align_images(test_a_file_paths, test_b_file_paths, test_path)
+ align_images(train_a_file_paths, train_b_file_paths, train_path)
diff --git a/datasets/prepare_cityscapes_dataset.py b/datasets/prepare_cityscapes_dataset.py
new file mode 100644
index 0000000..2ff21af
--- /dev/null
+++ b/datasets/prepare_cityscapes_dataset.py
@@ -0,0 +1,90 @@
+import os
+import glob
+from PIL import Image
+
+help_msg = """
+The dataset can be downloaded from https://cityscapes-dataset.com.
+Please download the datasets [gtFine_trainvaltest.zip] and [leftImg8bit_trainvaltest.zip] and unzip them.
+gtFine contains the semantics segmentations. Use --gtFine_dir to specify the path to the unzipped gtFine_trainvaltest directory.
+leftImg8bit contains the dashcam photographs. Use --leftImg8bit_dir to specify the path to the unzipped leftImg8bit_trainvaltest directory.
+The processed images will be placed at --output_dir.
+
+Example usage:
+
+python prepare_cityscapes_dataset.py --gitFine_dir ./gtFine/ --leftImg8bit_dir ./leftImg8bit --output_dir ./datasets/cityscapes/
+"""
+
+
+def load_resized_img(path):
+ return Image.open(path).convert('RGB').resize((256, 256))
+
+
+def check_matching_pair(segmap_path, photo_path):
+ segmap_identifier = os.path.basename(segmap_path).replace('_gtFine_color', '')
+ photo_identifier = os.path.basename(photo_path).replace('_leftImg8bit', '')
+
+ assert segmap_identifier == photo_identifier, \
+ "[%s] and [%s] don't seem to be matching. Aborting." % (segmap_path, photo_path)
+
+
+def process_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase):
+ save_phase = 'test' if phase == 'val' else 'train'
+ savedir = os.path.join(output_dir, save_phase)
+ os.makedirs(savedir, exist_ok=True)
+ os.makedirs(savedir + 'A', exist_ok=True)
+ os.makedirs(savedir + 'B', exist_ok=True)
+ print("Directory structure prepared at %s" % output_dir)
+
+ segmap_expr = os.path.join(gtFine_dir, phase) + "/*/*_color.png"
+ segmap_paths = glob.glob(segmap_expr)
+ segmap_paths = sorted(segmap_paths)
+
+ photo_expr = os.path.join(leftImg8bit_dir, phase) + "/*/*_leftImg8bit.png"
+ photo_paths = glob.glob(photo_expr)
+ photo_paths = sorted(photo_paths)
+
+ assert len(segmap_paths) == len(photo_paths), \
+ "%d images that match [%s], and %d images that match [%s]. Aborting." % (len(segmap_paths), segmap_expr, len(photo_paths), photo_expr)
+
+ for i, (segmap_path, photo_path) in enumerate(zip(segmap_paths, photo_paths)):
+ check_matching_pair(segmap_path, photo_path)
+ segmap = load_resized_img(segmap_path)
+ photo = load_resized_img(photo_path)
+
+ # data for pix2pix where the two images are placed side-by-side
+ sidebyside = Image.new('RGB', (512, 256))
+ sidebyside.paste(segmap, (256, 0))
+ sidebyside.paste(photo, (0, 0))
+ savepath = os.path.join(savedir, "%d.jpg" % i)
+ sidebyside.save(savepath, format='JPEG', subsampling=0, quality=100)
+
+ # data for cyclegan where the two images are stored at two distinct directories
+ savepath = os.path.join(savedir + 'A', "%d_A.jpg" % i)
+ photo.save(savepath, format='JPEG', subsampling=0, quality=100)
+ savepath = os.path.join(savedir + 'B', "%d_B.jpg" % i)
+ segmap.save(savepath, format='JPEG', subsampling=0, quality=100)
+
+ if i % (len(segmap_paths) // 10) == 0:
+ print("%d / %d: last image saved at %s, " % (i, len(segmap_paths), savepath))
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--gtFine_dir', type=str, required=True,
+ help='Path to the Cityscapes gtFine directory.')
+ parser.add_argument('--leftImg8bit_dir', type=str, required=True,
+ help='Path to the Cityscapes leftImg8bit_trainvaltest directory.')
+ parser.add_argument('--output_dir', type=str, required=True,
+ default='./datasets/cityscapes',
+ help='Directory the output images will be written to.')
+ opt = parser.parse_args()
+
+ print(help_msg)
+
+ print('Preparing Cityscapes Dataset for val phase')
+ process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "val")
+ print('Preparing Cityscapes Dataset for train phase')
+ process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "train")
+
+ print('Done')
diff --git a/datasets/single_image_monet_etretat/trainA/monet.jpg b/datasets/single_image_monet_etretat/trainA/monet.jpg
new file mode 100644
index 0000000..738c1cd
Binary files /dev/null and b/datasets/single_image_monet_etretat/trainA/monet.jpg differ
diff --git a/datasets/single_image_monet_etretat/trainB/etretat-normandy-france.jpg b/datasets/single_image_monet_etretat/trainB/etretat-normandy-france.jpg
new file mode 100644
index 0000000..41aabf6
Binary files /dev/null and b/datasets/single_image_monet_etretat/trainB/etretat-normandy-france.jpg differ
diff --git a/images/method_final.jpg b/images/method_final.jpg
new file mode 100644
index 0000000..c443960
Binary files /dev/null and b/images/method_final.jpg differ
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..fc01113
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,67 @@
+"""This package contains modules related to objective functions, optimizations, and network architectures.
+
+To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
+You need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate loss, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+
+In the function <__init__>, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): define networks used in our training.
+ -- self.visual_names (str list): specify the images that you want to display and save.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
+
+Now you can use the model class by specifying flag '--model dummy'.
+See our template model class 'template_model.py' for more details.
+"""
+
+import importlib
+from models.base_model import BaseModel
+
+
+def find_model_using_name(model_name):
+ """Import the module "models/[model_name]_model.py".
+
+ In the file, the class called DatasetNameModel() will
+ be instantiated. It has to be a subclass of BaseModel,
+ and it is case-insensitive.
+ """
+ model_filename = "models." + model_name + "_model"
+ modellib = importlib.import_module(model_filename)
+ model = None
+ target_model_name = model_name.replace('_', '') + 'model'
+ for name, cls in modellib.__dict__.items():
+ if name.lower() == target_model_name.lower() \
+ and issubclass(cls, BaseModel):
+ model = cls
+
+ if model is None:
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
+ exit(0)
+
+ return model
+
+
+def get_option_setter(model_name):
+ """Return the static method of the model class."""
+ model_class = find_model_using_name(model_name)
+ return model_class.modify_commandline_options
+
+
+def create_model(opt):
+ """Create a model given the option.
+
+ This function warps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from models import create_model
+ >>> model = create_model(opt)
+ """
+ model = find_model_using_name(opt.model)
+ instance = model(opt)
+ print("model [%s] was created" % type(instance).__name__)
+ return instance
diff --git a/models/__pycache__/__init__.cpython-36.pyc b/models/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..111de2f
Binary files /dev/null and b/models/__pycache__/__init__.cpython-36.pyc differ
diff --git a/models/__pycache__/base_model.cpython-36.pyc b/models/__pycache__/base_model.cpython-36.pyc
new file mode 100644
index 0000000..9927d69
Binary files /dev/null and b/models/__pycache__/base_model.cpython-36.pyc differ
diff --git a/models/__pycache__/cut_model.cpython-36.pyc b/models/__pycache__/cut_model.cpython-36.pyc
new file mode 100644
index 0000000..5a8ac7f
Binary files /dev/null and b/models/__pycache__/cut_model.cpython-36.pyc differ
diff --git a/models/__pycache__/mae.cpython-36.pyc b/models/__pycache__/mae.cpython-36.pyc
new file mode 100644
index 0000000..fd29e4a
Binary files /dev/null and b/models/__pycache__/mae.cpython-36.pyc differ
diff --git a/models/__pycache__/models_mae.cpython-36.pyc b/models/__pycache__/models_mae.cpython-36.pyc
new file mode 100644
index 0000000..7e36cbe
Binary files /dev/null and b/models/__pycache__/models_mae.cpython-36.pyc differ
diff --git a/models/__pycache__/mutilvitgloballocal_model.cpython-36.pyc b/models/__pycache__/mutilvitgloballocal_model.cpython-36.pyc
new file mode 100644
index 0000000..0efc91c
Binary files /dev/null and b/models/__pycache__/mutilvitgloballocal_model.cpython-36.pyc differ
diff --git a/models/__pycache__/networks.cpython-36.pyc b/models/__pycache__/networks.cpython-36.pyc
new file mode 100644
index 0000000..b3e43d1
Binary files /dev/null and b/models/__pycache__/networks.cpython-36.pyc differ
diff --git a/models/__pycache__/patchnce.cpython-36.pyc b/models/__pycache__/patchnce.cpython-36.pyc
new file mode 100644
index 0000000..daec1cc
Binary files /dev/null and b/models/__pycache__/patchnce.cpython-36.pyc differ
diff --git a/models/__pycache__/region0_model.cpython-36.pyc b/models/__pycache__/region0_model.cpython-36.pyc
new file mode 100644
index 0000000..fd2ae28
Binary files /dev/null and b/models/__pycache__/region0_model.cpython-36.pyc differ
diff --git a/models/__pycache__/region_model.cpython-36.pyc b/models/__pycache__/region_model.cpython-36.pyc
new file mode 100644
index 0000000..fd86bf6
Binary files /dev/null and b/models/__pycache__/region_model.cpython-36.pyc differ
diff --git a/models/__pycache__/stylegan_networks.cpython-36.pyc b/models/__pycache__/stylegan_networks.cpython-36.pyc
new file mode 100644
index 0000000..f046dc1
Binary files /dev/null and b/models/__pycache__/stylegan_networks.cpython-36.pyc differ
diff --git a/models/__pycache__/vit2Gmask_model.cpython-36.pyc b/models/__pycache__/vit2Gmask_model.cpython-36.pyc
new file mode 100644
index 0000000..f9e9017
Binary files /dev/null and b/models/__pycache__/vit2Gmask_model.cpython-36.pyc differ
diff --git a/models/__pycache__/vit2_model.cpython-36.pyc b/models/__pycache__/vit2_model.cpython-36.pyc
new file mode 100644
index 0000000..860549e
Binary files /dev/null and b/models/__pycache__/vit2_model.cpython-36.pyc differ
diff --git a/models/__pycache__/vit2patchmask_model.cpython-36.pyc b/models/__pycache__/vit2patchmask_model.cpython-36.pyc
new file mode 100644
index 0000000..2186fcf
Binary files /dev/null and b/models/__pycache__/vit2patchmask_model.cpython-36.pyc differ
diff --git a/models/__pycache__/vit2tokenmask_model.cpython-36.pyc b/models/__pycache__/vit2tokenmask_model.cpython-36.pyc
new file mode 100644
index 0000000..a9a67f4
Binary files /dev/null and b/models/__pycache__/vit2tokenmask_model.cpython-36.pyc differ
diff --git a/models/__pycache__/vitD_model.cpython-36.pyc b/models/__pycache__/vitD_model.cpython-36.pyc
new file mode 100644
index 0000000..43e653f
Binary files /dev/null and b/models/__pycache__/vitD_model.cpython-36.pyc differ
diff --git a/models/__pycache__/vit_model.cpython-36.pyc b/models/__pycache__/vit_model.cpython-36.pyc
new file mode 100644
index 0000000..c39955e
Binary files /dev/null and b/models/__pycache__/vit_model.cpython-36.pyc differ
diff --git a/models/__pycache__/vitdonly2_model.cpython-36.pyc b/models/__pycache__/vitdonly2_model.cpython-36.pyc
new file mode 100644
index 0000000..fed8f18
Binary files /dev/null and b/models/__pycache__/vitdonly2_model.cpython-36.pyc differ
diff --git a/models/__pycache__/vitdonly_model.cpython-36.pyc b/models/__pycache__/vitdonly_model.cpython-36.pyc
new file mode 100644
index 0000000..1d39131
Binary files /dev/null and b/models/__pycache__/vitdonly_model.cpython-36.pyc differ
diff --git a/models/__pycache__/vitgloballocal_model.cpython-36.pyc b/models/__pycache__/vitgloballocal_model.cpython-36.pyc
new file mode 100644
index 0000000..f071980
Binary files /dev/null and b/models/__pycache__/vitgloballocal_model.cpython-36.pyc differ
diff --git a/models/__pycache__/vitlocalgloballocal_model.cpython-36.pyc b/models/__pycache__/vitlocalgloballocal_model.cpython-36.pyc
new file mode 100644
index 0000000..a6b59ad
Binary files /dev/null and b/models/__pycache__/vitlocalgloballocal_model.cpython-36.pyc differ
diff --git a/models/base_model.py b/models/base_model.py
new file mode 100644
index 0000000..37bc25f
--- /dev/null
+++ b/models/base_model.py
@@ -0,0 +1,258 @@
+import os
+import torch
+from collections import OrderedDict
+from abc import ABC, abstractmethod
+from . import networks
+
+
+class BaseModel(ABC):
+ """This class is an abstract base class (ABC) for models.
+ To create a subclass, you need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate losses, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the BaseModel class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ When creating your custom class, you need to implement your own initialization.
+ In this fucntion, you should first call
+ Then, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): specify the images that you want to display and save.
+ -- self.visual_names (str list): define networks used in our training.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ """
+ self.opt = opt
+ self.gpu_ids = opt.gpu_ids
+ self.isTrain = opt.isTrain
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
+ if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
+ torch.backends.cudnn.benchmark = True
+ self.loss_names = []
+ self.model_names = []
+ self.visual_names = []
+ self.optimizers = []
+ self.image_paths = []
+ self.metric = 0 # used for learning rate policy 'plateau'
+
+ @staticmethod
+ def dict_grad_hook_factory(add_func=lambda x: x):
+ saved_dict = dict()
+
+ def hook_gen(name):
+ def grad_hook(grad):
+ saved_vals = add_func(grad)
+ saved_dict[name] = saved_vals
+ return grad_hook
+ return hook_gen, saved_dict
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new model-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input (dict): includes the data itself and its metadata information.
+ """
+ pass
+
+ @abstractmethod
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ pass
+
+ @abstractmethod
+ def optimize_parameters(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ pass
+
+ def setup(self, opt):
+ """Load and print networks; create schedulers
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ if self.isTrain:
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
+ if not self.isTrain or opt.continue_train:
+ load_suffix = opt.epoch
+ self.load_networks(load_suffix)
+
+ self.print_networks(opt.verbose)
+
+ def parallelize(self):
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, 'net' + name)
+ setattr(self, 'net' + name, torch.nn.DataParallel(net, self.opt.gpu_ids))
+
+ def data_dependent_initialize(self, data):
+ pass
+
+ def eval(self):
+ """Make models eval mode during test time"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, 'net' + name)
+ net.eval()
+
+ def test(self):
+ """Forward function used in test time.
+
+ This function wraps function in no_grad() so we don't save intermediate steps for backprop
+ It also calls to produce additional visualization results
+ """
+ with torch.no_grad():
+ self.forward()
+ self.compute_visuals()
+
+ def compute_visuals(self):
+ """Calculate additional output images for visdom and HTML visualization"""
+ pass
+
+ def get_image_paths(self):
+ """ Return image paths that are used to load current data"""
+ return self.image_paths
+
+ def update_learning_rate(self):
+ """Update learning rates for all the networks; called at the end of every epoch"""
+ for scheduler in self.schedulers:
+ if self.opt.lr_policy == 'plateau':
+ scheduler.step(self.metric)
+ else:
+ scheduler.step()
+
+ lr = self.optimizers[0].param_groups[0]['lr']
+ print('learning rate = %.7f' % lr)
+
+ def get_current_visuals(self):
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
+ visual_ret = OrderedDict()
+ for name in self.visual_names:
+ if isinstance(name, str):
+ visual_ret[name] = getattr(self, name)
+ return visual_ret
+
+ def get_current_losses(self):
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
+ errors_ret = OrderedDict()
+ for name in self.loss_names:
+ if isinstance(name, str):
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
+ return errors_ret
+
+ def save_networks(self, epoch):
+ """Save all the networks to the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ for name in self.model_names:
+ if isinstance(name, str):
+ save_filename = '%s_net_%s.pth' % (epoch, name)
+ save_path = os.path.join(self.save_dir, save_filename)
+ net = getattr(self, 'net' + name)
+
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
+ torch.save(net.module.cpu().state_dict(), save_path)
+ net.cuda(self.gpu_ids[0])
+ else:
+ torch.save(net.cpu().state_dict(), save_path)
+
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
+ key = keys[i]
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'running_mean' or key == 'running_var'):
+ if getattr(module, key) is None:
+ state_dict.pop('.'.join(keys))
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'num_batches_tracked'):
+ state_dict.pop('.'.join(keys))
+ else:
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+ def load_networks(self, epoch):
+ """Load all the networks from the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ for name in self.model_names:
+ if isinstance(name, str):
+ load_filename = '%s_net_%s.pth' % (epoch, name)
+ if self.opt.isTrain and self.opt.pretrained_name is not None:
+ load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
+ else:
+ load_dir = self.save_dir
+
+ load_path = os.path.join(load_dir, load_filename)
+ net = getattr(self, 'net' + name)
+ if isinstance(net, torch.nn.DataParallel):
+ net = net.module
+ print('loading the model from %s' % load_path)
+ # if you are using PyTorch newer than 0.4 (e.g., built from
+ # GitHub source), you can remove str() on self.device
+ state_dict = torch.load(load_path, map_location=str(self.device))
+ if hasattr(state_dict, '_metadata'):
+ del state_dict._metadata
+
+ # patch InstanceNorm checkpoints prior to 0.4
+ # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
+ # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
+ net.load_state_dict(state_dict)
+
+ def print_networks(self, verbose):
+ """Print the total number of parameters in the network and (if verbose) network architecture
+
+ Parameters:
+ verbose (bool) -- if verbose: print the network architecture
+ """
+ print('---------- Networks initialized -------------')
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, 'net' + name)
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ if verbose:
+ print(net)
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
+ print('-----------------------------------------------')
+
+ def set_requires_grad(self, nets, requires_grad=False):
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
+ Parameters:
+ nets (network list) -- a list of networks
+ requires_grad (bool) -- whether the networks require gradients or not
+ """
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
+
+ def generate_visuals_for_evaluation(self, data, mode):
+ return {}
diff --git a/models/cut_model.py b/models/cut_model.py
new file mode 100644
index 0000000..cd4a191
--- /dev/null
+++ b/models/cut_model.py
@@ -0,0 +1,214 @@
+import numpy as np
+import torch
+from .base_model import BaseModel
+from . import networks
+from .patchnce import PatchNCELoss
+import util.util as util
+
+
+class CUTModel(BaseModel):
+ """ This class implements CUT and FastCUT model, described in the paper
+ Contrastive Learning for Unpaired Image-to-Image Translation
+ Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
+ ECCV, 2020
+
+ The code borrows heavily from the PyTorch implementation of CycleGAN
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
+ """
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """ Configures options specific for CUT model
+ """
+ parser.add_argument('--CUT_mode', type=str, default="CUT", choices='(CUT, cut, FastCUT, fastcut)')
+
+ parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))')
+ parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
+ parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
+ parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
+ parser.add_argument('--nce_includes_all_negatives_from_minibatch',
+ type=util.str2bool, nargs='?', const=True, default=False,
+ help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
+ parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
+ parser.add_argument('--netF_nc', type=int, default=256)
+ parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
+ parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
+ parser.add_argument('--flip_equivariance',
+ type=util.str2bool, nargs='?', const=True, default=False,
+ help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
+
+ parser.set_defaults(pool_size=0) # no image pooling
+
+ opt, _ = parser.parse_known_args()
+
+ # Set default parameters for CUT and FastCUT
+ if opt.CUT_mode.lower() == "cut":
+ parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
+ elif opt.CUT_mode.lower() == "fastcut":
+ parser.set_defaults(
+ nce_idt=False, lambda_NCE=10.0, flip_equivariance=True,
+ n_epochs=150, n_epochs_decay=50
+ )
+ else:
+ raise ValueError(opt.CUT_mode)
+
+ return parser
+
+ def __init__(self, opt):
+ BaseModel.__init__(self, opt)
+
+ # specify the training losses you want to print out.
+ # The training/test scripts will call
+ self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G', 'NCE']
+ self.visual_names = ['real_A', 'fake_B', 'real_B']
+ self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
+
+ if opt.nce_idt and self.isTrain:
+ self.loss_names += ['NCE_Y']
+ self.visual_names += ['idt_B']
+
+ if self.isTrain:
+ self.model_names = ['G', 'F', 'D']
+ else: # during test time, only load G
+ self.model_names = ['G']
+
+ # define networks (both generator and discriminator)
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
+ self.netF = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
+
+ if self.isTrain:
+ self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
+
+ # define loss functions
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
+ self.criterionNCE = []
+
+ for nce_layer in self.nce_layers:
+ self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
+
+ self.criterionIdt = torch.nn.L1Loss().to(self.device)
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
+ self.optimizers.append(self.optimizer_G)
+ self.optimizers.append(self.optimizer_D)
+
+ def data_dependent_initialize(self, data):
+ """
+ The feature network netF is defined in terms of the shape of the intermediate, extracted
+ features of the encoder portion of netG. Because of this, the weights of netF are
+ initialized at the first feedforward pass with some input images.
+ Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
+ """
+ self.set_input(data)
+ bs_per_gpu = self.real_A.size(0) // max(len(self.opt.gpu_ids), 1)
+ self.real_A = self.real_A[:bs_per_gpu]
+ self.real_B = self.real_B[:bs_per_gpu]
+ self.forward() # compute fake images: G(A)
+ if self.opt.isTrain:
+ self.compute_D_loss().backward() # calculate gradients for D
+ self.compute_G_loss().backward() # calculate graidents for G
+ if self.opt.lambda_NCE > 0.0:
+ self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
+ self.optimizers.append(self.optimizer_F)
+
+ def optimize_parameters(self):
+ # forward
+ self.forward()
+
+ # update D
+ self.set_requires_grad(self.netD, True)
+ self.optimizer_D.zero_grad()
+ self.loss_D = self.compute_D_loss()
+ self.loss_D.backward()
+ self.optimizer_D.step()
+
+ # update G
+ self.set_requires_grad(self.netD, False)
+ self.optimizer_G.zero_grad()
+ if self.opt.netF == 'mlp_sample':
+ self.optimizer_F.zero_grad()
+ self.loss_G = self.compute_G_loss()
+ self.loss_G.backward()
+ self.optimizer_G.step()
+ if self.opt.netF == 'mlp_sample':
+ self.optimizer_F.step()
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+ Parameters:
+ input (dict): include the data itself and its metadata information.
+ The option 'direction' can be used to swap domain A and domain B.
+ """
+ AtoB = self.opt.direction == 'AtoB'
+ self.real_A = input['A' if AtoB else 'B'].to(self.device)
+ self.real_B = input['B' if AtoB else 'A'].to(self.device)
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
+
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
+ if self.opt.flip_equivariance:
+ self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
+ if self.flipped_for_equivariance:
+ self.real = torch.flip(self.real, [3])
+
+ self.fake = self.netG(self.real)
+ self.fake_B = self.fake[:self.real_A.size(0)]
+ if self.opt.nce_idt:
+ self.idt_B = self.fake[self.real_A.size(0):]
+
+ def compute_D_loss(self):
+ """Calculate GAN loss for the discriminator"""
+ fake = self.fake_B.detach()
+ # Fake; stop backprop to the generator by detaching fake_B
+ pred_fake = self.netD(fake)
+ self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
+ # Real
+ self.pred_real = self.netD(self.real_B)
+ loss_D_real = self.criterionGAN(self.pred_real, True)
+ self.loss_D_real = loss_D_real.mean()
+
+ # combine loss and calculate gradients
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
+ return self.loss_D
+
+ def compute_G_loss(self):
+ """Calculate GAN and NCE loss for the generator"""
+ fake = self.fake_B
+ # First, G(A) should fake the discriminator
+ if self.opt.lambda_GAN > 0.0:
+ pred_fake = self.netD(fake)
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
+ else:
+ self.loss_G_GAN = 0.0
+
+ if self.opt.lambda_NCE > 0.0:
+ self.loss_NCE = self.calculate_NCE_loss(self.real_A, self.fake_B)
+ else:
+ self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0
+
+ if self.opt.nce_idt and self.opt.lambda_NCE > 0.0:
+ self.loss_NCE_Y = self.calculate_NCE_loss(self.real_B, self.idt_B)
+ loss_NCE_both = (self.loss_NCE + self.loss_NCE_Y) * 0.5
+ else:
+ loss_NCE_both = self.loss_NCE
+
+ self.loss_G = self.loss_G_GAN + loss_NCE_both
+ return self.loss_G
+
+ def calculate_NCE_loss(self, src, tgt):
+ n_layers = len(self.nce_layers)
+ feat_q = self.netG(tgt, self.nce_layers, encode_only=True)
+
+ if self.opt.flip_equivariance and self.flipped_for_equivariance:
+ feat_q = [torch.flip(fq, [3]) for fq in feat_q]
+
+ feat_k = self.netG(src, self.nce_layers, encode_only=True)
+ feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
+ feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)
+
+ total_nce_loss = 0.0
+ for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
+ loss = crit(f_q, f_k) * self.opt.lambda_NCE
+ total_nce_loss += loss.mean()
+
+ return total_nce_loss / n_layers
diff --git a/models/cycle_gan_model.py b/models/cycle_gan_model.py
new file mode 100644
index 0000000..0e0874b
--- /dev/null
+++ b/models/cycle_gan_model.py
@@ -0,0 +1,222 @@
+import torch
+import itertools
+from util.image_pool import ImagePool
+from .base_model import BaseModel
+from . import networks
+try:
+ from apex import amp
+except ImportError as error:
+ print(error)
+
+
+class CycleGANModel(BaseModel):
+ """
+ This class implements the CycleGAN model, for learning image-to-image translation without paired data.
+
+ The model training requires '--dataset_mode unaligned' dataset.
+ By default, it uses a '--netG resnet_9blocks' ResNet generator,
+ a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
+ and a least-square GANs objective ('--gan_mode lsgan').
+
+ CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
+ """
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """Add new dataset-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+
+ For CycleGAN, in addition to GAN losses, we introduce lambda_A, lambda_B, and lambda_identity for the following losses.
+ A (source domain), B (target domain).
+ Generators: G_A: A -> B; G_B: B -> A.
+ Discriminators: D_A: G_A(A) vs. B; D_B: G_B(B) vs. A.
+ Forward cycle loss: lambda_A * ||G_B(G_A(A)) - A|| (Eqn. (2) in the paper)
+ Backward cycle loss: lambda_B * ||G_A(G_B(B)) - B|| (Eqn. (2) in the paper)
+ Identity loss (optional): lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A) (Sec 5.2 "Photo generation from paintings" in the paper)
+ Dropout is not used in the original CycleGAN paper.
+ """
+ # parser.set_defaults(no_dropout=True, no_antialias=True, no_antialias_up=True) # default CycleGAN did not use dropout
+ # parser.set_defaults(no_dropout=True)
+ if is_train:
+ parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
+ parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
+ parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
+
+ return parser
+
+ def __init__(self, opt):
+ """Initialize the CycleGAN class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ BaseModel.__init__(self, opt)
+ # specify the training losses you want to print out. The training/test scripts will call
+ self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
+ # specify the images you want to save/display. The training/test scripts will call
+ visual_names_A = ['real_A', 'fake_B', 'rec_A']
+ visual_names_B = ['real_B', 'fake_A', 'rec_B']
+ if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
+ visual_names_A.append('idt_B')
+ visual_names_B.append('idt_A')
+
+ self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
+ # specify the models you want to save to the disk. The training/test scripts will call and .
+ if self.isTrain:
+ self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
+ else: # during test time, only load Gs
+ self.model_names = ['G_A', 'G_B']
+
+ # define networks (both Generators and discriminators)
+ # The naming is different from those used in the paper.
+ # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
+ self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG,
+ not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt=opt)
+ self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.normG,
+ not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt=opt)
+
+ if self.isTrain: # define discriminators
+ self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
+ opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt=opt)
+ self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
+ opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt=opt)
+
+ if self.isTrain:
+ if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
+ assert(opt.input_nc == opt.output_nc)
+ self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
+ self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
+ # define loss functions
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.
+ self.criterionCycle = torch.nn.L1Loss()
+ self.criterionIdt = torch.nn.L1Loss()
+ # initialize optimizers; schedulers will be automatically created by function .
+ self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizers.append(self.optimizer_G)
+ self.optimizers.append(self.optimizer_D)
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input (dict): include the data itself and its metadata information.
+
+ The option 'direction' can be used to swap domain A and domain B.
+ """
+ AtoB = self.opt.direction == 'AtoB'
+ self.real_A = input['A' if AtoB else 'B'].to(self.device)
+ self.real_B = input['B' if AtoB else 'A'].to(self.device)
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
+
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ self.fake_B = self.netG_A(self.real_A) # G_A(A)
+ self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
+ self.fake_A = self.netG_B(self.real_B) # G_B(B)
+ self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
+
+ def backward_D_basic(self, netD, real, fake):
+ """Calculate GAN loss for the discriminator
+
+ Parameters:
+ netD (network) -- the discriminator D
+ real (tensor array) -- real images
+ fake (tensor array) -- images generated by a generator
+
+ Return the discriminator loss.
+ We also call loss_D.backward() to calculate the gradients.
+ """
+ # Real
+ pred_real = netD(real)
+ loss_D_real = self.criterionGAN(pred_real, True)
+ # Fake
+ pred_fake = netD(fake.detach())
+ loss_D_fake = self.criterionGAN(pred_fake, False)
+ # Combined loss and calculate gradients
+ loss_D = (loss_D_real + loss_D_fake) * 0.5
+ if self.opt.amp:
+ with amp.scale_loss(loss_D, self.optimizer_D) as scaled_loss:
+ scaled_loss.backward()
+ else:
+ loss_D.backward()
+ return loss_D
+
+ def backward_D_A(self):
+ """Calculate GAN loss for discriminator D_A"""
+ fake_B = self.fake_B_pool.query(self.fake_B)
+ self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
+
+ def backward_D_B(self):
+ """Calculate GAN loss for discriminator D_B"""
+ fake_A = self.fake_A_pool.query(self.fake_A)
+ self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
+
+ def backward_G(self):
+ """Calculate the loss for generators G_A and G_B"""
+ lambda_idt = self.opt.lambda_identity
+ lambda_A = self.opt.lambda_A
+ lambda_B = self.opt.lambda_B
+ # Identity loss
+ if lambda_idt > 0:
+ # G_A should be identity if real_B is fed: ||G_A(B) - B||
+ self.idt_A = self.netG_A(self.real_B)
+ self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
+ # G_B should be identity if real_A is fed: ||G_B(A) - A||
+ self.idt_B = self.netG_B(self.real_A)
+ self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
+ else:
+ self.loss_idt_A = 0
+ self.loss_idt_B = 0
+
+ # GAN loss D_A(G_A(A))
+ self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
+ # GAN loss D_B(G_B(B))
+ self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
+ # Forward cycle loss || G_B(G_A(A)) - A||
+ self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
+ # Backward cycle loss || G_A(G_B(B)) - B||
+ self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
+ # combined loss and calculate gradients
+ self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
+ if self.opt.amp:
+ with amp.scale_loss(self.loss_G, self.optimizer_G) as scaled_loss:
+ scaled_loss.backward()
+ else:
+ self.loss_G.backward()
+
+ def data_dependent_initialize(self):
+ return
+
+ def generate_visuals_for_evaluation(self, data, mode):
+ with torch.no_grad():
+ visuals = {}
+ AtoB = self.opt.direction == "AtoB"
+ G = self.netG_A
+ source = data["A" if AtoB else "B"].to(self.device)
+ if mode == "forward":
+ visuals["fake_B"] = G(source)
+ else:
+ raise ValueError("mode %s is not recognized" % mode)
+ return visuals
+
+ def optimize_parameters(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ # forward
+ self.forward() # compute fake images and reconstruction images.
+ # G_A and G_B
+ self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs
+ self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero
+ self.backward_G() # calculate gradients for G_A and G_B
+ self.optimizer_G.step() # update G_A and G_B's weights
+ # D_A and D_B
+ self.set_requires_grad([self.netD_A, self.netD_B], True)
+ self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero
+ self.backward_D_A() # calculate gradients for D_A
+ self.backward_D_B() # calculate graidents for D_B
+ self.optimizer_D.step() # update D_A and D_B's weights
diff --git a/models/networks.py b/models/networks.py
new file mode 100644
index 0000000..933f792
--- /dev/null
+++ b/models/networks.py
@@ -0,0 +1,1530 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import init
+import functools
+from torch.optim import lr_scheduler
+import numpy as np
+import random
+from .stylegan_networks import StyleGAN2Discriminator, StyleGAN2Generator, TileStyleGAN2Discriminator
+
+###############################################################################
+# Helper Functions
+###############################################################################
+
+
+def get_filter(filt_size=3):
+ if(filt_size == 1):
+ a = np.array([1., ])
+ elif(filt_size == 2):
+ a = np.array([1., 1.])
+ elif(filt_size == 3):
+ a = np.array([1., 2., 1.])
+ elif(filt_size == 4):
+ a = np.array([1., 3., 3., 1.])
+ elif(filt_size == 5):
+ a = np.array([1., 4., 6., 4., 1.])
+ elif(filt_size == 6):
+ a = np.array([1., 5., 10., 10., 5., 1.])
+ elif(filt_size == 7):
+ a = np.array([1., 6., 15., 20., 15., 6., 1.])
+
+ filt = torch.Tensor(a[:, None] * a[None, :])
+ filt = filt / torch.sum(filt)
+
+ return filt
+
+
+class Downsample(nn.Module):
+ def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0):
+ super(Downsample, self).__init__()
+ self.filt_size = filt_size
+ self.pad_off = pad_off
+ self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
+ self.stride = stride
+ self.off = int((self.stride - 1) / 2.)
+ self.channels = channels
+
+ filt = get_filter(filt_size=self.filt_size)
+ self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
+
+ self.pad = get_pad_layer(pad_type)(self.pad_sizes)
+
+ def forward(self, inp):
+ if(self.filt_size == 1):
+ if(self.pad_off == 0):
+ return inp[:, :, ::self.stride, ::self.stride]
+ else:
+ return self.pad(inp)[:, :, ::self.stride, ::self.stride]
+ else:
+ return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
+
+
+class Upsample2(nn.Module):
+ def __init__(self, scale_factor, mode='nearest'):
+ super().__init__()
+ self.factor = scale_factor
+ self.mode = mode
+
+ def forward(self, x):
+ return torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode)
+
+
+class Upsample(nn.Module):
+ def __init__(self, channels, pad_type='repl', filt_size=4, stride=2):
+ super(Upsample, self).__init__()
+ self.filt_size = filt_size
+ self.filt_odd = np.mod(filt_size, 2) == 1
+ self.pad_size = int((filt_size - 1) / 2)
+ self.stride = stride
+ self.off = int((self.stride - 1) / 2.)
+ self.channels = channels
+
+ filt = get_filter(filt_size=self.filt_size) * (stride**2)
+ self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
+
+ self.pad = get_pad_layer(pad_type)([1, 1, 1, 1])
+
+ def forward(self, inp):
+ ret_val = F.conv_transpose2d(self.pad(inp), self.filt, stride=self.stride, padding=1 + self.pad_size, groups=inp.shape[1])[:, :, 1:, 1:]
+ if(self.filt_odd):
+ return ret_val
+ else:
+ return ret_val[:, :, :-1, :-1]
+
+
+def get_pad_layer(pad_type):
+ if(pad_type in ['refl', 'reflect']):
+ PadLayer = nn.ReflectionPad2d
+ elif(pad_type in ['repl', 'replicate']):
+ PadLayer = nn.ReplicationPad2d
+ elif(pad_type == 'zero'):
+ PadLayer = nn.ZeroPad2d
+ else:
+ print('Pad type [%s] not recognized' % pad_type)
+ return PadLayer
+
+
+class Identity(nn.Module):
+ def forward(self, x):
+ return x
+
+
+def get_norm_layer(norm_type='instance'):
+ """Return a normalization layer
+
+ Parameters:
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
+
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
+ """
+ if norm_type == 'batch':
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
+ elif norm_type == 'instance':
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
+ elif norm_type == 'none':
+ def norm_layer(x):
+ return Identity()
+ else:
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
+ return norm_layer
+
+
+def get_scheduler(optimizer, opt):
+ """Return a learning rate scheduler
+
+ Parameters:
+ optimizer -- the optimizer of the network
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
+
+ For 'linear', we keep the same learning rate for the first epochs
+ and linearly decay the rate to zero over the next epochs.
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
+ See https://pytorch.org/docs/stable/optim.html for more details.
+ """
+ if opt.lr_policy == 'linear':
+ def lambda_rule(epoch):
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
+ return lr_l
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+ elif opt.lr_policy == 'step':
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
+ elif opt.lr_policy == 'plateau':
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
+ elif opt.lr_policy == 'cosine':
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
+ else:
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
+ return scheduler
+
+
+def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
+ """Initialize network weights.
+
+ Parameters:
+ net (network) -- network to be initialized
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
+ work better for some applications. Feel free to try yourself.
+ """
+ def init_func(m): # define the initialization function
+ classname = m.__class__.__name__
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
+ if debug:
+ print(classname)
+ if init_type == 'normal':
+ init.normal_(m.weight.data, 0.0, init_gain)
+ elif init_type == 'xavier':
+ init.xavier_normal_(m.weight.data, gain=init_gain)
+ elif init_type == 'kaiming':
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ init.orthogonal_(m.weight.data, gain=init_gain)
+ else:
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
+ if hasattr(m, 'bias') and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
+ init.normal_(m.weight.data, 1.0, init_gain)
+ init.constant_(m.bias.data, 0.0)
+
+ net.apply(init_func) # apply the initialization function
+
+
+def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
+ Parameters:
+ net (network) -- the network to be initialized
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+ Return an initialized network.
+ """
+ if len(gpu_ids) > 0:
+ assert(torch.cuda.is_available())
+ net.to(gpu_ids[0])
+ # if not amp:
+ # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
+ if initialize_weights:
+ init_weights(net, init_type, init_gain=init_gain, debug=debug)
+ return net
+
+
+def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal',
+ init_gain=0.02, no_antialias=False, no_antialias_up=False, gpu_ids=[], opt=None):
+ """Create a generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
+ use_dropout (bool) -- if use dropout layers.
+ init_type (str) -- the name of our initialization method.
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+ Returns a generator
+
+ Our current implementation provides two types of generators:
+ U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
+ The original U-Net paper: https://arxiv.org/abs/1505.04597
+
+ Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
+ Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
+ We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
+
+
+ The generator has been initialized by . It uses RELU for non-linearity.
+ """
+ net = None
+ norm_layer = get_norm_layer(norm_type=norm)
+
+ if netG == 'resnet_9blocks':
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=9, opt=opt)
+ elif netG == 'resnet_9blocks_mask':
+ net = ResnetGeneratorMask(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=9, opt=opt)
+ elif netG == 'resnet_6blocks':
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=6, opt=opt)
+ elif netG == 'resnet_4blocks':
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=4, opt=opt)
+ elif netG == 'unet_128':
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ elif netG == 'unet_256':
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
+ elif netG == 'stylegan2':
+ net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, opt=opt)
+ elif netG == 'smallstylegan2':
+ net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, n_blocks=2, opt=opt)
+ elif netG == 'resnet_cat':
+ n_blocks = 8
+ net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu')
+ else:
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
+ return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG))
+
+
+def define_F(input_nc, netF, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
+ if netF == 'global_pool':
+ net = PoolingF()
+ elif netF == 'reshape':
+ net = ReshapeF()
+ elif netF == 'sample':
+ net = PatchSampleF(use_mlp=False, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc)
+ elif netF == 'mlp_sample':
+ net = PatchSampleF(use_mlp=True, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc)
+ elif netF == 'strided_conv':
+ net = StridedConvF(init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
+ else:
+ raise NotImplementedError('projection model name [%s] is not recognized' % netF)
+ return init_net(net, init_type, init_gain, gpu_ids)
+
+
+def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
+ """Create a discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the first conv layer
+ netD (str) -- the architecture's name: basic | n_layers | pixel
+ n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
+ norm (str) -- the type of normalization layers used in the network.
+ init_type (str) -- the name of the initialization method.
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
+
+ Returns a discriminator
+
+ Our current implementation provides three types of discriminators:
+ [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
+ It can classify whether 70×70 overlapping patches are real or fake.
+ Such a patch-level discriminator architecture has fewer parameters
+ than a full-image discriminator and can work on arbitrarily-sized images
+ in a fully convolutional fashion.
+
+ [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
+ with the parameter (default=3 as used in [basic] (PatchGAN).)
+
+ [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
+ It encourages greater color diversity but has no effect on spatial statistics.
+
+ The discriminator has been initialized by . It uses Leaky RELU for non-linearity.
+ """
+ net = None
+ norm_layer = get_norm_layer(norm_type=norm)
+
+ if netD == 'basic': # default PatchGAN classifier
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, no_antialias=no_antialias,)
+ elif netD == 'n_layers': # more options
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias,)
+ elif netD == 'pixel': # classify if each pixel is real or fake
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
+ elif 'stylegan2' in netD:
+ net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, no_antialias=no_antialias, opt=opt)
+ else:
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
+ return init_net(net, init_type, init_gain, gpu_ids,
+ initialize_weights=('stylegan2' not in netD))
+
+
+##############################################################################
+# Classes
+##############################################################################
+class GANLoss(nn.Module):
+ """Define different GAN objectives.
+
+ The GANLoss class abstracts away the need to create the target label tensor
+ that has the same size as the input.
+ """
+
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
+ """ Initialize the GANLoss class.
+
+ Parameters:
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
+ target_real_label (bool) - - label for a real image
+ target_fake_label (bool) - - label of a fake image
+
+ Note: Do not use sigmoid as the last layer of Discriminator.
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
+ """
+ super(GANLoss, self).__init__()
+ self.register_buffer('real_label', torch.tensor(target_real_label))
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
+ self.gan_mode = gan_mode
+ if gan_mode == 'lsgan':
+ self.loss = nn.MSELoss()
+ elif gan_mode == 'vanilla':
+ self.loss = nn.BCEWithLogitsLoss()
+ elif gan_mode in ['wgangp', 'nonsaturating']:
+ self.loss = None
+ else:
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
+
+ def get_target_tensor(self, prediction, target_is_real):
+ """Create label tensors with the same size as the input.
+
+ Parameters:
+ prediction (tensor) - - tpyically the prediction from a discriminator
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
+
+ Returns:
+ A label tensor filled with ground truth label, and with the size of the input
+ """
+
+ if target_is_real:
+ target_tensor = self.real_label
+ else:
+ target_tensor = self.fake_label
+ return target_tensor.expand_as(prediction)
+
+ def __call__(self, prediction, target_is_real):
+ """Calculate loss given Discriminator's output and grount truth labels.
+
+ Parameters:
+ prediction (tensor) - - tpyically the prediction output from a discriminator
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
+
+ Returns:
+ the calculated loss.
+ """
+ bs = prediction.size(0)
+ if self.gan_mode in ['lsgan', 'vanilla']:
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
+ # print(prediction.shape, target_is_real.shape)
+ loss = self.loss(prediction, target_tensor)
+ elif self.gan_mode == 'wgangp':
+ if target_is_real:
+ loss = -prediction.mean()
+ else:
+ loss = prediction.mean()
+ elif self.gan_mode == 'nonsaturating':
+ if target_is_real:
+ loss = F.softplus(-prediction).view(bs, -1).mean(dim=1)
+ else:
+ loss = F.softplus(prediction).view(bs, -1).mean(dim=1)
+ return loss
+
+
+def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
+
+ Arguments:
+ netD (network) -- discriminator network
+ real_data (tensor array) -- real images
+ fake_data (tensor array) -- generated images from the generator
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
+ constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
+ lambda_gp (float) -- weight for this loss
+
+ Returns the gradient penalty loss
+ """
+ if lambda_gp > 0.0:
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
+ interpolatesv = real_data
+ elif type == 'fake':
+ interpolatesv = fake_data
+ elif type == 'mixed':
+ alpha = torch.rand(real_data.shape[0], 1, device=device)
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
+ else:
+ raise NotImplementedError('{} not implemented'.format(type))
+ interpolatesv.requires_grad_(True)
+ disc_interpolates = netD(interpolatesv)
+ gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+ create_graph=True, retain_graph=True, only_inputs=True)
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
+ return gradient_penalty, gradients
+ else:
+ return 0.0, None
+
+
+class Normalize(nn.Module):
+
+ def __init__(self, power=2):
+ super(Normalize, self).__init__()
+ self.power = power
+
+ def forward(self, x):
+ norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
+ out = x.div(norm + 1e-7)
+ return out
+
+
+class PoolingF(nn.Module):
+ def __init__(self):
+ super(PoolingF, self).__init__()
+ model = [nn.AdaptiveMaxPool2d(1)]
+ self.model = nn.Sequential(*model)
+ self.l2norm = Normalize(2)
+
+ def forward(self, x):
+ return self.l2norm(self.model(x))
+
+
+class ReshapeF(nn.Module):
+ def __init__(self):
+ super(ReshapeF, self).__init__()
+ model = [nn.AdaptiveAvgPool2d(4)]
+ self.model = nn.Sequential(*model)
+ self.l2norm = Normalize(2)
+
+ def forward(self, x):
+ x = self.model(x)
+ x_reshape = x.permute(0, 2, 3, 1).flatten(0, 2)
+ return self.l2norm(x_reshape)
+
+
+class StridedConvF(nn.Module):
+ def __init__(self, init_type='normal', init_gain=0.02, gpu_ids=[]):
+ super().__init__()
+ # self.conv1 = nn.Conv2d(256, 128, 3, stride=2)
+ # self.conv2 = nn.Conv2d(128, 64, 3, stride=1)
+ self.l2_norm = Normalize(2)
+ self.mlps = {}
+ self.moving_averages = {}
+ self.init_type = init_type
+ self.init_gain = init_gain
+ self.gpu_ids = gpu_ids
+
+ def create_mlp(self, x):
+ C, H = x.shape[1], x.shape[2]
+ n_down = int(np.rint(np.log2(H / 32)))
+ mlp = []
+ for i in range(n_down):
+ mlp.append(nn.Conv2d(C, max(C // 2, 64), 3, stride=2))
+ mlp.append(nn.ReLU())
+ C = max(C // 2, 64)
+ mlp.append(nn.Conv2d(C, 64, 3))
+ mlp = nn.Sequential(*mlp)
+ init_net(mlp, self.init_type, self.init_gain, self.gpu_ids)
+ return mlp
+
+ def update_moving_average(self, key, x):
+ if key not in self.moving_averages:
+ self.moving_averages[key] = x.detach()
+
+ self.moving_averages[key] = self.moving_averages[key] * 0.999 + x.detach() * 0.001
+
+ def forward(self, x, use_instance_norm=False):
+ C, H = x.shape[1], x.shape[2]
+ key = '%d_%d' % (C, H)
+ if key not in self.mlps:
+ self.mlps[key] = self.create_mlp(x)
+ self.add_module("child_%s" % key, self.mlps[key])
+ mlp = self.mlps[key]
+ x = mlp(x)
+ self.update_moving_average(key, x)
+ x = x - self.moving_averages[key]
+ if use_instance_norm:
+ x = F.instance_norm(x)
+ return self.l2_norm(x)
+
+
+class PatchSampleF(nn.Module):
+ def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
+ # potential issues: currently, we use the same patch_ids for multiple images in the batch
+ super(PatchSampleF, self).__init__()
+ self.l2norm = Normalize(2)
+ self.use_mlp = use_mlp
+ self.nc = nc # hard-coded
+ self.mlp_init = False
+ self.init_type = init_type
+ self.init_gain = init_gain
+ self.gpu_ids = gpu_ids
+
+ def create_mlp(self, feats):
+ for mlp_id, feat in enumerate(feats):
+ input_nc = feat.shape[-1]
+ # mlp = nn.Sequential(*[nn.Linear(input_nc, input_nc), nn.ReLU(), nn.Linear(input_nc, input_nc)])
+ mlp = nn.Sequential(*[nn.Linear(input_nc, input_nc)])
+ if len(self.gpu_ids) > 0:
+ mlp.cuda()
+ setattr(self, 'mlp_%d' % mlp_id, mlp)
+ init_net(self, self.init_type, self.init_gain, self.gpu_ids)
+ self.mlp_init = True
+
+ def forward(self, feats, num_patches=64, patch_ids=None):
+
+ return_feats = []
+ if self.use_mlp and not self.mlp_init:
+ self.create_mlp(feats)
+ for feat_id, feat in enumerate(feats):
+ mlp = getattr(self, 'mlp_%d' % feat_id)
+ res = mlp(feat)
+ return_feats.append(res)
+
+ return return_feats
+
+
+class G_Resnet(nn.Module):
+ def __init__(self, input_nc, output_nc, nz, num_downs, n_res, ngf=64,
+ norm=None, nl_layer=None):
+ super(G_Resnet, self).__init__()
+ n_downsample = num_downs
+ pad_type = 'reflect'
+ self.enc_content = ContentEncoder(n_downsample, n_res, input_nc, ngf, norm, nl_layer, pad_type=pad_type)
+ if nz == 0:
+ self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
+ else:
+ self.dec = Decoder_all(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
+
+ def decode(self, content, style=None):
+ return self.dec(content, style)
+
+ def forward(self, image, style=None, nce_layers=[], encode_only=False):
+ content, feats = self.enc_content(image, nce_layers=nce_layers, encode_only=encode_only)
+ if encode_only:
+ return feats
+ else:
+ images_recon = self.decode(content, style)
+ if len(nce_layers) > 0:
+ return images_recon, feats
+ else:
+ return images_recon
+
+##################################################################################
+# Encoder and Decoders
+##################################################################################
+
+
+class E_adaIN(nn.Module):
+ def __init__(self, input_nc, output_nc=1, nef=64, n_layers=4,
+ norm=None, nl_layer=None, vae=False):
+ # style encoder
+ super(E_adaIN, self).__init__()
+ self.enc_style = StyleEncoder(n_layers, input_nc, nef, output_nc, norm='none', activ='relu', vae=vae)
+
+ def forward(self, image):
+ style = self.enc_style(image)
+ return style
+
+
+class StyleEncoder(nn.Module):
+ def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, vae=False):
+ super(StyleEncoder, self).__init__()
+ self.vae = vae
+ self.model = []
+ self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
+ for i in range(2):
+ self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
+ dim *= 2
+ for i in range(n_downsample - 2):
+ self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
+ self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
+ if self.vae:
+ self.fc_mean = nn.Linear(dim, style_dim) # , 1, 1, 0)
+ self.fc_var = nn.Linear(dim, style_dim) # , 1, 1, 0)
+ else:
+ self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
+
+ self.model = nn.Sequential(*self.model)
+ self.output_dim = dim
+
+ def forward(self, x):
+ if self.vae:
+ output = self.model(x)
+ output = output.view(x.size(0), -1)
+ output_mean = self.fc_mean(output)
+ output_var = self.fc_var(output)
+ return output_mean, output_var
+ else:
+ return self.model(x).view(x.size(0), -1)
+
+
+class ContentEncoder(nn.Module):
+ def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type='zero'):
+ super(ContentEncoder, self).__init__()
+ self.model = []
+ self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
+ # downsampling blocks
+ for i in range(n_downsample):
+ self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
+ dim *= 2
+ # residual blocks
+ self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
+ self.model = nn.Sequential(*self.model)
+ self.output_dim = dim
+
+ def forward(self, x, nce_layers=[], encode_only=False):
+ if len(nce_layers) > 0:
+ feat = x
+ feats = []
+ for layer_id, layer in enumerate(self.model):
+ feat = layer(feat)
+ if layer_id in nce_layers:
+ feats.append(feat)
+ if layer_id == nce_layers[-1] and encode_only:
+ return None, feats
+ return feat, feats
+ else:
+ return self.model(x), None
+
+ for layer_id, layer in enumerate(self.model):
+ print(layer_id, layer)
+
+
+class Decoder_all(nn.Module):
+ def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
+ super(Decoder_all, self).__init__()
+ # AdaIN residual blocks
+ self.resnet_block = ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)
+ self.n_blocks = 0
+ # upsampling blocks
+ for i in range(n_upsample):
+ block = [Upsample2(scale_factor=2), Conv2dBlock(dim + nz, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
+ setattr(self, 'block_{:d}'.format(self.n_blocks), nn.Sequential(*block))
+ self.n_blocks += 1
+ dim //= 2
+ # use reflection padding in the last conv layer
+ setattr(self, 'block_{:d}'.format(self.n_blocks), Conv2dBlock(dim + nz, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect'))
+ self.n_blocks += 1
+
+ def forward(self, x, y=None):
+ if y is not None:
+ output = self.resnet_block(cat_feature(x, y))
+ for n in range(self.n_blocks):
+ block = getattr(self, 'block_{:d}'.format(n))
+ if n > 0:
+ output = block(cat_feature(output, y))
+ else:
+ output = block(output)
+ return output
+
+
+class Decoder(nn.Module):
+ def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
+ super(Decoder, self).__init__()
+
+ self.model = []
+ # AdaIN residual blocks
+ self.model += [ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)]
+ # upsampling blocks
+ for i in range(n_upsample):
+ if i == 0:
+ input_dim = dim + nz
+ else:
+ input_dim = dim
+ self.model += [Upsample2(scale_factor=2), Conv2dBlock(input_dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
+ dim //= 2
+ # use reflection padding in the last conv layer
+ self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect')]
+ self.model = nn.Sequential(*self.model)
+
+ def forward(self, x, y=None):
+ if y is not None:
+ return self.model(cat_feature(x, y))
+ else:
+ return self.model(x)
+
+##################################################################################
+# Sequential Models
+##################################################################################
+
+
+class ResBlocks(nn.Module):
+ def __init__(self, num_blocks, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
+ super(ResBlocks, self).__init__()
+ self.model = []
+ for i in range(num_blocks):
+ self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type, nz=nz)]
+ self.model = nn.Sequential(*self.model)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+##################################################################################
+# Basic Blocks
+##################################################################################
+def cat_feature(x, y):
+ y_expand = y.view(y.size(0), y.size(1), 1, 1).expand(
+ y.size(0), y.size(1), x.size(2), x.size(3))
+ x_cat = torch.cat([x, y_expand], 1)
+ return x_cat
+
+
+class ResBlock(nn.Module):
+ def __init__(self, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
+ super(ResBlock, self).__init__()
+
+ model = []
+ model += [Conv2dBlock(dim + nz, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
+ model += [Conv2dBlock(dim, dim + nz, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ residual = x
+ out = self.model(x)
+ out += residual
+ return out
+
+
+class Conv2dBlock(nn.Module):
+ def __init__(self, input_dim, output_dim, kernel_size, stride,
+ padding=0, norm='none', activation='relu', pad_type='zero'):
+ super(Conv2dBlock, self).__init__()
+ self.use_bias = True
+ # initialize padding
+ if pad_type == 'reflect':
+ self.pad = nn.ReflectionPad2d(padding)
+ elif pad_type == 'zero':
+ self.pad = nn.ZeroPad2d(padding)
+ else:
+ assert 0, "Unsupported padding type: {}".format(pad_type)
+
+ # initialize normalization
+ norm_dim = output_dim
+ if norm == 'batch':
+ self.norm = nn.BatchNorm2d(norm_dim)
+ elif norm == 'inst':
+ self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=False)
+ elif norm == 'ln':
+ self.norm = LayerNorm(norm_dim)
+ elif norm == 'none':
+ self.norm = None
+ else:
+ assert 0, "Unsupported normalization: {}".format(norm)
+
+ # initialize activation
+ if activation == 'relu':
+ self.activation = nn.ReLU(inplace=True)
+ elif activation == 'lrelu':
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
+ elif activation == 'prelu':
+ self.activation = nn.PReLU()
+ elif activation == 'selu':
+ self.activation = nn.SELU(inplace=True)
+ elif activation == 'tanh':
+ self.activation = nn.Tanh()
+ elif activation == 'none':
+ self.activation = None
+ else:
+ assert 0, "Unsupported activation: {}".format(activation)
+
+ # initialize convolution
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
+
+ def forward(self, x):
+ x = self.conv(self.pad(x))
+ if self.norm:
+ x = self.norm(x)
+ if self.activation:
+ x = self.activation(x)
+ return x
+
+
+class LinearBlock(nn.Module):
+ def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
+ super(LinearBlock, self).__init__()
+ use_bias = True
+ # initialize fully connected layer
+ self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
+
+ # initialize normalization
+ norm_dim = output_dim
+ if norm == 'batch':
+ self.norm = nn.BatchNorm1d(norm_dim)
+ elif norm == 'inst':
+ self.norm = nn.InstanceNorm1d(norm_dim)
+ elif norm == 'ln':
+ self.norm = LayerNorm(norm_dim)
+ elif norm == 'none':
+ self.norm = None
+ else:
+ assert 0, "Unsupported normalization: {}".format(norm)
+
+ # initialize activation
+ if activation == 'relu':
+ self.activation = nn.ReLU(inplace=True)
+ elif activation == 'lrelu':
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
+ elif activation == 'prelu':
+ self.activation = nn.PReLU()
+ elif activation == 'selu':
+ self.activation = nn.SELU(inplace=True)
+ elif activation == 'tanh':
+ self.activation = nn.Tanh()
+ elif activation == 'none':
+ self.activation = None
+ else:
+ assert 0, "Unsupported activation: {}".format(activation)
+
+ def forward(self, x):
+ out = self.fc(x)
+ if self.norm:
+ out = self.norm(out)
+ if self.activation:
+ out = self.activation(out)
+ return out
+
+##################################################################################
+# Normalization layers
+##################################################################################
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, num_features, eps=1e-5, affine=True):
+ super(LayerNorm, self).__init__()
+ self.num_features = num_features
+ self.affine = affine
+ self.eps = eps
+
+ if self.affine:
+ self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
+ self.beta = nn.Parameter(torch.zeros(num_features))
+
+ def forward(self, x):
+ shape = [-1] + [1] * (x.dim() - 1)
+ mean = x.view(x.size(0), -1).mean(1).view(*shape)
+ std = x.view(x.size(0), -1).std(1).view(*shape)
+ x = (x - mean) / (std + self.eps)
+
+ if self.affine:
+ shape = [1, -1] + [1] * (x.dim() - 2)
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
+ return x
+
+
+class ResnetGenerator(nn.Module):
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
+
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
+ """
+
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
+ """Construct a Resnet-based generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+ assert(n_blocks >= 0)
+ super(ResnetGenerator, self).__init__()
+ self.opt = opt
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ model = [nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2 ** i
+ if(no_antialias):
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+ else:
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True),
+ Downsample(ngf * mult * 2)]
+
+ mult = 2 ** n_downsampling
+ for i in range(n_blocks): # add ResNet blocks
+
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ if no_antialias_up:
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ else:
+ model += [Upsample(ngf * mult),
+ nn.Conv2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=1,
+ padding=1, # output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model += [nn.ReflectionPad2d(3)]
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ model += [nn.Tanh()]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input, layers=[], encode_only=False):
+ if -1 in layers:
+ layers.append(len(self.model))
+ if len(layers) > 0:
+ feat = input
+ feats = []
+ for layer_id, layer in enumerate(self.model):
+ # print(layer_id, layer)
+ feat = layer(feat)
+ if layer_id in layers:
+ # print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
+ feats.append(feat)
+ else:
+ # print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
+ pass
+ if layer_id == layers[-1] and encode_only:
+ # print('encoder only return features')
+ return feats # return intermediate features alone; stop in the last layers
+
+ return feat, feats # return both output and intermediate features
+ else:
+ """Standard forward"""
+ fake = self.model(input)
+ return fake
+
+
+
+
+class ResnetGeneratorMask(nn.Module):
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
+
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
+ """
+
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
+ """Construct a Resnet-based generator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+ assert(n_blocks >= 0)
+ super(ResnetGeneratorMask, self).__init__()
+ self.opt = opt
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ model = [nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2 ** i
+ if(no_antialias):
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+ else:
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True),
+ Downsample(ngf * mult * 2)]
+
+ mult = 2 ** n_downsampling
+ for i in range(n_blocks): # add ResNet blocks
+
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ self.layer_encoder = len(model) - 1
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ if no_antialias_up:
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ else:
+ model += [Upsample(ngf * mult),
+ nn.Conv2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=1,
+ padding=1, # output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model += [nn.ReflectionPad2d(3)]
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ model += [nn.Tanh()]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input, layers=[], encode_only=False, mask_rate=0.0):
+ if -1 in layers:
+ layers.append(len(self.model))
+ if len(layers) > 0:
+ feat = input
+ feats = []
+ for layer_id, layer in enumerate(self.model):
+ # print(layer_id, layer)
+ feat = layer(feat)
+ if layer_id in layers:
+ # print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
+ feats.append(feat)
+ else:
+ # print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
+ pass
+ if layer_id == layers[-1] and encode_only:
+ # print('encoder only return features')
+ return feats # return intermediate features alone; stop in the last layers
+
+ return feat, feats # return both output and intermediate features
+ elif mask_rate > 0.0:
+ feat = input
+ rate = random.uniform(0.0, mask_rate)
+ for layer_id, layer in enumerate(self.model):
+ feat = layer(feat)
+ # print(layer_id, self.layer_encoder)
+ if layer_id == self.layer_encoder:
+ # print('shape:', feat.shape)
+ B , C, H, W = feat.shape[0], feat.shape[1], feat.shape[2], feat.shape[3]
+ feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
+ all_num = feat_reshape.shape[1]
+ point_num = all_num * rate
+ point_id = np.random.permutation(all_num)
+ point_id = point_id[:int(min(point_num, all_num))]
+ feat_reshape[:,point_id,:] = 0
+ feat = feat_reshape.permute(0, 2, 1).reshape([B, C, H, W])
+ # print('rec', feat.shape)
+
+ return feat
+
+
+ else:
+ """Standard forward"""
+ fake = self.model(input)
+ return fake
+
+class ResnetDecoder(nn.Module):
+ """Resnet-based decoder that consists of a few Resnet blocks + a few upsampling operations.
+ """
+
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
+ """Construct a Resnet-based decoder
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+ assert(n_blocks >= 0)
+ super(ResnetDecoder, self).__init__()
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+ model = []
+ n_downsampling = 2
+ mult = 2 ** n_downsampling
+ for i in range(n_blocks): # add ResNet blocks
+
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ for i in range(n_downsampling): # add upsampling layers
+ mult = 2 ** (n_downsampling - i)
+ if(no_antialias):
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=2,
+ padding=1, output_padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ else:
+ model += [Upsample(ngf * mult),
+ nn.Conv2d(ngf * mult, int(ngf * mult / 2),
+ kernel_size=3, stride=1,
+ padding=1,
+ bias=use_bias),
+ norm_layer(int(ngf * mult / 2)),
+ nn.ReLU(True)]
+ model += [nn.ReflectionPad2d(3)]
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
+ model += [nn.Tanh()]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class ResnetEncoder(nn.Module):
+ """Resnet-based encoder that consists of a few downsampling + several Resnet blocks
+ """
+
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
+ """Construct a Resnet-based encoder
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers
+ n_blocks (int) -- the number of ResNet blocks
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
+ """
+ assert(n_blocks >= 0)
+ super(ResnetEncoder, self).__init__()
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ model = [nn.ReflectionPad2d(3),
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
+ norm_layer(ngf),
+ nn.ReLU(True)]
+
+ n_downsampling = 2
+ for i in range(n_downsampling): # add downsampling layers
+ mult = 2 ** i
+ if(no_antialias):
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True)]
+ else:
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
+ norm_layer(ngf * mult * 2),
+ nn.ReLU(True),
+ Downsample(ngf * mult * 2)]
+
+ mult = 2 ** n_downsampling
+ for i in range(n_blocks): # add ResNet blocks
+
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class ResnetBlock(nn.Module):
+ """Define a Resnet block"""
+
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
+ """Initialize the Resnet block
+
+ A resnet block is a conv block with skip connections
+ We construct a conv block with build_conv_block function,
+ and implement skip connections in function.
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
+ """
+ super(ResnetBlock, self).__init__()
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
+
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
+ """Construct a convolutional block.
+
+ Parameters:
+ dim (int) -- the number of channels in the conv layer.
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ use_bias (bool) -- if the conv layer uses bias or not
+
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
+ """
+ conv_block = []
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
+ if use_dropout:
+ conv_block += [nn.Dropout(0.5)]
+
+ p = 0
+ if padding_type == 'reflect':
+ conv_block += [nn.ReflectionPad2d(1)]
+ elif padding_type == 'replicate':
+ conv_block += [nn.ReplicationPad2d(1)]
+ elif padding_type == 'zero':
+ p = 1
+ else:
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
+
+ return nn.Sequential(*conv_block)
+
+ def forward(self, x):
+ """Forward function (with skip connections)"""
+ out = x + self.conv_block(x) # add skip connections
+ return out
+
+
+class UnetGenerator(nn.Module):
+ """Create a Unet-based generator"""
+
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet generator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ output_nc (int) -- the number of channels in output images
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
+ image of size 128x128 will become of size 1x1 # at the bottleneck
+ ngf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+
+ We construct the U-Net from the innermost layer to the outermost layer.
+ It is a recursive process.
+ """
+ super(UnetGenerator, self).__init__()
+ # construct unet structure
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
+ # gradually reduce the number of filters from ngf * 8 to ngf
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
+
+ def forward(self, input):
+ """Standard forward"""
+ return self.model(input)
+
+
+class UnetSkipConnectionBlock(nn.Module):
+ """Defines the Unet submodule with skip connection.
+ X -------------------identity----------------------
+ |-- downsampling -- |submodule| -- upsampling --|
+ """
+
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
+ """Construct a Unet submodule with skip connections.
+
+ Parameters:
+ outer_nc (int) -- the number of filters in the outer conv layer
+ inner_nc (int) -- the number of filters in the inner conv layer
+ input_nc (int) -- the number of channels in input images/features
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
+ outermost (bool) -- if this module is the outermost module
+ innermost (bool) -- if this module is the innermost module
+ norm_layer -- normalization layer
+ use_dropout (bool) -- if use dropout layers.
+ """
+ super(UnetSkipConnectionBlock, self).__init__()
+ self.outermost = outermost
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+ if input_nc is None:
+ input_nc = outer_nc
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
+ stride=2, padding=1, bias=use_bias)
+ downrelu = nn.LeakyReLU(0.2, True)
+ downnorm = norm_layer(inner_nc)
+ uprelu = nn.ReLU(True)
+ upnorm = norm_layer(outer_nc)
+
+ if outermost:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1)
+ down = [downconv]
+ up = [uprelu, upconv, nn.Tanh()]
+ model = down + [submodule] + up
+ elif innermost:
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv]
+ up = [uprelu, upconv, upnorm]
+ model = down + up
+ else:
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
+ kernel_size=4, stride=2,
+ padding=1, bias=use_bias)
+ down = [downrelu, downconv, downnorm]
+ up = [uprelu, upconv, upnorm]
+
+ if use_dropout:
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
+ else:
+ model = down + [submodule] + up
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ if self.outermost:
+ return self.model(x)
+ else: # add skip connections
+ return torch.cat([x, self.model(x)], 1)
+
+
+class MLPDiscriminator(nn.Module):
+ def __init__(self, in_feat=768, hid_feat = 768, out_feat = 768, dropout = 0.):
+ super().__init__()
+ if not hid_feat:
+ hid_feat = in_feat
+ if not out_feat:
+ out_feat = in_feat
+ self.linear1 = nn.Linear(in_feat, hid_feat)
+ self.activation = nn.GELU()
+ self.linear2 = nn.Linear(hid_feat, out_feat)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = self.activation(x)
+ x = self.dropout(x)
+ x = self.linear2(x)
+ return self.dropout(x)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator"""
+
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
+ """Construct a PatchGAN discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ kw = 4
+ padw = 1
+ if(no_antialias):
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ else:
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True), Downsample(ndf)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ if(no_antialias):
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+ else:
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ Downsample(ndf * nf_mult)]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
+ self.model = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.model(input)
+
+
+class PixelDiscriminator(nn.Module):
+ """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
+
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
+ """Construct a 1x1 PatchGAN discriminator
+
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ norm_layer -- normalization layer
+ """
+ super(PixelDiscriminator, self).__init__()
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ self.net = [
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
+ norm_layer(ndf * 2),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
+
+ self.net = nn.Sequential(*self.net)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.net(input)
+
+
+class PatchDiscriminator(NLayerDiscriminator):
+ """Defines a PatchGAN discriminator"""
+
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
+ super().__init__(input_nc, ndf, 2, norm_layer, no_antialias)
+
+ def forward(self, input):
+ B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
+ size = 16
+ Y = H // size
+ X = W // size
+ input = input.view(B, C, Y, size, X, size)
+ input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
+ return super().forward(input)
+
+
+class GroupedChannelNorm(nn.Module):
+ def __init__(self, num_groups):
+ super().__init__()
+ self.num_groups = num_groups
+
+ def forward(self, x):
+ shape = list(x.shape)
+ new_shape = [shape[0], self.num_groups, shape[1] // self.num_groups] + shape[2:]
+ x = x.view(*new_shape)
+ mean = x.mean(dim=2, keepdim=True)
+ std = x.std(dim=2, keepdim=True)
+ x_norm = (x - mean) / (std + 1e-7)
+ return x_norm.view(*shape)
diff --git a/models/patchnce.py b/models/patchnce.py
new file mode 100644
index 0000000..475793c
--- /dev/null
+++ b/models/patchnce.py
@@ -0,0 +1,55 @@
+from packaging import version
+import torch
+from torch import nn
+
+
+class PatchNCELoss(nn.Module):
+ def __init__(self, opt):
+ super().__init__()
+ self.opt = opt
+ self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
+ self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
+
+ def forward(self, feat_q, feat_k):
+ num_patches = feat_q.shape[0]
+ dim = feat_q.shape[1]
+ feat_k = feat_k.detach()
+
+ # pos logit
+ l_pos = torch.bmm(
+ feat_q.view(num_patches, 1, -1), feat_k.view(num_patches, -1, 1))
+ l_pos = l_pos.view(num_patches, 1)
+
+ # neg logit
+
+ # Should the negatives from the other samples of a minibatch be utilized?
+ # In CUT and FastCUT, we found that it's best to only include negatives
+ # from the same image. Therefore, we set
+ # --nce_includes_all_negatives_from_minibatch as False
+ # However, for single-image translation, the minibatch consists of
+ # crops from the "same" high-resolution image.
+ # Therefore, we will include the negatives from the entire minibatch.
+ if self.opt.nce_includes_all_negatives_from_minibatch:
+ # reshape features as if they are all negatives of minibatch of size 1.
+ batch_dim_for_bmm = 1
+ else:
+ batch_dim_for_bmm = self.opt.batch_size
+
+ # reshape features to batch size
+ feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
+ feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
+ npatches = feat_q.size(1)
+ l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
+
+ # diagonal entries are similarity between same features, and hence meaningless.
+ # just fill the diagonal with very small number, which is exp(-10) and almost zero
+ diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
+ l_neg_curbatch.masked_fill_(diagonal, -10.0)
+ l_neg = l_neg_curbatch.view(-1, npatches)
+
+ out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
+
+ loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
+ device=feat_q.device))
+
+ return loss
diff --git a/models/roma_model.py b/models/roma_model.py
new file mode 100644
index 0000000..48c307e
--- /dev/null
+++ b/models/roma_model.py
@@ -0,0 +1,363 @@
+import numpy as np
+import torch
+from .base_model import BaseModel
+from . import networks
+from .patchnce import PatchNCELoss
+import util.util as util
+import timm
+import time
+import torch.nn.functional as F
+import sys
+from functools import partial
+import torch.nn as nn
+import math
+
+from torchvision.transforms import transforms as tfs
+
+class ROMAModel(BaseModel):
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """ Configures options specific for CUT model
+ """
+ parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
+ parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator')
+ parser.add_argument('--lambda_motion', type=float, default=1.0, help='weight for Temporal Consistency')
+ parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
+ parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
+ parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
+ parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
+ parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
+ parser.add_argument('--local_nums', type=int, default=256)
+ parser.add_argument('--which_D_layer', type=int, default=-1)
+ parser.add_argument('--side_length', type=int, default=7)
+
+ parser.set_defaults(pool_size=0)
+
+ opt, _ = parser.parse_known_args()
+
+ return parser
+
+ def __init__(self, opt):
+ BaseModel.__init__(self, opt)
+
+
+ self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial', 'motion']
+ self.visual_names = ['real_A0', 'real_A1', 'fake_B0', 'fake_B1', 'real_B0', 'real_B1']
+ self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
+
+
+ if self.isTrain:
+ self.model_names = ['G', 'D_ViT']
+ else: # during test time, only load G
+ self.model_names = ['G']
+
+
+ # define networks (both generator and discriminator)
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
+
+
+ if self.isTrain:
+
+ self.netD_ViT = networks.MLPDiscriminator().to(self.device)
+ self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
+
+
+ self.norm = F.softmax
+
+ self.resize = tfs.Resize(size=(384,384))
+
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
+ self.criterionNCE = []
+
+ for atten_layer in self.atten_layers:
+ self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
+
+ self.criterionL1 = torch.nn.L1Loss().to(self.device)
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
+ self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
+ self.optimizers.append(self.optimizer_G)
+ self.optimizers.append(self.optimizer_D_ViT)
+
+ def data_dependent_initialize(self, data):
+ """
+ The feature network netF is defined in terms of the shape of the intermediate, extracted
+ features of the encoder portion of netG. Because of this, the weights of netF are
+ initialized at the first feedforward pass with some input images.
+ Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
+ """
+ pass
+
+
+ def optimize_parameters(self):
+ # forward
+ self.forward()
+
+ # update D
+ self.set_requires_grad(self.netD_ViT, True)
+ self.optimizer_D_ViT.zero_grad()
+ self.loss_D = self.compute_D_loss()
+ self.loss_D.backward()
+ self.optimizer_D_ViT.step()
+
+ # update G
+ self.set_requires_grad(self.netD_ViT, False)
+ self.optimizer_G.zero_grad()
+ self.loss_G = self.compute_G_loss()
+ self.loss_G.backward()
+ self.optimizer_G.step()
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+ Parameters:
+ input (dict): include the data itself and its metadata information.
+ The option 'direction' can be used to swap domain A and domain B.
+ """
+ AtoB = self.opt.direction == 'AtoB'
+ self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
+ self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
+ self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
+ self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
+
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+
+ # ============ 第一步:对 real_A / real_A2 进行多步随机生成过程 ============
+ tau = self.opt.tau
+ T = self.opt.num_timesteps
+ incs = np.array([0] + [1/(i+1) for i in range(T-1)])
+ times = np.cumsum(incs)
+ times = times / times[-1]
+ times = 0.5 * times[-1] + 0.5 * times #[0.5,1]
+ times = np.concatenate([np.zeros(1), times])
+ times = torch.tensor(times).float().cuda()
+ self.times = times
+ bs = self.mutil_real_A0_tokens.size(0)
+ time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
+ self.time_idx = time_idx
+
+ with torch.no_grad():
+ self.netG.eval()
+ # ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
+ for t in range(self.time_idx.int().item() + 1):
+ # 计算增量 delta 与 inter/scale,用于每个时间步的插值等
+ if t > 0:
+ delta = times[t] - times[t - 1]
+ denom = times[-1] - times[t - 1]
+ inter = (delta / denom).reshape(-1, 1, 1, 1)
+ scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
+
+ # 对 Xt、Xt2 进行随机噪声更新
+ Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
+ (scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
+ time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long()
+ z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
+ self.time = times[time_idx]
+ Xt_1 = self.netG(Xt, self.time, z)
+
+ Xt2 = self.mutil_real_A1_tokens if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
+ (scale * tau).sqrt() * torch.randn_like(Xt2).to(self.mutil_real_A1_tokens.device)
+ time_idx = (t * torch.ones(size=[self.mutil_real_A1_tokens.shape[0]]).to(self.mutil_real_A1_tokens.device)).long()
+ z = torch.randn(size=[self.mutil_real_A1_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
+ Xt_12 = self.netG(Xt2, self.time, z)
+
+ # 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
+ self.real_A_noisy = Xt.detach()
+ self.real_A_noisy2 = Xt2.detach()
+ # 保存noisy_map
+ self.noisy_map = self.real_A_noisy - self.real_A
+
+ # ============ 第三步:拼接输入并执行网络推理 =============
+ bs = self.mutil_real_A0_tokens.size(0)
+ z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
+ z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
+ # 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
+ self.real = self.mutil_real_A0_tokens
+ self.realt = self.real_A_noisy
+
+ if self.opt.flip_equivariance:
+ self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
+ if self.flipped_for_equivariance:
+ self.real = torch.flip(self.real, [3])
+ self.realt = torch.flip(self.realt, [3])
+
+
+ self.fake_B0 = self.netG(self.real_A0)
+ self.fake_B1 = self.netG(self.real_A1)
+
+ if self.opt.isTrain:
+ real_A0 = self.real_A0
+ real_A1 = self.real_A1
+ real_B0 = self.real_B0
+ real_B1 = self.real_B1
+ fake_B0 = self.fake_B0
+ fake_B1 = self.fake_B1
+ self.real_A0_resize = self.resize(real_A0)
+ self.real_A1_resize = self.resize(real_A1)
+ real_B0 = self.resize(real_B0)
+ real_B1 = self.resize(real_B1)
+ self.fake_B0_resize = self.resize(fake_B0)
+ self.fake_B1_resize = self.resize(fake_B1)
+ self.mutil_real_A0_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True)
+ self.mutil_real_A1_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True)
+ self.mutil_real_B0_tokens = self.netPreViT(real_B0, self.atten_layers, get_tokens=True)
+ self.mutil_real_B1_tokens = self.netPreViT(real_B1, self.atten_layers, get_tokens=True)
+ self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True)
+ self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True)
+
+ def tokens_concat(self, origin_tokens, adjacent_size):
+ adj_size = adjacent_size
+ B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
+ S = int(math.sqrt(token_num))
+ if S * S != token_num:
+ print('Error! Not a square!')
+ token_map = origin_tokens.clone().reshape(B,S,S,C)
+ cut_patch_list = []
+ for i in range(0, S, adj_size):
+ for j in range(0, S, adj_size):
+ i_left = i
+ i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
+ j_left = j
+ j_right = j + adj_size if j + adj_size <= S else S + 1
+
+ cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
+ cut_patch= cut_patch.reshape(B,-1,C)
+ cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
+ cut_patch_list.append(cut_patch)
+
+
+ result = torch.cat(cut_patch_list,dim=1)
+ return result
+
+
+ def cat_results(self, origin_tokens, adj_size_list):
+ res_list = [origin_tokens]
+ for ad_s in adj_size_list:
+ cat_result = self.tokens_concat(origin_tokens, ad_s)
+ res_list.append(cat_result)
+
+ result = torch.cat(res_list, dim=1)
+
+ return result
+
+
+
+ def compute_D_loss(self):
+ """Calculate GAN loss for the discriminator"""
+
+
+ lambda_D_ViT = self.opt.lambda_D_ViT
+ fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer].detach()
+ fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer].detach()
+
+ real_B0_tokens = self.mutil_real_B0_tokens[self.opt.which_D_layer]
+ real_B1_tokens = self.mutil_real_B1_tokens[self.opt.which_D_layer]
+
+
+ fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
+ fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
+
+
+
+ real_B0_tokens = self.cat_results(real_B0_tokens, self.opt.adj_size_list)
+ real_B1_tokens = self.cat_results(real_B1_tokens, self.opt.adj_size_list)
+
+ pre_fake0_ViT = self.netD_ViT(fake_B0_tokens)
+ pre_fake1_ViT = self.netD_ViT(fake_B1_tokens)
+
+ self.loss_D_fake_ViT = (self.criterionGAN(pre_fake0_ViT, False).mean() + self.criterionGAN(pre_fake1_ViT, False).mean()) * 0.5 * lambda_D_ViT
+
+ pred_real0_ViT = self.netD_ViT(real_B0_tokens)
+ pred_real1_ViT = self.netD_ViT(real_B1_tokens)
+ self.loss_D_real_ViT = (self.criterionGAN(pred_real0_ViT, True).mean() + self.criterionGAN(pred_real1_ViT, True).mean()) * 0.5 * lambda_D_ViT
+
+ self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
+
+
+ return self.loss_D_ViT
+
+ def compute_G_loss(self):
+
+ if self.opt.lambda_GAN > 0.0:
+
+ fake_B0_tokens = self.mutil_fake_B0_tokens[self.opt.which_D_layer]
+ fake_B1_tokens = self.mutil_fake_B1_tokens[self.opt.which_D_layer]
+ fake_B0_tokens = self.cat_results(fake_B0_tokens, self.opt.adj_size_list)
+ fake_B1_tokens = self.cat_results(fake_B1_tokens, self.opt.adj_size_list)
+ pred_fake0_ViT = self.netD_ViT(fake_B0_tokens)
+ pred_fake1_ViT = self.netD_ViT(fake_B1_tokens)
+ self.loss_G_GAN_ViT = (self.criterionGAN(pred_fake0_ViT, True) + self.criterionGAN(pred_fake1_ViT, True)) * 0.5 * self.opt.lambda_GAN
+ else:
+ self.loss_G_GAN_ViT = 0.0
+
+ if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
+ self.loss_global, self.loss_spatial = self.calculate_attention_loss()
+ else:
+ self.loss_global, self.loss_spatial = 0.0, 0.0
+
+ if self.opt.lambda_motion > 0.0:
+ self.loss_motion = 0.0
+ for real_A0_tokens, real_A1_tokens, fake_B0_tokens, fake_B1_tokens in zip(self.mutil_real_A0_tokens, self.mutil_real_A1_tokens, self.mutil_fake_B0_tokens, self.mutil_fake_B1_tokens):
+ A0_B1 = real_A0_tokens.bmm(fake_B1_tokens.permute(0,2,1))
+ B0_A1 = fake_B0_tokens.bmm(real_A1_tokens.permute(0,2,1))
+ cos_dis_global = F.cosine_similarity(A0_B1, B0_A1, dim=-1)
+ self.loss_motion += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
+ else:
+ self.loss_motion = 0.0
+
+ self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial + self.loss_motion
+ return self.loss_G
+
+ def calculate_attention_loss(self):
+ n_layers = len(self.atten_layers)
+ mutil_real_A0_tokens = self.mutil_real_A0_tokens
+ mutil_real_A1_tokens = self.mutil_real_A1_tokens
+ mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
+ mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
+
+
+ if self.opt.lambda_global > 0.0:
+ loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
+ loss_global *= 0.5
+
+ else:
+ loss_global = 0.0
+
+ if self.opt.lambda_spatial > 0.0:
+ loss_spatial = 0.0
+ local_nums = self.opt.local_nums
+ tokens_cnt = 576
+ local_id = np.random.permutation(tokens_cnt)
+ local_id = local_id[:int(min(local_nums, tokens_cnt))]
+
+ mutil_real_A0_local_tokens = self.netPreViT(self.real_A0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
+ mutil_real_A1_local_tokens = self.netPreViT(self.real_A1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
+
+ mutil_fake_B0_local_tokens = self.netPreViT(self.fake_B0_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
+ mutil_fake_B1_local_tokens = self.netPreViT(self.fake_B1_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
+
+ loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
+ loss_spatial *= 0.5
+
+ else:
+ loss_spatial = 0.0
+
+
+
+ return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
+
+ def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
+ loss = 0.0
+ n_layers = len(self.atten_layers)
+
+ for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
+
+ src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
+ tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
+ cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
+ loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
+
+ loss = loss / n_layers
+ return loss
+
diff --git a/models/roma_single_model.py b/models/roma_single_model.py
new file mode 100644
index 0000000..3c94d86
--- /dev/null
+++ b/models/roma_single_model.py
@@ -0,0 +1,272 @@
+import numpy as np
+import torch
+from .base_model import BaseModel
+from . import networks
+from .patchnce import PatchNCELoss
+import util.util as util
+import timm
+import time
+import torch.nn.functional as F
+import sys
+from functools import partial
+import torch.nn as nn
+import math
+
+from torchvision.transforms import transforms as tfs
+
+class ROMASingleModel(BaseModel):
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """ Configures options specific for CUT model
+ """
+ parser.add_argument('--adj_size_list', type=list, default=[2, 4, 6, 8, 12], help='different scales of perception field')
+ parser.add_argument('--lambda_mlp', type=float, default=1.0, help='weight of lr for discriminator')
+ parser.add_argument('--lambda_motion', type=float, default=1.0, help='weight for Temporal Consistency')
+ parser.add_argument('--lambda_D_ViT', type=float, default=1.0, help='weight for discriminator')
+ parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss: GAN(G(X))')
+ parser.add_argument('--lambda_global', type=float, default=1.0, help='weight for Global Structural Consistency')
+ parser.add_argument('--lambda_spatial', type=float, default=1.0, help='weight for Local Structural Consistency')
+ parser.add_argument('--atten_layers', type=str, default='1,3,5', help='compute Cross-Similarity on which layers')
+ parser.add_argument('--local_nums', type=int, default=256)
+ parser.add_argument('--which_D_layer', type=int, default=-1)
+ parser.add_argument('--side_length', type=int, default=7)
+
+ parser.set_defaults(pool_size=0)
+
+ opt, _ = parser.parse_known_args()
+
+ return parser
+
+ def __init__(self, opt):
+ BaseModel.__init__(self, opt)
+
+
+ self.loss_names = ['G_GAN_ViT', 'D_real_ViT', 'D_fake_ViT', 'global', 'spatial']
+ self.visual_names = ['real_A', 'fake_B', 'real_B']
+ self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
+
+
+ if self.isTrain:
+ self.model_names = ['G', 'D_ViT']
+ else: # during test time, only load G
+ self.model_names = ['G']
+
+
+ # define networks (both generator and discriminator)
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
+
+
+ if self.isTrain:
+
+ self.netD_ViT = networks.MLPDiscriminator().to(self.device)
+ # self.netPreViT = timm.create_model("vit_base_patch32_384",pretrained=True).to(self.device)
+ self.netPreViT = timm.create_model("vit_base_patch16_384",pretrained=True).to(self.device)
+
+
+ self.norm = F.softmax
+
+ self.resize = tfs.Resize(size=(384,384))
+ # self.resize = tfs.Resize(size=(224, 224))
+
+ # define loss functions
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
+ self.criterionNCE = []
+
+ for atten_layer in self.atten_layers:
+ self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
+
+ self.criterionL1 = torch.nn.L1Loss().to(self.device)
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
+ self.optimizer_D_ViT = torch.optim.Adam(self.netD_ViT.parameters(), lr=opt.lr * opt.lambda_mlp, betas=(opt.beta1, opt.beta2))
+ self.optimizers.append(self.optimizer_G)
+ self.optimizers.append(self.optimizer_D_ViT)
+
+ def data_dependent_initialize(self, data):
+ """
+ The feature network netF is defined in terms of the shape of the intermediate, extracted
+ features of the encoder portion of netG. Because of this, the weights of netF are
+ initialized at the first feedforward pass with some input images.
+ Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
+ """
+ pass
+
+
+ def optimize_parameters(self):
+ # forward
+ self.forward()
+
+ # update D
+ self.set_requires_grad(self.netD_ViT, True)
+ self.optimizer_D_ViT.zero_grad()
+ self.loss_D = self.compute_D_loss()
+ self.loss_D.backward()
+ self.optimizer_D_ViT.step()
+
+ # update G
+ self.set_requires_grad(self.netD_ViT, False)
+ self.optimizer_G.zero_grad()
+ self.loss_G = self.compute_G_loss()
+ self.loss_G.backward()
+ self.optimizer_G.step()
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+ Parameters:
+ input (dict): include the data itself and its metadata information.
+ The option 'direction' can be used to swap domain A and domain B.
+ """
+ AtoB = self.opt.direction == 'AtoB'
+ self.real_A = input['A' if AtoB else 'B'].to(self.device)
+ self.real_B = input['B' if AtoB else 'A'].to(self.device)
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
+
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ self.fake_B = self.netG(self.real_A)
+
+ if self.opt.isTrain:
+ real_A = self.real_A
+ real_B = self.real_B
+ fake_B = self.fake_B
+ self.real_A_resize = self.resize(real_A)
+ real_B = self.resize(real_B)
+ self.fake_B_resize = self.resize(fake_B)
+ self.mutil_real_A_tokens = self.netPreViT(self.real_A_resize, self.atten_layers, get_tokens=True)
+ self.mutil_real_B_tokens = self.netPreViT(real_B, self.atten_layers, get_tokens=True)
+ self.mutil_fake_B_tokens = self.netPreViT(self.fake_B_resize, self.atten_layers, get_tokens=True)
+
+ def tokens_concat(self, origin_tokens, adjacent_size):
+ adj_size = adjacent_size
+ B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
+ S = int(math.sqrt(token_num))
+ if S * S != token_num:
+ print('Error! Not a square!')
+ token_map = origin_tokens.clone().reshape(B,S,S,C)
+ cut_patch_list = []
+ for i in range(0, S, adj_size):
+ for j in range(0, S, adj_size):
+ i_left = i
+ i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
+ j_left = j
+ j_right = j + adj_size if j + adj_size <= S else S + 1
+
+ cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
+ cut_patch= cut_patch.reshape(B,-1,C)
+ cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
+ cut_patch_list.append(cut_patch)
+
+
+ result = torch.cat(cut_patch_list,dim=1)
+ return result
+
+
+ def cat_results(self, origin_tokens, adj_size_list):
+ res_list = [origin_tokens]
+ for ad_s in adj_size_list:
+ cat_result = self.tokens_concat(origin_tokens, ad_s)
+ res_list.append(cat_result)
+
+ result = torch.cat(res_list, dim=1)
+
+ return result
+
+
+
+ def compute_D_loss(self):
+ """Calculate GAN loss for the discriminator"""
+
+
+ lambda_D_ViT = self.opt.lambda_D_ViT
+ fake_B_tokens = self.mutil_fake_B_tokens[self.opt.which_D_layer].detach()
+
+ real_B_tokens = self.mutil_real_B_tokens[self.opt.which_D_layer]
+
+
+ fake_B_tokens = self.cat_results(fake_B_tokens, self.opt.adj_size_list)
+
+ real_B_tokens = self.cat_results(real_B_tokens, self.opt.adj_size_list)
+
+ pre_fake_ViT = self.netD_ViT(fake_B_tokens)
+
+
+ self.loss_D_fake_ViT = self.criterionGAN(pre_fake_ViT, False).mean() * lambda_D_ViT
+
+ pred_real_ViT = self.netD_ViT(real_B_tokens)
+ self.loss_D_real_ViT = self.criterionGAN(pred_real_ViT, True).mean() * lambda_D_ViT
+
+ self.loss_D_ViT = (self.loss_D_fake_ViT + self.loss_D_real_ViT) * 0.5
+
+
+ return self.loss_D_ViT
+
+ def compute_G_loss(self):
+
+ if self.opt.lambda_GAN > 0.0:
+
+ fake_B_tokens = self.mutil_fake_B_tokens[self.opt.which_D_layer]
+ fake_B_tokens = self.cat_results(fake_B_tokens, self.opt.adj_size_list)
+ pred_fake_ViT = self.netD_ViT(fake_B_tokens)
+ self.loss_G_GAN_ViT = self.criterionGAN(pred_fake_ViT, True) * self.opt.lambda_GAN
+ else:
+ self.loss_G_GAN_ViT = 0.0
+
+ if self.opt.lambda_global > 0.0 or self.opt.lambda_spatial > 0.0:
+ self.loss_global, self.loss_spatial = self.calculate_attention_loss()
+ else:
+ self.loss_global, self.loss_spatial = 0.0, 0.0
+
+
+
+ self.loss_G = self.loss_G_GAN_ViT + self.loss_global + self.loss_spatial
+ return self.loss_G
+
+ def calculate_attention_loss(self):
+ n_layers = len(self.atten_layers)
+ mutil_real_A_tokens = self.mutil_real_A_tokens
+ mutil_fake_B_tokens = self.mutil_fake_B_tokens
+
+
+
+ if self.opt.lambda_global > 0.0:
+ loss_global = self.calculate_similarity(mutil_real_A_tokens, mutil_fake_B_tokens)
+
+
+ else:
+ loss_global = 0.0
+
+ if self.opt.lambda_spatial > 0.0:
+ loss_spatial = 0.0
+ local_nums = self.opt.local_nums
+ tokens_cnt = 576
+ local_id = np.random.permutation(tokens_cnt)
+ local_id = local_id[:int(min(local_nums, tokens_cnt))]
+
+ mutil_real_A_local_tokens = self.netPreViT(self.real_A_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
+
+ mutil_fake_B_local_tokens = self.netPreViT(self.fake_B_resize, self.atten_layers, get_tokens=True, local_id=local_id, side_length = self.opt.side_length)
+
+ loss_spatial = self.calculate_similarity(mutil_real_A_local_tokens, mutil_fake_B_local_tokens)
+
+
+ else:
+ loss_spatial = 0.0
+
+
+
+ return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
+
+ def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
+ loss = 0.0
+ n_layers = len(self.atten_layers)
+
+ for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
+
+ src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
+ tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
+ cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
+ loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
+
+ loss = loss / n_layers
+ return loss
+
diff --git a/models/self_build.py b/models/self_build.py
new file mode 100644
index 0000000..1cb0f37
--- /dev/null
+++ b/models/self_build.py
@@ -0,0 +1,655 @@
+import numpy as np
+import math
+import timm
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision.transforms import GaussianBlur
+from .base_model import BaseModel
+from . import networks
+from .patchnce import PatchNCELoss
+import util.util as util
+
+from torchvision.transforms import transforms as tfs
+
+def warp(image, flow): #warp操作
+ """
+ 基于光流的图像变形函数
+ Args:
+ image: [B, C, H, W] 输入图像
+ flow: [B, 2, H, W] 光流场(x/y方向位移)
+ Returns:
+ warped: [B, C, H, W] 变形后的图像
+ """
+ B, C, H, W = image.shape
+ # 生成网格坐标
+ grid_x, grid_y = torch.meshgrid(torch.arange(W), torch.arange(H))
+ grid = torch.stack((grid_x, grid_y), dim=0).float().to(image.device) # [2,H,W]
+ grid = grid.unsqueeze(0).repeat(B,1,1,1) # [B,2,H,W]
+
+ # 应用光流位移(归一化到[-1,1])
+ new_grid = grid + flow
+ new_grid[:,0,:,:] = 2.0 * new_grid[:,0,:,:] / (W-1) - 1.0 # x方向
+ new_grid[:,1,:,:] = 2.0 * new_grid[:,1,:,:] / (H-1) - 1.0 # y方向
+ new_grid = new_grid.permute(0,2,3,1) # [B,H,W,2]
+
+ # 双线性插值
+ return F.grid_sample(image, new_grid, align_corners=True)
+
+# 时序归一化损失计算
+def compute_ctn_loss(G, x, F_content): #公式10
+ """
+ 计算内容感知时序归一化损失
+ Args:
+ G: 生成器
+ x: 输入红外图像 [B,C,H,W]
+ F_content: 生成的光流场 [B,2,H,W]
+ """
+
+ # 生成可见光图像
+ y_fake = G(x) # [B,3,H,W]
+
+ # 对生成结果应用光流变形
+ warped_fake = warp(y_fake, F_content) # [B,3,H,W]
+
+ # 对输入应用相同光流后生成图像
+ warped_x = warp(x, F_content) # [B,C,H,W]
+ y_fake_warped = G(warped_x) # [B,3,H,W]
+
+ # 计算L2损失
+ loss = F.mse_loss(warped_fake, y_fake_warped)
+ return loss
+
+class ContentAwareOptimization(nn.Module):
+ def __init__(self, lambda_inc=2.0, eta_ratio=0.4):
+ super().__init__()
+ self.lambda_inc = lambda_inc # 权重增强系数
+ self.eta_ratio = eta_ratio # 选择内容区域的比例
+
+ def compute_cosine_similarity(self, gradients):
+ """
+ 计算每个patch梯度与平均梯度的余弦相似度
+ Args:
+ gradients: [B, N, D] 判别器输出的每个patch的梯度(N=w*h)
+ Returns:
+ cosine_sim: [B, N] 每个patch的余弦相似度
+ """
+ mean_grad = torch.mean(gradients, dim=1, keepdim=True) # [B, 1, D]
+ # 计算余弦相似度
+ cosine_sim = F.cosine_similarity(gradients, mean_grad, dim=2) # [B, N]
+ return cosine_sim
+
+ def generate_weight_map(self, gradients_real, gradients_fake):
+ """
+ 生成内容感知权重图
+ Args:
+ gradients_real: [B, N, D] 真实图像判别器梯度
+ gradients_fake: [B, N, D] 生成图像判别器梯度
+ Returns:
+ weight_real: [B, N] 真实图像权重图
+ weight_fake: [B, N] 生成图像权重图
+ """
+ # 计算真实图像块的余弦相似度
+ cosine_real = self.compute_cosine_similarity(gradients_real) # [B, N] 公式5
+ # 计算生成图像块的余弦相似度
+ cosine_fake = self.compute_cosine_similarity(gradients_fake) # [B, N]
+
+ # 选择内容丰富的区域(余弦相似度最低的eta_ratio比例)
+ k = int(self.eta_ratio * cosine_real.shape[1])
+
+ # 对真实图像生成权重图
+ _, real_indices = torch.topk(-cosine_real, k, dim=1) # 选择最不相似的区域
+ weight_real = torch.ones_like(cosine_real)
+ for b in range(cosine_real.shape[0]):
+ weight_real[b, real_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_real[b, real_indices[b]])) #公式6
+
+ # 对生成图像生成权重图(同理)
+ _, fake_indices = torch.topk(-cosine_fake, k, dim=1)
+ weight_fake = torch.ones_like(cosine_fake)
+ for b in range(cosine_fake.shape[0]):
+ weight_fake[b, fake_indices[b]] = self.lambda_inc / (1e-6 + torch.abs(cosine_fake[b, fake_indices[b]]))
+
+ return weight_real, weight_fake
+
+ def forward(self, D_real, D_fake, real_scores, fake_scores):
+ """
+ 计算内容感知对抗损失
+ Args:
+ D_real: 判别器对真实图像的特征输出 [B, C, H, W]
+ D_fake: 判别器对生成图像的特征输出 [B, C, H, W]
+ real_scores: 真实图像的判别器预测 [B, N] (N=H*W)
+ fake_scores: 生成图像的判别器预测 [B, N]
+ Returns:
+ loss_co_adv: 内容感知对抗损失
+ """
+ B, C, H, W = D_real.shape
+ N = H * W
+
+ # 注册钩子获取梯度
+ gradients_real = []
+ gradients_fake = []
+
+ def hook_real(grad):
+ gradients_real.append(grad.detach().view(B, N, -1))
+
+ def hook_fake(grad):
+ gradients_fake.append(grad.detach().view(B, N, -1))
+
+ D_real.register_hook(hook_real)
+ D_fake.register_hook(hook_fake)
+
+ # 计算原始对抗损失以触发梯度计算
+ loss_real = torch.mean(torch.log(real_scores + 1e-8))
+ loss_fake = torch.mean(torch.log(1 - fake_scores + 1e-8))
+ # 添加与 D_real、D_fake 相关的 dummy 项,确保梯度传递
+ loss_dummy = 1e-8 * (D_real.sum() + D_fake.sum())
+ total_loss = loss_real + loss_fake + loss_dummy
+ total_loss.backward(retain_graph=True)
+
+ # 获取梯度数据
+ gradients_real = gradients_real[0] # [B, N, D]
+ gradients_fake = gradients_fake[0] # [B, N, D]
+
+ # 生成权重图
+ self.weight_real, self.weight_fake = self.generate_weight_map(gradients_real, gradients_fake)
+
+ # 应用权重到对抗损失
+ loss_co_real = torch.mean(self.weight_real * torch.log(real_scores + 1e-8))
+ loss_co_fake = torch.mean(self.weight_fake * torch.log(1 - fake_scores + 1e-8))
+
+ # 计算并返回最终内容感知对抗损失
+ loss_co_adv = -(loss_co_real + loss_co_fake)
+
+ return loss_co_adv
+
+class ContentAwareTemporalNorm(nn.Module):
+ def __init__(self, gamma_stride=0.1, kernel_size=21, sigma=5.0):
+ super().__init__()
+ self.gamma_stride = gamma_stride # 控制整体运动幅度
+ self.smoother = GaussianBlur(kernel_size, sigma=sigma) # 高斯平滑层
+
+ def forward(self, weight_map):
+ """
+ 生成内容感知光流
+ Args:
+ weight_map: [B, 1, H, W] 权重图(来自内容感知优化模块)
+ Returns:
+ F_content: [B, 2, H, W] 生成的光流场(x/y方向位移)
+ """
+ B, _, H, W = weight_map.shape
+
+ # 1. 归一化权重图
+ # 保持区域相对强度,同时限制数值范围
+ weight_norm = F.normalize(weight_map, p=1, dim=(2,3)) # L1归一化 [B,1,H,W]
+
+ # 2. 生成高斯噪声(与光流场同尺寸)
+ z = torch.randn(B, 2, H, W, device=weight_map.device) # [B,2,H,W]
+
+ # 3. 合成基础光流
+ # 将权重图扩展为2通道(x/y方向共享权重)
+ weight_expanded = weight_norm.expand(-1, 2, -1, -1) # [B,2,H,W]
+ F_raw = self.gamma_stride * weight_expanded * z # [B,2,H,W] #公式9
+
+ # 4. 平滑处理(保持结构连续性)
+ # 对每个通道独立进行高斯模糊
+ F_smooth = self.smoother(F_raw) # [B,2,H,W]
+
+ # 5. 动态范围调整(可选)
+ # 限制光流幅值,避免极端位移
+ F_content = torch.tanh(F_smooth) # 缩放到[-1,1]范围
+
+ return F_content
+
+class CTNxModel(BaseModel):
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """配置 CTNx 模型的特定选项"""
+
+ parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))')
+ parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
+ parser.add_argument('--lambda_SB', type=float, default=0.1, help='weight for SB loss')
+ parser.add_argument('--lambda_ctn', type=float, default=1.0, help='weight for content-aware temporal norm')
+
+ parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
+ parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
+ parser.add_argument('--nce_includes_all_negatives_from_minibatch',
+ type=util.str2bool, nargs='?', const=True, default=False,
+ help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
+
+ parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
+ parser.add_argument('--netF_nc', type=int, default=256)
+ parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
+
+ parser.add_argument('--lmda_1', type=float, default=0.1)
+ parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
+ parser.add_argument('--flip_equivariance',
+ type=util.str2bool, nargs='?', const=True, default=False,
+ help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CUT")
+
+ parser.add_argument('--lambda_inc', type=float, default=1.0, help='incremental weight for content-aware optimization')
+ parser.add_argument('--eta_ratio', type=float, default=0.1, help='ratio of content-rich regions')
+
+
+ parser.set_defaults(pool_size=0) # no image pooling
+
+ opt, _ = parser.parse_known_args()
+
+ # 直接设置为 sb 模式
+ parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
+
+ return parser
+
+ def __init__(self, opt):
+ """初始化 CTNx 模型"""
+ BaseModel.__init__(self, opt)
+
+ # 指定需要打印的训练损失
+ self.loss_names = ['G_GAN_1', 'D_real_1', 'D_fake_1', 'G_1', 'NCE_1', 'SB_1',
+ 'G_2']
+ self.visual_names = ['real_A', 'real_A_noisy', 'fake_B', 'real_B']
+ self.atten_layers = [int(i) for i in self.opt.atten_layers.split(',')]
+
+ if self.opt.phase == 'test':
+ self.visual_names = ['real']
+ for NFE in range(self.opt.num_timesteps):
+ fake_name = 'fake_' + str(NFE+1)
+ self.visual_names.append(fake_name)
+ self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
+
+ if opt.nce_idt and self.isTrain:
+ self.loss_names += ['NCE_Y']
+ self.visual_names += ['idt_B']
+
+ if self.isTrain:
+ self.model_names = ['G1', 'F1', 'D1', 'E1',
+ 'G2']
+
+
+ else:
+ self.model_names = ['G1']
+
+ # 创建网络
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
+
+
+ if self.isTrain:
+ self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
+ self.netE = networks.define_D(opt.output_nc*4, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
+
+ self.resize = tfs.Resize(size=(384,384))
+
+ # 加入预训练VIT
+ self.netPreViT = timm.create_model("vit_base_patch16_384", pretrained=True).to(self.device)
+
+ # 定义损失函数
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
+ self.criterionNCE = []
+ for nce_layer in self.nce_layers:
+ self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
+ self.criterionIdt = torch.nn.L1Loss().to(self.device)
+ self.optimizer_G1 = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
+ self.optimizer_D1 = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
+ self.optimizer_E1 = torch.optim.Adam(self.netE.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
+ self.optimizers = [self.optimizer_G1, self.optimizer_D1, self.optimizer_E1]
+
+ self.cao = ContentAwareOptimization(opt.lambda_inc, opt.eta_ratio) #损失函数
+ self.ctn = ContentAwareTemporalNorm() #生成的伪光流
+
+ def data_dependent_initialize(self, data):
+ """
+ The feature network netF is defined in terms of the shape of the intermediate, extracted
+ features of the encoder portion of netG. Because of this, the weights of netF are
+ initialized at the first feedforward pass with some input images.
+ Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
+ """
+ #bs_per_gpu = data["A"].size(0) // max(len(self.opt.gpu_ids), 1)
+ #self.set_input(data)
+ #self.real_A = self.real_A[:bs_per_gpu]
+ #self.real_B = self.real_B[:bs_per_gpu]
+ #self.forward() # compute fake images: G(A)
+ #if self.opt.isTrain:
+ #
+ # self.compute_G_loss().backward()
+ # self.compute_D_loss().backward()
+ # self.compute_E_loss().backward()
+ # if self.opt.lambda_NCE > 0.0:
+ # self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
+ # self.optimizers.append(self.optimizer_F)
+ pass
+
+ def optimize_parameters(self):
+ # forward
+ self.forward()
+
+ self.netG.train()
+ self.netE.train()
+ self.netD.train()
+
+ # update D
+ self.set_requires_grad(self.netD, True)
+ self.optimizer_D.zero_grad()
+ self.loss_D = self.compute_D_loss()
+ self.loss_D.backward()
+ self.optimizer_D.step()
+
+ self.set_requires_grad(self.netE, True)
+ self.optimizer_E.zero_grad()
+ self.loss_E = self.compute_E_loss()
+ self.loss_E.backward()
+ self.optimizer_E.step()
+
+ # update G
+ self.set_requires_grad(self.netD, False)
+ self.set_requires_grad(self.netE, False)
+
+ self.optimizer_G.zero_grad()
+
+ self.loss_G = self.compute_G_loss()
+ self.loss_G.backward()
+ self.optimizer_G.step()
+
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+ Parameters:
+ input (dict): include the data itself and its metadata information.
+ The option 'direction' can be used to swap domain A and domain B.
+ """
+ AtoB = self.opt.direction == 'AtoB'
+ self.real_A0 = input['A0' if AtoB else 'B0'].to(self.device)
+ self.real_A1 = input['A1' if AtoB else 'B1'].to(self.device)
+ self.real_B0 = input['B0' if AtoB else 'A0'].to(self.device)
+ self.real_B1 = input['B1' if AtoB else 'A1'].to(self.device)
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
+
+
+ def tokens_concat(self, origin_tokens, adjacent_size):
+ adj_size = adjacent_size
+ B, token_num, C = origin_tokens.shape[0], origin_tokens.shape[1], origin_tokens.shape[2]
+ S = int(math.sqrt(token_num))
+ if S * S != token_num:
+ print('Error! Not a square!')
+ token_map = origin_tokens.clone().reshape(B,S,S,C)
+ cut_patch_list = []
+ for i in range(0, S, adj_size):
+ for j in range(0, S, adj_size):
+ i_left = i
+ i_right = i + adj_size + 1 if i + adj_size <= S else S + 1
+ j_left = j
+ j_right = j + adj_size if j + adj_size <= S else S + 1
+
+ cut_patch = token_map[:, i_left:i_right, j_left: j_right, :]
+ cut_patch= cut_patch.reshape(B,-1,C)
+ cut_patch = torch.mean(cut_patch, dim=1, keepdim=True)
+ cut_patch_list.append(cut_patch)
+
+
+ result = torch.cat(cut_patch_list,dim=1)
+ return result
+
+ def cat_results(self, origin_tokens, adj_size_list):
+ res_list = [origin_tokens]
+ for ad_s in adj_size_list:
+ cat_result = self.tokens_concat(origin_tokens, ad_s)
+ res_list.append(cat_result)
+
+ result = torch.cat(res_list, dim=1)
+
+ return result
+
+
+
+ def forward(self):
+ """执行前向传递以生成输出图像"""
+
+ if self.opt.isTrain:
+ real_A0 = self.resize(self.real_A0)
+ real_A1 = self.resize(self.real_A1)
+ real_B0 = self.resize(self.real_B0)
+ real_B1 = self.resize(self.real_B1)
+ # 使用VIT
+ self.mutil_real_A0_tokens = self.netPreViT(real_A0, self.atten_layers, get_tokens=True)
+ self.mutil_real_A1_tokens = self.netPreViT(real_A1, self.atten_layers, get_tokens=True)
+
+ # 执行一次SB模块
+
+ # ============ 第一步:初始化时间步与时间索引 ============
+ # 计算 times,并确定当前 time_idx(随机选取用来表示当前时间步)
+ tau = self.opt.tau
+ T = self.opt.num_timesteps
+ incs = np.array([0] + [1/(i+1) for i in range(T-1)])
+ times = np.cumsum(incs)
+ times = times / times[-1]
+ times = 0.5 * times[-1] + 0.5 * times #[0.5,1]
+ times = np.concatenate([np.zeros(1), times])
+ times = torch.tensor(times).float().cuda()
+ self.times = times
+ bs = self.mutil_real_A0_tokens.size(0)
+ time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
+ self.time_idx = time_idx
+
+ with torch.no_grad():
+ self.netG.eval()
+ # ============ 第二步:对 real_A / real_A2 进行多步随机生成过程 ============
+ for t in range(self.time_idx.int().item() + 1):
+ # 计算增量 delta 与 inter/scale,用于每个时间步的插值等
+ if t > 0:
+ delta = times[t] - times[t - 1]
+ denom = times[-1] - times[t - 1]
+ inter = (delta / denom).reshape(-1, 1, 1, 1)
+ scale = (delta * (1 - delta / denom)).reshape(-1, 1, 1, 1)
+
+ # 对 Xt、Xt2 进行随机噪声更新
+ Xt = self.mutil_real_A0_tokens if (t == 0) else (1 - inter) * Xt + inter * Xt_1.detach() + \
+ (scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
+ time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long()
+ z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
+ self.time = times[time_idx]
+ Xt_1 = self.netG(Xt, self.time, z)
+
+ Xt2 = self.mutil_real_A1_tokens if (t == 0) else (1 - inter) * Xt2 + inter * Xt_12.detach() + \
+ (scale * tau).sqrt() * torch.randn_like(Xt2).to(self.mutil_real_A1_tokens.device)
+ time_idx = (t * torch.ones(size=[self.mutil_real_A1_tokens.shape[0]]).to(self.mutil_real_A1_tokens.device)).long()
+ z = torch.randn(size=[self.mutil_real_A1_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
+ Xt_12 = self.netG(Xt2, self.time, z)
+
+ # 保存去噪后的中间结果 (real_A_noisy 等),供下一步做拼接
+ self.real_A_noisy = Xt.detach()
+ self.real_A_noisy2 = Xt2.detach()
+ # 保存noisy_map
+ self.noisy_map = self.real_A_noisy - self.real_A
+
+ # ============ 第三步:拼接输入并执行网络推理 =============
+ bs = self.mutil_real_A0_tokens.size(0)
+ z_in = torch.randn(size=[2 * bs, 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
+ z_in2 = torch.randn(size=[bs, 4 * self.opt.ngf]).to(self.mutil_real_A1_tokens.device)
+ # 将 real_A, real_B 拼接 (如 nce_idt=True),并同样处理 real_A_noisy 与 XtB
+ self.real = self.mutil_real_A0_tokens
+ self.realt = self.real_A_noisy
+
+ if self.opt.flip_equivariance:
+ self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
+ if self.flipped_for_equivariance:
+ self.real = torch.flip(self.real, [3])
+ self.realt = torch.flip(self.realt, [3])
+
+ # 使用 netG 生成最终的 fake, fake_B2 等结果
+ self.fake_B = self.netG(self.realt, self.time, z_in)
+ self.fake_B2 = self.netG(self.real, self.time, z_in2)
+
+ self.fake_B = self.resize(self.fake_B)
+ self.fake_B2 = self.resize(self.fake_B2)
+
+ self.fake_B0 = self.fake_B
+ self.fake_B1 = self.fake_B2
+
+ # 使用VIT
+ self.mutil_fake_B0_tokens = self.netPreViT(self.fake_B, self.atten_layers, get_tokens=True)
+ self.mutil_fake_B1_tokens = self.netPreViT(self.fake_B2, self.atten_layers, get_tokens=True)
+
+ # ============ 第四步:推理模式下的多次采样 ============
+ if self.opt.phase == 'test':
+ tau = self.opt.tau
+ T = self.opt.num_timesteps
+ incs = np.array([0] + [1/(i+1) for i in range(T-1)])
+ times = np.cumsum(incs)
+ times = times / times[-1]
+ times = 0.5 * times[-1] + 0.5 * times
+ times = np.concatenate([np.zeros(1),times])
+ times = torch.tensor(times).float().cuda()
+ self.times = times
+ bs = self.real.size(0)
+ time_idx = (torch.randint(T, size=[1]).cuda() * torch.ones(size=[1]).cuda()).long()
+ self.time_idx = time_idx
+ visuals = []
+ with torch.no_grad():
+ self.netG.eval()
+ for t in range(self.opt.num_timesteps):
+
+ if t > 0:
+ delta = times[t] - times[t-1]
+ denom = times[-1] - times[t-1]
+ inter = (delta / denom).reshape(-1,1,1,1)
+ scale = (delta * (1 - delta / denom)).reshape(-1,1,1,1)
+ Xt = self.mutil_real_A0_tokens if (t == 0) else (1-inter) * Xt + inter * Xt_1.detach() + (scale * tau).sqrt() * torch.randn_like(Xt).to(self.mutil_real_A0_tokens.device)
+ time_idx = (t * torch.ones(size=[self.mutil_real_A0_tokens.shape[0]]).to(self.mutil_real_A0_tokens.device)).long()
+ time = times[time_idx]
+ z = torch.randn(size=[self.mutil_real_A0_tokens.shape[0], 4 * self.opt.ngf]).to(self.mutil_real_A0_tokens.device)
+ Xt_1 = self.netG(Xt, time_idx, z)
+
+ setattr(self, "fake_"+str(t+1), Xt_1)
+
+ if self.opt.phase == 'train':
+ # 真实图像的梯度
+ real_gradient = torch.autograd.grad(self.real_B.sum(), self.real_B, create_graph=True)[0]
+ # 生成图像的梯度
+ fake_gradient = torch.autograd.grad(self.fake_B.sum(), self.fake_B, create_graph=True)[0]
+ # 梯度图
+ self.weight_real, self.weight_fake = self.cao.generate_weight_map(real_gradient, fake_gradient)
+
+ # 生成图像的CTN光流图
+ self.f_content = self.ctn(self.weight_fake)
+
+ # 把前面生成后的图片再加上noisy_map
+ self.fake_B_2 = self.fake_B + self.noisy_map
+
+ # 变换后的图片
+ wapped_fake_B = warp(self.fake_B, self.f_content)
+
+ # 经过第二次生成器
+ self.fake_B_2 = self.netG(wapped_fake_B, self.time, z_in)
+
+ def compute_D_loss(self):
+ """计算判别器的 GAN 损失"""
+
+ fake = self.cat_results(self.fake_B.detach())
+ pred_fake = self.netD(fake, self.time)
+ self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
+
+ self.pred_real = self.netD(self.real_B0, self.time)
+ loss_D_real = self.criterionGAN(self.pred_real, True)
+ self.loss_D_real = loss_D_real.mean()
+
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
+ return self.loss_D
+
+ def compute_E_loss(self):
+ """计算判别器 E 的损失"""
+
+ XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B.detach()], dim=1)
+ XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2.detach()], dim=1)
+ temp = torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0).mean()
+ self.loss_E = -self.netE(XtXt_1, self.time, XtXt_1).mean() + temp + temp**2
+
+ return self.loss_E
+
+ def compute_G_loss(self):
+ """计算生成器的 GAN 损失"""
+
+ bs = self.mutil_real_A0_tokens.size(0)
+ tau = self.opt.tau
+
+ fake = self.fake_B
+ std = torch.rand(size=[1]).item() * self.opt.std
+
+ if self.opt.lambda_GAN > 0.0:
+ pred_fake = self.netD(fake, self.time)
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
+ else:
+ self.loss_G_GAN = 0.0
+ self.loss_SB = 0
+ if self.opt.lambda_SB > 0.0:
+ XtXt_1 = torch.cat([self.real_A_noisy, self.fake_B], dim=1)
+ XtXt_2 = torch.cat([self.real_A_noisy2, self.fake_B2], dim=1)
+
+ bs = self.opt.batch_size
+
+ # eq.9
+ ET_XY = self.netE(XtXt_1, self.time, XtXt_1).mean() - torch.logsumexp(self.netE(XtXt_1, self.time, XtXt_2).reshape(-1), dim=0)
+ self.loss_SB = -(self.opt.num_timesteps - self.time[0]) / self.opt.num_timesteps * self.opt.tau * ET_XY
+ self.loss_SB += self.opt.tau * torch.mean((self.real_A_noisy - self.fake_B) ** 2)
+
+ if self.opt.lambda_global > 0.0:
+ loss_global = self.calculate_similarity(self.mutil_real_A0_tokens, self.mutil_fake_B0_tokens) + self.calculate_similarity(self.mutil_real_A1_tokens, self.mutil_fake_B1_tokens)
+ loss_global *= 0.5
+ else:
+ loss_global = 0.0
+
+ if self.opt.lambda_ctn > 0.0:
+ wapped_fake_B = warp(self.fake_B, self.f_content) # use updated self.f_content
+ self.l2_loss = F.mse_loss(self.fake_B_2, wapped_fake_B) # complete the loss calculation
+
+ self.loss_G = self.loss_G_GAN + self.opt.lambda_SB * self.loss_SB + self.opt.lambda_ctn * self.l2_loss + loss_global * self.opt.lambda_global
+ return self.loss_G
+
+ def calculate_attention_loss(self):
+ n_layers = len(self.atten_layers)
+ mutil_real_A0_tokens = self.mutil_real_A0_tokens
+ mutil_real_A1_tokens = self.mutil_real_A1_tokens
+ mutil_fake_B0_tokens = self.mutil_fake_B0_tokens
+ mutil_fake_B1_tokens = self.mutil_fake_B1_tokens
+
+
+ if self.opt.lambda_global > 0.0:
+ loss_global = self.calculate_similarity(mutil_real_A0_tokens, mutil_fake_B0_tokens) + self.calculate_similarity(mutil_real_A1_tokens, mutil_fake_B1_tokens)
+ loss_global *= 0.5
+
+ else:
+ loss_global = 0.0
+
+ if self.opt.lambda_spatial > 0.0:
+ loss_spatial = 0.0
+ local_nums = self.opt.local_nums
+ tokens_cnt = 576
+ local_id = np.random.permutation(tokens_cnt)
+ local_id = local_id[:int(min(local_nums, tokens_cnt))]
+
+ mutil_real_A0_local_tokens = self.netPreViT(self.resize(self.real_A0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
+ mutil_real_A1_local_tokens = self.netPreViT(self.resize(self.real_A1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
+
+ mutil_fake_B0_local_tokens = self.netPreViT(self.resize(self.fake_B0), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
+ mutil_fake_B1_local_tokens = self.netPreViT(self.resize(self.fake_B1), self.atten_layers, get_tokens=True, local_id=local_id, side_length=self.opt.side_length)
+
+ loss_spatial = self.calculate_similarity(mutil_real_A0_local_tokens, mutil_fake_B0_local_tokens) + self.calculate_similarity(mutil_real_A1_local_tokens, mutil_fake_B1_local_tokens)
+ loss_spatial *= 0.5
+
+ else:
+ loss_spatial = 0.0
+
+ return loss_global * self.opt.lambda_global, loss_spatial * self.opt.lambda_spatial
+
+ def calculate_similarity(self, mutil_src_tokens, mutil_tgt_tokens):
+ loss = 0.0
+ n_layers = len(self.atten_layers)
+
+ for src_tokens, tgt_tokens in zip(mutil_src_tokens, mutil_tgt_tokens):
+ src_tgt = src_tokens.bmm(tgt_tokens.permute(0,2,1))
+ tgt_src = tgt_tokens.bmm(src_tokens.permute(0,2,1))
+ cos_dis_global = F.cosine_similarity(src_tgt, tgt_src, dim=-1)
+ loss += self.criterionL1(torch.ones_like(cos_dis_global), cos_dis_global).mean()
+
+ loss = loss / n_layers
+ return loss
+
+
+
\ No newline at end of file
diff --git a/models/stylegan_networks.py b/models/stylegan_networks.py
new file mode 100644
index 0000000..a3c625d
--- /dev/null
+++ b/models/stylegan_networks.py
@@ -0,0 +1,914 @@
+"""
+The network architectures is based on PyTorch implemenation of StyleGAN2Encoder.
+Original PyTorch repo: https://github.com/rosinality/style-based-gan-pytorch
+Origianl StyelGAN2 paper: https://github.com/NVlabs/stylegan2
+We use the network architeture for our single-image traning setting.
+"""
+
+import math
+import numpy as np
+import random
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
+ return F.leaky_relu(input + bias, negative_slope) * scale
+
+
+class FusedLeakyReLU(nn.Module):
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
+ super().__init__()
+ self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ # print("FusedLeakyReLU: ", input.abs().mean())
+ out = fused_leaky_relu(input, self.bias,
+ self.negative_slope,
+ self.scale)
+ # print("FusedLeakyReLU: ", out.abs().mean())
+ return out
+
+
+def upfirdn2d_native(
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
+):
+ _, minor, in_h, in_w = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
+
+ out = F.pad(
+ out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
+ )
+ out = out[
+ :,
+ :,
+ max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
+ max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0),
+ ]
+
+ # out = out.permute(0, 3, 1, 2)
+ out = out.reshape(
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
+ )
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ # out = out.permute(0, 2, 3, 1)
+
+ return out[:, :, ::down_y, ::down_x]
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
+
+
+class PixelNorm(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, input):
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
+
+
+def make_kernel(k):
+ k = torch.tensor(k, dtype=torch.float32)
+
+ if len(k.shape) == 1:
+ k = k[None, :] * k[:, None]
+
+ k /= k.sum()
+
+ return k
+
+
+class Upsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel) * (factor ** 2)
+ self.register_buffer('kernel', kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
+
+ return out
+
+
+class Downsample(nn.Module):
+ def __init__(self, kernel, factor=2):
+ super().__init__()
+
+ self.factor = factor
+ kernel = make_kernel(kernel)
+ self.register_buffer('kernel', kernel)
+
+ p = kernel.shape[0] - factor
+
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.pad = (pad0, pad1)
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
+
+ return out
+
+
+class Blur(nn.Module):
+ def __init__(self, kernel, pad, upsample_factor=1):
+ super().__init__()
+
+ kernel = make_kernel(kernel)
+
+ if upsample_factor > 1:
+ kernel = kernel * (upsample_factor ** 2)
+
+ self.register_buffer('kernel', kernel)
+
+ self.pad = pad
+
+ def forward(self, input):
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
+
+ return out
+
+
+class EqualConv2d(nn.Module):
+ def __init__(
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
+ )
+ self.scale = math.sqrt(1) / math.sqrt(in_channel * (kernel_size ** 2))
+
+ self.stride = stride
+ self.padding = padding
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channel))
+
+ else:
+ self.bias = None
+
+ def forward(self, input):
+ # print("Before EqualConv2d: ", input.abs().mean())
+ out = F.conv2d(
+ input,
+ self.weight * self.scale,
+ bias=self.bias,
+ stride=self.stride,
+ padding=self.padding,
+ )
+ # print("After EqualConv2d: ", out.abs().mean(), (self.weight * self.scale).abs().mean())
+
+ return out
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
+ )
+
+
+class EqualLinear(nn.Module):
+ def __init__(
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
+ ):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
+
+ else:
+ self.bias = None
+
+ self.activation = activation
+
+ self.scale = (math.sqrt(1) / math.sqrt(in_dim)) * lr_mul
+ self.lr_mul = lr_mul
+
+ def forward(self, input):
+ if self.activation:
+ out = F.linear(input, self.weight * self.scale)
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
+
+ else:
+ out = F.linear(
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
+ )
+
+ return out
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
+ )
+
+
+class ScaledLeakyReLU(nn.Module):
+ def __init__(self, negative_slope=0.2):
+ super().__init__()
+
+ self.negative_slope = negative_slope
+
+ def forward(self, input):
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
+
+ return out * math.sqrt(2)
+
+
+class ModulatedConv2d(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ demodulate=True,
+ upsample=False,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ ):
+ super().__init__()
+
+ self.eps = 1e-8
+ self.kernel_size = kernel_size
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+ self.upsample = upsample
+ self.downsample = downsample
+
+ if upsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
+ pad0 = (p + 1) // 2 + factor - 1
+ pad1 = p // 2 + 1
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
+
+ fan_in = in_channel * kernel_size ** 2
+ self.scale = math.sqrt(1) / math.sqrt(fan_in)
+ self.padding = kernel_size // 2
+
+ self.weight = nn.Parameter(
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
+ )
+
+ if style_dim is not None and style_dim > 0:
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
+
+ self.demodulate = demodulate
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
+ f'upsample={self.upsample}, downsample={self.downsample})'
+ )
+
+ def forward(self, input, style):
+ batch, in_channel, height, width = input.shape
+
+ if style is not None:
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
+ else:
+ style = torch.ones(batch, 1, in_channel, 1, 1).cuda()
+ weight = self.scale * self.weight * style
+
+ if self.demodulate:
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
+
+ weight = weight.view(
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+
+ if self.upsample:
+ input = input.view(1, batch * in_channel, height, width)
+ weight = weight.view(
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
+ )
+ weight = weight.transpose(1, 2).reshape(
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
+ )
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+ out = self.blur(out)
+
+ elif self.downsample:
+ input = self.blur(input)
+ _, _, height, width = input.shape
+ input = input.view(1, batch * in_channel, height, width)
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ else:
+ input = input.view(1, batch * in_channel, height, width)
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
+ _, _, height, width = out.shape
+ out = out.view(batch, self.out_channel, height, width)
+
+ return out
+
+
+class NoiseInjection(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.zeros(1))
+
+ def forward(self, image, noise=None):
+ if noise is None:
+ batch, _, height, width = image.shape
+ noise = image.new_empty(batch, 1, height, width).normal_()
+
+ return image + self.weight * noise
+
+
+class ConstantInput(nn.Module):
+ def __init__(self, channel, size=4):
+ super().__init__()
+
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
+
+ def forward(self, input):
+ batch = input.shape[0]
+ out = self.input.repeat(batch, 1, 1, 1)
+
+ return out
+
+
+class StyledConv(nn.Module):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim=None,
+ upsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ demodulate=True,
+ inject_noise=True,
+ ):
+ super().__init__()
+
+ self.inject_noise = inject_noise
+ self.conv = ModulatedConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ style_dim,
+ upsample=upsample,
+ blur_kernel=blur_kernel,
+ demodulate=demodulate,
+ )
+
+ self.noise = NoiseInjection()
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
+ # self.activate = ScaledLeakyReLU(0.2)
+ self.activate = FusedLeakyReLU(out_channel)
+
+ def forward(self, input, style=None, noise=None):
+ out = self.conv(input, style)
+ if self.inject_noise:
+ out = self.noise(out, noise=noise)
+ # out = out + self.bias
+ out = self.activate(out)
+
+ return out
+
+
+class ToRGB(nn.Module):
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ if upsample:
+ self.upsample = Upsample(blur_kernel)
+
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
+
+ def forward(self, input, style, skip=None):
+ out = self.conv(input, style)
+ out = out + self.bias
+
+ if skip is not None:
+ skip = self.upsample(skip)
+
+ out = out + skip
+
+ return out
+
+
+class Generator(nn.Module):
+ def __init__(
+ self,
+ size,
+ style_dim,
+ n_mlp,
+ channel_multiplier=2,
+ blur_kernel=[1, 3, 3, 1],
+ lr_mlp=0.01,
+ ):
+ super().__init__()
+
+ self.size = size
+
+ self.style_dim = style_dim
+
+ layers = [PixelNorm()]
+
+ for i in range(n_mlp):
+ layers.append(
+ EqualLinear(
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
+ )
+ )
+
+ self.style = nn.Sequential(*layers)
+
+ self.channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256 * channel_multiplier,
+ 128: 128 * channel_multiplier,
+ 256: 64 * channel_multiplier,
+ 512: 32 * channel_multiplier,
+ 1024: 16 * channel_multiplier,
+ }
+
+ self.input = ConstantInput(self.channels[4])
+ self.conv1 = StyledConv(
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
+ )
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
+
+ self.log_size = int(math.log(size, 2))
+ self.num_layers = (self.log_size - 2) * 2 + 1
+
+ self.convs = nn.ModuleList()
+ self.upsamples = nn.ModuleList()
+ self.to_rgbs = nn.ModuleList()
+ self.noises = nn.Module()
+
+ in_channel = self.channels[4]
+
+ for layer_idx in range(self.num_layers):
+ res = (layer_idx + 5) // 2
+ shape = [1, 1, 2 ** res, 2 ** res]
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
+
+ for i in range(3, self.log_size + 1):
+ out_channel = self.channels[2 ** i]
+
+ self.convs.append(
+ StyledConv(
+ in_channel,
+ out_channel,
+ 3,
+ style_dim,
+ upsample=True,
+ blur_kernel=blur_kernel,
+ )
+ )
+
+ self.convs.append(
+ StyledConv(
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
+ )
+ )
+
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
+
+ in_channel = out_channel
+
+ self.n_latent = self.log_size * 2 - 2
+
+ def make_noise(self):
+ device = self.input.input.device
+
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
+
+ for i in range(3, self.log_size + 1):
+ for _ in range(2):
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
+
+ return noises
+
+ def mean_latent(self, n_latent):
+ latent_in = torch.randn(
+ n_latent, self.style_dim, device=self.input.input.device
+ )
+ latent = self.style(latent_in).mean(0, keepdim=True)
+
+ return latent
+
+ def get_latent(self, input):
+ return self.style(input)
+
+ def forward(
+ self,
+ styles,
+ return_latents=False,
+ inject_index=None,
+ truncation=1,
+ truncation_latent=None,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ ):
+ if not input_is_latent:
+ styles = [self.style(s) for s in styles]
+
+ if noise is None:
+ if randomize_noise:
+ noise = [None] * self.num_layers
+ else:
+ noise = [
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
+ ]
+
+ if truncation < 1:
+ style_t = []
+
+ for style in styles:
+ style_t.append(
+ truncation_latent + truncation * (style - truncation_latent)
+ )
+
+ styles = style_t
+
+ if len(styles) < 2:
+ inject_index = self.n_latent
+
+ if len(styles[0].shape) < 3:
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+
+ else:
+ latent = styles[0]
+
+ else:
+ if inject_index is None:
+ inject_index = random.randint(1, self.n_latent - 1)
+
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
+
+ latent = torch.cat([latent, latent2], 1)
+
+ out = self.input(latent)
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
+
+ skip = self.to_rgb1(out, latent[:, 1])
+
+ i = 1
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
+ ):
+ out = conv1(out, latent[:, i], noise=noise1)
+ out = conv2(out, latent[:, i + 1], noise=noise2)
+ skip = to_rgb(out, latent[:, i + 2], skip)
+
+ i += 2
+
+ image = skip
+
+ if return_latents:
+ return image, latent
+
+ else:
+ return image, None
+
+
+class ConvLayer(nn.Sequential):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ bias=True,
+ activate=True,
+ ):
+ layers = []
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
+
+ stride = 2
+ self.padding = 0
+
+ else:
+ stride = 1
+ self.padding = kernel_size // 2
+
+ layers.append(
+ EqualConv2d(
+ in_channel,
+ out_channel,
+ kernel_size,
+ padding=self.padding,
+ stride=stride,
+ bias=bias and not activate,
+ )
+ )
+
+ if activate:
+ if bias:
+ layers.append(FusedLeakyReLU(out_channel))
+
+ else:
+ layers.append(ScaledLeakyReLU(0.2))
+
+ super().__init__(*layers)
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample=True, skip_gain=1.0):
+ super().__init__()
+
+ self.skip_gain = skip_gain
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample, blur_kernel=blur_kernel)
+
+ if in_channel != out_channel or downsample:
+ self.skip = ConvLayer(
+ in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False
+ )
+ else:
+ self.skip = nn.Identity()
+
+ def forward(self, input):
+ out = self.conv1(input)
+ out = self.conv2(out)
+
+ skip = self.skip(input)
+ out = (out * self.skip_gain + skip) / math.sqrt(self.skip_gain ** 2 + 1.0)
+
+ return out
+
+
+class StyleGAN2Discriminator(nn.Module):
+ def __init__(self, input_nc, ndf=64, n_layers=3, no_antialias=False, size=None, opt=None):
+ super().__init__()
+ self.opt = opt
+ self.stddev_group = 16
+ if size is None:
+ size = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
+ if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
+ size = 2 ** int(np.log2(self.opt.D_patch_size))
+
+ blur_kernel = [1, 3, 3, 1]
+ channel_multiplier = ndf / 64
+ channels = {
+ 4: min(384, int(4096 * channel_multiplier)),
+ 8: min(384, int(2048 * channel_multiplier)),
+ 16: min(384, int(1024 * channel_multiplier)),
+ 32: min(384, int(512 * channel_multiplier)),
+ 64: int(256 * channel_multiplier),
+ 128: int(128 * channel_multiplier),
+ 256: int(64 * channel_multiplier),
+ 512: int(32 * channel_multiplier),
+ 1024: int(16 * channel_multiplier),
+ }
+
+ convs = [ConvLayer(3, channels[size], 1)]
+
+ log_size = int(math.log(size, 2))
+
+ in_channel = channels[size]
+
+ if "smallpatch" in self.opt.netD:
+ final_res_log2 = 4
+ elif "patch" in self.opt.netD:
+ final_res_log2 = 3
+ else:
+ final_res_log2 = 2
+
+ for i in range(log_size, final_res_log2, -1):
+ out_channel = channels[2 ** (i - 1)]
+
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
+
+ in_channel = out_channel
+
+ self.convs = nn.Sequential(*convs)
+
+ if False and "tile" in self.opt.netD:
+ in_channel += 1
+ self.final_conv = ConvLayer(in_channel, channels[4], 3)
+ if "patch" in self.opt.netD:
+ self.final_linear = ConvLayer(channels[4], 1, 3, bias=False, activate=False)
+ else:
+ self.final_linear = nn.Sequential(
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
+ EqualLinear(channels[4], 1),
+ )
+
+ def forward(self, input, get_minibatch_features=False):
+ if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
+ h, w = input.size(2), input.size(3)
+ y = torch.randint(h - self.opt.D_patch_size, ())
+ x = torch.randint(w - self.opt.D_patch_size, ())
+ input = input[:, :, y:y + self.opt.D_patch_size, x:x + self.opt.D_patch_size]
+ out = input
+ for i, conv in enumerate(self.convs):
+ out = conv(out)
+ # print(i, out.abs().mean())
+ # out = self.convs(input)
+
+ batch, channel, height, width = out.shape
+
+ if False and "tile" in self.opt.netD:
+ group = min(batch, self.stddev_group)
+ stddev = out.view(
+ group, -1, 1, channel // 1, height, width
+ )
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
+ stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
+ stddev = stddev.repeat(group, 1, height, width)
+ out = torch.cat([out, stddev], 1)
+
+ out = self.final_conv(out)
+ # print(out.abs().mean())
+
+ if "patch" not in self.opt.netD:
+ out = out.view(batch, -1)
+ out = self.final_linear(out)
+
+ return out
+
+
+class TileStyleGAN2Discriminator(StyleGAN2Discriminator):
+ def forward(self, input):
+ B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
+ size = self.opt.D_patch_size
+ Y = H // size
+ X = W // size
+ input = input.view(B, C, Y, size, X, size)
+ input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
+ return super().forward(input)
+
+
+class StyleGAN2Encoder(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
+ super().__init__()
+ assert opt is not None
+ self.opt = opt
+ channel_multiplier = ngf / 32
+ channels = {
+ 4: min(512, int(round(4096 * channel_multiplier))),
+ 8: min(512, int(round(2048 * channel_multiplier))),
+ 16: min(512, int(round(1024 * channel_multiplier))),
+ 32: min(512, int(round(512 * channel_multiplier))),
+ 64: int(round(256 * channel_multiplier)),
+ 128: int(round(128 * channel_multiplier)),
+ 256: int(round(64 * channel_multiplier)),
+ 512: int(round(32 * channel_multiplier)),
+ 1024: int(round(16 * channel_multiplier)),
+ }
+
+ blur_kernel = [1, 3, 3, 1]
+
+ cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
+ convs = [nn.Identity(),
+ ConvLayer(3, channels[cur_res], 1)]
+
+ num_downsampling = self.opt.stylegan2_G_num_downsampling
+ for i in range(num_downsampling):
+ in_channel = channels[cur_res]
+ out_channel = channels[cur_res // 2]
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel, downsample=True))
+ cur_res = cur_res // 2
+
+ for i in range(n_blocks // 2):
+ n_channel = channels[cur_res]
+ convs.append(ResBlock(n_channel, n_channel, downsample=False))
+
+ self.convs = nn.Sequential(*convs)
+
+ def forward(self, input, layers=[], get_features=False):
+ feat = input
+ feats = []
+ if -1 in layers:
+ layers.append(len(self.convs) - 1)
+ for layer_id, layer in enumerate(self.convs):
+ feat = layer(feat)
+ # print(layer_id, " features ", feat.abs().mean())
+ if layer_id in layers:
+ feats.append(feat)
+
+ if get_features:
+ return feat, feats
+ else:
+ return feat
+
+
+class StyleGAN2Decoder(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
+ super().__init__()
+ assert opt is not None
+ self.opt = opt
+
+ blur_kernel = [1, 3, 3, 1]
+
+ channel_multiplier = ngf / 32
+ channels = {
+ 4: min(512, int(round(4096 * channel_multiplier))),
+ 8: min(512, int(round(2048 * channel_multiplier))),
+ 16: min(512, int(round(1024 * channel_multiplier))),
+ 32: min(512, int(round(512 * channel_multiplier))),
+ 64: int(round(256 * channel_multiplier)),
+ 128: int(round(128 * channel_multiplier)),
+ 256: int(round(64 * channel_multiplier)),
+ 512: int(round(32 * channel_multiplier)),
+ 1024: int(round(16 * channel_multiplier)),
+ }
+
+ num_downsampling = self.opt.stylegan2_G_num_downsampling
+ cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) // (2 ** num_downsampling)
+ convs = []
+
+ for i in range(n_blocks // 2):
+ n_channel = channels[cur_res]
+ convs.append(ResBlock(n_channel, n_channel, downsample=False))
+
+ for i in range(num_downsampling):
+ in_channel = channels[cur_res]
+ out_channel = channels[cur_res * 2]
+ inject_noise = "small" not in self.opt.netG
+ convs.append(
+ StyledConv(in_channel, out_channel, 3, upsample=True, blur_kernel=blur_kernel, inject_noise=inject_noise)
+ )
+ cur_res = cur_res * 2
+
+ convs.append(ConvLayer(channels[cur_res], 3, 1))
+
+ self.convs = nn.Sequential(*convs)
+
+ def forward(self, input):
+ return self.convs(input)
+
+
+class StyleGAN2Generator(nn.Module):
+ def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
+ super().__init__()
+ self.opt = opt
+ self.encoder = StyleGAN2Encoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
+ self.decoder = StyleGAN2Decoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
+
+ def forward(self, input, layers=[], encode_only=False):
+ feat, feats = self.encoder(input, layers, True)
+ if encode_only:
+ return feats
+ else:
+ fake = self.decoder(feat)
+
+ if len(layers) > 0:
+ return fake, feats
+ else:
+ return fake
diff --git a/models/template_model.py b/models/template_model.py
new file mode 100644
index 0000000..68cdaf6
--- /dev/null
+++ b/models/template_model.py
@@ -0,0 +1,99 @@
+"""Model class template
+
+This module provides a template for users to implement custom models.
+You can specify '--model template' to use this model.
+The class name should be consistent with both the filename and its model option.
+The filename should be _dataset.py
+The class name should be Dataset.py
+It implements a simple image-to-image translation baseline based on regression loss.
+Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
+ min_ ||netG(data_A) - data_B||_1
+You need to implement the following functions:
+ : Add model-specific options and rewrite default values for existing options.
+ <__init__>: Initialize this model class.
+ : Unpack input data and perform data pre-processing.
+ : Run forward pass. This will be called by both and .
+ : Update network weights; it will be called in every training iteration.
+"""
+import torch
+from .base_model import BaseModel
+from . import networks
+
+
+class TemplateModel(BaseModel):
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """Add new model-specific options and rewrite default values for existing options.
+
+ Parameters:
+ parser -- the option parser
+ is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
+ if is_train:
+ parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
+
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this model class.
+
+ Parameters:
+ opt -- training/test options
+
+ A few things can be done here.
+ - (required) call the initialization function of BaseModel
+ - define loss function, visualization images, model names, and optimizers
+ """
+ BaseModel.__init__(self, opt) # call the initialization method of BaseModel
+ # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
+ self.loss_names = ['loss_G']
+ # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
+ self.visual_names = ['data_A', 'data_B', 'output']
+ # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
+ # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
+ self.model_names = ['G']
+ # define networks; you can use opt.isTrain to specify different behaviors for training and test.
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
+ if self.isTrain: # only defined during training time
+ # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
+ # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
+ self.criterionLoss = torch.nn.L1Loss()
+ # define and initialize optimizers. You can define one optimizer for each network.
+ # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizers = [self.optimizer]
+
+ # Our program will automatically call to define schedulers, load networks, and print networks
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input: a dictionary that contains the data itself and its metadata information.
+ """
+ AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B
+ self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
+ self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
+ self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
+
+ def forward(self):
+ """Run forward pass. This will be called by both functions and ."""
+ self.output = self.netG(self.data_A) # generate output image given the input data_A
+
+ def backward(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ # caculate the intermediate results if necessary; here self.output has been computed during function
+ # calculate loss given the input and intermediate results
+ self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
+ self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
+
+ def optimize_parameters(self):
+ """Update network weights; it will be called in every training iteration."""
+ self.forward() # first call forward to calculate intermediate results
+ self.optimizer.zero_grad() # clear network G's existing gradients
+ self.backward() # calculate gradients for network G
+ self.optimizer.step() # update gradients for network G
diff --git a/models/util/__pycache__/pos_embed.cpython-36.pyc b/models/util/__pycache__/pos_embed.cpython-36.pyc
new file mode 100644
index 0000000..b5141c6
Binary files /dev/null and b/models/util/__pycache__/pos_embed.cpython-36.pyc differ
diff --git a/models/util/crop.py b/models/util/crop.py
new file mode 100644
index 0000000..fcb2612
--- /dev/null
+++ b/models/util/crop.py
@@ -0,0 +1,42 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+import torch
+
+from torchvision import transforms
+from torchvision.transforms import functional as F
+
+
+class RandomResizedCrop(transforms.RandomResizedCrop):
+ """
+ RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
+ This may lead to results different with torchvision's version.
+ Following BYOL's TF code:
+ https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
+ """
+ @staticmethod
+ def get_params(img, scale, ratio):
+ width, height = F._get_image_size(img)
+ area = height * width
+
+ target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
+ log_ratio = torch.log(torch.tensor(ratio))
+ aspect_ratio = torch.exp(
+ torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
+ ).item()
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ w = min(w, width)
+ h = min(h, height)
+
+ i = torch.randint(0, height - h + 1, size=(1,)).item()
+ j = torch.randint(0, width - w + 1, size=(1,)).item()
+
+ return i, j, h, w
\ No newline at end of file
diff --git a/models/util/datasets.py b/models/util/datasets.py
new file mode 100644
index 0000000..0dde1f4
--- /dev/null
+++ b/models/util/datasets.py
@@ -0,0 +1,65 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+
+import os
+import PIL
+
+from torchvision import datasets, transforms
+
+from timm.data import create_transform
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+
+
+def build_dataset(is_train, args):
+ transform = build_transform(is_train, args)
+
+ root = os.path.join(args.data_path, 'train' if is_train else 'val')
+ dataset = datasets.ImageFolder(root, transform=transform)
+
+ print(dataset)
+
+ return dataset
+
+
+def build_transform(is_train, args):
+ mean = IMAGENET_DEFAULT_MEAN
+ std = IMAGENET_DEFAULT_STD
+ # train transform
+ if is_train:
+ # this should always dispatch to transforms_imagenet_train
+ transform = create_transform(
+ input_size=args.input_size,
+ is_training=True,
+ color_jitter=args.color_jitter,
+ auto_augment=args.aa,
+ interpolation='bicubic',
+ re_prob=args.reprob,
+ re_mode=args.remode,
+ re_count=args.recount,
+ mean=mean,
+ std=std,
+ )
+ return transform
+
+ # eval transform
+ t = []
+ if args.input_size <= 224:
+ crop_pct = 224 / 256
+ else:
+ crop_pct = 1.0
+ size = int(args.input_size / crop_pct)
+ t.append(
+ transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images
+ )
+ t.append(transforms.CenterCrop(args.input_size))
+
+ t.append(transforms.ToTensor())
+ t.append(transforms.Normalize(mean, std))
+ return transforms.Compose(t)
diff --git a/models/util/lars.py b/models/util/lars.py
new file mode 100644
index 0000000..509c5f6
--- /dev/null
+++ b/models/util/lars.py
@@ -0,0 +1,47 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# LARS optimizer, implementation from MoCo v3:
+# https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+
+import torch
+
+
+class LARS(torch.optim.Optimizer):
+ """
+ LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
+ """
+ def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
+ super().__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self):
+ for g in self.param_groups:
+ for p in g['params']:
+ dp = p.grad
+
+ if dp is None:
+ continue
+
+ if p.ndim > 1: # if not normalization gamma/beta or bias
+ dp = dp.add(p, alpha=g['weight_decay'])
+ param_norm = torch.norm(p)
+ update_norm = torch.norm(dp)
+ one = torch.ones_like(param_norm)
+ q = torch.where(param_norm > 0.,
+ torch.where(update_norm > 0,
+ (g['trust_coefficient'] * param_norm / update_norm), one),
+ one)
+ dp = dp.mul(q)
+
+ param_state = self.state[p]
+ if 'mu' not in param_state:
+ param_state['mu'] = torch.zeros_like(p)
+ mu = param_state['mu']
+ mu.mul_(g['momentum']).add_(dp)
+ p.add_(mu, alpha=-g['lr'])
\ No newline at end of file
diff --git a/models/util/lr_decay.py b/models/util/lr_decay.py
new file mode 100644
index 0000000..7fa11f1
--- /dev/null
+++ b/models/util/lr_decay.py
@@ -0,0 +1,76 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# ELECTRA https://github.com/google-research/electra
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+
+import json
+
+
+def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
+ """
+ Parameter groups for layer-wise lr decay
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
+ """
+ param_group_names = {}
+ param_groups = {}
+
+ num_layers = len(model.blocks) + 1
+
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
+
+ for n, p in model.named_parameters():
+ if not p.requires_grad:
+ continue
+
+ # no decay: all 1D parameters and model specific ones
+ if p.ndim == 1 or n in no_weight_decay_list:
+ g_decay = "no_decay"
+ this_decay = 0.
+ else:
+ g_decay = "decay"
+ this_decay = weight_decay
+
+ layer_id = get_layer_id_for_vit(n, num_layers)
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
+
+ if group_name not in param_group_names:
+ this_scale = layer_scales[layer_id]
+
+ param_group_names[group_name] = {
+ "lr_scale": this_scale,
+ "weight_decay": this_decay,
+ "params": [],
+ }
+ param_groups[group_name] = {
+ "lr_scale": this_scale,
+ "weight_decay": this_decay,
+ "params": [],
+ }
+
+ param_group_names[group_name]["params"].append(n)
+ param_groups[group_name]["params"].append(p)
+
+ # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
+
+ return list(param_groups.values())
+
+
+def get_layer_id_for_vit(name, num_layers):
+ """
+ Assign a parameter with its layer id
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
+ """
+ if name in ['cls_token', 'pos_embed']:
+ return 0
+ elif name.startswith('patch_embed'):
+ return 0
+ elif name.startswith('blocks'):
+ return int(name.split('.')[1]) + 1
+ else:
+ return num_layers
\ No newline at end of file
diff --git a/models/util/lr_sched.py b/models/util/lr_sched.py
new file mode 100644
index 0000000..4cb682b
--- /dev/null
+++ b/models/util/lr_sched.py
@@ -0,0 +1,21 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+def adjust_learning_rate(optimizer, epoch, args):
+ """Decay the learning rate with half-cycle cosine after warmup"""
+ if epoch < args.warmup_epochs:
+ lr = args.lr * epoch / args.warmup_epochs
+ else:
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
+ for param_group in optimizer.param_groups:
+ if "lr_scale" in param_group:
+ param_group["lr"] = lr * param_group["lr_scale"]
+ else:
+ param_group["lr"] = lr
+ return lr
diff --git a/models/util/misc.py b/models/util/misc.py
new file mode 100644
index 0000000..ad9a786
--- /dev/null
+++ b/models/util/misc.py
@@ -0,0 +1,340 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+
+import builtins
+import datetime
+import os
+import time
+from collections import defaultdict, deque
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+from torch._six import inf
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ builtin_print = builtins.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ force = force or (get_world_size() > 8)
+ if is_master or force:
+ now = datetime.datetime.now().time()
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
+ builtin_print(*args, **kwargs)
+
+ builtins.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ if args.dist_on_itp:
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+ args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
+ os.environ['LOCAL_RANK'] = str(args.gpu)
+ os.environ['RANK'] = str(args.rank)
+ os.environ['WORLD_SIZE'] = str(args.world_size)
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ setup_for_distributed(is_master=True) # hack
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}, gpu {}'.format(
+ args.rank, args.dist_url, args.gpu), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+class NativeScalerWithGradNormCount:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self):
+ self._scaler = torch.cuda.amp.GradScaler()
+
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if update_grad:
+ if clip_grad is not None:
+ assert parameters is not None
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ else:
+ self._scaler.unscale_(optimizer)
+ norm = get_grad_norm_(parameters)
+ self._scaler.step(optimizer)
+ self._scaler.update()
+ else:
+ norm = None
+ return norm
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
+ return total_norm
+
+
+def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
+ output_dir = Path(args.output_dir)
+ epoch_name = str(epoch)
+ if loss_scaler is not None:
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
+ for checkpoint_path in checkpoint_paths:
+ to_save = {
+ 'model': model_without_ddp.state_dict(),
+ 'optimizer': optimizer.state_dict(),
+ 'epoch': epoch,
+ 'scaler': loss_scaler.state_dict(),
+ 'args': args,
+ }
+
+ save_on_master(to_save, checkpoint_path)
+ else:
+ client_state = {'epoch': epoch}
+ model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
+
+
+def load_model(args, model_without_ddp, optimizer, loss_scaler):
+ if args.resume:
+ if args.resume.startswith('https'):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ args.resume, map_location='cpu', check_hash=True)
+ else:
+ checkpoint = torch.load(args.resume, map_location='cpu')
+ model_without_ddp.load_state_dict(checkpoint['model'])
+ print("Resume checkpoint %s" % args.resume)
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ args.start_epoch = checkpoint['epoch'] + 1
+ if 'scaler' in checkpoint:
+ loss_scaler.load_state_dict(checkpoint['scaler'])
+ print("With optim & sched!")
+
+
+def all_reduce_mean(x):
+ world_size = get_world_size()
+ if world_size > 1:
+ x_reduce = torch.tensor(x).cuda()
+ dist.all_reduce(x_reduce)
+ x_reduce /= world_size
+ return x_reduce.item()
+ else:
+ return x
\ No newline at end of file
diff --git a/models/util/pos_embed.py b/models/util/pos_embed.py
new file mode 100644
index 0000000..6acf8bd
--- /dev/null
+++ b/models/util/pos_embed.py
@@ -0,0 +1,96 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# Position embedding utils
+# --------------------------------------------------------
+
+import numpy as np
+
+import torch
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model, checkpoint_model):
+ if 'pos_embed' in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed
diff --git a/options/__init__.py b/options/__init__.py
new file mode 100644
index 0000000..e7eedeb
--- /dev/null
+++ b/options/__init__.py
@@ -0,0 +1 @@
+"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
diff --git a/options/__pycache__/__init__.cpython-36.pyc b/options/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..2dc8563
Binary files /dev/null and b/options/__pycache__/__init__.cpython-36.pyc differ
diff --git a/options/__pycache__/base_options.cpython-36.pyc b/options/__pycache__/base_options.cpython-36.pyc
new file mode 100644
index 0000000..45d1cc8
Binary files /dev/null and b/options/__pycache__/base_options.cpython-36.pyc differ
diff --git a/options/__pycache__/test_options.cpython-36.pyc b/options/__pycache__/test_options.cpython-36.pyc
new file mode 100644
index 0000000..cefdf9a
Binary files /dev/null and b/options/__pycache__/test_options.cpython-36.pyc differ
diff --git a/options/__pycache__/train_options.cpython-36.pyc b/options/__pycache__/train_options.cpython-36.pyc
new file mode 100644
index 0000000..879b9d3
Binary files /dev/null and b/options/__pycache__/train_options.cpython-36.pyc differ
diff --git a/options/base_options.py b/options/base_options.py
new file mode 100644
index 0000000..5837dd5
--- /dev/null
+++ b/options/base_options.py
@@ -0,0 +1,167 @@
+import argparse
+import os
+from util import util
+import torch
+import models
+import data
+
+
+class BaseOptions():
+ """This class defines options used during both training and test time.
+
+ It also implements several helper functions such as parsing, printing, and saving the options.
+ It also gathers additional options defined in functions in both dataset class and model class.
+ """
+
+ def __init__(self, cmd_line=None):
+ """Reset the class; indicates the class hasn't been initailized"""
+ self.initialized = False
+ self.cmd_line = None
+ if cmd_line is not None:
+ self.cmd_line = cmd_line.split()
+
+ def initialize(self, parser):
+ """Define the common options that are used in both training and test."""
+ # basic parameters
+ parser.add_argument('--dataroot', default='placeholder', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
+ parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
+ parser.add_argument('--easy_label', type=str, default='experiment_name', help='Interpretable name')
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
+ parser.add_argument('--use_idt', action='store_true', help='use_idt')
+ parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
+ # model parameters
+ parser.add_argument('--model', type=str, default='cut', help='chooses which model to use.')
+ parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
+ parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
+ parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
+ parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
+ parser.add_argument('--netD', type=str, default='basic', choices=['basic', 'n_layers', 'pixel', 'patch', 'tilestylegan2', 'stylegan2'], help='specify discriminator architecture. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
+ parser.add_argument('--netG', type=str, default='resnet_9blocks', choices=['resnet_9blocks','resnet_9blocks_mask', 'resnet_6blocks', 'unet_256', 'unet_128', 'stylegan2', 'smallstylegan2', 'resnet_cat'], help='specify generator architecture')
+ parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
+ parser.add_argument('--normG', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for G')
+ parser.add_argument('--normD', type=str, default='instance', choices=['instance', 'batch', 'none'], help='instance normalization or batch normalization for D')
+ parser.add_argument('--init_type', type=str, default='xavier', choices=['normal', 'xavier', 'kaiming', 'orthogonal'], help='network initialization')
+ parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
+ parser.add_argument('--no_dropout', type=util.str2bool, nargs='?', const=True, default=True,
+ help='no dropout for the generator')
+ parser.add_argument('--no_antialias', action='store_true', help='if specified, use stride=2 convs instead of antialiased-downsampling (sad)')
+ parser.add_argument('--no_antialias_up', action='store_true', help='if specified, use [upconv(learned filter)] instead of [upconv(hard-coded [1,3,3,1] filter), conv]')
+ # dataset parameters
+ parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
+ parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
+ parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
+ parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
+ parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
+ parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
+ parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
+ parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
+ parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
+ parser.add_argument('--random_scale_max', type=float, default=3.0,
+ help='(used for single image translation) Randomly scale the image by the specified factor as data augmentation.')
+ # additional parameters
+ parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
+
+ # parameters related to StyleGAN2-based networks
+ parser.add_argument('--stylegan2_G_num_downsampling',
+ default=1, type=int,
+ help='Number of downsampling layers used by StyleGAN2Generator')
+
+ self.initialized = True
+ return parser
+
+ def gather_options(self):
+ """Initialize our parser with basic options(only once).
+ Add additional model-specific and dataset-specific options.
+ These options are defined in the function
+ in model and dataset classes.
+ """
+ if not self.initialized: # check if it has been initialized
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser = self.initialize(parser)
+
+ # get the basic options
+ if self.cmd_line is None:
+ opt, _ = parser.parse_known_args()
+ else:
+ opt, _ = parser.parse_known_args(self.cmd_line)
+
+ # modify model-related parser options
+ model_name = opt.model
+ model_option_setter = models.get_option_setter(model_name)
+
+ parser = model_option_setter(parser, self.isTrain)
+ if self.cmd_line is None:
+ print(parser)
+ opt, _ = parser.parse_known_args() # parse again with new defaults
+ else:
+ opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
+
+ # modify dataset-related parser options
+ dataset_name = opt.dataset_mode
+ dataset_option_setter = data.get_option_setter(dataset_name)
+ parser = dataset_option_setter(parser, self.isTrain)
+
+ # save and return the parser
+ self.parser = parser
+ if self.cmd_line is None:
+ return parser.parse_args()
+ else:
+ return parser.parse_args(self.cmd_line)
+
+ def print_options(self, opt):
+ """Print and save options
+
+ It will print both current options and default values(if different).
+ It will save options into a text file / [checkpoints_dir] / opt.txt
+ """
+ message = ''
+ message += '----------------- Options ---------------\n'
+ for k, v in sorted(vars(opt).items()):
+ comment = ''
+ default = self.parser.get_default(k)
+ if v != default:
+ comment = '\t[default: %s]' % str(default)
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
+ message += '----------------- End -------------------'
+ print(message)
+
+ # save to the disk
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ util.mkdirs(expr_dir)
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
+ try:
+ with open(file_name, 'wt') as opt_file:
+ opt_file.write(message)
+ opt_file.write('\n')
+ except PermissionError as error:
+ print("permission error {}".format(error))
+ pass
+
+ def parse(self):
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
+ opt = self.gather_options()
+ opt.isTrain = self.isTrain # train or test
+
+ # process opt.suffix
+ if opt.suffix:
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
+ opt.name = opt.name + suffix
+
+ self.print_options(opt)
+
+ # set gpu ids
+ str_ids = opt.gpu_ids.split(',')
+ opt.gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ opt.gpu_ids.append(id)
+ if len(opt.gpu_ids) > 0:
+ torch.cuda.set_device(opt.gpu_ids[0])
+
+ self.opt = opt
+ return self.opt
diff --git a/options/test_options.py b/options/test_options.py
new file mode 100644
index 0000000..e4559ad
--- /dev/null
+++ b/options/test_options.py
@@ -0,0 +1,21 @@
+from .base_options import BaseOptions
+
+
+class TestOptions(BaseOptions):
+ """This class includes test options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser) # define shared options
+ parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+ # Dropout and Batchnorm has different behavioir during training and test.
+ parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
+ parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
+
+ # To avoid cropping, the load_size should be the same as crop_size
+ parser.set_defaults(load_size=parser.get_default('crop_size'))
+ self.isTrain = False
+ return parser
diff --git a/options/train_options.py b/options/train_options.py
new file mode 100644
index 0000000..5df79aa
--- /dev/null
+++ b/options/train_options.py
@@ -0,0 +1,47 @@
+from .base_options import BaseOptions
+
+
+class TrainOptions(BaseOptions):
+ """This class includes training options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser)
+ # visdom and HTML visualization parameters
+ parser.add_argument('--display_freq', type=int, default=50, help='frequency of showing training results on screen')
+ parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
+ parser.add_argument('--display_id', type=int, default=None, help='window id of the web display. Default is random window id')
+ parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
+ parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
+ parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
+ parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
+ parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
+ parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
+ # network saving and loading parameters
+ parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
+ parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
+ parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
+ parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
+ parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
+
+ # parser.add_argument('--use_mlp', action='store_true', help='use_mlp')
+ # parser.add_argument('--use_tgt_style_src', action='store_true', help='use_tgt_style_src')
+ parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
+ parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
+ parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
+
+ # training parameters
+ parser.add_argument('--n_epochs', type=int, default=100, help='number of epochs with the initial learning rate')
+ parser.add_argument('--n_epochs_decay', type=int, default=100, help='number of epochs to linearly decay learning rate to zero')
+ parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
+ parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
+ parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
+ parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
+ parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
+ parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
+ parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
+
+ self.isTrain = True
+ return parser
diff --git a/scripts/test.sh b/scripts/test.sh
new file mode 100644
index 0000000..8e8b29a
--- /dev/null
+++ b/scripts/test.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0 python test.py --dataroot /path/of/test_dataset --checkpoints_dir ./checkpoints --name train1 --model roma_single --num_test 10000 --epoch latest
diff --git a/scripts/train.sh b/scripts/train.sh
new file mode 100644
index 0000000..f5765f3
--- /dev/null
+++ b/scripts/train.sh
@@ -0,0 +1,5 @@
+# Train for video mode
+CUDA_VISIBLE_DEVICES=0 python train.py --dataroot /path --name ROMA_name --dataset_mode unaligned_double --no_flip --local_nums 64 --display_env ROMA_env --model roma --side_length 7 --lambda_spatial 5.0 --lambda_global 5.0 --lambda_motion 1.0 --atten_layers 1,3,5 --lr 0.00001
+
+# Train for image mode
+CUDA_VISIBLE_DEVICES=0 python train.py --dataroot /path --name ROMA_name --dataset_mode unaligned --local_nums 64 --display_env ROMA_env --model roma --side_length 7 --lambda_spatial 5.0 --lambda_global 5.0 --atten_layers 1,3,5 --lr 0.00001
\ No newline at end of file
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..fddd57b
--- /dev/null
+++ b/test.py
@@ -0,0 +1,70 @@
+"""General-purpose test script for image-to-image translation.
+
+Once you have trained your model with train.py, you can use this script to test the model.
+It will load a saved model from --checkpoints_dir and save the results to --results_dir.
+
+It first creates model and dataset given the option. It will hard-code some parameters.
+It then runs inference for --num_test images and save results to an HTML file.
+
+Example (You need to train models first or download pre-trained models from our website):
+ Test a CycleGAN model (both sides):
+ python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
+
+ Test a CycleGAN model (one side only):
+ python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
+
+ The option '--model test' is used for generating CycleGAN results only for one side.
+ This option will automatically set '--dataset_mode single', which only loads the images from one set.
+ On the contrary, using '--model cycle_gan' requires loading and generating results in both directions,
+ which is sometimes unnecessary. The results will be saved at ./results/.
+ Use '--results_dir ' to specify the results directory.
+
+ Test a pix2pix model:
+ python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
+
+See options/base_options.py and options/test_options.py for more test options.
+See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
+See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
+"""
+import os
+from options.test_options import TestOptions
+from data import create_dataset
+from models import create_model
+from util.visualizer import save_images
+from util import html
+import util.util as util
+
+
+if __name__ == '__main__':
+ opt = TestOptions().parse() # get test options
+ # hard-code some parameters for test
+ opt.num_threads = 0 # test code only supports num_threads = 1
+ opt.batch_size = 1 # test code only supports batch_size = 1
+ opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
+ opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
+ opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
+ dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
+ # train_dataset = create_dataset(util.copyconf(opt, phase="train"))
+ model = create_model(opt) # create a model given opt.model and other options
+ # create a webpage for viewing the results
+ web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch)) # define the website directory
+ print('creating web directory', web_dir)
+ webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
+
+ for i, data in enumerate(dataset):
+ if i == 0:
+ model.data_dependent_initialize(data)
+ model.setup(opt) # regular setup: load and print networks; create schedulers
+ model.parallelize()
+ if opt.eval:
+ model.eval()
+ if i >= opt.num_test: # only apply our model to opt.num_test images.
+ break
+ model.set_input(data) # unpack data from data loader
+ model.test() # run inference
+ visuals = model.get_current_visuals() # get image results
+ img_path = model.get_image_paths() # get image paths
+ if i % 5 == 0: # save images to an HTML file
+ print('processing (%04d)-th image... %s' % (i, img_path))
+ save_images(webpage, visuals, img_path, width=opt.display_winsize)
+ webpage.save() # save the HTML
diff --git a/timm/__init__.py b/timm/__init__.py
new file mode 100644
index 0000000..04ec7e5
--- /dev/null
+++ b/timm/__init__.py
@@ -0,0 +1,4 @@
+from .version import __version__
+from .models import create_model, list_models, is_model, list_modules, model_entrypoint, \
+ is_scriptable, is_exportable, set_scriptable, set_exportable, has_model_default_key, is_model_default_key, \
+ get_model_default_value, is_model_pretrained
diff --git a/timm/__pycache__/__init__.cpython-36.pyc b/timm/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..c81fb8b
Binary files /dev/null and b/timm/__pycache__/__init__.cpython-36.pyc differ
diff --git a/timm/__pycache__/version.cpython-36.pyc b/timm/__pycache__/version.cpython-36.pyc
new file mode 100644
index 0000000..0b2c6b1
Binary files /dev/null and b/timm/__pycache__/version.cpython-36.pyc differ
diff --git a/timm/data/__init__.py b/timm/data/__init__.py
new file mode 100644
index 0000000..7d3cb2b
--- /dev/null
+++ b/timm/data/__init__.py
@@ -0,0 +1,12 @@
+from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
+ rand_augment_transform, auto_augment_transform
+from .config import resolve_data_config
+from .constants import *
+from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
+from .dataset_factory import create_dataset
+from .loader import create_loader
+from .mixup import Mixup, FastCollateMixup
+from .parsers import create_parser
+from .real_labels import RealLabelsImagenet
+from .transforms import *
+from .transforms_factory import create_transform
\ No newline at end of file
diff --git a/timm/data/__pycache__/__init__.cpython-36.pyc b/timm/data/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..f75de55
Binary files /dev/null and b/timm/data/__pycache__/__init__.cpython-36.pyc differ
diff --git a/timm/data/__pycache__/auto_augment.cpython-36.pyc b/timm/data/__pycache__/auto_augment.cpython-36.pyc
new file mode 100644
index 0000000..f0105bf
Binary files /dev/null and b/timm/data/__pycache__/auto_augment.cpython-36.pyc differ
diff --git a/timm/data/__pycache__/config.cpython-36.pyc b/timm/data/__pycache__/config.cpython-36.pyc
new file mode 100644
index 0000000..6549eb0
Binary files /dev/null and b/timm/data/__pycache__/config.cpython-36.pyc differ
diff --git a/timm/data/__pycache__/constants.cpython-36.pyc b/timm/data/__pycache__/constants.cpython-36.pyc
new file mode 100644
index 0000000..d06bfd1
Binary files /dev/null and b/timm/data/__pycache__/constants.cpython-36.pyc differ
diff --git a/timm/data/__pycache__/dataset.cpython-36.pyc b/timm/data/__pycache__/dataset.cpython-36.pyc
new file mode 100644
index 0000000..5e1d453
Binary files /dev/null and b/timm/data/__pycache__/dataset.cpython-36.pyc differ
diff --git a/timm/data/__pycache__/dataset_factory.cpython-36.pyc b/timm/data/__pycache__/dataset_factory.cpython-36.pyc
new file mode 100644
index 0000000..3db11e8
Binary files /dev/null and b/timm/data/__pycache__/dataset_factory.cpython-36.pyc differ
diff --git a/timm/data/__pycache__/distributed_sampler.cpython-36.pyc b/timm/data/__pycache__/distributed_sampler.cpython-36.pyc
new file mode 100644
index 0000000..489ad09
Binary files /dev/null and b/timm/data/__pycache__/distributed_sampler.cpython-36.pyc differ
diff --git a/timm/data/__pycache__/loader.cpython-36.pyc b/timm/data/__pycache__/loader.cpython-36.pyc
new file mode 100644
index 0000000..0ce97c9
Binary files /dev/null and b/timm/data/__pycache__/loader.cpython-36.pyc differ
diff --git a/timm/data/__pycache__/mixup.cpython-36.pyc b/timm/data/__pycache__/mixup.cpython-36.pyc
new file mode 100644
index 0000000..28d2edf
Binary files /dev/null and b/timm/data/__pycache__/mixup.cpython-36.pyc differ
diff --git a/timm/data/__pycache__/random_erasing.cpython-36.pyc b/timm/data/__pycache__/random_erasing.cpython-36.pyc
new file mode 100644
index 0000000..a8f50ab
Binary files /dev/null and b/timm/data/__pycache__/random_erasing.cpython-36.pyc differ
diff --git a/timm/data/__pycache__/real_labels.cpython-36.pyc b/timm/data/__pycache__/real_labels.cpython-36.pyc
new file mode 100644
index 0000000..7098218
Binary files /dev/null and b/timm/data/__pycache__/real_labels.cpython-36.pyc differ
diff --git a/timm/data/__pycache__/transforms.cpython-36.pyc b/timm/data/__pycache__/transforms.cpython-36.pyc
new file mode 100644
index 0000000..6651158
Binary files /dev/null and b/timm/data/__pycache__/transforms.cpython-36.pyc differ
diff --git a/timm/data/__pycache__/transforms_factory.cpython-36.pyc b/timm/data/__pycache__/transforms_factory.cpython-36.pyc
new file mode 100644
index 0000000..6cc4504
Binary files /dev/null and b/timm/data/__pycache__/transforms_factory.cpython-36.pyc differ
diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py
new file mode 100644
index 0000000..8907e50
--- /dev/null
+++ b/timm/data/auto_augment.py
@@ -0,0 +1,865 @@
+""" AutoAugment, RandAugment, and AugMix for PyTorch
+
+This code implements the searched ImageNet policies with various tweaks and improvements and
+does not include any of the search code.
+
+AA and RA Implementation adapted from:
+ https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
+
+AugMix adapted from:
+ https://github.com/google-research/augmix
+
+Papers:
+ AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
+ Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
+ RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
+ AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import random
+import math
+import re
+from PIL import Image, ImageOps, ImageEnhance, ImageChops
+import PIL
+import numpy as np
+
+
+_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])
+
+_FILL = (128, 128, 128)
+
+_LEVEL_DENOM = 10. # denominator for conversion from 'Mx' magnitude scale to fractional aug level for op arguments
+
+_HPARAMS_DEFAULT = dict(
+ translate_const=250,
+ img_mean=_FILL,
+)
+
+_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
+
+
+def _interpolation(kwargs):
+ interpolation = kwargs.pop('resample', Image.BILINEAR)
+ if isinstance(interpolation, (list, tuple)):
+ return random.choice(interpolation)
+ else:
+ return interpolation
+
+
+def _check_args_tf(kwargs):
+ if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
+ kwargs.pop('fillcolor')
+ kwargs['resample'] = _interpolation(kwargs)
+
+
+def shear_x(img, factor, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
+
+
+def shear_y(img, factor, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
+
+
+def translate_x_rel(img, pct, **kwargs):
+ pixels = pct * img.size[0]
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
+
+
+def translate_y_rel(img, pct, **kwargs):
+ pixels = pct * img.size[1]
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
+
+
+def translate_x_abs(img, pixels, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
+
+
+def translate_y_abs(img, pixels, **kwargs):
+ _check_args_tf(kwargs)
+ return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
+
+
+def rotate(img, degrees, **kwargs):
+ _check_args_tf(kwargs)
+ if _PIL_VER >= (5, 2):
+ return img.rotate(degrees, **kwargs)
+ elif _PIL_VER >= (5, 0):
+ w, h = img.size
+ post_trans = (0, 0)
+ rotn_center = (w / 2.0, h / 2.0)
+ angle = -math.radians(degrees)
+ matrix = [
+ round(math.cos(angle), 15),
+ round(math.sin(angle), 15),
+ 0.0,
+ round(-math.sin(angle), 15),
+ round(math.cos(angle), 15),
+ 0.0,
+ ]
+
+ def transform(x, y, matrix):
+ (a, b, c, d, e, f) = matrix
+ return a * x + b * y + c, d * x + e * y + f
+
+ matrix[2], matrix[5] = transform(
+ -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
+ )
+ matrix[2] += rotn_center[0]
+ matrix[5] += rotn_center[1]
+ return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
+ else:
+ return img.rotate(degrees, resample=kwargs['resample'])
+
+
+def auto_contrast(img, **__):
+ return ImageOps.autocontrast(img)
+
+
+def invert(img, **__):
+ return ImageOps.invert(img)
+
+
+def equalize(img, **__):
+ return ImageOps.equalize(img)
+
+
+def solarize(img, thresh, **__):
+ return ImageOps.solarize(img, thresh)
+
+
+def solarize_add(img, add, thresh=128, **__):
+ lut = []
+ for i in range(256):
+ if i < thresh:
+ lut.append(min(255, i + add))
+ else:
+ lut.append(i)
+ if img.mode in ("L", "RGB"):
+ if img.mode == "RGB" and len(lut) == 256:
+ lut = lut + lut + lut
+ return img.point(lut)
+ else:
+ return img
+
+
+def posterize(img, bits_to_keep, **__):
+ if bits_to_keep >= 8:
+ return img
+ return ImageOps.posterize(img, bits_to_keep)
+
+
+def contrast(img, factor, **__):
+ return ImageEnhance.Contrast(img).enhance(factor)
+
+
+def color(img, factor, **__):
+ return ImageEnhance.Color(img).enhance(factor)
+
+
+def brightness(img, factor, **__):
+ return ImageEnhance.Brightness(img).enhance(factor)
+
+
+def sharpness(img, factor, **__):
+ return ImageEnhance.Sharpness(img).enhance(factor)
+
+
+def _randomly_negate(v):
+ """With 50% prob, negate the value"""
+ return -v if random.random() > 0.5 else v
+
+
+def _rotate_level_to_arg(level, _hparams):
+ # range [-30, 30]
+ level = (level / _LEVEL_DENOM) * 30.
+ level = _randomly_negate(level)
+ return level,
+
+
+def _enhance_level_to_arg(level, _hparams):
+ # range [0.1, 1.9]
+ return (level / _LEVEL_DENOM) * 1.8 + 0.1,
+
+
+def _enhance_increasing_level_to_arg(level, _hparams):
+ # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend
+ # range [0.1, 1.9] if level <= _LEVEL_DENOM
+ level = (level / _LEVEL_DENOM) * .9
+ level = max(0.1, 1.0 + _randomly_negate(level)) # keep it >= 0.1
+ return level,
+
+
+def _shear_level_to_arg(level, _hparams):
+ # range [-0.3, 0.3]
+ level = (level / _LEVEL_DENOM) * 0.3
+ level = _randomly_negate(level)
+ return level,
+
+
+def _translate_abs_level_to_arg(level, hparams):
+ translate_const = hparams['translate_const']
+ level = (level / _LEVEL_DENOM) * float(translate_const)
+ level = _randomly_negate(level)
+ return level,
+
+
+def _translate_rel_level_to_arg(level, hparams):
+ # default range [-0.45, 0.45]
+ translate_pct = hparams.get('translate_pct', 0.45)
+ level = (level / _LEVEL_DENOM) * translate_pct
+ level = _randomly_negate(level)
+ return level,
+
+
+def _posterize_level_to_arg(level, _hparams):
+ # As per Tensorflow TPU EfficientNet impl
+ # range [0, 4], 'keep 0 up to 4 MSB of original image'
+ # intensity/severity of augmentation decreases with level
+ return int((level / _LEVEL_DENOM) * 4),
+
+
+def _posterize_increasing_level_to_arg(level, hparams):
+ # As per Tensorflow models research and UDA impl
+ # range [4, 0], 'keep 4 down to 0 MSB of original image',
+ # intensity/severity of augmentation increases with level
+ return 4 - _posterize_level_to_arg(level, hparams)[0],
+
+
+def _posterize_original_level_to_arg(level, _hparams):
+ # As per original AutoAugment paper description
+ # range [4, 8], 'keep 4 up to 8 MSB of image'
+ # intensity/severity of augmentation decreases with level
+ return int((level / _LEVEL_DENOM) * 4) + 4,
+
+
+def _solarize_level_to_arg(level, _hparams):
+ # range [0, 256]
+ # intensity/severity of augmentation decreases with level
+ return int((level / _LEVEL_DENOM) * 256),
+
+
+def _solarize_increasing_level_to_arg(level, _hparams):
+ # range [0, 256]
+ # intensity/severity of augmentation increases with level
+ return 256 - _solarize_level_to_arg(level, _hparams)[0],
+
+
+def _solarize_add_level_to_arg(level, _hparams):
+ # range [0, 110]
+ return int((level / _LEVEL_DENOM) * 110),
+
+
+LEVEL_TO_ARG = {
+ 'AutoContrast': None,
+ 'Equalize': None,
+ 'Invert': None,
+ 'Rotate': _rotate_level_to_arg,
+ # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers
+ 'Posterize': _posterize_level_to_arg,
+ 'PosterizeIncreasing': _posterize_increasing_level_to_arg,
+ 'PosterizeOriginal': _posterize_original_level_to_arg,
+ 'Solarize': _solarize_level_to_arg,
+ 'SolarizeIncreasing': _solarize_increasing_level_to_arg,
+ 'SolarizeAdd': _solarize_add_level_to_arg,
+ 'Color': _enhance_level_to_arg,
+ 'ColorIncreasing': _enhance_increasing_level_to_arg,
+ 'Contrast': _enhance_level_to_arg,
+ 'ContrastIncreasing': _enhance_increasing_level_to_arg,
+ 'Brightness': _enhance_level_to_arg,
+ 'BrightnessIncreasing': _enhance_increasing_level_to_arg,
+ 'Sharpness': _enhance_level_to_arg,
+ 'SharpnessIncreasing': _enhance_increasing_level_to_arg,
+ 'ShearX': _shear_level_to_arg,
+ 'ShearY': _shear_level_to_arg,
+ 'TranslateX': _translate_abs_level_to_arg,
+ 'TranslateY': _translate_abs_level_to_arg,
+ 'TranslateXRel': _translate_rel_level_to_arg,
+ 'TranslateYRel': _translate_rel_level_to_arg,
+}
+
+
+NAME_TO_OP = {
+ 'AutoContrast': auto_contrast,
+ 'Equalize': equalize,
+ 'Invert': invert,
+ 'Rotate': rotate,
+ 'Posterize': posterize,
+ 'PosterizeIncreasing': posterize,
+ 'PosterizeOriginal': posterize,
+ 'Solarize': solarize,
+ 'SolarizeIncreasing': solarize,
+ 'SolarizeAdd': solarize_add,
+ 'Color': color,
+ 'ColorIncreasing': color,
+ 'Contrast': contrast,
+ 'ContrastIncreasing': contrast,
+ 'Brightness': brightness,
+ 'BrightnessIncreasing': brightness,
+ 'Sharpness': sharpness,
+ 'SharpnessIncreasing': sharpness,
+ 'ShearX': shear_x,
+ 'ShearY': shear_y,
+ 'TranslateX': translate_x_abs,
+ 'TranslateY': translate_y_abs,
+ 'TranslateXRel': translate_x_rel,
+ 'TranslateYRel': translate_y_rel,
+}
+
+
+class AugmentOp:
+
+ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
+ hparams = hparams or _HPARAMS_DEFAULT
+ self.name = name
+ self.aug_fn = NAME_TO_OP[name]
+ self.level_fn = LEVEL_TO_ARG[name]
+ self.prob = prob
+ self.magnitude = magnitude
+ self.hparams = hparams.copy()
+ self.kwargs = dict(
+ fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
+ resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
+ )
+
+ # If magnitude_std is > 0, we introduce some randomness
+ # in the usually fixed policy and sample magnitude from a normal distribution
+ # with mean `magnitude` and std-dev of `magnitude_std`.
+ # NOTE This is my own hack, being tested, not in papers or reference impls.
+ # If magnitude_std is inf, we sample magnitude from a uniform distribution
+ self.magnitude_std = self.hparams.get('magnitude_std', 0)
+ self.magnitude_max = self.hparams.get('magnitude_max', None)
+
+ def __call__(self, img):
+ if self.prob < 1.0 and random.random() > self.prob:
+ return img
+ magnitude = self.magnitude
+ if self.magnitude_std > 0:
+ # magnitude randomization enabled
+ if self.magnitude_std == float('inf'):
+ magnitude = random.uniform(0, magnitude)
+ elif self.magnitude_std > 0:
+ magnitude = random.gauss(magnitude, self.magnitude_std)
+ # default upper_bound for the timm RA impl is _LEVEL_DENOM (10)
+ # setting magnitude_max overrides this to allow M > 10 (behaviour closer to Google TF RA impl)
+ upper_bound = self.magnitude_max or _LEVEL_DENOM
+ magnitude = max(0., min(magnitude, upper_bound))
+ level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple()
+ return self.aug_fn(img, *level_args, **self.kwargs)
+
+ def __repr__(self):
+ fs = self.__class__.__name__ + f'(name={self.name}, p={self.prob}'
+ fs += f', m={self.magnitude}, mstd={self.magnitude_std}'
+ if self.magnitude_max is not None:
+ fs += f', mmax={self.magnitude_max}'
+ fs += ')'
+ return fs
+
+
+def auto_augment_policy_v0(hparams):
+ # ImageNet v0 policy from TPU EfficientNet impl, cannot find a paper reference.
+ policy = [
+ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+ [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+ [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+ [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+ [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+ [('Posterize', 0.4, 6), ('AutoContrast', 0.4, 7)],
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+ [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+ [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
+ [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+ [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+ [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+ [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+ [('Posterize', 0.8, 2), ('Solarize', 0.6, 10)], # This results in black image with Tpu posterize
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+ [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+ ]
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+ return pc
+
+
+def auto_augment_policy_v0r(hparams):
+ # ImageNet v0 policy from TPU EfficientNet impl, with variation of Posterize used
+ # in Google research implementation (number of bits discarded increases with magnitude)
+ policy = [
+ [('Equalize', 0.8, 1), ('ShearY', 0.8, 4)],
+ [('Color', 0.4, 9), ('Equalize', 0.6, 3)],
+ [('Color', 0.4, 1), ('Rotate', 0.6, 8)],
+ [('Solarize', 0.8, 3), ('Equalize', 0.4, 7)],
+ [('Solarize', 0.4, 2), ('Solarize', 0.6, 2)],
+ [('Color', 0.2, 0), ('Equalize', 0.8, 8)],
+ [('Equalize', 0.4, 8), ('SolarizeAdd', 0.8, 3)],
+ [('ShearX', 0.2, 9), ('Rotate', 0.6, 8)],
+ [('Color', 0.6, 1), ('Equalize', 1.0, 2)],
+ [('Invert', 0.4, 9), ('Rotate', 0.6, 0)],
+ [('Equalize', 1.0, 9), ('ShearY', 0.6, 3)],
+ [('Color', 0.4, 7), ('Equalize', 0.6, 0)],
+ [('PosterizeIncreasing', 0.4, 6), ('AutoContrast', 0.4, 7)],
+ [('Solarize', 0.6, 8), ('Color', 0.6, 9)],
+ [('Solarize', 0.2, 4), ('Rotate', 0.8, 9)],
+ [('Rotate', 1.0, 7), ('TranslateYRel', 0.8, 9)],
+ [('ShearX', 0.0, 0), ('Solarize', 0.8, 4)],
+ [('ShearY', 0.8, 0), ('Color', 0.6, 4)],
+ [('Color', 1.0, 0), ('Rotate', 0.6, 2)],
+ [('Equalize', 0.8, 4), ('Equalize', 0.0, 8)],
+ [('Equalize', 1.0, 4), ('AutoContrast', 0.6, 2)],
+ [('ShearY', 0.4, 7), ('SolarizeAdd', 0.6, 7)],
+ [('PosterizeIncreasing', 0.8, 2), ('Solarize', 0.6, 10)],
+ [('Solarize', 0.6, 8), ('Equalize', 0.6, 1)],
+ [('Color', 0.8, 6), ('Rotate', 0.4, 5)],
+ ]
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+ return pc
+
+
+def auto_augment_policy_original(hparams):
+ # ImageNet policy from https://arxiv.org/abs/1805.09501
+ policy = [
+ [('PosterizeOriginal', 0.4, 8), ('Rotate', 0.6, 9)],
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+ [('PosterizeOriginal', 0.6, 7), ('PosterizeOriginal', 0.6, 6)],
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+ [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
+ [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
+ [('PosterizeOriginal', 0.8, 5), ('Equalize', 1.0, 2)],
+ [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
+ [('Equalize', 0.6, 8), ('PosterizeOriginal', 0.4, 6)],
+ [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
+ [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
+ [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+ [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
+ [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
+ [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
+ [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
+ [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+ ]
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+ return pc
+
+
+def auto_augment_policy_originalr(hparams):
+ # ImageNet policy from https://arxiv.org/abs/1805.09501 with research posterize variation
+ policy = [
+ [('PosterizeIncreasing', 0.4, 8), ('Rotate', 0.6, 9)],
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+ [('PosterizeIncreasing', 0.6, 7), ('PosterizeIncreasing', 0.6, 6)],
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+ [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
+ [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
+ [('PosterizeIncreasing', 0.8, 5), ('Equalize', 1.0, 2)],
+ [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
+ [('Equalize', 0.6, 8), ('PosterizeIncreasing', 0.4, 6)],
+ [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
+ [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
+ [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+ [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
+ [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
+ [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
+ [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
+ [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
+ [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
+ [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
+ [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
+ [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
+ [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
+ ]
+ pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
+ return pc
+
+
+def auto_augment_policy(name='v0', hparams=None):
+ hparams = hparams or _HPARAMS_DEFAULT
+ if name == 'original':
+ return auto_augment_policy_original(hparams)
+ elif name == 'originalr':
+ return auto_augment_policy_originalr(hparams)
+ elif name == 'v0':
+ return auto_augment_policy_v0(hparams)
+ elif name == 'v0r':
+ return auto_augment_policy_v0r(hparams)
+ else:
+ assert False, 'Unknown AA policy (%s)' % name
+
+
+class AutoAugment:
+
+ def __init__(self, policy):
+ self.policy = policy
+
+ def __call__(self, img):
+ sub_policy = random.choice(self.policy)
+ for op in sub_policy:
+ img = op(img)
+ return img
+
+ def __repr__(self):
+ fs = self.__class__.__name__ + f'(policy='
+ for p in self.policy:
+ fs += '\n\t['
+ fs += ', '.join([str(op) for op in p])
+ fs += ']'
+ fs += ')'
+ return fs
+
+
+def auto_augment_transform(config_str, hparams):
+ """
+ Create a AutoAugment transform
+
+ :param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
+ dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
+ The remaining sections, not order sepecific determine
+ 'mstd' - float std deviation of magnitude noise applied
+ Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
+
+ :param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
+
+ :return: A PyTorch compatible Transform
+ """
+ config = config_str.split('-')
+ policy_name = config[0]
+ config = config[1:]
+ for c in config:
+ cs = re.split(r'(\d.*)', c)
+ if len(cs) < 2:
+ continue
+ key, val = cs[:2]
+ if key == 'mstd':
+ # noise param injected via hparams for now
+ hparams.setdefault('magnitude_std', float(val))
+ else:
+ assert False, 'Unknown AutoAugment config section'
+ aa_policy = auto_augment_policy(policy_name, hparams=hparams)
+ return AutoAugment(aa_policy)
+
+
+_RAND_TRANSFORMS = [
+ 'AutoContrast',
+ 'Equalize',
+ 'Invert',
+ 'Rotate',
+ 'Posterize',
+ 'Solarize',
+ 'SolarizeAdd',
+ 'Color',
+ 'Contrast',
+ 'Brightness',
+ 'Sharpness',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateXRel',
+ 'TranslateYRel',
+ #'Cutout' # NOTE I've implement this as random erasing separately
+]
+
+
+_RAND_INCREASING_TRANSFORMS = [
+ 'AutoContrast',
+ 'Equalize',
+ 'Invert',
+ 'Rotate',
+ 'PosterizeIncreasing',
+ 'SolarizeIncreasing',
+ 'SolarizeAdd',
+ 'ColorIncreasing',
+ 'ContrastIncreasing',
+ 'BrightnessIncreasing',
+ 'SharpnessIncreasing',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateXRel',
+ 'TranslateYRel',
+ #'Cutout' # NOTE I've implement this as random erasing separately
+]
+
+
+
+# These experimental weights are based loosely on the relative improvements mentioned in paper.
+# They may not result in increased performance, but could likely be tuned to so.
+_RAND_CHOICE_WEIGHTS_0 = {
+ 'Rotate': 0.3,
+ 'ShearX': 0.2,
+ 'ShearY': 0.2,
+ 'TranslateXRel': 0.1,
+ 'TranslateYRel': 0.1,
+ 'Color': .025,
+ 'Sharpness': 0.025,
+ 'AutoContrast': 0.025,
+ 'Solarize': .005,
+ 'SolarizeAdd': .005,
+ 'Contrast': .005,
+ 'Brightness': .005,
+ 'Equalize': .005,
+ 'Posterize': 0,
+ 'Invert': 0,
+}
+
+
+def _select_rand_weights(weight_idx=0, transforms=None):
+ transforms = transforms or _RAND_TRANSFORMS
+ assert weight_idx == 0 # only one set of weights currently
+ rand_weights = _RAND_CHOICE_WEIGHTS_0
+ probs = [rand_weights[k] for k in transforms]
+ probs /= np.sum(probs)
+ return probs
+
+
+def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
+ hparams = hparams or _HPARAMS_DEFAULT
+ transforms = transforms or _RAND_TRANSFORMS
+ return [AugmentOp(
+ name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
+
+
+class RandAugment:
+ def __init__(self, ops, num_layers=2, choice_weights=None):
+ self.ops = ops
+ self.num_layers = num_layers
+ self.choice_weights = choice_weights
+
+ def __call__(self, img):
+ # no replacement when using weighted choice
+ ops = np.random.choice(
+ self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
+ for op in ops:
+ img = op(img)
+ return img
+
+ def __repr__(self):
+ fs = self.__class__.__name__ + f'(n={self.num_layers}, ops='
+ for op in self.ops:
+ fs += f'\n\t{op}'
+ fs += ')'
+ return fs
+
+
+def rand_augment_transform(config_str, hparams):
+ """
+ Create a RandAugment transform
+
+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
+ sections, not order sepecific determine
+ 'm' - integer magnitude of rand augment
+ 'n' - integer num layers (number of transform ops selected per image)
+ 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
+ 'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
+ 'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
+ 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
+ Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
+ 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
+
+ :param hparams: Other hparams (kwargs) for the RandAugmentation scheme
+
+ :return: A PyTorch compatible Transform
+ """
+ magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
+ num_layers = 2 # default to 2 ops per image
+ weight_idx = None # default to no probability weights for op choice
+ transforms = _RAND_TRANSFORMS
+ config = config_str.split('-')
+ assert config[0] == 'rand'
+ config = config[1:]
+ for c in config:
+ cs = re.split(r'(\d.*)', c)
+ if len(cs) < 2:
+ continue
+ key, val = cs[:2]
+ if key == 'mstd':
+ # noise param / randomization of magnitude values
+ mstd = float(val)
+ if mstd > 100:
+ # use uniform sampling in 0 to magnitude if mstd is > 100
+ mstd = float('inf')
+ hparams.setdefault('magnitude_std', mstd)
+ elif key == 'mmax':
+ # clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
+ hparams.setdefault('magnitude_max', int(val))
+ elif key == 'inc':
+ if bool(val):
+ transforms = _RAND_INCREASING_TRANSFORMS
+ elif key == 'm':
+ magnitude = int(val)
+ elif key == 'n':
+ num_layers = int(val)
+ elif key == 'w':
+ weight_idx = int(val)
+ else:
+ assert False, 'Unknown RandAugment config section'
+ ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
+ choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
+ return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
+
+
+_AUGMIX_TRANSFORMS = [
+ 'AutoContrast',
+ 'ColorIncreasing', # not in paper
+ 'ContrastIncreasing', # not in paper
+ 'BrightnessIncreasing', # not in paper
+ 'SharpnessIncreasing', # not in paper
+ 'Equalize',
+ 'Rotate',
+ 'PosterizeIncreasing',
+ 'SolarizeIncreasing',
+ 'ShearX',
+ 'ShearY',
+ 'TranslateXRel',
+ 'TranslateYRel',
+]
+
+
+def augmix_ops(magnitude=10, hparams=None, transforms=None):
+ hparams = hparams or _HPARAMS_DEFAULT
+ transforms = transforms or _AUGMIX_TRANSFORMS
+ return [AugmentOp(
+ name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
+
+
+class AugMixAugment:
+ """ AugMix Transform
+ Adapted and improved from impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
+ From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
+ https://arxiv.org/abs/1912.02781
+ """
+ def __init__(self, ops, alpha=1., width=3, depth=-1, blended=False):
+ self.ops = ops
+ self.alpha = alpha
+ self.width = width
+ self.depth = depth
+ self.blended = blended # blended mode is faster but not well tested
+
+ def _calc_blended_weights(self, ws, m):
+ ws = ws * m
+ cump = 1.
+ rws = []
+ for w in ws[::-1]:
+ alpha = w / cump
+ cump *= (1 - alpha)
+ rws.append(alpha)
+ return np.array(rws[::-1], dtype=np.float32)
+
+ def _apply_blended(self, img, mixing_weights, m):
+ # This is my first crack and implementing a slightly faster mixed augmentation. Instead
+ # of accumulating the mix for each chain in a Numpy array and then blending with original,
+ # it recomputes the blending coefficients and applies one PIL image blend per chain.
+ # TODO the results appear in the right ballpark but they differ by more than rounding.
+ img_orig = img.copy()
+ ws = self._calc_blended_weights(mixing_weights, m)
+ for w in ws:
+ depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
+ ops = np.random.choice(self.ops, depth, replace=True)
+ img_aug = img_orig # no ops are in-place, deep copy not necessary
+ for op in ops:
+ img_aug = op(img_aug)
+ img = Image.blend(img, img_aug, w)
+ return img
+
+ def _apply_basic(self, img, mixing_weights, m):
+ # This is a literal adaptation of the paper/official implementation without normalizations and
+ # PIL <-> Numpy conversions between every op. It is still quite CPU compute heavy compared to the
+ # typical augmentation transforms, could use a GPU / Kornia implementation.
+ img_shape = img.size[0], img.size[1], len(img.getbands())
+ mixed = np.zeros(img_shape, dtype=np.float32)
+ for mw in mixing_weights:
+ depth = self.depth if self.depth > 0 else np.random.randint(1, 4)
+ ops = np.random.choice(self.ops, depth, replace=True)
+ img_aug = img # no ops are in-place, deep copy not necessary
+ for op in ops:
+ img_aug = op(img_aug)
+ mixed += mw * np.asarray(img_aug, dtype=np.float32)
+ np.clip(mixed, 0, 255., out=mixed)
+ mixed = Image.fromarray(mixed.astype(np.uint8))
+ return Image.blend(img, mixed, m)
+
+ def __call__(self, img):
+ mixing_weights = np.float32(np.random.dirichlet([self.alpha] * self.width))
+ m = np.float32(np.random.beta(self.alpha, self.alpha))
+ if self.blended:
+ mixed = self._apply_blended(img, mixing_weights, m)
+ else:
+ mixed = self._apply_basic(img, mixing_weights, m)
+ return mixed
+
+ def __repr__(self):
+ fs = self.__class__.__name__ + f'(alpha={self.alpha}, width={self.width}, depth={self.depth}, ops='
+ for op in self.ops:
+ fs += f'\n\t{op}'
+ fs += ')'
+ return fs
+
+
+def augment_and_mix_transform(config_str, hparams):
+ """ Create AugMix PyTorch transform
+
+ :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
+ dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
+ sections, not order sepecific determine
+ 'm' - integer magnitude (severity) of augmentation mix (default: 3)
+ 'w' - integer width of augmentation chain (default: 3)
+ 'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
+ 'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
+ 'mstd' - float std deviation of magnitude noise applied (default: 0)
+ Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
+
+ :param hparams: Other hparams (kwargs) for the Augmentation transforms
+
+ :return: A PyTorch compatible Transform
+ """
+ magnitude = 3
+ width = 3
+ depth = -1
+ alpha = 1.
+ blended = False
+ config = config_str.split('-')
+ assert config[0] == 'augmix'
+ config = config[1:]
+ for c in config:
+ cs = re.split(r'(\d.*)', c)
+ if len(cs) < 2:
+ continue
+ key, val = cs[:2]
+ if key == 'mstd':
+ # noise param injected via hparams for now
+ hparams.setdefault('magnitude_std', float(val))
+ elif key == 'm':
+ magnitude = int(val)
+ elif key == 'w':
+ width = int(val)
+ elif key == 'd':
+ depth = int(val)
+ elif key == 'a':
+ alpha = float(val)
+ elif key == 'b':
+ blended = bool(val)
+ else:
+ assert False, 'Unknown AugMix config section'
+ hparams.setdefault('magnitude_std', float('inf')) # default to uniform sampling (if not set via mstd arg)
+ ops = augmix_ops(magnitude=magnitude, hparams=hparams)
+ return AugMixAugment(ops, alpha=alpha, width=width, depth=depth, blended=blended)
diff --git a/timm/data/config.py b/timm/data/config.py
new file mode 100644
index 0000000..38f5689
--- /dev/null
+++ b/timm/data/config.py
@@ -0,0 +1,78 @@
+import logging
+from .constants import *
+
+
+_logger = logging.getLogger(__name__)
+
+
+def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
+ new_config = {}
+ default_cfg = default_cfg
+ if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
+ default_cfg = model.default_cfg
+
+ # Resolve input/image size
+ in_chans = 3
+ if 'chans' in args and args['chans'] is not None:
+ in_chans = args['chans']
+
+ input_size = (in_chans, 224, 224)
+ if 'input_size' in args and args['input_size'] is not None:
+ assert isinstance(args['input_size'], (tuple, list))
+ assert len(args['input_size']) == 3
+ input_size = tuple(args['input_size'])
+ in_chans = input_size[0] # input_size overrides in_chans
+ elif 'img_size' in args and args['img_size'] is not None:
+ assert isinstance(args['img_size'], int)
+ input_size = (in_chans, args['img_size'], args['img_size'])
+ else:
+ if use_test_size and 'test_input_size' in default_cfg:
+ input_size = default_cfg['test_input_size']
+ elif 'input_size' in default_cfg:
+ input_size = default_cfg['input_size']
+ new_config['input_size'] = input_size
+
+ # resolve interpolation method
+ new_config['interpolation'] = 'bicubic'
+ if 'interpolation' in args and args['interpolation']:
+ new_config['interpolation'] = args['interpolation']
+ elif 'interpolation' in default_cfg:
+ new_config['interpolation'] = default_cfg['interpolation']
+
+ # resolve dataset + model mean for normalization
+ new_config['mean'] = IMAGENET_DEFAULT_MEAN
+ if 'mean' in args and args['mean'] is not None:
+ mean = tuple(args['mean'])
+ if len(mean) == 1:
+ mean = tuple(list(mean) * in_chans)
+ else:
+ assert len(mean) == in_chans
+ new_config['mean'] = mean
+ elif 'mean' in default_cfg:
+ new_config['mean'] = default_cfg['mean']
+
+ # resolve dataset + model std deviation for normalization
+ new_config['std'] = IMAGENET_DEFAULT_STD
+ if 'std' in args and args['std'] is not None:
+ std = tuple(args['std'])
+ if len(std) == 1:
+ std = tuple(list(std) * in_chans)
+ else:
+ assert len(std) == in_chans
+ new_config['std'] = std
+ elif 'std' in default_cfg:
+ new_config['std'] = default_cfg['std']
+
+ # resolve default crop percentage
+ new_config['crop_pct'] = DEFAULT_CROP_PCT
+ if 'crop_pct' in args and args['crop_pct'] is not None:
+ new_config['crop_pct'] = args['crop_pct']
+ elif 'crop_pct' in default_cfg:
+ new_config['crop_pct'] = default_cfg['crop_pct']
+
+ if verbose:
+ _logger.info('Data processing configuration for current model + dataset:')
+ for n, v in new_config.items():
+ _logger.info('\t%s: %s' % (n, str(v)))
+
+ return new_config
diff --git a/timm/data/constants.py b/timm/data/constants.py
new file mode 100644
index 0000000..d6d4a01
--- /dev/null
+++ b/timm/data/constants.py
@@ -0,0 +1,7 @@
+DEFAULT_CROP_PCT = 0.875
+IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
+IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
+IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
+IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
+IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
diff --git a/timm/data/dataset.py b/timm/data/dataset.py
new file mode 100644
index 0000000..d3603a2
--- /dev/null
+++ b/timm/data/dataset.py
@@ -0,0 +1,152 @@
+""" Quick n Simple Image Folder, Tarfile based DataSet
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch.utils.data as data
+import os
+import torch
+import logging
+
+from PIL import Image
+
+from .parsers import create_parser
+
+_logger = logging.getLogger(__name__)
+
+
+_ERROR_RETRY = 50
+
+
+class ImageDataset(data.Dataset):
+
+ def __init__(
+ self,
+ root,
+ parser=None,
+ class_map=None,
+ load_bytes=False,
+ transform=None,
+ target_transform=None,
+ ):
+ if parser is None or isinstance(parser, str):
+ parser = create_parser(parser or '', root=root, class_map=class_map)
+ self.parser = parser
+ self.load_bytes = load_bytes
+ self.transform = transform
+ self.target_transform = target_transform
+ self._consecutive_errors = 0
+
+ def __getitem__(self, index):
+ img, target = self.parser[index]
+ try:
+ img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
+ except Exception as e:
+ _logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
+ self._consecutive_errors += 1
+ if self._consecutive_errors < _ERROR_RETRY:
+ return self.__getitem__((index + 1) % len(self.parser))
+ else:
+ raise e
+ self._consecutive_errors = 0
+ if self.transform is not None:
+ img = self.transform(img)
+ if target is None:
+ target = -1
+ elif self.target_transform is not None:
+ target = self.target_transform(target)
+ return img, target
+
+ def __len__(self):
+ return len(self.parser)
+
+ def filename(self, index, basename=False, absolute=False):
+ return self.parser.filename(index, basename, absolute)
+
+ def filenames(self, basename=False, absolute=False):
+ return self.parser.filenames(basename, absolute)
+
+
+class IterableImageDataset(data.IterableDataset):
+
+ def __init__(
+ self,
+ root,
+ parser=None,
+ split='train',
+ is_training=False,
+ batch_size=None,
+ repeats=0,
+ download=False,
+ transform=None,
+ target_transform=None,
+ ):
+ assert parser is not None
+ if isinstance(parser, str):
+ self.parser = create_parser(
+ parser, root=root, split=split, is_training=is_training,
+ batch_size=batch_size, repeats=repeats, download=download)
+ else:
+ self.parser = parser
+ self.transform = transform
+ self.target_transform = target_transform
+ self._consecutive_errors = 0
+
+ def __iter__(self):
+ for img, target in self.parser:
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+ yield img, target
+
+ def __len__(self):
+ if hasattr(self.parser, '__len__'):
+ return len(self.parser)
+ else:
+ return 0
+
+ def filename(self, index, basename=False, absolute=False):
+ assert False, 'Filename lookup by index not supported, use filenames().'
+
+ def filenames(self, basename=False, absolute=False):
+ return self.parser.filenames(basename, absolute)
+
+
+class AugMixDataset(torch.utils.data.Dataset):
+ """Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
+
+ def __init__(self, dataset, num_splits=2):
+ self.augmentation = None
+ self.normalize = None
+ self.dataset = dataset
+ if self.dataset.transform is not None:
+ self._set_transforms(self.dataset.transform)
+ self.num_splits = num_splits
+
+ def _set_transforms(self, x):
+ assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
+ self.dataset.transform = x[0]
+ self.augmentation = x[1]
+ self.normalize = x[2]
+
+ @property
+ def transform(self):
+ return self.dataset.transform
+
+ @transform.setter
+ def transform(self, x):
+ self._set_transforms(x)
+
+ def _normalize(self, x):
+ return x if self.normalize is None else self.normalize(x)
+
+ def __getitem__(self, i):
+ x, y = self.dataset[i] # all splits share the same dataset base transform
+ x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
+ # run the full augmentation on the remaining splits
+ for _ in range(self.num_splits - 1):
+ x_list.append(self._normalize(self.augmentation(x)))
+ return tuple(x_list), y
+
+ def __len__(self):
+ return len(self.dataset)
diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py
new file mode 100644
index 0000000..e86bcc2
--- /dev/null
+++ b/timm/data/dataset_factory.py
@@ -0,0 +1,139 @@
+import os
+
+from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST, ImageNet, ImageFolder
+try:
+ from torchvision.datasets import Places365
+ has_places365 = True
+except ImportError:
+ has_places365 = False
+try:
+ from torchvision.datasets import INaturalist
+ has_inaturalist = True
+except ImportError:
+ has_inaturalist = False
+
+from .dataset import IterableImageDataset, ImageDataset
+
+_TORCH_BASIC_DS = dict(
+ cifar10=CIFAR10,
+ cifar100=CIFAR100,
+ mnist=MNIST,
+ qmist=QMNIST,
+ kmnist=KMNIST,
+ fashion_mnist=FashionMNIST,
+)
+_TRAIN_SYNONYM = {'train', 'training'}
+_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
+
+
+def _search_split(root, split):
+ # look for sub-folder with name of split in root and use that if it exists
+ split_name = split.split('[')[0]
+ try_root = os.path.join(root, split_name)
+ if os.path.exists(try_root):
+ return try_root
+
+ def _try(syn):
+ for s in syn:
+ try_root = os.path.join(root, s)
+ if os.path.exists(try_root):
+ return try_root
+ return root
+ if split_name in _TRAIN_SYNONYM:
+ root = _try(_TRAIN_SYNONYM)
+ elif split_name in _EVAL_SYNONYM:
+ root = _try(_EVAL_SYNONYM)
+ return root
+
+
+def create_dataset(
+ name,
+ root,
+ split='validation',
+ search_split=True,
+ class_map=None,
+ load_bytes=False,
+ is_training=False,
+ download=False,
+ batch_size=None,
+ repeats=0,
+ **kwargs
+):
+ """ Dataset factory method
+
+ In parenthesis after each arg are the type of dataset supported for each arg, one of:
+ * folder - default, timm folder (or tar) based ImageDataset
+ * torch - torchvision based datasets
+ * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
+ * all - any of the above
+
+ Args:
+ name: dataset name, empty is okay for folder based datasets
+ root: root folder of dataset (all)
+ split: dataset split (all)
+ search_split: search for split specific child fold from root so one can specify
+ `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
+ class_map: specify class -> index mapping via text file or dict (folder)
+ load_bytes: load data, return images as undecoded bytes (folder)
+ download: download dataset if not present and supported (TFDS, torch)
+ is_training: create dataset in train mode, this is different from the split.
+ For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
+ batch_size: batch size hint for (TFDS)
+ repeats: dataset repeats per iteration i.e. epoch (TFDS)
+ **kwargs: other args to pass to dataset
+
+ Returns:
+ Dataset object
+ """
+ name = name.lower()
+ if name.startswith('torch/'):
+ name = name.split('/', 2)[-1]
+ torch_kwargs = dict(root=root, download=download, **kwargs)
+ if name in _TORCH_BASIC_DS:
+ ds_class = _TORCH_BASIC_DS[name]
+ use_train = split in _TRAIN_SYNONYM
+ ds = ds_class(train=use_train, **torch_kwargs)
+ elif name == 'inaturalist' or name == 'inat':
+ assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
+ target_type = 'full'
+ split_split = split.split('/')
+ if len(split_split) > 1:
+ target_type = split_split[0].split('_')
+ if len(target_type) == 1:
+ target_type = target_type[0]
+ split = split_split[-1]
+ if split in _TRAIN_SYNONYM:
+ split = '2021_train'
+ elif split in _EVAL_SYNONYM:
+ split = '2021_valid'
+ ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
+ elif name == 'places365':
+ assert has_places365, 'Please update to a newer PyTorch and torchvision for Places365 dataset.'
+ if split in _TRAIN_SYNONYM:
+ split = 'train-standard'
+ elif split in _EVAL_SYNONYM:
+ split = 'val'
+ ds = Places365(split=split, **torch_kwargs)
+ elif name == 'imagenet':
+ if split in _EVAL_SYNONYM:
+ split = 'val'
+ ds = ImageNet(split=split, **torch_kwargs)
+ elif name == 'image_folder' or name == 'folder':
+ # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
+ if search_split and os.path.isdir(root):
+ # look for split specific sub-folder in root
+ root = _search_split(root, split)
+ ds = ImageFolder(root, **kwargs)
+ else:
+ assert False, f"Unknown torchvision dataset {name}"
+ elif name.startswith('tfds/'):
+ ds = IterableImageDataset(
+ root, parser=name, split=split, is_training=is_training,
+ download=download, batch_size=batch_size, repeats=repeats, **kwargs)
+ else:
+ # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
+ if search_split and os.path.isdir(root):
+ # look for split specific sub-folder in root
+ root = _search_split(root, split)
+ ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
+ return ds
diff --git a/timm/data/distributed_sampler.py b/timm/data/distributed_sampler.py
new file mode 100644
index 0000000..fa403d0
--- /dev/null
+++ b/timm/data/distributed_sampler.py
@@ -0,0 +1,128 @@
+import math
+import torch
+from torch.utils.data import Sampler
+import torch.distributed as dist
+
+
+class OrderedDistributedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+ .. note::
+ Dataset is assumed to be of constant size.
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ indices = list(range(len(self.dataset)))
+
+ # add extra samples to make it evenly divisible
+ indices += indices[:(self.total_size - len(indices))]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+
+class RepeatAugSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset for distributed,
+ with repeated augmentation.
+ It ensures that different each augmented version of a sample will be visible to a
+ different process (GPU). Heavily based on torch.utils.data.DistributedSampler
+
+ This sampler was taken from https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
+ Used in
+ Copyright (c) 2015-present, Facebook, Inc.
+ """
+
+ def __init__(
+ self,
+ dataset,
+ num_replicas=None,
+ rank=None,
+ shuffle=True,
+ num_repeats=3,
+ selected_round=256,
+ selected_ratio=0,
+ ):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.shuffle = shuffle
+ self.num_repeats = num_repeats
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * num_repeats / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+ # Determine the number of samples to select per epoch for each rank.
+ # num_selected logic defaults to be the same as original RASampler impl, but this one can be tweaked
+ # via selected_ratio and selected_round args.
+ selected_ratio = selected_ratio or num_replicas # ratio to reduce selected samples by, num_replicas if 0
+ if selected_round:
+ self.num_selected_samples = int(math.floor(
+ len(self.dataset) // selected_round * selected_round / selected_ratio))
+ else:
+ self.num_selected_samples = int(math.ceil(len(self.dataset) / selected_ratio))
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ if self.shuffle:
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = list(range(len(self.dataset)))
+
+ # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
+ indices = [x for x in indices for _ in range(self.num_repeats)]
+ # add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ indices += indices[:padding_size]
+ assert len(indices) == self.total_size
+
+ # subsample per rank
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ # return up to num selected samples
+ return iter(indices[:self.num_selected_samples])
+
+ def __len__(self):
+ return self.num_selected_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
\ No newline at end of file
diff --git a/timm/data/loader.py b/timm/data/loader.py
new file mode 100644
index 0000000..a02399a
--- /dev/null
+++ b/timm/data/loader.py
@@ -0,0 +1,289 @@
+""" Loader Factory, Fast Collate, CUDA Prefetcher
+
+Prefetcher and Fast Collate inspired by NVIDIA APEX example at
+https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import random
+from functools import partial
+from typing import Callable
+
+import torch.utils.data
+import numpy as np
+
+from .transforms_factory import create_transform
+from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
+from .random_erasing import RandomErasing
+from .mixup import FastCollateMixup
+
+
+def fast_collate(batch):
+ """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
+ assert isinstance(batch[0], tuple)
+ batch_size = len(batch)
+ if isinstance(batch[0][0], tuple):
+ # This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
+ # such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
+ inner_tuple_size = len(batch[0][0])
+ flattened_batch_size = batch_size * inner_tuple_size
+ targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
+ tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
+ for i in range(batch_size):
+ assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
+ for j in range(inner_tuple_size):
+ targets[i + j * batch_size] = batch[i][1]
+ tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
+ return tensor, targets
+ elif isinstance(batch[0][0], np.ndarray):
+ targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
+ assert len(targets) == batch_size
+ tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
+ for i in range(batch_size):
+ tensor[i] += torch.from_numpy(batch[i][0])
+ return tensor, targets
+ elif isinstance(batch[0][0], torch.Tensor):
+ targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
+ assert len(targets) == batch_size
+ tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
+ for i in range(batch_size):
+ tensor[i].copy_(batch[i][0])
+ return tensor, targets
+ else:
+ assert False
+
+
+class PrefetchLoader:
+
+ def __init__(self,
+ loader,
+ mean=IMAGENET_DEFAULT_MEAN,
+ std=IMAGENET_DEFAULT_STD,
+ fp16=False,
+ re_prob=0.,
+ re_mode='const',
+ re_count=1,
+ re_num_splits=0):
+ self.loader = loader
+ self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
+ self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
+ self.fp16 = fp16
+ if fp16:
+ self.mean = self.mean.half()
+ self.std = self.std.half()
+ if re_prob > 0.:
+ self.random_erasing = RandomErasing(
+ probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
+ else:
+ self.random_erasing = None
+
+ def __iter__(self):
+ stream = torch.cuda.Stream()
+ first = True
+
+ for next_input, next_target in self.loader:
+ with torch.cuda.stream(stream):
+ next_input = next_input.cuda(non_blocking=True)
+ next_target = next_target.cuda(non_blocking=True)
+ if self.fp16:
+ next_input = next_input.half().sub_(self.mean).div_(self.std)
+ else:
+ next_input = next_input.float().sub_(self.mean).div_(self.std)
+ if self.random_erasing is not None:
+ next_input = self.random_erasing(next_input)
+
+ if not first:
+ yield input, target
+ else:
+ first = False
+
+ torch.cuda.current_stream().wait_stream(stream)
+ input = next_input
+ target = next_target
+
+ yield input, target
+
+ def __len__(self):
+ return len(self.loader)
+
+ @property
+ def sampler(self):
+ return self.loader.sampler
+
+ @property
+ def dataset(self):
+ return self.loader.dataset
+
+ @property
+ def mixup_enabled(self):
+ if isinstance(self.loader.collate_fn, FastCollateMixup):
+ return self.loader.collate_fn.mixup_enabled
+ else:
+ return False
+
+ @mixup_enabled.setter
+ def mixup_enabled(self, x):
+ if isinstance(self.loader.collate_fn, FastCollateMixup):
+ self.loader.collate_fn.mixup_enabled = x
+
+
+def _worker_init(worker_id, worker_seeding='all'):
+ worker_info = torch.utils.data.get_worker_info()
+ assert worker_info.id == worker_id
+ if isinstance(worker_seeding, Callable):
+ seed = worker_seeding(worker_info)
+ random.seed(seed)
+ torch.manual_seed(seed)
+ np.random.seed(seed % (2 ** 32 - 1))
+ else:
+ assert worker_seeding in ('all', 'part')
+ # random / torch seed already called in dataloader iter class w/ worker_info.seed
+ # to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
+ if worker_seeding == 'all':
+ np.random.seed(worker_info.seed % (2 ** 32 - 1))
+
+
+def create_loader(
+ dataset,
+ input_size,
+ batch_size,
+ is_training=False,
+ use_prefetcher=True,
+ no_aug=False,
+ re_prob=0.,
+ re_mode='const',
+ re_count=1,
+ re_split=False,
+ scale=None,
+ ratio=None,
+ hflip=0.5,
+ vflip=0.,
+ color_jitter=0.4,
+ auto_augment=None,
+ num_aug_repeats=0,
+ num_aug_splits=0,
+ interpolation='bilinear',
+ mean=IMAGENET_DEFAULT_MEAN,
+ std=IMAGENET_DEFAULT_STD,
+ num_workers=1,
+ distributed=False,
+ crop_pct=None,
+ collate_fn=None,
+ pin_memory=False,
+ fp16=False,
+ tf_preprocessing=False,
+ use_multi_epochs_loader=False,
+ persistent_workers=True,
+ worker_seeding='all',
+):
+ re_num_splits = 0
+ if re_split:
+ # apply RE to second half of batch if no aug split otherwise line up with aug split
+ re_num_splits = num_aug_splits or 2
+ dataset.transform = create_transform(
+ input_size,
+ is_training=is_training,
+ use_prefetcher=use_prefetcher,
+ no_aug=no_aug,
+ scale=scale,
+ ratio=ratio,
+ hflip=hflip,
+ vflip=vflip,
+ color_jitter=color_jitter,
+ auto_augment=auto_augment,
+ interpolation=interpolation,
+ mean=mean,
+ std=std,
+ crop_pct=crop_pct,
+ tf_preprocessing=tf_preprocessing,
+ re_prob=re_prob,
+ re_mode=re_mode,
+ re_count=re_count,
+ re_num_splits=re_num_splits,
+ separate=num_aug_splits > 0,
+ )
+
+ sampler = None
+ if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
+ if is_training:
+ if num_aug_repeats:
+ sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
+ else:
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset)
+ else:
+ # This will add extra duplicate entries to result in equal num
+ # of samples per-process, will slightly alter validation results
+ sampler = OrderedDistributedSampler(dataset)
+ else:
+ assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
+
+ if collate_fn is None:
+ collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
+
+ loader_class = torch.utils.data.DataLoader
+ if use_multi_epochs_loader:
+ loader_class = MultiEpochsDataLoader
+
+ loader_args = dict(
+ batch_size=batch_size,
+ shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
+ num_workers=num_workers,
+ sampler=sampler,
+ collate_fn=collate_fn,
+ pin_memory=pin_memory,
+ drop_last=is_training,
+ worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
+ persistent_workers=persistent_workers
+ )
+ try:
+ loader = loader_class(dataset, **loader_args)
+ except TypeError as e:
+ loader_args.pop('persistent_workers') # only in Pytorch 1.7+
+ loader = loader_class(dataset, **loader_args)
+ if use_prefetcher:
+ prefetch_re_prob = re_prob if is_training and not no_aug else 0.
+ loader = PrefetchLoader(
+ loader,
+ mean=mean,
+ std=std,
+ fp16=fp16,
+ re_prob=prefetch_re_prob,
+ re_mode=re_mode,
+ re_count=re_count,
+ re_num_splits=re_num_splits
+ )
+
+ return loader
+
+
+class MultiEpochsDataLoader(torch.utils.data.DataLoader):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._DataLoader__initialized = False
+ self.batch_sampler = _RepeatSampler(self.batch_sampler)
+ self._DataLoader__initialized = True
+ self.iterator = super().__iter__()
+
+ def __len__(self):
+ return len(self.batch_sampler.sampler)
+
+ def __iter__(self):
+ for i in range(len(self)):
+ yield next(self.iterator)
+
+
+class _RepeatSampler(object):
+ """ Sampler that repeats forever.
+
+ Args:
+ sampler (Sampler)
+ """
+
+ def __init__(self, sampler):
+ self.sampler = sampler
+
+ def __iter__(self):
+ while True:
+ yield from iter(self.sampler)
diff --git a/timm/data/mixup.py b/timm/data/mixup.py
new file mode 100644
index 0000000..7e382c5
--- /dev/null
+++ b/timm/data/mixup.py
@@ -0,0 +1,316 @@
+""" Mixup and Cutmix
+
+Papers:
+mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
+
+CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
+
+Code Reference:
+CutMix: https://github.com/clovaai/CutMix-PyTorch
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import numpy as np
+import torch
+
+
+def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
+ x = x.long().view(-1, 1)
+ return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
+
+
+def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
+ off_value = smoothing / num_classes
+ on_value = 1. - smoothing + off_value
+ y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
+ y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
+ return y1 * lam + y2 * (1. - lam)
+
+
+def rand_bbox(img_shape, lam, margin=0., count=None):
+ """ Standard CutMix bounding-box
+ Generates a random square bbox based on lambda value. This impl includes
+ support for enforcing a border margin as percent of bbox dimensions.
+
+ Args:
+ img_shape (tuple): Image shape as tuple
+ lam (float): Cutmix lambda value
+ margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
+ count (int): Number of bbox to generate
+ """
+ ratio = np.sqrt(1 - lam)
+ img_h, img_w = img_shape[-2:]
+ cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
+ margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
+ cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
+ cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
+ yl = np.clip(cy - cut_h // 2, 0, img_h)
+ yh = np.clip(cy + cut_h // 2, 0, img_h)
+ xl = np.clip(cx - cut_w // 2, 0, img_w)
+ xh = np.clip(cx + cut_w // 2, 0, img_w)
+ return yl, yh, xl, xh
+
+
+def rand_bbox_minmax(img_shape, minmax, count=None):
+ """ Min-Max CutMix bounding-box
+ Inspired by Darknet cutmix impl, generates a random rectangular bbox
+ based on min/max percent values applied to each dimension of the input image.
+
+ Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
+
+ Args:
+ img_shape (tuple): Image shape as tuple
+ minmax (tuple or list): Min and max bbox ratios (as percent of image size)
+ count (int): Number of bbox to generate
+ """
+ assert len(minmax) == 2
+ img_h, img_w = img_shape[-2:]
+ cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
+ cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
+ yl = np.random.randint(0, img_h - cut_h, size=count)
+ xl = np.random.randint(0, img_w - cut_w, size=count)
+ yu = yl + cut_h
+ xu = xl + cut_w
+ return yl, yu, xl, xu
+
+
+def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
+ """ Generate bbox and apply lambda correction.
+ """
+ if ratio_minmax is not None:
+ yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
+ else:
+ yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
+ if correct_lam or ratio_minmax is not None:
+ bbox_area = (yu - yl) * (xu - xl)
+ lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
+ return (yl, yu, xl, xu), lam
+
+
+class Mixup:
+ """ Mixup/Cutmix that applies different params to each element or whole batch
+
+ Args:
+ mixup_alpha (float): mixup alpha value, mixup is active if > 0.
+ cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
+ cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
+ prob (float): probability of applying mixup or cutmix per batch or element
+ switch_prob (float): probability of switching to cutmix instead of mixup when both are active
+ mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
+ correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
+ label_smoothing (float): apply label smoothing to the mixed target tensor
+ num_classes (int): number of classes for target
+ """
+ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
+ mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
+ self.mixup_alpha = mixup_alpha
+ self.cutmix_alpha = cutmix_alpha
+ self.cutmix_minmax = cutmix_minmax
+ if self.cutmix_minmax is not None:
+ assert len(self.cutmix_minmax) == 2
+ # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
+ self.cutmix_alpha = 1.0
+ self.mix_prob = prob
+ self.switch_prob = switch_prob
+ self.label_smoothing = label_smoothing
+ self.num_classes = num_classes
+ self.mode = mode
+ self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
+ self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
+
+ def _params_per_elem(self, batch_size):
+ lam = np.ones(batch_size, dtype=np.float32)
+ use_cutmix = np.zeros(batch_size, dtype=np.bool)
+ if self.mixup_enabled:
+ if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
+ use_cutmix = np.random.rand(batch_size) < self.switch_prob
+ lam_mix = np.where(
+ use_cutmix,
+ np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
+ np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
+ elif self.mixup_alpha > 0.:
+ lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
+ elif self.cutmix_alpha > 0.:
+ use_cutmix = np.ones(batch_size, dtype=np.bool)
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
+ else:
+ assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
+ lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
+ return lam, use_cutmix
+
+ def _params_per_batch(self):
+ lam = 1.
+ use_cutmix = False
+ if self.mixup_enabled and np.random.rand() < self.mix_prob:
+ if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
+ use_cutmix = np.random.rand() < self.switch_prob
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
+ np.random.beta(self.mixup_alpha, self.mixup_alpha)
+ elif self.mixup_alpha > 0.:
+ lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
+ elif self.cutmix_alpha > 0.:
+ use_cutmix = True
+ lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
+ else:
+ assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
+ lam = float(lam_mix)
+ return lam, use_cutmix
+
+ def _mix_elem(self, x):
+ batch_size = len(x)
+ lam_batch, use_cutmix = self._params_per_elem(batch_size)
+ x_orig = x.clone() # need to keep an unmodified original for mixing source
+ for i in range(batch_size):
+ j = batch_size - i - 1
+ lam = lam_batch[i]
+ if lam != 1.:
+ if use_cutmix[i]:
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
+ lam_batch[i] = lam
+ else:
+ x[i] = x[i] * lam + x_orig[j] * (1 - lam)
+ return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
+
+ def _mix_pair(self, x):
+ batch_size = len(x)
+ lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
+ x_orig = x.clone() # need to keep an unmodified original for mixing source
+ for i in range(batch_size // 2):
+ j = batch_size - i - 1
+ lam = lam_batch[i]
+ if lam != 1.:
+ if use_cutmix[i]:
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
+ x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
+ lam_batch[i] = lam
+ else:
+ x[i] = x[i] * lam + x_orig[j] * (1 - lam)
+ x[j] = x[j] * lam + x_orig[i] * (1 - lam)
+ lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
+ return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
+
+ def _mix_batch(self, x):
+ lam, use_cutmix = self._params_per_batch()
+ if lam == 1.:
+ return 1.
+ if use_cutmix:
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
+ else:
+ x_flipped = x.flip(0).mul_(1. - lam)
+ x.mul_(lam).add_(x_flipped)
+ return lam
+
+ def __call__(self, x, target):
+ assert len(x) % 2 == 0, 'Batch size should be even when using this'
+ if self.mode == 'elem':
+ lam = self._mix_elem(x)
+ elif self.mode == 'pair':
+ lam = self._mix_pair(x)
+ else:
+ lam = self._mix_batch(x)
+ target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
+ return x, target
+
+
+class FastCollateMixup(Mixup):
+ """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
+
+ A Mixup impl that's performed while collating the batches.
+ """
+
+ def _mix_elem_collate(self, output, batch, half=False):
+ batch_size = len(batch)
+ num_elem = batch_size // 2 if half else batch_size
+ assert len(output) == num_elem
+ lam_batch, use_cutmix = self._params_per_elem(num_elem)
+ for i in range(num_elem):
+ j = batch_size - i - 1
+ lam = lam_batch[i]
+ mixed = batch[i][0]
+ if lam != 1.:
+ if use_cutmix[i]:
+ if not half:
+ mixed = mixed.copy()
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
+ lam_batch[i] = lam
+ else:
+ mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
+ np.rint(mixed, out=mixed)
+ output[i] += torch.from_numpy(mixed.astype(np.uint8))
+ if half:
+ lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
+ return torch.tensor(lam_batch).unsqueeze(1)
+
+ def _mix_pair_collate(self, output, batch):
+ batch_size = len(batch)
+ lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
+ for i in range(batch_size // 2):
+ j = batch_size - i - 1
+ lam = lam_batch[i]
+ mixed_i = batch[i][0]
+ mixed_j = batch[j][0]
+ assert 0 <= lam <= 1.0
+ if lam < 1.:
+ if use_cutmix[i]:
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ patch_i = mixed_i[:, yl:yh, xl:xh].copy()
+ mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
+ mixed_j[:, yl:yh, xl:xh] = patch_i
+ lam_batch[i] = lam
+ else:
+ mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
+ mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
+ mixed_i = mixed_temp
+ np.rint(mixed_j, out=mixed_j)
+ np.rint(mixed_i, out=mixed_i)
+ output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
+ output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
+ lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
+ return torch.tensor(lam_batch).unsqueeze(1)
+
+ def _mix_batch_collate(self, output, batch):
+ batch_size = len(batch)
+ lam, use_cutmix = self._params_per_batch()
+ if use_cutmix:
+ (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
+ output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
+ for i in range(batch_size):
+ j = batch_size - i - 1
+ mixed = batch[i][0]
+ if lam != 1.:
+ if use_cutmix:
+ mixed = mixed.copy() # don't want to modify the original while iterating
+ mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
+ else:
+ mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
+ np.rint(mixed, out=mixed)
+ output[i] += torch.from_numpy(mixed.astype(np.uint8))
+ return lam
+
+ def __call__(self, batch, _=None):
+ batch_size = len(batch)
+ assert batch_size % 2 == 0, 'Batch size should be even when using this'
+ half = 'half' in self.mode
+ if half:
+ batch_size //= 2
+ output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
+ if self.mode == 'elem' or self.mode == 'half':
+ lam = self._mix_elem_collate(output, batch, half=half)
+ elif self.mode == 'pair':
+ lam = self._mix_pair_collate(output, batch)
+ else:
+ lam = self._mix_batch_collate(output, batch)
+ target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
+ target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
+ target = target[:batch_size]
+ return output, target
+
diff --git a/timm/data/parsers/__init__.py b/timm/data/parsers/__init__.py
new file mode 100644
index 0000000..eeb44e3
--- /dev/null
+++ b/timm/data/parsers/__init__.py
@@ -0,0 +1 @@
+from .parser_factory import create_parser
diff --git a/timm/data/parsers/__pycache__/__init__.cpython-36.pyc b/timm/data/parsers/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..c27e10b
Binary files /dev/null and b/timm/data/parsers/__pycache__/__init__.cpython-36.pyc differ
diff --git a/timm/data/parsers/__pycache__/class_map.cpython-36.pyc b/timm/data/parsers/__pycache__/class_map.cpython-36.pyc
new file mode 100644
index 0000000..58fe839
Binary files /dev/null and b/timm/data/parsers/__pycache__/class_map.cpython-36.pyc differ
diff --git a/timm/data/parsers/__pycache__/constants.cpython-36.pyc b/timm/data/parsers/__pycache__/constants.cpython-36.pyc
new file mode 100644
index 0000000..5fcdc42
Binary files /dev/null and b/timm/data/parsers/__pycache__/constants.cpython-36.pyc differ
diff --git a/timm/data/parsers/__pycache__/parser.cpython-36.pyc b/timm/data/parsers/__pycache__/parser.cpython-36.pyc
new file mode 100644
index 0000000..5044634
Binary files /dev/null and b/timm/data/parsers/__pycache__/parser.cpython-36.pyc differ
diff --git a/timm/data/parsers/__pycache__/parser_factory.cpython-36.pyc b/timm/data/parsers/__pycache__/parser_factory.cpython-36.pyc
new file mode 100644
index 0000000..29045fa
Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_factory.cpython-36.pyc differ
diff --git a/timm/data/parsers/__pycache__/parser_image_folder.cpython-36.pyc b/timm/data/parsers/__pycache__/parser_image_folder.cpython-36.pyc
new file mode 100644
index 0000000..eb55787
Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_folder.cpython-36.pyc differ
diff --git a/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-36.pyc b/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-36.pyc
new file mode 100644
index 0000000..ae87757
Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_in_tar.cpython-36.pyc differ
diff --git a/timm/data/parsers/__pycache__/parser_image_tar.cpython-36.pyc b/timm/data/parsers/__pycache__/parser_image_tar.cpython-36.pyc
new file mode 100644
index 0000000..4f3a7b8
Binary files /dev/null and b/timm/data/parsers/__pycache__/parser_image_tar.cpython-36.pyc differ
diff --git a/timm/data/parsers/class_map.py b/timm/data/parsers/class_map.py
new file mode 100644
index 0000000..6b6fe45
--- /dev/null
+++ b/timm/data/parsers/class_map.py
@@ -0,0 +1,19 @@
+import os
+
+
+def load_class_map(map_or_filename, root=''):
+ if isinstance(map_or_filename, dict):
+ assert dict, 'class_map dict must be non-empty'
+ return map_or_filename
+ class_map_path = map_or_filename
+ if not os.path.exists(class_map_path):
+ class_map_path = os.path.join(root, class_map_path)
+ assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
+ class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
+ if class_map_ext == '.txt':
+ with open(class_map_path) as f:
+ class_to_idx = {v.strip(): k for k, v in enumerate(f)}
+ else:
+ assert False, f'Unsupported class map file extension ({class_map_ext}).'
+ return class_to_idx
+
diff --git a/timm/data/parsers/constants.py b/timm/data/parsers/constants.py
new file mode 100644
index 0000000..e7ba484
--- /dev/null
+++ b/timm/data/parsers/constants.py
@@ -0,0 +1 @@
+IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')
diff --git a/timm/data/parsers/parser.py b/timm/data/parsers/parser.py
new file mode 100644
index 0000000..76ab6d1
--- /dev/null
+++ b/timm/data/parsers/parser.py
@@ -0,0 +1,17 @@
+from abc import abstractmethod
+
+
+class Parser:
+ def __init__(self):
+ pass
+
+ @abstractmethod
+ def _filename(self, index, basename=False, absolute=False):
+ pass
+
+ def filename(self, index, basename=False, absolute=False):
+ return self._filename(index, basename=basename, absolute=absolute)
+
+ def filenames(self, basename=False, absolute=False):
+ return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
+
diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py
new file mode 100644
index 0000000..892090a
--- /dev/null
+++ b/timm/data/parsers/parser_factory.py
@@ -0,0 +1,29 @@
+import os
+
+from .parser_image_folder import ParserImageFolder
+from .parser_image_tar import ParserImageTar
+from .parser_image_in_tar import ParserImageInTar
+
+
+def create_parser(name, root, split='train', **kwargs):
+ name = name.lower()
+ name = name.split('/', 2)
+ prefix = ''
+ if len(name) > 1:
+ prefix = name[0]
+ name = name[-1]
+
+ # FIXME improve the selection right now just tfds prefix or fallback path, will need options to
+ # explicitly select other options shortly
+ if prefix == 'tfds':
+ from .parser_tfds import ParserTfds # defer tensorflow import
+ parser = ParserTfds(root, name, split=split, **kwargs)
+ else:
+ assert os.path.exists(root)
+ # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
+ # FIXME support split here, in parser?
+ if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
+ parser = ParserImageInTar(root, **kwargs)
+ else:
+ parser = ParserImageFolder(root, **kwargs)
+ return parser
diff --git a/timm/data/parsers/parser_image_folder.py b/timm/data/parsers/parser_image_folder.py
new file mode 100644
index 0000000..ed34900
--- /dev/null
+++ b/timm/data/parsers/parser_image_folder.py
@@ -0,0 +1,69 @@
+""" A dataset parser that reads images from folders
+
+Folders are scannerd recursively to find image files. Labels are based
+on the folder hierarchy, just leaf folders by default.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import os
+
+from timm.utils.misc import natural_key
+
+from .parser import Parser
+from .class_map import load_class_map
+from .constants import IMG_EXTENSIONS
+
+
+def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
+ labels = []
+ filenames = []
+ for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
+ rel_path = os.path.relpath(root, folder) if (root != folder) else ''
+ label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
+ for f in files:
+ base, ext = os.path.splitext(f)
+ if ext.lower() in types:
+ filenames.append(os.path.join(root, f))
+ labels.append(label)
+ if class_to_idx is None:
+ # building class index
+ unique_labels = set(labels)
+ sorted_labels = list(sorted(unique_labels, key=natural_key))
+ class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
+ images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
+ if sort:
+ images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
+ return images_and_targets, class_to_idx
+
+
+class ParserImageFolder(Parser):
+
+ def __init__(
+ self,
+ root,
+ class_map=''):
+ super().__init__()
+
+ self.root = root
+ class_to_idx = None
+ if class_map:
+ class_to_idx = load_class_map(class_map, root)
+ self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
+ if len(self.samples) == 0:
+ raise RuntimeError(
+ f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
+
+ def __getitem__(self, index):
+ path, target = self.samples[index]
+ return open(path, 'rb'), target
+
+ def __len__(self):
+ return len(self.samples)
+
+ def _filename(self, index, basename=False, absolute=False):
+ filename = self.samples[index][0]
+ if basename:
+ filename = os.path.basename(filename)
+ elif not absolute:
+ filename = os.path.relpath(filename, self.root)
+ return filename
diff --git a/timm/data/parsers/parser_image_in_tar.py b/timm/data/parsers/parser_image_in_tar.py
new file mode 100644
index 0000000..c6ada96
--- /dev/null
+++ b/timm/data/parsers/parser_image_in_tar.py
@@ -0,0 +1,222 @@
+""" A dataset parser that reads tarfile based datasets
+
+This parser can read and extract image samples from:
+* a single tar of image files
+* a folder of multiple tarfiles containing imagefiles
+* a tar of tars containing image files
+
+Labels are based on the combined folder and/or tar name structure.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import os
+import tarfile
+import pickle
+import logging
+import numpy as np
+from glob import glob
+from typing import List, Dict
+
+from timm.utils.misc import natural_key
+
+from .parser import Parser
+from .class_map import load_class_map
+from .constants import IMG_EXTENSIONS
+
+
+_logger = logging.getLogger(__name__)
+CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
+
+
+class TarState:
+
+ def __init__(self, tf: tarfile.TarFile = None, ti: tarfile.TarInfo = None):
+ self.tf: tarfile.TarFile = tf
+ self.ti: tarfile.TarInfo = ti
+ self.children: Dict[str, TarState] = {} # child states (tars within tars)
+
+ def reset(self):
+ self.tf = None
+
+
+def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
+ sample_count = 0
+ for i, ti in enumerate(tf):
+ if not ti.isfile():
+ continue
+ dirname, basename = os.path.split(ti.path)
+ name, ext = os.path.splitext(basename)
+ ext = ext.lower()
+ if ext == '.tar':
+ with tarfile.open(fileobj=tf.extractfile(ti), mode='r|') as ctf:
+ child_info = dict(
+ name=ti.name, path=os.path.join(parent_info['path'], name), ti=ti, children=[], samples=[])
+ sample_count += _extract_tarinfo(ctf, child_info, extensions=extensions)
+ _logger.debug(f'{i}/?. Extracted child tarinfos from {ti.name}. {len(child_info["samples"])} images.')
+ parent_info['children'].append(child_info)
+ elif ext in extensions:
+ parent_info['samples'].append(ti)
+ sample_count += 1
+ return sample_count
+
+
+def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
+ root_is_tar = False
+ if os.path.isfile(root):
+ assert os.path.splitext(root)[-1].lower() == '.tar'
+ tar_filenames = [root]
+ root, root_name = os.path.split(root)
+ root_name = os.path.splitext(root_name)[0]
+ root_is_tar = True
+ else:
+ root_name = root.strip(os.path.sep).split(os.path.sep)[-1]
+ tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
+ num_tars = len(tar_filenames)
+ tar_bytes = sum([os.path.getsize(f) for f in tar_filenames])
+ assert num_tars, f'No .tar files found at specified path ({root}).'
+
+ _logger.info(f'Scanning {tar_bytes/1024**2:.2f}MB of tar files...')
+ info = dict(tartrees=[])
+ cache_path = ''
+ if cache_tarinfo is None:
+ cache_tarinfo = True if tar_bytes > 10*1024**3 else False # FIXME magic number, 10GB
+ if cache_tarinfo:
+ cache_filename = '_' + root_name + CACHE_FILENAME_SUFFIX
+ cache_path = os.path.join(root, cache_filename)
+ if os.path.exists(cache_path):
+ _logger.info(f'Reading tar info from cache file {cache_path}.')
+ with open(cache_path, 'rb') as pf:
+ info = pickle.load(pf)
+ assert len(info['tartrees']) == num_tars, "Cached tartree len doesn't match number of tarfiles"
+ else:
+ for i, fn in enumerate(tar_filenames):
+ path = '' if root_is_tar else os.path.splitext(os.path.basename(fn))[0]
+ with tarfile.open(fn, mode='r|') as tf: # tarinfo scans done in streaming mode
+ parent_info = dict(name=os.path.relpath(fn, root), path=path, ti=None, children=[], samples=[])
+ num_samples = _extract_tarinfo(tf, parent_info, extensions=extensions)
+ num_children = len(parent_info["children"])
+ _logger.debug(
+ f'{i}/{num_tars}. Extracted tarinfos from {fn}. {num_children} children, {num_samples} samples.')
+ info['tartrees'].append(parent_info)
+ if cache_path:
+ _logger.info(f'Writing tar info to cache file {cache_path}.')
+ with open(cache_path, 'wb') as pf:
+ pickle.dump(info, pf)
+
+ samples = []
+ labels = []
+ build_class_map = False
+ if class_name_to_idx is None:
+ build_class_map = True
+
+ # Flatten tartree info into lists of samples and targets w/ targets based on label id via
+ # class map arg or from unique paths.
+ # NOTE: currently only flattening up to two-levels, filesystem .tars and then one level of sub-tar children
+ # this covers my current use cases and keeps things a little easier to test for now.
+ tarfiles = []
+
+ def _label_from_paths(*path, leaf_only=True):
+ path = os.path.join(*path).strip(os.path.sep)
+ return path.split(os.path.sep)[-1] if leaf_only else path.replace(os.path.sep, '_')
+
+ def _add_samples(info, fn):
+ added = 0
+ for s in info['samples']:
+ label = _label_from_paths(info['path'], os.path.dirname(s.path))
+ if not build_class_map and label not in class_name_to_idx:
+ continue
+ samples.append((s, fn, info['ti']))
+ labels.append(label)
+ added += 1
+ return added
+
+ _logger.info(f'Collecting samples and building tar states.')
+ for parent_info in info['tartrees']:
+ # if tartree has children, we assume all samples are at the child level
+ tar_name = None if root_is_tar else parent_info['name']
+ tar_state = TarState()
+ parent_added = 0
+ for child_info in parent_info['children']:
+ child_added = _add_samples(child_info, fn=tar_name)
+ if child_added:
+ tar_state.children[child_info['name']] = TarState(ti=child_info['ti'])
+ parent_added += child_added
+ parent_added += _add_samples(parent_info, fn=tar_name)
+ if parent_added:
+ tarfiles.append((tar_name, tar_state))
+ del info
+
+ if build_class_map:
+ # build class index
+ sorted_labels = list(sorted(set(labels), key=natural_key))
+ class_name_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
+
+ _logger.info(f'Mapping targets and sorting samples.')
+ samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx]
+ if sort:
+ samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path))
+ samples, targets = zip(*samples_and_targets)
+ samples = np.array(samples)
+ targets = np.array(targets)
+ _logger.info(f'Finished processing {len(samples)} samples across {len(tarfiles)} tar files.')
+ return samples, targets, class_name_to_idx, tarfiles
+
+
+class ParserImageInTar(Parser):
+ """ Multi-tarfile dataset parser where there is one .tar file per class
+ """
+
+ def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
+ super().__init__()
+
+ class_name_to_idx = None
+ if class_map:
+ class_name_to_idx = load_class_map(class_map, root)
+ self.root = root
+ self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
+ self.root,
+ class_name_to_idx=class_name_to_idx,
+ cache_tarinfo=cache_tarinfo,
+ extensions=IMG_EXTENSIONS)
+ self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
+ if len(tarfiles) == 1 and tarfiles[0][0] is None:
+ self.root_is_tar = True
+ self.tar_state = tarfiles[0][1]
+ else:
+ self.root_is_tar = False
+ self.tar_state = dict(tarfiles)
+ self.cache_tarfiles = cache_tarfiles
+
+ def __len__(self):
+ return len(self.samples)
+
+ def __getitem__(self, index):
+ sample = self.samples[index]
+ target = self.targets[index]
+ sample_ti, parent_fn, child_ti = sample
+ parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root
+
+ tf = None
+ cache_state = None
+ if self.cache_tarfiles:
+ cache_state = self.tar_state if self.root_is_tar else self.tar_state[parent_fn]
+ tf = cache_state.tf
+ if tf is None:
+ tf = tarfile.open(parent_abs)
+ if self.cache_tarfiles:
+ cache_state.tf = tf
+ if child_ti is not None:
+ ctf = cache_state.children[child_ti.name].tf if self.cache_tarfiles else None
+ if ctf is None:
+ ctf = tarfile.open(fileobj=tf.extractfile(child_ti))
+ if self.cache_tarfiles:
+ cache_state.children[child_ti.name].tf = ctf
+ tf = ctf
+
+ return tf.extractfile(sample_ti), target
+
+ def _filename(self, index, basename=False, absolute=False):
+ filename = self.samples[index][0].name
+ if basename:
+ filename = os.path.basename(filename)
+ return filename
diff --git a/timm/data/parsers/parser_image_tar.py b/timm/data/parsers/parser_image_tar.py
new file mode 100644
index 0000000..467537f
--- /dev/null
+++ b/timm/data/parsers/parser_image_tar.py
@@ -0,0 +1,72 @@
+""" A dataset parser that reads single tarfile based datasets
+
+This parser can read datasets consisting if a single tarfile containing images.
+I am planning to deprecated it in favour of ParerImageInTar.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import os
+import tarfile
+
+from .parser import Parser
+from .class_map import load_class_map
+from .constants import IMG_EXTENSIONS
+from timm.utils.misc import natural_key
+
+
+def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
+ files = []
+ labels = []
+ for ti in tarfile.getmembers():
+ if not ti.isfile():
+ continue
+ dirname, basename = os.path.split(ti.path)
+ label = os.path.basename(dirname)
+ ext = os.path.splitext(basename)[1]
+ if ext.lower() in IMG_EXTENSIONS:
+ files.append(ti)
+ labels.append(label)
+ if class_to_idx is None:
+ unique_labels = set(labels)
+ sorted_labels = list(sorted(unique_labels, key=natural_key))
+ class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
+ tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
+ if sort:
+ tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
+ return tarinfo_and_targets, class_to_idx
+
+
+class ParserImageTar(Parser):
+ """ Single tarfile dataset where classes are mapped to folders within tar
+ NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
+ operate on folders of tars or tars in tars.
+ """
+ def __init__(self, root, class_map=''):
+ super().__init__()
+
+ class_to_idx = None
+ if class_map:
+ class_to_idx = load_class_map(class_map, root)
+ assert os.path.isfile(root)
+ self.root = root
+
+ with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
+ self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
+ self.imgs = self.samples
+ self.tarfile = None # lazy init in __getitem__
+
+ def __getitem__(self, index):
+ if self.tarfile is None:
+ self.tarfile = tarfile.open(self.root)
+ tarinfo, target = self.samples[index]
+ fileobj = self.tarfile.extractfile(tarinfo)
+ return fileobj, target
+
+ def __len__(self):
+ return len(self.samples)
+
+ def _filename(self, index, basename=False, absolute=False):
+ filename = self.samples[index][0].name
+ if basename:
+ filename = os.path.basename(filename)
+ return filename
diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py
new file mode 100644
index 0000000..ee5893c
--- /dev/null
+++ b/timm/data/parsers/parser_tfds.py
@@ -0,0 +1,297 @@
+""" Dataset parser interface that wraps TFDS datasets
+
+Wraps many (most?) TFDS image-classification datasets
+from https://github.com/tensorflow/datasets
+https://www.tensorflow.org/datasets/catalog/overview#image_classification
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import math
+import torch
+import torch.distributed as dist
+from PIL import Image
+
+try:
+ import tensorflow as tf
+ tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
+ import tensorflow_datasets as tfds
+ try:
+ tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
+ has_buggy_even_splits = False
+ except TypeError:
+ print("Warning: This version of tfds doesn't have the latest even_splits impl. "
+ "Please update or use tfds-nightly for better fine-grained split behaviour.")
+ has_buggy_even_splits = True
+except ImportError as e:
+ print(e)
+ print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
+ exit(1)
+from .parser import Parser
+
+
+MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
+SHUFFLE_SIZE = 8192 # examples to shuffle in DS queue
+PREFETCH_SIZE = 2048 # examples to prefetch
+
+
+def even_split_indices(split, n, num_examples):
+ partitions = [round(i * num_examples / n) for i in range(n + 1)]
+ return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)]
+
+
+def get_class_labels(info):
+ if 'label' not in info.features:
+ return {}
+ class_label = info.features['label']
+ class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
+ return class_to_idx
+
+
+class ParserTfds(Parser):
+ """ Wrap Tensorflow Datasets for use in PyTorch
+
+ There several things to be aware of:
+ * To prevent excessive examples being dropped per epoch w/ distributed training or multiplicity of
+ dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
+ https://github.com/pytorch/pytorch/issues/33413
+ * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
+ from each worker could be a different size. For training this is worked around by option above, for
+ validation extra examples are inserted iff distributed mode is enabled so that the batches being reduced
+ across replicas are of same size. This will slightly alter the results, distributed validation will not be
+ 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
+ since there are up to N * J extra examples with IterableDatasets.
+ * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
+ replicas and dataloader workers you can use. For really small datasets that only contain a few shards
+ you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
+ benefit of distributed training or fast dataloading should be much less for small datasets.
+ * This wrapper is currently configured to return individual, decompressed image examples from the TFDS
+ dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
+ to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
+ components.
+
+ """
+
+ def __init__(
+ self,
+ root,
+ name,
+ split='train',
+ is_training=False,
+ batch_size=None,
+ download=False,
+ repeats=0,
+ seed=42,
+ input_name='image',
+ input_image='RGB',
+ target_name='label',
+ target_image='',
+ prefetch_size=None,
+ shuffle_size=None,
+ max_threadpool_size=None
+ ):
+ """ Tensorflow-datasets Wrapper
+
+ Args:
+ root: root data dir (ie your TFDS_DATA_DIR. not dataset specific sub-dir)
+ name: tfds dataset name (eg `imagenet2012`)
+ split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
+ is_training: training mode, shuffle enabled, dataset len rounded by batch_size
+ batch_size: batch_size to use to unsure total examples % batch_size == 0 in training across all dis nodes
+ download: download and build TFDS dataset if set, otherwise must use tfds CLI
+ repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
+ seed: common seed for shard shuffle across all distributed/worker instances
+ input_name: name of Feature to return as data (input)
+ input_image: image mode if input is an image (currently PIL mode string)
+ target_name: name of Feature to return as target (label)
+ target_image: image mode if target is an image (currently PIL mode string)
+ prefetch_size: override default tf.data prefetch buffer size
+ shuffle_size: override default tf.data shuffle buffer size
+ max_threadpool_size: override default threadpool size for tf.data
+ """
+ super().__init__()
+ self.root = root
+ self.split = split
+ self.is_training = is_training
+ if self.is_training:
+ assert batch_size is not None, \
+ "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
+ self.batch_size = batch_size
+ self.repeats = repeats
+ self.common_seed = seed # a seed that's fixed across all worker / distributed instances
+
+ # performance settings
+ self.prefetch_size = prefetch_size or PREFETCH_SIZE
+ self.shuffle_size = shuffle_size or SHUFFLE_SIZE
+ self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
+
+ # TFDS builder and split information
+ self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature
+ self.input_image = input_image
+ self.target_name = target_name
+ self.target_image = target_image
+ self.builder = tfds.builder(name, data_dir=root)
+ # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
+ if download:
+ self.builder.download_and_prepare()
+ self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
+ self.split_info = self.builder.info.splits[split]
+ self.num_examples = self.split_info.num_examples
+
+ # Distributed world state
+ self.dist_rank = 0
+ self.dist_num_replicas = 1
+ if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
+ self.dist_rank = dist.get_rank()
+ self.dist_num_replicas = dist.get_world_size()
+
+ # Attributes that are updated in _lazy_init, including the tf.data pipeline itself
+ self.global_num_workers = 1
+ self.worker_info = None
+ self.worker_seed = 0 # seed unique to each work instance
+ self.subsplit = None # set when data is distributed across workers using sub-splits
+ self.ds = None # initialized lazily on each dataloader worker process
+
+ def _lazy_init(self):
+ """ Lazily initialize the dataset.
+
+ This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that
+ will be using the dataset instance. The __init__ method is called on the main process,
+ this will be called in a dataloader worker process.
+
+ NOTE: There will be problems if you try to re-use this dataset across different loader/worker
+ instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
+ before it is passed to dataloader.
+ """
+ worker_info = torch.utils.data.get_worker_info()
+
+ # setup input context to split dataset across distributed processes
+ num_workers = 1
+ global_worker_id = 0
+ if worker_info is not None:
+ self.worker_info = worker_info
+ self.worker_seed = worker_info.seed
+ num_workers = worker_info.num_workers
+ self.global_num_workers = self.dist_num_replicas * num_workers
+ global_worker_id = self.dist_rank * num_workers + worker_info.id
+
+ """ Data sharding
+ InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
+ My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
+ between the splits each iteration, but that understanding could be wrong.
+
+ I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
+ the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
+ in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
+ for validation where we can't drop examples and need to avoid minimize uneven splits to avoid padding.
+ """
+ should_subsplit = self.global_num_workers > 1 and (
+ self.split_info.num_shards < self.global_num_workers or not self.is_training)
+ if should_subsplit:
+ # split the dataset w/o using sharding for more even examples / worker, can result in less optimal
+ # read patterns for distributed training (overlap across shards) so better to use InputContext there
+ if has_buggy_even_splits:
+ # my even_split workaround doesn't work on subsplits, upgrade tfds!
+ if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
+ subsplits = even_split_indices(self.split, self.global_num_workers, self.num_examples)
+ self.subsplit = subsplits[global_worker_id]
+ else:
+ subsplits = tfds.even_splits(self.split, self.global_num_workers)
+ self.subsplit = subsplits[global_worker_id]
+
+ input_context = None
+ if self.global_num_workers > 1 and self.subsplit is None:
+ # set input context to divide shards among distributed replicas
+ input_context = tf.distribute.InputContext(
+ num_input_pipelines=self.global_num_workers,
+ input_pipeline_id=global_worker_id,
+ num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
+ )
+ read_config = tfds.ReadConfig(
+ shuffle_seed=self.common_seed,
+ shuffle_reshuffle_each_iteration=True,
+ input_context=input_context)
+ ds = self.builder.as_dataset(
+ split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
+ # avoid overloading threading w/ combo of TF ds threads + PyTorch workers
+ options = tf.data.Options()
+ thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
+ getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers)
+ getattr(options, thread_member).max_intra_op_parallelism = 1
+ ds = ds.with_options(options)
+ if self.is_training or self.repeats > 1:
+ # to prevent excessive drop_last batch behaviour w/ IterableDatasets
+ # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
+ ds = ds.repeat() # allow wrap around and break iteration manually
+ if self.is_training:
+ ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
+ ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size))
+ self.ds = tfds.as_numpy(ds)
+
+ def __iter__(self):
+ if self.ds is None:
+ self._lazy_init()
+
+ # Compute a rounded up sample count that is used to:
+ # 1. make batches even cross workers & replicas in distributed validation.
+ # This adds extra examples and will slightly alter validation results.
+ # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
+ # batches are produced (underlying tfds iter wraps around)
+ target_example_count = math.ceil(max(1, self.repeats) * self.num_examples / self.global_num_workers)
+ if self.is_training:
+ # round up to nearest batch_size per worker-replica
+ target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size
+
+ # Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
+ example_count = 0
+ for example in self.ds:
+ input_data = example[self.input_name]
+ if self.input_image:
+ input_data = Image.fromarray(input_data, mode=self.input_image)
+ target_data = example[self.target_name]
+ if self.target_image:
+ target_data = Image.fromarray(target_data, mode=self.target_image)
+ yield input_data, target_data
+ example_count += 1
+ if self.is_training and example_count >= target_example_count:
+ # Need to break out of loop when repeat() is enabled for training w/ oversampling
+ # this results in extra examples per epoch but seems more desirable than dropping
+ # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
+ break
+
+ # Pad across distributed nodes (make counts equal by adding examples)
+ if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
+ 0 < example_count < target_example_count:
+ # Validation batch padding only done for distributed training where results are reduced across nodes.
+ # For single process case, it won't matter if workers return different batch sizes.
+ # If using input_context or % based splits, sample count can vary significantly across workers and this
+ # approach should not be used (hence disabled if self.subsplit isn't set).
+ while example_count < target_example_count:
+ yield input_data, target_data # yield prev sample again
+ example_count += 1
+
+ def __len__(self):
+ # this is just an estimate and does not factor in extra examples added to pad batches based on
+ # complete worker & replica info (not available until init in dataloader).
+ return math.ceil(max(1, self.repeats) * self.num_examples / self.dist_num_replicas)
+
+ def _filename(self, index, basename=False, absolute=False):
+ assert False, "Not supported" # no random access to examples
+
+ def filenames(self, basename=False, absolute=False):
+ """ Return all filenames in dataset, overrides base"""
+ if self.ds is None:
+ self._lazy_init()
+ names = []
+ for sample in self.ds:
+ if len(names) > self.num_examples:
+ break # safety for ds.repeat() case
+ if 'file_name' in sample:
+ name = sample['file_name']
+ elif 'filename' in sample:
+ name = sample['filename']
+ elif 'id' in sample:
+ name = sample['id']
+ else:
+ assert False, "No supported name field present"
+ names.append(name)
+ return names
diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py
new file mode 100644
index 0000000..2fa6315
--- /dev/null
+++ b/timm/data/random_erasing.py
@@ -0,0 +1,103 @@
+""" Random Erasing (Cutout)
+
+Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0
+Copyright Zhun Zhong & Liang Zheng
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import random
+import math
+import torch
+
+
+def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'):
+ # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
+ # paths, flip the order so normal is run on CPU if this becomes a problem
+ # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
+ if per_pixel:
+ return torch.empty(patch_size, dtype=dtype, device=device).normal_()
+ elif rand_color:
+ return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
+ else:
+ return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
+
+
+class RandomErasing:
+ """ Randomly selects a rectangle region in an image and erases its pixels.
+ 'Random Erasing Data Augmentation' by Zhong et al.
+ See https://arxiv.org/pdf/1708.04896.pdf
+
+ This variant of RandomErasing is intended to be applied to either a batch
+ or single image tensor after it has been normalized by dataset mean and std.
+ Args:
+ probability: Probability that the Random Erasing operation will be performed.
+ min_area: Minimum percentage of erased area wrt input image area.
+ max_area: Maximum percentage of erased area wrt input image area.
+ min_aspect: Minimum aspect ratio of erased area.
+ mode: pixel color mode, one of 'const', 'rand', or 'pixel'
+ 'const' - erase block is constant color of 0 for all channels
+ 'rand' - erase block is same per-channel random (normal) color
+ 'pixel' - erase block is per-pixel random (normal) color
+ max_count: maximum number of erasing blocks per image, area per box is scaled by count.
+ per-image count is randomly chosen between 1 and this value.
+ """
+
+ def __init__(
+ self,
+ probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
+ mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
+ self.probability = probability
+ self.min_area = min_area
+ self.max_area = max_area
+ max_aspect = max_aspect or 1 / min_aspect
+ self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
+ self.min_count = min_count
+ self.max_count = max_count or min_count
+ self.num_splits = num_splits
+ self.mode = mode.lower()
+ self.rand_color = False
+ self.per_pixel = False
+ if self.mode == 'rand':
+ self.rand_color = True # per block random normal
+ elif self.mode == 'pixel':
+ self.per_pixel = True # per pixel random normal
+ else:
+ assert not self.mode or self.mode == 'const'
+ self.device = device
+
+ def _erase(self, img, chan, img_h, img_w, dtype):
+ if random.random() > self.probability:
+ return
+ area = img_h * img_w
+ count = self.min_count if self.min_count == self.max_count else \
+ random.randint(self.min_count, self.max_count)
+ for _ in range(count):
+ for attempt in range(10):
+ target_area = random.uniform(self.min_area, self.max_area) * area / count
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
+ if w < img_w and h < img_h:
+ top = random.randint(0, img_h - h)
+ left = random.randint(0, img_w - w)
+ img[:, top:top + h, left:left + w] = _get_pixels(
+ self.per_pixel, self.rand_color, (chan, h, w),
+ dtype=dtype, device=self.device)
+ break
+
+ def __call__(self, input):
+ if len(input.size()) == 3:
+ self._erase(input, *input.size(), input.dtype)
+ else:
+ batch_size, chan, img_h, img_w = input.size()
+ # skip first slice of batch if num_splits is set (for clean portion of samples)
+ batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
+ for i in range(batch_start, batch_size):
+ self._erase(input[i], chan, img_h, img_w, input.dtype)
+ return input
+
+ def __repr__(self):
+ # NOTE simplified state for repr
+ fs = self.__class__.__name__ + f'(p={self.probability}, mode={self.mode}'
+ fs += f', count=({self.min_count}, {self.max_count}))'
+ return fs
diff --git a/timm/data/real_labels.py b/timm/data/real_labels.py
new file mode 100644
index 0000000..939c348
--- /dev/null
+++ b/timm/data/real_labels.py
@@ -0,0 +1,42 @@
+""" Real labels evaluator for ImageNet
+Paper: `Are we done with ImageNet?` - https://arxiv.org/abs/2006.07159
+Based on Numpy example at https://github.com/google-research/reassessed-imagenet
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import os
+import json
+import numpy as np
+
+
+class RealLabelsImagenet:
+
+ def __init__(self, filenames, real_json='real.json', topk=(1, 5)):
+ with open(real_json) as real_labels:
+ real_labels = json.load(real_labels)
+ real_labels = {f'ILSVRC2012_val_{i + 1:08d}.JPEG': labels for i, labels in enumerate(real_labels)}
+ self.real_labels = real_labels
+ self.filenames = filenames
+ assert len(self.filenames) == len(self.real_labels)
+ self.topk = topk
+ self.is_correct = {k: [] for k in topk}
+ self.sample_idx = 0
+
+ def add_result(self, output):
+ maxk = max(self.topk)
+ _, pred_batch = output.topk(maxk, 1, True, True)
+ pred_batch = pred_batch.cpu().numpy()
+ for pred in pred_batch:
+ filename = self.filenames[self.sample_idx]
+ filename = os.path.basename(filename)
+ if self.real_labels[filename]:
+ for k in self.topk:
+ self.is_correct[k].append(
+ any([p in self.real_labels[filename] for p in pred[:k]]))
+ self.sample_idx += 1
+
+ def get_accuracy(self, k=None):
+ if k is None:
+ return {k: float(np.mean(self.is_correct[k])) * 100 for k in self.topk}
+ else:
+ return float(np.mean(self.is_correct[k])) * 100
diff --git a/timm/data/tf_preprocessing.py b/timm/data/tf_preprocessing.py
new file mode 100644
index 0000000..44b4a3a
--- /dev/null
+++ b/timm/data/tf_preprocessing.py
@@ -0,0 +1,232 @@
+""" Tensorflow Preprocessing Adapter
+
+Allows use of Tensorflow preprocessing pipeline in PyTorch Transform
+
+Copyright of original Tensorflow code below.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""ImageNet preprocessing for MnasNet."""
+import tensorflow as tf
+import numpy as np
+
+IMAGE_SIZE = 224
+CROP_PADDING = 32
+
+
+def distorted_bounding_box_crop(image_bytes,
+ bbox,
+ min_object_covered=0.1,
+ aspect_ratio_range=(0.75, 1.33),
+ area_range=(0.05, 1.0),
+ max_attempts=100,
+ scope=None):
+ """Generates cropped_image using one of the bboxes randomly distorted.
+
+ See `tf.image.sample_distorted_bounding_box` for more documentation.
+
+ Args:
+ image_bytes: `Tensor` of binary image data.
+ bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
+ where each coordinate is [0, 1) and the coordinates are arranged
+ as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
+ image.
+ min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
+ area of the image must contain at least this fraction of any bounding
+ box supplied.
+ aspect_ratio_range: An optional list of `float`s. The cropped area of the
+ image must have an aspect ratio = width / height within this range.
+ area_range: An optional list of `float`s. The cropped area of the image
+ must contain a fraction of the supplied image within in this range.
+ max_attempts: An optional `int`. Number of attempts at generating a cropped
+ region of the image of the specified constraints. After `max_attempts`
+ failures, return the entire image.
+ scope: Optional `str` for name scope.
+ Returns:
+ cropped image `Tensor`
+ """
+ with tf.name_scope(scope, 'distorted_bounding_box_crop', [image_bytes, bbox]):
+ shape = tf.image.extract_jpeg_shape(image_bytes)
+ sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
+ shape,
+ bounding_boxes=bbox,
+ min_object_covered=min_object_covered,
+ aspect_ratio_range=aspect_ratio_range,
+ area_range=area_range,
+ max_attempts=max_attempts,
+ use_image_if_no_bounding_boxes=True)
+ bbox_begin, bbox_size, _ = sample_distorted_bounding_box
+
+ # Crop the image to the specified bounding box.
+ offset_y, offset_x, _ = tf.unstack(bbox_begin)
+ target_height, target_width, _ = tf.unstack(bbox_size)
+ crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
+ image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
+
+ return image
+
+
+def _at_least_x_are_equal(a, b, x):
+ """At least `x` of `a` and `b` `Tensors` are equal."""
+ match = tf.equal(a, b)
+ match = tf.cast(match, tf.int32)
+ return tf.greater_equal(tf.reduce_sum(match), x)
+
+
+def _decode_and_random_crop(image_bytes, image_size, resize_method):
+ """Make a random crop of image_size."""
+ bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
+ image = distorted_bounding_box_crop(
+ image_bytes,
+ bbox,
+ min_object_covered=0.1,
+ aspect_ratio_range=(3. / 4, 4. / 3.),
+ area_range=(0.08, 1.0),
+ max_attempts=10,
+ scope=None)
+ original_shape = tf.image.extract_jpeg_shape(image_bytes)
+ bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)
+
+ image = tf.cond(
+ bad,
+ lambda: _decode_and_center_crop(image_bytes, image_size),
+ lambda: tf.image.resize([image], [image_size, image_size], resize_method)[0])
+
+ return image
+
+
+def _decode_and_center_crop(image_bytes, image_size, resize_method):
+ """Crops to center of image with padding then scales image_size."""
+ shape = tf.image.extract_jpeg_shape(image_bytes)
+ image_height = shape[0]
+ image_width = shape[1]
+
+ padded_center_crop_size = tf.cast(
+ ((image_size / (image_size + CROP_PADDING)) *
+ tf.cast(tf.minimum(image_height, image_width), tf.float32)),
+ tf.int32)
+
+ offset_height = ((image_height - padded_center_crop_size) + 1) // 2
+ offset_width = ((image_width - padded_center_crop_size) + 1) // 2
+ crop_window = tf.stack([offset_height, offset_width,
+ padded_center_crop_size, padded_center_crop_size])
+ image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
+ image = tf.image.resize([image], [image_size, image_size], resize_method)[0]
+
+ return image
+
+
+def _flip(image):
+ """Random horizontal image flip."""
+ image = tf.image.random_flip_left_right(image)
+ return image
+
+
+def preprocess_for_train(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'):
+ """Preprocesses the given image for evaluation.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+ use_bfloat16: `bool` for whether to use bfloat16.
+ image_size: image size.
+ interpolation: image interpolation method
+
+ Returns:
+ A preprocessed image `Tensor`.
+ """
+ resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
+ image = _decode_and_random_crop(image_bytes, image_size, resize_method)
+ image = _flip(image)
+ image = tf.reshape(image, [image_size, image_size, 3])
+ image = tf.image.convert_image_dtype(
+ image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
+ return image
+
+
+def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interpolation='bicubic'):
+ """Preprocesses the given image for evaluation.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+ use_bfloat16: `bool` for whether to use bfloat16.
+ image_size: image size.
+ interpolation: image interpolation method
+
+ Returns:
+ A preprocessed image `Tensor`.
+ """
+ resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
+ image = _decode_and_center_crop(image_bytes, image_size, resize_method)
+ image = tf.reshape(image, [image_size, image_size, 3])
+ image = tf.image.convert_image_dtype(
+ image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
+ return image
+
+
+def preprocess_image(image_bytes,
+ is_training=False,
+ use_bfloat16=False,
+ image_size=IMAGE_SIZE,
+ interpolation='bicubic'):
+ """Preprocesses the given image.
+
+ Args:
+ image_bytes: `Tensor` representing an image binary of arbitrary size.
+ is_training: `bool` for whether the preprocessing is for training.
+ use_bfloat16: `bool` for whether to use bfloat16.
+ image_size: image size.
+ interpolation: image interpolation method
+
+ Returns:
+ A preprocessed image `Tensor` with value range of [0, 255].
+ """
+ if is_training:
+ return preprocess_for_train(image_bytes, use_bfloat16, image_size, interpolation)
+ else:
+ return preprocess_for_eval(image_bytes, use_bfloat16, image_size, interpolation)
+
+
+class TfPreprocessTransform:
+
+ def __init__(self, is_training=False, size=224, interpolation='bicubic'):
+ self.is_training = is_training
+ self.size = size[0] if isinstance(size, tuple) else size
+ self.interpolation = interpolation
+ self._image_bytes = None
+ self.process_image = self._build_tf_graph()
+ self.sess = None
+
+ def _build_tf_graph(self):
+ with tf.device('/cpu:0'):
+ self._image_bytes = tf.placeholder(
+ shape=[],
+ dtype=tf.string,
+ )
+ img = preprocess_image(
+ self._image_bytes, self.is_training, False, self.size, self.interpolation)
+ return img
+
+ def __call__(self, image_bytes):
+ if self.sess is None:
+ self.sess = tf.Session()
+ img = self.sess.run(self.process_image, feed_dict={self._image_bytes: image_bytes})
+ img = img.round().clip(0, 255).astype(np.uint8)
+ if img.ndim < 3:
+ img = np.expand_dims(img, axis=-1)
+ img = np.rollaxis(img, 2) # HWC to CHW
+ return img
diff --git a/timm/data/transforms.py b/timm/data/transforms.py
new file mode 100644
index 0000000..45c078f
--- /dev/null
+++ b/timm/data/transforms.py
@@ -0,0 +1,185 @@
+import torch
+import torchvision.transforms.functional as F
+try:
+ from torchvision.transforms.functional import InterpolationMode
+ has_interpolation_mode = True
+except ImportError:
+ has_interpolation_mode = False
+from PIL import Image
+import warnings
+import math
+import random
+import numpy as np
+
+
+class ToNumpy:
+
+ def __call__(self, pil_img):
+ np_img = np.array(pil_img, dtype=np.uint8)
+ if np_img.ndim < 3:
+ np_img = np.expand_dims(np_img, axis=-1)
+ np_img = np.rollaxis(np_img, 2) # HWC to CHW
+ return np_img
+
+
+class ToTensor:
+
+ def __init__(self, dtype=torch.float32):
+ self.dtype = dtype
+
+ def __call__(self, pil_img):
+ np_img = np.array(pil_img, dtype=np.uint8)
+ if np_img.ndim < 3:
+ np_img = np.expand_dims(np_img, axis=-1)
+ np_img = np.rollaxis(np_img, 2) # HWC to CHW
+ return torch.from_numpy(np_img).to(dtype=self.dtype)
+
+
+_pil_interpolation_to_str = {
+ Image.NEAREST: 'nearest',
+ Image.BILINEAR: 'bilinear',
+ Image.BICUBIC: 'bicubic',
+ Image.BOX: 'box',
+ Image.HAMMING: 'hamming',
+ Image.LANCZOS: 'lanczos',
+}
+_str_to_pil_interpolation = {b: a for a, b in _pil_interpolation_to_str.items()}
+
+
+if has_interpolation_mode:
+ _torch_interpolation_to_str = {
+ InterpolationMode.NEAREST: 'nearest',
+ InterpolationMode.BILINEAR: 'bilinear',
+ InterpolationMode.BICUBIC: 'bicubic',
+ InterpolationMode.BOX: 'box',
+ InterpolationMode.HAMMING: 'hamming',
+ InterpolationMode.LANCZOS: 'lanczos',
+ }
+ _str_to_torch_interpolation = {b: a for a, b in _torch_interpolation_to_str.items()}
+else:
+ _pil_interpolation_to_torch = {}
+ _torch_interpolation_to_str = {}
+
+
+def str_to_pil_interp(mode_str):
+ return _str_to_pil_interpolation[mode_str]
+
+
+def str_to_interp_mode(mode_str):
+ if has_interpolation_mode:
+ return _str_to_torch_interpolation[mode_str]
+ else:
+ return _str_to_pil_interpolation[mode_str]
+
+
+def interp_mode_to_str(mode):
+ if has_interpolation_mode:
+ return _torch_interpolation_to_str[mode]
+ else:
+ return _pil_interpolation_to_str[mode]
+
+
+_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
+
+
+class RandomResizedCropAndInterpolation:
+ """Crop the given PIL Image to random size and aspect ratio with random interpolation.
+
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
+ is finally resized to given size.
+ This is popularly used to train the Inception networks.
+
+ Args:
+ size: expected output size of each edge
+ scale: range of size of the origin size cropped
+ ratio: range of aspect ratio of the origin aspect ratio cropped
+ interpolation: Default: PIL.Image.BILINEAR
+ """
+
+ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
+ interpolation='bilinear'):
+ if isinstance(size, (list, tuple)):
+ self.size = tuple(size)
+ else:
+ self.size = (size, size)
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
+ warnings.warn("range should be of kind (min, max)")
+
+ if interpolation == 'random':
+ self.interpolation = _RANDOM_INTERPOLATION
+ else:
+ self.interpolation = str_to_interp_mode(interpolation)
+ self.scale = scale
+ self.ratio = ratio
+
+ @staticmethod
+ def get_params(img, scale, ratio):
+ """Get parameters for ``crop`` for a random sized crop.
+
+ Args:
+ img (PIL Image): Image to be cropped.
+ scale (tuple): range of size of the origin size cropped
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
+
+ Returns:
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
+ sized crop.
+ """
+ area = img.size[0] * img.size[1]
+
+ for attempt in range(10):
+ target_area = random.uniform(*scale) * area
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
+
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+ if w <= img.size[0] and h <= img.size[1]:
+ i = random.randint(0, img.size[1] - h)
+ j = random.randint(0, img.size[0] - w)
+ return i, j, h, w
+
+ # Fallback to central crop
+ in_ratio = img.size[0] / img.size[1]
+ if in_ratio < min(ratio):
+ w = img.size[0]
+ h = int(round(w / min(ratio)))
+ elif in_ratio > max(ratio):
+ h = img.size[1]
+ w = int(round(h * max(ratio)))
+ else: # whole image
+ w = img.size[0]
+ h = img.size[1]
+ i = (img.size[1] - h) // 2
+ j = (img.size[0] - w) // 2
+ return i, j, h, w
+
+ def __call__(self, img):
+ """
+ Args:
+ img (PIL Image): Image to be cropped and resized.
+
+ Returns:
+ PIL Image: Randomly cropped and resized image.
+ """
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
+ if isinstance(self.interpolation, (tuple, list)):
+ interpolation = random.choice(self.interpolation)
+ else:
+ interpolation = self.interpolation
+ return F.resized_crop(img, i, j, h, w, self.size, interpolation)
+
+ def __repr__(self):
+ if isinstance(self.interpolation, (tuple, list)):
+ interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation])
+ else:
+ interpolate_str = interp_mode_to_str(self.interpolation)
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
+ format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
+ format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
+ format_string += ', interpolation={0})'.format(interpolate_str)
+ return format_string
+
+
diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py
new file mode 100644
index 0000000..d4815d9
--- /dev/null
+++ b/timm/data/transforms_factory.py
@@ -0,0 +1,236 @@
+""" Transforms Factory
+Factory methods for building image transforms for use with TIMM (PyTorch Image Models)
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import math
+
+import torch
+from torchvision import transforms
+
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
+from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
+from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy
+from timm.data.random_erasing import RandomErasing
+
+
+def transforms_noaug_train(
+ img_size=224,
+ interpolation='bilinear',
+ use_prefetcher=False,
+ mean=IMAGENET_DEFAULT_MEAN,
+ std=IMAGENET_DEFAULT_STD,
+):
+ if interpolation == 'random':
+ # random interpolation not supported with no-aug
+ interpolation = 'bilinear'
+ tfl = [
+ transforms.Resize(img_size, interpolation=str_to_interp_mode(interpolation)),
+ transforms.CenterCrop(img_size)
+ ]
+ if use_prefetcher:
+ # prefetcher and collate will handle tensor conversion and norm
+ tfl += [ToNumpy()]
+ else:
+ tfl += [
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=torch.tensor(mean),
+ std=torch.tensor(std))
+ ]
+ return transforms.Compose(tfl)
+
+
+def transforms_imagenet_train(
+ img_size=224,
+ scale=None,
+ ratio=None,
+ hflip=0.5,
+ vflip=0.,
+ color_jitter=0.4,
+ auto_augment=None,
+ interpolation='random',
+ use_prefetcher=False,
+ mean=IMAGENET_DEFAULT_MEAN,
+ std=IMAGENET_DEFAULT_STD,
+ re_prob=0.,
+ re_mode='const',
+ re_count=1,
+ re_num_splits=0,
+ separate=False,
+):
+ """
+ If separate==True, the transforms are returned as a tuple of 3 separate transforms
+ for use in a mixing dataset that passes
+ * all data through the first (primary) transform, called the 'clean' data
+ * a portion of the data through the secondary transform
+ * normalizes and converts the branches above with the third, final transform
+ """
+ scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
+ ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range
+ primary_tfl = [
+ RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
+ if hflip > 0.:
+ primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
+ if vflip > 0.:
+ primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
+
+ secondary_tfl = []
+ if auto_augment:
+ assert isinstance(auto_augment, str)
+ if isinstance(img_size, (tuple, list)):
+ img_size_min = min(img_size)
+ else:
+ img_size_min = img_size
+ aa_params = dict(
+ translate_const=int(img_size_min * 0.45),
+ img_mean=tuple([min(255, round(255 * x)) for x in mean]),
+ )
+ if interpolation and interpolation != 'random':
+ aa_params['interpolation'] = str_to_pil_interp(interpolation)
+ if auto_augment.startswith('rand'):
+ secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
+ elif auto_augment.startswith('augmix'):
+ aa_params['translate_pct'] = 0.3
+ secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
+ else:
+ secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
+ elif color_jitter is not None:
+ # color jitter is enabled when not using AA
+ if isinstance(color_jitter, (list, tuple)):
+ # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
+ # or 4 if also augmenting hue
+ assert len(color_jitter) in (3, 4)
+ else:
+ # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
+ color_jitter = (float(color_jitter),) * 3
+ secondary_tfl += [transforms.ColorJitter(*color_jitter)]
+
+ final_tfl = []
+ if use_prefetcher:
+ # prefetcher and collate will handle tensor conversion and norm
+ final_tfl += [ToNumpy()]
+ else:
+ final_tfl += [
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=torch.tensor(mean),
+ std=torch.tensor(std))
+ ]
+ if re_prob > 0.:
+ final_tfl.append(
+ RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))
+
+ if separate:
+ return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
+ else:
+ return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
+
+
+def transforms_imagenet_eval(
+ img_size=224,
+ crop_pct=None,
+ interpolation='bilinear',
+ use_prefetcher=False,
+ mean=IMAGENET_DEFAULT_MEAN,
+ std=IMAGENET_DEFAULT_STD):
+ crop_pct = crop_pct or DEFAULT_CROP_PCT
+
+ if isinstance(img_size, (tuple, list)):
+ assert len(img_size) == 2
+ if img_size[-1] == img_size[-2]:
+ # fall-back to older behaviour so Resize scales to shortest edge if target is square
+ scale_size = int(math.floor(img_size[0] / crop_pct))
+ else:
+ scale_size = tuple([int(x / crop_pct) for x in img_size])
+ else:
+ scale_size = int(math.floor(img_size / crop_pct))
+
+ tfl = [
+ transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
+ transforms.CenterCrop(img_size),
+ ]
+ if use_prefetcher:
+ # prefetcher and collate will handle tensor conversion and norm
+ tfl += [ToNumpy()]
+ else:
+ tfl += [
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=torch.tensor(mean),
+ std=torch.tensor(std))
+ ]
+
+ return transforms.Compose(tfl)
+
+
+def create_transform(
+ input_size,
+ is_training=False,
+ use_prefetcher=False,
+ no_aug=False,
+ scale=None,
+ ratio=None,
+ hflip=0.5,
+ vflip=0.,
+ color_jitter=0.4,
+ auto_augment=None,
+ interpolation='bilinear',
+ mean=IMAGENET_DEFAULT_MEAN,
+ std=IMAGENET_DEFAULT_STD,
+ re_prob=0.,
+ re_mode='const',
+ re_count=1,
+ re_num_splits=0,
+ crop_pct=None,
+ tf_preprocessing=False,
+ separate=False):
+
+ if isinstance(input_size, (tuple, list)):
+ img_size = input_size[-2:]
+ else:
+ img_size = input_size
+
+ if tf_preprocessing and use_prefetcher:
+ assert not separate, "Separate transforms not supported for TF preprocessing"
+ from timm.data.tf_preprocessing import TfPreprocessTransform
+ transform = TfPreprocessTransform(
+ is_training=is_training, size=img_size, interpolation=interpolation)
+ else:
+ if is_training and no_aug:
+ assert not separate, "Cannot perform split augmentation with no_aug"
+ transform = transforms_noaug_train(
+ img_size,
+ interpolation=interpolation,
+ use_prefetcher=use_prefetcher,
+ mean=mean,
+ std=std)
+ elif is_training:
+ transform = transforms_imagenet_train(
+ img_size,
+ scale=scale,
+ ratio=ratio,
+ hflip=hflip,
+ vflip=vflip,
+ color_jitter=color_jitter,
+ auto_augment=auto_augment,
+ interpolation=interpolation,
+ use_prefetcher=use_prefetcher,
+ mean=mean,
+ std=std,
+ re_prob=re_prob,
+ re_mode=re_mode,
+ re_count=re_count,
+ re_num_splits=re_num_splits,
+ separate=separate)
+ else:
+ assert not separate, "Separate transforms not supported for validation preprocessing"
+ transform = transforms_imagenet_eval(
+ img_size,
+ interpolation=interpolation,
+ use_prefetcher=use_prefetcher,
+ mean=mean,
+ std=std,
+ crop_pct=crop_pct)
+
+ return transform
diff --git a/timm/loss/__init__.py b/timm/loss/__init__.py
new file mode 100644
index 0000000..ea7f15f
--- /dev/null
+++ b/timm/loss/__init__.py
@@ -0,0 +1,4 @@
+from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel
+from .binary_cross_entropy import BinaryCrossEntropy
+from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+from .jsd import JsdCrossEntropy
diff --git a/timm/loss/asymmetric_loss.py b/timm/loss/asymmetric_loss.py
new file mode 100644
index 0000000..a8b10f9
--- /dev/null
+++ b/timm/loss/asymmetric_loss.py
@@ -0,0 +1,97 @@
+import torch
+import torch.nn as nn
+
+
+class AsymmetricLossMultiLabel(nn.Module):
+ def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
+ super(AsymmetricLossMultiLabel, self).__init__()
+
+ self.gamma_neg = gamma_neg
+ self.gamma_pos = gamma_pos
+ self.clip = clip
+ self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
+ self.eps = eps
+
+ def forward(self, x, y):
+ """"
+ Parameters
+ ----------
+ x: input logits
+ y: targets (multi-label binarized vector)
+ """
+
+ # Calculating Probabilities
+ x_sigmoid = torch.sigmoid(x)
+ xs_pos = x_sigmoid
+ xs_neg = 1 - x_sigmoid
+
+ # Asymmetric Clipping
+ if self.clip is not None and self.clip > 0:
+ xs_neg = (xs_neg + self.clip).clamp(max=1)
+
+ # Basic CE calculation
+ los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
+ los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
+ loss = los_pos + los_neg
+
+ # Asymmetric Focusing
+ if self.gamma_neg > 0 or self.gamma_pos > 0:
+ if self.disable_torch_grad_focal_loss:
+ torch._C.set_grad_enabled(False)
+ pt0 = xs_pos * y
+ pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
+ pt = pt0 + pt1
+ one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
+ one_sided_w = torch.pow(1 - pt, one_sided_gamma)
+ if self.disable_torch_grad_focal_loss:
+ torch._C.set_grad_enabled(True)
+ loss *= one_sided_w
+
+ return -loss.sum()
+
+
+class AsymmetricLossSingleLabel(nn.Module):
+ def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'):
+ super(AsymmetricLossSingleLabel, self).__init__()
+
+ self.eps = eps
+ self.logsoftmax = nn.LogSoftmax(dim=-1)
+ self.targets_classes = [] # prevent gpu repeated memory allocation
+ self.gamma_pos = gamma_pos
+ self.gamma_neg = gamma_neg
+ self.reduction = reduction
+
+ def forward(self, inputs, target, reduction=None):
+ """"
+ Parameters
+ ----------
+ x: input logits
+ y: targets (1-hot vector)
+ """
+
+ num_classes = inputs.size()[-1]
+ log_preds = self.logsoftmax(inputs)
+ self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
+
+ # ASL weights
+ targets = self.targets_classes
+ anti_targets = 1 - targets
+ xs_pos = torch.exp(log_preds)
+ xs_neg = 1 - xs_pos
+ xs_pos = xs_pos * targets
+ xs_neg = xs_neg * anti_targets
+ asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
+ self.gamma_pos * targets + self.gamma_neg * anti_targets)
+ log_preds = log_preds * asymmetric_w
+
+ if self.eps > 0: # label smoothing
+ self.targets_classes.mul_(1 - self.eps).add_(self.eps / num_classes)
+
+ # loss calculation
+ loss = - self.targets_classes.mul(log_preds)
+
+ loss = loss.sum(dim=-1)
+ if self.reduction == 'mean':
+ loss = loss.mean()
+
+ return loss
diff --git a/timm/loss/binary_cross_entropy.py b/timm/loss/binary_cross_entropy.py
new file mode 100644
index 0000000..ed76c1e
--- /dev/null
+++ b/timm/loss/binary_cross_entropy.py
@@ -0,0 +1,47 @@
+""" Binary Cross Entropy w/ a few extras
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BinaryCrossEntropy(nn.Module):
+ """ BCE with optional one-hot from dense targets, label smoothing, thresholding
+ NOTE for experiments comparing CE to BCE /w label smoothing, may remove
+ """
+ def __init__(
+ self, smoothing=0.1, target_threshold: Optional[float] = None, weight: Optional[torch.Tensor] = None,
+ reduction: str = 'mean', pos_weight: Optional[torch.Tensor] = None):
+ super(BinaryCrossEntropy, self).__init__()
+ assert 0. <= smoothing < 1.0
+ self.smoothing = smoothing
+ self.target_threshold = target_threshold
+ self.reduction = reduction
+ self.register_buffer('weight', weight)
+ self.register_buffer('pos_weight', pos_weight)
+
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ assert x.shape[0] == target.shape[0]
+ if target.shape != x.shape:
+ # NOTE currently assume smoothing or other label softening is applied upstream if targets are already sparse
+ num_classes = x.shape[-1]
+ # FIXME should off/on be different for smoothing w/ BCE? Other impl out there differ
+ off_value = self.smoothing / num_classes
+ on_value = 1. - self.smoothing + off_value
+ target = target.long().view(-1, 1)
+ target = torch.full(
+ (target.size()[0], num_classes),
+ off_value,
+ device=x.device, dtype=x.dtype).scatter_(1, target, on_value)
+ if self.target_threshold is not None:
+ # Make target 0, or 1 if threshold set
+ target = target.gt(self.target_threshold).to(dtype=target.dtype)
+ return F.binary_cross_entropy_with_logits(
+ x, target,
+ self.weight,
+ pos_weight=self.pos_weight,
+ reduction=self.reduction)
diff --git a/timm/loss/cross_entropy.py b/timm/loss/cross_entropy.py
new file mode 100644
index 0000000..8519810
--- /dev/null
+++ b/timm/loss/cross_entropy.py
@@ -0,0 +1,36 @@
+""" Cross Entropy w/ smoothing or soft targets
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class LabelSmoothingCrossEntropy(nn.Module):
+ """ NLL loss with label smoothing.
+ """
+ def __init__(self, smoothing=0.1):
+ super(LabelSmoothingCrossEntropy, self).__init__()
+ assert smoothing < 1.0
+ self.smoothing = smoothing
+ self.confidence = 1. - smoothing
+
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ logprobs = F.log_softmax(x, dim=-1)
+ nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
+ nll_loss = nll_loss.squeeze(1)
+ smooth_loss = -logprobs.mean(dim=-1)
+ loss = self.confidence * nll_loss + self.smoothing * smooth_loss
+ return loss.mean()
+
+
+class SoftTargetCrossEntropy(nn.Module):
+
+ def __init__(self):
+ super(SoftTargetCrossEntropy, self).__init__()
+
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
+ return loss.mean()
diff --git a/timm/loss/jsd.py b/timm/loss/jsd.py
new file mode 100644
index 0000000..dd64e15
--- /dev/null
+++ b/timm/loss/jsd.py
@@ -0,0 +1,39 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .cross_entropy import LabelSmoothingCrossEntropy
+
+
+class JsdCrossEntropy(nn.Module):
+ """ Jensen-Shannon Divergence + Cross-Entropy Loss
+
+ Based on impl here: https://github.com/google-research/augmix/blob/master/imagenet.py
+ From paper: 'AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty -
+ https://arxiv.org/abs/1912.02781
+
+ Hacked together by / Copyright 2020 Ross Wightman
+ """
+ def __init__(self, num_splits=3, alpha=12, smoothing=0.1):
+ super().__init__()
+ self.num_splits = num_splits
+ self.alpha = alpha
+ if smoothing is not None and smoothing > 0:
+ self.cross_entropy_loss = LabelSmoothingCrossEntropy(smoothing)
+ else:
+ self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
+
+ def __call__(self, output, target):
+ split_size = output.shape[0] // self.num_splits
+ assert split_size * self.num_splits == output.shape[0]
+ logits_split = torch.split(output, split_size)
+
+ # Cross-entropy is only computed on clean images
+ loss = self.cross_entropy_loss(logits_split[0], target[:split_size])
+ probs = [F.softmax(logits, dim=1) for logits in logits_split]
+
+ # Clamp mixture distribution to avoid exploding KL divergence
+ logp_mixture = torch.clamp(torch.stack(probs).mean(axis=0), 1e-7, 1).log()
+ loss += self.alpha * sum([F.kl_div(
+ logp_mixture, p_split, reduction='batchmean') for p_split in probs]) / len(probs)
+ return loss
diff --git a/timm/models/__init__.py b/timm/models/__init__.py
new file mode 100644
index 0000000..0982b6e
--- /dev/null
+++ b/timm/models/__init__.py
@@ -0,0 +1,58 @@
+from .beit import *
+from .byoanet import *
+from .byobnet import *
+from .cait import *
+from .coat import *
+from .convit import *
+from .convmixer import *
+from .crossvit import *
+from .cspnet import *
+from .densenet import *
+from .dla import *
+from .dpn import *
+from .efficientnet import *
+from .ghostnet import *
+from .gluon_resnet import *
+from .gluon_xception import *
+from .hardcorenas import *
+from .hrnet import *
+from .inception_resnet_v2 import *
+from .inception_v3 import *
+from .inception_v4 import *
+from .levit import *
+from .mlp_mixer import *
+from .mobilenetv3 import *
+from .nasnet import *
+from .nest import *
+from .nfnet import *
+from .pit import *
+from .pnasnet import *
+from .regnet import *
+from .res2net import *
+from .resnest import *
+from .resnet import *
+from .resnetv2 import *
+from .rexnet import *
+from .selecsls import *
+from .senet import *
+from .sknet import *
+from .swin_transformer import *
+from .tnt import *
+from .tresnet import *
+from .twins import *
+from .vgg import *
+from .visformer import *
+from .vision_transformer import *
+from .vision_transformer_hybrid import *
+from .vovnet import *
+from .xception import *
+from .xception_aligned import *
+from .xcit import *
+
+from .factory import create_model, split_model_name, safe_model_name
+from .helpers import load_checkpoint, resume_checkpoint, model_parameters
+from .layers import TestTimePoolHead, apply_test_time_pool
+from .layers import convert_splitbn_model
+from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
+from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
+ has_model_default_key, is_model_default_key, get_model_default_value, is_model_pretrained
diff --git a/timm/models/__pycache__/__init__.cpython-36.pyc b/timm/models/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..db8276a
Binary files /dev/null and b/timm/models/__pycache__/__init__.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/beit.cpython-36.pyc b/timm/models/__pycache__/beit.cpython-36.pyc
new file mode 100644
index 0000000..325294a
Binary files /dev/null and b/timm/models/__pycache__/beit.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/byoanet.cpython-36.pyc b/timm/models/__pycache__/byoanet.cpython-36.pyc
new file mode 100644
index 0000000..58e9926
Binary files /dev/null and b/timm/models/__pycache__/byoanet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/byobnet.cpython-36.pyc b/timm/models/__pycache__/byobnet.cpython-36.pyc
new file mode 100644
index 0000000..06ede29
Binary files /dev/null and b/timm/models/__pycache__/byobnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/cait.cpython-36.pyc b/timm/models/__pycache__/cait.cpython-36.pyc
new file mode 100644
index 0000000..9756431
Binary files /dev/null and b/timm/models/__pycache__/cait.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/coat.cpython-36.pyc b/timm/models/__pycache__/coat.cpython-36.pyc
new file mode 100644
index 0000000..a33e31c
Binary files /dev/null and b/timm/models/__pycache__/coat.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/convit.cpython-36.pyc b/timm/models/__pycache__/convit.cpython-36.pyc
new file mode 100644
index 0000000..6192cad
Binary files /dev/null and b/timm/models/__pycache__/convit.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/convmixer.cpython-36.pyc b/timm/models/__pycache__/convmixer.cpython-36.pyc
new file mode 100644
index 0000000..ac3549b
Binary files /dev/null and b/timm/models/__pycache__/convmixer.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/crossvit.cpython-36.pyc b/timm/models/__pycache__/crossvit.cpython-36.pyc
new file mode 100644
index 0000000..a96f59c
Binary files /dev/null and b/timm/models/__pycache__/crossvit.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/cspnet.cpython-36.pyc b/timm/models/__pycache__/cspnet.cpython-36.pyc
new file mode 100644
index 0000000..ea881c2
Binary files /dev/null and b/timm/models/__pycache__/cspnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/densenet.cpython-36.pyc b/timm/models/__pycache__/densenet.cpython-36.pyc
new file mode 100644
index 0000000..716a407
Binary files /dev/null and b/timm/models/__pycache__/densenet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/dla.cpython-36.pyc b/timm/models/__pycache__/dla.cpython-36.pyc
new file mode 100644
index 0000000..31dd1ab
Binary files /dev/null and b/timm/models/__pycache__/dla.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/dpn.cpython-36.pyc b/timm/models/__pycache__/dpn.cpython-36.pyc
new file mode 100644
index 0000000..044bea2
Binary files /dev/null and b/timm/models/__pycache__/dpn.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/efficientnet.cpython-36.pyc b/timm/models/__pycache__/efficientnet.cpython-36.pyc
new file mode 100644
index 0000000..eae7974
Binary files /dev/null and b/timm/models/__pycache__/efficientnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/efficientnet_blocks.cpython-36.pyc b/timm/models/__pycache__/efficientnet_blocks.cpython-36.pyc
new file mode 100644
index 0000000..b0298ca
Binary files /dev/null and b/timm/models/__pycache__/efficientnet_blocks.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/efficientnet_builder.cpython-36.pyc b/timm/models/__pycache__/efficientnet_builder.cpython-36.pyc
new file mode 100644
index 0000000..ff28783
Binary files /dev/null and b/timm/models/__pycache__/efficientnet_builder.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/factory.cpython-36.pyc b/timm/models/__pycache__/factory.cpython-36.pyc
new file mode 100644
index 0000000..d099862
Binary files /dev/null and b/timm/models/__pycache__/factory.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/features.cpython-36.pyc b/timm/models/__pycache__/features.cpython-36.pyc
new file mode 100644
index 0000000..6462ee0
Binary files /dev/null and b/timm/models/__pycache__/features.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/fx_features.cpython-36.pyc b/timm/models/__pycache__/fx_features.cpython-36.pyc
new file mode 100644
index 0000000..ce3cbc8
Binary files /dev/null and b/timm/models/__pycache__/fx_features.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/ghostnet.cpython-36.pyc b/timm/models/__pycache__/ghostnet.cpython-36.pyc
new file mode 100644
index 0000000..0abefd0
Binary files /dev/null and b/timm/models/__pycache__/ghostnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/gluon_resnet.cpython-36.pyc b/timm/models/__pycache__/gluon_resnet.cpython-36.pyc
new file mode 100644
index 0000000..7ffa096
Binary files /dev/null and b/timm/models/__pycache__/gluon_resnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/gluon_xception.cpython-36.pyc b/timm/models/__pycache__/gluon_xception.cpython-36.pyc
new file mode 100644
index 0000000..7b46ebb
Binary files /dev/null and b/timm/models/__pycache__/gluon_xception.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/hardcorenas.cpython-36.pyc b/timm/models/__pycache__/hardcorenas.cpython-36.pyc
new file mode 100644
index 0000000..6a7fca9
Binary files /dev/null and b/timm/models/__pycache__/hardcorenas.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/helpers.cpython-36.pyc b/timm/models/__pycache__/helpers.cpython-36.pyc
new file mode 100644
index 0000000..67e30c6
Binary files /dev/null and b/timm/models/__pycache__/helpers.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/hrnet.cpython-36.pyc b/timm/models/__pycache__/hrnet.cpython-36.pyc
new file mode 100644
index 0000000..9075e06
Binary files /dev/null and b/timm/models/__pycache__/hrnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/hub.cpython-36.pyc b/timm/models/__pycache__/hub.cpython-36.pyc
new file mode 100644
index 0000000..ad969b1
Binary files /dev/null and b/timm/models/__pycache__/hub.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/inception_resnet_v2.cpython-36.pyc b/timm/models/__pycache__/inception_resnet_v2.cpython-36.pyc
new file mode 100644
index 0000000..ae88182
Binary files /dev/null and b/timm/models/__pycache__/inception_resnet_v2.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/inception_v3.cpython-36.pyc b/timm/models/__pycache__/inception_v3.cpython-36.pyc
new file mode 100644
index 0000000..9ab31c9
Binary files /dev/null and b/timm/models/__pycache__/inception_v3.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/inception_v4.cpython-36.pyc b/timm/models/__pycache__/inception_v4.cpython-36.pyc
new file mode 100644
index 0000000..eb608f4
Binary files /dev/null and b/timm/models/__pycache__/inception_v4.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/levit.cpython-36.pyc b/timm/models/__pycache__/levit.cpython-36.pyc
new file mode 100644
index 0000000..c8ac5e7
Binary files /dev/null and b/timm/models/__pycache__/levit.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/mlp_mixer.cpython-36.pyc b/timm/models/__pycache__/mlp_mixer.cpython-36.pyc
new file mode 100644
index 0000000..03d6894
Binary files /dev/null and b/timm/models/__pycache__/mlp_mixer.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/mobilenetv3.cpython-36.pyc b/timm/models/__pycache__/mobilenetv3.cpython-36.pyc
new file mode 100644
index 0000000..8aefbd0
Binary files /dev/null and b/timm/models/__pycache__/mobilenetv3.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/nasnet.cpython-36.pyc b/timm/models/__pycache__/nasnet.cpython-36.pyc
new file mode 100644
index 0000000..124d8b1
Binary files /dev/null and b/timm/models/__pycache__/nasnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/nest.cpython-36.pyc b/timm/models/__pycache__/nest.cpython-36.pyc
new file mode 100644
index 0000000..3412d30
Binary files /dev/null and b/timm/models/__pycache__/nest.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/nfnet.cpython-36.pyc b/timm/models/__pycache__/nfnet.cpython-36.pyc
new file mode 100644
index 0000000..c0112a5
Binary files /dev/null and b/timm/models/__pycache__/nfnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/pit.cpython-36.pyc b/timm/models/__pycache__/pit.cpython-36.pyc
new file mode 100644
index 0000000..20192c3
Binary files /dev/null and b/timm/models/__pycache__/pit.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/pnasnet.cpython-36.pyc b/timm/models/__pycache__/pnasnet.cpython-36.pyc
new file mode 100644
index 0000000..1929f0f
Binary files /dev/null and b/timm/models/__pycache__/pnasnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/registry.cpython-36.pyc b/timm/models/__pycache__/registry.cpython-36.pyc
new file mode 100644
index 0000000..da13b3b
Binary files /dev/null and b/timm/models/__pycache__/registry.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/regnet.cpython-36.pyc b/timm/models/__pycache__/regnet.cpython-36.pyc
new file mode 100644
index 0000000..ae87c17
Binary files /dev/null and b/timm/models/__pycache__/regnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/res2net.cpython-36.pyc b/timm/models/__pycache__/res2net.cpython-36.pyc
new file mode 100644
index 0000000..2b9fbc2
Binary files /dev/null and b/timm/models/__pycache__/res2net.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/resnest.cpython-36.pyc b/timm/models/__pycache__/resnest.cpython-36.pyc
new file mode 100644
index 0000000..f2e0e31
Binary files /dev/null and b/timm/models/__pycache__/resnest.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/resnet.cpython-36.pyc b/timm/models/__pycache__/resnet.cpython-36.pyc
new file mode 100644
index 0000000..9609c1b
Binary files /dev/null and b/timm/models/__pycache__/resnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/resnetv2.cpython-36.pyc b/timm/models/__pycache__/resnetv2.cpython-36.pyc
new file mode 100644
index 0000000..1243362
Binary files /dev/null and b/timm/models/__pycache__/resnetv2.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/rexnet.cpython-36.pyc b/timm/models/__pycache__/rexnet.cpython-36.pyc
new file mode 100644
index 0000000..2212e55
Binary files /dev/null and b/timm/models/__pycache__/rexnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/selecsls.cpython-36.pyc b/timm/models/__pycache__/selecsls.cpython-36.pyc
new file mode 100644
index 0000000..09cd380
Binary files /dev/null and b/timm/models/__pycache__/selecsls.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/senet.cpython-36.pyc b/timm/models/__pycache__/senet.cpython-36.pyc
new file mode 100644
index 0000000..e604798
Binary files /dev/null and b/timm/models/__pycache__/senet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/sknet.cpython-36.pyc b/timm/models/__pycache__/sknet.cpython-36.pyc
new file mode 100644
index 0000000..dd38016
Binary files /dev/null and b/timm/models/__pycache__/sknet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/swin_transformer.cpython-36.pyc b/timm/models/__pycache__/swin_transformer.cpython-36.pyc
new file mode 100644
index 0000000..47de810
Binary files /dev/null and b/timm/models/__pycache__/swin_transformer.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/tnt.cpython-36.pyc b/timm/models/__pycache__/tnt.cpython-36.pyc
new file mode 100644
index 0000000..d89fb98
Binary files /dev/null and b/timm/models/__pycache__/tnt.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/tresnet.cpython-36.pyc b/timm/models/__pycache__/tresnet.cpython-36.pyc
new file mode 100644
index 0000000..5716d57
Binary files /dev/null and b/timm/models/__pycache__/tresnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/twins.cpython-36.pyc b/timm/models/__pycache__/twins.cpython-36.pyc
new file mode 100644
index 0000000..91ed400
Binary files /dev/null and b/timm/models/__pycache__/twins.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/vgg.cpython-36.pyc b/timm/models/__pycache__/vgg.cpython-36.pyc
new file mode 100644
index 0000000..ab11f2c
Binary files /dev/null and b/timm/models/__pycache__/vgg.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/visformer.cpython-36.pyc b/timm/models/__pycache__/visformer.cpython-36.pyc
new file mode 100644
index 0000000..73ee0cd
Binary files /dev/null and b/timm/models/__pycache__/visformer.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/vision_transformer.cpython-36.pyc b/timm/models/__pycache__/vision_transformer.cpython-36.pyc
new file mode 100644
index 0000000..f28afe9
Binary files /dev/null and b/timm/models/__pycache__/vision_transformer.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/vision_transformer_hybrid.cpython-36.pyc b/timm/models/__pycache__/vision_transformer_hybrid.cpython-36.pyc
new file mode 100644
index 0000000..2734bd8
Binary files /dev/null and b/timm/models/__pycache__/vision_transformer_hybrid.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/vovnet.cpython-36.pyc b/timm/models/__pycache__/vovnet.cpython-36.pyc
new file mode 100644
index 0000000..35456cd
Binary files /dev/null and b/timm/models/__pycache__/vovnet.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/xception.cpython-36.pyc b/timm/models/__pycache__/xception.cpython-36.pyc
new file mode 100644
index 0000000..51fc3b5
Binary files /dev/null and b/timm/models/__pycache__/xception.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/xception_aligned.cpython-36.pyc b/timm/models/__pycache__/xception_aligned.cpython-36.pyc
new file mode 100644
index 0000000..ad4c8b2
Binary files /dev/null and b/timm/models/__pycache__/xception_aligned.cpython-36.pyc differ
diff --git a/timm/models/__pycache__/xcit.cpython-36.pyc b/timm/models/__pycache__/xcit.cpython-36.pyc
new file mode 100644
index 0000000..76a0416
Binary files /dev/null and b/timm/models/__pycache__/xcit.cpython-36.pyc differ
diff --git a/timm/models/beit.py b/timm/models/beit.py
new file mode 100644
index 0000000..f644b65
--- /dev/null
+++ b/timm/models/beit.py
@@ -0,0 +1,416 @@
+""" BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
+
+Model from official source: https://github.com/microsoft/unilm/tree/master/beit
+
+At this point only the 1k fine-tuned classification weights and model configs have been added,
+see original source above for pre-training models and procedure.
+
+Modifications by / Copyright 2021 Ross Wightman, original copyrights below
+"""
+# --------------------------------------------------------
+# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
+# Github source: https://github.com/microsoft/unilm/tree/master/beit
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# By Hangbo Bao
+# Based on timm and DeiT code bases
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# https://github.com/facebookresearch/deit/
+# https://github.com/facebookresearch/dino
+# --------------------------------------------------------'
+import math
+from functools import partial
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .helpers import build_model_with_cfg
+from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
+from .registry import register_model
+from .vision_transformer import checkpoint_filter_fn
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'beit_base_patch16_224': _cfg(
+ url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
+ 'beit_base_patch16_384': _cfg(
+ url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0,
+ ),
+ 'beit_base_patch16_224_in22k': _cfg(
+ url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth',
+ num_classes=21841,
+ ),
+ 'beit_large_patch16_224': _cfg(
+ url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
+ 'beit_large_patch16_384': _cfg(
+ url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0,
+ ),
+ 'beit_large_patch16_512': _cfg(
+ url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
+ input_size=(3, 512, 512), crop_pct=1.0,
+ ),
+ 'beit_large_patch16_224_in22k': _cfg(
+ url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22k.pth',
+ num_classes=21841,
+ ),
+}
+
+
+class Attention(nn.Module):
+ def __init__(
+ self, dim, num_heads=8, qkv_bias=False, attn_drop=0.,
+ proj_drop=0., window_size=None, attn_head_dim=None):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.k_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = \
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None):
+ B, N, C = x.shape
+ qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if self.relative_position_bias_table is not None:
+ relative_position_bias = \
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if rel_pos_bias is not None:
+ attn = attn + rel_pos_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
+ window_size=None, attn_head_dim=None):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ window_size=window_size, attn_head_dim=attn_head_dim)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if init_values:
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None):
+ if self.gamma_1 is None:
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class RelativePositionBias(nn.Module):
+
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ # cls to token & token 2 cls & cls to cls
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = \
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
+
+ def forward(self):
+ relative_position_bias = \
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+class Beit(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=None,
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
+ use_mean_pooling=True, init_scale=0.001):
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ if use_abs_pos_emb:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ else:
+ self.pos_embed = None
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ if use_shared_rel_pos_bias:
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.grid_size, num_heads=num_heads)
+ else:
+ self.rel_pos_bias = None
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.use_rel_pos_bias = use_rel_pos_bias
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None)
+ for i in range(depth)])
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ # trunc_normal_(self.mask_token, std=.02)
+ self.fix_init_weight()
+ if isinstance(self.head, nn.Linear):
+ trunc_normal_(self.head.weight, std=.02)
+ self.head.weight.data.mul_(init_scale)
+ self.head.bias.data.mul_(init_scale)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def get_num_layers(self):
+ return len(self.blocks)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ batch_size, seq_len, _ = x.size()
+
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for blk in self.blocks:
+ x = blk(x, rel_pos_bias=rel_pos_bias)
+
+ x = self.norm(x)
+ if self.fc_norm is not None:
+ t = x[:, 1:, :]
+ return self.fc_norm(t.mean(1))
+ else:
+ return x[:, 0]
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def _create_beit(variant, pretrained=False, default_cfg=None, **kwargs):
+ default_cfg = default_cfg or default_cfgs[variant]
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Beit models.')
+
+ model = build_model_with_cfg(
+ Beit, variant, pretrained,
+ default_cfg=default_cfg,
+ # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+ return model
+
+
+@register_model
+def beit_base_patch16_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
+ model = _create_beit('beit_base_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def beit_base_patch16_384(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
+ model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def beit_base_patch16_224_in22k(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1, **kwargs)
+ model = _create_beit('beit_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def beit_large_patch16_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
+ model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def beit_large_patch16_384(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
+ model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def beit_large_patch16_512(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
+ model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def beit_large_patch16_224_in22k(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
+ model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
+ return model
diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py
new file mode 100644
index 0000000..f44040b
--- /dev/null
+++ b/timm/models/byoanet.py
@@ -0,0 +1,443 @@
+""" Bring-Your-Own-Attention Network
+
+A flexible network w/ dataclass based config for stacking NN blocks including
+self-attention (or similar) layers.
+
+Currently used to implement experimental variants of:
+ * Bottleneck Transformers
+ * Lambda ResNets
+ * HaloNets
+
+Consider all of the models definitions here as experimental WIP and likely to change.
+
+Hacked together by / copyright Ross Wightman, 2021.
+"""
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
+from .helpers import build_model_with_cfg
+from .registry import register_model
+
+__all__ = []
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.95, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
+ 'fixed_input_size': False, 'min_input_size': (3, 224, 224),
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # GPU-Efficient (ResNet) weights
+ 'botnet26t_256': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'sebotnet33ts_256': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sebotnet33ts_a1h2_256-957e3c3e.pth',
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
+ 'botnet50ts_256': _cfg(
+ url='',
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'eca_botnext26ts_256': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_botnext26ts_c_256-95a898f6.pth',
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
+
+ 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
+ 'halonet26t': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_a1h_256-3083328c.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
+ 'sehalonet33ts': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
+ 'halonet50ts': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h2_256-f3a3daee.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
+ 'eca_halonext26ts': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
+
+ 'lambda_resnet26t': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_c_256-e5a5c857.pth',
+ min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
+ 'lambda_resnet50ts': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet50ts_a1h_256-b87370f7.pth',
+ min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'lambda_resnet26rpt_256': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_c_256-ab00292d.pth',
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
+
+ 'haloregnetz_b': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
+
+ 'lamhalobotnet50ts_256': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lamhalobotnet50ts_a1h2_256-fe3d9445.pth',
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'halo2botnet50ts_256': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halo2botnet50ts_a1h2_256-fd9c11a3.pth',
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
+}
+
+
+model_cfgs = dict(
+
+ botnet26t=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ fixed_input_size=True,
+ self_attn_layer='bottleneck',
+ self_attn_kwargs=dict()
+ ),
+ sebotnet33ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25),
+ ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ act_layer='silu',
+ num_features=1280,
+ attn_layer='se',
+ self_attn_layer='bottleneck',
+ self_attn_kwargs=dict()
+ ),
+ botnet50ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ fixed_input_size=True,
+ self_attn_layer='bottleneck',
+ self_attn_kwargs=dict()
+ ),
+ eca_botnext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ fixed_input_size=True,
+ act_layer='silu',
+ attn_layer='eca',
+ self_attn_layer='bottleneck',
+ self_attn_kwargs=dict(dim_head=16)
+ ),
+
+ halonet_h1=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
+ ByoBlockCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
+ ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
+ ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
+ ),
+ stem_chs=64,
+ stem_type='7x7',
+ stem_pool='maxpool',
+
+ self_attn_layer='halo',
+ self_attn_kwargs=dict(block_size=8, halo_size=3),
+ ),
+ halonet26t=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ self_attn_layer='halo',
+ self_attn_kwargs=dict(block_size=8, halo_size=2)
+ ),
+ sehalonet33ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25),
+ ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ act_layer='silu',
+ num_features=1280,
+ attn_layer='se',
+ self_attn_layer='halo',
+ self_attn_kwargs=dict(block_size=8, halo_size=3)
+ ),
+ halonet50ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25,
+ self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3, num_heads=4)),
+ interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ self_attn_layer='halo',
+ self_attn_kwargs=dict(block_size=8, halo_size=3)
+ ),
+ eca_halonext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ attn_layer='eca',
+ self_attn_layer='halo',
+ self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
+ ),
+
+ lambda_resnet26t=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ self_attn_layer='lambda',
+ self_attn_kwargs=dict(r=9)
+ ),
+ lambda_resnet50ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ self_attn_layer='lambda',
+ self_attn_kwargs=dict(r=9)
+ ),
+ lambda_resnet26rpt_256=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ self_attn_layer='lambda',
+ self_attn_kwargs=dict(r=None)
+ ),
+
+ # experimental
+ haloregnetz_b=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
+ ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
+ interleave_blocks(types=('bottle', 'self_attn'), every=3, d=12, c=192, s=2, gs=16, br=3),
+ ByoBlockCfg('self_attn', d=2, c=288, s=2, gs=16, br=3),
+ ),
+ stem_chs=32,
+ stem_pool='',
+ downsample='',
+ num_features=1536,
+ act_layer='silu',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ self_attn_layer='halo',
+ self_attn_kwargs=dict(block_size=7, halo_size=2, qk_ratio=0.33)
+ ),
+
+ # experimental
+ lamhalobotnet50ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25,
+ self_attn_layer='lambda', self_attn_kwargs=dict(r=13)),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25,
+ self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25,
+ self_attn_layer='bottleneck', self_attn_kwargs=dict()),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ act_layer='silu',
+ ),
+ halo2botnet50ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25,
+ self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25,
+ self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25,
+ self_attn_layer='bottleneck', self_attn_kwargs=dict()),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ act_layer='silu',
+ ),
+)
+
+
+def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ ByobNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
+ feature_cfg=dict(flatten_sequential=True),
+ **kwargs)
+
+
+@register_model
+def botnet26t_256(pretrained=False, **kwargs):
+ """ Bottleneck Transformer w/ ResNet26-T backbone.
+ """
+ kwargs.setdefault('img_size', 256)
+ return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def sebotnet33ts_256(pretrained=False, **kwargs):
+ """ Bottleneck Transformer w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU,
+ """
+ return _create_byoanet('sebotnet33ts_256', 'sebotnet33ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def botnet50ts_256(pretrained=False, **kwargs):
+ """ Bottleneck Transformer w/ ResNet50-T backbone, silu act.
+ """
+ kwargs.setdefault('img_size', 256)
+ return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_botnext26ts_256(pretrained=False, **kwargs):
+ """ Bottleneck Transformer w/ ResNet26-T backbone, silu act.
+ """
+ kwargs.setdefault('img_size', 256)
+ return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def halonet_h1(pretrained=False, **kwargs):
+ """ HaloNet-H1. Halo attention in all stages as per the paper.
+ NOTE: This runs very slowly!
+ """
+ return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def halonet26t(pretrained=False, **kwargs):
+ """ HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages
+ """
+ return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def sehalonet33ts(pretrained=False, **kwargs):
+ """ HaloNet w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU, 1-2 Halo in stage 2,3,4.
+ """
+ return _create_byoanet('sehalonet33ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def halonet50ts(pretrained=False, **kwargs):
+ """ HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages
+ """
+ return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_halonext26ts(pretrained=False, **kwargs):
+ """ HaloNet w/ a ResNet26-t backbone, silu act. Halo attention in final two stages
+ """
+ return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def lambda_resnet26t(pretrained=False, **kwargs):
+ """ Lambda-ResNet-26-T. Lambda layers w/ conv pos in last two stages.
+ """
+ return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def lambda_resnet50ts(pretrained=False, **kwargs):
+ """ Lambda-ResNet-50-TS. SiLU act. Lambda layers w/ conv pos in last two stages.
+ """
+ return _create_byoanet('lambda_resnet50ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def lambda_resnet26rpt_256(pretrained=False, **kwargs):
+ """ Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages.
+ """
+ kwargs.setdefault('img_size', 256)
+ return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def haloregnetz_b(pretrained=False, **kwargs):
+ """ Halo + RegNetZ
+ """
+ return _create_byoanet('haloregnetz_b', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def lamhalobotnet50ts_256(pretrained=False, **kwargs):
+ """ Combo Attention (Lambda + Halo + Bot) Network
+ """
+ return _create_byoanet('lamhalobotnet50ts_256', 'lamhalobotnet50ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def halo2botnet50ts_256(pretrained=False, **kwargs):
+ """ Combo Attention (Halo + Halo + Bot) Network
+ """
+ return _create_byoanet('halo2botnet50ts_256', 'halo2botnet50ts', pretrained=pretrained, **kwargs)
diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py
new file mode 100644
index 0000000..fa57943
--- /dev/null
+++ b/timm/models/byobnet.py
@@ -0,0 +1,1531 @@
+""" Bring-Your-Own-Blocks Network
+
+A flexible network w/ dataclass based config for stacking those NN blocks.
+
+This model is currently used to implement the following networks:
+
+GPU Efficient (ResNets) - gernet_l/m/s (original versions called genet, but this was already used (by SENet author)).
+Paper: `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
+Code and weights: https://github.com/idstcv/GPU-Efficient-Networks, licensed Apache 2.0
+
+RepVGG - repvgg_*
+Paper: `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+Code and weights: https://github.com/DingXiaoH/RepVGG, licensed MIT
+
+In all cases the models have been modified to fit within the design of ByobNet. I've remapped
+the original weights and verified accuracies.
+
+For GPU Efficient nets, I used the original names for the blocks since they were for the most part
+the same as original residual blocks in ResNe(X)t, DarkNet, and other existing models. Note also some
+changes introduced in RegNet were also present in the stem and bottleneck blocks for this model.
+
+A significant number of different network archs can be implemented here, including variants of the
+above nets that include attention.
+
+Hacked together by / copyright Ross Wightman, 2021.
+"""
+import math
+from dataclasses import dataclass, field, replace
+from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg, named_apply
+from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
+ create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNormSample2d
+from .registry import register_model
+
+__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+def _cfgr(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
+ 'crop_pct': 0.9, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # GPU-Efficient (ResNet) weights
+ 'gernet_s': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_s-756b4751.pth'),
+ 'gernet_m': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_m-0873c53a.pth'),
+ 'gernet_l': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-ger-weights/gernet_l-f31e2e8d.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8)),
+
+ # RepVGG weights
+ 'repvgg_a2': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_a2-c1ee6d2b.pth',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')),
+ 'repvgg_b0': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b0-80ac3f1b.pth',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')),
+ 'repvgg_b1': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b1-77ca2989.pth',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')),
+ 'repvgg_b1g4': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b1g4-abde5d92.pth',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')),
+ 'repvgg_b2': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b2-25b7494e.pth',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')),
+ 'repvgg_b2g4': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b2g4-165a85f2.pth',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')),
+ 'repvgg_b3': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b3-199bc50d.pth',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')),
+ 'repvgg_b3g4': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b3g4-73c370bf.pth',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')),
+
+ # experimental configs
+ 'resnet51q': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth',
+ first_conv='stem.conv1', input_size=(3, 256, 256), pool_size=(8, 8),
+ test_input_size=(3, 288, 288), crop_pct=1.0),
+ 'resnet61q': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet61q_ra2-6afc536c.pth',
+ test_input_size=(3, 288, 288), crop_pct=1.0),
+
+ 'resnext26ts': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth'),
+ 'gcresnext26ts': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth'),
+ 'seresnext26ts': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnext26ts_256-6f0d74a3.pth'),
+ 'eca_resnext26ts': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnext26ts_256-5a1d030f.pth'),
+ 'bat_resnext26ts': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/bat_resnext26ts_256-fa6fd595.pth',
+ min_input_size=(3, 256, 256)),
+
+ 'resnet32ts': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth'),
+ 'resnet33ts': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth'),
+ 'gcresnet33ts': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth'),
+ 'seresnet33ts': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnet33ts_256-f8ad44d9.pth'),
+ 'eca_resnet33ts': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnet33ts_256-8f98face.pth'),
+
+ 'gcresnet50t': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet50t_256-96374d1c.pth'),
+
+ 'gcresnext50ts': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth'),
+
+ # experimental models, likely to change ot be removed
+ 'regnetz_b16': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', crop_pct=0.94),
+ 'regnetz_c16': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab2_256-a54bf36a.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), first_conv='stem.conv', crop_pct=0.94),
+ 'regnetz_d32': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d_rab_256-b8073a89.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
+ 'regnetz_d8': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d8_bh-afc03c55.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=1.0),
+ 'regnetz_e8': _cfgr(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_e8_bh-aace8e6e.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=1.0),
+ 'regnetz_d8_evob': _cfgr(
+ url='',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
+ 'regnetz_d8_evos': _cfgr(
+ url='',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
+}
+
+
+@dataclass
+class ByoBlockCfg:
+ type: Union[str, nn.Module]
+ d: int # block depth (number of block repeats in stage)
+ c: int # number of output channels for each block in stage
+ s: int = 2 # stride of stage (first block)
+ gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1
+ br: float = 1. # bottleneck-ratio of blocks in stage
+
+ # NOTE: these config items override the model cfgs that are applied to all blocks by default
+ attn_layer: Optional[str] = None
+ attn_kwargs: Optional[Dict[str, Any]] = None
+ self_attn_layer: Optional[str] = None
+ self_attn_kwargs: Optional[Dict[str, Any]] = None
+ block_kwargs: Optional[Dict[str, Any]] = None
+
+
+@dataclass
+class ByoModelCfg:
+ blocks: Tuple[Union[ByoBlockCfg, Tuple[ByoBlockCfg, ...]], ...]
+ downsample: str = 'conv1x1'
+ stem_type: str = '3x3'
+ stem_pool: Optional[str] = 'maxpool'
+ stem_chs: int = 32
+ width_factor: float = 1.0
+ num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
+ zero_init_last: bool = True # zero init last weight (usually bn) in residual path
+ fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation
+
+ act_layer: str = 'relu'
+ norm_layer: str = 'batchnorm'
+
+ # NOTE: these config items will be overridden by the block cfg (per-block) if they are set there
+ attn_layer: Optional[str] = None
+ attn_kwargs: dict = field(default_factory=lambda: dict())
+ self_attn_layer: Optional[str] = None
+ self_attn_kwargs: dict = field(default_factory=lambda: dict())
+ block_kwargs: Dict[str, Any] = field(default_factory=lambda: dict())
+
+
+def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
+ c = (64, 128, 256, 512)
+ group_size = 0
+ if groups > 0:
+ group_size = lambda chs, idx: chs // groups if (idx + 1) % 2 == 0 else 0
+ bcfg = tuple([ByoBlockCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)])
+ return bcfg
+
+
+def interleave_blocks(
+ types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs
+) -> Tuple[ByoBlockCfg]:
+ """ interleave 2 block types in stack
+ """
+ assert len(types) == 2
+ if isinstance(every, int):
+ every = list(range(0 if first else every, d, every + 1))
+ if not every:
+ every = [d - 1]
+ set(every)
+ blocks = []
+ for i in range(d):
+ block_type = types[1] if i in every else types[0]
+ blocks += [ByoBlockCfg(type=block_type, d=1, **kwargs)]
+ return tuple(blocks)
+
+
+model_cfgs = dict(
+ gernet_l=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
+ ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
+ ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
+ ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
+ ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
+ ),
+ stem_chs=32,
+ stem_pool=None,
+ num_features=2560,
+ ),
+ gernet_m=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
+ ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
+ ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
+ ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
+ ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
+ ),
+ stem_chs=32,
+ stem_pool=None,
+ num_features=2560,
+ ),
+ gernet_s=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
+ ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
+ ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
+ ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
+ ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
+ ),
+ stem_chs=13,
+ stem_pool=None,
+ num_features=1920,
+ ),
+
+ repvgg_a2=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b0=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(1., 1., 1., 2.5)),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b1=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.)),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b1g4=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.), groups=4),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b2=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.)),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b2g4=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.), groups=4),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b3=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.)),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b3g4=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.), groups=4),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+
+ # 4 x conv stem w/ 2 act, no maxpool, 2,4,6,4 repeats, group size 32 in first 3 blocks
+ # DW convs in last block, 2048 pre-FC, silu act
+ resnet51q=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
+ ),
+ stem_chs=128,
+ stem_type='quad2',
+ stem_pool=None,
+ num_features=2048,
+ act_layer='silu',
+ ),
+
+ # 4 x conv stem w/ 4 act, no maxpool, 1,4,6,4 repeats, edge block first, group size 32 in next 2 blocks
+ # DW convs in last block, 4 conv for each bottle block, 2048 pre-FC, silu act
+ resnet61q=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='edge', d=1, c=256, s=1, gs=0, br=1.0, block_kwargs=dict()),
+ ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
+ ),
+ stem_chs=128,
+ stem_type='quad',
+ stem_pool=None,
+ num_features=2048,
+ act_layer='silu',
+ block_kwargs=dict(extra_conv=True),
+ ),
+
+ # A series of ResNeXt-26 models w/ one of none, GC, SE, ECA, BAT attn, group size 32, SiLU act,
+ # and a tiered stem w/ maxpool
+ resnext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ ),
+ gcresnext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ attn_layer='gca',
+ ),
+ seresnext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ attn_layer='se',
+ ),
+ eca_resnext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ attn_layer='eca',
+ ),
+ bat_resnext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ attn_layer='bat',
+ attn_kwargs=dict(block_size=8)
+ ),
+
+ # ResNet-32 (2, 3, 3, 2) models w/ no attn, no groups, SiLU act, no pre-fc feat layer, tiered stem w/o maxpool
+ resnet32ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ num_features=0,
+ act_layer='silu',
+ ),
+
+ # ResNet-33 (2, 3, 3, 2) models w/ no attn, no groups, SiLU act, 1280 pre-FC feat, tiered stem w/o maxpool
+ resnet33ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ num_features=1280,
+ act_layer='silu',
+ ),
+
+ # A series of ResNet-33 (2, 3, 3, 2) models w/ one of GC, SE, ECA attn, no groups, SiLU act, 1280 pre-FC feat
+ # and a tiered stem w/ no maxpool
+ gcresnet33ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ num_features=1280,
+ act_layer='silu',
+ attn_layer='gca',
+ ),
+ seresnet33ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ num_features=1280,
+ act_layer='silu',
+ attn_layer='se',
+ ),
+ eca_resnet33ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ num_features=1280,
+ act_layer='silu',
+ attn_layer='eca',
+ ),
+
+ gcresnet50t=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
+ ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
+ ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ attn_layer='gca',
+ ),
+
+ gcresnext50ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=6, c=1024, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=2048, s=2, gs=32, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ # stem_pool=None,
+ act_layer='silu',
+ attn_layer='gca',
+ ),
+
+ # experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW
+ regnetz_b16=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
+ ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
+ ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3),
+ ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3),
+ ),
+ stem_chs=32,
+ stem_pool='',
+ downsample='',
+ num_features=1536,
+ act_layer='silu',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+ regnetz_c16=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4),
+ ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4),
+ ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=4),
+ ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=4),
+ ),
+ stem_chs=32,
+ stem_pool='',
+ downsample='',
+ num_features=1536,
+ act_layer='silu',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+ regnetz_d32=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=32, br=4),
+ ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=32, br=4),
+ ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=32, br=4),
+ ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=32, br=4),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ downsample='',
+ num_features=1792,
+ act_layer='silu',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+ regnetz_d8=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ downsample='',
+ num_features=1792,
+ act_layer='silu',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+ regnetz_e8=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=96, s=1, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=8, c=192, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=16, c=384, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=8, br=4),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ downsample='',
+ num_features=2048,
+ act_layer='silu',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+
+ # experimental EvoNorm configs
+ regnetz_d8_evob=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ downsample='',
+ num_features=1792,
+ act_layer='silu',
+ norm_layer='evonormbatch',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+ regnetz_d8_evos=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4),
+ ),
+ stem_chs=64,
+ stem_type='deep',
+ stem_pool='',
+ downsample='',
+ num_features=1792,
+ act_layer='silu',
+ norm_layer=partial(EvoNormSample2d, groups=32),
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+)
+
+@register_model
+def gernet_l(pretrained=False, **kwargs):
+ """ GEResNet-Large (GENet-Large from official impl)
+ `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
+ """
+ return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def gernet_m(pretrained=False, **kwargs):
+ """ GEResNet-Medium (GENet-Normal from official impl)
+ `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
+ """
+ return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def gernet_s(pretrained=False, **kwargs):
+ """ EResNet-Small (GENet-Small from official impl)
+ `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
+ """
+ return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_a2(pretrained=False, **kwargs):
+ """ RepVGG-A2
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b0(pretrained=False, **kwargs):
+ """ RepVGG-B0
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b1(pretrained=False, **kwargs):
+ """ RepVGG-B1
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b1g4(pretrained=False, **kwargs):
+ """ RepVGG-B1g4
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b2(pretrained=False, **kwargs):
+ """ RepVGG-B2
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b2g4(pretrained=False, **kwargs):
+ """ RepVGG-B2g4
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b3(pretrained=False, **kwargs):
+ """ RepVGG-B3
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b3g4(pretrained=False, **kwargs):
+ """ RepVGG-B3g4
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def resnet51q(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('resnet51q', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def resnet61q(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('resnet61q', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def resnext26ts(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('resnext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def gcresnext26ts(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('gcresnext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def seresnext26ts(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('seresnext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_resnext26ts(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('eca_resnext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def bat_resnext26ts(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('bat_resnext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def resnet32ts(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('resnet32ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def resnet33ts(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('resnet33ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def gcresnet33ts(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('gcresnet33ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def seresnet33ts(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('seresnet33ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_resnet33ts(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('eca_resnet33ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def gcresnet50t(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('gcresnet50t', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def gcresnext50ts(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('gcresnext50ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_b16(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('regnetz_b16', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_c16(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('regnetz_c16', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_d32(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('regnetz_d32', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_d8(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('regnetz_d8', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_e8(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('regnetz_e8', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_d8_evob(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('regnetz_d8_evob', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_d8_evos(pretrained=False, **kwargs):
+ """
+ """
+ return _create_byobnet('regnetz_d8_evos', pretrained=pretrained, **kwargs)
+
+
+def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]:
+ if not isinstance(stage_blocks_cfg, Sequence):
+ stage_blocks_cfg = (stage_blocks_cfg,)
+ block_cfgs = []
+ for i, cfg in enumerate(stage_blocks_cfg):
+ block_cfgs += [replace(cfg, d=1) for _ in range(cfg.d)]
+ return block_cfgs
+
+
+def num_groups(group_size, channels):
+ if not group_size: # 0 or None
+ return 1 # normal conv with 1 group
+ else:
+ # NOTE group_size == 1 -> depthwise conv
+ assert channels % group_size == 0
+ return channels // group_size
+
+
+@dataclass
+class LayerFn:
+ conv_norm_act: Callable = ConvBnAct
+ norm_act: Callable = BatchNormAct2d
+ act: Callable = nn.ReLU
+ attn: Optional[Callable] = None
+ self_attn: Optional[Callable] = None
+
+
+class DownsampleAvg(nn.Module):
+ def __init__(self, in_chs, out_chs, stride=1, dilation=1, apply_act=False, layers: LayerFn = None):
+ """ AvgPool Downsampling as in 'D' ResNet variants."""
+ super(DownsampleAvg, self).__init__()
+ layers = layers or LayerFn()
+ avg_stride = stride if dilation == 1 else 1
+ if stride > 1 or dilation > 1:
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
+ self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
+ else:
+ self.pool = nn.Identity()
+ self.conv = layers.conv_norm_act(in_chs, out_chs, 1, apply_act=apply_act)
+
+ def forward(self, x):
+ return self.conv(self.pool(x))
+
+
+def create_shortcut(downsample_type, layers: LayerFn, in_chs, out_chs, stride, dilation, **kwargs):
+ assert downsample_type in ('avg', 'conv1x1', '')
+ if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
+ if not downsample_type:
+ return None # no shortcut
+ elif downsample_type == 'avg':
+ return DownsampleAvg(in_chs, out_chs, stride=stride, dilation=dilation[0], **kwargs)
+ else:
+ return layers.conv_norm_act(in_chs, out_chs, kernel_size=1, stride=stride, dilation=dilation[0], **kwargs)
+ else:
+ return nn.Identity() # identity shortcut
+
+
+class BasicBlock(nn.Module):
+ """ ResNet Basic Block - kxk + kxk
+ """
+
+ def __init__(
+ self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0,
+ downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
+ drop_path_rate=0.):
+ super(BasicBlock, self).__init__()
+ layers = layers or LayerFn()
+ mid_chs = make_divisible(out_chs * bottle_ratio)
+ groups = num_groups(group_size, mid_chs)
+
+ self.shortcut = create_shortcut(
+ downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
+ apply_act=False, layers=layers)
+
+ self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
+ self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
+ self.conv2_kxk = layers.conv_norm_act(
+ mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False)
+ self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
+
+ def init_weights(self, zero_init_last: bool = False):
+ if zero_init_last and self.shortcut is not None:
+ nn.init.zeros_(self.conv2_kxk.bn.weight)
+ for attn in (self.attn, self.attn_last):
+ if hasattr(attn, 'reset_parameters'):
+ attn.reset_parameters()
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1_kxk(x)
+ x = self.conv2_kxk(x)
+ x = self.attn(x)
+ x = self.drop_path(x)
+ if self.shortcut is not None:
+ x = x + self.shortcut(shortcut)
+ return self.act(x)
+
+
+class BottleneckBlock(nn.Module):
+ """ ResNet-like Bottleneck Block - 1x1 - kxk - 1x1
+ """
+
+ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
+ downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False,
+ layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
+ super(BottleneckBlock, self).__init__()
+ layers = layers or LayerFn()
+ mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
+ groups = num_groups(group_size, mid_chs)
+
+ self.shortcut = create_shortcut(
+ downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
+ apply_act=False, layers=layers)
+
+ self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
+ self.conv2_kxk = layers.conv_norm_act(
+ mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
+ groups=groups, drop_block=drop_block)
+ if extra_conv:
+ self.conv2b_kxk = layers.conv_norm_act(
+ mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block)
+ else:
+ self.conv2b_kxk = nn.Identity()
+ self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
+ self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
+ self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
+
+ def init_weights(self, zero_init_last: bool = False):
+ if zero_init_last and self.shortcut is not None:
+ nn.init.zeros_(self.conv3_1x1.bn.weight)
+ for attn in (self.attn, self.attn_last):
+ if hasattr(attn, 'reset_parameters'):
+ attn.reset_parameters()
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1_1x1(x)
+ x = self.conv2_kxk(x)
+ x = self.conv2b_kxk(x)
+ x = self.attn(x)
+ x = self.conv3_1x1(x)
+ x = self.attn_last(x)
+ x = self.drop_path(x)
+ if self.shortcut is not None:
+ x = x + self.shortcut(shortcut)
+ return self.act(x)
+
+
+class DarkBlock(nn.Module):
+ """ DarkNet-like (1x1 + 3x3 w/ stride) block
+
+ The GE-Net impl included a 1x1 + 3x3 block in their search space. It was not used in the feature models.
+ This block is pretty much a DarkNet block (also DenseNet) hence the name. Neither DarkNet or DenseNet
+ uses strides within the block (external 3x3 or maxpool downsampling is done in front of the block repeats).
+
+ If one does want to use a lot of these blocks w/ stride, I'd recommend using the EdgeBlock (3x3 /w stride + 1x1)
+ for more optimal compute.
+ """
+
+ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
+ downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
+ drop_path_rate=0.):
+ super(DarkBlock, self).__init__()
+ layers = layers or LayerFn()
+ mid_chs = make_divisible(out_chs * bottle_ratio)
+ groups = num_groups(group_size, mid_chs)
+
+ self.shortcut = create_shortcut(
+ downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
+ apply_act=False, layers=layers)
+
+ self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
+ self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
+ self.conv2_kxk = layers.conv_norm_act(
+ mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
+ groups=groups, drop_block=drop_block, apply_act=False)
+ self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
+
+ def init_weights(self, zero_init_last: bool = False):
+ if zero_init_last and self.shortcut is not None:
+ nn.init.zeros_(self.conv2_kxk.bn.weight)
+ for attn in (self.attn, self.attn_last):
+ if hasattr(attn, 'reset_parameters'):
+ attn.reset_parameters()
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1_1x1(x)
+ x = self.attn(x)
+ x = self.conv2_kxk(x)
+ x = self.attn_last(x)
+ x = self.drop_path(x)
+ if self.shortcut is not None:
+ x = x + self.shortcut(shortcut)
+ return self.act(x)
+
+
+class EdgeBlock(nn.Module):
+ """ EdgeResidual-like (3x3 + 1x1) block
+
+ A two layer block like DarkBlock, but with the order of the 3x3 and 1x1 convs reversed.
+ Very similar to the EfficientNet Edge-Residual block but this block it ends with activations, is
+ intended to be used with either expansion or bottleneck contraction, and can use DW/group/non-grouped convs.
+
+ FIXME is there a more common 3x3 + 1x1 conv block to name this after?
+ """
+
+ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
+ downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None,
+ drop_block=None, drop_path_rate=0.):
+ super(EdgeBlock, self).__init__()
+ layers = layers or LayerFn()
+ mid_chs = make_divisible(out_chs * bottle_ratio)
+ groups = num_groups(group_size, mid_chs)
+
+ self.shortcut = create_shortcut(
+ downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
+ apply_act=False, layers=layers)
+
+ self.conv1_kxk = layers.conv_norm_act(
+ in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
+ groups=groups, drop_block=drop_block)
+ self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
+ self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
+ self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
+
+ def init_weights(self, zero_init_last: bool = False):
+ if zero_init_last and self.shortcut is not None:
+ nn.init.zeros_(self.conv2_1x1.bn.weight)
+ for attn in (self.attn, self.attn_last):
+ if hasattr(attn, 'reset_parameters'):
+ attn.reset_parameters()
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1_kxk(x)
+ x = self.attn(x)
+ x = self.conv2_1x1(x)
+ x = self.attn_last(x)
+ x = self.drop_path(x)
+ if self.shortcut is not None:
+ x = x + self.shortcut(shortcut)
+ return self.act(x)
+
+
+class RepVggBlock(nn.Module):
+ """ RepVGG Block.
+
+ Adapted from impl at https://github.com/DingXiaoH/RepVGG
+
+ This version does not currently support the deploy optimization. It is currently fixed in 'train' mode.
+ """
+
+ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
+ downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
+ super(RepVggBlock, self).__init__()
+ layers = layers or LayerFn()
+ groups = num_groups(group_size, in_chs)
+
+ use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
+ self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
+ self.conv_kxk = layers.conv_norm_act(
+ in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
+ groups=groups, drop_block=drop_block, apply_act=False)
+ self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False)
+ self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
+ self.act = layers.act(inplace=True)
+
+ def init_weights(self, zero_init_last: bool = False):
+ # NOTE this init overrides that base model init with specific changes for the block type
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ nn.init.normal_(m.weight, .1, .1)
+ nn.init.normal_(m.bias, 0, .1)
+ if hasattr(self.attn, 'reset_parameters'):
+ self.attn.reset_parameters()
+
+ def forward(self, x):
+ if self.identity is None:
+ x = self.conv_1x1(x) + self.conv_kxk(x)
+ else:
+ identity = self.identity(x)
+ x = self.conv_1x1(x) + self.conv_kxk(x)
+ x = self.drop_path(x) # not in the paper / official impl, experimental
+ x = x + identity
+ x = self.attn(x) # no attn in the paper / official impl, experimental
+ return self.act(x)
+
+
+class SelfAttnBlock(nn.Module):
+ """ ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
+ """
+
+ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
+ downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True,
+ feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
+ super(SelfAttnBlock, self).__init__()
+ assert layers is not None
+ mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
+ groups = num_groups(group_size, mid_chs)
+
+ self.shortcut = create_shortcut(
+ downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation,
+ apply_act=False, layers=layers)
+
+ self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
+ if extra_conv:
+ self.conv2_kxk = layers.conv_norm_act(
+ mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
+ groups=groups, drop_block=drop_block)
+ stride = 1 # striding done via conv if enabled
+ else:
+ self.conv2_kxk = nn.Identity()
+ opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size)
+ # FIXME need to dilate self attn to have dilated network support, moop moop
+ self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs)
+ self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity()
+ self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
+
+ def init_weights(self, zero_init_last: bool = False):
+ if zero_init_last and self.shortcut is not None:
+ nn.init.zeros_(self.conv3_1x1.bn.weight)
+ if hasattr(self.self_attn, 'reset_parameters'):
+ self.self_attn.reset_parameters()
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1_1x1(x)
+ x = self.conv2_kxk(x)
+ x = self.self_attn(x)
+ x = self.post_attn(x)
+ x = self.conv3_1x1(x)
+ x = self.drop_path(x)
+ if self.shortcut is not None:
+ x = x + self.shortcut(shortcut)
+ return self.act(x)
+
+_block_registry = dict(
+ basic=BasicBlock,
+ bottle=BottleneckBlock,
+ dark=DarkBlock,
+ edge=EdgeBlock,
+ rep=RepVggBlock,
+ self_attn=SelfAttnBlock,
+)
+
+
+def register_block(block_type:str, block_fn: nn.Module):
+ _block_registry[block_type] = block_fn
+
+
+def create_block(block: Union[str, nn.Module], **kwargs):
+ if isinstance(block, (nn.Module, partial)):
+ return block(**kwargs)
+ assert block in _block_registry, f'Unknown block type ({block}'
+ return _block_registry[block](**kwargs)
+
+
+class Stem(nn.Sequential):
+
+ def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
+ num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
+ super().__init__()
+ assert stride in (2, 4)
+ layers = layers or LayerFn()
+
+ if isinstance(out_chs, (list, tuple)):
+ num_rep = len(out_chs)
+ stem_chs = out_chs
+ else:
+ stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
+
+ self.stride = stride
+ self.feature_info = [] # track intermediate features
+ prev_feat = ''
+ stem_strides = [2] + [1] * (num_rep - 1)
+ if stride == 4 and not pool:
+ # set last conv in stack to be strided if stride == 4 and no pooling layer
+ stem_strides[-1] = 2
+
+ num_act = num_rep if num_act is None else num_act
+ # if num_act < num_rep, first convs in stack won't have bn + act
+ stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
+ prev_chs = in_chs
+ curr_stride = 1
+ for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
+ layer_fn = layers.conv_norm_act if na else create_conv2d
+ conv_name = f'conv{i + 1}'
+ if i > 0 and s > 1:
+ self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
+ self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
+ prev_chs = ch
+ curr_stride *= s
+ prev_feat = conv_name
+
+ if pool and 'max' in pool.lower():
+ self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
+ self.add_module('pool', nn.MaxPool2d(3, 2, 1))
+ curr_stride *= 2
+ prev_feat = 'pool'
+
+ self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
+ assert curr_stride == stride
+
+
+def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None):
+ layers = layers or LayerFn()
+ assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', '7x7', '3x3')
+ if 'quad' in stem_type:
+ # based on NFNet stem, stack of 4 3x3 convs
+ num_act = 2 if 'quad2' in stem_type else None
+ stem = Stem(in_chs, out_chs, num_rep=4, num_act=num_act, pool=pool_type, layers=layers)
+ elif 'tiered' in stem_type:
+ # 3x3 stack of 3 convs as in my ResNet-T
+ stem = Stem(in_chs, (3 * out_chs // 8, out_chs // 2, out_chs), pool=pool_type, layers=layers)
+ elif 'deep' in stem_type:
+ # 3x3 stack of 3 convs as in ResNet-D
+ stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers)
+ elif 'rep' in stem_type:
+ stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers)
+ elif '7x7' in stem_type:
+ # 7x7 stem conv as in ResNet
+ if pool_type:
+ stem = Stem(in_chs, out_chs, 7, num_rep=1, pool=pool_type, layers=layers)
+ else:
+ stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2)
+ else:
+ # 3x3 stem conv as in RegNet is the default
+ if pool_type:
+ stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers)
+ else:
+ stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2)
+
+ if isinstance(stem, Stem):
+ feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
+ else:
+ feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)]
+ return stem, feature_info
+
+
+def reduce_feat_size(feat_size, stride=2):
+ return None if feat_size is None else tuple([s // stride for s in feat_size])
+
+
+def override_kwargs(block_kwargs, model_kwargs):
+ """ Override model level attn/self-attn/block kwargs w/ block level
+
+ NOTE: kwargs are NOT merged across levels, block_kwargs will fully replace model_kwargs
+ for the block if set to anything that isn't None.
+
+ i.e. an empty block_kwargs dict will remove kwargs set at model level for that block
+ """
+ out_kwargs = block_kwargs if block_kwargs is not None else model_kwargs
+ return out_kwargs or {} # make sure None isn't returned
+
+
+def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, model_cfg: ByoModelCfg, ):
+ layer_fns = block_kwargs['layers']
+
+ # override attn layer / args with block local config
+ attn_set = block_cfg.attn_layer is not None
+ if attn_set or block_cfg.attn_kwargs is not None:
+ # override attn layer config
+ if attn_set and not block_cfg.attn_layer:
+ # empty string for attn_layer type will disable attn for this block
+ attn_layer = None
+ else:
+ attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs)
+ attn_layer = block_cfg.attn_layer or model_cfg.attn_layer
+ attn_layer = partial(get_attn(attn_layer), **attn_kwargs) if attn_layer is not None else None
+ layer_fns = replace(layer_fns, attn=attn_layer)
+
+ # override self-attn layer / args with block local cfg
+ self_attn_set = block_cfg.self_attn_layer is not None
+ if self_attn_set or block_cfg.self_attn_kwargs is not None:
+ # override attn layer config
+ if self_attn_set and not block_cfg.self_attn_layer: # attn_layer == ''
+ # empty string for self_attn_layer type will disable attn for this block
+ self_attn_layer = None
+ else:
+ self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
+ self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
+ self_attn_layer = partial(get_attn(self_attn_layer), **self_attn_kwargs) \
+ if self_attn_layer is not None else None
+ layer_fns = replace(layer_fns, self_attn=self_attn_layer)
+
+ block_kwargs['layers'] = layer_fns
+
+ # add additional block_kwargs specified in block_cfg or model_cfg, precedence to block if set
+ block_kwargs.update(override_kwargs(block_cfg.block_kwargs, model_cfg.block_kwargs))
+
+
+def create_byob_stages(
+ cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any],
+ feat_size: Optional[int] = None,
+ layers: Optional[LayerFn] = None,
+ block_kwargs_fn: Optional[Callable] = update_block_kwargs):
+
+ layers = layers or LayerFn()
+ feature_info = []
+ block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
+ depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs]
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
+ dilation = 1
+ net_stride = stem_feat['reduction']
+ prev_chs = stem_feat['num_chs']
+ prev_feat = stem_feat
+ stages = []
+ for stage_idx, stage_block_cfgs in enumerate(block_cfgs):
+ stride = stage_block_cfgs[0].s
+ if stride != 1 and prev_feat:
+ feature_info.append(prev_feat)
+ if net_stride >= output_stride and stride > 1:
+ dilation *= stride
+ stride = 1
+ net_stride *= stride
+ first_dilation = 1 if dilation in (1, 2) else 2
+
+ blocks = []
+ for block_idx, block_cfg in enumerate(stage_block_cfgs):
+ out_chs = make_divisible(block_cfg.c * cfg.width_factor)
+ group_size = block_cfg.gs
+ if isinstance(group_size, Callable):
+ group_size = group_size(out_chs, block_idx)
+ block_kwargs = dict( # Blocks used in this model must accept these arguments
+ in_chs=prev_chs,
+ out_chs=out_chs,
+ stride=stride if block_idx == 0 else 1,
+ dilation=(first_dilation, dilation),
+ group_size=group_size,
+ bottle_ratio=block_cfg.br,
+ downsample=cfg.downsample,
+ drop_path_rate=dpr[stage_idx][block_idx],
+ layers=layers,
+ )
+ if block_cfg.type in ('self_attn',):
+ # add feat_size arg for blocks that support/need it
+ block_kwargs['feat_size'] = feat_size
+ block_kwargs_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg)
+ blocks += [create_block(block_cfg.type, **block_kwargs)]
+ first_dilation = dilation
+ prev_chs = out_chs
+ if stride > 1 and block_idx == 0:
+ feat_size = reduce_feat_size(feat_size, stride)
+
+ stages += [nn.Sequential(*blocks)]
+ prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
+
+ feature_info.append(prev_feat)
+ return nn.Sequential(*stages), feature_info
+
+
+def get_layer_fns(cfg: ByoModelCfg):
+ act = get_act_layer(cfg.act_layer)
+ norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
+ conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
+ attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
+ self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
+ layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
+ return layer_fn
+
+
+class ByobNet(nn.Module):
+ """ 'Bring-your-own-blocks' Net
+
+ A flexible network backbone that allows building model stem + blocks via
+ dataclass cfg definition w/ factory functions for module instantiation.
+
+ Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
+ """
+ def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
+ zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.):
+ super().__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ layers = get_layer_fns(cfg)
+ if cfg.fixed_input_size:
+ assert img_size is not None, 'img_size argument is required for fixed input size model'
+ feat_size = to_2tuple(img_size) if img_size is not None else None
+
+ self.feature_info = []
+ stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
+ self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
+ self.feature_info.extend(stem_feat[:-1])
+ feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
+
+ self.stages, stage_feat = create_byob_stages(
+ cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers, feat_size=feat_size)
+ self.feature_info.extend(stage_feat[:-1])
+
+ prev_chs = stage_feat[-1]['num_chs']
+ if cfg.num_features:
+ self.num_features = int(round(cfg.width_factor * cfg.num_features))
+ self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
+ else:
+ self.num_features = prev_chs
+ self.final_conv = nn.Identity()
+ self.feature_info += [
+ dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
+
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ # init weights
+ named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
+
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ x = self.stages(x)
+ x = self.final_conv(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def _init_weights(module, name='', zero_init_last=False):
+ if isinstance(module, nn.Conv2d):
+ fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
+ fan_out //= module.groups
+ module.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Linear):
+ nn.init.normal_(module.weight, mean=0.0, std=0.01)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.BatchNorm2d):
+ nn.init.ones_(module.weight)
+ nn.init.zeros_(module.bias)
+ elif hasattr(module, 'init_weights'):
+ module.init_weights(zero_init_last=zero_init_last)
+
+
+def _create_byobnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ ByobNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ model_cfg=model_cfgs[variant],
+ feature_cfg=dict(flatten_sequential=True),
+ **kwargs)
diff --git a/timm/models/cait.py b/timm/models/cait.py
new file mode 100644
index 0000000..69b4ba0
--- /dev/null
+++ b/timm/models/cait.py
@@ -0,0 +1,394 @@
+""" Class-Attention in Image Transformers (CaiT)
+
+Paper: 'Going deeper with Image Transformers' - https://arxiv.org/abs/2103.17239
+
+Original code and weights from https://github.com/facebookresearch/deit, copyright below
+
+"""
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+from functools import partial
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg, overlay_external_default_cfg
+from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
+from .registry import register_model
+
+
+__all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn']
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 384, 384), 'pool_size': None,
+ 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = dict(
+ cait_xxs24_224=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/XXS24_224.pth',
+ input_size=(3, 224, 224),
+ ),
+ cait_xxs24_384=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/XXS24_384.pth',
+ ),
+ cait_xxs36_224=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/XXS36_224.pth',
+ input_size=(3, 224, 224),
+ ),
+ cait_xxs36_384=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/XXS36_384.pth',
+ ),
+ cait_xs24_384=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/XS24_384.pth',
+ ),
+ cait_s24_224=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/S24_224.pth',
+ input_size=(3, 224, 224),
+ ),
+ cait_s24_384=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/S24_384.pth',
+ ),
+ cait_s36_384=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/S36_384.pth',
+ ),
+ cait_m36_384=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/M36_384.pth',
+ ),
+ cait_m48_448=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/M48_448.pth',
+ input_size=(3, 448, 448),
+ ),
+)
+
+
+class ClassAttn(nn.Module):
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+ # with slight modifications to do CA
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.k = nn.Linear(dim, dim, bias=qkv_bias)
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+
+ q = q * self.scale
+ v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+
+ attn = (q @ k.transpose(-2, -1))
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
+ x_cls = self.proj(x_cls)
+ x_cls = self.proj_drop(x_cls)
+
+ return x_cls
+
+
+class LayerScaleBlockClassAttn(nn.Module):
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+ # with slight modifications to add CA and LayerScale
+ def __init__(
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn,
+ mlp_block=Mlp, init_values=1e-4):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_block(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+
+ def forward(self, x, x_cls):
+ u = torch.cat((x_cls, x), dim=1)
+ x_cls = x_cls + self.drop_path(self.gamma_1 * self.attn(self.norm1(u)))
+ x_cls = x_cls + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x_cls)))
+ return x_cls
+
+
+class TalkingHeadAttn(nn.Module):
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+ # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+
+ self.num_heads = num_heads
+
+ head_dim = dim // num_heads
+
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+
+ self.proj = nn.Linear(dim, dim)
+
+ self.proj_l = nn.Linear(num_heads, num_heads)
+ self.proj_w = nn.Linear(num_heads, num_heads)
+
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+
+ attn = (q @ k.transpose(-2, -1))
+
+ attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+
+ attn = attn.softmax(dim=-1)
+
+ attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class LayerScaleBlock(nn.Module):
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+ # with slight modifications to add layerScale
+ def __init__(
+ self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn,
+ mlp_block=Mlp, init_values=1e-4):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_block(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class Cait(nn.Module):
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+ # with slight modifications to adapt to our cait models
+ def __init__(
+ self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ global_pool=None,
+ block_layers=LayerScaleBlock,
+ block_layers_token=LayerScaleBlockClassAttn,
+ patch_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ attn_block=TalkingHeadAttn,
+ mlp_block=Mlp,
+ init_scale=1e-4,
+ attn_block_token_only=ClassAttn,
+ mlp_block_token_only=Mlp,
+ depth_token_only=2,
+ mlp_ratio_clstk=4.0
+ ):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim
+
+ self.patch_embed = patch_layer(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [drop_path_rate for i in range(depth)]
+ self.blocks = nn.ModuleList([
+ block_layers(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_scale)
+ for i in range(depth)])
+
+ self.blocks_token_only = nn.ModuleList([
+ block_layers_token(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias,
+ drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer,
+ act_layer=act_layer, attn_block=attn_block_token_only,
+ mlp_block=mlp_block_token_only, init_values=init_scale)
+ for i in range(depth_token_only)])
+
+ self.norm = norm_layer(embed_dim)
+
+ self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ for i, blk in enumerate(self.blocks):
+ x = blk(x)
+
+ for i, blk in enumerate(self.blocks_token_only):
+ cls_tokens = blk(x, cls_tokens)
+
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = self.norm(x)
+ return x[:, 0]
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def checkpoint_filter_fn(state_dict, model=None):
+ if 'model' in state_dict:
+ state_dict = state_dict['model']
+ checkpoint_no_module = {}
+ for k, v in state_dict.items():
+ checkpoint_no_module[k.replace('module.', '')] = v
+ return checkpoint_no_module
+
+
+def _create_cait(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ model = build_model_with_cfg(
+ Cait, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+ return model
+
+
+@register_model
+def cait_xxs24_224(pretrained=False, **kwargs):
+ model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs)
+ model = _create_cait('cait_xxs24_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def cait_xxs24_384(pretrained=False, **kwargs):
+ model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs)
+ model = _create_cait('cait_xxs24_384', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def cait_xxs36_224(pretrained=False, **kwargs):
+ model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs)
+ model = _create_cait('cait_xxs36_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def cait_xxs36_384(pretrained=False, **kwargs):
+ model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs)
+ model = _create_cait('cait_xxs36_384', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def cait_xs24_384(pretrained=False, **kwargs):
+ model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_scale=1e-5, **kwargs)
+ model = _create_cait('cait_xs24_384', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def cait_s24_224(pretrained=False, **kwargs):
+ model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs)
+ model = _create_cait('cait_s24_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def cait_s24_384(pretrained=False, **kwargs):
+ model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs)
+ model = _create_cait('cait_s24_384', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def cait_s36_384(pretrained=False, **kwargs):
+ model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_scale=1e-6, **kwargs)
+ model = _create_cait('cait_s36_384', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def cait_m36_384(pretrained=False, **kwargs):
+ model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_scale=1e-6, **kwargs)
+ model = _create_cait('cait_m36_384', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def cait_m48_448(pretrained=False, **kwargs):
+ model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_scale=1e-6, **kwargs)
+ model = _create_cait('cait_m48_448', pretrained=pretrained, **model_args)
+ return model
diff --git a/timm/models/coat.py b/timm/models/coat.py
new file mode 100644
index 0000000..18ff8ab
--- /dev/null
+++ b/timm/models/coat.py
@@ -0,0 +1,661 @@
+"""
+CoaT architecture.
+
+Paper: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399
+
+Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
+
+Modified from timm/models/vision_transformer.py
+"""
+from copy import deepcopy
+from functools import partial
+from typing import Tuple, List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg, overlay_external_default_cfg
+from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
+from .registry import register_model
+from .layers import _assert
+
+
+__all__ = [
+ "coat_tiny",
+ "coat_mini",
+ "coat_lite_tiny",
+ "coat_lite_mini",
+ "coat_lite_small"
+]
+
+
+def _cfg_coat(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed1.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'coat_tiny': _cfg_coat(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_tiny-473c2a20.pth'
+ ),
+ 'coat_mini': _cfg_coat(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_mini-2c6baf49.pth'
+ ),
+ 'coat_lite_tiny': _cfg_coat(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth'
+ ),
+ 'coat_lite_mini': _cfg_coat(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth'
+ ),
+ 'coat_lite_small': _cfg_coat(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_small-fea1d5a1.pth'
+ ),
+}
+
+
+class ConvRelPosEnc(nn.Module):
+ """ Convolutional relative position encoding. """
+ def __init__(self, Ch, h, window):
+ """
+ Initialization.
+ Ch: Channels per head.
+ h: Number of heads.
+ window: Window size(s) in convolutional relative positional encoding. It can have two forms:
+ 1. An integer of window size, which assigns all attention heads with the same window s
+ size in ConvRelPosEnc.
+ 2. A dict mapping window size to #attention head splits (
+ e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2})
+ It will apply different window size to the attention head splits.
+ """
+ super().__init__()
+
+ if isinstance(window, int):
+ # Set the same window size for all attention heads.
+ window = {window: h}
+ self.window = window
+ elif isinstance(window, dict):
+ self.window = window
+ else:
+ raise ValueError()
+
+ self.conv_list = nn.ModuleList()
+ self.head_splits = []
+ for cur_window, cur_head_split in window.items():
+ dilation = 1
+ # Determine padding size.
+ # Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
+ padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2
+ cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch,
+ kernel_size=(cur_window, cur_window),
+ padding=(padding_size, padding_size),
+ dilation=(dilation, dilation),
+ groups=cur_head_split*Ch,
+ )
+ self.conv_list.append(cur_conv)
+ self.head_splits.append(cur_head_split)
+ self.channel_splits = [x*Ch for x in self.head_splits]
+
+ def forward(self, q, v, size: Tuple[int, int]):
+ B, h, N, Ch = q.shape
+ H, W = size
+ _assert(N == 1 + H * W, '')
+
+ # Convolutional relative position encoding.
+ q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
+ v_img = v[:, :, 1:, :] # [B, h, H*W, Ch]
+
+ v_img = v_img.transpose(-1, -2).reshape(B, h * Ch, H, W)
+ v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels
+ conv_v_img_list = []
+ for i, conv in enumerate(self.conv_list):
+ conv_v_img_list.append(conv(v_img_list[i]))
+ conv_v_img = torch.cat(conv_v_img_list, dim=1)
+ conv_v_img = conv_v_img.reshape(B, h, Ch, H * W).transpose(-1, -2)
+
+ EV_hat = q_img * conv_v_img
+ EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch].
+ return EV_hat
+
+
+class FactorAtt_ConvRelPosEnc(nn.Module):
+ """ Factorized attention with convolutional relative position encoding class. """
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ # Shared convolutional relative position encoding.
+ self.crpe = shared_crpe
+
+ def forward(self, x, size: Tuple[int, int]):
+ B, N, C = x.shape
+
+ # Generate Q, K, V.
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # [B, h, N, Ch]
+
+ # Factorized attention.
+ k_softmax = k.softmax(dim=2)
+ factor_att = k_softmax.transpose(-1, -2) @ v
+ factor_att = q @ factor_att
+
+ # Convolutional relative position encoding.
+ crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
+
+ # Merge and reshape.
+ x = self.scale * factor_att + crpe
+ x = x.transpose(1, 2).reshape(B, N, C) # [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]
+
+ # Output projection.
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class ConvPosEnc(nn.Module):
+ """ Convolutional Position Encoding.
+ Note: This module is similar to the conditional position encoding in CPVT.
+ """
+ def __init__(self, dim, k=3):
+ super(ConvPosEnc, self).__init__()
+ self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim)
+
+ def forward(self, x, size: Tuple[int, int]):
+ B, N, C = x.shape
+ H, W = size
+ _assert(N == 1 + H * W, '')
+
+ # Extract CLS token and image tokens.
+ cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
+
+ # Depthwise convolution.
+ feat = img_tokens.transpose(1, 2).view(B, C, H, W)
+ x = self.proj(feat) + feat
+ x = x.flatten(2).transpose(1, 2)
+
+ # Combine with CLS token.
+ x = torch.cat((cls_token, x), dim=1)
+
+ return x
+
+
+class SerialBlock(nn.Module):
+ """ Serial block class.
+ Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None):
+ super().__init__()
+
+ # Conv-Attention.
+ self.cpe = shared_cpe
+
+ self.norm1 = norm_layer(dim)
+ self.factoratt_crpe = FactorAtt_ConvRelPosEnc(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ # MLP.
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x, size: Tuple[int, int]):
+ # Conv-Attention.
+ x = self.cpe(x, size)
+ cur = self.norm1(x)
+ cur = self.factoratt_crpe(cur, size)
+ x = x + self.drop_path(cur)
+
+ # MLP.
+ cur = self.norm2(x)
+ cur = self.mlp(cur)
+ x = x + self.drop_path(cur)
+
+ return x
+
+
+class ParallelBlock(nn.Module):
+ """ Parallel block class. """
+ def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_crpes=None):
+ super().__init__()
+
+ # Conv-Attention.
+ self.norm12 = norm_layer(dims[1])
+ self.norm13 = norm_layer(dims[2])
+ self.norm14 = norm_layer(dims[3])
+ self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc(
+ dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ shared_crpe=shared_crpes[1]
+ )
+ self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc(
+ dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ shared_crpe=shared_crpes[2]
+ )
+ self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc(
+ dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
+ shared_crpe=shared_crpes[3]
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ # MLP.
+ self.norm22 = norm_layer(dims[1])
+ self.norm23 = norm_layer(dims[2])
+ self.norm24 = norm_layer(dims[3])
+ # In parallel block, we assume dimensions are the same and share the linear transformation.
+ assert dims[1] == dims[2] == dims[3]
+ assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
+ mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
+ self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
+ in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def upsample(self, x, factor: float, size: Tuple[int, int]):
+ """ Feature map up-sampling. """
+ return self.interpolate(x, scale_factor=factor, size=size)
+
+ def downsample(self, x, factor: float, size: Tuple[int, int]):
+ """ Feature map down-sampling. """
+ return self.interpolate(x, scale_factor=1.0/factor, size=size)
+
+ def interpolate(self, x, scale_factor: float, size: Tuple[int, int]):
+ """ Feature map interpolation. """
+ B, N, C = x.shape
+ H, W = size
+ _assert(N == 1 + H * W, '')
+
+ cls_token = x[:, :1, :]
+ img_tokens = x[:, 1:, :]
+
+ img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
+ img_tokens = F.interpolate(
+ img_tokens, scale_factor=scale_factor, recompute_scale_factor=False, mode='bilinear', align_corners=False)
+ img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
+
+ out = torch.cat((cls_token, img_tokens), dim=1)
+
+ return out
+
+ def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]):
+ _, S2, S3, S4 = sizes
+ cur2 = self.norm12(x2)
+ cur3 = self.norm13(x3)
+ cur4 = self.norm14(x4)
+ cur2 = self.factoratt_crpe2(cur2, size=S2)
+ cur3 = self.factoratt_crpe3(cur3, size=S3)
+ cur4 = self.factoratt_crpe4(cur4, size=S4)
+ upsample3_2 = self.upsample(cur3, factor=2., size=S3)
+ upsample4_3 = self.upsample(cur4, factor=2., size=S4)
+ upsample4_2 = self.upsample(cur4, factor=4., size=S4)
+ downsample2_3 = self.downsample(cur2, factor=2., size=S2)
+ downsample3_4 = self.downsample(cur3, factor=2., size=S3)
+ downsample2_4 = self.downsample(cur2, factor=4., size=S2)
+ cur2 = cur2 + upsample3_2 + upsample4_2
+ cur3 = cur3 + upsample4_3 + downsample2_3
+ cur4 = cur4 + downsample3_4 + downsample2_4
+ x2 = x2 + self.drop_path(cur2)
+ x3 = x3 + self.drop_path(cur3)
+ x4 = x4 + self.drop_path(cur4)
+
+ # MLP.
+ cur2 = self.norm22(x2)
+ cur3 = self.norm23(x3)
+ cur4 = self.norm24(x4)
+ cur2 = self.mlp2(cur2)
+ cur3 = self.mlp3(cur3)
+ cur4 = self.mlp4(cur4)
+ x2 = x2 + self.drop_path(cur2)
+ x3 = x3 + self.drop_path(cur3)
+ x4 = x4 + self.drop_path(cur4)
+
+ return x1, x2, x3, x4
+
+
+class CoaT(nn.Module):
+ """ CoaT class. """
+ def __init__(
+ self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0),
+ serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ return_interm_layers=False, out_features=None, crpe_window=None, **kwargs):
+ super().__init__()
+ crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
+ self.return_interm_layers = return_interm_layers
+ self.out_features = out_features
+ self.embed_dims = embed_dims
+ self.num_features = embed_dims[-1]
+ self.num_classes = num_classes
+
+ # Patch embeddings.
+ img_size = to_2tuple(img_size)
+ self.patch_embed1 = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans,
+ embed_dim=embed_dims[0], norm_layer=nn.LayerNorm)
+ self.patch_embed2 = PatchEmbed(
+ img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0],
+ embed_dim=embed_dims[1], norm_layer=nn.LayerNorm)
+ self.patch_embed3 = PatchEmbed(
+ img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1],
+ embed_dim=embed_dims[2], norm_layer=nn.LayerNorm)
+ self.patch_embed4 = PatchEmbed(
+ img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2],
+ embed_dim=embed_dims[3], norm_layer=nn.LayerNorm)
+
+ # Class tokens.
+ self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0]))
+ self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1]))
+ self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2]))
+ self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))
+
+ # Convolutional position encodings.
+ self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3)
+ self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3)
+ self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3)
+ self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3)
+
+ # Convolutional relative position encodings.
+ self.crpe1 = ConvRelPosEnc(Ch=embed_dims[0] // num_heads, h=num_heads, window=crpe_window)
+ self.crpe2 = ConvRelPosEnc(Ch=embed_dims[1] // num_heads, h=num_heads, window=crpe_window)
+ self.crpe3 = ConvRelPosEnc(Ch=embed_dims[2] // num_heads, h=num_heads, window=crpe_window)
+ self.crpe4 = ConvRelPosEnc(Ch=embed_dims[3] // num_heads, h=num_heads, window=crpe_window)
+
+ # Disable stochastic depth.
+ dpr = drop_path_rate
+ assert dpr == 0.0
+
+ # Serial blocks 1.
+ self.serial_blocks1 = nn.ModuleList([
+ SerialBlock(
+ dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
+ shared_cpe=self.cpe1, shared_crpe=self.crpe1
+ )
+ for _ in range(serial_depths[0])]
+ )
+
+ # Serial blocks 2.
+ self.serial_blocks2 = nn.ModuleList([
+ SerialBlock(
+ dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
+ shared_cpe=self.cpe2, shared_crpe=self.crpe2
+ )
+ for _ in range(serial_depths[1])]
+ )
+
+ # Serial blocks 3.
+ self.serial_blocks3 = nn.ModuleList([
+ SerialBlock(
+ dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
+ shared_cpe=self.cpe3, shared_crpe=self.crpe3
+ )
+ for _ in range(serial_depths[2])]
+ )
+
+ # Serial blocks 4.
+ self.serial_blocks4 = nn.ModuleList([
+ SerialBlock(
+ dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
+ shared_cpe=self.cpe4, shared_crpe=self.crpe4
+ )
+ for _ in range(serial_depths[3])]
+ )
+
+ # Parallel blocks.
+ self.parallel_depth = parallel_depth
+ if self.parallel_depth > 0:
+ self.parallel_blocks = nn.ModuleList([
+ ParallelBlock(
+ dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
+ shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4)
+ )
+ for _ in range(parallel_depth)]
+ )
+ else:
+ self.parallel_blocks = None
+
+ # Classification head(s).
+ if not self.return_interm_layers:
+ if self.parallel_blocks is not None:
+ self.norm2 = norm_layer(embed_dims[1])
+ self.norm3 = norm_layer(embed_dims[2])
+ else:
+ self.norm2 = self.norm3 = None
+ self.norm4 = norm_layer(embed_dims[3])
+
+ if self.parallel_depth > 0:
+ # CoaT series: Aggregate features of last three scales for classification.
+ assert embed_dims[1] == embed_dims[2] == embed_dims[3]
+ self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+ else:
+ # CoaT-Lite series: Use feature of last scale for classification.
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ # Initialize weights.
+ trunc_normal_(self.cls_token1, std=.02)
+ trunc_normal_(self.cls_token2, std=.02)
+ trunc_normal_(self.cls_token3, std=.02)
+ trunc_normal_(self.cls_token4, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ def insert_cls(self, x, cls_token):
+ """ Insert CLS token. """
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ return x
+
+ def remove_cls(self, x):
+ """ Remove CLS token. """
+ return x[:, 1:, :]
+
+ def forward_features(self, x0):
+ B = x0.shape[0]
+
+ # Serial blocks 1.
+ x1 = self.patch_embed1(x0)
+ H1, W1 = self.patch_embed1.grid_size
+ x1 = self.insert_cls(x1, self.cls_token1)
+ for blk in self.serial_blocks1:
+ x1 = blk(x1, size=(H1, W1))
+ x1_nocls = self.remove_cls(x1)
+ x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
+
+ # Serial blocks 2.
+ x2 = self.patch_embed2(x1_nocls)
+ H2, W2 = self.patch_embed2.grid_size
+ x2 = self.insert_cls(x2, self.cls_token2)
+ for blk in self.serial_blocks2:
+ x2 = blk(x2, size=(H2, W2))
+ x2_nocls = self.remove_cls(x2)
+ x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
+
+ # Serial blocks 3.
+ x3 = self.patch_embed3(x2_nocls)
+ H3, W3 = self.patch_embed3.grid_size
+ x3 = self.insert_cls(x3, self.cls_token3)
+ for blk in self.serial_blocks3:
+ x3 = blk(x3, size=(H3, W3))
+ x3_nocls = self.remove_cls(x3)
+ x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
+
+ # Serial blocks 4.
+ x4 = self.patch_embed4(x3_nocls)
+ H4, W4 = self.patch_embed4.grid_size
+ x4 = self.insert_cls(x4, self.cls_token4)
+ for blk in self.serial_blocks4:
+ x4 = blk(x4, size=(H4, W4))
+ x4_nocls = self.remove_cls(x4)
+ x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
+
+ # Only serial blocks: Early return.
+ if self.parallel_blocks is None:
+ if not torch.jit.is_scripting() and self.return_interm_layers:
+ # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
+ feat_out = {}
+ if 'x1_nocls' in self.out_features:
+ feat_out['x1_nocls'] = x1_nocls
+ if 'x2_nocls' in self.out_features:
+ feat_out['x2_nocls'] = x2_nocls
+ if 'x3_nocls' in self.out_features:
+ feat_out['x3_nocls'] = x3_nocls
+ if 'x4_nocls' in self.out_features:
+ feat_out['x4_nocls'] = x4_nocls
+ return feat_out
+ else:
+ # Return features for classification.
+ x4 = self.norm4(x4)
+ x4_cls = x4[:, 0]
+ return x4_cls
+
+ # Parallel blocks.
+ for blk in self.parallel_blocks:
+ x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4))
+ x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
+
+ if not torch.jit.is_scripting() and self.return_interm_layers:
+ # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
+ feat_out = {}
+ if 'x1_nocls' in self.out_features:
+ x1_nocls = self.remove_cls(x1)
+ x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
+ feat_out['x1_nocls'] = x1_nocls
+ if 'x2_nocls' in self.out_features:
+ x2_nocls = self.remove_cls(x2)
+ x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
+ feat_out['x2_nocls'] = x2_nocls
+ if 'x3_nocls' in self.out_features:
+ x3_nocls = self.remove_cls(x3)
+ x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
+ feat_out['x3_nocls'] = x3_nocls
+ if 'x4_nocls' in self.out_features:
+ x4_nocls = self.remove_cls(x4)
+ x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
+ feat_out['x4_nocls'] = x4_nocls
+ return feat_out
+ else:
+ x2 = self.norm2(x2)
+ x3 = self.norm3(x3)
+ x4 = self.norm4(x4)
+ x2_cls = x2[:, :1] # [B, 1, C]
+ x3_cls = x3[:, :1]
+ x4_cls = x4[:, :1]
+ merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # [B, 3, C]
+ merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C]
+ return merged_cls
+
+ def forward(self, x):
+ if self.return_interm_layers:
+ # Return intermediate features (for down-stream tasks).
+ return self.forward_features(x)
+ else:
+ # Return features for classification.
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def checkpoint_filter_fn(state_dict, model):
+ out_dict = {}
+ for k, v in state_dict.items():
+ # original model had unused norm layers, removing them requires filtering pretrained checkpoints
+ if k.startswith('norm1') or \
+ (model.norm2 is None and k.startswith('norm2')) or \
+ (model.norm3 is None and k.startswith('norm3')):
+ continue
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ model = build_model_with_cfg(
+ CoaT, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+ return model
+
+
+@register_model
+def coat_tiny(pretrained=False, **kwargs):
+ model_cfg = dict(
+ patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6,
+ num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
+ model = _create_coat('coat_tiny', pretrained=pretrained, **model_cfg)
+ return model
+
+
+@register_model
+def coat_mini(pretrained=False, **kwargs):
+ model_cfg = dict(
+ patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6,
+ num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
+ model = _create_coat('coat_mini', pretrained=pretrained, **model_cfg)
+ return model
+
+
+@register_model
+def coat_lite_tiny(pretrained=False, **kwargs):
+ model_cfg = dict(
+ patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0,
+ num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
+ model = _create_coat('coat_lite_tiny', pretrained=pretrained, **model_cfg)
+ return model
+
+
+@register_model
+def coat_lite_mini(pretrained=False, **kwargs):
+ model_cfg = dict(
+ patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0,
+ num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
+ model = _create_coat('coat_lite_mini', pretrained=pretrained, **model_cfg)
+ return model
+
+
+@register_model
+def coat_lite_small(pretrained=False, **kwargs):
+ model_cfg = dict(
+ patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0,
+ num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
+ model = _create_coat('coat_lite_small', pretrained=pretrained, **model_cfg)
+ return model
\ No newline at end of file
diff --git a/timm/models/convit.py b/timm/models/convit.py
new file mode 100644
index 0000000..6ef1da7
--- /dev/null
+++ b/timm/models/convit.py
@@ -0,0 +1,351 @@
+""" ConViT Model
+
+@article{d2021convit,
+ title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases},
+ author={d'Ascoli, St{\'e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent},
+ journal={arXiv preprint arXiv:2103.10697},
+ year={2021}
+}
+
+Paper link: https://arxiv.org/abs/2103.10697
+Original code: https://github.com/facebookresearch/convit, original copyright below
+"""
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+#
+'''These modules are adapted from those of timm, see
+https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+'''
+
+import torch
+import torch.nn as nn
+from functools import partial
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
+from .registry import register_model
+from .vision_transformer_hybrid import HybridEmbed
+from .fx_features import register_notrace_module
+
+import torch
+import torch.nn as nn
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # ConViT
+ 'convit_tiny': _cfg(
+ url="https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"),
+ 'convit_small': _cfg(
+ url="https://dl.fbaipublicfiles.com/convit/convit_small.pth"),
+ 'convit_base': _cfg(
+ url="https://dl.fbaipublicfiles.com/convit/convit_base.pth")
+}
+
+
+@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
+class GPSA(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
+ locality_strength=1.):
+ super().__init__()
+ self.num_heads = num_heads
+ self.dim = dim
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+ self.locality_strength = locality_strength
+
+ self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.pos_proj = nn.Linear(3, num_heads)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.gating_param = nn.Parameter(torch.ones(self.num_heads))
+ self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None
+
+ def forward(self, x):
+ B, N, C = x.shape
+ if self.rel_indices is None or self.rel_indices.shape[1] != N:
+ self.rel_indices = self.get_rel_indices(N)
+ attn = self.get_attention(x)
+ v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def get_attention(self, x):
+ B, N, C = x.shape
+ qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k = qk[0], qk[1]
+ pos_score = self.rel_indices.expand(B, -1, -1, -1)
+ pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
+ patch_score = (q @ k.transpose(-2, -1)) * self.scale
+ patch_score = patch_score.softmax(dim=-1)
+ pos_score = pos_score.softmax(dim=-1)
+
+ gating = self.gating_param.view(1, -1, 1, 1)
+ attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score
+ attn /= attn.sum(dim=-1).unsqueeze(-1)
+ attn = self.attn_drop(attn)
+ return attn
+
+ def get_attention_map(self, x, return_map=False):
+ attn_map = self.get_attention(x).mean(0) # average over batch
+ distances = self.rel_indices.squeeze()[:, :, -1] ** .5
+ dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / distances.size(0)
+ if return_map:
+ return dist, attn_map
+ else:
+ return dist
+
+ def local_init(self):
+ self.v.weight.data.copy_(torch.eye(self.dim))
+ locality_distance = 1 # max(1,1/locality_strength**.5)
+
+ kernel_size = int(self.num_heads ** .5)
+ center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2
+ for h1 in range(kernel_size):
+ for h2 in range(kernel_size):
+ position = h1 + kernel_size * h2
+ self.pos_proj.weight.data[position, 2] = -1
+ self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance
+ self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance
+ self.pos_proj.weight.data *= self.locality_strength
+
+ def get_rel_indices(self, num_patches: int) -> torch.Tensor:
+ img_size = int(num_patches ** .5)
+ rel_indices = torch.zeros(1, num_patches, num_patches, 3)
+ ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
+ indx = ind.repeat(img_size, img_size)
+ indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
+ indd = indx ** 2 + indy ** 2
+ rel_indices[:, :, :, 2] = indd.unsqueeze(0)
+ rel_indices[:, :, :, 1] = indy.unsqueeze(0)
+ rel_indices[:, :, :, 0] = indx.unsqueeze(0)
+ device = self.qk.weight.device
+ return rel_indices.to(device)
+
+
+class MHSA(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def get_attention_map(self, x, return_map=False):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ attn_map = (q @ k.transpose(-2, -1)) * self.scale
+ attn_map = attn_map.softmax(dim=-1).mean(0)
+
+ img_size = int(N ** .5)
+ ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
+ indx = ind.repeat(img_size, img_size)
+ indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
+ indd = indx ** 2 + indy ** 2
+ distances = indd ** .5
+ distances = distances.to('cuda')
+
+ dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N
+ if return_map:
+ return dist, attn_map
+ else:
+ return dist
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.use_gpsa = use_gpsa
+ if self.use_gpsa:
+ self.attn = GPSA(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, **kwargs)
+ else:
+ self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class ConViT(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None,
+ local_up_to_layer=3, locality_strength=1., use_pos_embed=True):
+ super().__init__()
+ embed_dim *= num_heads
+ self.num_classes = num_classes
+ self.local_up_to_layer = local_up_to_layer
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.locality_strength = locality_strength
+ self.use_pos_embed = use_pos_embed
+
+ if hybrid_backbone is not None:
+ self.patch_embed = HybridEmbed(
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
+ else:
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+ self.num_patches = num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ if self.use_pos_embed:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.pos_embed, std=.02)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ use_gpsa=True,
+ locality_strength=locality_strength)
+ if i < local_up_to_layer else
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ use_gpsa=False)
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ # Classifier head
+ self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+ for n, m in self.named_modules():
+ if hasattr(m, 'local_init'):
+ m.local_init()
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+
+ if self.use_pos_embed:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ for u, blk in enumerate(self.blocks):
+ if u == self.local_up_to_layer:
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = blk(x)
+
+ x = self.norm(x)
+ return x[:, 0]
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def _create_convit(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ return build_model_with_cfg(
+ ConViT, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ **kwargs)
+
+
+@register_model
+def convit_tiny(pretrained=False, **kwargs):
+ model_args = dict(
+ local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
+ num_heads=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model = _create_convit(variant='convit_tiny', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def convit_small(pretrained=False, **kwargs):
+ model_args = dict(
+ local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
+ num_heads=9, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model = _create_convit(variant='convit_small', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def convit_base(pretrained=False, **kwargs):
+ model_args = dict(
+ local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
+ num_heads=16, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model = _create_convit(variant='convit_base', pretrained=pretrained, **model_args)
+ return model
diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py
new file mode 100644
index 0000000..a240078
--- /dev/null
+++ b/timm/models/convmixer.py
@@ -0,0 +1,101 @@
+import torch.nn as nn
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.models.registry import register_model
+from .helpers import build_model_with_cfg
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .96, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
+ 'first_conv': 'stem.0',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'convmixer_1536_20': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1536_20_ks9_p7.pth.tar'),
+ 'convmixer_768_32': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_768_32_ks7_p7_relu.pth.tar'),
+ 'convmixer_1024_20_ks9_p14': _cfg(url='https://github.com/tmp-iclr/convmixer/releases/download/timm-v1.0/convmixer_1024_20_ks9_p14.pth.tar')
+}
+
+
+class Residual(nn.Module):
+ def __init__(self, fn):
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x):
+ return self.fn(x) + x
+
+
+class ConvMixer(nn.Module):
+ def __init__(self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, activation=nn.GELU, **kwargs):
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = dim
+ self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity()
+ self.stem = nn.Sequential(
+ nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size),
+ activation(),
+ nn.BatchNorm2d(dim)
+ )
+ self.blocks = nn.Sequential(
+ *[nn.Sequential(
+ Residual(nn.Sequential(
+ nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
+ activation(),
+ nn.BatchNorm2d(dim)
+ )),
+ nn.Conv2d(dim, dim, kernel_size=1),
+ activation(),
+ nn.BatchNorm2d(dim)
+ ) for i in range(depth)]
+ )
+ self.pooling = nn.Sequential(
+ nn.AdaptiveAvgPool2d((1, 1)),
+ nn.Flatten()
+ )
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ x = self.blocks(x)
+ x = self.pooling(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+
+ return x
+
+
+def _create_convmixer(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(ConvMixer, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs)
+
+
+@register_model
+def convmixer_1536_20(pretrained=False, **kwargs):
+ model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs)
+ return _create_convmixer('convmixer_1536_20', pretrained, **model_args)
+
+
+@register_model
+def convmixer_768_32(pretrained=False, **kwargs):
+ model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, activation=nn.ReLU, **kwargs)
+ return _create_convmixer('convmixer_768_32', pretrained, **model_args)
+
+
+@register_model
+def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs):
+ model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs)
+ return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args)
\ No newline at end of file
diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py
new file mode 100644
index 0000000..ddc4f64
--- /dev/null
+++ b/timm/models/crossvit.py
@@ -0,0 +1,517 @@
+""" CrossViT Model
+
+@inproceedings{
+ chen2021crossvit,
+ title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}},
+ author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda},
+ booktitle={International Conference on Computer Vision (ICCV)},
+ year={2021}
+}
+
+Paper link: https://arxiv.org/abs/2103.14899
+Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
+
+NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408
+"""
+
+# Copyright IBM All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+"""
+Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+
+"""
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.hub
+from functools import partial
+from typing import List
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .fx_features import register_notrace_function
+from .helpers import build_model_with_cfg
+from .layers import DropPath, to_2tuple, trunc_normal_, _assert
+from .registry import register_model
+from .vision_transformer import Mlp, Block
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
+ 'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
+ 'classifier': ('head.0', 'head.1'),
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'crossvit_15_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_224.pth'),
+ 'crossvit_15_dagger_240': _cfg(
+ url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_224.pth',
+ first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
+ ),
+ 'crossvit_15_dagger_408': _cfg(
+ url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth',
+ input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
+ ),
+ 'crossvit_18_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'),
+ 'crossvit_18_dagger_240': _cfg(
+ url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_224.pth',
+ first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
+ ),
+ 'crossvit_18_dagger_408': _cfg(
+ url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth',
+ input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
+ ),
+ 'crossvit_9_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'),
+ 'crossvit_9_dagger_240': _cfg(
+ url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_dagger_224.pth',
+ first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
+ ),
+ 'crossvit_base_240': _cfg(
+ url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth'),
+ 'crossvit_small_240': _cfg(
+ url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_small_224.pth'),
+ 'crossvit_tiny_240': _cfg(
+ url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth'),
+}
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ if multi_conv:
+ if patch_size[0] == 12:
+ self.proj = nn.Sequential(
+ nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1),
+ )
+ elif patch_size[0] == 16:
+ self.proj = nn.Sequential(
+ nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
+ )
+ else:
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ _assert(H == self.img_size[0],
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
+ _assert(W == self.img_size[1],
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
+ self.wk = nn.Linear(dim, dim, bias=qkv_bias)
+ self.wv = nn.Linear(dim, dim, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ # B1C -> B1H(C/H) -> BH1(C/H)
+ q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ # BNC -> BNH(C/H) -> BHN(C/H)
+ k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ # BNC -> BNH(C/H) -> BHN(C/H)
+ v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class CrossAttentionBlock(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = CrossAttention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
+
+ return x
+
+
+class MultiScaleBlock(nn.Module):
+
+ def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+
+ num_branches = len(dim)
+ self.num_branches = num_branches
+ # different branch could have different embedding size, the first one is the base
+ self.blocks = nn.ModuleList()
+ for d in range(num_branches):
+ tmp = []
+ for i in range(depth[d]):
+ tmp.append(Block(
+ dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer))
+ if len(tmp) != 0:
+ self.blocks.append(nn.Sequential(*tmp))
+
+ if len(self.blocks) == 0:
+ self.blocks = None
+
+ self.projs = nn.ModuleList()
+ for d in range(num_branches):
+ if dim[d] == dim[(d + 1) % num_branches] and False:
+ tmp = [nn.Identity()]
+ else:
+ tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches])]
+ self.projs.append(nn.Sequential(*tmp))
+
+ self.fusion = nn.ModuleList()
+ for d in range(num_branches):
+ d_ = (d + 1) % num_branches
+ nh = num_heads[d_]
+ if depth[-1] == 0: # backward capability:
+ self.fusion.append(
+ CrossAttentionBlock(
+ dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
+ else:
+ tmp = []
+ for _ in range(depth[-1]):
+ tmp.append(CrossAttentionBlock(
+ dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
+ self.fusion.append(nn.Sequential(*tmp))
+
+ self.revert_projs = nn.ModuleList()
+ for d in range(num_branches):
+ if dim[(d + 1) % num_branches] == dim[d] and False:
+ tmp = [nn.Identity()]
+ else:
+ tmp = [norm_layer(dim[(d + 1) % num_branches]), act_layer(),
+ nn.Linear(dim[(d + 1) % num_branches], dim[d])]
+ self.revert_projs.append(nn.Sequential(*tmp))
+
+ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
+
+ outs_b = []
+ for i, block in enumerate(self.blocks):
+ outs_b.append(block(x[i]))
+
+ # only take the cls token out
+ proj_cls_token = torch.jit.annotate(List[torch.Tensor], [])
+ for i, proj in enumerate(self.projs):
+ proj_cls_token.append(proj(outs_b[i][:, 0:1, ...]))
+
+ # cross attention
+ outs = []
+ for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)):
+ tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1)
+ tmp = fusion(tmp)
+ reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...])
+ tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1)
+ outs.append(tmp)
+ return outs
+
+
+def _compute_num_patches(img_size, patches):
+ return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
+
+
+@register_notrace_function
+def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
+ """
+ Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
+ Args:
+ x (Tensor): input image
+ ss (tuple[int, int]): height and width to scale to
+ crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
+ Returns:
+ Tensor: the "scaled" image batch tensor
+ """
+ H, W = x.shape[-2:]
+ if H != ss[0] or W != ss[1]:
+ if crop_scale and ss[0] <= H and ss[1] <= W:
+ cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
+ x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
+ else:
+ x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
+ return x
+
+
+class CrossViT(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+
+ def __init__(
+ self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000,
+ embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.),
+ qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False, crop_scale=False,
+ ):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.img_size = to_2tuple(img_size)
+ img_scale = to_2tuple(img_scale)
+ self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale]
+ self.crop_scale = crop_scale # crop instead of interpolate for scale
+ num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
+ self.num_branches = len(patch_size)
+ self.embed_dim = embed_dim
+ self.num_features = embed_dim[0] # to pass the tests
+ self.patch_embed = nn.ModuleList()
+
+ # hard-coded for torch jit script
+ for i in range(self.num_branches):
+ setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])))
+ setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i])))
+
+ for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim):
+ self.patch_embed.append(
+ PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv))
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ total_depth = sum([sum(x[-2:]) for x in depth])
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule
+ dpr_ptr = 0
+ self.blocks = nn.ModuleList()
+ for idx, block_cfg in enumerate(depth):
+ curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
+ dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
+ blk = MultiScaleBlock(
+ embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_, norm_layer=norm_layer)
+ dpr_ptr += curr_depth
+ self.blocks.append(blk)
+
+ self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])
+ self.head = nn.ModuleList([
+ nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
+ for i in range(self.num_branches)])
+
+ for i in range(self.num_branches):
+ trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
+ trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ out = set()
+ for i in range(self.num_branches):
+ out.add(f'cls_token_{i}')
+ pe = getattr(self, f'pos_embed_{i}', None)
+ if pe is not None and pe.requires_grad:
+ out.add(f'pos_embed_{i}')
+ return out
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.ModuleList(
+ [nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in
+ range(self.num_branches)])
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ xs = []
+ for i, patch_embed in enumerate(self.patch_embed):
+ x_ = x
+ ss = self.img_size_scaled[i]
+ x_ = scale_image(x_, ss, self.crop_scale)
+ x_ = patch_embed(x_)
+ cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
+ cls_tokens = cls_tokens.expand(B, -1, -1)
+ x_ = torch.cat((cls_tokens, x_), dim=1)
+ pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script
+ x_ = x_ + pos_embed
+ x_ = self.pos_drop(x_)
+ xs.append(x_)
+
+ for i, blk in enumerate(self.blocks):
+ xs = blk(xs)
+
+ # NOTE: was before branch token section, move to here to assure all branch token are before layer norm
+ xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
+ return [xo[:, 0] for xo in xs]
+
+ def forward(self, x):
+ xs = self.forward_features(x)
+ ce_logits = [head(xs[i]) for i, head in enumerate(self.head)]
+ if not isinstance(self.head[0], nn.Identity):
+ ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
+ return ce_logits
+
+
+def _create_crossvit(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ def pretrained_filter_fn(state_dict):
+ new_state_dict = {}
+ for key in state_dict.keys():
+ if 'pos_embed' in key or 'cls_token' in key:
+ new_key = key.replace(".", "_")
+ else:
+ new_key = key
+ new_state_dict[new_key] = state_dict[key]
+ return new_state_dict
+
+ return build_model_with_cfg(
+ CrossViT, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ pretrained_filter_fn=pretrained_filter_fn,
+ **kwargs)
+
+
+@register_model
+def crossvit_tiny_240(pretrained=False, **kwargs):
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
+ num_heads=[3, 3], mlp_ratio=[4, 4, 1], **kwargs)
+ model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def crossvit_small_240(pretrained=False, **kwargs):
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
+ num_heads=[6, 6], mlp_ratio=[4, 4, 1], **kwargs)
+ model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def crossvit_base_240(pretrained=False, **kwargs):
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
+ num_heads=[12, 12], mlp_ratio=[4, 4, 1], **kwargs)
+ model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def crossvit_9_240(pretrained=False, **kwargs):
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
+ num_heads=[4, 4], mlp_ratio=[3, 3, 1], **kwargs)
+ model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def crossvit_15_240(pretrained=False, **kwargs):
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
+ num_heads=[6, 6], mlp_ratio=[3, 3, 1], **kwargs)
+ model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def crossvit_18_240(pretrained=False, **kwargs):
+ model_args = dict(
+ img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
+ num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs)
+ model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def crossvit_9_dagger_240(pretrained=False, **kwargs):
+ model_args = dict(
+ img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
+ num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
+ model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def crossvit_15_dagger_240(pretrained=False, **kwargs):
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
+ num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
+ model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def crossvit_15_dagger_408(pretrained=False, **kwargs):
+ model_args = dict(
+ img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
+ num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
+ model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def crossvit_18_dagger_240(pretrained=False, **kwargs):
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
+ num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
+ model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def crossvit_18_dagger_408(pretrained=False, **kwargs):
+ model_args = dict(
+ img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
+ num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
+ model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **model_args)
+ return model
diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py
new file mode 100644
index 0000000..39d1620
--- /dev/null
+++ b/timm/models/cspnet.py
@@ -0,0 +1,457 @@
+"""PyTorch CspNet
+
+A PyTorch implementation of Cross Stage Partial Networks including:
+* CSPResNet50
+* CSPResNeXt50
+* CSPDarkNet53
+* and DarkNet53 for good measure
+
+Based on paper `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
+
+Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStagePartialNetworks
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import ClassifierHead, ConvBnAct, DropPath, create_attn, get_norm_act_layer
+from .registry import register_model
+
+
+__all__ = ['CspNet'] # model_registry will add each entrypoint fn to this
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
+ 'crop_pct': 0.887, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'cspresnet50': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth'),
+ 'cspresnet50d': _cfg(url=''),
+ 'cspresnet50w': _cfg(url=''),
+ 'cspresnext50': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth',
+ input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.875 # FIXME I trained this at 224x224, not 256 like ref impl
+ ),
+ 'cspresnext50_iabn': _cfg(url=''),
+ 'cspdarknet53': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'),
+ 'cspdarknet53_iabn': _cfg(url=''),
+ 'darknet53': _cfg(url=''),
+}
+
+
+model_cfgs = dict(
+ cspresnet50=dict(
+ stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'),
+ stage=dict(
+ out_chs=(128, 256, 512, 1024),
+ depth=(3, 3, 5, 2),
+ stride=(1,) + (2,) * 3,
+ exp_ratio=(2.,) * 4,
+ bottle_ratio=(0.5,) * 4,
+ block_ratio=(1.,) * 4,
+ cross_linear=True,
+ )
+ ),
+ cspresnet50d=dict(
+ stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'),
+ stage=dict(
+ out_chs=(128, 256, 512, 1024),
+ depth=(3, 3, 5, 2),
+ stride=(1,) + (2,) * 3,
+ exp_ratio=(2.,) * 4,
+ bottle_ratio=(0.5,) * 4,
+ block_ratio=(1.,) * 4,
+ cross_linear=True,
+ )
+ ),
+ cspresnet50w=dict(
+ stem=dict(out_chs=[32, 32, 64], kernel_size=3, stride=2, pool='max'),
+ stage=dict(
+ out_chs=(256, 512, 1024, 2048),
+ depth=(3, 3, 5, 2),
+ stride=(1,) + (2,) * 3,
+ exp_ratio=(1.,) * 4,
+ bottle_ratio=(0.25,) * 4,
+ block_ratio=(0.5,) * 4,
+ cross_linear=True,
+ )
+ ),
+ cspresnext50=dict(
+ stem=dict(out_chs=64, kernel_size=7, stride=2, pool='max'),
+ stage=dict(
+ out_chs=(256, 512, 1024, 2048),
+ depth=(3, 3, 5, 2),
+ stride=(1,) + (2,) * 3,
+ groups=(32,) * 4,
+ exp_ratio=(1.,) * 4,
+ bottle_ratio=(1.,) * 4,
+ block_ratio=(0.5,) * 4,
+ cross_linear=True,
+ )
+ ),
+ cspdarknet53=dict(
+ stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
+ stage=dict(
+ out_chs=(64, 128, 256, 512, 1024),
+ depth=(1, 2, 8, 8, 4),
+ stride=(2,) * 5,
+ exp_ratio=(2.,) + (1.,) * 4,
+ bottle_ratio=(0.5,) + (1.0,) * 4,
+ block_ratio=(1.,) + (0.5,) * 4,
+ down_growth=True,
+ )
+ ),
+ darknet53=dict(
+ stem=dict(out_chs=32, kernel_size=3, stride=1, pool=''),
+ stage=dict(
+ out_chs=(64, 128, 256, 512, 1024),
+ depth=(1, 2, 8, 8, 4),
+ stride=(2,) * 5,
+ bottle_ratio=(0.5,) * 5,
+ block_ratio=(1.,) * 5,
+ )
+ )
+)
+
+
+def create_stem(
+ in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='',
+ act_layer=None, norm_layer=None, aa_layer=None):
+ stem = nn.Sequential()
+ if not isinstance(out_chs, (tuple, list)):
+ out_chs = [out_chs]
+ assert len(out_chs)
+ in_c = in_chans
+ for i, out_c in enumerate(out_chs):
+ conv_name = f'conv{i + 1}'
+ stem.add_module(conv_name, ConvBnAct(
+ in_c, out_c, kernel_size, stride=stride if i == 0 else 1,
+ act_layer=act_layer, norm_layer=norm_layer))
+ in_c = out_c
+ last_conv = conv_name
+ if pool:
+ if aa_layer is not None:
+ stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
+ stem.add_module('aa', aa_layer(channels=in_c, stride=2))
+ else:
+ stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
+ return stem, dict(num_chs=in_c, reduction=stride, module='.'.join(['stem', last_conv]))
+
+
+class ResBottleneck(nn.Module):
+ """ ResNe(X)t Bottleneck Block
+ """
+
+ def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.25, groups=1,
+ act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_last=False,
+ attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
+ super(ResBottleneck, self).__init__()
+ mid_chs = int(round(out_chs * bottle_ratio))
+ ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block)
+
+ self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
+ self.conv2 = ConvBnAct(mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs)
+ self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None
+ self.conv3 = ConvBnAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs)
+ self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None
+ self.drop_path = drop_path
+ self.act3 = act_layer(inplace=True)
+
+ def zero_init_last_bn(self):
+ nn.init.zeros_(self.conv3.bn.weight)
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ if self.attn2 is not None:
+ x = self.attn2(x)
+ x = self.conv3(x)
+ if self.attn3 is not None:
+ x = self.attn3(x)
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+ x = x + shortcut
+ # FIXME partial shortcut needed if first block handled as per original, not used for my current impl
+ #x[:, :shortcut.size(1)] += shortcut
+ x = self.act3(x)
+ return x
+
+
+class DarkBlock(nn.Module):
+ """ DarkNet Block
+ """
+
+ def __init__(self, in_chs, out_chs, dilation=1, bottle_ratio=0.5, groups=1,
+ act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None,
+ drop_block=None, drop_path=None):
+ super(DarkBlock, self).__init__()
+ mid_chs = int(round(out_chs * bottle_ratio))
+ ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block)
+ self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
+ self.conv2 = ConvBnAct(mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs)
+ self.attn = create_attn(attn_layer, channels=out_chs)
+ self.drop_path = drop_path
+
+ def zero_init_last_bn(self):
+ nn.init.zeros_(self.conv2.bn.weight)
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ if self.attn is not None:
+ x = self.attn(x)
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+ x = x + shortcut
+ return x
+
+
+class CrossStage(nn.Module):
+ """Cross Stage."""
+ def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., exp_ratio=1.,
+ groups=1, first_dilation=None, down_growth=False, cross_linear=False, block_dpr=None,
+ block_fn=ResBottleneck, **block_kwargs):
+ super(CrossStage, self).__init__()
+ first_dilation = first_dilation or dilation
+ down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
+ exp_chs = int(round(out_chs * exp_ratio))
+ block_out_chs = int(round(out_chs * block_ratio))
+ conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
+
+ if stride != 1 or first_dilation != dilation:
+ self.conv_down = ConvBnAct(
+ in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
+ aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs)
+ prev_chs = down_chs
+ else:
+ self.conv_down = None
+ prev_chs = in_chs
+
+ # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also,
+ # there is also special case for the first stage for some of the model that results in uneven split
+ # across the two paths. I did it this way for simplicity for now.
+ self.conv_exp = ConvBnAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs)
+ prev_chs = exp_chs // 2 # output of conv_exp is always split in two
+
+ self.blocks = nn.Sequential()
+ for i in range(depth):
+ drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None
+ self.blocks.add_module(str(i), block_fn(
+ prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs))
+ prev_chs = block_out_chs
+
+ # transition convs
+ self.conv_transition_b = ConvBnAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs)
+ self.conv_transition = ConvBnAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)
+
+ def forward(self, x):
+ if self.conv_down is not None:
+ x = self.conv_down(x)
+ x = self.conv_exp(x)
+ split = x.shape[1] // 2
+ xs, xb = x[:, :split], x[:, split:]
+ xb = self.blocks(xb)
+ xb = self.conv_transition_b(xb).contiguous()
+ out = self.conv_transition(torch.cat([xs, xb], dim=1))
+ return out
+
+
+class DarkStage(nn.Module):
+ """DarkNet stage."""
+
+ def __init__(self, in_chs, out_chs, stride, dilation, depth, block_ratio=1., bottle_ratio=1., groups=1,
+ first_dilation=None, block_fn=ResBottleneck, block_dpr=None, **block_kwargs):
+ super(DarkStage, self).__init__()
+ first_dilation = first_dilation or dilation
+
+ self.conv_down = ConvBnAct(
+ in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
+ act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'),
+ aa_layer=block_kwargs.get('aa_layer', None))
+
+ prev_chs = out_chs
+ block_out_chs = int(round(out_chs * block_ratio))
+ self.blocks = nn.Sequential()
+ for i in range(depth):
+ drop_path = DropPath(block_dpr[i]) if block_dpr and block_dpr[i] else None
+ self.blocks.add_module(str(i), block_fn(
+ prev_chs, block_out_chs, dilation, bottle_ratio, groups, drop_path=drop_path, **block_kwargs))
+ prev_chs = block_out_chs
+
+ def forward(self, x):
+ x = self.conv_down(x)
+ x = self.blocks(x)
+ return x
+
+
+def _cfg_to_stage_args(cfg, curr_stride=2, output_stride=32, drop_path_rate=0.):
+ # get per stage args for stage and containing blocks, calculate strides to meet target output_stride
+ num_stages = len(cfg['depth'])
+ if 'groups' not in cfg:
+ cfg['groups'] = (1,) * num_stages
+ if 'down_growth' in cfg and not isinstance(cfg['down_growth'], (list, tuple)):
+ cfg['down_growth'] = (cfg['down_growth'],) * num_stages
+ if 'cross_linear' in cfg and not isinstance(cfg['cross_linear'], (list, tuple)):
+ cfg['cross_linear'] = (cfg['cross_linear'],) * num_stages
+ cfg['block_dpr'] = [None] * num_stages if not drop_path_rate else \
+ [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg['depth'])).split(cfg['depth'])]
+ stage_strides = []
+ stage_dilations = []
+ stage_first_dilations = []
+ dilation = 1
+ for cfg_stride in cfg['stride']:
+ stage_first_dilations.append(dilation)
+ if curr_stride >= output_stride:
+ dilation *= cfg_stride
+ stride = 1
+ else:
+ stride = cfg_stride
+ curr_stride *= stride
+ stage_strides.append(stride)
+ stage_dilations.append(dilation)
+ cfg['stride'] = stage_strides
+ cfg['dilation'] = stage_dilations
+ cfg['first_dilation'] = stage_first_dilations
+ stage_args = [dict(zip(cfg.keys(), values)) for values in zip(*cfg.values())]
+ return stage_args
+
+
+class CspNet(nn.Module):
+ """Cross Stage Partial base model.
+
+ Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
+ Ref Impl: https://github.com/WongKinYiu/CrossStagePartialNetworks
+
+ NOTE: There are differences in the way I handle the 1x1 'expansion' conv in this impl vs the
+ darknet impl. I did it this way for simplicity and less special cases.
+ """
+
+ def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0.,
+ act_layer=nn.LeakyReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_path_rate=0.,
+ zero_init_last_bn=True, stage_fn=CrossStage, block_fn=ResBottleneck):
+ super().__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ assert output_stride in (8, 16, 32)
+ layer_args = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer)
+
+ # Construct the stem
+ self.stem, stem_feat_info = create_stem(in_chans, **cfg['stem'], **layer_args)
+ self.feature_info = [stem_feat_info]
+ prev_chs = stem_feat_info['num_chs']
+ curr_stride = stem_feat_info['reduction'] # reduction does not include pool
+ if cfg['stem']['pool']:
+ curr_stride *= 2
+
+ # Construct the stages
+ per_stage_args = _cfg_to_stage_args(
+ cfg['stage'], curr_stride=curr_stride, output_stride=output_stride, drop_path_rate=drop_path_rate)
+ self.stages = nn.Sequential()
+ for i, sa in enumerate(per_stage_args):
+ self.stages.add_module(
+ str(i), stage_fn(prev_chs, **sa, **layer_args, block_fn=block_fn))
+ prev_chs = sa['out_chs']
+ curr_stride *= sa['stride']
+ self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
+
+ # Construct the head
+ self.num_features = prev_chs
+ self.head = ClassifierHead(
+ in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, mean=0.0, std=0.01)
+ nn.init.zeros_(m.bias)
+ if zero_init_last_bn:
+ for m in self.modules():
+ if hasattr(m, 'zero_init_last_bn'):
+ m.zero_init_last_bn()
+
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ x = self.stages(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def _create_cspnet(variant, pretrained=False, **kwargs):
+ cfg_variant = variant.split('_')[0]
+ return build_model_with_cfg(
+ CspNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant],
+ **kwargs)
+
+
+@register_model
+def cspresnet50(pretrained=False, **kwargs):
+ return _create_cspnet('cspresnet50', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cspresnet50d(pretrained=False, **kwargs):
+ return _create_cspnet('cspresnet50d', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cspresnet50w(pretrained=False, **kwargs):
+ return _create_cspnet('cspresnet50w', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cspresnext50(pretrained=False, **kwargs):
+ return _create_cspnet('cspresnext50', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cspresnext50_iabn(pretrained=False, **kwargs):
+ norm_layer = get_norm_act_layer('iabn')
+ return _create_cspnet('cspresnext50_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs)
+
+
+@register_model
+def cspdarknet53(pretrained=False, **kwargs):
+ return _create_cspnet('cspdarknet53', pretrained=pretrained, block_fn=DarkBlock, **kwargs)
+
+
+@register_model
+def cspdarknet53_iabn(pretrained=False, **kwargs):
+ norm_layer = get_norm_act_layer('iabn')
+ return _create_cspnet('cspdarknet53_iabn', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs)
+
+
+@register_model
+def darknet53(pretrained=False, **kwargs):
+ return _create_cspnet('darknet53', pretrained=pretrained, block_fn=DarkBlock, stage_fn=DarkStage, **kwargs)
diff --git a/timm/models/densenet.py b/timm/models/densenet.py
new file mode 100644
index 0000000..38a1972
--- /dev/null
+++ b/timm/models/densenet.py
@@ -0,0 +1,387 @@
+"""Pytorch Densenet implementation w/ tweaks
+This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with
+fixed kwargs passthrough and addition of dynamic global avg/max pool.
+"""
+import re
+from collections import OrderedDict
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from torch.jit.annotations import List
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import BatchNormAct2d, create_norm_act, BlurPool2d, create_classifier
+from .registry import register_model
+
+__all__ = ['DenseNet']
+
+
+def _cfg(url=''):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'features.conv0', 'classifier': 'classifier',
+ }
+
+
+default_cfgs = {
+ 'densenet121': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenet121_ra-50efcf5c.pth'),
+ 'densenet121d': _cfg(url=''),
+ 'densenetblur121d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenetblur121d_ra-100dcfbc.pth'),
+ 'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'),
+ 'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'),
+ 'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'),
+ 'densenet264': _cfg(url=''),
+ 'densenet264d_iabn': _cfg(url=''),
+ 'tv_densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-a639ec97.pth'),
+}
+
+
+class DenseLayer(nn.Module):
+ def __init__(self, num_input_features, growth_rate, bn_size, norm_layer=BatchNormAct2d,
+ drop_rate=0., memory_efficient=False):
+ super(DenseLayer, self).__init__()
+ self.add_module('norm1', norm_layer(num_input_features)),
+ self.add_module('conv1', nn.Conv2d(
+ num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
+ self.add_module('norm2', norm_layer(bn_size * growth_rate)),
+ self.add_module('conv2', nn.Conv2d(
+ bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
+ self.drop_rate = float(drop_rate)
+ self.memory_efficient = memory_efficient
+
+ def bottleneck_fn(self, xs):
+ # type: (List[torch.Tensor]) -> torch.Tensor
+ concated_features = torch.cat(xs, 1)
+ bottleneck_output = self.conv1(self.norm1(concated_features)) # noqa: T484
+ return bottleneck_output
+
+ # todo: rewrite when torchscript supports any
+ def any_requires_grad(self, x):
+ # type: (List[torch.Tensor]) -> bool
+ for tensor in x:
+ if tensor.requires_grad:
+ return True
+ return False
+
+ @torch.jit.unused # noqa: T484
+ def call_checkpoint_bottleneck(self, x):
+ # type: (List[torch.Tensor]) -> torch.Tensor
+ def closure(*xs):
+ return self.bottleneck_fn(xs)
+
+ return cp.checkpoint(closure, *x)
+
+ @torch.jit._overload_method # noqa: F811
+ def forward(self, x):
+ # type: (List[torch.Tensor]) -> (torch.Tensor)
+ pass
+
+ @torch.jit._overload_method # noqa: F811
+ def forward(self, x):
+ # type: (torch.Tensor) -> (torch.Tensor)
+ pass
+
+ # torchscript does not yet support *args, so we overload method
+ # allowing it to take either a List[Tensor] or single Tensor
+ def forward(self, x): # noqa: F811
+ if isinstance(x, torch.Tensor):
+ prev_features = [x]
+ else:
+ prev_features = x
+
+ if self.memory_efficient and self.any_requires_grad(prev_features):
+ if torch.jit.is_scripting():
+ raise Exception("Memory Efficient not supported in JIT")
+ bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
+ else:
+ bottleneck_output = self.bottleneck_fn(prev_features)
+
+ new_features = self.conv2(self.norm2(bottleneck_output))
+ if self.drop_rate > 0:
+ new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
+ return new_features
+
+
+class DenseBlock(nn.ModuleDict):
+ _version = 2
+
+ def __init__(self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=nn.ReLU,
+ drop_rate=0., memory_efficient=False):
+ super(DenseBlock, self).__init__()
+ for i in range(num_layers):
+ layer = DenseLayer(
+ num_input_features + i * growth_rate,
+ growth_rate=growth_rate,
+ bn_size=bn_size,
+ norm_layer=norm_layer,
+ drop_rate=drop_rate,
+ memory_efficient=memory_efficient,
+ )
+ self.add_module('denselayer%d' % (i + 1), layer)
+
+ def forward(self, init_features):
+ features = [init_features]
+ for name, layer in self.items():
+ new_features = layer(features)
+ features.append(new_features)
+ return torch.cat(features, 1)
+
+
+class DenseTransition(nn.Sequential):
+ def __init__(self, num_input_features, num_output_features, norm_layer=nn.BatchNorm2d, aa_layer=None):
+ super(DenseTransition, self).__init__()
+ self.add_module('norm', norm_layer(num_input_features))
+ self.add_module('conv', nn.Conv2d(
+ num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
+ if aa_layer is not None:
+ self.add_module('pool', aa_layer(num_output_features, stride=2))
+ else:
+ self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
+
+
+class DenseNet(nn.Module):
+ r"""Densenet-BC model class, based on
+ `"Densely Connected Convolutional Networks" `_
+
+ Args:
+ growth_rate (int) - how many filters to add each layer (`k` in paper)
+ block_config (list of 4 ints) - how many layers in each pooling block
+ bn_size (int) - multiplicative factor for number of bottle neck layers
+ (i.e. bn_size * k features in the bottleneck layer)
+ drop_rate (float) - dropout rate after each dense layer
+ num_classes (int) - number of classification classes
+ memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
+ but slower. Default: *False*. See `"paper" `_
+ """
+
+ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), bn_size=4, stem_type='',
+ num_classes=1000, in_chans=3, global_pool='avg',
+ norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, memory_efficient=False,
+ aa_stem_only=True):
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ super(DenseNet, self).__init__()
+
+ # Stem
+ deep_stem = 'deep' in stem_type # 3x3 deep stem
+ num_init_features = growth_rate * 2
+ if aa_layer is None:
+ stem_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ else:
+ stem_pool = nn.Sequential(*[
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
+ aa_layer(channels=num_init_features, stride=2)])
+ if deep_stem:
+ stem_chs_1 = stem_chs_2 = growth_rate
+ if 'tiered' in stem_type:
+ stem_chs_1 = 3 * (growth_rate // 4)
+ stem_chs_2 = num_init_features if 'narrow' in stem_type else 6 * (growth_rate // 4)
+ self.features = nn.Sequential(OrderedDict([
+ ('conv0', nn.Conv2d(in_chans, stem_chs_1, 3, stride=2, padding=1, bias=False)),
+ ('norm0', norm_layer(stem_chs_1)),
+ ('conv1', nn.Conv2d(stem_chs_1, stem_chs_2, 3, stride=1, padding=1, bias=False)),
+ ('norm1', norm_layer(stem_chs_2)),
+ ('conv2', nn.Conv2d(stem_chs_2, num_init_features, 3, stride=1, padding=1, bias=False)),
+ ('norm2', norm_layer(num_init_features)),
+ ('pool0', stem_pool),
+ ]))
+ else:
+ self.features = nn.Sequential(OrderedDict([
+ ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
+ ('norm0', norm_layer(num_init_features)),
+ ('pool0', stem_pool),
+ ]))
+ self.feature_info = [
+ dict(num_chs=num_init_features, reduction=2, module=f'features.norm{2 if deep_stem else 0}')]
+ current_stride = 4
+
+ # DenseBlocks
+ num_features = num_init_features
+ for i, num_layers in enumerate(block_config):
+ block = DenseBlock(
+ num_layers=num_layers,
+ num_input_features=num_features,
+ bn_size=bn_size,
+ growth_rate=growth_rate,
+ norm_layer=norm_layer,
+ drop_rate=drop_rate,
+ memory_efficient=memory_efficient
+ )
+ module_name = f'denseblock{(i + 1)}'
+ self.features.add_module(module_name, block)
+ num_features = num_features + num_layers * growth_rate
+ transition_aa_layer = None if aa_stem_only else aa_layer
+ if i != len(block_config) - 1:
+ self.feature_info += [
+ dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)]
+ current_stride *= 2
+ trans = DenseTransition(
+ num_input_features=num_features, num_output_features=num_features // 2,
+ norm_layer=norm_layer, aa_layer=transition_aa_layer)
+ self.features.add_module(f'transition{i + 1}', trans)
+ num_features = num_features // 2
+
+ # Final batch norm
+ self.features.add_module('norm5', norm_layer(num_features))
+
+ self.feature_info += [dict(num_chs=num_features, reduction=current_stride, module='features.norm5')]
+ self.num_features = num_features
+
+ # Linear layer
+ self.global_pool, self.classifier = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ # Official init from torch repo.
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.constant_(m.bias, 0)
+
+ def get_classifier(self):
+ return self.classifier
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.classifier = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ return self.features(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ # both classifier and block drop?
+ # if self.drop_rate > 0.:
+ # x = F.dropout(x, p=self.drop_rate, training=self.training)
+ x = self.classifier(x)
+ return x
+
+
+def _filter_torchvision_pretrained(state_dict):
+ pattern = re.compile(
+ r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
+
+ for key in list(state_dict.keys()):
+ res = pattern.match(key)
+ if res:
+ new_key = res.group(1) + res.group(2)
+ state_dict[new_key] = state_dict[key]
+ del state_dict[key]
+ return state_dict
+
+
+def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs):
+ kwargs['growth_rate'] = growth_rate
+ kwargs['block_config'] = block_config
+ return build_model_with_cfg(
+ DenseNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained,
+ **kwargs)
+
+
+@register_model
+def densenet121(pretrained=False, **kwargs):
+ r"""Densenet-121 model from
+ `"Densely Connected Convolutional Networks" `
+ """
+ model = _create_densenet(
+ 'densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def densenetblur121d(pretrained=False, **kwargs):
+ r"""Densenet-121 model from
+ `"Densely Connected Convolutional Networks" `
+ """
+ model = _create_densenet(
+ 'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep',
+ aa_layer=BlurPool2d, **kwargs)
+ return model
+
+
+@register_model
+def densenet121d(pretrained=False, **kwargs):
+ r"""Densenet-121 model from
+ `"Densely Connected Convolutional Networks" `
+ """
+ model = _create_densenet(
+ 'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
+ pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def densenet169(pretrained=False, **kwargs):
+ r"""Densenet-169 model from
+ `"Densely Connected Convolutional Networks" `
+ """
+ model = _create_densenet(
+ 'densenet169', growth_rate=32, block_config=(6, 12, 32, 32), pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def densenet201(pretrained=False, **kwargs):
+ r"""Densenet-201 model from
+ `"Densely Connected Convolutional Networks" `
+ """
+ model = _create_densenet(
+ 'densenet201', growth_rate=32, block_config=(6, 12, 48, 32), pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def densenet161(pretrained=False, **kwargs):
+ r"""Densenet-161 model from
+ `"Densely Connected Convolutional Networks" `
+ """
+ model = _create_densenet(
+ 'densenet161', growth_rate=48, block_config=(6, 12, 36, 24), pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def densenet264(pretrained=False, **kwargs):
+ r"""Densenet-264 model from
+ `"Densely Connected Convolutional Networks" `
+ """
+ model = _create_densenet(
+ 'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def densenet264d_iabn(pretrained=False, **kwargs):
+ r"""Densenet-264 model with deep stem and Inplace-ABN
+ """
+ def norm_act_fn(num_features, **kwargs):
+ return create_norm_act('iabn', num_features, **kwargs)
+ model = _create_densenet(
+ 'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep',
+ norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tv_densenet121(pretrained=False, **kwargs):
+ r"""Densenet-121 model with original Torchvision weights, from
+ `"Densely Connected Convolutional Networks" `
+ """
+ model = _create_densenet(
+ 'tv_densenet121', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, **kwargs)
+ return model
diff --git a/timm/models/dla.py b/timm/models/dla.py
new file mode 100644
index 0000000..f6e4dd2
--- /dev/null
+++ b/timm/models/dla.py
@@ -0,0 +1,443 @@
+""" Deep Layer Aggregation and DLA w/ Res2Net
+DLA original adapted from Official Pytorch impl at:
+DLA Paper: `Deep Layer Aggregation` - https://arxiv.org/abs/1707.06484
+
+Res2Net additions from: https://github.com/gasvn/Res2Net/
+Res2Net Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169
+"""
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import create_classifier
+from .registry import register_model
+
+__all__ = ['DLA']
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'base_layer.0', 'classifier': 'fc',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'dla34': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla34-ba72cf86.pth'),
+ 'dla46_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla46_c-2bfd52c3.pth'),
+ 'dla46x_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla46x_c-d761bae7.pth'),
+ 'dla60x_c': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60x_c-b870c45c.pth'),
+ 'dla60': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60-24839fc4.pth'),
+ 'dla60x': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla60x-d15cacda.pth'),
+ 'dla102': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102-d94d9790.pth'),
+ 'dla102x': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102x-ad62be81.pth'),
+ 'dla102x2': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla102x2-262837b6.pth'),
+ 'dla169': _cfg(url='http://dl.yf.io/dla/models/imagenet/dla169-0914e092.pth'),
+ 'dla60_res2net': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net_dla60_4s-d88db7f9.pth'),
+ 'dla60_res2next': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next_dla60_4s-d327927b.pth'),
+}
+
+
+class DlaBasic(nn.Module):
+ """DLA Basic"""
+
+ def __init__(self, inplanes, planes, stride=1, dilation=1, **_):
+ super(DlaBasic, self).__init__()
+ self.conv1 = nn.Conv2d(
+ inplanes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(
+ planes, planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.stride = stride
+
+ def forward(self, x, shortcut=None):
+ if shortcut is None:
+ shortcut = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ out += shortcut
+ out = self.relu(out)
+
+ return out
+
+
+class DlaBottleneck(nn.Module):
+ """DLA/DLA-X Bottleneck"""
+ expansion = 2
+
+ def __init__(self, inplanes, outplanes, stride=1, dilation=1, cardinality=1, base_width=64):
+ super(DlaBottleneck, self).__init__()
+ self.stride = stride
+ mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
+ mid_planes = mid_planes // self.expansion
+
+ self.conv1 = nn.Conv2d(inplanes, mid_planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(mid_planes)
+ self.conv2 = nn.Conv2d(
+ mid_planes, mid_planes, kernel_size=3, stride=stride, padding=dilation,
+ bias=False, dilation=dilation, groups=cardinality)
+ self.bn2 = nn.BatchNorm2d(mid_planes)
+ self.conv3 = nn.Conv2d(mid_planes, outplanes, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(outplanes)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x, shortcut=None):
+ if shortcut is None:
+ shortcut = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ out += shortcut
+ out = self.relu(out)
+
+ return out
+
+
+class DlaBottle2neck(nn.Module):
+ """ Res2Net/Res2NeXT DLA Bottleneck
+ Adapted from https://github.com/gasvn/Res2Net/blob/master/dla.py
+ """
+ expansion = 2
+
+ def __init__(self, inplanes, outplanes, stride=1, dilation=1, scale=4, cardinality=8, base_width=4):
+ super(DlaBottle2neck, self).__init__()
+ self.is_first = stride > 1
+ self.scale = scale
+ mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
+ mid_planes = mid_planes // self.expansion
+ self.width = mid_planes
+
+ self.conv1 = nn.Conv2d(inplanes, mid_planes * scale, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(mid_planes * scale)
+
+ num_scale_convs = max(1, scale - 1)
+ convs = []
+ bns = []
+ for _ in range(num_scale_convs):
+ convs.append(nn.Conv2d(
+ mid_planes, mid_planes, kernel_size=3, stride=stride,
+ padding=dilation, dilation=dilation, groups=cardinality, bias=False))
+ bns.append(nn.BatchNorm2d(mid_planes))
+ self.convs = nn.ModuleList(convs)
+ self.bns = nn.ModuleList(bns)
+ if self.is_first:
+ self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
+
+ self.conv3 = nn.Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(outplanes)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x, shortcut=None):
+ if shortcut is None:
+ shortcut = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ spx = torch.split(out, self.width, 1)
+ spo = []
+ for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
+ sp = spx[i] if i == 0 or self.is_first else sp + spx[i]
+ sp = conv(sp)
+ sp = bn(sp)
+ sp = self.relu(sp)
+ spo.append(sp)
+ if self.scale > 1:
+ spo.append(self.pool(spx[-1]) if self.is_first else spx[-1])
+ out = torch.cat(spo, 1)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ out += shortcut
+ out = self.relu(out)
+
+ return out
+
+
+class DlaRoot(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, shortcut):
+ super(DlaRoot, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2)
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.relu = nn.ReLU(inplace=True)
+ self.shortcut = shortcut
+
+ def forward(self, *x):
+ children = x
+ x = self.conv(torch.cat(x, 1))
+ x = self.bn(x)
+ if self.shortcut:
+ x += children[0]
+ x = self.relu(x)
+
+ return x
+
+
+class DlaTree(nn.Module):
+ def __init__(self, levels, block, in_channels, out_channels, stride=1,
+ dilation=1, cardinality=1, base_width=64,
+ level_root=False, root_dim=0, root_kernel_size=1, root_shortcut=False):
+ super(DlaTree, self).__init__()
+ if root_dim == 0:
+ root_dim = 2 * out_channels
+ if level_root:
+ root_dim += in_channels
+ self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else nn.Identity()
+ self.project = nn.Identity()
+ cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width)
+ if levels == 1:
+ self.tree1 = block(in_channels, out_channels, stride, **cargs)
+ self.tree2 = block(out_channels, out_channels, 1, **cargs)
+ if in_channels != out_channels:
+ # NOTE the official impl/weights have project layers in levels > 1 case that are never
+ # used, I've moved the project layer here to avoid wasted params but old checkpoints will
+ # need strict=False while loading.
+ self.project = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
+ nn.BatchNorm2d(out_channels))
+ else:
+ cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut))
+ self.tree1 = DlaTree(
+ levels - 1, block, in_channels, out_channels, stride, root_dim=0, **cargs)
+ self.tree2 = DlaTree(
+ levels - 1, block, out_channels, out_channels, root_dim=root_dim + out_channels, **cargs)
+ if levels == 1:
+ self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut)
+ self.level_root = level_root
+ self.root_dim = root_dim
+ self.levels = levels
+
+ def forward(self, x, shortcut=None, children=None):
+ children = [] if children is None else children
+ bottom = self.downsample(x)
+ shortcut = self.project(bottom)
+ if self.level_root:
+ children.append(bottom)
+ x1 = self.tree1(x, shortcut)
+ if self.levels == 1:
+ x2 = self.tree2(x1)
+ x = self.root(x2, x1, *children)
+ else:
+ children.append(x1)
+ x = self.tree2(x1, children=children)
+ return x
+
+
+class DLA(nn.Module):
+ def __init__(self, levels, channels, output_stride=32, num_classes=1000, in_chans=3,
+ cardinality=1, base_width=64, block=DlaBottle2neck, shortcut_root=False,
+ drop_rate=0.0, global_pool='avg'):
+ super(DLA, self).__init__()
+ self.channels = channels
+ self.num_classes = num_classes
+ self.cardinality = cardinality
+ self.base_width = base_width
+ self.drop_rate = drop_rate
+ assert output_stride == 32 # FIXME support dilation
+
+ self.base_layer = nn.Sequential(
+ nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
+ nn.BatchNorm2d(channels[0]),
+ nn.ReLU(inplace=True))
+ self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
+ self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
+ cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root)
+ self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs)
+ self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs)
+ self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs)
+ self.level5 = DlaTree(levels[5], block, channels[4], channels[5], 2, level_root=True, **cargs)
+ self.feature_info = [
+ dict(num_chs=channels[0], reduction=1, module='level0'), # rare to have a meaningful stride 1 level
+ dict(num_chs=channels[1], reduction=2, module='level1'),
+ dict(num_chs=channels[2], reduction=4, module='level2'),
+ dict(num_chs=channels[3], reduction=8, module='level3'),
+ dict(num_chs=channels[4], reduction=16, module='level4'),
+ dict(num_chs=channels[5], reduction=32, module='level5'),
+ ]
+
+ self.num_features = channels[-1]
+ self.global_pool, self.fc = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
+ modules = []
+ for i in range(convs):
+ modules.extend([
+ nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride if i == 0 else 1,
+ padding=dilation, bias=False, dilation=dilation),
+ nn.BatchNorm2d(planes),
+ nn.ReLU(inplace=True)])
+ inplanes = planes
+ return nn.Sequential(*modules)
+
+ def get_classifier(self):
+ return self.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.fc = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.base_layer(x)
+ x = self.level0(x)
+ x = self.level1(x)
+ x = self.level2(x)
+ x = self.level3(x)
+ x = self.level4(x)
+ x = self.level5(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate > 0.:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ x = self.fc(x)
+ x = self.flatten(x)
+ return x
+
+
+def _create_dla(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ DLA, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ pretrained_strict=False,
+ feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)),
+ **kwargs)
+
+
+@register_model
+def dla60_res2net(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
+ block=DlaBottle2neck, cardinality=1, base_width=28, **kwargs)
+ return _create_dla('dla60_res2net', pretrained, **model_kwargs)
+
+
+@register_model
+def dla60_res2next(pretrained=False,**kwargs):
+ model_kwargs = dict(
+ levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
+ block=DlaBottle2neck, cardinality=8, base_width=4, **kwargs)
+ return _create_dla('dla60_res2next', pretrained, **model_kwargs)
+
+
+@register_model
+def dla34(pretrained=False, **kwargs): # DLA-34
+ model_kwargs = dict(
+ levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512],
+ block=DlaBasic, **kwargs)
+ return _create_dla('dla34', pretrained, **model_kwargs)
+
+
+@register_model
+def dla46_c(pretrained=False, **kwargs): # DLA-46-C
+ model_kwargs = dict(
+ levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256],
+ block=DlaBottleneck, **kwargs)
+ return _create_dla('dla46_c', pretrained, **model_kwargs)
+
+
+@register_model
+def dla46x_c(pretrained=False, **kwargs): # DLA-X-46-C
+ model_kwargs = dict(
+ levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256],
+ block=DlaBottleneck, cardinality=32, base_width=4, **kwargs)
+ return _create_dla('dla46x_c', pretrained, **model_kwargs)
+
+
+@register_model
+def dla60x_c(pretrained=False, **kwargs): # DLA-X-60-C
+ model_kwargs = dict(
+ levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 64, 64, 128, 256],
+ block=DlaBottleneck, cardinality=32, base_width=4, **kwargs)
+ return _create_dla('dla60x_c', pretrained, **model_kwargs)
+
+
+@register_model
+def dla60(pretrained=False, **kwargs): # DLA-60
+ model_kwargs = dict(
+ levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
+ block=DlaBottleneck, **kwargs)
+ return _create_dla('dla60', pretrained, **model_kwargs)
+
+
+@register_model
+def dla60x(pretrained=False, **kwargs): # DLA-X-60
+ model_kwargs = dict(
+ levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
+ block=DlaBottleneck, cardinality=32, base_width=4, **kwargs)
+ return _create_dla('dla60x', pretrained, **model_kwargs)
+
+
+@register_model
+def dla102(pretrained=False, **kwargs): # DLA-102
+ model_kwargs = dict(
+ levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
+ block=DlaBottleneck, shortcut_root=True, **kwargs)
+ return _create_dla('dla102', pretrained, **model_kwargs)
+
+
+@register_model
+def dla102x(pretrained=False, **kwargs): # DLA-X-102
+ model_kwargs = dict(
+ levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
+ block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True, **kwargs)
+ return _create_dla('dla102x', pretrained, **model_kwargs)
+
+
+@register_model
+def dla102x2(pretrained=False, **kwargs): # DLA-X-102 64
+ model_kwargs = dict(
+ levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
+ block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True, **kwargs)
+ return _create_dla('dla102x2', pretrained, **model_kwargs)
+
+
+@register_model
+def dla169(pretrained=False, **kwargs): # DLA-169
+ model_kwargs = dict(
+ levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024],
+ block=DlaBottleneck, shortcut_root=True, **kwargs)
+ return _create_dla('dla169', pretrained, **model_kwargs)
diff --git a/timm/models/dpn.py b/timm/models/dpn.py
new file mode 100644
index 0000000..c4e380b
--- /dev/null
+++ b/timm/models/dpn.py
@@ -0,0 +1,317 @@
+""" PyTorch implementation of DualPathNetworks
+Based on original MXNet implementation https://github.com/cypw/DPNs with
+many ideas from another PyTorch implementation https://github.com/oyam/pytorch-DPNs.
+
+This implementation is compatible with the pretrained weights from cypw's MXNet implementation.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from collections import OrderedDict
+from functools import partial
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import BatchNormAct2d, ConvBnAct, create_conv2d, create_classifier
+from .registry import register_model
+
+__all__ = ['DPN']
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD,
+ 'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'dpn68': _cfg(
+ url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
+ 'dpn68b': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dpn68b_ra-a31ca160.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'dpn92': _cfg(
+ url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'),
+ 'dpn98': _cfg(
+ url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'),
+ 'dpn131': _cfg(
+ url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn131-71dfe43e0.pth'),
+ 'dpn107': _cfg(
+ url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth')
+}
+
+
+class CatBnAct(nn.Module):
+ def __init__(self, in_chs, norm_layer=BatchNormAct2d):
+ super(CatBnAct, self).__init__()
+ self.bn = norm_layer(in_chs, eps=0.001)
+
+ @torch.jit._overload_method # noqa: F811
+ def forward(self, x):
+ # type: (Tuple[torch.Tensor, torch.Tensor]) -> (torch.Tensor)
+ pass
+
+ @torch.jit._overload_method # noqa: F811
+ def forward(self, x):
+ # type: (torch.Tensor) -> (torch.Tensor)
+ pass
+
+ def forward(self, x):
+ if isinstance(x, tuple):
+ x = torch.cat(x, dim=1)
+ return self.bn(x)
+
+
+class BnActConv2d(nn.Module):
+ def __init__(self, in_chs, out_chs, kernel_size, stride, groups=1, norm_layer=BatchNormAct2d):
+ super(BnActConv2d, self).__init__()
+ self.bn = norm_layer(in_chs, eps=0.001)
+ self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, groups=groups)
+
+ def forward(self, x):
+ return self.conv(self.bn(x))
+
+
+class DualPathBlock(nn.Module):
+ def __init__(
+ self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False):
+ super(DualPathBlock, self).__init__()
+ self.num_1x1_c = num_1x1_c
+ self.inc = inc
+ self.b = b
+ if block_type == 'proj':
+ self.key_stride = 1
+ self.has_proj = True
+ elif block_type == 'down':
+ self.key_stride = 2
+ self.has_proj = True
+ else:
+ assert block_type == 'normal'
+ self.key_stride = 1
+ self.has_proj = False
+
+ self.c1x1_w_s1 = None
+ self.c1x1_w_s2 = None
+ if self.has_proj:
+ # Using different member names here to allow easier parameter key matching for conversion
+ if self.key_stride == 2:
+ self.c1x1_w_s2 = BnActConv2d(
+ in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=2)
+ else:
+ self.c1x1_w_s1 = BnActConv2d(
+ in_chs=in_chs, out_chs=num_1x1_c + 2 * inc, kernel_size=1, stride=1)
+
+ self.c1x1_a = BnActConv2d(in_chs=in_chs, out_chs=num_1x1_a, kernel_size=1, stride=1)
+ self.c3x3_b = BnActConv2d(
+ in_chs=num_1x1_a, out_chs=num_3x3_b, kernel_size=3, stride=self.key_stride, groups=groups)
+ if b:
+ self.c1x1_c = CatBnAct(in_chs=num_3x3_b)
+ self.c1x1_c1 = create_conv2d(num_3x3_b, num_1x1_c, kernel_size=1)
+ self.c1x1_c2 = create_conv2d(num_3x3_b, inc, kernel_size=1)
+ else:
+ self.c1x1_c = BnActConv2d(in_chs=num_3x3_b, out_chs=num_1x1_c + inc, kernel_size=1, stride=1)
+ self.c1x1_c1 = None
+ self.c1x1_c2 = None
+
+ @torch.jit._overload_method # noqa: F811
+ def forward(self, x):
+ # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
+ pass
+
+ @torch.jit._overload_method # noqa: F811
+ def forward(self, x):
+ # type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
+ pass
+
+ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
+ if isinstance(x, tuple):
+ x_in = torch.cat(x, dim=1)
+ else:
+ x_in = x
+ if self.c1x1_w_s1 is None and self.c1x1_w_s2 is None:
+ # self.has_proj == False, torchscript requires condition on module == None
+ x_s1 = x[0]
+ x_s2 = x[1]
+ else:
+ # self.has_proj == True
+ if self.c1x1_w_s1 is not None:
+ # self.key_stride = 1
+ x_s = self.c1x1_w_s1(x_in)
+ else:
+ # self.key_stride = 2
+ x_s = self.c1x1_w_s2(x_in)
+ x_s1 = x_s[:, :self.num_1x1_c, :, :]
+ x_s2 = x_s[:, self.num_1x1_c:, :, :]
+ x_in = self.c1x1_a(x_in)
+ x_in = self.c3x3_b(x_in)
+ x_in = self.c1x1_c(x_in)
+ if self.c1x1_c1 is not None:
+ # self.b == True, using None check for torchscript compat
+ out1 = self.c1x1_c1(x_in)
+ out2 = self.c1x1_c2(x_in)
+ else:
+ out1 = x_in[:, :self.num_1x1_c, :, :]
+ out2 = x_in[:, self.num_1x1_c:, :, :]
+ resid = x_s1 + out1
+ dense = torch.cat([x_s2, out2], dim=1)
+ return resid, dense
+
+
+class DPN(nn.Module):
+ def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
+ b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32,
+ num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', fc_act=nn.ELU):
+ super(DPN, self).__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ self.b = b
+ assert output_stride == 32 # FIXME look into dilation support
+ norm_layer = partial(BatchNormAct2d, eps=.001)
+ fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act, inplace=False)
+ bw_factor = 1 if small else 4
+ blocks = OrderedDict()
+
+ # conv1
+ blocks['conv1_1'] = ConvBnAct(
+ in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer)
+ blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')]
+
+ # conv2
+ bw = 64 * bw_factor
+ inc = inc_sec[0]
+ r = (k_r * bw) // (64 * bw_factor)
+ blocks['conv2_1'] = DualPathBlock(num_init_features, r, r, bw, inc, groups, 'proj', b)
+ in_chs = bw + 3 * inc
+ for i in range(2, k_sec[0] + 1):
+ blocks['conv2_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
+ in_chs += inc
+ self.feature_info += [dict(num_chs=in_chs, reduction=4, module=f'features.conv2_{k_sec[0]}')]
+
+ # conv3
+ bw = 128 * bw_factor
+ inc = inc_sec[1]
+ r = (k_r * bw) // (64 * bw_factor)
+ blocks['conv3_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
+ in_chs = bw + 3 * inc
+ for i in range(2, k_sec[1] + 1):
+ blocks['conv3_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
+ in_chs += inc
+ self.feature_info += [dict(num_chs=in_chs, reduction=8, module=f'features.conv3_{k_sec[1]}')]
+
+ # conv4
+ bw = 256 * bw_factor
+ inc = inc_sec[2]
+ r = (k_r * bw) // (64 * bw_factor)
+ blocks['conv4_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
+ in_chs = bw + 3 * inc
+ for i in range(2, k_sec[2] + 1):
+ blocks['conv4_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
+ in_chs += inc
+ self.feature_info += [dict(num_chs=in_chs, reduction=16, module=f'features.conv4_{k_sec[2]}')]
+
+ # conv5
+ bw = 512 * bw_factor
+ inc = inc_sec[3]
+ r = (k_r * bw) // (64 * bw_factor)
+ blocks['conv5_1'] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'down', b)
+ in_chs = bw + 3 * inc
+ for i in range(2, k_sec[3] + 1):
+ blocks['conv5_' + str(i)] = DualPathBlock(in_chs, r, r, bw, inc, groups, 'normal', b)
+ in_chs += inc
+ self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')]
+
+ blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer)
+
+ self.num_features = in_chs
+ self.features = nn.Sequential(blocks)
+
+ # Using 1x1 conv for the FC layer to allow the extra pooling scheme
+ self.global_pool, self.classifier = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
+
+ def get_classifier(self):
+ return self.classifier
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.classifier = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
+
+ def forward_features(self, x):
+ return self.features(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate > 0.:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ x = self.classifier(x)
+ x = self.flatten(x)
+ return x
+
+
+def _create_dpn(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ DPN, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ feature_cfg=dict(feature_concat=True, flatten_sequential=True),
+ **kwargs)
+
+
+@register_model
+def dpn68(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ small=True, num_init_features=10, k_r=128, groups=32,
+ k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs)
+ return _create_dpn('dpn68', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def dpn68b(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ small=True, num_init_features=10, k_r=128, groups=32,
+ b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs)
+ return _create_dpn('dpn68b', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def dpn92(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ num_init_features=64, k_r=96, groups=32,
+ k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), **kwargs)
+ return _create_dpn('dpn92', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def dpn98(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ num_init_features=96, k_r=160, groups=40,
+ k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), **kwargs)
+ return _create_dpn('dpn98', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def dpn131(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ num_init_features=128, k_r=160, groups=40,
+ k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), **kwargs)
+ return _create_dpn('dpn131', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def dpn107(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ num_init_features=128, k_r=200, groups=50,
+ k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), **kwargs)
+ return _create_dpn('dpn107', pretrained=pretrained, **model_kwargs)
diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py
new file mode 100644
index 0000000..b1c570b
--- /dev/null
+++ b/timm/models/efficientnet.py
@@ -0,0 +1,2286 @@
+""" The EfficientNet Family in PyTorch
+
+An implementation of EfficienNet that covers variety of related models with efficient architectures:
+
+* EfficientNet-V2
+ - `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
+
+* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent weight ports)
+ - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946
+ - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971
+ - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665
+ - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252
+
+* MixNet (Small, Medium, and Large)
+ - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595
+
+* MNasNet B1, A1 (SE), Small
+ - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626
+
+* FBNet-C
+ - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443
+
+* Single-Path NAS Pixel1
+ - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877
+
+* TinyNet
+ - Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets - https://arxiv.org/abs/2010.14819
+ - Definitions & weights borrowed from https://github.com/huawei-noah/CV-Backbones/tree/master/tinynet_pytorch
+
+* And likely more...
+
+The majority of the above models (EfficientNet*, MixNet, MnasNet) and original weights were made available
+by Mingxing Tan, Quoc Le, and other members of their Google Brain team. Thanks for consistently releasing
+the models and weights open source!
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from functools import partial
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
+from .efficientnet_blocks import SqueezeExcite
+from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
+ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
+from .features import FeatureInfo, FeatureHooks
+from .helpers import build_model_with_cfg, default_cfg_for_features
+from .layers import create_conv2d, create_classifier
+from .registry import register_model
+
+__all__ = ['EfficientNet', 'EfficientNetFeatures']
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv_stem', 'classifier': 'classifier',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'mnasnet_050': _cfg(url=''),
+ 'mnasnet_075': _cfg(url=''),
+ 'mnasnet_100': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth'),
+ 'mnasnet_140': _cfg(url=''),
+
+ 'semnasnet_050': _cfg(url=''),
+ 'semnasnet_075': _cfg(url=''),
+ 'semnasnet_100': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'),
+ 'semnasnet_140': _cfg(url=''),
+ 'mnasnet_small': _cfg(url=''),
+
+ 'mobilenetv2_100': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth'),
+ 'mobilenetv2_110d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth'),
+ 'mobilenetv2_120d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth'),
+ 'mobilenetv2_140': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth'),
+
+ 'fbnetc_100': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
+ interpolation='bilinear'),
+ 'spnasnet_100': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',
+ interpolation='bilinear'),
+
+ # NOTE experimenting with alternate attention
+ 'efficientnet_b0': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'),
+ 'efficientnet_b1': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
+ test_input_size=(3, 256, 256), crop_pct=1.0),
+ 'efficientnet_b2': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8), test_input_size=(3, 288, 288), crop_pct=1.0),
+ 'efficientnet_b3': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth',
+ input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
+ 'efficientnet_b4': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b4_ra2_320-7eb33cd5.pth',
+ input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), crop_pct=1.0),
+ 'efficientnet_b5': _cfg(
+ url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
+ 'efficientnet_b6': _cfg(
+ url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
+ 'efficientnet_b7': _cfg(
+ url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
+ 'efficientnet_b8': _cfg(
+ url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
+ 'efficientnet_l2': _cfg(
+ url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
+
+ 'efficientnet_es': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'),
+ 'efficientnet_em': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_em_ra2-66250f76.pth',
+ input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
+ 'efficientnet_el': _cfg(
+ url='https://github.com/DeGirum/pruned-models/releases/download/efficientnet_v1.0/efficientnet_el.pth',
+ input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
+
+ 'efficientnet_es_pruned': _cfg(
+ url='https://github.com/DeGirum/pruned-models/releases/download/efficientnet_v1.0/efficientnet_es_pruned75.pth'),
+ 'efficientnet_el_pruned': _cfg(
+ url='https://github.com/DeGirum/pruned-models/releases/download/efficientnet_v1.0/efficientnet_el_pruned70.pth',
+ input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
+
+ 'efficientnet_cc_b0_4e': _cfg(url=''),
+ 'efficientnet_cc_b0_8e': _cfg(url=''),
+ 'efficientnet_cc_b1_8e': _cfg(url='', input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
+
+ 'efficientnet_lite0': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth'),
+ 'efficientnet_lite1': _cfg(
+ url='',
+ input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
+ 'efficientnet_lite2': _cfg(
+ url='',
+ input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
+ 'efficientnet_lite3': _cfg(
+ url='',
+ input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
+ 'efficientnet_lite4': _cfg(
+ url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
+
+ 'efficientnet_b1_pruned': _cfg(
+ url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb1_pruned_9ebb3fe6.pth',
+ input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
+ 'efficientnet_b2_pruned': _cfg(
+ url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb2_pruned_203f55bc.pth',
+ input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
+ 'efficientnet_b3_pruned': _cfg(
+ url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45403/outputs/effnetb3_pruned_5abcc29f.pth',
+ input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
+
+ 'efficientnetv2_rw_t': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_t_agc-3620981a.pth',
+ input_size=(3, 224, 224), test_input_size=(3, 288, 288), pool_size=(7, 7), crop_pct=1.0),
+ 'gc_efficientnetv2_rw_t': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gc_efficientnetv2_rw_t_agc-927a0bde.pth',
+ input_size=(3, 224, 224), test_input_size=(3, 288, 288), pool_size=(7, 7), crop_pct=1.0),
+ 'efficientnetv2_rw_s': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth',
+ input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0),
+ 'efficientnetv2_rw_m': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_rw_m_agc-3d90cb1e.pth',
+ input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0),
+
+ 'efficientnetv2_s': _cfg(
+ url='',
+ input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0),
+ 'efficientnetv2_m': _cfg(
+ url='',
+ input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0),
+ 'efficientnetv2_l': _cfg(
+ url='',
+ input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
+ 'efficientnetv2_xl': _cfg(
+ url='',
+ input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0),
+
+ 'tf_efficientnet_b0': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
+ input_size=(3, 224, 224)),
+ 'tf_efficientnet_b1': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
+ input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
+ 'tf_efficientnet_b2': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
+ input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
+ 'tf_efficientnet_b3': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
+ input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
+ 'tf_efficientnet_b4': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
+ input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
+ 'tf_efficientnet_b5': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
+ input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
+ 'tf_efficientnet_b6': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
+ input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
+ 'tf_efficientnet_b7': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
+ input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
+ 'tf_efficientnet_b8': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
+ input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
+
+ 'tf_efficientnet_b0_ap': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 224, 224)),
+ 'tf_efficientnet_b1_ap': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
+ input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
+ 'tf_efficientnet_b2_ap': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
+ input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
+ 'tf_efficientnet_b3_ap': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
+ input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
+ 'tf_efficientnet_b4_ap': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
+ input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
+ 'tf_efficientnet_b5_ap': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
+ input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
+ 'tf_efficientnet_b6_ap': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
+ input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
+ 'tf_efficientnet_b7_ap': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
+ input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
+ 'tf_efficientnet_b8_ap': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
+ input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
+
+ 'tf_efficientnet_b0_ns': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
+ input_size=(3, 224, 224)),
+ 'tf_efficientnet_b1_ns': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
+ input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
+ 'tf_efficientnet_b2_ns': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
+ input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
+ 'tf_efficientnet_b3_ns': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
+ input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
+ 'tf_efficientnet_b4_ns': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
+ input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
+ 'tf_efficientnet_b5_ns': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
+ input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
+ 'tf_efficientnet_b6_ns': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
+ input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
+ 'tf_efficientnet_b7_ns': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
+ input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
+ 'tf_efficientnet_l2_ns_475': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
+ input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
+ 'tf_efficientnet_l2_ns': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
+ input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
+
+ 'tf_efficientnet_es': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 224, 224), ),
+ 'tf_efficientnet_em': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
+ 'tf_efficientnet_el': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
+
+ 'tf_efficientnet_cc_b0_4e': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
+ 'tf_efficientnet_cc_b0_8e': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
+ 'tf_efficientnet_cc_b1_8e': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
+ input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
+
+ 'tf_efficientnet_lite0': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res
+ ),
+ 'tf_efficientnet_lite1': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882,
+ interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res
+ ),
+ 'tf_efficientnet_lite2': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890,
+ interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res
+ ),
+ 'tf_efficientnet_lite3': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, interpolation='bilinear'),
+ 'tf_efficientnet_lite4': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'),
+
+ 'tf_efficientnetv2_s': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
+ 'tf_efficientnetv2_m': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
+ 'tf_efficientnetv2_l': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
+
+ 'tf_efficientnetv2_s_in21ft1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
+ 'tf_efficientnetv2_m_in21ft1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21ft1k-bf41664a.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
+ 'tf_efficientnetv2_l_in21ft1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21ft1k-60127a9d.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
+ 'tf_efficientnetv2_xl_in21ft1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21ft1k-06c35c48.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0),
+
+ 'tf_efficientnetv2_s_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
+ input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
+ 'tf_efficientnetv2_m_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
+ input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
+ 'tf_efficientnetv2_l_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21k-91a19ec9.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
+ input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
+ 'tf_efficientnetv2_xl_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21k-fd7e8abf.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
+ input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0),
+
+ 'tf_efficientnetv2_b0': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b0-c7cc451f.pth',
+ input_size=(3, 192, 192), test_input_size=(3, 224, 224), pool_size=(6, 6)),
+ 'tf_efficientnetv2_b1': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b1-be6e41b0.pth',
+ input_size=(3, 192, 192), test_input_size=(3, 240, 240), pool_size=(6, 6), crop_pct=0.882),
+ 'tf_efficientnetv2_b2': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b2-847de54e.pth',
+ input_size=(3, 208, 208), test_input_size=(3, 260, 260), pool_size=(7, 7), crop_pct=0.890),
+ 'tf_efficientnetv2_b3': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b3-57773f13.pth',
+ input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904),
+
+ 'mixnet_s': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth'),
+ 'mixnet_m': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'),
+ 'mixnet_l': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth'),
+ 'mixnet_xl': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth'),
+ 'mixnet_xxl': _cfg(),
+
+ 'tf_mixnet_s': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'),
+ 'tf_mixnet_m': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth'),
+ 'tf_mixnet_l': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'),
+
+ "tinynet_a": _cfg(
+ input_size=(3, 192, 192), pool_size=(6, 6), # int(224 * 0.86)
+ url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_a.pth'),
+ "tinynet_b": _cfg(
+ input_size=(3, 188, 188), pool_size=(6, 6), # int(224 * 0.84)
+ url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_b.pth'),
+ "tinynet_c": _cfg(
+ input_size=(3, 184, 184), pool_size=(6, 6), # int(224 * 0.825)
+ url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_c.pth'),
+ "tinynet_d": _cfg(
+ input_size=(3, 152, 152), pool_size=(5, 5), # int(224 * 0.68)
+ url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_d.pth'),
+ "tinynet_e": _cfg(
+ input_size=(3, 106, 106), pool_size=(4, 4), # int(224 * 0.475)
+ url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_e.pth'),
+}
+
+
+class EfficientNet(nn.Module):
+ """ (Generic) EfficientNet
+
+ A flexible and performant PyTorch implementation of efficient network architectures, including:
+ * EfficientNet-V2 Small, Medium, Large, XL & B0-B3
+ * EfficientNet B0-B8, L2
+ * EfficientNet-EdgeTPU
+ * EfficientNet-CondConv
+ * MixNet S, M, L, XL
+ * MnasNet A1, B1, and small
+ * FBNet C
+ * Single-Path NAS Pixel1
+
+ """
+
+ def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32, fix_stem=False,
+ output_stride=32, pad_type='', round_chs_fn=round_channels, act_layer=None, norm_layer=None,
+ se_layer=None, drop_rate=0., drop_path_rate=0., global_pool='avg'):
+ super(EfficientNet, self).__init__()
+ act_layer = act_layer or nn.ReLU
+ norm_layer = norm_layer or nn.BatchNorm2d
+ se_layer = se_layer or SqueezeExcite
+ self.num_classes = num_classes
+ self.num_features = num_features
+ self.drop_rate = drop_rate
+
+ # Stem
+ if not fix_stem:
+ stem_size = round_chs_fn(stem_size)
+ self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
+ self.bn1 = norm_layer(stem_size)
+ self.act1 = act_layer(inplace=True)
+
+ # Middle stages (IR/ER/DS Blocks)
+ builder = EfficientNetBuilder(
+ output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
+ act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
+ self.blocks = nn.Sequential(*builder(stem_size, block_args))
+ self.feature_info = builder.features
+ head_chs = builder.in_chs
+
+ # Head + Pooling
+ self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type)
+ self.bn2 = norm_layer(self.num_features)
+ self.act2 = act_layer(inplace=True)
+ self.global_pool, self.classifier = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ efficientnet_init_weights(self)
+
+ def as_sequential(self):
+ layers = [self.conv_stem, self.bn1, self.act1]
+ layers.extend(self.blocks)
+ layers.extend([self.conv_head, self.bn2, self.act2, self.global_pool])
+ layers.extend([nn.Dropout(self.drop_rate), self.classifier])
+ return nn.Sequential(*layers)
+
+ def get_classifier(self):
+ return self.classifier
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.classifier = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ x = self.conv_stem(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ x = self.blocks(x)
+ x = self.conv_head(x)
+ x = self.bn2(x)
+ x = self.act2(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate > 0.:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ return self.classifier(x)
+
+
+class EfficientNetFeatures(nn.Module):
+ """ EfficientNet Feature Extractor
+
+ A work-in-progress feature extraction module for EfficientNet, to use as a backbone for segmentation
+ and object detection models.
+ """
+
+ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
+ stem_size=32, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
+ act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
+ super(EfficientNetFeatures, self).__init__()
+ act_layer = act_layer or nn.ReLU
+ norm_layer = norm_layer or nn.BatchNorm2d
+ se_layer = se_layer or SqueezeExcite
+ self.drop_rate = drop_rate
+
+ # Stem
+ if not fix_stem:
+ stem_size = round_chs_fn(stem_size)
+ self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
+ self.bn1 = norm_layer(stem_size)
+ self.act1 = act_layer(inplace=True)
+
+ # Middle stages (IR/ER/DS Blocks)
+ builder = EfficientNetBuilder(
+ output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
+ act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate,
+ feature_location=feature_location)
+ self.blocks = nn.Sequential(*builder(stem_size, block_args))
+ self.feature_info = FeatureInfo(builder.features, out_indices)
+ self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
+
+ efficientnet_init_weights(self)
+
+ # Register feature extraction hooks with FeatureHooks helper
+ self.feature_hooks = None
+ if feature_location != 'bottleneck':
+ hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
+ self.feature_hooks = FeatureHooks(hooks, self.named_modules())
+
+ def forward(self, x) -> List[torch.Tensor]:
+ x = self.conv_stem(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ if self.feature_hooks is None:
+ features = []
+ if 0 in self._stage_out_idx:
+ features.append(x) # add stem out
+ for i, b in enumerate(self.blocks):
+ x = b(x)
+ if i + 1 in self._stage_out_idx:
+ features.append(x)
+ return features
+ else:
+ self.blocks(x)
+ out = self.feature_hooks.get_output(x.device)
+ return list(out.values())
+
+
+def _create_effnet(variant, pretrained=False, **kwargs):
+ features_only = False
+ model_cls = EfficientNet
+ kwargs_filter = None
+ if kwargs.pop('features_only', False):
+ features_only = True
+ kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')
+ model_cls = EfficientNetFeatures
+ model = build_model_with_cfg(
+ model_cls, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ pretrained_strict=not features_only,
+ kwargs_filter=kwargs_filter,
+ **kwargs)
+ if features_only:
+ model.default_cfg = default_cfg_for_features(model.default_cfg)
+ return model
+
+
+def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a mnasnet-a1 model.
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
+ Paper: https://arxiv.org/pdf/1807.11626.pdf.
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer.
+ """
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_e1_c16_noskip'],
+ # stage 1, 112x112 in
+ ['ir_r2_k3_s2_e6_c24'],
+ # stage 2, 56x56 in
+ ['ir_r3_k5_s2_e3_c40_se0.25'],
+ # stage 3, 28x28 in
+ ['ir_r4_k3_s2_e6_c80'],
+ # stage 4, 14x14in
+ ['ir_r2_k3_s1_e6_c112_se0.25'],
+ # stage 5, 14x14in
+ ['ir_r3_k5_s2_e6_c160_se0.25'],
+ # stage 6, 7x7 in
+ ['ir_r1_k3_s1_e6_c320'],
+ ]
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ stem_size=32,
+ round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ **kwargs
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a mnasnet-b1 model.
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
+ Paper: https://arxiv.org/pdf/1807.11626.pdf.
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer.
+ """
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_c16_noskip'],
+ # stage 1, 112x112 in
+ ['ir_r3_k3_s2_e3_c24'],
+ # stage 2, 56x56 in
+ ['ir_r3_k5_s2_e3_c40'],
+ # stage 3, 28x28 in
+ ['ir_r3_k5_s2_e6_c80'],
+ # stage 4, 14x14in
+ ['ir_r2_k3_s1_e6_c96'],
+ # stage 5, 14x14in
+ ['ir_r4_k5_s2_e6_c192'],
+ # stage 6, 7x7 in
+ ['ir_r1_k3_s1_e6_c320_noskip']
+ ]
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ stem_size=32,
+ round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ **kwargs
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a mnasnet-b1 model.
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
+ Paper: https://arxiv.org/pdf/1807.11626.pdf.
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer.
+ """
+ arch_def = [
+ ['ds_r1_k3_s1_c8'],
+ ['ir_r1_k3_s2_e3_c16'],
+ ['ir_r2_k3_s2_e6_c16'],
+ ['ir_r4_k5_s2_e6_c32_se0.25'],
+ ['ir_r3_k3_s1_e6_c32_se0.25'],
+ ['ir_r3_k5_s2_e6_c88_se0.25'],
+ ['ir_r1_k3_s1_e6_c144']
+ ]
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ stem_size=8,
+ round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ **kwargs
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_mobilenet_v2(
+ variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs):
+ """ Generate MobileNet-V2 network
+ Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
+ Paper: https://arxiv.org/abs/1801.04381
+ """
+ arch_def = [
+ ['ds_r1_k3_s1_c16'],
+ ['ir_r2_k3_s2_e6_c24'],
+ ['ir_r3_k3_s2_e6_c32'],
+ ['ir_r4_k3_s2_e6_c64'],
+ ['ir_r3_k3_s1_e6_c96'],
+ ['ir_r3_k3_s2_e6_c160'],
+ ['ir_r1_k3_s1_e6_c320'],
+ ]
+ round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
+ num_features=1280 if fix_stem_head else round_chs_fn(1280),
+ stem_size=32,
+ fix_stem=fix_stem_head,
+ round_chs_fn=round_chs_fn,
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ act_layer=resolve_act_layer(kwargs, 'relu6'),
+ **kwargs
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """ FBNet-C
+
+ Paper: https://arxiv.org/abs/1812.03443
+ Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
+
+ NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper,
+ it was used to confirm some building block details
+ """
+ arch_def = [
+ ['ir_r1_k3_s1_e1_c16'],
+ ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'],
+ ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'],
+ ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'],
+ ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'],
+ ['ir_r4_k5_s2_e6_c184'],
+ ['ir_r1_k3_s1_e6_c352'],
+ ]
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ stem_size=16,
+ num_features=1984, # paper suggests this, but is not 100% clear
+ round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ **kwargs
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates the Single-Path NAS model from search targeted for Pixel1 phone.
+
+ Paper: https://arxiv.org/abs/1904.02877
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer.
+ """
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_c16_noskip'],
+ # stage 1, 112x112 in
+ ['ir_r3_k3_s2_e3_c24'],
+ # stage 2, 56x56 in
+ ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'],
+ # stage 3, 28x28 in
+ ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'],
+ # stage 4, 14x14in
+ ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'],
+ # stage 5, 14x14in
+ ['ir_r4_k5_s2_e6_c192'],
+ # stage 6, 7x7 in
+ ['ir_r1_k3_s1_e6_c320_noskip']
+ ]
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ stem_size=32,
+ round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ **kwargs
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates an EfficientNet model.
+
+ Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
+ Paper: https://arxiv.org/abs/1905.11946
+
+ EfficientNet params
+ name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
+ 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
+ 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
+ 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
+ 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
+ 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
+ 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
+ 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
+ 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
+ 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
+ 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer
+ depth_multiplier: multiplier to number of repeats per stage
+
+ """
+ arch_def = [
+ ['ds_r1_k3_s1_e1_c16_se0.25'],
+ ['ir_r2_k3_s2_e6_c24_se0.25'],
+ ['ir_r2_k5_s2_e6_c40_se0.25'],
+ ['ir_r3_k3_s2_e6_c80_se0.25'],
+ ['ir_r3_k5_s1_e6_c112_se0.25'],
+ ['ir_r4_k5_s2_e6_c192_se0.25'],
+ ['ir_r1_k3_s1_e6_c320_se0.25'],
+ ]
+ round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier),
+ num_features=round_chs_fn(1280),
+ stem_size=32,
+ round_chs_fn=round_chs_fn,
+ act_layer=resolve_act_layer(kwargs, 'swish'),
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ **kwargs,
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
+ """ Creates an EfficientNet-EdgeTPU model
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
+ """
+
+ arch_def = [
+ # NOTE `fc` is present to override a mismatch between stem channels and in chs not
+ # present in other models
+ ['er_r1_k3_s1_e4_c24_fc24_noskip'],
+ ['er_r2_k3_s2_e8_c32'],
+ ['er_r4_k3_s2_e8_c48'],
+ ['ir_r5_k5_s2_e8_c96'],
+ ['ir_r4_k5_s1_e8_c144'],
+ ['ir_r2_k5_s2_e8_c192'],
+ ]
+ round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier),
+ num_features=round_chs_fn(1280),
+ stem_size=32,
+ round_chs_fn=round_chs_fn,
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ act_layer=resolve_act_layer(kwargs, 'relu'),
+ **kwargs,
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_efficientnet_condconv(
+ variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
+ """Creates an EfficientNet-CondConv model.
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
+ """
+ arch_def = [
+ ['ds_r1_k3_s1_e1_c16_se0.25'],
+ ['ir_r2_k3_s2_e6_c24_se0.25'],
+ ['ir_r2_k5_s2_e6_c40_se0.25'],
+ ['ir_r3_k3_s2_e6_c80_se0.25'],
+ ['ir_r3_k5_s1_e6_c112_se0.25_cc4'],
+ ['ir_r4_k5_s2_e6_c192_se0.25_cc4'],
+ ['ir_r1_k3_s1_e6_c320_se0.25_cc4'],
+ ]
+ # NOTE unlike official impl, this one uses `cc` option where x is the base number of experts for each stage and
+ # the expert_multiplier increases that on a per-model basis as with depth/channel multipliers
+ round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
+ num_features=round_chs_fn(1280),
+ stem_size=32,
+ round_chs_fn=round_chs_fn,
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ act_layer=resolve_act_layer(kwargs, 'swish'),
+ **kwargs,
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates an EfficientNet-Lite model.
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
+ Paper: https://arxiv.org/abs/1905.11946
+
+ EfficientNet params
+ name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
+ 'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
+ 'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
+ 'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
+ 'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
+ 'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer
+ depth_multiplier: multiplier to number of repeats per stage
+ """
+ arch_def = [
+ ['ds_r1_k3_s1_e1_c16'],
+ ['ir_r2_k3_s2_e6_c24'],
+ ['ir_r2_k5_s2_e6_c40'],
+ ['ir_r3_k3_s2_e6_c80'],
+ ['ir_r3_k5_s1_e6_c112'],
+ ['ir_r4_k5_s2_e6_c192'],
+ ['ir_r1_k3_s1_e6_c320'],
+ ]
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True),
+ num_features=1280,
+ stem_size=32,
+ fix_stem=True,
+ round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
+ act_layer=resolve_act_layer(kwargs, 'relu6'),
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ **kwargs,
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_efficientnetv2_base(
+ variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
+ """ Creates an EfficientNet-V2 base model
+
+ Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
+ Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
+ """
+ arch_def = [
+ ['cn_r1_k3_s1_e1_c16_skip'],
+ ['er_r2_k3_s2_e4_c32'],
+ ['er_r2_k3_s2_e4_c48'],
+ ['ir_r3_k3_s2_e4_c96_se0.25'],
+ ['ir_r5_k3_s1_e6_c112_se0.25'],
+ ['ir_r8_k3_s2_e6_c192_se0.25'],
+ ]
+ round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.)
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier),
+ num_features=round_chs_fn(1280),
+ stem_size=32,
+ round_chs_fn=round_chs_fn,
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ act_layer=resolve_act_layer(kwargs, 'silu'),
+ **kwargs,
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_efficientnetv2_s(
+ variant, channel_multiplier=1.0, depth_multiplier=1.0, rw=False, pretrained=False, **kwargs):
+ """ Creates an EfficientNet-V2 Small model
+
+ Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
+ Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
+
+ NOTE: `rw` flag sets up 'small' variant to behave like my initial v2 small model,
+ before ref the impl was released.
+ """
+ arch_def = [
+ ['cn_r2_k3_s1_e1_c24_skip'],
+ ['er_r4_k3_s2_e4_c48'],
+ ['er_r4_k3_s2_e4_c64'],
+ ['ir_r6_k3_s2_e4_c128_se0.25'],
+ ['ir_r9_k3_s1_e6_c160_se0.25'],
+ ['ir_r15_k3_s2_e6_c256_se0.25'],
+ ]
+ num_features = 1280
+ if rw:
+ # my original variant, based on paper figure differs from the official release
+ arch_def[0] = ['er_r2_k3_s1_e1_c24']
+ arch_def[-1] = ['ir_r15_k3_s2_e6_c272_se0.25']
+ num_features = 1792
+
+ round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier),
+ num_features=round_chs_fn(num_features),
+ stem_size=24,
+ round_chs_fn=round_chs_fn,
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ act_layer=resolve_act_layer(kwargs, 'silu'),
+ **kwargs,
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
+ """ Creates an EfficientNet-V2 Medium model
+
+ Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
+ Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
+ """
+
+ arch_def = [
+ ['cn_r3_k3_s1_e1_c24_skip'],
+ ['er_r5_k3_s2_e4_c48'],
+ ['er_r5_k3_s2_e4_c80'],
+ ['ir_r7_k3_s2_e4_c160_se0.25'],
+ ['ir_r14_k3_s1_e6_c176_se0.25'],
+ ['ir_r18_k3_s2_e6_c304_se0.25'],
+ ['ir_r5_k3_s1_e6_c512_se0.25'],
+ ]
+
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier),
+ num_features=1280,
+ stem_size=24,
+ round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ act_layer=resolve_act_layer(kwargs, 'silu'),
+ **kwargs,
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
+ """ Creates an EfficientNet-V2 Large model
+
+ Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
+ Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
+ """
+
+ arch_def = [
+ ['cn_r4_k3_s1_e1_c32_skip'],
+ ['er_r7_k3_s2_e4_c64'],
+ ['er_r7_k3_s2_e4_c96'],
+ ['ir_r10_k3_s2_e4_c192_se0.25'],
+ ['ir_r19_k3_s1_e6_c224_se0.25'],
+ ['ir_r25_k3_s2_e6_c384_se0.25'],
+ ['ir_r7_k3_s1_e6_c640_se0.25'],
+ ]
+
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier),
+ num_features=1280,
+ stem_size=32,
+ round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ act_layer=resolve_act_layer(kwargs, 'silu'),
+ **kwargs,
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
+ """ Creates an EfficientNet-V2 Xtra-Large model
+
+ Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
+ Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
+ """
+
+ arch_def = [
+ ['cn_r4_k3_s1_e1_c32_skip'],
+ ['er_r8_k3_s2_e4_c64'],
+ ['er_r8_k3_s2_e4_c96'],
+ ['ir_r16_k3_s2_e4_c192_se0.25'],
+ ['ir_r24_k3_s1_e6_c256_se0.25'],
+ ['ir_r32_k3_s2_e6_c512_se0.25'],
+ ['ir_r8_k3_s1_e6_c640_se0.25'],
+ ]
+
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier),
+ num_features=1280,
+ stem_size=32,
+ round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ act_layer=resolve_act_layer(kwargs, 'silu'),
+ **kwargs,
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a MixNet Small model.
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
+ Paper: https://arxiv.org/abs/1907.09595
+ """
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_e1_c16'], # relu
+ # stage 1, 112x112 in
+ ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu
+ # stage 2, 56x56 in
+ ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
+ # stage 3, 28x28 in
+ ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish
+ # stage 4, 14x14in
+ ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
+ # stage 5, 14x14in
+ ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
+ # 7x7
+ ]
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ num_features=1536,
+ stem_size=16,
+ round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ **kwargs
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a MixNet Medium-Large model.
+
+ Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
+ Paper: https://arxiv.org/abs/1907.09595
+ """
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_e1_c24'], # relu
+ # stage 1, 112x112 in
+ ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu
+ # stage 2, 56x56 in
+ ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
+ # stage 3, 28x28 in
+ ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish
+ # stage 4, 14x14in
+ ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
+ # stage 5, 14x14in
+ ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
+ # 7x7
+ ]
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
+ num_features=1536,
+ stem_size=24,
+ round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ **kwargs
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_tinynet(
+ variant, model_width=1.0, depth_multiplier=1.0, pretrained=False, **kwargs
+):
+ """Creates a TinyNet model.
+ """
+ arch_def = [
+ ['ds_r1_k3_s1_e1_c16_se0.25'], ['ir_r2_k3_s2_e6_c24_se0.25'],
+ ['ir_r2_k5_s2_e6_c40_se0.25'], ['ir_r3_k3_s2_e6_c80_se0.25'],
+ ['ir_r3_k5_s1_e6_c112_se0.25'], ['ir_r4_k5_s2_e6_c192_se0.25'],
+ ['ir_r1_k3_s1_e6_c320_se0.25'],
+ ]
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
+ num_features=max(1280, round_channels(1280, model_width, 8, None)),
+ stem_size=32,
+ fix_stem=True,
+ round_chs_fn=partial(round_channels, multiplier=model_width),
+ act_layer=resolve_act_layer(kwargs, 'swish'),
+ norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ **kwargs,
+ )
+ model = _create_effnet(variant, pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def mnasnet_050(pretrained=False, **kwargs):
+ """ MNASNet B1, depth multiplier of 0.5. """
+ model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mnasnet_075(pretrained=False, **kwargs):
+ """ MNASNet B1, depth multiplier of 0.75. """
+ model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mnasnet_100(pretrained=False, **kwargs):
+ """ MNASNet B1, depth multiplier of 1.0. """
+ model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mnasnet_b1(pretrained=False, **kwargs):
+ """ MNASNet B1, depth multiplier of 1.0. """
+ return mnasnet_100(pretrained, **kwargs)
+
+
+@register_model
+def mnasnet_140(pretrained=False, **kwargs):
+ """ MNASNet B1, depth multiplier of 1.4 """
+ model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def semnasnet_050(pretrained=False, **kwargs):
+ """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """
+ model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def semnasnet_075(pretrained=False, **kwargs):
+ """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """
+ model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def semnasnet_100(pretrained=False, **kwargs):
+ """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """
+ model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mnasnet_a1(pretrained=False, **kwargs):
+ """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """
+ return semnasnet_100(pretrained, **kwargs)
+
+
+@register_model
+def semnasnet_140(pretrained=False, **kwargs):
+ """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """
+ model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mnasnet_small(pretrained=False, **kwargs):
+ """ MNASNet Small, depth multiplier of 1.0. """
+ model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mobilenetv2_100(pretrained=False, **kwargs):
+ """ MobileNet V2 w/ 1.0 channel multiplier """
+ model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mobilenetv2_140(pretrained=False, **kwargs):
+ """ MobileNet V2 w/ 1.4 channel multiplier """
+ model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mobilenetv2_110d(pretrained=False, **kwargs):
+ """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers"""
+ model = _gen_mobilenet_v2(
+ 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mobilenetv2_120d(pretrained=False, **kwargs):
+ """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """
+ model = _gen_mobilenet_v2(
+ 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def fbnetc_100(pretrained=False, **kwargs):
+ """ FBNet-C """
+ if pretrained:
+ # pretrained model trained with non-default BN epsilon
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def spnasnet_100(pretrained=False, **kwargs):
+ """ Single-Path NAS Pixel1"""
+ model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_b0(pretrained=False, **kwargs):
+ """ EfficientNet-B0 """
+ # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_b1(pretrained=False, **kwargs):
+ """ EfficientNet-B1 """
+ # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_b2(pretrained=False, **kwargs):
+ """ EfficientNet-B2 """
+ # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_b2a(pretrained=False, **kwargs):
+ """ EfficientNet-B2 @ 288x288 w/ 1.0 test crop"""
+ # WARN this model def is deprecated, different train/test res + test crop handled by default_cfg now
+ return efficientnet_b2(pretrained=pretrained, **kwargs)
+
+
+@register_model
+def efficientnet_b3(pretrained=False, **kwargs):
+ """ EfficientNet-B3 """
+ # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_b3a(pretrained=False, **kwargs):
+ """ EfficientNet-B3 @ 320x320 w/ 1.0 test crop-pct """
+ # WARN this model def is deprecated, different train/test res + test crop handled by default_cfg now
+ return efficientnet_b3(pretrained=pretrained, **kwargs)
+
+
+@register_model
+def efficientnet_b4(pretrained=False, **kwargs):
+ """ EfficientNet-B4 """
+ # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_b5(pretrained=False, **kwargs):
+ """ EfficientNet-B5 """
+ # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_b6(pretrained=False, **kwargs):
+ """ EfficientNet-B6 """
+ # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_b7(pretrained=False, **kwargs):
+ """ EfficientNet-B7 """
+ # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_b8(pretrained=False, **kwargs):
+ """ EfficientNet-B8 """
+ # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
+ model = _gen_efficientnet(
+ 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_l2(pretrained=False, **kwargs):
+ """ EfficientNet-L2."""
+ # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
+ model = _gen_efficientnet(
+ 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_es(pretrained=False, **kwargs):
+ """ EfficientNet-Edge Small. """
+ model = _gen_efficientnet_edge(
+ 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+@register_model
+def efficientnet_es_pruned(pretrained=False, **kwargs):
+ """ EfficientNet-Edge Small Pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0"""
+ model = _gen_efficientnet_edge(
+ 'efficientnet_es_pruned', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+@register_model
+def efficientnet_em(pretrained=False, **kwargs):
+ """ EfficientNet-Edge-Medium. """
+ model = _gen_efficientnet_edge(
+ 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_el(pretrained=False, **kwargs):
+ """ EfficientNet-Edge-Large. """
+ model = _gen_efficientnet_edge(
+ 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+@register_model
+def efficientnet_el_pruned(pretrained=False, **kwargs):
+ """ EfficientNet-Edge-Large pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0"""
+ model = _gen_efficientnet_edge(
+ 'efficientnet_el_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+@register_model
+def efficientnet_cc_b0_4e(pretrained=False, **kwargs):
+ """ EfficientNet-CondConv-B0 w/ 8 Experts """
+ # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
+ model = _gen_efficientnet_condconv(
+ 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_cc_b0_8e(pretrained=False, **kwargs):
+ """ EfficientNet-CondConv-B0 w/ 8 Experts """
+ # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
+ model = _gen_efficientnet_condconv(
+ 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
+ pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_cc_b1_8e(pretrained=False, **kwargs):
+ """ EfficientNet-CondConv-B1 w/ 8 Experts """
+ # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
+ model = _gen_efficientnet_condconv(
+ 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
+ pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_lite0(pretrained=False, **kwargs):
+ """ EfficientNet-Lite0 """
+ # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
+ model = _gen_efficientnet_lite(
+ 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_lite1(pretrained=False, **kwargs):
+ """ EfficientNet-Lite1 """
+ # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
+ model = _gen_efficientnet_lite(
+ 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_lite2(pretrained=False, **kwargs):
+ """ EfficientNet-Lite2 """
+ # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
+ model = _gen_efficientnet_lite(
+ 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_lite3(pretrained=False, **kwargs):
+ """ EfficientNet-Lite3 """
+ # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
+ model = _gen_efficientnet_lite(
+ 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_lite4(pretrained=False, **kwargs):
+ """ EfficientNet-Lite4 """
+ # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
+ model = _gen_efficientnet_lite(
+ 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_b1_pruned(pretrained=False, **kwargs):
+ """ EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ variant = 'efficientnet_b1_pruned'
+ model = _gen_efficientnet(
+ variant, channel_multiplier=1.0, depth_multiplier=1.1, pruned=True, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_b2_pruned(pretrained=False, **kwargs):
+ """ EfficientNet-B2 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pruned=True,
+ pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnet_b3_pruned(pretrained=False, **kwargs):
+ """ EfficientNet-B3 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pruned=True,
+ pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnetv2_rw_t(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Tiny (Custom variant, tiny not in paper). """
+ model = _gen_efficientnetv2_s(
+ 'efficientnetv2_rw_t', channel_multiplier=0.8, depth_multiplier=0.9, rw=False, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def gc_efficientnetv2_rw_t(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Tiny w/ Global Context Attn (Custom variant, tiny not in paper). """
+ model = _gen_efficientnetv2_s(
+ 'gc_efficientnetv2_rw_t', channel_multiplier=0.8, depth_multiplier=0.9,
+ rw=False, se_layer='gc', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnetv2_rw_s(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Small (RW variant).
+ NOTE: This is my initial (pre official code release) w/ some differences.
+ See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding
+ """
+ model = _gen_efficientnetv2_s('efficientnetv2_rw_s', rw=True, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnetv2_rw_m(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Medium (RW variant).
+ """
+ model = _gen_efficientnetv2_s(
+ 'efficientnetv2_rw_m', channel_multiplier=1.2, depth_multiplier=(1.2,) * 4 + (1.6,) * 2, rw=True,
+ pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnetv2_s(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Small. """
+ model = _gen_efficientnetv2_s('efficientnetv2_s', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnetv2_m(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Medium. """
+ model = _gen_efficientnetv2_m('efficientnetv2_m', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnetv2_l(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Large. """
+ model = _gen_efficientnetv2_l('efficientnetv2_l', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def efficientnetv2_xl(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Xtra-Large. """
+ model = _gen_efficientnetv2_xl('efficientnetv2_xl', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b0(pretrained=False, **kwargs):
+ """ EfficientNet-B0. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b1(pretrained=False, **kwargs):
+ """ EfficientNet-B1. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b2(pretrained=False, **kwargs):
+ """ EfficientNet-B2. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b3(pretrained=False, **kwargs):
+ """ EfficientNet-B3. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b4(pretrained=False, **kwargs):
+ """ EfficientNet-B4. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b5(pretrained=False, **kwargs):
+ """ EfficientNet-B5. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b6(pretrained=False, **kwargs):
+ """ EfficientNet-B6. Tensorflow compatible variant """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b7(pretrained=False, **kwargs):
+ """ EfficientNet-B7. Tensorflow compatible variant """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b8(pretrained=False, **kwargs):
+ """ EfficientNet-B8. Tensorflow compatible variant """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b0_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B0 AdvProp. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b1_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B1 AdvProp. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b2_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B2 AdvProp. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b3_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B3 AdvProp. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b4_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B4 AdvProp. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b5_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B5 AdvProp. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b6_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B6 AdvProp. Tensorflow compatible variant """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b7_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B7 AdvProp. Tensorflow compatible variant """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b8_ap(pretrained=False, **kwargs):
+ """ EfficientNet-B8 AdvProp. Tensorflow compatible variant """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b0_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b1_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b2_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b3_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b4_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b5_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b6_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_b7_ns(pretrained=False, **kwargs):
+ """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs):
+ """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_l2_ns(pretrained=False, **kwargs):
+ """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant """
+ # NOTE for train, drop_rate should be 0.5
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet(
+ 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_es(pretrained=False, **kwargs):
+ """ EfficientNet-Edge Small. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_edge(
+ 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_em(pretrained=False, **kwargs):
+ """ EfficientNet-Edge-Medium. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_edge(
+ 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_el(pretrained=False, **kwargs):
+ """ EfficientNet-Edge-Large. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_edge(
+ 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
+ """ EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """
+ # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_condconv(
+ 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
+ """ EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """
+ # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_condconv(
+ 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
+ pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
+ """ EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """
+ # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_condconv(
+ 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
+ pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_lite0(pretrained=False, **kwargs):
+ """ EfficientNet-Lite0 """
+ # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_lite(
+ 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_lite1(pretrained=False, **kwargs):
+ """ EfficientNet-Lite1 """
+ # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_lite(
+ 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_lite2(pretrained=False, **kwargs):
+ """ EfficientNet-Lite2 """
+ # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_lite(
+ 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_lite3(pretrained=False, **kwargs):
+ """ EfficientNet-Lite3 """
+ # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_lite(
+ 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnet_lite4(pretrained=False, **kwargs):
+ """ EfficientNet-Lite4 """
+ # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnet_lite(
+ 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
+ return model
+
+
+
+@register_model
+def tf_efficientnetv2_s(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Small. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_s('tf_efficientnetv2_s', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_m(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Medium. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_m('tf_efficientnetv2_m', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_l(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Large. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_l('tf_efficientnetv2_l', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_s_in21ft1k(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Small. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_s('tf_efficientnetv2_s_in21ft1k', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_m_in21ft1k(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Medium. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_m('tf_efficientnetv2_m_in21ft1k', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_l_in21ft1k(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Large. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_l('tf_efficientnetv2_l_in21ft1k', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_xl_in21ft1k(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Xtra-Large. Pretrained on ImageNet-21k, fine-tuned on 1k. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_xl('tf_efficientnetv2_xl_in21ft1k', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_s_in21k(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Small w/ ImageNet-21k pretrained weights. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_s('tf_efficientnetv2_s_in21k', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_m_in21k(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Medium w/ ImageNet-21k pretrained weights. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_m('tf_efficientnetv2_m_in21k', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_l_in21k(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Large w/ ImageNet-21k pretrained weights. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_l('tf_efficientnetv2_l_in21k', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_xl_in21k(pretrained=False, **kwargs):
+ """ EfficientNet-V2 Xtra-Large w/ ImageNet-21k pretrained weights. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_xl('tf_efficientnetv2_xl_in21k', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_b0(pretrained=False, **kwargs):
+ """ EfficientNet-V2-B0. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_base('tf_efficientnetv2_b0', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_b1(pretrained=False, **kwargs):
+ """ EfficientNet-V2-B1. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_base(
+ 'tf_efficientnetv2_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_b2(pretrained=False, **kwargs):
+ """ EfficientNet-V2-B2. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_base(
+ 'tf_efficientnetv2_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_efficientnetv2_b3(pretrained=False, **kwargs):
+ """ EfficientNet-V2-B3. Tensorflow compatible variant """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_efficientnetv2_base(
+ 'tf_efficientnetv2_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mixnet_s(pretrained=False, **kwargs):
+ """Creates a MixNet Small model.
+ """
+ model = _gen_mixnet_s(
+ 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mixnet_m(pretrained=False, **kwargs):
+ """Creates a MixNet Medium model.
+ """
+ model = _gen_mixnet_m(
+ 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mixnet_l(pretrained=False, **kwargs):
+ """Creates a MixNet Large model.
+ """
+ model = _gen_mixnet_m(
+ 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mixnet_xl(pretrained=False, **kwargs):
+ """Creates a MixNet Extra-Large model.
+ Not a paper spec, experimental def by RW w/ depth scaling.
+ """
+ model = _gen_mixnet_m(
+ 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mixnet_xxl(pretrained=False, **kwargs):
+ """Creates a MixNet Double Extra Large model.
+ Not a paper spec, experimental def by RW w/ depth scaling.
+ """
+ model = _gen_mixnet_m(
+ 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_mixnet_s(pretrained=False, **kwargs):
+ """Creates a MixNet Small model. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mixnet_s(
+ 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_mixnet_m(pretrained=False, **kwargs):
+ """Creates a MixNet Medium model. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mixnet_m(
+ 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_mixnet_l(pretrained=False, **kwargs):
+ """Creates a MixNet Large model. Tensorflow compatible variant
+ """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mixnet_m(
+ 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tinynet_a(pretrained=False, **kwargs):
+ model = _gen_tinynet('tinynet_a', 1.0, 1.2, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tinynet_b(pretrained=False, **kwargs):
+ model = _gen_tinynet('tinynet_b', 0.75, 1.1, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tinynet_c(pretrained=False, **kwargs):
+ model = _gen_tinynet('tinynet_c', 0.54, 0.85, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tinynet_d(pretrained=False, **kwargs):
+ model = _gen_tinynet('tinynet_d', 0.54, 0.695, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tinynet_e(pretrained=False, **kwargs):
+ model = _gen_tinynet('tinynet_e', 0.51, 0.6, pretrained=pretrained, **kwargs)
+ return model
diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py
new file mode 100644
index 0000000..b1fec44
--- /dev/null
+++ b/timm/models/efficientnet_blocks.py
@@ -0,0 +1,323 @@
+""" EfficientNet, MobileNetV3, etc Blocks
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from .layers import create_conv2d, drop_path, make_divisible, create_act_layer
+from .layers.activations import sigmoid
+
+__all__ = [
+ 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual']
+
+
+class SqueezeExcite(nn.Module):
+ """ Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family
+
+ Args:
+ in_chs (int): input channels to layer
+ rd_ratio (float): ratio of squeeze reduction
+ act_layer (nn.Module): activation layer of containing block
+ gate_layer (Callable): attention gate function
+ force_act_layer (nn.Module): override block's activation fn if this is set/bound
+ rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs
+ """
+
+ def __init__(
+ self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU,
+ gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None):
+ super(SqueezeExcite, self).__init__()
+ if rd_channels is None:
+ rd_round_fn = rd_round_fn or round
+ rd_channels = rd_round_fn(in_chs * rd_ratio)
+ act_layer = force_act_layer or act_layer
+ self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True)
+ self.act1 = create_act_layer(act_layer, inplace=True)
+ self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True)
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ x_se = x.mean((2, 3), keepdim=True)
+ x_se = self.conv_reduce(x_se)
+ x_se = self.act1(x_se)
+ x_se = self.conv_expand(x_se)
+ return x * self.gate(x_se)
+
+
+class ConvBnAct(nn.Module):
+ """ Conv + Norm Layer + Activation w/ optional skip connection
+ """
+ def __init__(
+ self, in_chs, out_chs, kernel_size, stride=1, dilation=1, pad_type='',
+ skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.):
+ super(ConvBnAct, self).__init__()
+ self.has_residual = skip and stride == 1 and in_chs == out_chs
+ self.drop_path_rate = drop_path_rate
+ self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
+ self.bn1 = norm_layer(out_chs)
+ self.act1 = act_layer(inplace=True)
+
+ def feature_info(self, location):
+ if location == 'expansion': # output of conv after act, same as block coutput
+ info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
+ else: # location == 'bottleneck', block output
+ info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
+ return info
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ if self.has_residual:
+ if self.drop_path_rate > 0.:
+ x = drop_path(x, self.drop_path_rate, self.training)
+ x += shortcut
+ return x
+
+
+class DepthwiseSeparableConv(nn.Module):
+ """ DepthwiseSeparable block
+ Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion
+ (factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
+ """
+ def __init__(
+ self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
+ noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
+ se_layer=None, drop_path_rate=0.):
+ super(DepthwiseSeparableConv, self).__init__()
+ self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
+ self.has_pw_act = pw_act # activation after point-wise conv
+ self.drop_path_rate = drop_path_rate
+
+ self.conv_dw = create_conv2d(
+ in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
+ self.bn1 = norm_layer(in_chs)
+ self.act1 = act_layer(inplace=True)
+
+ # Squeeze-and-excitation
+ self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity()
+
+ self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
+ self.bn2 = norm_layer(out_chs)
+ self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
+
+ def feature_info(self, location):
+ if location == 'expansion': # after SE, input to PW
+ info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
+ else: # location == 'bottleneck', block output
+ info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
+ return info
+
+ def forward(self, x):
+ shortcut = x
+
+ x = self.conv_dw(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+
+ x = self.se(x)
+
+ x = self.conv_pw(x)
+ x = self.bn2(x)
+ x = self.act2(x)
+
+ if self.has_residual:
+ if self.drop_path_rate > 0.:
+ x = drop_path(x, self.drop_path_rate, self.training)
+ x += shortcut
+ return x
+
+
+class InvertedResidual(nn.Module):
+ """ Inverted residual block w/ optional SE
+
+ Originally used in MobileNet-V2 - https://arxiv.org/abs/1801.04381v4, this layer is often
+ referred to as 'MBConv' for (Mobile inverted bottleneck conv) and is also used in
+ * MNasNet - https://arxiv.org/abs/1807.11626
+ * EfficientNet - https://arxiv.org/abs/1905.11946
+ * MobileNet-V3 - https://arxiv.org/abs/1905.02244
+ """
+
+ def __init__(
+ self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
+ noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
+ norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
+ super(InvertedResidual, self).__init__()
+ conv_kwargs = conv_kwargs or {}
+ mid_chs = make_divisible(in_chs * exp_ratio)
+ self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
+ self.drop_path_rate = drop_path_rate
+
+ # Point-wise expansion
+ self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
+ self.bn1 = norm_layer(mid_chs)
+ self.act1 = act_layer(inplace=True)
+
+ # Depth-wise convolution
+ self.conv_dw = create_conv2d(
+ mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
+ padding=pad_type, depthwise=True, **conv_kwargs)
+ self.bn2 = norm_layer(mid_chs)
+ self.act2 = act_layer(inplace=True)
+
+ # Squeeze-and-excitation
+ self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
+
+ # Point-wise linear projection
+ self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
+ self.bn3 = norm_layer(out_chs)
+
+ def feature_info(self, location):
+ if location == 'expansion': # after SE, input to PWL
+ info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
+ else: # location == 'bottleneck', block output
+ info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
+ return info
+
+ def forward(self, x):
+ shortcut = x
+
+ # Point-wise expansion
+ x = self.conv_pw(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+
+ # Depth-wise convolution
+ x = self.conv_dw(x)
+ x = self.bn2(x)
+ x = self.act2(x)
+
+ # Squeeze-and-excitation
+ x = self.se(x)
+
+ # Point-wise linear projection
+ x = self.conv_pwl(x)
+ x = self.bn3(x)
+
+ if self.has_residual:
+ if self.drop_path_rate > 0.:
+ x = drop_path(x, self.drop_path_rate, self.training)
+ x += shortcut
+
+ return x
+
+
+class CondConvResidual(InvertedResidual):
+ """ Inverted residual block w/ CondConv routing"""
+
+ def __init__(
+ self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
+ noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
+ norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
+
+ self.num_experts = num_experts
+ conv_kwargs = dict(num_experts=self.num_experts)
+
+ super(CondConvResidual, self).__init__(
+ in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
+ act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
+ pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs,
+ drop_path_rate=drop_path_rate)
+
+ self.routing_fn = nn.Linear(in_chs, self.num_experts)
+
+ def forward(self, x):
+ shortcut = x
+
+ # CondConv routing
+ pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
+ routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
+
+ # Point-wise expansion
+ x = self.conv_pw(x, routing_weights)
+ x = self.bn1(x)
+ x = self.act1(x)
+
+ # Depth-wise convolution
+ x = self.conv_dw(x, routing_weights)
+ x = self.bn2(x)
+ x = self.act2(x)
+
+ # Squeeze-and-excitation
+ x = self.se(x)
+
+ # Point-wise linear projection
+ x = self.conv_pwl(x, routing_weights)
+ x = self.bn3(x)
+
+ if self.has_residual:
+ if self.drop_path_rate > 0.:
+ x = drop_path(x, self.drop_path_rate, self.training)
+ x += shortcut
+ return x
+
+
+class EdgeResidual(nn.Module):
+ """ Residual block with expansion convolution followed by pointwise-linear w/ stride
+
+ Originally introduced in `EfficientNet-EdgeTPU: Creating Accelerator-Optimized Neural Networks with AutoML`
+ - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html
+
+ This layer is also called FusedMBConv in the MobileDet, EfficientNet-X, and EfficientNet-V2 papers
+ * MobileDet - https://arxiv.org/abs/2004.14525
+ * EfficientNet-X - https://arxiv.org/abs/2102.05610
+ * EfficientNet-V2 - https://arxiv.org/abs/2104.00298
+ """
+
+ def __init__(
+ self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='',
+ force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU,
+ norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
+ super(EdgeResidual, self).__init__()
+ if force_in_chs > 0:
+ mid_chs = make_divisible(force_in_chs * exp_ratio)
+ else:
+ mid_chs = make_divisible(in_chs * exp_ratio)
+ self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
+ self.drop_path_rate = drop_path_rate
+
+ # Expansion convolution
+ self.conv_exp = create_conv2d(
+ in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
+ self.bn1 = norm_layer(mid_chs)
+ self.act1 = act_layer(inplace=True)
+
+ # Squeeze-and-excitation
+ self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
+
+ # Point-wise linear projection
+ self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)
+ self.bn2 = norm_layer(out_chs)
+
+ def feature_info(self, location):
+ if location == 'expansion': # after SE, before PWL
+ info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
+ else: # location == 'bottleneck', block output
+ info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
+ return info
+
+ def forward(self, x):
+ shortcut = x
+
+ # Expansion convolution
+ x = self.conv_exp(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+
+ # Squeeze-and-excitation
+ x = self.se(x)
+
+ # Point-wise linear projection
+ x = self.conv_pwl(x)
+ x = self.bn2(x)
+
+ if self.has_residual:
+ if self.drop_path_rate > 0.:
+ x = drop_path(x, self.drop_path_rate, self.training)
+ x += shortcut
+
+ return x
diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py
new file mode 100644
index 0000000..a23e827
--- /dev/null
+++ b/timm/models/efficientnet_builder.py
@@ -0,0 +1,463 @@
+""" EfficientNet, MobileNetV3, etc Builder
+
+Assembles EfficieNet and related network feature blocks from string definitions.
+Handles stride, dilation calculations, and selects feature extraction points.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import logging
+import math
+import re
+from copy import deepcopy
+from functools import partial
+
+import torch.nn as nn
+
+from .efficientnet_blocks import *
+from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
+
+__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
+ 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
+
+_logger = logging.getLogger(__name__)
+
+
+_DEBUG_BUILDER = False
+
+# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
+# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
+# NOTE: momentum varies btw .99 and .9997 depending on source
+# .99 in official TF TPU impl
+# .9997 (/w .999 in search space) for paper
+BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
+BN_EPS_TF_DEFAULT = 1e-3
+_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
+
+
+def get_bn_args_tf():
+ return _BN_ARGS_TF.copy()
+
+
+def resolve_bn_args(kwargs):
+ bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
+ bn_momentum = kwargs.pop('bn_momentum', None)
+ if bn_momentum is not None:
+ bn_args['momentum'] = bn_momentum
+ bn_eps = kwargs.pop('bn_eps', None)
+ if bn_eps is not None:
+ bn_args['eps'] = bn_eps
+ return bn_args
+
+
+def resolve_act_layer(kwargs, default='relu'):
+ return get_act_layer(kwargs.pop('act_layer', default))
+
+
+def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):
+ """Round number of filters based on depth multiplier."""
+ if not multiplier:
+ return channels
+ return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit)
+
+
+def _log_info_if(msg, condition):
+ if condition:
+ _logger.info(msg)
+
+
+def _parse_ksize(ss):
+ if ss.isdigit():
+ return int(ss)
+ else:
+ return [int(k) for k in ss.split('.')]
+
+
+def _decode_block_str(block_str):
+ """ Decode block definition string
+
+ Gets a list of block arg (dicts) through a string notation of arguments.
+ E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
+
+ All args can exist in any order with the exception of the leading string which
+ is assumed to indicate the block type.
+
+ leading string - block type (
+ ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
+ r - number of repeat blocks,
+ k - kernel size,
+ s - strides (1-9),
+ e - expansion ratio,
+ c - output channels,
+ se - squeeze/excitation ratio
+ n - activation fn ('re', 'r6', 'hs', or 'sw')
+ Args:
+ block_str: a string representation of block arguments.
+ Returns:
+ A list of block args (dicts)
+ Raises:
+ ValueError: if the string def not properly specified (TODO)
+ """
+ assert isinstance(block_str, str)
+ ops = block_str.split('_')
+ block_type = ops[0] # take the block type off the front
+ ops = ops[1:]
+ options = {}
+ skip = None
+ for op in ops:
+ # string options being checked on individual basis, combine if they grow
+ if op == 'noskip':
+ skip = False # force no skip connection
+ elif op == 'skip':
+ skip = True # force a skip connection
+ elif op.startswith('n'):
+ # activation fn
+ key = op[0]
+ v = op[1:]
+ if v == 're':
+ value = get_act_layer('relu')
+ elif v == 'r6':
+ value = get_act_layer('relu6')
+ elif v == 'hs':
+ value = get_act_layer('hard_swish')
+ elif v == 'sw':
+ value = get_act_layer('swish') # aka SiLU
+ elif v == 'mi':
+ value = get_act_layer('mish')
+ else:
+ continue
+ options[key] = value
+ else:
+ # all numeric options
+ splits = re.split(r'(\d.*)', op)
+ if len(splits) >= 2:
+ key, value = splits[:2]
+ options[key] = value
+
+ # if act_layer is None, the model default (passed to model init) will be used
+ act_layer = options['n'] if 'n' in options else None
+ exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
+ pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
+ force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
+
+ num_repeat = int(options['r'])
+ # each type of block has different valid arguments, fill accordingly
+ if block_type == 'ir':
+ block_args = dict(
+ block_type=block_type,
+ dw_kernel_size=_parse_ksize(options['k']),
+ exp_kernel_size=exp_kernel_size,
+ pw_kernel_size=pw_kernel_size,
+ out_chs=int(options['c']),
+ exp_ratio=float(options['e']),
+ se_ratio=float(options['se']) if 'se' in options else 0.,
+ stride=int(options['s']),
+ act_layer=act_layer,
+ noskip=skip is False,
+ )
+ if 'cc' in options:
+ block_args['num_experts'] = int(options['cc'])
+ elif block_type == 'ds' or block_type == 'dsa':
+ block_args = dict(
+ block_type=block_type,
+ dw_kernel_size=_parse_ksize(options['k']),
+ pw_kernel_size=pw_kernel_size,
+ out_chs=int(options['c']),
+ se_ratio=float(options['se']) if 'se' in options else 0.,
+ stride=int(options['s']),
+ act_layer=act_layer,
+ pw_act=block_type == 'dsa',
+ noskip=block_type == 'dsa' or skip is False,
+ )
+ elif block_type == 'er':
+ block_args = dict(
+ block_type=block_type,
+ exp_kernel_size=_parse_ksize(options['k']),
+ pw_kernel_size=pw_kernel_size,
+ out_chs=int(options['c']),
+ exp_ratio=float(options['e']),
+ force_in_chs=force_in_chs,
+ se_ratio=float(options['se']) if 'se' in options else 0.,
+ stride=int(options['s']),
+ act_layer=act_layer,
+ noskip=skip is False,
+ )
+ elif block_type == 'cn':
+ block_args = dict(
+ block_type=block_type,
+ kernel_size=int(options['k']),
+ out_chs=int(options['c']),
+ stride=int(options['s']),
+ act_layer=act_layer,
+ skip=skip is True,
+ )
+ else:
+ assert False, 'Unknown block type (%s)' % block_type
+
+ return block_args, num_repeat
+
+
+def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
+ """ Per-stage depth scaling
+ Scales the block repeats in each stage. This depth scaling impl maintains
+ compatibility with the EfficientNet scaling method, while allowing sensible
+ scaling for other models that may have multiple block arg definitions in each stage.
+ """
+
+ # We scale the total repeat count for each stage, there may be multiple
+ # block arg defs per stage so we need to sum.
+ num_repeat = sum(repeats)
+ if depth_trunc == 'round':
+ # Truncating to int by rounding allows stages with few repeats to remain
+ # proportionally smaller for longer. This is a good choice when stage definitions
+ # include single repeat stages that we'd prefer to keep that way as long as possible
+ num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
+ else:
+ # The default for EfficientNet truncates repeats to int via 'ceil'.
+ # Any multiplier > 1.0 will result in an increased depth for every stage.
+ num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
+
+ # Proportionally distribute repeat count scaling to each block definition in the stage.
+ # Allocation is done in reverse as it results in the first block being less likely to be scaled.
+ # The first block makes less sense to repeat in most of the arch definitions.
+ repeats_scaled = []
+ for r in repeats[::-1]:
+ rs = max(1, round((r / num_repeat * num_repeat_scaled)))
+ repeats_scaled.append(rs)
+ num_repeat -= r
+ num_repeat_scaled -= rs
+ repeats_scaled = repeats_scaled[::-1]
+
+ # Apply the calculated scaling to each block arg in the stage
+ sa_scaled = []
+ for ba, rep in zip(stack_args, repeats_scaled):
+ sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
+ return sa_scaled
+
+
+def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
+ arch_args = []
+ if isinstance(depth_multiplier, tuple):
+ assert len(depth_multiplier) == len(arch_def)
+ else:
+ depth_multiplier = (depth_multiplier,) * len(arch_def)
+ for stack_idx, (block_strings, multiplier) in enumerate(zip(arch_def, depth_multiplier)):
+ assert isinstance(block_strings, list)
+ stack_args = []
+ repeats = []
+ for block_str in block_strings:
+ assert isinstance(block_str, str)
+ ba, rep = _decode_block_str(block_str)
+ if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
+ ba['num_experts'] *= experts_multiplier
+ stack_args.append(ba)
+ repeats.append(rep)
+ if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
+ arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
+ else:
+ arch_args.append(_scale_stage_depth(stack_args, repeats, multiplier, depth_trunc))
+ return arch_args
+
+
+class EfficientNetBuilder:
+ """ Build Trunk Blocks
+
+ This ended up being somewhat of a cross between
+ https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
+ and
+ https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
+
+ """
+ def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False,
+ act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
+ self.output_stride = output_stride
+ self.pad_type = pad_type
+ self.round_chs_fn = round_chs_fn
+ self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
+ self.act_layer = act_layer
+ self.norm_layer = norm_layer
+ self.se_layer = get_attn(se_layer)
+ try:
+ self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg
+ self.se_has_ratio = True
+ except TypeError:
+ self.se_has_ratio = False
+ self.drop_path_rate = drop_path_rate
+ if feature_location == 'depthwise':
+ # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
+ _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
+ feature_location = 'expansion'
+ self.feature_location = feature_location
+ assert feature_location in ('bottleneck', 'expansion', '')
+ self.verbose = _DEBUG_BUILDER
+
+ # state updated during build, consumed by model
+ self.in_chs = None
+ self.features = []
+
+ def _make_block(self, ba, block_idx, block_count):
+ drop_path_rate = self.drop_path_rate * block_idx / block_count
+ bt = ba.pop('block_type')
+ ba['in_chs'] = self.in_chs
+ ba['out_chs'] = self.round_chs_fn(ba['out_chs'])
+ if 'force_in_chs' in ba and ba['force_in_chs']:
+ # NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
+ ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs'])
+ ba['pad_type'] = self.pad_type
+ # block act fn overrides the model default
+ ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
+ assert ba['act_layer'] is not None
+ ba['norm_layer'] = self.norm_layer
+ ba['drop_path_rate'] = drop_path_rate
+ if bt != 'cn':
+ se_ratio = ba.pop('se_ratio')
+ if se_ratio and self.se_layer is not None:
+ if not self.se_from_exp:
+ # adjust se_ratio by expansion ratio if calculating se channels from block input
+ se_ratio /= ba.get('exp_ratio', 1.0)
+ if self.se_has_ratio:
+ ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
+ else:
+ ba['se_layer'] = self.se_layer
+
+ if bt == 'ir':
+ _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
+ block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba)
+ elif bt == 'ds' or bt == 'dsa':
+ _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
+ block = DepthwiseSeparableConv(**ba)
+ elif bt == 'er':
+ _log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
+ block = EdgeResidual(**ba)
+ elif bt == 'cn':
+ _log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
+ block = ConvBnAct(**ba)
+ else:
+ assert False, 'Uknkown block type (%s) while building model.' % bt
+
+ self.in_chs = ba['out_chs'] # update in_chs for arg of next block
+ return block
+
+ def __call__(self, in_chs, model_block_args):
+ """ Build the blocks
+ Args:
+ in_chs: Number of input-channels passed to first block
+ model_block_args: A list of lists, outer list defines stages, inner
+ list contains strings defining block configuration(s)
+ Return:
+ List of block stacks (each stack wrapped in nn.Sequential)
+ """
+ _log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose)
+ self.in_chs = in_chs
+ total_block_count = sum([len(x) for x in model_block_args])
+ total_block_idx = 0
+ current_stride = 2
+ current_dilation = 1
+ stages = []
+ if model_block_args[0][0]['stride'] > 1:
+ # if the first block starts with a stride, we need to extract first level feat from stem
+ feature_info = dict(
+ module='act1', num_chs=in_chs, stage=0, reduction=current_stride,
+ hook_type='forward' if self.feature_location != 'bottleneck' else '')
+ self.features.append(feature_info)
+
+ # outer list of block_args defines the stacks
+ for stack_idx, stack_args in enumerate(model_block_args):
+ last_stack = stack_idx + 1 == len(model_block_args)
+ _log_info_if('Stack: {}'.format(stack_idx), self.verbose)
+ assert isinstance(stack_args, list)
+
+ blocks = []
+ # each stack (stage of blocks) contains a list of block arguments
+ for block_idx, block_args in enumerate(stack_args):
+ last_block = block_idx + 1 == len(stack_args)
+ _log_info_if(' Block: {}'.format(block_idx), self.verbose)
+
+ assert block_args['stride'] in (1, 2)
+ if block_idx >= 1: # only the first block in any stack can have a stride > 1
+ block_args['stride'] = 1
+
+ extract_features = False
+ if last_block:
+ next_stack_idx = stack_idx + 1
+ extract_features = next_stack_idx >= len(model_block_args) or \
+ model_block_args[next_stack_idx][0]['stride'] > 1
+
+ next_dilation = current_dilation
+ if block_args['stride'] > 1:
+ next_output_stride = current_stride * block_args['stride']
+ if next_output_stride > self.output_stride:
+ next_dilation = current_dilation * block_args['stride']
+ block_args['stride'] = 1
+ _log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format(
+ self.output_stride), self.verbose)
+ else:
+ current_stride = next_output_stride
+ block_args['dilation'] = current_dilation
+ if next_dilation != current_dilation:
+ current_dilation = next_dilation
+
+ # create the block
+ block = self._make_block(block_args, total_block_idx, total_block_count)
+ blocks.append(block)
+
+ # stash feature module name and channel info for model feature extraction
+ if extract_features:
+ feature_info = dict(
+ stage=stack_idx + 1, reduction=current_stride, **block.feature_info(self.feature_location))
+ module_name = f'blocks.{stack_idx}.{block_idx}'
+ leaf_name = feature_info.get('module', '')
+ feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name
+ self.features.append(feature_info)
+
+ total_block_idx += 1 # incr global block idx (across all stacks)
+ stages.append(nn.Sequential(*blocks))
+ return stages
+
+
+def _init_weight_goog(m, n='', fix_group_fanout=True):
+ """ Weight initialization as per Tensorflow official implementations.
+
+ Args:
+ m (nn.Module): module to init
+ n (str): module name
+ fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
+
+ Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
+ * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
+ * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
+ """
+ if isinstance(m, CondConv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ if fix_group_fanout:
+ fan_out //= m.groups
+ init_weight_fn = get_condconv_initializer(
+ lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
+ init_weight_fn(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ if fix_group_fanout:
+ fan_out //= m.groups
+ nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ fan_out = m.weight.size(0) # fan-out
+ fan_in = 0
+ if 'routing_fn' in n:
+ fan_in = m.weight.size(1)
+ init_range = 1.0 / math.sqrt(fan_in + fan_out)
+ nn.init.uniform_(m.weight, -init_range, init_range)
+ nn.init.zeros_(m.bias)
+
+
+def efficientnet_init_weights(model: nn.Module, init_fn=None):
+ init_fn = init_fn or _init_weight_goog
+ for n, m in model.named_modules():
+ init_fn(m, n)
+
diff --git a/timm/models/factory.py b/timm/models/factory.py
new file mode 100644
index 0000000..d040a9f
--- /dev/null
+++ b/timm/models/factory.py
@@ -0,0 +1,86 @@
+from .registry import is_model, is_model_in_modules, model_entrypoint
+from .helpers import load_checkpoint
+from .layers import set_layer_config
+from .hub import load_model_config_from_hf
+
+
+def split_model_name(model_name):
+ model_split = model_name.split(':', 1)
+ if len(model_split) == 1:
+ return '', model_split[0]
+ else:
+ source_name, model_name = model_split
+ assert source_name in ('timm', 'hf_hub')
+ return source_name, model_name
+
+
+def safe_model_name(model_name, remove_source=True):
+ def make_safe(name):
+ return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
+ if remove_source:
+ model_name = split_model_name(model_name)[-1]
+ return make_safe(model_name)
+
+
+def create_model(
+ model_name,
+ pretrained=False,
+ checkpoint_path='',
+ scriptable=None,
+ exportable=None,
+ no_jit=None,
+ **kwargs):
+ """Create a model
+
+ Args:
+ model_name (str): name of model to instantiate
+ pretrained (bool): load pretrained ImageNet-1k weights if true
+ checkpoint_path (str): path of checkpoint to load after model is initialized
+ scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
+ exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
+ no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
+
+ Keyword Args:
+ drop_rate (float): dropout rate for training (default: 0.0)
+ global_pool (str): global pool type (default: 'avg')
+ **: other kwargs are model specific
+ """
+ source_name, model_name = split_model_name(model_name)
+
+ # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
+ is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
+ if not is_efficientnet:
+ kwargs.pop('bn_tf', None)
+ kwargs.pop('bn_momentum', None)
+ kwargs.pop('bn_eps', None)
+
+ # handle backwards compat with drop_connect -> drop_path change
+ drop_connect_rate = kwargs.pop('drop_connect_rate', None)
+ if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
+ print("WARNING: 'drop_connect' as an argument is deprecated, please use 'drop_path'."
+ " Setting drop_path to %f." % drop_connect_rate)
+ kwargs['drop_path_rate'] = drop_connect_rate
+
+ # Parameters that aren't supported by all models or are intended to only override model defaults if set
+ # should default to None in command line args/cfg. Remove them if they are present and not set so that
+ # non-supporting models don't break and default args remain in effect.
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+
+ if source_name == 'hf_hub':
+ # For model names specified in the form `hf_hub:path/architecture_name#revision`,
+ # load model weights + default_cfg from Hugging Face hub.
+ hf_default_cfg, model_name = load_model_config_from_hf(model_name)
+ kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday
+
+ if is_model(model_name):
+ create_fn = model_entrypoint(model_name)
+ else:
+ raise RuntimeError('Unknown model (%s)' % model_name)
+
+ with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
+ model = create_fn(pretrained=pretrained, **kwargs)
+
+ if checkpoint_path:
+ load_checkpoint(model, checkpoint_path)
+
+ return model
diff --git a/timm/models/features.py b/timm/models/features.py
new file mode 100644
index 0000000..b1d6890
--- /dev/null
+++ b/timm/models/features.py
@@ -0,0 +1,284 @@
+""" PyTorch Feature Extraction Helpers
+
+A collection of classes, functions, modules to help extract features from models
+and provide a common interface for describing them.
+
+The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
+https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from collections import OrderedDict, defaultdict
+from copy import deepcopy
+from functools import partial
+from typing import Dict, List, Tuple
+
+import torch
+import torch.nn as nn
+
+
+class FeatureInfo:
+
+ def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
+ prev_reduction = 1
+ for fi in feature_info:
+ # sanity check the mandatory fields, there may be additional fields depending on the model
+ assert 'num_chs' in fi and fi['num_chs'] > 0
+ assert 'reduction' in fi and fi['reduction'] >= prev_reduction
+ prev_reduction = fi['reduction']
+ assert 'module' in fi
+ self.out_indices = out_indices
+ self.info = feature_info
+
+ def from_other(self, out_indices: Tuple[int]):
+ return FeatureInfo(deepcopy(self.info), out_indices)
+
+ def get(self, key, idx=None):
+ """ Get value by key at specified index (indices)
+ if idx == None, returns value for key at each output index
+ if idx is an integer, return value for that feature module index (ignoring output indices)
+ if idx is a list/tupple, return value for each module index (ignoring output indices)
+ """
+ if idx is None:
+ return [self.info[i][key] for i in self.out_indices]
+ if isinstance(idx, (tuple, list)):
+ return [self.info[i][key] for i in idx]
+ else:
+ return self.info[idx][key]
+
+ def get_dicts(self, keys=None, idx=None):
+ """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
+ """
+ if idx is None:
+ if keys is None:
+ return [self.info[i] for i in self.out_indices]
+ else:
+ return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
+ if isinstance(idx, (tuple, list)):
+ return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
+ else:
+ return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
+
+ def channels(self, idx=None):
+ """ feature channels accessor
+ """
+ return self.get('num_chs', idx)
+
+ def reduction(self, idx=None):
+ """ feature reduction (output stride) accessor
+ """
+ return self.get('reduction', idx)
+
+ def module_name(self, idx=None):
+ """ feature module name accessor
+ """
+ return self.get('module', idx)
+
+ def __getitem__(self, item):
+ return self.info[item]
+
+ def __len__(self):
+ return len(self.info)
+
+
+class FeatureHooks:
+ """ Feature Hook Helper
+
+ This module helps with the setup and extraction of hooks for extracting features from
+ internal nodes in a model by node name. This works quite well in eager Python but needs
+ redesign for torcscript.
+ """
+
+ def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
+ # setup feature hooks
+ modules = {k: v for k, v in named_modules}
+ for i, h in enumerate(hooks):
+ hook_name = h['module']
+ m = modules[hook_name]
+ hook_id = out_map[i] if out_map else hook_name
+ hook_fn = partial(self._collect_output_hook, hook_id)
+ hook_type = h['hook_type'] if 'hook_type' in h else default_hook_type
+ if hook_type == 'forward_pre':
+ m.register_forward_pre_hook(hook_fn)
+ elif hook_type == 'forward':
+ m.register_forward_hook(hook_fn)
+ else:
+ assert False, "Unsupported hook type"
+ self._feature_outputs = defaultdict(OrderedDict)
+
+ def _collect_output_hook(self, hook_id, *args):
+ x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
+ if isinstance(x, tuple):
+ x = x[0] # unwrap input tuple
+ self._feature_outputs[x.device][hook_id] = x
+
+ def get_output(self, device) -> Dict[str, torch.tensor]:
+ output = self._feature_outputs[device]
+ self._feature_outputs[device] = OrderedDict() # clear after reading
+ return output
+
+
+def _module_list(module, flatten_sequential=False):
+ # a yield/iter would be better for this but wouldn't be compatible with torchscript
+ ml = []
+ for name, module in module.named_children():
+ if flatten_sequential and isinstance(module, nn.Sequential):
+ # first level of Sequential containers is flattened into containing model
+ for child_name, child_module in module.named_children():
+ combined = [name, child_name]
+ ml.append(('_'.join(combined), '.'.join(combined), child_module))
+ else:
+ ml.append((name, name, module))
+ return ml
+
+
+def _get_feature_info(net, out_indices):
+ feature_info = getattr(net, 'feature_info')
+ if isinstance(feature_info, FeatureInfo):
+ return feature_info.from_other(out_indices)
+ elif isinstance(feature_info, (list, tuple)):
+ return FeatureInfo(net.feature_info, out_indices)
+ else:
+ assert False, "Provided feature_info is not valid"
+
+
+def _get_return_layers(feature_info, out_map):
+ module_names = feature_info.module_name()
+ return_layers = {}
+ for i, name in enumerate(module_names):
+ return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
+ return return_layers
+
+
+class FeatureDictNet(nn.ModuleDict):
+ """ Feature extractor with OrderedDict return
+
+ Wrap a model and extract features as specified by the out indices, the network is
+ partially re-built from contained modules.
+
+ There is a strong assumption that the modules have been registered into the model in the same
+ order as they are used. There should be no reuse of the same nn.Module more than once, including
+ trivial modules like `self.relu = nn.ReLU`.
+
+ Only submodules that are directly assigned to the model class (`model.feature1`) or at most
+ one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
+ All Sequential containers that are directly assigned to the original model will have their
+ modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
+
+ Arguments:
+ model (nn.Module): model from which we will extract the features
+ out_indices (tuple[int]): model output indices to extract features for
+ out_map (sequence): list or tuple specifying desired return id for each out index,
+ otherwise str(index) is used
+ feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
+ vs select element [0]
+ flatten_sequential (bool): whether to flatten sequential modules assigned to model
+ """
+ def __init__(
+ self, model,
+ out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
+ super(FeatureDictNet, self).__init__()
+ self.feature_info = _get_feature_info(model, out_indices)
+ self.concat = feature_concat
+ self.return_layers = {}
+ return_layers = _get_return_layers(self.feature_info, out_map)
+ modules = _module_list(model, flatten_sequential=flatten_sequential)
+ remaining = set(return_layers.keys())
+ layers = OrderedDict()
+ for new_name, old_name, module in modules:
+ layers[new_name] = module
+ if old_name in remaining:
+ # return id has to be consistently str type for torchscript
+ self.return_layers[new_name] = str(return_layers[old_name])
+ remaining.remove(old_name)
+ if not remaining:
+ break
+ assert not remaining and len(self.return_layers) == len(return_layers), \
+ f'Return layers ({remaining}) are not present in model'
+ self.update(layers)
+
+ def _collect(self, x) -> (Dict[str, torch.Tensor]):
+ out = OrderedDict()
+ for name, module in self.items():
+ x = module(x)
+ if name in self.return_layers:
+ out_id = self.return_layers[name]
+ if isinstance(x, (tuple, list)):
+ # If model tap is a tuple or list, concat or select first element
+ # FIXME this may need to be more generic / flexible for some nets
+ out[out_id] = torch.cat(x, 1) if self.concat else x[0]
+ else:
+ out[out_id] = x
+ return out
+
+ def forward(self, x) -> Dict[str, torch.Tensor]:
+ return self._collect(x)
+
+
+class FeatureListNet(FeatureDictNet):
+ """ Feature extractor with list return
+
+ See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
+ In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
+ """
+ def __init__(
+ self, model,
+ out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
+ super(FeatureListNet, self).__init__(
+ model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
+ flatten_sequential=flatten_sequential)
+
+ def forward(self, x) -> (List[torch.Tensor]):
+ return list(self._collect(x).values())
+
+
+class FeatureHookNet(nn.ModuleDict):
+ """ FeatureHookNet
+
+ Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
+
+ If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
+ network in any way.
+
+ If `no_rewrite` is False, the model will be re-written as in the
+ FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
+
+ FIXME this does not currently work with Torchscript, see FeatureHooks class
+ """
+ def __init__(
+ self, model,
+ out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
+ feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
+ super(FeatureHookNet, self).__init__()
+ assert not torch.jit.is_scripting()
+ self.feature_info = _get_feature_info(model, out_indices)
+ self.out_as_dict = out_as_dict
+ layers = OrderedDict()
+ hooks = []
+ if no_rewrite:
+ assert not flatten_sequential
+ if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
+ model.reset_classifier(0)
+ layers['body'] = model
+ hooks.extend(self.feature_info.get_dicts())
+ else:
+ modules = _module_list(model, flatten_sequential=flatten_sequential)
+ remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
+ for f in self.feature_info.get_dicts()}
+ for new_name, old_name, module in modules:
+ layers[new_name] = module
+ for fn, fm in module.named_modules(prefix=old_name):
+ if fn in remaining:
+ hooks.append(dict(module=fn, hook_type=remaining[fn]))
+ del remaining[fn]
+ if not remaining:
+ break
+ assert not remaining, f'Return layers ({remaining}) are not present in model'
+ self.update(layers)
+ self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
+
+ def forward(self, x):
+ for name, module in self.items():
+ x = module(x)
+ out = self.hooks.get_output(x.device)
+ return out if self.out_as_dict else list(out.values())
diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py
new file mode 100644
index 0000000..5a25ee3
--- /dev/null
+++ b/timm/models/fx_features.py
@@ -0,0 +1,73 @@
+""" PyTorch FX Based Feature Extraction Helpers
+Using https://pytorch.org/vision/stable/feature_extraction.html
+"""
+from typing import Callable
+from torch import nn
+
+from .features import _get_feature_info
+
+try:
+ from torchvision.models.feature_extraction import create_feature_extractor
+ has_fx_feature_extraction = True
+except ImportError:
+ has_fx_feature_extraction = False
+
+# Layers we went to treat as leaf modules
+from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
+from .layers.non_local_attn import BilinearAttnTransform
+from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
+
+# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
+# BUT modules from timm.models should use the registration mechanism below
+_leaf_modules = {
+ BatchNormAct2d, # reason: flow control for jit scripting
+ BilinearAttnTransform, # reason: flow control t <= 1
+ BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
+ # Reason: get_same_padding has a max which raises a control flow error
+ Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
+ CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
+ DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
+}
+
+try:
+ from .layers import InplaceAbn
+ _leaf_modules.add(InplaceAbn)
+except ImportError:
+ pass
+
+
+def register_notrace_module(module: nn.Module):
+ """
+ Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
+ """
+ _leaf_modules.add(module)
+ return module
+
+
+# Functions we want to autowrap (treat them as leaves)
+_autowrap_functions = set()
+
+
+def register_notrace_function(func: Callable):
+ """
+ Decorator for functions which ought not to be traced through
+ """
+ _autowrap_functions.add(func)
+ return func
+
+
+class FeatureGraphNet(nn.Module):
+ def __init__(self, model, out_indices, out_map=None):
+ super().__init__()
+ assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
+ self.feature_info = _get_feature_info(model, out_indices)
+ if out_map is not None:
+ assert len(out_map) == len(out_indices)
+ return_nodes = {info['module']: out_map[i] if out_map is not None else info['module']
+ for i, info in enumerate(self.feature_info) if i in out_indices}
+ self.graph_module = create_feature_extractor(
+ model, return_nodes,
+ tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)})
+
+ def forward(self, x):
+ return list(self.graph_module(x).values())
diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py
new file mode 100644
index 0000000..3b6f90a
--- /dev/null
+++ b/timm/models/ghostnet.py
@@ -0,0 +1,276 @@
+"""
+An implementation of GhostNet Model as defined in:
+GhostNet: More Features from Cheap Operations. https://arxiv.org/abs/1911.11907
+The train script of the model is similar to that of MobileNetV3
+Original model: https://github.com/huawei-noah/CV-backbones/tree/master/ghostnet_pytorch
+"""
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .layers import SelectAdaptivePool2d, Linear, make_divisible
+from .efficientnet_blocks import SqueezeExcite, ConvBnAct
+from .helpers import build_model_with_cfg
+from .registry import register_model
+
+
+__all__ = ['GhostNet']
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv_stem', 'classifier': 'classifier',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'ghostnet_050': _cfg(url=''),
+ 'ghostnet_100': _cfg(
+ url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'),
+ 'ghostnet_130': _cfg(url=''),
+}
+
+
+_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4))
+
+
+class GhostModule(nn.Module):
+ def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True):
+ super(GhostModule, self).__init__()
+ self.oup = oup
+ init_channels = math.ceil(oup / ratio)
+ new_channels = init_channels * (ratio - 1)
+
+ self.primary_conv = nn.Sequential(
+ nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size//2, bias=False),
+ nn.BatchNorm2d(init_channels),
+ nn.ReLU(inplace=True) if relu else nn.Sequential(),
+ )
+
+ self.cheap_operation = nn.Sequential(
+ nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size//2, groups=init_channels, bias=False),
+ nn.BatchNorm2d(new_channels),
+ nn.ReLU(inplace=True) if relu else nn.Sequential(),
+ )
+
+ def forward(self, x):
+ x1 = self.primary_conv(x)
+ x2 = self.cheap_operation(x1)
+ out = torch.cat([x1, x2], dim=1)
+ return out[:, :self.oup, :, :]
+
+
+class GhostBottleneck(nn.Module):
+ """ Ghost bottleneck w/ optional SE"""
+
+ def __init__(self, in_chs, mid_chs, out_chs, dw_kernel_size=3,
+ stride=1, act_layer=nn.ReLU, se_ratio=0.):
+ super(GhostBottleneck, self).__init__()
+ has_se = se_ratio is not None and se_ratio > 0.
+ self.stride = stride
+
+ # Point-wise expansion
+ self.ghost1 = GhostModule(in_chs, mid_chs, relu=True)
+
+ # Depth-wise convolution
+ if self.stride > 1:
+ self.conv_dw = nn.Conv2d(
+ mid_chs, mid_chs, dw_kernel_size, stride=stride,
+ padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False)
+ self.bn_dw = nn.BatchNorm2d(mid_chs)
+ else:
+ self.conv_dw = None
+ self.bn_dw = None
+
+ # Squeeze-and-excitation
+ self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None
+
+ # Point-wise linear projection
+ self.ghost2 = GhostModule(mid_chs, out_chs, relu=False)
+
+ # shortcut
+ if in_chs == out_chs and self.stride == 1:
+ self.shortcut = nn.Sequential()
+ else:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(
+ in_chs, in_chs, dw_kernel_size, stride=stride,
+ padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False),
+ nn.BatchNorm2d(in_chs),
+ nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(out_chs),
+ )
+
+ def forward(self, x):
+ shortcut = x
+
+ # 1st ghost bottleneck
+ x = self.ghost1(x)
+
+ # Depth-wise convolution
+ if self.conv_dw is not None:
+ x = self.conv_dw(x)
+ x = self.bn_dw(x)
+
+ # Squeeze-and-excitation
+ if self.se is not None:
+ x = self.se(x)
+
+ # 2nd ghost bottleneck
+ x = self.ghost2(x)
+
+ x += self.shortcut(shortcut)
+ return x
+
+
+class GhostNet(nn.Module):
+ def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32, global_pool='avg'):
+ super(GhostNet, self).__init__()
+ # setting of inverted residual blocks
+ assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
+ self.cfgs = cfgs
+ self.num_classes = num_classes
+ self.dropout = dropout
+ self.feature_info = []
+
+ # building first layer
+ stem_chs = make_divisible(16 * width, 4)
+ self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False)
+ self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem'))
+ self.bn1 = nn.BatchNorm2d(stem_chs)
+ self.act1 = nn.ReLU(inplace=True)
+ prev_chs = stem_chs
+
+ # building inverted residual blocks
+ stages = nn.ModuleList([])
+ block = GhostBottleneck
+ stage_idx = 0
+ net_stride = 2
+ for cfg in self.cfgs:
+ layers = []
+ s = 1
+ for k, exp_size, c, se_ratio, s in cfg:
+ out_chs = make_divisible(c * width, 4)
+ mid_chs = make_divisible(exp_size * width, 4)
+ layers.append(block(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio))
+ prev_chs = out_chs
+ if s > 1:
+ net_stride *= 2
+ self.feature_info.append(dict(
+ num_chs=prev_chs, reduction=net_stride, module=f'blocks.{stage_idx}'))
+ stages.append(nn.Sequential(*layers))
+ stage_idx += 1
+
+ out_chs = make_divisible(exp_size * width, 4)
+ stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1)))
+ self.pool_dim = prev_chs = out_chs
+
+ self.blocks = nn.Sequential(*stages)
+
+ # building last several layers
+ self.num_features = out_chs = 1280
+ self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
+ self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
+ self.act2 = nn.ReLU(inplace=True)
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
+ self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity()
+
+ def get_classifier(self):
+ return self.classifier
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ # cannot meaningfully change pooling of efficient head after creation
+ self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
+ self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.conv_stem(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ x = self.blocks(x)
+ x = self.global_pool(x)
+ x = self.conv_head(x)
+ x = self.act2(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.flatten(x)
+ if self.dropout > 0.:
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.classifier(x)
+ return x
+
+
+def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs):
+ """
+ Constructs a GhostNet model
+ """
+ cfgs = [
+ # k, t, c, SE, s
+ # stage1
+ [[3, 16, 16, 0, 1]],
+ # stage2
+ [[3, 48, 24, 0, 2]],
+ [[3, 72, 24, 0, 1]],
+ # stage3
+ [[5, 72, 40, 0.25, 2]],
+ [[5, 120, 40, 0.25, 1]],
+ # stage4
+ [[3, 240, 80, 0, 2]],
+ [[3, 200, 80, 0, 1],
+ [3, 184, 80, 0, 1],
+ [3, 184, 80, 0, 1],
+ [3, 480, 112, 0.25, 1],
+ [3, 672, 112, 0.25, 1]
+ ],
+ # stage5
+ [[5, 672, 160, 0.25, 2]],
+ [[5, 960, 160, 0, 1],
+ [5, 960, 160, 0.25, 1],
+ [5, 960, 160, 0, 1],
+ [5, 960, 160, 0.25, 1]
+ ]
+ ]
+ model_kwargs = dict(
+ cfgs=cfgs,
+ width=width,
+ **kwargs,
+ )
+ return build_model_with_cfg(
+ GhostNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ feature_cfg=dict(flatten_sequential=True),
+ **model_kwargs)
+
+
+@register_model
+def ghostnet_050(pretrained=False, **kwargs):
+ """ GhostNet-0.5x """
+ model = _create_ghostnet('ghostnet_050', width=0.5, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def ghostnet_100(pretrained=False, **kwargs):
+ """ GhostNet-1.0x """
+ model = _create_ghostnet('ghostnet_100', width=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def ghostnet_130(pretrained=False, **kwargs):
+ """ GhostNet-1.3x """
+ model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs)
+ return model
diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py
new file mode 100644
index 0000000..027a10b
--- /dev/null
+++ b/timm/models/gluon_resnet.py
@@ -0,0 +1,248 @@
+"""Pytorch impl of MxNet Gluon ResNet/(SE)ResNeXt variants
+This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-ResNeXt additions
+and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py)
+by Ross Wightman
+"""
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import SEModule
+from .registry import register_model
+from .resnet import ResNet, Bottleneck, BasicBlock
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv1', 'classifier': 'fc',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'gluon_resnet18_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet18_v1b-0757602b.pth'),
+ 'gluon_resnet34_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet34_v1b-c6d82d59.pth'),
+ 'gluon_resnet50_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1b-0ebe02e2.pth'),
+ 'gluon_resnet101_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1b-3b017079.pth'),
+ 'gluon_resnet152_v1b': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1b-c1edb0dd.pth'),
+ 'gluon_resnet50_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth',
+ first_conv='conv1.0'),
+ 'gluon_resnet101_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth',
+ first_conv='conv1.0'),
+ 'gluon_resnet152_v1c': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth',
+ first_conv='conv1.0'),
+ 'gluon_resnet50_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth',
+ first_conv='conv1.0'),
+ 'gluon_resnet101_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth',
+ first_conv='conv1.0'),
+ 'gluon_resnet152_v1d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth',
+ first_conv='conv1.0'),
+ 'gluon_resnet50_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth',
+ first_conv='conv1.0'),
+ 'gluon_resnet101_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth',
+ first_conv='conv1.0'),
+ 'gluon_resnet152_v1s': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth',
+ first_conv='conv1.0'),
+ 'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'),
+ 'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'),
+ 'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'),
+ 'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'),
+ 'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'),
+ 'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'),
+ 'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth',
+ first_conv='conv1.0'),
+}
+
+
+def _create_resnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ ResNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ **kwargs)
+
+
+@register_model
+def gluon_resnet18_v1b(pretrained=False, **kwargs):
+ """Constructs a ResNet-18 model.
+ """
+ model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
+ return _create_resnet('gluon_resnet18_v1b', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnet34_v1b(pretrained=False, **kwargs):
+ """Constructs a ResNet-34 model.
+ """
+ model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
+ return _create_resnet('gluon_resnet34_v1b', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnet50_v1b(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
+ return _create_resnet('gluon_resnet50_v1b', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnet101_v1b(pretrained=False, **kwargs):
+ """Constructs a ResNet-101 model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs)
+ return _create_resnet('gluon_resnet101_v1b', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnet152_v1b(pretrained=False, **kwargs):
+ """Constructs a ResNet-152 model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs)
+ return _create_resnet('gluon_resnet152_v1b', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnet50_v1c(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', **kwargs)
+ return _create_resnet('gluon_resnet50_v1c', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnet101_v1c(pretrained=False, **kwargs):
+ """Constructs a ResNet-101 model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', **kwargs)
+ return _create_resnet('gluon_resnet101_v1c', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnet152_v1c(pretrained=False, **kwargs):
+ """Constructs a ResNet-152 model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', **kwargs)
+ return _create_resnet('gluon_resnet152_v1c', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnet50_v1d(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
+ return _create_resnet('gluon_resnet50_v1d', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnet101_v1d(pretrained=False, **kwargs):
+ """Constructs a ResNet-101 model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
+ return _create_resnet('gluon_resnet101_v1d', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnet152_v1d(pretrained=False, **kwargs):
+ """Constructs a ResNet-152 model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
+ return _create_resnet('gluon_resnet152_v1d', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnet50_v1s(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 6, 3], stem_width=64, stem_type='deep', **kwargs)
+ return _create_resnet('gluon_resnet50_v1s', pretrained, **model_args)
+
+
+
+@register_model
+def gluon_resnet101_v1s(pretrained=False, **kwargs):
+ """Constructs a ResNet-101 model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 23, 3], stem_width=64, stem_type='deep', **kwargs)
+ return _create_resnet('gluon_resnet101_v1s', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnet152_v1s(pretrained=False, **kwargs):
+ """Constructs a ResNet-152 model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 8, 36, 3], stem_width=64, stem_type='deep', **kwargs)
+ return _create_resnet('gluon_resnet152_v1s', pretrained, **model_args)
+
+
+
+@register_model
+def gluon_resnext50_32x4d(pretrained=False, **kwargs):
+ """Constructs a ResNeXt50-32x4d model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
+ return _create_resnet('gluon_resnext50_32x4d', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnext101_32x4d(pretrained=False, **kwargs):
+ """Constructs a ResNeXt-101 model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
+ return _create_resnet('gluon_resnext101_32x4d', pretrained, **model_args)
+
+
+@register_model
+def gluon_resnext101_64x4d(pretrained=False, **kwargs):
+ """Constructs a ResNeXt-101 model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs)
+ return _create_resnet('gluon_resnext101_64x4d', pretrained, **model_args)
+
+
+@register_model
+def gluon_seresnext50_32x4d(pretrained=False, **kwargs):
+ """Constructs a SEResNeXt50-32x4d model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
+ block_args=dict(attn_layer=SEModule), **kwargs)
+ return _create_resnet('gluon_seresnext50_32x4d', pretrained, **model_args)
+
+
+@register_model
+def gluon_seresnext101_32x4d(pretrained=False, **kwargs):
+ """Constructs a SEResNeXt-101-32x4d model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4,
+ block_args=dict(attn_layer=SEModule), **kwargs)
+ return _create_resnet('gluon_seresnext101_32x4d', pretrained, **model_args)
+
+
+@register_model
+def gluon_seresnext101_64x4d(pretrained=False, **kwargs):
+ """Constructs a SEResNeXt-101-64x4d model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4,
+ block_args=dict(attn_layer=SEModule), **kwargs)
+ return _create_resnet('gluon_seresnext101_64x4d', pretrained, **model_args)
+
+
+@register_model
+def gluon_senet154(pretrained=False, **kwargs):
+ """Constructs an SENet-154 model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep',
+ down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer=SEModule), **kwargs)
+ return _create_resnet('gluon_senet154', pretrained, **model_args)
diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py
new file mode 100644
index 0000000..fbd668a
--- /dev/null
+++ b/timm/models/gluon_xception.py
@@ -0,0 +1,246 @@
+"""Pytorch impl of Gluon Xception
+This is a port of the Gluon Xception code and weights, itself ported from a PyTorch DeepLab impl.
+
+Gluon model: (https://gluon-cv.mxnet.io/_modules/gluoncv/model_zoo/xception.html)
+Original PyTorch DeepLab impl: https://github.com/jfzhang95/pytorch-deeplab-xception
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from collections import OrderedDict
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import create_classifier, get_padding
+from .registry import register_model
+
+__all__ = ['Xception65']
+
+default_cfgs = {
+ 'gluon_xception65': {
+ 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_xception-7015a15c.pth',
+ 'input_size': (3, 299, 299),
+ 'crop_pct': 0.903,
+ 'pool_size': (10, 10),
+ 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN,
+ 'std': IMAGENET_DEFAULT_STD,
+ 'num_classes': 1000,
+ 'first_conv': 'conv1',
+ 'classifier': 'fc'
+ # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
+ },
+}
+
+""" PADDING NOTES
+The original PyTorch and Gluon impl of these models dutifully reproduced the
+aligned padding added to Tensorflow models for Deeplab. This padding was compensating
+for Tensorflow 'SAME' padding. PyTorch symmetric padding behaves the way we'd want it to.
+"""
+
+
+class SeparableConv2d(nn.Module):
+ def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None):
+ super(SeparableConv2d, self).__init__()
+ self.kernel_size = kernel_size
+ self.dilation = dilation
+
+ # depthwise convolution
+ padding = get_padding(kernel_size, stride, dilation)
+ self.conv_dw = nn.Conv2d(
+ inplanes, inplanes, kernel_size, stride=stride,
+ padding=padding, dilation=dilation, groups=inplanes, bias=bias)
+ self.bn = norm_layer(num_features=inplanes)
+ # pointwise convolution
+ self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
+
+ def forward(self, x):
+ x = self.conv_dw(x)
+ x = self.bn(x)
+ x = self.conv_pw(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None):
+ super(Block, self).__init__()
+ if isinstance(planes, (list, tuple)):
+ assert len(planes) == 3
+ else:
+ planes = (planes,) * 3
+ outplanes = planes[-1]
+
+ if outplanes != inplanes or stride != 1:
+ self.skip = nn.Sequential()
+ self.skip.add_module('conv1', nn.Conv2d(
+ inplanes, outplanes, 1, stride=stride, bias=False)),
+ self.skip.add_module('bn1', norm_layer(num_features=outplanes))
+ else:
+ self.skip = None
+
+ rep = OrderedDict()
+ for i in range(3):
+ rep['act%d' % (i + 1)] = nn.ReLU(inplace=True)
+ rep['conv%d' % (i + 1)] = SeparableConv2d(
+ inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer)
+ rep['bn%d' % (i + 1)] = norm_layer(planes[i])
+ inplanes = planes[i]
+
+ if not start_with_relu:
+ del rep['act1']
+ else:
+ rep['act1'] = nn.ReLU(inplace=False)
+ self.rep = nn.Sequential(rep)
+
+ def forward(self, x):
+ skip = x
+ if self.skip is not None:
+ skip = self.skip(skip)
+ x = self.rep(x) + skip
+ return x
+
+
+class Xception65(nn.Module):
+ """Modified Aligned Xception.
+
+ NOTE: only the 65 layer version is included here, the 71 layer variant
+ was not correct and had no pretrained weights
+ """
+
+ def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
+ drop_rate=0., global_pool='avg'):
+ super(Xception65, self).__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ if output_stride == 32:
+ entry_block3_stride = 2
+ exit_block20_stride = 2
+ middle_dilation = 1
+ exit_dilation = (1, 1)
+ elif output_stride == 16:
+ entry_block3_stride = 2
+ exit_block20_stride = 1
+ middle_dilation = 1
+ exit_dilation = (1, 2)
+ elif output_stride == 8:
+ entry_block3_stride = 1
+ exit_block20_stride = 1
+ middle_dilation = 2
+ exit_dilation = (2, 4)
+ else:
+ raise NotImplementedError
+
+ # Entry flow
+ self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = norm_layer(num_features=32)
+ self.act1 = nn.ReLU(inplace=True)
+
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = norm_layer(num_features=64)
+ self.act2 = nn.ReLU(inplace=True)
+
+ self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer)
+ self.block1_act = nn.ReLU(inplace=True)
+ self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer)
+ self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer)
+
+ # Middle flow
+ self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
+ 728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)]))
+
+ # Exit flow
+ self.block20 = Block(
+ 728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer)
+ self.block20_act = nn.ReLU(inplace=True)
+
+ self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
+ self.bn3 = norm_layer(num_features=1536)
+ self.act3 = nn.ReLU(inplace=True)
+
+ self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
+ self.bn4 = norm_layer(num_features=1536)
+ self.act4 = nn.ReLU(inplace=True)
+
+ self.num_features = 2048
+ self.conv5 = SeparableConv2d(
+ 1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
+ self.bn5 = norm_layer(num_features=self.num_features)
+ self.act5 = nn.ReLU(inplace=True)
+ self.feature_info = [
+ dict(num_chs=64, reduction=2, module='act2'),
+ dict(num_chs=128, reduction=4, module='block1_act'),
+ dict(num_chs=256, reduction=8, module='block3.rep.act1'),
+ dict(num_chs=728, reduction=16, module='block20.rep.act1'),
+ dict(num_chs=2048, reduction=32, module='act5'),
+ ]
+
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ def get_classifier(self):
+ return self.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ # Entry flow
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.act2(x)
+
+ x = self.block1(x)
+ x = self.block1_act(x)
+ # c1 = x
+ x = self.block2(x)
+ # c2 = x
+ x = self.block3(x)
+
+ # Middle flow
+ x = self.mid(x)
+ # c3 = x
+
+ # Exit flow
+ x = self.block20(x)
+ x = self.block20_act(x)
+ x = self.conv3(x)
+ x = self.bn3(x)
+ x = self.act3(x)
+
+ x = self.conv4(x)
+ x = self.bn4(x)
+ x = self.act4(x)
+
+ x = self.conv5(x)
+ x = self.bn5(x)
+ x = self.act5(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate:
+ F.dropout(x, self.drop_rate, training=self.training)
+ x = self.fc(x)
+ return x
+
+
+def _create_gluon_xception(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ Xception65, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ feature_cfg=dict(feature_cls='hook'),
+ **kwargs)
+
+
+@register_model
+def gluon_xception65(pretrained=False, **kwargs):
+ """ Modified Aligned Xception-65
+ """
+ return _create_gluon_xception('gluon_xception65', pretrained, **kwargs)
diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py
new file mode 100644
index 0000000..9988a04
--- /dev/null
+++ b/timm/models/hardcorenas.py
@@ -0,0 +1,152 @@
+from functools import partial
+
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .efficientnet_blocks import SqueezeExcite
+from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels
+from .helpers import build_model_with_cfg, default_cfg_for_features
+from .layers import get_act_fn
+from .mobilenetv3 import MobileNetV3, MobileNetV3Features
+from .registry import register_model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv_stem', 'classifier': 'classifier',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'hardcorenas_a': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_A_Green_38ms_75.9_23474aeb.pth'),
+ 'hardcorenas_b': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_B_Green_40ms_76.5_1f882d1e.pth'),
+ 'hardcorenas_c': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_C_Green_44ms_77.1_d4148c9e.pth'),
+ 'hardcorenas_d': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_D_Green_50ms_77.4_23e3cdde.pth'),
+ 'hardcorenas_e': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_E_Green_55ms_77.9_90f20e8a.pth'),
+ 'hardcorenas_f': _cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/public/HardCoReNAS/HardCoreNAS_F_Green_60ms_78.1_2855edf1.pth'),
+}
+
+
+def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
+ """Creates a hardcorenas model
+
+ Ref impl: https://github.com/Alibaba-MIIL/HardCoReNAS
+ Paper: https://arxiv.org/abs/2102.11646
+
+ """
+ num_features = 1280
+ se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ num_features=num_features,
+ stem_size=32,
+ norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ act_layer=resolve_act_layer(kwargs, 'hard_swish'),
+ se_layer=se_layer,
+ **kwargs,
+ )
+
+ features_only = False
+ model_cls = MobileNetV3
+ kwargs_filter = None
+ if model_kwargs.pop('features_only', False):
+ features_only = True
+ kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool')
+ model_cls = MobileNetV3Features
+ model = build_model_with_cfg(
+ model_cls, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ pretrained_strict=not features_only,
+ kwargs_filter=kwargs_filter,
+ **model_kwargs)
+ if features_only:
+ model.default_cfg = default_cfg_for_features(model.default_cfg)
+ return model
+
+
+@register_model
+def hardcorenas_a(pretrained=False, **kwargs):
+ """ hardcorenas_A """
+ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
+ ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e6_c40_nre_se0.25'],
+ ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25'],
+ ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25'],
+ ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
+ model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_a', arch_def=arch_def, **kwargs)
+ return model
+
+
+@register_model
+def hardcorenas_b(pretrained=False, **kwargs):
+ """ hardcorenas_B """
+ arch_def = [['ds_r1_k3_s1_e1_c16_nre'],
+ ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25', 'ir_r1_k3_s1_e3_c24_nre'],
+ ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre'],
+ ['ir_r1_k5_s2_e3_c80', 'ir_r1_k5_s1_e3_c80', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'],
+ ['ir_r1_k5_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'],
+ ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'],
+ ['cn_r1_k1_s1_c960']]
+ model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_b', arch_def=arch_def, **kwargs)
+ return model
+
+
+@register_model
+def hardcorenas_c(pretrained=False, **kwargs):
+ """ hardcorenas_C """
+ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
+ ['ir_r1_k5_s2_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre', 'ir_r1_k5_s1_e3_c40_nre',
+ 'ir_r1_k5_s1_e3_c40_nre'],
+ ['ir_r1_k5_s2_e4_c80', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80', 'ir_r1_k3_s1_e3_c80'],
+ ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112', 'ir_r1_k3_s1_e3_c112'],
+ ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e3_c192_se0.25'],
+ ['cn_r1_k1_s1_c960']]
+ model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_c', arch_def=arch_def, **kwargs)
+ return model
+
+
+@register_model
+def hardcorenas_d(pretrained=False, **kwargs):
+ """ hardcorenas_D """
+ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
+ ['ir_r1_k5_s2_e3_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k3_s1_e3_c40_nre_se0.25'],
+ ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25',
+ 'ir_r1_k3_s1_e3_c80_se0.25'],
+ ['ir_r1_k3_s1_e4_c112_se0.25', 'ir_r1_k5_s1_e4_c112_se0.25', 'ir_r1_k3_s1_e3_c112_se0.25',
+ 'ir_r1_k5_s1_e3_c112_se0.25'],
+ ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
+ 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
+ model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_d', arch_def=arch_def, **kwargs)
+ return model
+
+
+@register_model
+def hardcorenas_e(pretrained=False, **kwargs):
+ """ hardcorenas_E """
+ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
+ ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25', 'ir_r1_k5_s1_e4_c40_nre_se0.25',
+ 'ir_r1_k3_s1_e3_c40_nre_se0.25'], ['ir_r1_k5_s2_e4_c80_se0.25', 'ir_r1_k3_s1_e6_c80_se0.25'],
+ ['ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25',
+ 'ir_r1_k5_s1_e3_c112_se0.25'],
+ ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25',
+ 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
+ model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_e', arch_def=arch_def, **kwargs)
+ return model
+
+
+@register_model
+def hardcorenas_f(pretrained=False, **kwargs):
+ """ hardcorenas_F """
+ arch_def = [['ds_r1_k3_s1_e1_c16_nre'], ['ir_r1_k5_s2_e3_c24_nre_se0.25', 'ir_r1_k5_s1_e3_c24_nre_se0.25'],
+ ['ir_r1_k5_s2_e6_c40_nre_se0.25', 'ir_r1_k5_s1_e6_c40_nre_se0.25'],
+ ['ir_r1_k5_s2_e6_c80_se0.25', 'ir_r1_k5_s1_e6_c80_se0.25', 'ir_r1_k3_s1_e3_c80_se0.25',
+ 'ir_r1_k3_s1_e3_c80_se0.25'],
+ ['ir_r1_k3_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25', 'ir_r1_k5_s1_e6_c112_se0.25',
+ 'ir_r1_k3_s1_e3_c112_se0.25'],
+ ['ir_r1_k5_s2_e6_c192_se0.25', 'ir_r1_k5_s1_e6_c192_se0.25', 'ir_r1_k3_s1_e6_c192_se0.25',
+ 'ir_r1_k3_s1_e6_c192_se0.25'], ['cn_r1_k1_s1_c960']]
+ model = _gen_hardcorenas(pretrained=pretrained, variant='hardcorenas_f', arch_def=arch_def, **kwargs)
+ return model
diff --git a/timm/models/helpers.py b/timm/models/helpers.py
new file mode 100644
index 0000000..16ce64d
--- /dev/null
+++ b/timm/models/helpers.py
@@ -0,0 +1,518 @@
+""" Model creation / weight loading / state_dict helpers
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import logging
+import os
+import math
+from collections import OrderedDict
+from copy import deepcopy
+from typing import Any, Callable, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from torch.hub import load_state_dict_from_url
+
+from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
+from .fx_features import FeatureGraphNet
+from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
+from .layers import Conv2dSame, Linear
+
+
+_logger = logging.getLogger(__name__)
+
+
+def load_state_dict(checkpoint_path, use_ema=False):
+ if checkpoint_path and os.path.isfile(checkpoint_path):
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+ state_dict_key = ''
+ if isinstance(checkpoint, dict):
+ if use_ema and checkpoint.get('state_dict_ema', None) is not None:
+ state_dict_key = 'state_dict_ema'
+ elif use_ema and checkpoint.get('model_ema', None) is not None:
+ state_dict_key = 'model_ema'
+ elif 'state_dict' in checkpoint:
+ state_dict_key = 'state_dict'
+ elif 'model' in checkpoint:
+ state_dict_key = 'model'
+ if state_dict_key:
+ state_dict = checkpoint[state_dict_key]
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ # strip `module.` prefix
+ name = k[7:] if k.startswith('module') else k
+ new_state_dict[name] = v
+ state_dict = new_state_dict
+ else:
+ state_dict = checkpoint
+ _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
+ return state_dict
+ else:
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
+ raise FileNotFoundError()
+
+
+def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
+ if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
+ # numpy checkpoint, try to load via model specific load_pretrained fn
+ if hasattr(model, 'load_pretrained'):
+ model.load_pretrained(checkpoint_path)
+ else:
+ raise NotImplementedError('Model cannot load numpy checkpoint')
+ return
+ state_dict = load_state_dict(checkpoint_path, use_ema)
+ model.load_state_dict(state_dict, strict=strict)
+
+
+def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
+ resume_epoch = None
+ if os.path.isfile(checkpoint_path):
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
+ if log_info:
+ _logger.info('Restoring model state from checkpoint...')
+ new_state_dict = OrderedDict()
+ for k, v in checkpoint['state_dict'].items():
+ name = k[7:] if k.startswith('module') else k
+ new_state_dict[name] = v
+ model.load_state_dict(new_state_dict)
+
+ if optimizer is not None and 'optimizer' in checkpoint:
+ if log_info:
+ _logger.info('Restoring optimizer state from checkpoint...')
+ optimizer.load_state_dict(checkpoint['optimizer'])
+
+ if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
+ if log_info:
+ _logger.info('Restoring AMP loss scaler state from checkpoint...')
+ loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
+
+ if 'epoch' in checkpoint:
+ resume_epoch = checkpoint['epoch']
+ if 'version' in checkpoint and checkpoint['version'] > 1:
+ resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
+
+ if log_info:
+ _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
+ else:
+ model.load_state_dict(checkpoint)
+ if log_info:
+ _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
+ return resume_epoch
+ else:
+ _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
+ raise FileNotFoundError()
+
+
+def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False, check_hash=False):
+ r"""Loads a custom (read non .pth) weight file
+
+ Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
+ a passed in custom load fun, or the `load_pretrained` model member fn.
+
+ If the object is already present in `model_dir`, it's deserialized and returned.
+ The default value of `model_dir` is ``/checkpoints`` where
+ `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
+
+ Args:
+ model: The instantiated model to load weights into
+ default_cfg (dict): Default pretrained model cfg
+ load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
+ 'laod_pretrained' on the model will be called if it exists
+ progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
+ check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
+ ``filename-.ext`` where ```` is the first eight or more
+ digits of the SHA256 hash of the contents of the file. The hash is used to
+ ensure unique names and to verify the contents of the file. Default: False
+ """
+ default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {}
+ pretrained_url = default_cfg.get('url', None)
+ if not pretrained_url:
+ _logger.warning("No pretrained weights exist for this model. Using random initialization.")
+ return
+ cached_file = download_cached_file(default_cfg['url'], check_hash=check_hash, progress=progress)
+
+ if load_fn is not None:
+ load_fn(model, cached_file)
+ elif hasattr(model, 'load_pretrained'):
+ model.load_pretrained(cached_file)
+ else:
+ _logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
+
+
+def adapt_input_conv(in_chans, conv_weight):
+ conv_type = conv_weight.dtype
+ conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
+ O, I, J, K = conv_weight.shape
+ if in_chans == 1:
+ if I > 3:
+ assert conv_weight.shape[1] % 3 == 0
+ # For models with space2depth stems
+ conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
+ conv_weight = conv_weight.sum(dim=2, keepdim=False)
+ else:
+ conv_weight = conv_weight.sum(dim=1, keepdim=True)
+ elif in_chans != 3:
+ if I != 3:
+ raise NotImplementedError('Weight format not supported by conversion.')
+ else:
+ # NOTE this strategy should be better than random init, but there could be other combinations of
+ # the original RGB input layer weights that'd work better for specific cases.
+ repeat = int(math.ceil(in_chans / 3))
+ conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
+ conv_weight *= (3 / float(in_chans))
+ conv_weight = conv_weight.to(conv_type)
+ return conv_weight
+
+
+def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
+ """ Load pretrained checkpoint
+
+ Args:
+ model (nn.Module) : PyTorch model module
+ default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset
+ num_classes (int): num_classes for model
+ in_chans (int): in_chans for model
+ filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
+ strict (bool): strict load of checkpoint
+ progress (bool): enable progress bar for weight download
+
+ """
+ default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {}
+ pretrained_url = default_cfg.get('url', None)
+ hf_hub_id = default_cfg.get('hf_hub', None)
+ if not pretrained_url and not hf_hub_id:
+ _logger.warning("No pretrained weights exist for this model. Using random initialization.")
+ return
+ if pretrained_url:
+ _logger.info(f'Loading pretrained weights from url ({pretrained_url})')
+ state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu')
+ elif hf_hub_id and has_hf_hub(necessary=True):
+ _logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
+ state_dict = load_state_dict_from_hf(hf_hub_id)
+ if filter_fn is not None:
+ # for backwards compat with filter fn that take one arg, try one first, the two
+ try:
+ state_dict = filter_fn(state_dict)
+ except TypeError:
+ state_dict = filter_fn(state_dict, model)
+
+ input_convs = default_cfg.get('first_conv', None)
+ if input_convs is not None and in_chans != 3:
+ if isinstance(input_convs, str):
+ input_convs = (input_convs,)
+ for input_conv_name in input_convs:
+ weight_name = input_conv_name + '.weight'
+ try:
+ state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
+ _logger.info(
+ f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
+ except NotImplementedError as e:
+ del state_dict[weight_name]
+ strict = False
+ _logger.warning(
+ f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
+
+ classifiers = default_cfg.get('classifier', None)
+ label_offset = default_cfg.get('label_offset', 0)
+ if classifiers is not None:
+ if isinstance(classifiers, str):
+ classifiers = (classifiers,)
+ if num_classes != default_cfg['num_classes']:
+ for classifier_name in classifiers:
+ # completely discard fully connected if model num_classes doesn't match pretrained weights
+ del state_dict[classifier_name + '.weight']
+ del state_dict[classifier_name + '.bias']
+ strict = False
+ elif label_offset > 0:
+ for classifier_name in classifiers:
+ # special case for pretrained weights with an extra background class in pretrained weights
+ classifier_weight = state_dict[classifier_name + '.weight']
+ state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
+ classifier_bias = state_dict[classifier_name + '.bias']
+ state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
+
+ model.load_state_dict(state_dict, strict=strict)
+
+
+def extract_layer(model, layer):
+ layer = layer.split('.')
+ module = model
+ if hasattr(model, 'module') and layer[0] != 'module':
+ module = model.module
+ if not hasattr(model, 'module') and layer[0] == 'module':
+ layer = layer[1:]
+ for l in layer:
+ if hasattr(module, l):
+ if not l.isdigit():
+ module = getattr(module, l)
+ else:
+ module = module[int(l)]
+ else:
+ return module
+ return module
+
+
+def set_layer(model, layer, val):
+ layer = layer.split('.')
+ module = model
+ if hasattr(model, 'module') and layer[0] != 'module':
+ module = model.module
+ lst_index = 0
+ module2 = module
+ for l in layer:
+ if hasattr(module2, l):
+ if not l.isdigit():
+ module2 = getattr(module2, l)
+ else:
+ module2 = module2[int(l)]
+ lst_index += 1
+ lst_index -= 1
+ for l in layer[:lst_index]:
+ if not l.isdigit():
+ module = getattr(module, l)
+ else:
+ module = module[int(l)]
+ l = layer[lst_index]
+ setattr(module, l, val)
+
+
+def adapt_model_from_string(parent_module, model_string):
+ separator = '***'
+ state_dict = {}
+ lst_shape = model_string.split(separator)
+ for k in lst_shape:
+ k = k.split(':')
+ key = k[0]
+ shape = k[1][1:-1].split(',')
+ if shape[0] != '':
+ state_dict[key] = [int(i) for i in shape]
+
+ new_module = deepcopy(parent_module)
+ for n, m in parent_module.named_modules():
+ old_module = extract_layer(parent_module, n)
+ if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
+ if isinstance(old_module, Conv2dSame):
+ conv = Conv2dSame
+ else:
+ conv = nn.Conv2d
+ s = state_dict[n + '.weight']
+ in_channels = s[1]
+ out_channels = s[0]
+ g = 1
+ if old_module.groups > 1:
+ in_channels = out_channels
+ g = in_channels
+ new_conv = conv(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
+ bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
+ groups=g, stride=old_module.stride)
+ set_layer(new_module, n, new_conv)
+ if isinstance(old_module, nn.BatchNorm2d):
+ new_bn = nn.BatchNorm2d(
+ num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
+ affine=old_module.affine, track_running_stats=True)
+ set_layer(new_module, n, new_bn)
+ if isinstance(old_module, nn.Linear):
+ # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
+ num_features = state_dict[n + '.weight'][1]
+ new_fc = Linear(
+ in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
+ set_layer(new_module, n, new_fc)
+ if hasattr(new_module, 'num_features'):
+ new_module.num_features = num_features
+ new_module.eval()
+ parent_module.eval()
+
+ return new_module
+
+
+def adapt_model_from_file(parent_module, model_variant):
+ adapt_file = os.path.join(os.path.dirname(__file__), 'pruned', model_variant + '.txt')
+ with open(adapt_file, 'r') as f:
+ return adapt_model_from_string(parent_module, f.read().strip())
+
+
+def default_cfg_for_features(default_cfg):
+ default_cfg = deepcopy(default_cfg)
+ # remove default pretrained cfg fields that don't have much relevance for feature backbone
+ to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size?
+ for tr in to_remove:
+ default_cfg.pop(tr, None)
+ return default_cfg
+
+
+def overlay_external_default_cfg(default_cfg, kwargs):
+ """ Overlay 'external_default_cfg' in kwargs on top of default_cfg arg.
+ """
+ external_default_cfg = kwargs.pop('external_default_cfg', None)
+ if external_default_cfg:
+ default_cfg.pop('url', None) # url should come from external cfg
+ default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
+ default_cfg.update(external_default_cfg)
+
+
+def set_default_kwargs(kwargs, names, default_cfg):
+ for n in names:
+ # for legacy reasons, model __init__args uses img_size + in_chans as separate args while
+ # default_cfg has one input_size=(C, H ,W) entry
+ if n == 'img_size':
+ input_size = default_cfg.get('input_size', None)
+ if input_size is not None:
+ assert len(input_size) == 3
+ kwargs.setdefault(n, input_size[-2:])
+ elif n == 'in_chans':
+ input_size = default_cfg.get('input_size', None)
+ if input_size is not None:
+ assert len(input_size) == 3
+ kwargs.setdefault(n, input_size[0])
+ else:
+ default_val = default_cfg.get(n, None)
+ if default_val is not None:
+ kwargs.setdefault(n, default_cfg[n])
+
+
+def filter_kwargs(kwargs, names):
+ if not kwargs or not names:
+ return
+ for n in names:
+ kwargs.pop(n, None)
+
+
+def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
+ """ Update the default_cfg and kwargs before passing to model
+
+ FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
+ could/should be replaced by an improved configuration mechanism
+
+ Args:
+ default_cfg: input default_cfg (updated in-place)
+ kwargs: keyword args passed to model build fn (updated in-place)
+ kwargs_filter: keyword arg keys that must be removed before model __init__
+ """
+ # Overlay default cfg values from `external_default_cfg` if it exists in kwargs
+ overlay_external_default_cfg(default_cfg, kwargs)
+ # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
+ default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
+ if default_cfg.get('fixed_input_size', False):
+ # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
+ default_kwarg_names += ('img_size',)
+ set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg)
+ # Filter keyword args for task specific model variants (some 'features only' models, etc.)
+ filter_kwargs(kwargs, names=kwargs_filter)
+
+
+def build_model_with_cfg(
+ model_cls: Callable,
+ variant: str,
+ pretrained: bool,
+ default_cfg: dict,
+ model_cfg: Optional[Any] = None,
+ feature_cfg: Optional[dict] = None,
+ pretrained_strict: bool = True,
+ pretrained_filter_fn: Optional[Callable] = None,
+ pretrained_custom_load: bool = False,
+ kwargs_filter: Optional[Tuple[str]] = None,
+ **kwargs):
+ """ Build model with specified default_cfg and optional model_cfg
+
+ This helper fn aids in the construction of a model including:
+ * handling default_cfg and associated pretained weight loading
+ * passing through optional model_cfg for models with config based arch spec
+ * features_only model adaptation
+ * pruning config / model adaptation
+
+ Args:
+ model_cls (nn.Module): model class
+ variant (str): model variant name
+ pretrained (bool): load pretrained weights
+ default_cfg (dict): model's default pretrained/task config
+ model_cfg (Optional[Dict]): model's architecture config
+ feature_cfg (Optional[Dict]: feature extraction adapter config
+ pretrained_strict (bool): load pretrained weights strictly
+ pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
+ pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights
+ kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
+ **kwargs: model args passed through to model __init__
+ """
+ pruned = kwargs.pop('pruned', False)
+ features = False
+ feature_cfg = feature_cfg or {}
+ default_cfg = deepcopy(default_cfg) if default_cfg else {}
+ update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
+ default_cfg.setdefault('architecture', variant)
+
+ # Setup for feature extraction wrapper done at end of this fn
+ if kwargs.pop('features_only', False):
+ features = True
+ feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
+ if 'out_indices' in kwargs:
+ feature_cfg['out_indices'] = kwargs.pop('out_indices')
+
+ # Build the model
+ model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
+ model.default_cfg = default_cfg
+
+ if pruned:
+ model = adapt_model_from_file(model, variant)
+
+ # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
+ num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
+ if pretrained:
+ if pretrained_custom_load:
+ load_custom_pretrained(model)
+ else:
+ load_pretrained(
+ model,
+ num_classes=num_classes_pretrained,
+ in_chans=kwargs.get('in_chans', 3),
+ filter_fn=pretrained_filter_fn,
+ strict=pretrained_strict)
+
+ # Wrap the model in a feature extraction module if enabled
+ if features:
+ feature_cls = FeatureListNet
+ if 'feature_cls' in feature_cfg:
+ feature_cls = feature_cfg.pop('feature_cls')
+ if isinstance(feature_cls, str):
+ feature_cls = feature_cls.lower()
+ if 'hook' in feature_cls:
+ feature_cls = FeatureHookNet
+ elif feature_cls == 'fx':
+ feature_cls = FeatureGraphNet
+ else:
+ assert False, f'Unknown feature class {feature_cls}'
+ model = feature_cls(model, **feature_cfg)
+ model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
+
+ return model
+
+
+def model_parameters(model, exclude_head=False):
+ if exclude_head:
+ # FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
+ return [p for p in model.parameters()][:-2]
+ else:
+ return model.parameters()
+
+
+def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = '.'.join((name, child_name)) if name else child_name
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
+ if not depth_first and include_root:
+ yield name, module
+ for child_name, child_module in module.named_children():
+ child_name = '.'.join((name, child_name)) if name else child_name
+ yield from named_modules(
+ module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+ if depth_first and include_root:
+ yield name, module
diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py
new file mode 100644
index 0000000..c56964f
--- /dev/null
+++ b/timm/models/hrnet.py
@@ -0,0 +1,836 @@
+""" HRNet
+
+Copied from https://github.com/HRNet/HRNet-Image-Classification
+
+Original header:
+ Copyright (c) Microsoft
+ Licensed under the MIT License.
+ Written by Bin Xiao (Bin.Xiao@microsoft.com)
+ Modified by Ke Sun (sunk@mail.ustc.edu.cn)
+"""
+import logging
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .features import FeatureInfo
+from .helpers import build_model_with_cfg, default_cfg_for_features
+from .layers import create_classifier
+from .registry import register_model
+from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE
+
+_BN_MOMENTUM = 0.1
+_logger = logging.getLogger(__name__)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv1', 'classifier': 'classifier',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'hrnet_w18_small': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v1-f460c6bc.pth'),
+ 'hrnet_w18_small_v2': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnet_w18_small_v2-4c50a8cb.pth'),
+ 'hrnet_w18': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w18-8cb57bb9.pth'),
+ 'hrnet_w30': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w30-8d7f8dab.pth'),
+ 'hrnet_w32': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w32-90d8c5fb.pth'),
+ 'hrnet_w40': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w40-7cd397a4.pth'),
+ 'hrnet_w44': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w44-c9ac8c18.pth'),
+ 'hrnet_w48': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w48-abd2e6ab.pth'),
+ 'hrnet_w64': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-hrnet/hrnetv2_w64-b47cc881.pth'),
+}
+
+cfg_cls = dict(
+ hrnet_w18_small=dict(
+ STEM_WIDTH=64,
+ STAGE1=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=1,
+ BLOCK='BOTTLENECK',
+ NUM_BLOCKS=(1,),
+ NUM_CHANNELS=(32,),
+ FUSE_METHOD='SUM',
+ ),
+ STAGE2=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=2,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(2, 2),
+ NUM_CHANNELS=(16, 32),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE3=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=3,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(2, 2, 2),
+ NUM_CHANNELS=(16, 32, 64),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE4=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=4,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(2, 2, 2, 2),
+ NUM_CHANNELS=(16, 32, 64, 128),
+ FUSE_METHOD='SUM',
+ ),
+ ),
+
+ hrnet_w18_small_v2=dict(
+ STEM_WIDTH=64,
+ STAGE1=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=1,
+ BLOCK='BOTTLENECK',
+ NUM_BLOCKS=(2,),
+ NUM_CHANNELS=(64,),
+ FUSE_METHOD='SUM',
+ ),
+ STAGE2=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=2,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(2, 2),
+ NUM_CHANNELS=(18, 36),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE3=dict(
+ NUM_MODULES=3,
+ NUM_BRANCHES=3,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(2, 2, 2),
+ NUM_CHANNELS=(18, 36, 72),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE4=dict(
+ NUM_MODULES=2,
+ NUM_BRANCHES=4,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(2, 2, 2, 2),
+ NUM_CHANNELS=(18, 36, 72, 144),
+ FUSE_METHOD='SUM',
+ ),
+ ),
+
+ hrnet_w18=dict(
+ STEM_WIDTH=64,
+ STAGE1=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=1,
+ BLOCK='BOTTLENECK',
+ NUM_BLOCKS=(4,),
+ NUM_CHANNELS=(64,),
+ FUSE_METHOD='SUM',
+ ),
+ STAGE2=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=2,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4),
+ NUM_CHANNELS=(18, 36),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE3=dict(
+ NUM_MODULES=4,
+ NUM_BRANCHES=3,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4),
+ NUM_CHANNELS=(18, 36, 72),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE4=dict(
+ NUM_MODULES=3,
+ NUM_BRANCHES=4,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4, 4),
+ NUM_CHANNELS=(18, 36, 72, 144),
+ FUSE_METHOD='SUM',
+ ),
+ ),
+
+ hrnet_w30=dict(
+ STEM_WIDTH=64,
+ STAGE1=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=1,
+ BLOCK='BOTTLENECK',
+ NUM_BLOCKS=(4,),
+ NUM_CHANNELS=(64,),
+ FUSE_METHOD='SUM',
+ ),
+ STAGE2=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=2,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4),
+ NUM_CHANNELS=(30, 60),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE3=dict(
+ NUM_MODULES=4,
+ NUM_BRANCHES=3,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4),
+ NUM_CHANNELS=(30, 60, 120),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE4=dict(
+ NUM_MODULES=3,
+ NUM_BRANCHES=4,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4, 4),
+ NUM_CHANNELS=(30, 60, 120, 240),
+ FUSE_METHOD='SUM',
+ ),
+ ),
+
+ hrnet_w32=dict(
+ STEM_WIDTH=64,
+ STAGE1=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=1,
+ BLOCK='BOTTLENECK',
+ NUM_BLOCKS=(4,),
+ NUM_CHANNELS=(64,),
+ FUSE_METHOD='SUM',
+ ),
+ STAGE2=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=2,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4),
+ NUM_CHANNELS=(32, 64),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE3=dict(
+ NUM_MODULES=4,
+ NUM_BRANCHES=3,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4),
+ NUM_CHANNELS=(32, 64, 128),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE4=dict(
+ NUM_MODULES=3,
+ NUM_BRANCHES=4,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4, 4),
+ NUM_CHANNELS=(32, 64, 128, 256),
+ FUSE_METHOD='SUM',
+ ),
+ ),
+
+ hrnet_w40=dict(
+ STEM_WIDTH=64,
+ STAGE1=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=1,
+ BLOCK='BOTTLENECK',
+ NUM_BLOCKS=(4,),
+ NUM_CHANNELS=(64,),
+ FUSE_METHOD='SUM',
+ ),
+ STAGE2=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=2,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4),
+ NUM_CHANNELS=(40, 80),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE3=dict(
+ NUM_MODULES=4,
+ NUM_BRANCHES=3,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4),
+ NUM_CHANNELS=(40, 80, 160),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE4=dict(
+ NUM_MODULES=3,
+ NUM_BRANCHES=4,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4, 4),
+ NUM_CHANNELS=(40, 80, 160, 320),
+ FUSE_METHOD='SUM',
+ ),
+ ),
+
+ hrnet_w44=dict(
+ STEM_WIDTH=64,
+ STAGE1=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=1,
+ BLOCK='BOTTLENECK',
+ NUM_BLOCKS=(4,),
+ NUM_CHANNELS=(64,),
+ FUSE_METHOD='SUM',
+ ),
+ STAGE2=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=2,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4),
+ NUM_CHANNELS=(44, 88),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE3=dict(
+ NUM_MODULES=4,
+ NUM_BRANCHES=3,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4),
+ NUM_CHANNELS=(44, 88, 176),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE4=dict(
+ NUM_MODULES=3,
+ NUM_BRANCHES=4,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4, 4),
+ NUM_CHANNELS=(44, 88, 176, 352),
+ FUSE_METHOD='SUM',
+ ),
+ ),
+
+ hrnet_w48=dict(
+ STEM_WIDTH=64,
+ STAGE1=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=1,
+ BLOCK='BOTTLENECK',
+ NUM_BLOCKS=(4,),
+ NUM_CHANNELS=(64,),
+ FUSE_METHOD='SUM',
+ ),
+ STAGE2=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=2,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4),
+ NUM_CHANNELS=(48, 96),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE3=dict(
+ NUM_MODULES=4,
+ NUM_BRANCHES=3,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4),
+ NUM_CHANNELS=(48, 96, 192),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE4=dict(
+ NUM_MODULES=3,
+ NUM_BRANCHES=4,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4, 4),
+ NUM_CHANNELS=(48, 96, 192, 384),
+ FUSE_METHOD='SUM',
+ ),
+ ),
+
+ hrnet_w64=dict(
+ STEM_WIDTH=64,
+ STAGE1=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=1,
+ BLOCK='BOTTLENECK',
+ NUM_BLOCKS=(4,),
+ NUM_CHANNELS=(64,),
+ FUSE_METHOD='SUM',
+ ),
+ STAGE2=dict(
+ NUM_MODULES=1,
+ NUM_BRANCHES=2,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4),
+ NUM_CHANNELS=(64, 128),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE3=dict(
+ NUM_MODULES=4,
+ NUM_BRANCHES=3,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4),
+ NUM_CHANNELS=(64, 128, 256),
+ FUSE_METHOD='SUM'
+ ),
+ STAGE4=dict(
+ NUM_MODULES=3,
+ NUM_BRANCHES=4,
+ BLOCK='BASIC',
+ NUM_BLOCKS=(4, 4, 4, 4),
+ NUM_CHANNELS=(64, 128, 256, 512),
+ FUSE_METHOD='SUM',
+ ),
+ )
+)
+
+
+class HighResolutionModule(nn.Module):
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
+ num_channels, fuse_method, multi_scale_output=True):
+ super(HighResolutionModule, self).__init__()
+ self._check_branches(
+ num_branches, blocks, num_blocks, num_inchannels, num_channels)
+
+ self.num_inchannels = num_inchannels
+ self.fuse_method = fuse_method
+ self.num_branches = num_branches
+
+ self.multi_scale_output = multi_scale_output
+
+ self.branches = self._make_branches(
+ num_branches, blocks, num_blocks, num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.fuse_act = nn.ReLU(False)
+
+ def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels):
+ error_msg = ''
+ if num_branches != len(num_blocks):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(num_branches, len(num_blocks))
+ elif num_branches != len(num_channels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(num_branches, len(num_channels))
+ elif num_branches != len(num_inchannels):
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(num_branches, len(num_inchannels))
+ if error_msg:
+ _logger.error(error_msg)
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
+ downsample = None
+ if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ self.num_inchannels[branch_index], num_channels[branch_index] * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=_BN_MOMENTUM),
+ )
+
+ layers = [block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)]
+ self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))
+
+ return nn.Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ branches = []
+ for i in range(num_branches):
+ branches.append(self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ if self.num_branches == 1:
+ return nn.Identity()
+
+ num_branches = self.num_branches
+ num_inchannels = self.num_inchannels
+ fuse_layers = []
+ for i in range(num_branches if self.multi_scale_output else 1):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),
+ nn.BatchNorm2d(num_inchannels[i], momentum=_BN_MOMENTUM),
+ nn.Upsample(scale_factor=2 ** (j - i), mode='nearest')))
+ elif j == i:
+ fuse_layer.append(nn.Identity())
+ else:
+ conv3x3s = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ num_outchannels_conv3x3 = num_inchannels[i]
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
+ nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM)))
+ else:
+ num_outchannels_conv3x3 = num_inchannels[j]
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
+ nn.BatchNorm2d(num_outchannels_conv3x3, momentum=_BN_MOMENTUM),
+ nn.ReLU(False)))
+ fuse_layer.append(nn.Sequential(*conv3x3s))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def get_num_inchannels(self):
+ return self.num_inchannels
+
+ def forward(self, x: List[torch.Tensor]):
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i, branch in enumerate(self.branches):
+ x[i] = branch(x[i])
+
+ x_fuse = []
+ for i, fuse_outer in enumerate(self.fuse_layers):
+ y = x[0] if i == 0 else fuse_outer[0](x[0])
+ for j in range(1, self.num_branches):
+ if i == j:
+ y = y + x[j]
+ else:
+ y = y + fuse_outer[j](x[j])
+ x_fuse.append(self.fuse_act(y))
+
+ return x_fuse
+
+
+blocks_dict = {
+ 'BASIC': BasicBlock,
+ 'BOTTLENECK': Bottleneck
+}
+
+
+class HighResolutionNet(nn.Module):
+
+ def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0, head='classification'):
+ super(HighResolutionNet, self).__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+
+ stem_width = cfg['STEM_WIDTH']
+ self.conv1 = nn.Conv2d(in_chans, stem_width, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(stem_width, momentum=_BN_MOMENTUM)
+ self.act1 = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(stem_width, 64, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(64, momentum=_BN_MOMENTUM)
+ self.act2 = nn.ReLU(inplace=True)
+
+ self.stage1_cfg = cfg['STAGE1']
+ num_channels = self.stage1_cfg['NUM_CHANNELS'][0]
+ block = blocks_dict[self.stage1_cfg['BLOCK']]
+ num_blocks = self.stage1_cfg['NUM_BLOCKS'][0]
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
+ stage1_out_channel = block.expansion * num_channels
+
+ self.stage2_cfg = cfg['STAGE2']
+ num_channels = self.stage2_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage2_cfg['BLOCK']]
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition1 = self._make_transition_layer([stage1_out_channel], num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels)
+
+ self.stage3_cfg = cfg['STAGE3']
+ num_channels = self.stage3_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage3_cfg['BLOCK']]
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels)
+
+ self.stage4_cfg = cfg['STAGE4']
+ num_channels = self.stage4_cfg['NUM_CHANNELS']
+ block = blocks_dict[self.stage4_cfg['BLOCK']]
+ num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))]
+ self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels)
+ self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True)
+
+ self.head = head
+ self.head_channels = None # set if _make_head called
+ if head == 'classification':
+ # Classification Head
+ self.num_features = 2048
+ self.incre_modules, self.downsamp_modules, self.final_layer = self._make_head(pre_stage_channels)
+ self.global_pool, self.classifier = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+ elif head == 'incre':
+ self.num_features = 2048
+ self.incre_modules, _, _ = self._make_head(pre_stage_channels, True)
+ else:
+ self.incre_modules = None
+ self.num_features = 256
+
+ curr_stride = 2
+ # module names aren't actually valid here, hook or FeatureNet based extraction would not work
+ self.feature_info = [dict(num_chs=64, reduction=curr_stride, module='stem')]
+ for i, c in enumerate(self.head_channels if self.head_channels else num_channels):
+ curr_stride *= 2
+ c = c * 4 if self.head_channels else c # head block expansion factor of 4
+ self.feature_info += [dict(num_chs=c, reduction=curr_stride, module=f'stage{i + 1}')]
+
+ self.init_weights()
+
+ def _make_head(self, pre_stage_channels, incre_only=False):
+ head_block = Bottleneck
+ self.head_channels = [32, 64, 128, 256]
+
+ # Increasing the #channels on each resolution
+ # from C, 2C, 4C, 8C to 128, 256, 512, 1024
+ incre_modules = []
+ for i, channels in enumerate(pre_stage_channels):
+ incre_modules.append(self._make_layer(head_block, channels, self.head_channels[i], 1, stride=1))
+ incre_modules = nn.ModuleList(incre_modules)
+ if incre_only:
+ return incre_modules, None, None
+
+ # downsampling modules
+ downsamp_modules = []
+ for i in range(len(pre_stage_channels) - 1):
+ in_channels = self.head_channels[i] * head_block.expansion
+ out_channels = self.head_channels[i + 1] * head_block.expansion
+ downsamp_module = nn.Sequential(
+ nn.Conv2d(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1),
+ nn.BatchNorm2d(out_channels, momentum=_BN_MOMENTUM),
+ nn.ReLU(inplace=True)
+ )
+ downsamp_modules.append(downsamp_module)
+ downsamp_modules = nn.ModuleList(downsamp_modules)
+
+ final_layer = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.head_channels[3] * head_block.expansion,
+ out_channels=self.num_features, kernel_size=1, stride=1, padding=0
+ ),
+ nn.BatchNorm2d(self.num_features, momentum=_BN_MOMENTUM),
+ nn.ReLU(inplace=True)
+ )
+
+ return incre_modules, downsamp_modules, final_layer
+
+ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer):
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(nn.Sequential(
+ nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False),
+ nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM),
+ nn.ReLU(inplace=True)))
+ else:
+ transition_layers.append(nn.Identity())
+ else:
+ conv3x3s = []
+ for j in range(i + 1 - num_branches_pre):
+ inchannels = num_channels_pre_layer[-1]
+ outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
+ conv3x3s.append(nn.Sequential(
+ nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
+ nn.BatchNorm2d(outchannels, momentum=_BN_MOMENTUM),
+ nn.ReLU(inplace=True)))
+ transition_layers.append(nn.Sequential(*conv3x3s))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion, momentum=_BN_MOMENTUM),
+ )
+
+ layers = [block(inplanes, planes, stride, downsample)]
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
+ num_modules = layer_config['NUM_MODULES']
+ num_branches = layer_config['NUM_BRANCHES']
+ num_blocks = layer_config['NUM_BLOCKS']
+ num_channels = layer_config['NUM_CHANNELS']
+ block = blocks_dict[layer_config['BLOCK']]
+ fuse_method = layer_config['FUSE_METHOD']
+
+ modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used last module
+ reset_multi_scale_output = multi_scale_output or i < num_modules - 1
+ modules.append(HighResolutionModule(
+ num_branches, block, num_blocks, num_inchannels, num_channels, fuse_method, reset_multi_scale_output)
+ )
+ num_inchannels = modules[-1].get_num_inchannels()
+
+ return nn.Sequential(*modules), num_inchannels
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def get_classifier(self):
+ return self.classifier
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.classifier = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def stages(self, x) -> List[torch.Tensor]:
+ x = self.layer1(x)
+
+ xl = [t(x) for i, t in enumerate(self.transition1)]
+ yl = self.stage2(xl)
+
+ xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition2)]
+ yl = self.stage3(xl)
+
+ xl = [t(yl[-1]) if not isinstance(t, nn.Identity) else yl[i] for i, t in enumerate(self.transition3)]
+ yl = self.stage4(xl)
+ return yl
+
+ def forward_features(self, x):
+ # Stem
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.act2(x)
+
+ # Stages
+ yl = self.stages(x)
+
+ # Classification Head
+ y = self.incre_modules[0](yl[0])
+ for i, down in enumerate(self.downsamp_modules):
+ y = self.incre_modules[i + 1](yl[i + 1]) + down(y)
+ y = self.final_layer(y)
+ return y
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate > 0.:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ x = self.classifier(x)
+ return x
+
+
+class HighResolutionNetFeatures(HighResolutionNet):
+ """HighResolutionNet feature extraction
+
+ The design of HRNet makes it easy to grab feature maps, this class provides a simple wrapper to do so.
+ It would be more complicated to use the FeatureNet helpers.
+
+ The `feature_location=incre` allows grabbing increased channel count features using part of the
+ classification head. If `feature_location=''` the default HRNet features are returned. First stem
+ conv is used for stride 2 features.
+ """
+
+ def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0.0,
+ feature_location='incre', out_indices=(0, 1, 2, 3, 4)):
+ assert feature_location in ('incre', '')
+ super(HighResolutionNetFeatures, self).__init__(
+ cfg, in_chans=in_chans, num_classes=num_classes, global_pool=global_pool,
+ drop_rate=drop_rate, head=feature_location)
+ self.feature_info = FeatureInfo(self.feature_info, out_indices)
+ self._out_idx = {i for i in out_indices}
+
+ def forward_features(self, x):
+ assert False, 'Not supported'
+
+ def forward(self, x) -> List[torch.tensor]:
+ out = []
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ if 0 in self._out_idx:
+ out.append(x)
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.act2(x)
+ x = self.stages(x)
+ if self.incre_modules is not None:
+ x = [incre(f) for f, incre in zip(x, self.incre_modules)]
+ for i, f in enumerate(x):
+ if i + 1 in self._out_idx:
+ out.append(f)
+ return out
+
+
+def _create_hrnet(variant, pretrained, **model_kwargs):
+ model_cls = HighResolutionNet
+ features_only = False
+ kwargs_filter = None
+ if model_kwargs.pop('features_only', False):
+ model_cls = HighResolutionNetFeatures
+ kwargs_filter = ('num_classes', 'global_pool')
+ features_only = True
+ model = build_model_with_cfg(
+ model_cls, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ model_cfg=cfg_cls[variant],
+ pretrained_strict=not features_only,
+ kwargs_filter=kwargs_filter,
+ **model_kwargs)
+ if features_only:
+ model.default_cfg = default_cfg_for_features(model.default_cfg)
+ return model
+
+
+@register_model
+def hrnet_w18_small(pretrained=True, **kwargs):
+ return _create_hrnet('hrnet_w18_small', pretrained, **kwargs)
+
+
+@register_model
+def hrnet_w18_small_v2(pretrained=True, **kwargs):
+ return _create_hrnet('hrnet_w18_small_v2', pretrained, **kwargs)
+
+
+@register_model
+def hrnet_w18(pretrained=True, **kwargs):
+ return _create_hrnet('hrnet_w18', pretrained, **kwargs)
+
+
+@register_model
+def hrnet_w30(pretrained=True, **kwargs):
+ return _create_hrnet('hrnet_w30', pretrained, **kwargs)
+
+
+@register_model
+def hrnet_w32(pretrained=True, **kwargs):
+ return _create_hrnet('hrnet_w32', pretrained, **kwargs)
+
+
+@register_model
+def hrnet_w40(pretrained=True, **kwargs):
+ return _create_hrnet('hrnet_w40', pretrained, **kwargs)
+
+
+@register_model
+def hrnet_w44(pretrained=True, **kwargs):
+ return _create_hrnet('hrnet_w44', pretrained, **kwargs)
+
+
+@register_model
+def hrnet_w48(pretrained=True, **kwargs):
+ return _create_hrnet('hrnet_w48', pretrained, **kwargs)
+
+
+@register_model
+def hrnet_w64(pretrained=True, **kwargs):
+ return _create_hrnet('hrnet_w64', pretrained, **kwargs)
diff --git a/timm/models/hub.py b/timm/models/hub.py
new file mode 100644
index 0000000..65e7ba9
--- /dev/null
+++ b/timm/models/hub.py
@@ -0,0 +1,171 @@
+import json
+import logging
+import os
+from functools import partial
+from pathlib import Path
+from typing import Union
+
+import torch
+from torch.hub import HASH_REGEX, download_url_to_file, urlparse
+try:
+ from torch.hub import get_dir
+except ImportError:
+ from torch.hub import _get_torch_home as get_dir
+
+from timm import __version__
+try:
+ from huggingface_hub import HfApi, HfFolder, Repository, cached_download, hf_hub_url
+ cached_download = partial(cached_download, library_name="timm", library_version=__version__)
+ _has_hf_hub = True
+except ImportError:
+ cached_download = None
+ _has_hf_hub = False
+
+_logger = logging.getLogger(__name__)
+
+
+def get_cache_dir(child_dir=''):
+ """
+ Returns the location of the directory where models are cached (and creates it if necessary).
+ """
+ # Issue warning to move data if old env is set
+ if os.getenv('TORCH_MODEL_ZOO'):
+ _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
+
+ hub_dir = get_dir()
+ child_dir = () if not child_dir else (child_dir,)
+ model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
+ os.makedirs(model_dir, exist_ok=True)
+ return model_dir
+
+
+def download_cached_file(url, check_hash=True, progress=False):
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ cached_file = os.path.join(get_cache_dir(), filename)
+ if not os.path.exists(cached_file):
+ _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
+ hash_prefix = None
+ if check_hash:
+ r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
+ hash_prefix = r.group(1) if r else None
+ download_url_to_file(url, cached_file, hash_prefix, progress=progress)
+ return cached_file
+
+
+def has_hf_hub(necessary=False):
+ if not _has_hf_hub and necessary:
+ # if no HF Hub module installed and it is necessary to continue, raise error
+ raise RuntimeError(
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
+ return _has_hf_hub
+
+
+def hf_split(hf_id):
+ rev_split = hf_id.split('@')
+ assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
+ hf_model_id = rev_split[0]
+ hf_revision = rev_split[-1] if len(rev_split) > 1 else None
+ return hf_model_id, hf_revision
+
+
+def load_cfg_from_json(json_file: Union[str, os.PathLike]):
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ return json.loads(text)
+
+
+def _download_from_hf(model_id: str, filename: str):
+ hf_model_id, hf_revision = hf_split(model_id)
+ url = hf_hub_url(hf_model_id, filename, revision=hf_revision)
+ return cached_download(url, cache_dir=get_cache_dir('hf'))
+
+
+def load_model_config_from_hf(model_id: str):
+ assert has_hf_hub(True)
+ cached_file = _download_from_hf(model_id, 'config.json')
+ default_cfg = load_cfg_from_json(cached_file)
+ default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation
+ model_name = default_cfg.get('architecture')
+ return default_cfg, model_name
+
+
+def load_state_dict_from_hf(model_id: str):
+ assert has_hf_hub(True)
+ cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
+ state_dict = torch.load(cached_file, map_location='cpu')
+ return state_dict
+
+
+def save_for_hf(model, save_directory, model_config=None):
+ assert has_hf_hub(True)
+ model_config = model_config or {}
+ save_directory = Path(save_directory)
+ save_directory.mkdir(exist_ok=True, parents=True)
+
+ weights_path = save_directory / 'pytorch_model.bin'
+ torch.save(model.state_dict(), weights_path)
+
+ config_path = save_directory / 'config.json'
+ hf_config = model.default_cfg
+ hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
+ hf_config['num_features'] = model_config.pop('num_features', model.num_features)
+ hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])])
+ hf_config.update(model_config)
+
+ with config_path.open('w') as f:
+ json.dump(hf_config, f, indent=2)
+
+
+def push_to_hf_hub(
+ model,
+ local_dir,
+ repo_namespace_or_url=None,
+ commit_message='Add model',
+ use_auth_token=True,
+ git_email=None,
+ git_user=None,
+ revision=None,
+ model_config=None,
+):
+ if repo_namespace_or_url:
+ repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:]
+ else:
+ if isinstance(use_auth_token, str):
+ token = use_auth_token
+ else:
+ token = HfFolder.get_token()
+
+ if token is None:
+ raise ValueError(
+ "You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
+ "entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
+ "token as the `use_auth_token` argument."
+ )
+
+ repo_owner = HfApi().whoami(token)['name']
+ repo_name = Path(local_dir).name
+
+ repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}'
+
+ repo = Repository(
+ local_dir,
+ clone_from=repo_url,
+ use_auth_token=use_auth_token,
+ git_user=git_user,
+ git_email=git_email,
+ revision=revision,
+ )
+
+ # Prepare a default model card that includes the necessary tags to enable inference.
+ readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
+ with repo.commit(commit_message):
+ # Save model weights and config.
+ save_for_hf(model, repo.local_dir, model_config=model_config)
+
+ # Save a model card if it doesn't exist.
+ readme_path = Path(repo.local_dir) / 'README.md'
+ if not readme_path.exists():
+ readme_path.write_text(readme_text)
+
+ return repo.git_remote_url()
diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py
new file mode 100644
index 0000000..7167284
--- /dev/null
+++ b/timm/models/inception_resnet_v2.py
@@ -0,0 +1,358 @@
+""" Pytorch Inception-Resnet-V2 implementation
+Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
+based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
+from .helpers import build_model_with_cfg
+from .layers import create_classifier
+from .registry import register_model
+
+__all__ = ['InceptionResnetV2']
+
+default_cfgs = {
+ # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
+ 'inception_resnet_v2': {
+ 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
+ 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
+ 'crop_pct': 0.8975, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
+ 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
+ 'label_offset': 1, # 1001 classes in pretrained weights
+ },
+ # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
+ 'ens_adv_inception_resnet_v2': {
+ 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
+ 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
+ 'crop_pct': 0.8975, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
+ 'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
+ 'label_offset': 1, # 1001 classes in pretrained weights
+ }
+}
+
+
+class BasicConv2d(nn.Module):
+ def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
+ super(BasicConv2d, self).__init__()
+ self.conv = nn.Conv2d(
+ in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
+ self.bn = nn.BatchNorm2d(out_planes, eps=.001)
+ self.relu = nn.ReLU(inplace=False)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ return x
+
+
+class Mixed_5b(nn.Module):
+ def __init__(self):
+ super(Mixed_5b, self).__init__()
+
+ self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(192, 48, kernel_size=1, stride=1),
+ BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
+ )
+
+ self.branch2 = nn.Sequential(
+ BasicConv2d(192, 64, kernel_size=1, stride=1),
+ BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
+ BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
+ )
+
+ self.branch3 = nn.Sequential(
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
+ BasicConv2d(192, 64, kernel_size=1, stride=1)
+ )
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ x3 = self.branch3(x)
+ out = torch.cat((x0, x1, x2, x3), 1)
+ return out
+
+
+class Block35(nn.Module):
+ def __init__(self, scale=1.0):
+ super(Block35, self).__init__()
+
+ self.scale = scale
+
+ self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
+ BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
+ )
+
+ self.branch2 = nn.Sequential(
+ BasicConv2d(320, 32, kernel_size=1, stride=1),
+ BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
+ BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
+ )
+
+ self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
+ self.relu = nn.ReLU(inplace=False)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ out = torch.cat((x0, x1, x2), 1)
+ out = self.conv2d(out)
+ out = out * self.scale + x
+ out = self.relu(out)
+ return out
+
+
+class Mixed_6a(nn.Module):
+ def __init__(self):
+ super(Mixed_6a, self).__init__()
+
+ self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(320, 256, kernel_size=1, stride=1),
+ BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
+ )
+
+ self.branch2 = nn.MaxPool2d(3, stride=2)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ out = torch.cat((x0, x1, x2), 1)
+ return out
+
+
+class Block17(nn.Module):
+ def __init__(self, scale=1.0):
+ super(Block17, self).__init__()
+
+ self.scale = scale
+
+ self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(1088, 128, kernel_size=1, stride=1),
+ BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)),
+ BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0))
+ )
+
+ self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
+ self.relu = nn.ReLU(inplace=False)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ out = torch.cat((x0, x1), 1)
+ out = self.conv2d(out)
+ out = out * self.scale + x
+ out = self.relu(out)
+ return out
+
+
+class Mixed_7a(nn.Module):
+ def __init__(self):
+ super(Mixed_7a, self).__init__()
+
+ self.branch0 = nn.Sequential(
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
+ BasicConv2d(256, 384, kernel_size=3, stride=2)
+ )
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
+ BasicConv2d(256, 288, kernel_size=3, stride=2)
+ )
+
+ self.branch2 = nn.Sequential(
+ BasicConv2d(1088, 256, kernel_size=1, stride=1),
+ BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
+ BasicConv2d(288, 320, kernel_size=3, stride=2)
+ )
+
+ self.branch3 = nn.MaxPool2d(3, stride=2)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ x3 = self.branch3(x)
+ out = torch.cat((x0, x1, x2, x3), 1)
+ return out
+
+
+class Block8(nn.Module):
+
+ def __init__(self, scale=1.0, no_relu=False):
+ super(Block8, self).__init__()
+
+ self.scale = scale
+
+ self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(2080, 192, kernel_size=1, stride=1),
+ BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)),
+ BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
+ )
+
+ self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
+ self.relu = None if no_relu else nn.ReLU(inplace=False)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ out = torch.cat((x0, x1), 1)
+ out = self.conv2d(out)
+ out = out * self.scale + x
+ if self.relu is not None:
+ out = self.relu(out)
+ return out
+
+
+class InceptionResnetV2(nn.Module):
+ def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'):
+ super(InceptionResnetV2, self).__init__()
+ self.drop_rate = drop_rate
+ self.num_classes = num_classes
+ self.num_features = 1536
+ assert output_stride == 32
+
+ self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
+ self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
+ self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
+ self.feature_info = [dict(num_chs=64, reduction=2, module='conv2d_2b')]
+
+ self.maxpool_3a = nn.MaxPool2d(3, stride=2)
+ self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
+ self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
+ self.feature_info += [dict(num_chs=192, reduction=4, module='conv2d_4a')]
+
+ self.maxpool_5a = nn.MaxPool2d(3, stride=2)
+ self.mixed_5b = Mixed_5b()
+ self.repeat = nn.Sequential(
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17),
+ Block35(scale=0.17)
+ )
+ self.feature_info += [dict(num_chs=320, reduction=8, module='repeat')]
+
+ self.mixed_6a = Mixed_6a()
+ self.repeat_1 = nn.Sequential(
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10),
+ Block17(scale=0.10)
+ )
+ self.feature_info += [dict(num_chs=1088, reduction=16, module='repeat_1')]
+
+ self.mixed_7a = Mixed_7a()
+ self.repeat_2 = nn.Sequential(
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20),
+ Block8(scale=0.20)
+ )
+ self.block8 = Block8(no_relu=True)
+ self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
+ self.feature_info += [dict(num_chs=self.num_features, reduction=32, module='conv2d_7b')]
+
+ self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ def get_classifier(self):
+ return self.classif
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.classif = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ x = self.conv2d_1a(x)
+ x = self.conv2d_2a(x)
+ x = self.conv2d_2b(x)
+ x = self.maxpool_3a(x)
+ x = self.conv2d_3b(x)
+ x = self.conv2d_4a(x)
+ x = self.maxpool_5a(x)
+ x = self.mixed_5b(x)
+ x = self.repeat(x)
+ x = self.mixed_6a(x)
+ x = self.repeat_1(x)
+ x = self.mixed_7a(x)
+ x = self.repeat_2(x)
+ x = self.block8(x)
+ x = self.conv2d_7b(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate > 0:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ x = self.classif(x)
+ return x
+
+
+def _create_inception_resnet_v2(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ InceptionResnetV2, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ **kwargs)
+
+
+@register_model
+def inception_resnet_v2(pretrained=False, **kwargs):
+ r"""InceptionResnetV2 model architecture from the
+ `"InceptionV4, Inception-ResNet..." ` paper.
+ """
+ return _create_inception_resnet_v2('inception_resnet_v2', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def ens_adv_inception_resnet_v2(pretrained=False, **kwargs):
+ r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
+ As per https://arxiv.org/abs/1705.07204 and
+ https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
+ """
+ return _create_inception_resnet_v2('ens_adv_inception_resnet_v2', pretrained=pretrained, **kwargs)
diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py
new file mode 100644
index 0000000..cbb1107
--- /dev/null
+++ b/timm/models/inception_v3.py
@@ -0,0 +1,470 @@
+""" Inception-V3
+
+Originally from torchvision Inception3 model
+Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
+from .helpers import build_model_with_cfg
+from .registry import register_model
+from .layers import trunc_normal_, create_classifier, Linear
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
+ 'first_conv': 'Conv2d_1a_3x3.conv', 'classifier': 'fc',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # original PyTorch weights, ported from Tensorflow but modified
+ 'inception_v3': _cfg(
+ url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
+ has_aux=True), # checkpoint has aux logit layer weights
+ # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
+ 'tf_inception_v3': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
+ num_classes=1000, has_aux=False, label_offset=1),
+ # my port of Tensorflow adversarially trained Inception V3 from
+ # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
+ 'adv_inception_v3': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
+ num_classes=1000, has_aux=False, label_offset=1),
+ # from gluon pretrained models, best performing in terms of accuracy/loss metrics
+ # https://gluon-cv.mxnet.io/model_zoo/classification.html
+ 'gluon_inception_v3': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth',
+ mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
+ std=IMAGENET_DEFAULT_STD, # also works well with inception defaults
+ has_aux=False,
+ )
+}
+
+
+class InceptionA(nn.Module):
+
+ def __init__(self, in_channels, pool_features, conv_block=None):
+ super(InceptionA, self).__init__()
+ if conv_block is None:
+ conv_block = BasicConv2d
+ self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
+
+ self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
+ self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
+
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
+
+ self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
+
+ def _forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch5x5 = self.branch5x5_1(x)
+ branch5x5 = self.branch5x5_2(branch5x5)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+ return outputs
+
+ def forward(self, x):
+ outputs = self._forward(x)
+ return torch.cat(outputs, 1)
+
+
+class InceptionB(nn.Module):
+
+ def __init__(self, in_channels, conv_block=None):
+ super(InceptionB, self).__init__()
+ if conv_block is None:
+ conv_block = BasicConv2d
+ self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
+
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
+
+ def _forward(self, x):
+ branch3x3 = self.branch3x3(x)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
+
+ outputs = [branch3x3, branch3x3dbl, branch_pool]
+ return outputs
+
+ def forward(self, x):
+ outputs = self._forward(x)
+ return torch.cat(outputs, 1)
+
+
+class InceptionC(nn.Module):
+
+ def __init__(self, in_channels, channels_7x7, conv_block=None):
+ super(InceptionC, self).__init__()
+ if conv_block is None:
+ conv_block = BasicConv2d
+ self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
+
+ c7 = channels_7x7
+ self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
+ self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
+ self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
+
+ self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
+ self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
+ self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
+ self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
+ self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
+
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
+
+ def _forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch7x7 = self.branch7x7_1(x)
+ branch7x7 = self.branch7x7_2(branch7x7)
+ branch7x7 = self.branch7x7_3(branch7x7)
+
+ branch7x7dbl = self.branch7x7dbl_1(x)
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+ return outputs
+
+ def forward(self, x):
+ outputs = self._forward(x)
+ return torch.cat(outputs, 1)
+
+
+class InceptionD(nn.Module):
+
+ def __init__(self, in_channels, conv_block=None):
+ super(InceptionD, self).__init__()
+ if conv_block is None:
+ conv_block = BasicConv2d
+ self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
+ self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
+
+ self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
+ self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
+ self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
+ self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
+
+ def _forward(self, x):
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = self.branch3x3_2(branch3x3)
+
+ branch7x7x3 = self.branch7x7x3_1(x)
+ branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
+ branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
+ branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
+
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
+ outputs = [branch3x3, branch7x7x3, branch_pool]
+ return outputs
+
+ def forward(self, x):
+ outputs = self._forward(x)
+ return torch.cat(outputs, 1)
+
+
+class InceptionE(nn.Module):
+
+ def __init__(self, in_channels, conv_block=None):
+ super(InceptionE, self).__init__()
+ if conv_block is None:
+ conv_block = BasicConv2d
+ self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
+
+ self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
+ self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
+ self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
+
+ self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
+ self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
+ self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
+ self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
+
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
+
+ def _forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = [
+ self.branch3x3_2a(branch3x3),
+ self.branch3x3_2b(branch3x3),
+ ]
+ branch3x3 = torch.cat(branch3x3, 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = [
+ self.branch3x3dbl_3a(branch3x3dbl),
+ self.branch3x3dbl_3b(branch3x3dbl),
+ ]
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+ return outputs
+
+ def forward(self, x):
+ outputs = self._forward(x)
+ return torch.cat(outputs, 1)
+
+
+class InceptionAux(nn.Module):
+
+ def __init__(self, in_channels, num_classes, conv_block=None):
+ super(InceptionAux, self).__init__()
+ if conv_block is None:
+ conv_block = BasicConv2d
+ self.conv0 = conv_block(in_channels, 128, kernel_size=1)
+ self.conv1 = conv_block(128, 768, kernel_size=5)
+ self.conv1.stddev = 0.01
+ self.fc = Linear(768, num_classes)
+ self.fc.stddev = 0.001
+
+ def forward(self, x):
+ # N x 768 x 17 x 17
+ x = F.avg_pool2d(x, kernel_size=5, stride=3)
+ # N x 768 x 5 x 5
+ x = self.conv0(x)
+ # N x 128 x 5 x 5
+ x = self.conv1(x)
+ # N x 768 x 1 x 1
+ # Adaptive average pooling
+ x = F.adaptive_avg_pool2d(x, (1, 1))
+ # N x 768 x 1 x 1
+ x = torch.flatten(x, 1)
+ # N x 768
+ x = self.fc(x)
+ # N x 1000
+ return x
+
+
+class BasicConv2d(nn.Module):
+
+ def __init__(self, in_channels, out_channels, **kwargs):
+ super(BasicConv2d, self).__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return F.relu(x, inplace=True)
+
+
+class InceptionV3(nn.Module):
+ """Inception-V3 with no AuxLogits
+ FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns
+ """
+
+ def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=False):
+ super(InceptionV3, self).__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ self.aux_logits = aux_logits
+
+ self.Conv2d_1a_3x3 = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
+ self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
+ self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
+ self.Pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
+ self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
+ self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
+ self.Pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
+ self.Mixed_5b = InceptionA(192, pool_features=32)
+ self.Mixed_5c = InceptionA(256, pool_features=64)
+ self.Mixed_5d = InceptionA(288, pool_features=64)
+ self.Mixed_6a = InceptionB(288)
+ self.Mixed_6b = InceptionC(768, channels_7x7=128)
+ self.Mixed_6c = InceptionC(768, channels_7x7=160)
+ self.Mixed_6d = InceptionC(768, channels_7x7=160)
+ self.Mixed_6e = InceptionC(768, channels_7x7=192)
+ if aux_logits:
+ self.AuxLogits = InceptionAux(768, num_classes)
+ else:
+ self.AuxLogits = None
+ self.Mixed_7a = InceptionD(768)
+ self.Mixed_7b = InceptionE(1280)
+ self.Mixed_7c = InceptionE(2048)
+ self.feature_info = [
+ dict(num_chs=64, reduction=2, module='Conv2d_2b_3x3'),
+ dict(num_chs=192, reduction=4, module='Conv2d_4a_3x3'),
+ dict(num_chs=288, reduction=8, module='Mixed_5d'),
+ dict(num_chs=768, reduction=16, module='Mixed_6e'),
+ dict(num_chs=2048, reduction=32, module='Mixed_7c'),
+ ]
+
+ self.num_features = 2048
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
+ stddev = m.stddev if hasattr(m, 'stddev') else 0.1
+ trunc_normal_(m.weight, std=stddev)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward_preaux(self, x):
+ # N x 3 x 299 x 299
+ x = self.Conv2d_1a_3x3(x)
+ # N x 32 x 149 x 149
+ x = self.Conv2d_2a_3x3(x)
+ # N x 32 x 147 x 147
+ x = self.Conv2d_2b_3x3(x)
+ # N x 64 x 147 x 147
+ x = self.Pool1(x)
+ # N x 64 x 73 x 73
+ x = self.Conv2d_3b_1x1(x)
+ # N x 80 x 73 x 73
+ x = self.Conv2d_4a_3x3(x)
+ # N x 192 x 71 x 71
+ x = self.Pool2(x)
+ # N x 192 x 35 x 35
+ x = self.Mixed_5b(x)
+ # N x 256 x 35 x 35
+ x = self.Mixed_5c(x)
+ # N x 288 x 35 x 35
+ x = self.Mixed_5d(x)
+ # N x 288 x 35 x 35
+ x = self.Mixed_6a(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6b(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6c(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6d(x)
+ # N x 768 x 17 x 17
+ x = self.Mixed_6e(x)
+ # N x 768 x 17 x 17
+ return x
+
+ def forward_postaux(self, x):
+ x = self.Mixed_7a(x)
+ # N x 1280 x 8 x 8
+ x = self.Mixed_7b(x)
+ # N x 2048 x 8 x 8
+ x = self.Mixed_7c(x)
+ # N x 2048 x 8 x 8
+ return x
+
+ def forward_features(self, x):
+ x = self.forward_preaux(x)
+ x = self.forward_postaux(x)
+ return x
+
+ def get_classifier(self):
+ return self.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate > 0:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ x = self.fc(x)
+ return x
+
+
+class InceptionV3Aux(InceptionV3):
+ """InceptionV3 with AuxLogits
+ """
+
+ def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', aux_logits=True):
+ super(InceptionV3Aux, self).__init__(
+ num_classes, in_chans, drop_rate, global_pool, aux_logits)
+
+ def forward_features(self, x):
+ x = self.forward_preaux(x)
+ aux = self.AuxLogits(x) if self.training else None
+ x = self.forward_postaux(x)
+ return x, aux
+
+ def forward(self, x):
+ x, aux = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate > 0:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ x = self.fc(x)
+ return x, aux
+
+
+def _create_inception_v3(variant, pretrained=False, **kwargs):
+ default_cfg = default_cfgs[variant]
+ aux_logits = kwargs.pop('aux_logits', False)
+ if aux_logits:
+ assert not kwargs.pop('features_only', False)
+ model_cls = InceptionV3Aux
+ load_strict = default_cfg['has_aux']
+ else:
+ model_cls = InceptionV3
+ load_strict = not default_cfg['has_aux']
+ return build_model_with_cfg(
+ model_cls, variant, pretrained,
+ default_cfg=default_cfg,
+ pretrained_strict=load_strict,
+ **kwargs)
+
+
+@register_model
+def inception_v3(pretrained=False, **kwargs):
+ # original PyTorch weights, ported from Tensorflow but modified
+ model = _create_inception_v3('inception_v3', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_inception_v3(pretrained=False, **kwargs):
+ # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
+ model = _create_inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def adv_inception_v3(pretrained=False, **kwargs):
+ # my port of Tensorflow adversarially trained Inception V3 from
+ # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
+ model = _create_inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def gluon_inception_v3(pretrained=False, **kwargs):
+ # from gluon pretrained models, best performing in terms of accuracy/loss metrics
+ # https://gluon-cv.mxnet.io/model_zoo/classification.html
+ model = _create_inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs)
+ return model
diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py
new file mode 100644
index 0000000..cc899e1
--- /dev/null
+++ b/timm/models/inception_v4.py
@@ -0,0 +1,316 @@
+""" Pytorch Inception-V4 implementation
+Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
+based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
+from .helpers import build_model_with_cfg
+from .layers import create_classifier
+from .registry import register_model
+
+__all__ = ['InceptionV4']
+
+default_cfgs = {
+ 'inception_v4': {
+ 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth',
+ 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
+ 'first_conv': 'features.0.conv', 'classifier': 'last_linear',
+ 'label_offset': 1, # 1001 classes in pretrained weights
+ }
+}
+
+
+class BasicConv2d(nn.Module):
+ def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
+ super(BasicConv2d, self).__init__()
+ self.conv = nn.Conv2d(
+ in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
+ self.bn = nn.BatchNorm2d(out_planes, eps=0.001)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ return x
+
+
+class Mixed3a(nn.Module):
+ def __init__(self):
+ super(Mixed3a, self).__init__()
+ self.maxpool = nn.MaxPool2d(3, stride=2)
+ self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2)
+
+ def forward(self, x):
+ x0 = self.maxpool(x)
+ x1 = self.conv(x)
+ out = torch.cat((x0, x1), 1)
+ return out
+
+
+class Mixed4a(nn.Module):
+ def __init__(self):
+ super(Mixed4a, self).__init__()
+
+ self.branch0 = nn.Sequential(
+ BasicConv2d(160, 64, kernel_size=1, stride=1),
+ BasicConv2d(64, 96, kernel_size=3, stride=1)
+ )
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(160, 64, kernel_size=1, stride=1),
+ BasicConv2d(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)),
+ BasicConv2d(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)),
+ BasicConv2d(64, 96, kernel_size=(3, 3), stride=1)
+ )
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ out = torch.cat((x0, x1), 1)
+ return out
+
+
+class Mixed5a(nn.Module):
+ def __init__(self):
+ super(Mixed5a, self).__init__()
+ self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2)
+ self.maxpool = nn.MaxPool2d(3, stride=2)
+
+ def forward(self, x):
+ x0 = self.conv(x)
+ x1 = self.maxpool(x)
+ out = torch.cat((x0, x1), 1)
+ return out
+
+
+class InceptionA(nn.Module):
+ def __init__(self):
+ super(InceptionA, self).__init__()
+ self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1)
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(384, 64, kernel_size=1, stride=1),
+ BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1)
+ )
+
+ self.branch2 = nn.Sequential(
+ BasicConv2d(384, 64, kernel_size=1, stride=1),
+ BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
+ BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
+ )
+
+ self.branch3 = nn.Sequential(
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
+ BasicConv2d(384, 96, kernel_size=1, stride=1)
+ )
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ x3 = self.branch3(x)
+ out = torch.cat((x0, x1, x2, x3), 1)
+ return out
+
+
+class ReductionA(nn.Module):
+ def __init__(self):
+ super(ReductionA, self).__init__()
+ self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2)
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(384, 192, kernel_size=1, stride=1),
+ BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1),
+ BasicConv2d(224, 256, kernel_size=3, stride=2)
+ )
+
+ self.branch2 = nn.MaxPool2d(3, stride=2)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ out = torch.cat((x0, x1, x2), 1)
+ return out
+
+
+class InceptionB(nn.Module):
+ def __init__(self):
+ super(InceptionB, self).__init__()
+ self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1)
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(1024, 192, kernel_size=1, stride=1),
+ BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
+ BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0))
+ )
+
+ self.branch2 = nn.Sequential(
+ BasicConv2d(1024, 192, kernel_size=1, stride=1),
+ BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)),
+ BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
+ BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)),
+ BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3))
+ )
+
+ self.branch3 = nn.Sequential(
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
+ BasicConv2d(1024, 128, kernel_size=1, stride=1)
+ )
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ x3 = self.branch3(x)
+ out = torch.cat((x0, x1, x2, x3), 1)
+ return out
+
+
+class ReductionB(nn.Module):
+ def __init__(self):
+ super(ReductionB, self).__init__()
+
+ self.branch0 = nn.Sequential(
+ BasicConv2d(1024, 192, kernel_size=1, stride=1),
+ BasicConv2d(192, 192, kernel_size=3, stride=2)
+ )
+
+ self.branch1 = nn.Sequential(
+ BasicConv2d(1024, 256, kernel_size=1, stride=1),
+ BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)),
+ BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)),
+ BasicConv2d(320, 320, kernel_size=3, stride=2)
+ )
+
+ self.branch2 = nn.MaxPool2d(3, stride=2)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ out = torch.cat((x0, x1, x2), 1)
+ return out
+
+
+class InceptionC(nn.Module):
+ def __init__(self):
+ super(InceptionC, self).__init__()
+
+ self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1)
+
+ self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
+ self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
+ self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
+
+ self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
+ self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0))
+ self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1))
+ self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
+ self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
+
+ self.branch3 = nn.Sequential(
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
+ BasicConv2d(1536, 256, kernel_size=1, stride=1)
+ )
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+
+ x1_0 = self.branch1_0(x)
+ x1_1a = self.branch1_1a(x1_0)
+ x1_1b = self.branch1_1b(x1_0)
+ x1 = torch.cat((x1_1a, x1_1b), 1)
+
+ x2_0 = self.branch2_0(x)
+ x2_1 = self.branch2_1(x2_0)
+ x2_2 = self.branch2_2(x2_1)
+ x2_3a = self.branch2_3a(x2_2)
+ x2_3b = self.branch2_3b(x2_2)
+ x2 = torch.cat((x2_3a, x2_3b), 1)
+
+ x3 = self.branch3(x)
+
+ out = torch.cat((x0, x1, x2, x3), 1)
+ return out
+
+
+class InceptionV4(nn.Module):
+ def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'):
+ super(InceptionV4, self).__init__()
+ assert output_stride == 32
+ self.drop_rate = drop_rate
+ self.num_classes = num_classes
+ self.num_features = 1536
+
+ self.features = nn.Sequential(
+ BasicConv2d(in_chans, 32, kernel_size=3, stride=2),
+ BasicConv2d(32, 32, kernel_size=3, stride=1),
+ BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
+ Mixed3a(),
+ Mixed4a(),
+ Mixed5a(),
+ InceptionA(),
+ InceptionA(),
+ InceptionA(),
+ InceptionA(),
+ ReductionA(), # Mixed6a
+ InceptionB(),
+ InceptionB(),
+ InceptionB(),
+ InceptionB(),
+ InceptionB(),
+ InceptionB(),
+ InceptionB(),
+ ReductionB(), # Mixed7a
+ InceptionC(),
+ InceptionC(),
+ InceptionC(),
+ )
+ self.feature_info = [
+ dict(num_chs=64, reduction=2, module='features.2'),
+ dict(num_chs=160, reduction=4, module='features.3'),
+ dict(num_chs=384, reduction=8, module='features.9'),
+ dict(num_chs=1024, reduction=16, module='features.17'),
+ dict(num_chs=1536, reduction=32, module='features.21'),
+ ]
+ self.global_pool, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def get_classifier(self):
+ return self.last_linear
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ return self.features(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate > 0:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ x = self.last_linear(x)
+ return x
+
+
+def _create_inception_v4(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ InceptionV4, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ feature_cfg=dict(flatten_sequential=True),
+ **kwargs)
+
+
+@register_model
+def inception_v4(pretrained=False, **kwargs):
+ return _create_inception_v4('inception_v4', pretrained, **kwargs)
diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py
new file mode 100644
index 0000000..4831af9
--- /dev/null
+++ b/timm/models/layers/__init__.py
@@ -0,0 +1,40 @@
+from .activations import *
+from .adaptive_avgmax_pool import \
+ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
+from .blur_pool import BlurPool2d
+from .classifier import ClassifierHead, create_classifier
+from .cond_conv2d import CondConv2d, get_condconv_initializer
+from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
+ set_layer_config
+from .conv2d_same import Conv2dSame, conv2d_same
+from .conv_bn_act import ConvBnAct
+from .create_act import create_act_layer, get_act_layer, get_act_fn
+from .create_attn import get_attn, create_attn
+from .create_conv2d import create_conv2d
+from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
+from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
+from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
+from .evo_norm import EvoNormBatch2d, EvoNormSample2d
+from .gather_excite import GatherExcite
+from .global_context import GlobalContext
+from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
+from .inplace_abn import InplaceAbn
+from .linear import Linear
+from .mixed_conv2d import MixedConv2d
+from .mlp import Mlp, GluMlp, GatedMlp
+from .non_local_attn import NonLocalAttn, BatNonLocalAttn
+from .norm import GroupNorm, LayerNorm2d
+from .norm_act import BatchNormAct2d, GroupNormAct
+from .padding import get_padding, get_same_padding, pad_same
+from .patch_embed import PatchEmbed
+from .pool2d_same import AvgPool2dSame, create_pool2d
+from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
+from .selective_kernel import SelectiveKernel
+from .separable_conv import SeparableConv2d, SeparableConvBnAct
+from .space_to_depth import SpaceToDepthModule
+from .split_attn import SplitAttn
+from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
+from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
+from .test_time_pool import TestTimePoolHead, apply_test_time_pool
+from .trace_utils import _assert, _float_to_int
+from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_
diff --git a/timm/models/layers/__pycache__/__init__.cpython-36.pyc b/timm/models/layers/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..bda2543
Binary files /dev/null and b/timm/models/layers/__pycache__/__init__.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/activations.cpython-36.pyc b/timm/models/layers/__pycache__/activations.cpython-36.pyc
new file mode 100644
index 0000000..91b3edf
Binary files /dev/null and b/timm/models/layers/__pycache__/activations.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/activations_jit.cpython-36.pyc b/timm/models/layers/__pycache__/activations_jit.cpython-36.pyc
new file mode 100644
index 0000000..5a0a1e4
Binary files /dev/null and b/timm/models/layers/__pycache__/activations_jit.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/activations_me.cpython-36.pyc b/timm/models/layers/__pycache__/activations_me.cpython-36.pyc
new file mode 100644
index 0000000..ba3aa66
Binary files /dev/null and b/timm/models/layers/__pycache__/activations_me.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-36.pyc b/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-36.pyc
new file mode 100644
index 0000000..dfb6e33
Binary files /dev/null and b/timm/models/layers/__pycache__/adaptive_avgmax_pool.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/blur_pool.cpython-36.pyc b/timm/models/layers/__pycache__/blur_pool.cpython-36.pyc
new file mode 100644
index 0000000..e903b88
Binary files /dev/null and b/timm/models/layers/__pycache__/blur_pool.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/bottleneck_attn.cpython-36.pyc b/timm/models/layers/__pycache__/bottleneck_attn.cpython-36.pyc
new file mode 100644
index 0000000..25e8923
Binary files /dev/null and b/timm/models/layers/__pycache__/bottleneck_attn.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/cbam.cpython-36.pyc b/timm/models/layers/__pycache__/cbam.cpython-36.pyc
new file mode 100644
index 0000000..e0feef1
Binary files /dev/null and b/timm/models/layers/__pycache__/cbam.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/classifier.cpython-36.pyc b/timm/models/layers/__pycache__/classifier.cpython-36.pyc
new file mode 100644
index 0000000..3bd9f1b
Binary files /dev/null and b/timm/models/layers/__pycache__/classifier.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/cond_conv2d.cpython-36.pyc b/timm/models/layers/__pycache__/cond_conv2d.cpython-36.pyc
new file mode 100644
index 0000000..08fd2ae
Binary files /dev/null and b/timm/models/layers/__pycache__/cond_conv2d.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/config.cpython-36.pyc b/timm/models/layers/__pycache__/config.cpython-36.pyc
new file mode 100644
index 0000000..262bb71
Binary files /dev/null and b/timm/models/layers/__pycache__/config.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/conv2d_same.cpython-36.pyc b/timm/models/layers/__pycache__/conv2d_same.cpython-36.pyc
new file mode 100644
index 0000000..c82f558
Binary files /dev/null and b/timm/models/layers/__pycache__/conv2d_same.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/conv_bn_act.cpython-36.pyc b/timm/models/layers/__pycache__/conv_bn_act.cpython-36.pyc
new file mode 100644
index 0000000..e26afdd
Binary files /dev/null and b/timm/models/layers/__pycache__/conv_bn_act.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/create_act.cpython-36.pyc b/timm/models/layers/__pycache__/create_act.cpython-36.pyc
new file mode 100644
index 0000000..26ac969
Binary files /dev/null and b/timm/models/layers/__pycache__/create_act.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/create_attn.cpython-36.pyc b/timm/models/layers/__pycache__/create_attn.cpython-36.pyc
new file mode 100644
index 0000000..1996b64
Binary files /dev/null and b/timm/models/layers/__pycache__/create_attn.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/create_conv2d.cpython-36.pyc b/timm/models/layers/__pycache__/create_conv2d.cpython-36.pyc
new file mode 100644
index 0000000..14b72b5
Binary files /dev/null and b/timm/models/layers/__pycache__/create_conv2d.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/create_norm_act.cpython-36.pyc b/timm/models/layers/__pycache__/create_norm_act.cpython-36.pyc
new file mode 100644
index 0000000..f46dead
Binary files /dev/null and b/timm/models/layers/__pycache__/create_norm_act.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/drop.cpython-36.pyc b/timm/models/layers/__pycache__/drop.cpython-36.pyc
new file mode 100644
index 0000000..8c4df04
Binary files /dev/null and b/timm/models/layers/__pycache__/drop.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/eca.cpython-36.pyc b/timm/models/layers/__pycache__/eca.cpython-36.pyc
new file mode 100644
index 0000000..7d43317
Binary files /dev/null and b/timm/models/layers/__pycache__/eca.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/evo_norm.cpython-36.pyc b/timm/models/layers/__pycache__/evo_norm.cpython-36.pyc
new file mode 100644
index 0000000..57f5a9b
Binary files /dev/null and b/timm/models/layers/__pycache__/evo_norm.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/gather_excite.cpython-36.pyc b/timm/models/layers/__pycache__/gather_excite.cpython-36.pyc
new file mode 100644
index 0000000..1f8ee71
Binary files /dev/null and b/timm/models/layers/__pycache__/gather_excite.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/global_context.cpython-36.pyc b/timm/models/layers/__pycache__/global_context.cpython-36.pyc
new file mode 100644
index 0000000..56bb634
Binary files /dev/null and b/timm/models/layers/__pycache__/global_context.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/halo_attn.cpython-36.pyc b/timm/models/layers/__pycache__/halo_attn.cpython-36.pyc
new file mode 100644
index 0000000..7f376c4
Binary files /dev/null and b/timm/models/layers/__pycache__/halo_attn.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/helpers.cpython-36.pyc b/timm/models/layers/__pycache__/helpers.cpython-36.pyc
new file mode 100644
index 0000000..6f38984
Binary files /dev/null and b/timm/models/layers/__pycache__/helpers.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/inplace_abn.cpython-36.pyc b/timm/models/layers/__pycache__/inplace_abn.cpython-36.pyc
new file mode 100644
index 0000000..1359520
Binary files /dev/null and b/timm/models/layers/__pycache__/inplace_abn.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/lambda_layer.cpython-36.pyc b/timm/models/layers/__pycache__/lambda_layer.cpython-36.pyc
new file mode 100644
index 0000000..904960f
Binary files /dev/null and b/timm/models/layers/__pycache__/lambda_layer.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/linear.cpython-36.pyc b/timm/models/layers/__pycache__/linear.cpython-36.pyc
new file mode 100644
index 0000000..ea03df2
Binary files /dev/null and b/timm/models/layers/__pycache__/linear.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/mixed_conv2d.cpython-36.pyc b/timm/models/layers/__pycache__/mixed_conv2d.cpython-36.pyc
new file mode 100644
index 0000000..398d393
Binary files /dev/null and b/timm/models/layers/__pycache__/mixed_conv2d.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/mlp.cpython-36.pyc b/timm/models/layers/__pycache__/mlp.cpython-36.pyc
new file mode 100644
index 0000000..0e7a05a
Binary files /dev/null and b/timm/models/layers/__pycache__/mlp.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/non_local_attn.cpython-36.pyc b/timm/models/layers/__pycache__/non_local_attn.cpython-36.pyc
new file mode 100644
index 0000000..c1d30fc
Binary files /dev/null and b/timm/models/layers/__pycache__/non_local_attn.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/norm.cpython-36.pyc b/timm/models/layers/__pycache__/norm.cpython-36.pyc
new file mode 100644
index 0000000..ef351ad
Binary files /dev/null and b/timm/models/layers/__pycache__/norm.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/norm_act.cpython-36.pyc b/timm/models/layers/__pycache__/norm_act.cpython-36.pyc
new file mode 100644
index 0000000..15cb8d5
Binary files /dev/null and b/timm/models/layers/__pycache__/norm_act.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/padding.cpython-36.pyc b/timm/models/layers/__pycache__/padding.cpython-36.pyc
new file mode 100644
index 0000000..b52a352
Binary files /dev/null and b/timm/models/layers/__pycache__/padding.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/patch_embed.cpython-36.pyc b/timm/models/layers/__pycache__/patch_embed.cpython-36.pyc
new file mode 100644
index 0000000..adcf81f
Binary files /dev/null and b/timm/models/layers/__pycache__/patch_embed.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/pool2d_same.cpython-36.pyc b/timm/models/layers/__pycache__/pool2d_same.cpython-36.pyc
new file mode 100644
index 0000000..b17948a
Binary files /dev/null and b/timm/models/layers/__pycache__/pool2d_same.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/selective_kernel.cpython-36.pyc b/timm/models/layers/__pycache__/selective_kernel.cpython-36.pyc
new file mode 100644
index 0000000..7c14edb
Binary files /dev/null and b/timm/models/layers/__pycache__/selective_kernel.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/separable_conv.cpython-36.pyc b/timm/models/layers/__pycache__/separable_conv.cpython-36.pyc
new file mode 100644
index 0000000..89f388f
Binary files /dev/null and b/timm/models/layers/__pycache__/separable_conv.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/space_to_depth.cpython-36.pyc b/timm/models/layers/__pycache__/space_to_depth.cpython-36.pyc
new file mode 100644
index 0000000..dde43a1
Binary files /dev/null and b/timm/models/layers/__pycache__/space_to_depth.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/split_attn.cpython-36.pyc b/timm/models/layers/__pycache__/split_attn.cpython-36.pyc
new file mode 100644
index 0000000..a6adc54
Binary files /dev/null and b/timm/models/layers/__pycache__/split_attn.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/split_batchnorm.cpython-36.pyc b/timm/models/layers/__pycache__/split_batchnorm.cpython-36.pyc
new file mode 100644
index 0000000..27627b3
Binary files /dev/null and b/timm/models/layers/__pycache__/split_batchnorm.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/squeeze_excite.cpython-36.pyc b/timm/models/layers/__pycache__/squeeze_excite.cpython-36.pyc
new file mode 100644
index 0000000..2925ffc
Binary files /dev/null and b/timm/models/layers/__pycache__/squeeze_excite.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/std_conv.cpython-36.pyc b/timm/models/layers/__pycache__/std_conv.cpython-36.pyc
new file mode 100644
index 0000000..b020b10
Binary files /dev/null and b/timm/models/layers/__pycache__/std_conv.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/test_time_pool.cpython-36.pyc b/timm/models/layers/__pycache__/test_time_pool.cpython-36.pyc
new file mode 100644
index 0000000..08eec64
Binary files /dev/null and b/timm/models/layers/__pycache__/test_time_pool.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/trace_utils.cpython-36.pyc b/timm/models/layers/__pycache__/trace_utils.cpython-36.pyc
new file mode 100644
index 0000000..8172c06
Binary files /dev/null and b/timm/models/layers/__pycache__/trace_utils.cpython-36.pyc differ
diff --git a/timm/models/layers/__pycache__/weight_init.cpython-36.pyc b/timm/models/layers/__pycache__/weight_init.cpython-36.pyc
new file mode 100644
index 0000000..2049d6c
Binary files /dev/null and b/timm/models/layers/__pycache__/weight_init.cpython-36.pyc differ
diff --git a/timm/models/layers/activations.py b/timm/models/layers/activations.py
new file mode 100644
index 0000000..e16b3bd
--- /dev/null
+++ b/timm/models/layers/activations.py
@@ -0,0 +1,145 @@
+""" Activations
+
+A collection of activations fn and modules with a common interface so that they can
+easily be swapped. All have an `inplace` arg even if not used.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+def swish(x, inplace: bool = False):
+ """Swish - Described in: https://arxiv.org/abs/1710.05941
+ """
+ return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
+
+
+class Swish(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(Swish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return swish(x, self.inplace)
+
+
+def mish(x, inplace: bool = False):
+ """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
+ NOTE: I don't have a working inplace variant
+ """
+ return x.mul(F.softplus(x).tanh())
+
+
+class Mish(nn.Module):
+ """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
+ """
+ def __init__(self, inplace: bool = False):
+ super(Mish, self).__init__()
+
+ def forward(self, x):
+ return mish(x)
+
+
+def sigmoid(x, inplace: bool = False):
+ return x.sigmoid_() if inplace else x.sigmoid()
+
+
+# PyTorch has this, but not with a consistent inplace argmument interface
+class Sigmoid(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(Sigmoid, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return x.sigmoid_() if self.inplace else x.sigmoid()
+
+
+def tanh(x, inplace: bool = False):
+ return x.tanh_() if inplace else x.tanh()
+
+
+# PyTorch has this, but not with a consistent inplace argmument interface
+class Tanh(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(Tanh, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return x.tanh_() if self.inplace else x.tanh()
+
+
+def hard_swish(x, inplace: bool = False):
+ inner = F.relu6(x + 3.).div_(6.)
+ return x.mul_(inner) if inplace else x.mul(inner)
+
+
+class HardSwish(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSwish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return hard_swish(x, self.inplace)
+
+
+def hard_sigmoid(x, inplace: bool = False):
+ if inplace:
+ return x.add_(3.).clamp_(0., 6.).div_(6.)
+ else:
+ return F.relu6(x + 3.) / 6.
+
+
+class HardSigmoid(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSigmoid, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return hard_sigmoid(x, self.inplace)
+
+
+def hard_mish(x, inplace: bool = False):
+ """ Hard Mish
+ Experimental, based on notes by Mish author Diganta Misra at
+ https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
+ """
+ if inplace:
+ return x.mul_(0.5 * (x + 2).clamp(min=0, max=2))
+ else:
+ return 0.5 * x * (x + 2).clamp(min=0, max=2)
+
+
+class HardMish(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardMish, self).__init__()
+ self.inplace = inplace
+
+ def forward(self, x):
+ return hard_mish(x, self.inplace)
+
+
+class PReLU(nn.PReLU):
+ """Applies PReLU (w/ dummy inplace arg)
+ """
+ def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
+ super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.prelu(input, self.weight)
+
+
+def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
+ return F.gelu(x)
+
+
+class GELU(nn.Module):
+ """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
+ """
+ def __init__(self, inplace: bool = False):
+ super(GELU, self).__init__()
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.gelu(input)
diff --git a/timm/models/layers/activations_jit.py b/timm/models/layers/activations_jit.py
new file mode 100644
index 0000000..b4a5165
--- /dev/null
+++ b/timm/models/layers/activations_jit.py
@@ -0,0 +1,90 @@
+""" Activations
+
+A collection of jit-scripted activations fn and modules with a common interface so that they can
+easily be swapped. All have an `inplace` arg even if not used.
+
+All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
+currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
+versions if they contain in-place ops.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+@torch.jit.script
+def swish_jit(x, inplace: bool = False):
+ """Swish - Described in: https://arxiv.org/abs/1710.05941
+ """
+ return x.mul(x.sigmoid())
+
+
+@torch.jit.script
+def mish_jit(x, _inplace: bool = False):
+ """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
+ """
+ return x.mul(F.softplus(x).tanh())
+
+
+class SwishJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(SwishJit, self).__init__()
+
+ def forward(self, x):
+ return swish_jit(x)
+
+
+class MishJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(MishJit, self).__init__()
+
+ def forward(self, x):
+ return mish_jit(x)
+
+
+@torch.jit.script
+def hard_sigmoid_jit(x, inplace: bool = False):
+ # return F.relu6(x + 3.) / 6.
+ return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
+
+
+class HardSigmoidJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSigmoidJit, self).__init__()
+
+ def forward(self, x):
+ return hard_sigmoid_jit(x)
+
+
+@torch.jit.script
+def hard_swish_jit(x, inplace: bool = False):
+ # return x * (F.relu6(x + 3.) / 6)
+ return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
+
+
+class HardSwishJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSwishJit, self).__init__()
+
+ def forward(self, x):
+ return hard_swish_jit(x)
+
+
+@torch.jit.script
+def hard_mish_jit(x, inplace: bool = False):
+ """ Hard Mish
+ Experimental, based on notes by Mish author Diganta Misra at
+ https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
+ """
+ return 0.5 * x * (x + 2).clamp(min=0, max=2)
+
+
+class HardMishJit(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardMishJit, self).__init__()
+
+ def forward(self, x):
+ return hard_mish_jit(x)
diff --git a/timm/models/layers/activations_me.py b/timm/models/layers/activations_me.py
new file mode 100644
index 0000000..9a12bb7
--- /dev/null
+++ b/timm/models/layers/activations_me.py
@@ -0,0 +1,218 @@
+""" Activations (memory-efficient w/ custom autograd)
+
+A collection of activations fn and modules with a common interface so that they can
+easily be swapped. All have an `inplace` arg even if not used.
+
+These activations are not compatible with jit scripting or ONNX export of the model, please use either
+the JIT or basic versions of the activations.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+
+@torch.jit.script
+def swish_jit_fwd(x):
+ return x.mul(torch.sigmoid(x))
+
+
+@torch.jit.script
+def swish_jit_bwd(x, grad_output):
+ x_sigmoid = torch.sigmoid(x)
+ return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
+
+
+class SwishJitAutoFn(torch.autograd.Function):
+ """ torch.jit.script optimised Swish w/ memory-efficient checkpoint
+ Inspired by conversation btw Jeremy Howard & Adam Pazske
+ https://twitter.com/jeremyphoward/status/1188251041835315200
+ """
+ @staticmethod
+ def symbolic(g, x):
+ return g.op("Mul", x, g.op("Sigmoid", x))
+
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return swish_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return swish_jit_bwd(x, grad_output)
+
+
+def swish_me(x, inplace=False):
+ return SwishJitAutoFn.apply(x)
+
+
+class SwishMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(SwishMe, self).__init__()
+
+ def forward(self, x):
+ return SwishJitAutoFn.apply(x)
+
+
+@torch.jit.script
+def mish_jit_fwd(x):
+ return x.mul(torch.tanh(F.softplus(x)))
+
+
+@torch.jit.script
+def mish_jit_bwd(x, grad_output):
+ x_sigmoid = torch.sigmoid(x)
+ x_tanh_sp = F.softplus(x).tanh()
+ return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
+
+
+class MishJitAutoFn(torch.autograd.Function):
+ """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
+ A memory efficient, jit scripted variant of Mish
+ """
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return mish_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return mish_jit_bwd(x, grad_output)
+
+
+def mish_me(x, inplace=False):
+ return MishJitAutoFn.apply(x)
+
+
+class MishMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(MishMe, self).__init__()
+
+ def forward(self, x):
+ return MishJitAutoFn.apply(x)
+
+
+@torch.jit.script
+def hard_sigmoid_jit_fwd(x, inplace: bool = False):
+ return (x + 3).clamp(min=0, max=6).div(6.)
+
+
+@torch.jit.script
+def hard_sigmoid_jit_bwd(x, grad_output):
+ m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
+ return grad_output * m
+
+
+class HardSigmoidJitAutoFn(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return hard_sigmoid_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return hard_sigmoid_jit_bwd(x, grad_output)
+
+
+def hard_sigmoid_me(x, inplace: bool = False):
+ return HardSigmoidJitAutoFn.apply(x)
+
+
+class HardSigmoidMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSigmoidMe, self).__init__()
+
+ def forward(self, x):
+ return HardSigmoidJitAutoFn.apply(x)
+
+
+@torch.jit.script
+def hard_swish_jit_fwd(x):
+ return x * (x + 3).clamp(min=0, max=6).div(6.)
+
+
+@torch.jit.script
+def hard_swish_jit_bwd(x, grad_output):
+ m = torch.ones_like(x) * (x >= 3.)
+ m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
+ return grad_output * m
+
+
+class HardSwishJitAutoFn(torch.autograd.Function):
+ """A memory efficient, jit-scripted HardSwish activation"""
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return hard_swish_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return hard_swish_jit_bwd(x, grad_output)
+
+ @staticmethod
+ def symbolic(g, self):
+ input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float)))
+ hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
+ hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
+ return g.op("Mul", self, hardtanh_)
+
+
+def hard_swish_me(x, inplace=False):
+ return HardSwishJitAutoFn.apply(x)
+
+
+class HardSwishMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardSwishMe, self).__init__()
+
+ def forward(self, x):
+ return HardSwishJitAutoFn.apply(x)
+
+
+@torch.jit.script
+def hard_mish_jit_fwd(x):
+ return 0.5 * x * (x + 2).clamp(min=0, max=2)
+
+
+@torch.jit.script
+def hard_mish_jit_bwd(x, grad_output):
+ m = torch.ones_like(x) * (x >= -2.)
+ m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
+ return grad_output * m
+
+
+class HardMishJitAutoFn(torch.autograd.Function):
+ """ A memory efficient, jit scripted variant of Hard Mish
+ Experimental, based on notes by Mish author Diganta Misra at
+ https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
+ """
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ return hard_mish_jit_fwd(x)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x = ctx.saved_tensors[0]
+ return hard_mish_jit_bwd(x, grad_output)
+
+
+def hard_mish_me(x, inplace: bool = False):
+ return HardMishJitAutoFn.apply(x)
+
+
+class HardMishMe(nn.Module):
+ def __init__(self, inplace: bool = False):
+ super(HardMishMe, self).__init__()
+
+ def forward(self, x):
+ return HardMishJitAutoFn.apply(x)
+
+
+
diff --git a/timm/models/layers/adaptive_avgmax_pool.py b/timm/models/layers/adaptive_avgmax_pool.py
new file mode 100644
index 0000000..ebc6ada
--- /dev/null
+++ b/timm/models/layers/adaptive_avgmax_pool.py
@@ -0,0 +1,118 @@
+""" PyTorch selectable adaptive pooling
+Adaptive pooling with the ability to select the type of pooling from:
+ * 'avg' - Average pooling
+ * 'max' - Max pooling
+ * 'avgmax' - Sum of average and max pooling re-scaled by 0.5
+ * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim
+
+Both a functional and a nn.Module version of the pooling is provided.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def adaptive_pool_feat_mult(pool_type='avg'):
+ if pool_type == 'catavgmax':
+ return 2
+ else:
+ return 1
+
+
+def adaptive_avgmax_pool2d(x, output_size=1):
+ x_avg = F.adaptive_avg_pool2d(x, output_size)
+ x_max = F.adaptive_max_pool2d(x, output_size)
+ return 0.5 * (x_avg + x_max)
+
+
+def adaptive_catavgmax_pool2d(x, output_size=1):
+ x_avg = F.adaptive_avg_pool2d(x, output_size)
+ x_max = F.adaptive_max_pool2d(x, output_size)
+ return torch.cat((x_avg, x_max), 1)
+
+
+def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
+ """Selectable global pooling function with dynamic input kernel size
+ """
+ if pool_type == 'avg':
+ x = F.adaptive_avg_pool2d(x, output_size)
+ elif pool_type == 'avgmax':
+ x = adaptive_avgmax_pool2d(x, output_size)
+ elif pool_type == 'catavgmax':
+ x = adaptive_catavgmax_pool2d(x, output_size)
+ elif pool_type == 'max':
+ x = F.adaptive_max_pool2d(x, output_size)
+ else:
+ assert False, 'Invalid pool type: %s' % pool_type
+ return x
+
+
+class FastAdaptiveAvgPool2d(nn.Module):
+ def __init__(self, flatten=False):
+ super(FastAdaptiveAvgPool2d, self).__init__()
+ self.flatten = flatten
+
+ def forward(self, x):
+ return x.mean((2, 3), keepdim=not self.flatten)
+
+
+class AdaptiveAvgMaxPool2d(nn.Module):
+ def __init__(self, output_size=1):
+ super(AdaptiveAvgMaxPool2d, self).__init__()
+ self.output_size = output_size
+
+ def forward(self, x):
+ return adaptive_avgmax_pool2d(x, self.output_size)
+
+
+class AdaptiveCatAvgMaxPool2d(nn.Module):
+ def __init__(self, output_size=1):
+ super(AdaptiveCatAvgMaxPool2d, self).__init__()
+ self.output_size = output_size
+
+ def forward(self, x):
+ return adaptive_catavgmax_pool2d(x, self.output_size)
+
+
+class SelectAdaptivePool2d(nn.Module):
+ """Selectable global pooling layer with dynamic input kernel size
+ """
+ def __init__(self, output_size=1, pool_type='fast', flatten=False):
+ super(SelectAdaptivePool2d, self).__init__()
+ self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
+ self.flatten = nn.Flatten(1) if flatten else nn.Identity()
+ if pool_type == '':
+ self.pool = nn.Identity() # pass through
+ elif pool_type == 'fast':
+ assert output_size == 1
+ self.pool = FastAdaptiveAvgPool2d(flatten)
+ self.flatten = nn.Identity()
+ elif pool_type == 'avg':
+ self.pool = nn.AdaptiveAvgPool2d(output_size)
+ elif pool_type == 'avgmax':
+ self.pool = AdaptiveAvgMaxPool2d(output_size)
+ elif pool_type == 'catavgmax':
+ self.pool = AdaptiveCatAvgMaxPool2d(output_size)
+ elif pool_type == 'max':
+ self.pool = nn.AdaptiveMaxPool2d(output_size)
+ else:
+ assert False, 'Invalid pool type: %s' % pool_type
+
+ def is_identity(self):
+ return not self.pool_type
+
+ def forward(self, x):
+ x = self.pool(x)
+ x = self.flatten(x)
+ return x
+
+ def feat_mult(self):
+ return adaptive_pool_feat_mult(self.pool_type)
+
+ def __repr__(self):
+ return self.__class__.__name__ + ' (' \
+ + 'pool_type=' + self.pool_type \
+ + ', flatten=' + str(self.flatten) + ')'
+
diff --git a/timm/models/layers/attention_pool2d.py b/timm/models/layers/attention_pool2d.py
new file mode 100644
index 0000000..66e49b8
--- /dev/null
+++ b/timm/models/layers/attention_pool2d.py
@@ -0,0 +1,182 @@
+""" Attention Pool 2D
+
+Implementations of 2D spatial feature pooling using multi-head attention instead of average pool.
+
+Based on idea in CLIP by OpenAI, licensed Apache 2.0
+https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import math
+from typing import List, Union, Tuple
+
+import torch
+import torch.nn as nn
+
+from .helpers import to_2tuple
+from .weight_init import trunc_normal_
+
+
+def rot(x):
+ return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
+
+
+def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
+ return x * cos_emb + rot(x) * sin_emb
+
+
+def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
+ if isinstance(x, torch.Tensor):
+ x = [x]
+ return [t * cos_emb + rot(t) * sin_emb for t in x]
+
+
+class RotaryEmbedding(nn.Module):
+ """ Rotary position embedding
+
+ NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
+ been well tested, and will likely change. It will be moved to its own file.
+
+ The following impl/resources were referenced for this impl:
+ * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
+ * https://blog.eleuther.ai/rotary-embeddings/
+ """
+ def __init__(self, dim, max_freq=4):
+ super().__init__()
+ self.dim = dim
+ self.register_buffer('bands', 2 ** torch.linspace(0., max_freq - 1, self.dim // 4), persistent=False)
+
+ def get_embed(self, shape: torch.Size, device: torch.device = None, dtype: torch.dtype = None):
+ """
+ NOTE: shape arg should include spatial dim only
+ """
+ device = device or self.bands.device
+ dtype = dtype or self.bands.dtype
+ if not isinstance(shape, torch.Size):
+ shape = torch.Size(shape)
+ N = shape.numel()
+ grid = torch.stack(torch.meshgrid(
+ [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in shape]), dim=-1).unsqueeze(-1)
+ emb = grid * math.pi * self.bands
+ sin = emb.sin().reshape(N, -1).repeat_interleave(2, -1)
+ cos = emb.cos().reshape(N, -1).repeat_interleave(2, -1)
+ return sin, cos
+
+ def forward(self, x):
+ # assuming channel-first tensor where spatial dim are >= 2
+ sin_emb, cos_emb = self.get_embed(x.shape[2:])
+ return apply_rot_embed(x, sin_emb, cos_emb)
+
+
+class RotAttentionPool2d(nn.Module):
+ """ Attention based 2D feature pooling w/ rotary (relative) pos embedding.
+ This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
+
+ Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed.
+ https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
+
+ NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
+ train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
+ """
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int = None,
+ embed_dim: int = None,
+ num_heads: int = 4,
+ qkv_bias: bool = True,
+ ):
+ super().__init__()
+ embed_dim = embed_dim or in_features
+ out_features = out_features or in_features
+ self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(embed_dim, out_features)
+ self.num_heads = num_heads
+ assert embed_dim % num_heads == 0
+ self.head_dim = embed_dim // num_heads
+ self.scale = self.head_dim ** -0.5
+ self.pos_embed = RotaryEmbedding(self.head_dim)
+
+ trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
+ nn.init.zeros_(self.qkv.bias)
+
+ def forward(self, x):
+ B, _, H, W = x.shape
+ N = H * W
+ sin_emb, cos_emb = self.pos_embed.get_embed(x.shape[2:])
+ x = x.reshape(B, -1, N).permute(0, 2, 1)
+
+ x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
+
+ x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = x[0], x[1], x[2]
+
+ qc, q = q[:, :, :1], q[:, :, 1:]
+ q = apply_rot_embed(q, sin_emb, cos_emb)
+ q = torch.cat([qc, q], dim=2)
+
+ kc, k = k[:, :, :1], k[:, :, 1:]
+ k = apply_rot_embed(k, sin_emb, cos_emb)
+ k = torch.cat([kc, k], dim=2)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
+ x = self.proj(x)
+ return x[:, 0]
+
+
+class AttentionPool2d(nn.Module):
+ """ Attention based 2D feature pooling w/ learned (absolute) pos embedding.
+ This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
+
+ It was based on impl in CLIP by OpenAI
+ https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
+
+ NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
+ """
+ def __init__(
+ self,
+ in_features: int,
+ feat_size: Union[int, Tuple[int, int]],
+ out_features: int = None,
+ embed_dim: int = None,
+ num_heads: int = 4,
+ qkv_bias: bool = True,
+ ):
+ super().__init__()
+
+ embed_dim = embed_dim or in_features
+ out_features = out_features or in_features
+ assert embed_dim % num_heads == 0
+ self.feat_size = to_2tuple(feat_size)
+ self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
+ self.proj = nn.Linear(embed_dim, out_features)
+ self.num_heads = num_heads
+ self.head_dim = embed_dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ spatial_dim = self.feat_size[0] * self.feat_size[1]
+ self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))
+ trunc_normal_(self.pos_embed, std=in_features ** -0.5)
+ trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
+ nn.init.zeros_(self.qkv.bias)
+
+ def forward(self, x):
+ B, _, H, W = x.shape
+ N = H * W
+ assert self.feat_size[0] == H
+ assert self.feat_size[1] == W
+ x = x.reshape(B, -1, N).permute(0, 2, 1)
+ x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
+ x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
+
+ x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = x[0], x[1], x[2]
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
+ x = self.proj(x)
+ return x[:, 0]
diff --git a/timm/models/layers/blur_pool.py b/timm/models/layers/blur_pool.py
new file mode 100644
index 0000000..ca4ce75
--- /dev/null
+++ b/timm/models/layers/blur_pool.py
@@ -0,0 +1,42 @@
+"""
+BlurPool layer inspired by
+ - Kornia's Max_BlurPool2d
+ - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
+
+Hacked together by Chris Ha and Ross Wightman
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from .padding import get_padding
+
+
+class BlurPool2d(nn.Module):
+ r"""Creates a module that computes blurs and downsample a given feature map.
+ See :cite:`zhang2019shiftinvar` for more details.
+ Corresponds to the Downsample class, which does blurring and subsampling
+
+ Args:
+ channels = Number of input channels
+ filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
+ stride (int): downsampling filter stride
+
+ Returns:
+ torch.Tensor: the transformed tensor.
+ """
+ def __init__(self, channels, filt_size=3, stride=2) -> None:
+ super(BlurPool2d, self).__init__()
+ assert filt_size > 1
+ self.channels = channels
+ self.filt_size = filt_size
+ self.stride = stride
+ self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
+ coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
+ blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
+ self.register_buffer('filt', blur_filter, persistent=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = F.pad(x, self.padding, 'reflect')
+ return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1])
diff --git a/timm/models/layers/bottleneck_attn.py b/timm/models/layers/bottleneck_attn.py
new file mode 100644
index 0000000..c3db464
--- /dev/null
+++ b/timm/models/layers/bottleneck_attn.py
@@ -0,0 +1,157 @@
+""" Bottleneck Self Attention (Bottleneck Transformers)
+
+Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
+
+@misc{2101.11605,
+Author = {Aravind Srinivas and Tsung-Yi Lin and Niki Parmar and Jonathon Shlens and Pieter Abbeel and Ashish Vaswani},
+Title = {Bottleneck Transformers for Visual Recognition},
+Year = {2021},
+}
+
+Based on ref gist at: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
+
+This impl is a WIP but given that it is based on the ref gist likely not too far off.
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .helpers import to_2tuple, make_divisible
+from .weight_init import trunc_normal_
+from .trace_utils import _assert
+
+
+def rel_logits_1d(q, rel_k, permute_mask: List[int]):
+ """ Compute relative logits along one dimension
+
+ As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
+ Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
+
+ Args:
+ q: (batch, heads, height, width, dim)
+ rel_k: (2 * width - 1, dim)
+ permute_mask: permute output dim according to this
+ """
+ B, H, W, dim = q.shape
+ x = (q @ rel_k.transpose(-1, -2))
+ x = x.reshape(-1, W, 2 * W -1)
+
+ # pad to shift from relative to absolute indexing
+ x_pad = F.pad(x, [0, 1]).flatten(1)
+ x_pad = F.pad(x_pad, [0, W - 1])
+
+ # reshape and slice out the padded elements
+ x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1)
+ x = x_pad[:, :W, W - 1:]
+
+ # reshape and tile
+ x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1)
+ return x.permute(permute_mask)
+
+
+class PosEmbedRel(nn.Module):
+ """ Relative Position Embedding
+ As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
+ Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
+ """
+ def __init__(self, feat_size, dim_head, scale):
+ super().__init__()
+ self.height, self.width = to_2tuple(feat_size)
+ self.dim_head = dim_head
+ self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * scale)
+ self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * scale)
+
+ def forward(self, q):
+ B, HW, _ = q.shape
+
+ # relative logits in width dimension.
+ q = q.reshape(B, self.height, self.width, -1)
+ rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
+
+ # relative logits in height dimension.
+ q = q.transpose(1, 2)
+ rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
+
+ rel_logits = rel_logits_h + rel_logits_w
+ rel_logits = rel_logits.reshape(B, HW, HW)
+ return rel_logits
+
+
+class BottleneckAttn(nn.Module):
+ """ Bottleneck Attention
+ Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
+
+ The internal dimensions of the attention module are controlled by the interaction of several arguments.
+ * the output dimension of the module is specified by dim_out, which falls back to input dim if not set
+ * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
+ * the query and key (qk) dimensions are determined by
+ * num_heads * dim_head if dim_head is not None
+ * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
+ * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
+
+ Args:
+ dim (int): input dimension to the module
+ dim_out (int): output dimension of the module, same as dim if not set
+ stride (int): output stride of the module, avg pool used if stride == 2 (default: 1).
+ num_heads (int): parallel attention heads (default: 4)
+ dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
+ qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
+ qkv_bias (bool): add bias to q, k, and v projections
+ scale_pos_embed (bool): scale the position embedding as well as Q @ K
+ """
+ def __init__(
+ self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=None,
+ qk_ratio=1.0, qkv_bias=False, scale_pos_embed=False):
+ super().__init__()
+ assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required'
+ dim_out = dim_out or dim
+ assert dim_out % num_heads == 0
+ self.num_heads = num_heads
+ self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
+ self.dim_head_v = dim_out // self.num_heads
+ self.dim_out_qk = num_heads * self.dim_head_qk
+ self.dim_out_v = num_heads * self.dim_head_v
+ self.scale = self.dim_head_qk ** -0.5
+ self.scale_pos_embed = scale_pos_embed
+
+ self.qkv = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
+
+ # NOTE I'm only supporting relative pos embedding for now
+ self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
+
+ self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
+ trunc_normal_(self.pos_embed.height_rel, std=self.scale)
+ trunc_normal_(self.pos_embed.width_rel, std=self.scale)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ _assert(H == self.pos_embed.height, '')
+ _assert(W == self.pos_embed.width, '')
+
+ x = self.qkv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
+
+ # NOTE head vs channel split ordering in qkv projection was decided before I allowed qk to differ from v
+ # So, this is more verbose than if heads were before qkv splits, but throughput is not impacted.
+ q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
+ q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
+ k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
+ v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
+
+ if self.scale_pos_embed:
+ attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
+ else:
+ attn = (q @ k) * self.scale + self.pos_embed(q)
+ attn = attn.softmax(dim=-1)
+
+ out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
+ out = self.pool(out)
+ return out
diff --git a/timm/models/layers/cbam.py b/timm/models/layers/cbam.py
new file mode 100644
index 0000000..bacf5cf
--- /dev/null
+++ b/timm/models/layers/cbam.py
@@ -0,0 +1,112 @@
+""" CBAM (sort-of) Attention
+
+Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
+
+WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on
+some tasks, especially fine-grained it seems. I may end up removing this impl.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+from torch import nn as nn
+import torch.nn.functional as F
+
+from .conv_bn_act import ConvBnAct
+from .create_act import create_act_layer, get_act_layer
+from .helpers import make_divisible
+
+
+class ChannelAttn(nn.Module):
+ """ Original CBAM channel attention module, currently avg + max pool variant only.
+ """
+ def __init__(
+ self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
+ act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
+ super(ChannelAttn, self).__init__()
+ if not rd_channels:
+ rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
+ self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias)
+ self.act = act_layer(inplace=True)
+ self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias)
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True))))
+ x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True))))
+ return x * self.gate(x_avg + x_max)
+
+
+class LightChannelAttn(ChannelAttn):
+ """An experimental 'lightweight' that sums avg + max pool first
+ """
+ def __init__(
+ self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
+ act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
+ super(LightChannelAttn, self).__init__(
+ channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias)
+
+ def forward(self, x):
+ x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True)
+ x_attn = self.fc2(self.act(self.fc1(x_pool)))
+ return x * F.sigmoid(x_attn)
+
+
+class SpatialAttn(nn.Module):
+ """ Original CBAM spatial attention module
+ """
+ def __init__(self, kernel_size=7, gate_layer='sigmoid'):
+ super(SpatialAttn, self).__init__()
+ self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None)
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1)
+ x_attn = self.conv(x_attn)
+ return x * self.gate(x_attn)
+
+
+class LightSpatialAttn(nn.Module):
+ """An experimental 'lightweight' variant that sums avg_pool and max_pool results.
+ """
+ def __init__(self, kernel_size=7, gate_layer='sigmoid'):
+ super(LightSpatialAttn, self).__init__()
+ self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None)
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True)
+ x_attn = self.conv(x_attn)
+ return x * self.gate(x_attn)
+
+
+class CbamModule(nn.Module):
+ def __init__(
+ self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
+ spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
+ super(CbamModule, self).__init__()
+ self.channel = ChannelAttn(
+ channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
+ rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
+ self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer)
+
+ def forward(self, x):
+ x = self.channel(x)
+ x = self.spatial(x)
+ return x
+
+
+class LightCbamModule(nn.Module):
+ def __init__(
+ self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
+ spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
+ super(LightCbamModule, self).__init__()
+ self.channel = LightChannelAttn(
+ channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
+ rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
+ self.spatial = LightSpatialAttn(spatial_kernel_size)
+
+ def forward(self, x):
+ x = self.channel(x)
+ x = self.spatial(x)
+ return x
+
diff --git a/timm/models/layers/classifier.py b/timm/models/layers/classifier.py
new file mode 100644
index 0000000..2b74541
--- /dev/null
+++ b/timm/models/layers/classifier.py
@@ -0,0 +1,56 @@
+""" Classifier head and layer factory
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from torch import nn as nn
+from torch.nn import functional as F
+
+from .adaptive_avgmax_pool import SelectAdaptivePool2d
+from .linear import Linear
+
+
+def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
+ flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
+ if not pool_type:
+ assert num_classes == 0 or use_conv,\
+ 'Pooling can only be disabled if classifier is also removed or conv classifier is used'
+ flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
+ global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
+ num_pooled_features = num_features * global_pool.feat_mult()
+ return global_pool, num_pooled_features
+
+
+def _create_fc(num_features, num_classes, use_conv=False):
+ if num_classes <= 0:
+ fc = nn.Identity() # pass-through (no classifier)
+ elif use_conv:
+ fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
+ else:
+ # NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue
+ fc = Linear(num_features, num_classes, bias=True)
+ return fc
+
+
+def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
+ global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
+ fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
+ return global_pool, fc
+
+
+class ClassifierHead(nn.Module):
+ """Classifier head w/ configurable global pooling and dropout."""
+
+ def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
+ super(ClassifierHead, self).__init__()
+ self.drop_rate = drop_rate
+ self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
+ self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
+ self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
+
+ def forward(self, x):
+ x = self.global_pool(x)
+ if self.drop_rate:
+ x = F.dropout(x, p=float(self.drop_rate), training=self.training)
+ x = self.fc(x)
+ x = self.flatten(x)
+ return x
diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py
new file mode 100644
index 0000000..8b4bbca
--- /dev/null
+++ b/timm/models/layers/cond_conv2d.py
@@ -0,0 +1,122 @@
+""" PyTorch Conditionally Parameterized Convolution (CondConv)
+
+Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference
+(https://arxiv.org/abs/1904.04971)
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import math
+from functools import partial
+import numpy as np
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from .helpers import to_2tuple
+from .conv2d_same import conv2d_same
+from .padding import get_padding_value
+
+
+def get_condconv_initializer(initializer, num_experts, expert_shape):
+ def condconv_initializer(weight):
+ """CondConv initializer function."""
+ num_params = np.prod(expert_shape)
+ if (len(weight.shape) != 2 or weight.shape[0] != num_experts or
+ weight.shape[1] != num_params):
+ raise (ValueError(
+ 'CondConv variables must have shape [num_experts, num_params]'))
+ for i in range(num_experts):
+ initializer(weight[i].view(expert_shape))
+ return condconv_initializer
+
+
+class CondConv2d(nn.Module):
+ """ Conditionally Parameterized Convolution
+ Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
+
+ Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
+ https://github.com/pytorch/pytorch/issues/17983
+ """
+ __constants__ = ['in_channels', 'out_channels', 'dynamic_padding']
+
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4):
+ super(CondConv2d, self).__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = to_2tuple(kernel_size)
+ self.stride = to_2tuple(stride)
+ padding_val, is_padding_dynamic = get_padding_value(
+ padding, kernel_size, stride=stride, dilation=dilation)
+ self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript
+ self.padding = to_2tuple(padding_val)
+ self.dilation = to_2tuple(dilation)
+ self.groups = groups
+ self.num_experts = num_experts
+
+ self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size
+ weight_num_param = 1
+ for wd in self.weight_shape:
+ weight_num_param *= wd
+ self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param))
+
+ if bias:
+ self.bias_shape = (self.out_channels,)
+ self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels))
+ else:
+ self.register_parameter('bias', None)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ init_weight = get_condconv_initializer(
+ partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape)
+ init_weight(self.weight)
+ if self.bias is not None:
+ fan_in = np.prod(self.weight_shape[1:])
+ bound = 1 / math.sqrt(fan_in)
+ init_bias = get_condconv_initializer(
+ partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape)
+ init_bias(self.bias)
+
+ def forward(self, x, routing_weights):
+ B, C, H, W = x.shape
+ weight = torch.matmul(routing_weights, self.weight)
+ new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
+ weight = weight.view(new_weight_shape)
+ bias = None
+ if self.bias is not None:
+ bias = torch.matmul(routing_weights, self.bias)
+ bias = bias.view(B * self.out_channels)
+ # move batch elements with channels so each batch element can be efficiently convolved with separate kernel
+ x = x.view(1, B * C, H, W)
+ if self.dynamic_padding:
+ out = conv2d_same(
+ x, weight, bias, stride=self.stride, padding=self.padding,
+ dilation=self.dilation, groups=self.groups * B)
+ else:
+ out = F.conv2d(
+ x, weight, bias, stride=self.stride, padding=self.padding,
+ dilation=self.dilation, groups=self.groups * B)
+ out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1])
+
+ # Literal port (from TF definition)
+ # x = torch.split(x, 1, 0)
+ # weight = torch.split(weight, 1, 0)
+ # if self.bias is not None:
+ # bias = torch.matmul(routing_weights, self.bias)
+ # bias = torch.split(bias, 1, 0)
+ # else:
+ # bias = [None] * B
+ # out = []
+ # for xi, wi, bi in zip(x, weight, bias):
+ # wi = wi.view(*self.weight_shape)
+ # if bi is not None:
+ # bi = bi.view(*self.bias_shape)
+ # out.append(self.conv_fn(
+ # xi, wi, bi, stride=self.stride, padding=self.padding,
+ # dilation=self.dilation, groups=self.groups))
+ # out = torch.cat(out, 0)
+ return out
diff --git a/timm/models/layers/config.py b/timm/models/layers/config.py
new file mode 100644
index 0000000..f07b9d7
--- /dev/null
+++ b/timm/models/layers/config.py
@@ -0,0 +1,115 @@
+""" Model / Layer Config singleton state
+"""
+from typing import Any, Optional
+
+__all__ = [
+ 'is_exportable', 'is_scriptable', 'is_no_jit',
+ 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
+]
+
+# Set to True if prefer to have layers with no jit optimization (includes activations)
+_NO_JIT = False
+
+# Set to True if prefer to have activation layers with no jit optimization
+# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
+# the jit flags so far are activations. This will change as more layers are updated and/or added.
+_NO_ACTIVATION_JIT = False
+
+# Set to True if exporting a model with Same padding via ONNX
+_EXPORTABLE = False
+
+# Set to True if wanting to use torch.jit.script on a model
+_SCRIPTABLE = False
+
+
+def is_no_jit():
+ return _NO_JIT
+
+
+class set_no_jit:
+ def __init__(self, mode: bool) -> None:
+ global _NO_JIT
+ self.prev = _NO_JIT
+ _NO_JIT = mode
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, *args: Any) -> bool:
+ global _NO_JIT
+ _NO_JIT = self.prev
+ return False
+
+
+def is_exportable():
+ return _EXPORTABLE
+
+
+class set_exportable:
+ def __init__(self, mode: bool) -> None:
+ global _EXPORTABLE
+ self.prev = _EXPORTABLE
+ _EXPORTABLE = mode
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, *args: Any) -> bool:
+ global _EXPORTABLE
+ _EXPORTABLE = self.prev
+ return False
+
+
+def is_scriptable():
+ return _SCRIPTABLE
+
+
+class set_scriptable:
+ def __init__(self, mode: bool) -> None:
+ global _SCRIPTABLE
+ self.prev = _SCRIPTABLE
+ _SCRIPTABLE = mode
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, *args: Any) -> bool:
+ global _SCRIPTABLE
+ _SCRIPTABLE = self.prev
+ return False
+
+
+class set_layer_config:
+ """ Layer config context manager that allows setting all layer config flags at once.
+ If a flag arg is None, it will not change the current value.
+ """
+ def __init__(
+ self,
+ scriptable: Optional[bool] = None,
+ exportable: Optional[bool] = None,
+ no_jit: Optional[bool] = None,
+ no_activation_jit: Optional[bool] = None):
+ global _SCRIPTABLE
+ global _EXPORTABLE
+ global _NO_JIT
+ global _NO_ACTIVATION_JIT
+ self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
+ if scriptable is not None:
+ _SCRIPTABLE = scriptable
+ if exportable is not None:
+ _EXPORTABLE = exportable
+ if no_jit is not None:
+ _NO_JIT = no_jit
+ if no_activation_jit is not None:
+ _NO_ACTIVATION_JIT = no_activation_jit
+
+ def __enter__(self) -> None:
+ pass
+
+ def __exit__(self, *args: Any) -> bool:
+ global _SCRIPTABLE
+ global _EXPORTABLE
+ global _NO_JIT
+ global _NO_ACTIVATION_JIT
+ _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
+ return False
diff --git a/timm/models/layers/conv2d_same.py b/timm/models/layers/conv2d_same.py
new file mode 100644
index 0000000..75f0f98
--- /dev/null
+++ b/timm/models/layers/conv2d_same.py
@@ -0,0 +1,42 @@
+""" Conv2d w/ Same Padding
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Tuple, Optional
+
+from .padding import pad_same, get_padding_value
+
+
+def conv2d_same(
+ x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
+ padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
+ x = pad_same(x, weight.shape[-2:], stride, dilation)
+ return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
+
+
+class Conv2dSame(nn.Conv2d):
+ """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
+ """
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1,
+ padding=0, dilation=1, groups=1, bias=True):
+ super(Conv2dSame, self).__init__(
+ in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
+
+ def forward(self, x):
+ return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+
+def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
+ padding = kwargs.pop('padding', '')
+ kwargs.setdefault('bias', False)
+ padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
+ if is_dynamic:
+ return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
+ else:
+ return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
+
+
diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py
new file mode 100644
index 0000000..33005c3
--- /dev/null
+++ b/timm/models/layers/conv_bn_act.py
@@ -0,0 +1,40 @@
+""" Conv2d + BN + Act
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from torch import nn as nn
+
+from .create_conv2d import create_conv2d
+from .create_norm_act import convert_norm_act
+
+
+class ConvBnAct(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
+ bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None,
+ drop_block=None):
+ super(ConvBnAct, self).__init__()
+ use_aa = aa_layer is not None
+
+ self.conv = create_conv2d(
+ in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
+ padding=padding, dilation=dilation, groups=groups, bias=bias)
+
+ # NOTE for backwards compatibility with models that use separate norm and act layer definitions
+ norm_act_layer = convert_norm_act(norm_layer, act_layer)
+ self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block)
+ self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None
+
+ @property
+ def in_channels(self):
+ return self.conv.in_channels
+
+ @property
+ def out_channels(self):
+ return self.conv.out_channels
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ if self.aa is not None:
+ x = self.aa(x)
+ return x
diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py
new file mode 100644
index 0000000..aa55769
--- /dev/null
+++ b/timm/models/layers/create_act.py
@@ -0,0 +1,153 @@
+""" Activation Factory
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from typing import Union, Callable, Type
+
+from .activations import *
+from .activations_jit import *
+from .activations_me import *
+from .config import is_exportable, is_scriptable, is_no_jit
+
+# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
+# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
+# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.
+_has_silu = 'silu' in dir(torch.nn.functional)
+_has_hardswish = 'hardswish' in dir(torch.nn.functional)
+_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional)
+_has_mish = 'mish' in dir(torch.nn.functional)
+
+
+_ACT_FN_DEFAULT = dict(
+ silu=F.silu if _has_silu else swish,
+ swish=F.silu if _has_silu else swish,
+ mish=F.mish if _has_mish else mish,
+ relu=F.relu,
+ relu6=F.relu6,
+ leaky_relu=F.leaky_relu,
+ elu=F.elu,
+ celu=F.celu,
+ selu=F.selu,
+ gelu=gelu,
+ sigmoid=sigmoid,
+ tanh=tanh,
+ hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
+ hard_swish=F.hardswish if _has_hardswish else hard_swish,
+ hard_mish=hard_mish,
+)
+
+_ACT_FN_JIT = dict(
+ silu=F.silu if _has_silu else swish_jit,
+ swish=F.silu if _has_silu else swish_jit,
+ mish=F.mish if _has_mish else mish_jit,
+ hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
+ hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
+ hard_mish=hard_mish_jit
+)
+
+_ACT_FN_ME = dict(
+ silu=F.silu if _has_silu else swish_me,
+ swish=F.silu if _has_silu else swish_me,
+ mish=F.mish if _has_mish else mish_me,
+ hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me,
+ hard_swish=F.hardswish if _has_hardswish else hard_swish_me,
+ hard_mish=hard_mish_me,
+)
+
+_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)
+for a in _ACT_FNS:
+ a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
+ a.setdefault('hardswish', a.get('hard_swish'))
+
+
+_ACT_LAYER_DEFAULT = dict(
+ silu=nn.SiLU if _has_silu else Swish,
+ swish=nn.SiLU if _has_silu else Swish,
+ mish=nn.Mish if _has_mish else Mish,
+ relu=nn.ReLU,
+ relu6=nn.ReLU6,
+ leaky_relu=nn.LeakyReLU,
+ elu=nn.ELU,
+ prelu=PReLU,
+ celu=nn.CELU,
+ selu=nn.SELU,
+ gelu=GELU,
+ sigmoid=Sigmoid,
+ tanh=Tanh,
+ hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
+ hard_swish=nn.Hardswish if _has_hardswish else HardSwish,
+ hard_mish=HardMish,
+)
+
+_ACT_LAYER_JIT = dict(
+ silu=nn.SiLU if _has_silu else SwishJit,
+ swish=nn.SiLU if _has_silu else SwishJit,
+ mish=nn.Mish if _has_mish else MishJit,
+ hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
+ hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
+ hard_mish=HardMishJit
+)
+
+_ACT_LAYER_ME = dict(
+ silu=nn.SiLU if _has_silu else SwishMe,
+ swish=nn.SiLU if _has_silu else SwishMe,
+ mish=nn.Mish if _has_mish else MishMe,
+ hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe,
+ hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe,
+ hard_mish=HardMishMe,
+)
+
+_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)
+for a in _ACT_LAYERS:
+ a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
+ a.setdefault('hardswish', a.get('hard_swish'))
+
+
+def get_act_fn(name: Union[Callable, str] = 'relu'):
+ """ Activation Function Factory
+ Fetching activation fns by name with this function allows export or torch script friendly
+ functions to be returned dynamically based on current config.
+ """
+ if not name:
+ return None
+ if isinstance(name, Callable):
+ return name
+ if not (is_no_jit() or is_exportable() or is_scriptable()):
+ # If not exporting or scripting the model, first look for a memory-efficient version with
+ # custom autograd, then fallback
+ if name in _ACT_FN_ME:
+ return _ACT_FN_ME[name]
+ if is_exportable() and name in ('silu', 'swish'):
+ # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
+ return swish
+ if not (is_no_jit() or is_exportable()):
+ if name in _ACT_FN_JIT:
+ return _ACT_FN_JIT[name]
+ return _ACT_FN_DEFAULT[name]
+
+
+def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
+ """ Activation Layer Factory
+ Fetching activation layers by name with this function allows export or torch script friendly
+ functions to be returned dynamically based on current config.
+ """
+ if not name:
+ return None
+ if isinstance(name, type):
+ return name
+ if not (is_no_jit() or is_exportable() or is_scriptable()):
+ if name in _ACT_LAYER_ME:
+ return _ACT_LAYER_ME[name]
+ if is_exportable() and name in ('silu', 'swish'):
+ # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
+ return Swish
+ if not (is_no_jit() or is_exportable()):
+ if name in _ACT_LAYER_JIT:
+ return _ACT_LAYER_JIT[name]
+ return _ACT_LAYER_DEFAULT[name]
+
+
+def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
+ act_layer = get_act_layer(name)
+ if act_layer is None:
+ return None
+ return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs)
diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py
new file mode 100644
index 0000000..028c0f7
--- /dev/null
+++ b/timm/models/layers/create_attn.py
@@ -0,0 +1,89 @@
+""" Attention Factory
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import torch
+from functools import partial
+
+from .bottleneck_attn import BottleneckAttn
+from .cbam import CbamModule, LightCbamModule
+from .eca import EcaModule, CecaModule
+from .gather_excite import GatherExcite
+from .global_context import GlobalContext
+from .halo_attn import HaloAttn
+from .lambda_layer import LambdaLayer
+from .non_local_attn import NonLocalAttn, BatNonLocalAttn
+from .selective_kernel import SelectiveKernel
+from .split_attn import SplitAttn
+from .squeeze_excite import SEModule, EffectiveSEModule
+
+
+def get_attn(attn_type):
+ if isinstance(attn_type, torch.nn.Module):
+ return attn_type
+ module_cls = None
+ if attn_type is not None:
+ if isinstance(attn_type, str):
+ attn_type = attn_type.lower()
+ # Lightweight attention modules (channel and/or coarse spatial).
+ # Typically added to existing network architecture blocks in addition to existing convolutions.
+ if attn_type == 'se':
+ module_cls = SEModule
+ elif attn_type == 'ese':
+ module_cls = EffectiveSEModule
+ elif attn_type == 'eca':
+ module_cls = EcaModule
+ elif attn_type == 'ecam':
+ module_cls = partial(EcaModule, use_mlp=True)
+ elif attn_type == 'ceca':
+ module_cls = CecaModule
+ elif attn_type == 'ge':
+ module_cls = GatherExcite
+ elif attn_type == 'gc':
+ module_cls = GlobalContext
+ elif attn_type == 'gca':
+ module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)
+ elif attn_type == 'cbam':
+ module_cls = CbamModule
+ elif attn_type == 'lcbam':
+ module_cls = LightCbamModule
+
+ # Attention / attention-like modules w/ significant params
+ # Typically replace some of the existing workhorse convs in a network architecture.
+ # All of these accept a stride argument and can spatially downsample the input.
+ elif attn_type == 'sk':
+ module_cls = SelectiveKernel
+ elif attn_type == 'splat':
+ module_cls = SplitAttn
+
+ # Self-attention / attention-like modules w/ significant compute and/or params
+ # Typically replace some of the existing workhorse convs in a network architecture.
+ # All of these accept a stride argument and can spatially downsample the input.
+ elif attn_type == 'lambda':
+ return LambdaLayer
+ elif attn_type == 'bottleneck':
+ return BottleneckAttn
+ elif attn_type == 'halo':
+ return HaloAttn
+ elif attn_type == 'nl':
+ module_cls = NonLocalAttn
+ elif attn_type == 'bat':
+ module_cls = BatNonLocalAttn
+
+ # Woops!
+ else:
+ assert False, "Invalid attn module (%s)" % attn_type
+ elif isinstance(attn_type, bool):
+ if attn_type:
+ module_cls = SEModule
+ else:
+ module_cls = attn_type
+ return module_cls
+
+
+def create_attn(attn_type, channels, **kwargs):
+ module_cls = get_attn(attn_type)
+ if module_cls is not None:
+ # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
+ return module_cls(channels, **kwargs)
+ return None
diff --git a/timm/models/layers/create_conv2d.py b/timm/models/layers/create_conv2d.py
new file mode 100644
index 0000000..3a0cc03
--- /dev/null
+++ b/timm/models/layers/create_conv2d.py
@@ -0,0 +1,31 @@
+""" Create Conv2d Factory Method
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+from .mixed_conv2d import MixedConv2d
+from .cond_conv2d import CondConv2d
+from .conv2d_same import create_conv2d_pad
+
+
+def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
+ """ Select a 2d convolution implementation based on arguments
+ Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
+
+ Used extensively by EfficientNet, MobileNetv3 and related networks.
+ """
+ if isinstance(kernel_size, list):
+ assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
+ assert 'groups' not in kwargs # MixedConv groups are defined by kernel list
+ # We're going to use only lists for defining the MixedConv2d kernel groups,
+ # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
+ m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
+ else:
+ depthwise = kwargs.pop('depthwise', False)
+ # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
+ groups = in_channels if depthwise else kwargs.pop('groups', 1)
+ if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
+ m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
+ else:
+ m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
+ return m
diff --git a/timm/models/layers/create_norm_act.py b/timm/models/layers/create_norm_act.py
new file mode 100644
index 0000000..5b56294
--- /dev/null
+++ b/timm/models/layers/create_norm_act.py
@@ -0,0 +1,83 @@
+""" NormAct (Normalizaiton + Activation Layer) Factory
+
+Create norm + act combo modules that attempt to be backwards compatible with separate norm + act
+isntances in models. Where these are used it will be possible to swap separate BN + act layers with
+combined modules like IABN or EvoNorms.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import types
+import functools
+
+import torch
+import torch.nn as nn
+
+from .evo_norm import EvoNormBatch2d, EvoNormSample2d
+from .norm_act import BatchNormAct2d, GroupNormAct
+from .inplace_abn import InplaceAbn
+
+_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn}
+_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type
+
+
+def get_norm_act_layer(layer_class):
+ layer_class = layer_class.replace('_', '').lower()
+ if layer_class.startswith("batchnorm"):
+ layer = BatchNormAct2d
+ elif layer_class.startswith("groupnorm"):
+ layer = GroupNormAct
+ elif layer_class == "evonormbatch":
+ layer = EvoNormBatch2d
+ elif layer_class == "evonormsample":
+ layer = EvoNormSample2d
+ elif layer_class == "iabn" or layer_class == "inplaceabn":
+ layer = InplaceAbn
+ else:
+ assert False, "Invalid norm_act layer (%s)" % layer_class
+ return layer
+
+
+def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs):
+ layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu
+ assert len(layer_parts) in (1, 2)
+ layer = get_norm_act_layer(layer_parts[0])
+ #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection?
+ layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
+ if jit:
+ layer_instance = torch.jit.script(layer_instance)
+ return layer_instance
+
+
+def convert_norm_act(norm_layer, act_layer):
+ assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
+ assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
+ norm_act_kwargs = {}
+
+ # unbind partial fn, so args can be rebound later
+ if isinstance(norm_layer, functools.partial):
+ norm_act_kwargs.update(norm_layer.keywords)
+ norm_layer = norm_layer.func
+
+ if isinstance(norm_layer, str):
+ norm_act_layer = get_norm_act_layer(norm_layer)
+ elif norm_layer in _NORM_ACT_TYPES:
+ norm_act_layer = norm_layer
+ elif isinstance(norm_layer, types.FunctionType):
+ # if function type, must be a lambda/fn that creates a norm_act layer
+ norm_act_layer = norm_layer
+ else:
+ type_name = norm_layer.__name__.lower()
+ if type_name.startswith('batchnorm'):
+ norm_act_layer = BatchNormAct2d
+ elif type_name.startswith('groupnorm'):
+ norm_act_layer = GroupNormAct
+ else:
+ assert False, f"No equivalent norm_act layer for {type_name}"
+
+ if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
+ # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
+ # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
+ norm_act_kwargs.setdefault('act_layer', act_layer)
+ if norm_act_kwargs:
+ norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args
+ return norm_act_layer
diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py
new file mode 100644
index 0000000..6de9e3f
--- /dev/null
+++ b/timm/models/layers/drop.py
@@ -0,0 +1,168 @@
+""" DropBlock, DropPath
+
+PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
+
+Papers:
+DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
+
+Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
+
+Code:
+DropBlock impl inspired by two Tensorflow impl that I liked:
+ - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
+ - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def drop_block_2d(
+ x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
+ with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
+ """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
+
+ DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
+ runs with success, but needs further validation and possibly optimization for lower runtime impact.
+ """
+ B, C, H, W = x.shape
+ total_size = W * H
+ clipped_block_size = min(block_size, min(W, H))
+ # seed_drop_rate, the gamma parameter
+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
+ (W - block_size + 1) * (H - block_size + 1))
+
+ # Forces the block to be inside the feature map.
+ w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
+ valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
+ ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
+ valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
+
+ if batchwise:
+ # one mask for whole batch, quite a bit faster
+ uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
+ else:
+ uniform_noise = torch.rand_like(x)
+ block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
+ block_mask = -F.max_pool2d(
+ -block_mask,
+ kernel_size=clipped_block_size, # block_size,
+ stride=1,
+ padding=clipped_block_size // 2)
+
+ if with_noise:
+ normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
+ if inplace:
+ x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
+ else:
+ x = x * block_mask + normal_noise * (1 - block_mask)
+ else:
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
+ if inplace:
+ x.mul_(block_mask * normalize_scale)
+ else:
+ x = x * block_mask * normalize_scale
+ return x
+
+
+def drop_block_fast_2d(
+ x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
+ gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
+ """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
+
+ DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
+ block mask at edges.
+ """
+ B, C, H, W = x.shape
+ total_size = W * H
+ clipped_block_size = min(block_size, min(W, H))
+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
+ (W - block_size + 1) * (H - block_size + 1))
+
+ if batchwise:
+ # one mask for whole batch, quite a bit faster
+ block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
+ else:
+ # mask per batch element
+ block_mask = torch.rand_like(x) < gamma
+ block_mask = F.max_pool2d(
+ block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
+
+ if with_noise:
+ normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
+ if inplace:
+ x.mul_(1. - block_mask).add_(normal_noise * block_mask)
+ else:
+ x = x * (1. - block_mask) + normal_noise * block_mask
+ else:
+ block_mask = 1 - block_mask
+ normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
+ if inplace:
+ x.mul_(block_mask * normalize_scale)
+ else:
+ x = x * block_mask * normalize_scale
+ return x
+
+
+class DropBlock2d(nn.Module):
+ """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
+ """
+ def __init__(self,
+ drop_prob=0.1,
+ block_size=7,
+ gamma_scale=1.0,
+ with_noise=False,
+ inplace=False,
+ batchwise=False,
+ fast=True):
+ super(DropBlock2d, self).__init__()
+ self.drop_prob = drop_prob
+ self.gamma_scale = gamma_scale
+ self.block_size = block_size
+ self.with_noise = with_noise
+ self.inplace = inplace
+ self.batchwise = batchwise
+ self.fast = fast # FIXME finish comparisons of fast vs not
+
+ def forward(self, x):
+ if not self.training or not self.drop_prob:
+ return x
+ if self.fast:
+ return drop_block_fast_2d(
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
+ else:
+ return drop_block_2d(
+ x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py
new file mode 100644
index 0000000..e29be6a
--- /dev/null
+++ b/timm/models/layers/eca.py
@@ -0,0 +1,145 @@
+"""
+ECA module from ECAnet
+
+paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
+https://arxiv.org/abs/1910.03151
+
+Original ECA model borrowed from https://github.com/BangguWu/ECANet
+
+Modified circular ECA implementation and adaption for use in timm package
+by Chris Ha https://github.com/VRandme
+
+Original License:
+
+MIT License
+
+Copyright (c) 2019 BangguWu, Qilong Wang
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+"""
+import math
+from torch import nn
+import torch.nn.functional as F
+
+
+from .create_act import create_act_layer
+from .helpers import make_divisible
+
+
+class EcaModule(nn.Module):
+ """Constructs an ECA module.
+
+ Args:
+ channels: Number of channels of the input feature map for use in adaptive kernel sizes
+ for actual calculations according to channel.
+ gamma, beta: when channel is given parameters of mapping function
+ refer to original paper https://arxiv.org/pdf/1910.03151.pdf
+ (default=None. if channel size not given, use k_size given for kernel size.)
+ kernel_size: Adaptive selection of kernel size (default=3)
+ gamm: used in kernel_size calc, see above
+ beta: used in kernel_size calc, see above
+ act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
+ gate_layer: gating non-linearity to use
+ """
+ def __init__(
+ self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid',
+ rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False):
+ super(EcaModule, self).__init__()
+ if channels is not None:
+ t = int(abs(math.log(channels, 2) + beta) / gamma)
+ kernel_size = max(t if t % 2 else t + 1, 3)
+ assert kernel_size % 2 == 1
+ padding = (kernel_size - 1) // 2
+ if use_mlp:
+ # NOTE 'mlp' mode is a timm experiment, not in paper
+ assert channels is not None
+ if rd_channels is None:
+ rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor)
+ act_layer = act_layer or nn.ReLU
+ self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True)
+ self.act = create_act_layer(act_layer)
+ self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True)
+ else:
+ self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
+ self.act = None
+ self.conv2 = None
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
+ y = self.conv(y)
+ if self.conv2 is not None:
+ y = self.act(y)
+ y = self.conv2(y)
+ y = self.gate(y).view(x.shape[0], -1, 1, 1)
+ return x * y.expand_as(x)
+
+
+EfficientChannelAttn = EcaModule # alias
+
+
+class CecaModule(nn.Module):
+ """Constructs a circular ECA module.
+
+ ECA module where the conv uses circular padding rather than zero padding.
+ Unlike the spatial dimension, the channels do not have inherent ordering nor
+ locality. Although this module in essence, applies such an assumption, it is unnecessary
+ to limit the channels on either "edge" from being circularly adapted to each other.
+ This will fundamentally increase connectivity and possibly increase performance metrics
+ (accuracy, robustness), without significantly impacting resource metrics
+ (parameter size, throughput,latency, etc)
+
+ Args:
+ channels: Number of channels of the input feature map for use in adaptive kernel sizes
+ for actual calculations according to channel.
+ gamma, beta: when channel is given parameters of mapping function
+ refer to original paper https://arxiv.org/pdf/1910.03151.pdf
+ (default=None. if channel size not given, use k_size given for kernel size.)
+ kernel_size: Adaptive selection of kernel size (default=3)
+ gamm: used in kernel_size calc, see above
+ beta: used in kernel_size calc, see above
+ act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
+ gate_layer: gating non-linearity to use
+ """
+
+ def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'):
+ super(CecaModule, self).__init__()
+ if channels is not None:
+ t = int(abs(math.log(channels, 2) + beta) / gamma)
+ kernel_size = max(t if t % 2 else t + 1, 3)
+ has_act = act_layer is not None
+ assert kernel_size % 2 == 1
+
+ # PyTorch circular padding mode is buggy as of pytorch 1.4
+ # see https://github.com/pytorch/pytorch/pull/17240
+ # implement manual circular padding
+ self.padding = (kernel_size - 1) // 2
+ self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act)
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ y = x.mean((2, 3)).view(x.shape[0], 1, -1)
+ # Manually implement circular padding, F.pad does not seemed to be bugged
+ y = F.pad(y, (self.padding, self.padding), mode='circular')
+ y = self.conv(y)
+ y = self.gate(y).view(x.shape[0], -1, 1, 1)
+ return x * y.expand_as(x)
+
+
+CircularEfficientChannelAttn = CecaModule
diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py
new file mode 100644
index 0000000..6ef0c88
--- /dev/null
+++ b/timm/models/layers/evo_norm.py
@@ -0,0 +1,81 @@
+"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch
+
+An attempt at getting decent performing EvoNorms running in PyTorch.
+While currently faster than other impl, still quite a ways off the built-in BN
+in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed).
+
+Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import torch
+import torch.nn as nn
+
+from .trace_utils import _assert
+
+
+class EvoNormBatch2d(nn.Module):
+ def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
+ super(EvoNormBatch2d, self).__init__()
+ self.apply_act = apply_act # apply activation (non-linearity)
+ self.momentum = momentum
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True)
+ self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True)
+ self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None
+ self.register_buffer('running_var', torch.ones(num_features))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.ones_(self.weight)
+ nn.init.zeros_(self.bias)
+ if self.apply_act:
+ nn.init.ones_(self.v)
+
+ def forward(self, x):
+ _assert(x.dim() == 4, 'expected 4D input')
+ x_type = x.dtype
+ if self.v is not None:
+ running_var = self.running_var.view(1, -1, 1, 1)
+ if self.training:
+ var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
+ n = x.numel() / x.shape[1]
+ running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
+ self.running_var.copy_(running_var.view(self.running_var.shape))
+ else:
+ var = running_var
+ v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
+ d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
+ d = d.max((var + self.eps).sqrt().to(dtype=x_type))
+ x = x / d
+ return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
+
+
+class EvoNormSample2d(nn.Module):
+ def __init__(self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None):
+ super(EvoNormSample2d, self).__init__()
+ self.apply_act = apply_act # apply activation (non-linearity)
+ self.groups = groups
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True)
+ self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True)
+ self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.ones_(self.weight)
+ nn.init.zeros_(self.bias)
+ if self.apply_act:
+ nn.init.ones_(self.v)
+
+ def forward(self, x):
+ _assert(x.dim() == 4, 'expected 4D input')
+ B, C, H, W = x.shape
+ _assert(C % self.groups == 0, '')
+ if self.v is not None:
+ n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid()
+ x = x.reshape(B, self.groups, -1)
+ x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
+ x = x.reshape(B, C, H, W)
+ return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
diff --git a/timm/models/layers/gather_excite.py b/timm/models/layers/gather_excite.py
new file mode 100644
index 0000000..2d60dc9
--- /dev/null
+++ b/timm/models/layers/gather_excite.py
@@ -0,0 +1,90 @@
+""" Gather-Excite Attention Block
+
+Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348
+
+Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet
+
+I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another
+impl that covers all of the cases.
+
+NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import math
+
+from torch import nn as nn
+import torch.nn.functional as F
+
+from .create_act import create_act_layer, get_act_layer
+from .create_conv2d import create_conv2d
+from .helpers import make_divisible
+from .mlp import ConvMlp
+
+
+class GatherExcite(nn.Module):
+ """ Gather-Excite Attention Module
+ """
+ def __init__(
+ self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
+ rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False,
+ act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):
+ super(GatherExcite, self).__init__()
+ self.add_maxpool = add_maxpool
+ act_layer = get_act_layer(act_layer)
+ self.extent = extent
+ if extra_params:
+ self.gather = nn.Sequential()
+ if extent == 0:
+ assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
+ self.gather.add_module(
+ 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))
+ if norm_layer:
+ self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))
+ else:
+ assert extent % 2 == 0
+ num_conv = int(math.log2(extent))
+ for i in range(num_conv):
+ self.gather.add_module(
+ f'conv{i + 1}',
+ create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))
+ if norm_layer:
+ self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels))
+ if i != num_conv - 1:
+ self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))
+ else:
+ self.gather = None
+ if self.extent == 0:
+ self.gk = 0
+ self.gs = 0
+ else:
+ assert extent % 2 == 0
+ self.gk = self.extent * 2 - 1
+ self.gs = self.extent
+
+ if not rd_channels:
+ rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
+ self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ size = x.shape[-2:]
+ if self.gather is not None:
+ x_ge = self.gather(x)
+ else:
+ if self.extent == 0:
+ # global extent
+ x_ge = x.mean(dim=(2, 3), keepdims=True)
+ if self.add_maxpool:
+ # experimental codepath, may remove or change
+ x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
+ else:
+ x_ge = F.avg_pool2d(
+ x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)
+ if self.add_maxpool:
+ # experimental codepath, may remove or change
+ x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)
+ x_ge = self.mlp(x_ge)
+ if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
+ x_ge = F.interpolate(x_ge, size=size)
+ return x * self.gate(x_ge)
diff --git a/timm/models/layers/global_context.py b/timm/models/layers/global_context.py
new file mode 100644
index 0000000..de7fb5c
--- /dev/null
+++ b/timm/models/layers/global_context.py
@@ -0,0 +1,67 @@
+""" Global Context Attention Block
+
+Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond`
+ - https://arxiv.org/abs/1904.11492
+
+Official code consulted as reference: https://github.com/xvjiarui/GCNet
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from torch import nn as nn
+import torch.nn.functional as F
+
+from .create_act import create_act_layer, get_act_layer
+from .helpers import make_divisible
+from .mlp import ConvMlp
+from .norm import LayerNorm2d
+
+
+class GlobalContext(nn.Module):
+
+ def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False,
+ rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
+ super(GlobalContext, self).__init__()
+ act_layer = get_act_layer(act_layer)
+
+ self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None
+
+ if rd_channels is None:
+ rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
+ if fuse_add:
+ self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
+ else:
+ self.mlp_add = None
+ if fuse_scale:
+ self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
+ else:
+ self.mlp_scale = None
+
+ self.gate = create_act_layer(gate_layer)
+ self.init_last_zero = init_last_zero
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ if self.conv_attn is not None:
+ nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
+ if self.mlp_add is not None:
+ nn.init.zeros_(self.mlp_add.fc2.weight)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+
+ if self.conv_attn is not None:
+ attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
+ attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
+ context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
+ context = context.view(B, C, 1, 1)
+ else:
+ context = x.mean(dim=(2, 3), keepdim=True)
+
+ if self.mlp_scale is not None:
+ mlp_x = self.mlp_scale(context)
+ x = x * self.gate(mlp_x)
+ if self.mlp_add is not None:
+ mlp_x = self.mlp_add(context)
+ x = x + mlp_x
+
+ return x
diff --git a/timm/models/layers/halo_attn.py b/timm/models/layers/halo_attn.py
new file mode 100644
index 0000000..f2ac64f
--- /dev/null
+++ b/timm/models/layers/halo_attn.py
@@ -0,0 +1,233 @@
+""" Halo Self Attention
+
+Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
+ - https://arxiv.org/abs/2103.12731
+
+@misc{2103.12731,
+Author = {Ashish Vaswani and Prajit Ramachandran and Aravind Srinivas and Niki Parmar and Blake Hechtman and
+ Jonathon Shlens},
+Title = {Scaling Local Self-Attention for Parameter Efficient Visual Backbones},
+Year = {2021},
+}
+
+Status:
+This impl is a WIP, there is no official ref impl and some details in paper weren't clear to me.
+The attention mechanism works but it's slow as implemented.
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from typing import List
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from .helpers import make_divisible
+from .weight_init import trunc_normal_
+from .trace_utils import _assert
+
+
+def rel_logits_1d(q, rel_k, permute_mask: List[int]):
+ """ Compute relative logits along one dimension
+
+ As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
+ Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
+
+ Args:
+ q: (batch, height, width, dim)
+ rel_k: (2 * window - 1, dim)
+ permute_mask: permute output dim according to this
+ """
+ B, H, W, dim = q.shape
+ rel_size = rel_k.shape[0]
+ win_size = (rel_size + 1) // 2
+
+ x = (q @ rel_k.transpose(-1, -2))
+ x = x.reshape(-1, W, rel_size)
+
+ # pad to shift from relative to absolute indexing
+ x_pad = F.pad(x, [0, 1]).flatten(1)
+ x_pad = F.pad(x_pad, [0, rel_size - W])
+
+ # reshape and slice out the padded elements
+ x_pad = x_pad.reshape(-1, W + 1, rel_size)
+ x = x_pad[:, :W, win_size - 1:]
+
+ # reshape and tile
+ x = x.reshape(B, H, 1, W, win_size).expand(-1, -1, win_size, -1, -1)
+ return x.permute(permute_mask)
+
+
+class PosEmbedRel(nn.Module):
+ """ Relative Position Embedding
+ As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
+ Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
+
+ """
+ def __init__(self, block_size, win_size, dim_head, scale):
+ """
+ Args:
+ block_size (int): block size
+ win_size (int): neighbourhood window size
+ dim_head (int): attention head dim
+ scale (float): scale factor (for init)
+ """
+ super().__init__()
+ self.block_size = block_size
+ self.dim_head = dim_head
+ self.height_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
+ self.width_rel = nn.Parameter(torch.randn(win_size * 2 - 1, dim_head) * scale)
+
+ def forward(self, q):
+ B, BB, HW, _ = q.shape
+
+ # relative logits in width dimension.
+ q = q.reshape(-1, self.block_size, self.block_size, self.dim_head)
+ rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
+
+ # relative logits in height dimension.
+ q = q.transpose(1, 2)
+ rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
+
+ rel_logits = rel_logits_h + rel_logits_w
+ rel_logits = rel_logits.reshape(B, BB, HW, -1)
+ return rel_logits
+
+
+class HaloAttn(nn.Module):
+ """ Halo Attention
+
+ Paper: `Scaling Local Self-Attention for Parameter Efficient Visual Backbones`
+ - https://arxiv.org/abs/2103.12731
+
+ The internal dimensions of the attention module are controlled by the interaction of several arguments.
+ * the output dimension of the module is specified by dim_out, which falls back to input dim if not set
+ * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
+ * the query and key (qk) dimensions are determined by
+ * num_heads * dim_head if dim_head is not None
+ * num_heads * (dim_out * attn_ratio // num_heads) if dim_head is None
+ * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not used
+
+ Args:
+ dim (int): input dimension to the module
+ dim_out (int): output dimension of the module, same as dim if not set
+ feat_size (Tuple[int, int]): size of input feature_map (not used, for arg compat with bottle/lambda)
+ stride: output stride of the module, query downscaled if > 1 (default: 1).
+ num_heads: parallel attention heads (default: 8).
+ dim_head: dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
+ block_size (int): size of blocks. (default: 8)
+ halo_size (int): size of halo overlap. (default: 3)
+ qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
+ qkv_bias (bool) : add bias to q, k, and v projections
+ avg_down (bool): use average pool downsample instead of strided query blocks
+ scale_pos_embed (bool): scale the position embedding as well as Q @ K
+ """
+ def __init__(
+ self, dim, dim_out=None, feat_size=None, stride=1, num_heads=8, dim_head=None, block_size=8, halo_size=3,
+ qk_ratio=1.0, qkv_bias=False, avg_down=False, scale_pos_embed=False):
+ super().__init__()
+ dim_out = dim_out or dim
+ assert dim_out % num_heads == 0
+ assert stride in (1, 2)
+ self.num_heads = num_heads
+ self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
+ self.dim_head_v = dim_out // self.num_heads
+ self.dim_out_qk = num_heads * self.dim_head_qk
+ self.dim_out_v = num_heads * self.dim_head_v
+ self.scale = self.dim_head_qk ** -0.5
+ self.scale_pos_embed = scale_pos_embed
+ self.block_size = self.block_size_ds = block_size
+ self.halo_size = halo_size
+ self.win_size = block_size + halo_size * 2 # neighbourhood window size
+ self.block_stride = 1
+ use_avg_pool = False
+ if stride > 1:
+ use_avg_pool = avg_down or block_size % stride != 0
+ self.block_stride = 1 if use_avg_pool else stride
+ self.block_size_ds = self.block_size // self.block_stride
+
+ # FIXME not clear if this stride behaviour is what the paper intended
+ # Also, the paper mentions using a 3D conv for dealing with the blocking/gather, and leaving
+ # data in unfolded block form. I haven't wrapped my head around how that'd look.
+ self.q = nn.Conv2d(dim, self.dim_out_qk, 1, stride=self.block_stride, bias=qkv_bias)
+ self.kv = nn.Conv2d(dim, self.dim_out_qk + self.dim_out_v, 1, bias=qkv_bias)
+
+ self.pos_embed = PosEmbedRel(
+ block_size=self.block_size_ds, win_size=self.win_size, dim_head=self.dim_head_qk, scale=self.scale)
+
+ self.pool = nn.AvgPool2d(2, 2) if use_avg_pool else nn.Identity()
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ std = self.q.weight.shape[1] ** -0.5 # fan-in
+ trunc_normal_(self.q.weight, std=std)
+ trunc_normal_(self.kv.weight, std=std)
+ trunc_normal_(self.pos_embed.height_rel, std=self.scale)
+ trunc_normal_(self.pos_embed.width_rel, std=self.scale)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ _assert(H % self.block_size == 0, '')
+ _assert(W % self.block_size == 0, '')
+ num_h_blocks = H // self.block_size
+ num_w_blocks = W // self.block_size
+ num_blocks = num_h_blocks * num_w_blocks
+
+ q = self.q(x)
+ # unfold
+ q = q.reshape(
+ -1, self.dim_head_qk,
+ num_h_blocks, self.block_size_ds, num_w_blocks, self.block_size_ds).permute(0, 1, 3, 5, 2, 4)
+ # B, num_heads * dim_head * block_size ** 2, num_blocks
+ q = q.reshape(B * self.num_heads, self.dim_head_qk, -1, num_blocks).transpose(1, 3)
+ # B * num_heads, num_blocks, block_size ** 2, dim_head
+
+ kv = self.kv(x)
+ # Generate overlapping windows for kv. This approach is good for GPU and CPU. However, unfold() is not
+ # lowered for PyTorch XLA so it will be very slow. See code at bottom of file for XLA friendly approach.
+ # FIXME figure out how to switch impl between this and conv2d if XLA being used.
+ kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size])
+ kv = kv.unfold(2, self.win_size, self.block_size).unfold(3, self.win_size, self.block_size).reshape(
+ B * self.num_heads, self.dim_head_qk + self.dim_head_v, num_blocks, -1).permute(0, 2, 3, 1)
+ k, v = torch.split(kv, [self.dim_head_qk, self.dim_head_v], dim=-1)
+ # B * num_heads, num_blocks, win_size ** 2, dim_head_qk or dim_head_v
+
+ if self.scale_pos_embed:
+ attn = (q @ k.transpose(-1, -2) + self.pos_embed(q)) * self.scale
+ else:
+ attn = (q @ k.transpose(-1, -2)) * self.scale + self.pos_embed(q)
+ # B * num_heads, num_blocks, block_size ** 2, win_size ** 2
+ attn = attn.softmax(dim=-1)
+
+ out = (attn @ v).transpose(1, 3) # B * num_heads, dim_head_v, block_size ** 2, num_blocks
+ # fold
+ out = out.reshape(-1, self.block_size_ds, self.block_size_ds, num_h_blocks, num_w_blocks)
+ out = out.permute(0, 3, 1, 4, 2).contiguous().view(
+ B, self.dim_out_v, H // self.block_stride, W // self.block_stride)
+ # B, dim_out, H // block_stride, W // block_stride
+ out = self.pool(out)
+ return out
+
+
+""" Three alternatives for overlapping windows.
+
+`.unfold().unfold()` is same speed as stride tricks with similar clarity as F.unfold()
+
+ if is_xla:
+ # This code achieves haloing on PyTorch XLA with reasonable runtime trade-off, it is
+ # EXTREMELY slow for backward on a GPU though so I need a way of selecting based on environment.
+ WW = self.win_size ** 2
+ pw = torch.eye(WW, dtype=x.dtype, device=x.device).reshape(WW, 1, self.win_size, self.win_size)
+ kv = F.conv2d(kv.reshape(-1, 1, H, W), pw, stride=self.block_size, padding=self.halo_size)
+ elif self.stride_tricks:
+ kv = F.pad(kv, [self.halo_size, self.halo_size, self.halo_size, self.halo_size]).contiguous()
+ kv = kv.as_strided((
+ B, self.dim_out_qk + self.dim_out_v, self.win_size, self.win_size, num_h_blocks, num_w_blocks),
+ stride=(kv.stride(0), kv.stride(1), kv.shape[-1], 1, self.block_size * kv.shape[-1], self.block_size))
+ else:
+ kv = F.unfold(kv, kernel_size=self.win_size, stride=self.block_size, padding=self.halo_size)
+
+ kv = kv.reshape(
+ B * self.num_heads, self.dim_head_qk + self.dim_head_v, -1, num_blocks).transpose(1, 3)
+"""
diff --git a/timm/models/layers/helpers.py b/timm/models/layers/helpers.py
new file mode 100644
index 0000000..cc54ca7
--- /dev/null
+++ b/timm/models/layers/helpers.py
@@ -0,0 +1,31 @@
+""" Layer/Module Helpers
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from itertools import repeat
+import collections.abc
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
+
+
+def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
+ min_value = min_value or divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ # Make sure that round down does not go down by more than 10%.
+ if new_v < round_limit * v:
+ new_v += divisor
+ return new_v
diff --git a/timm/models/layers/inplace_abn.py b/timm/models/layers/inplace_abn.py
new file mode 100644
index 0000000..3aae7cf
--- /dev/null
+++ b/timm/models/layers/inplace_abn.py
@@ -0,0 +1,87 @@
+import torch
+from torch import nn as nn
+
+try:
+ from inplace_abn.functions import inplace_abn, inplace_abn_sync
+ has_iabn = True
+except ImportError:
+ has_iabn = False
+
+ def inplace_abn(x, weight, bias, running_mean, running_var,
+ training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
+ raise ImportError(
+ "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'")
+
+ def inplace_abn_sync(**kwargs):
+ inplace_abn(**kwargs)
+
+
+class InplaceAbn(nn.Module):
+ """Activated Batch Normalization
+
+ This gathers a BatchNorm and an activation function in a single module
+
+ Parameters
+ ----------
+ num_features : int
+ Number of feature channels in the input and output.
+ eps : float
+ Small constant to prevent numerical issues.
+ momentum : float
+ Momentum factor applied to compute running statistics.
+ affine : bool
+ If `True` apply learned scale and shift transformation after normalization.
+ act_layer : str or nn.Module type
+ Name or type of the activation functions, one of: `leaky_relu`, `elu`
+ act_param : float
+ Negative slope for the `leaky_relu` activation.
+ """
+
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
+ act_layer="leaky_relu", act_param=0.01, drop_block=None):
+ super(InplaceAbn, self).__init__()
+ self.num_features = num_features
+ self.affine = affine
+ self.eps = eps
+ self.momentum = momentum
+ if apply_act:
+ if isinstance(act_layer, str):
+ assert act_layer in ('leaky_relu', 'elu', 'identity', '')
+ self.act_name = act_layer if act_layer else 'identity'
+ else:
+ # convert act layer passed as type to string
+ if act_layer == nn.ELU:
+ self.act_name = 'elu'
+ elif act_layer == nn.LeakyReLU:
+ self.act_name = 'leaky_relu'
+ elif act_layer == nn.Identity:
+ self.act_name = 'identity'
+ else:
+ assert False, f'Invalid act layer {act_layer.__name__} for IABN'
+ else:
+ self.act_name = 'identity'
+ self.act_param = act_param
+ if self.affine:
+ self.weight = nn.Parameter(torch.ones(num_features))
+ self.bias = nn.Parameter(torch.zeros(num_features))
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+ self.register_buffer('running_mean', torch.zeros(num_features))
+ self.register_buffer('running_var', torch.ones(num_features))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.constant_(self.running_mean, 0)
+ nn.init.constant_(self.running_var, 1)
+ if self.affine:
+ nn.init.constant_(self.weight, 1)
+ nn.init.constant_(self.bias, 0)
+
+ def forward(self, x):
+ output = inplace_abn(
+ x, self.weight, self.bias, self.running_mean, self.running_var,
+ self.training, self.momentum, self.eps, self.act_name, self.act_param)
+ if isinstance(output, tuple):
+ output = output[0]
+ return output
diff --git a/timm/models/layers/lambda_layer.py b/timm/models/layers/lambda_layer.py
new file mode 100644
index 0000000..e50b43c
--- /dev/null
+++ b/timm/models/layers/lambda_layer.py
@@ -0,0 +1,133 @@
+""" Lambda Layer
+
+Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
+ - https://arxiv.org/abs/2102.08602
+
+@misc{2102.08602,
+Author = {Irwan Bello},
+Title = {LambdaNetworks: Modeling Long-Range Interactions Without Attention},
+Year = {2021},
+}
+
+Status:
+This impl is a WIP. Code snippets in the paper were used as reference but
+good chance some details are missing/wrong.
+
+I've only implemented local lambda conv based pos embeddings.
+
+For a PyTorch impl that includes other embedding options checkout
+https://github.com/lucidrains/lambda-networks
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from .helpers import to_2tuple, make_divisible
+from .weight_init import trunc_normal_
+
+
+def rel_pos_indices(size):
+ size = to_2tuple(size)
+ pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
+ rel_pos = pos[:, None, :] - pos[:, :, None]
+ rel_pos[0] += size[0] - 1
+ rel_pos[1] += size[1] - 1
+ return rel_pos # 2, H * W, H * W
+
+
+class LambdaLayer(nn.Module):
+ """Lambda Layer
+
+ Paper: `LambdaNetworks: Modeling Long-Range Interactions Without Attention`
+ - https://arxiv.org/abs/2102.08602
+
+ NOTE: intra-depth parameter 'u' is fixed at 1. It did not appear worth the complexity to add.
+
+ The internal dimensions of the lambda module are controlled via the interaction of several arguments.
+ * the output dimension of the module is specified by dim_out, which falls back to input dim if not set
+ * the value (v) dimension is set to dim_out // num_heads, the v projection determines the output dim
+ * the query (q) and key (k) dimension are determined by
+ * dim_head = (dim_out * attn_ratio // num_heads) if dim_head is None
+ * q = num_heads * dim_head, k = dim_head
+ * as seen above, attn_ratio determines the ratio of q and k relative to the output if dim_head not set
+
+ Args:
+ dim (int): input dimension to the module
+ dim_out (int): output dimension of the module, same as dim if not set
+ feat_size (Tuple[int, int]): size of input feature_map for relative pos variant H, W
+ stride (int): output stride of the module, avg pool used if stride == 2
+ num_heads (int): parallel attention heads.
+ dim_head (int): dimension of query and key heads, calculated from dim_out * attn_ratio // num_heads if not set
+ r (int): local lambda convolution radius. Use lambda conv if set, else relative pos if not. (default: 9)
+ qk_ratio (float): ratio of q and k dimensions to output dimension when dim_head not set. (default: 1.0)
+ qkv_bias (bool): add bias to q, k, and v projections
+ """
+ def __init__(
+ self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, dim_head=16, r=9,
+ qk_ratio=1.0, qkv_bias=False):
+ super().__init__()
+ dim_out = dim_out or dim
+ assert dim_out % num_heads == 0, ' should be divided by num_heads'
+ self.dim_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
+ self.num_heads = num_heads
+ self.dim_v = dim_out // num_heads
+
+ self.qkv = nn.Conv2d(
+ dim,
+ num_heads * self.dim_qk + self.dim_qk + self.dim_v,
+ kernel_size=1, bias=qkv_bias)
+ self.norm_q = nn.BatchNorm2d(num_heads * self.dim_qk)
+ self.norm_v = nn.BatchNorm2d(self.dim_v)
+
+ if r is not None:
+ # local lambda convolution for pos
+ self.conv_lambda = nn.Conv3d(1, self.dim_qk, (r, r, 1), padding=(r // 2, r // 2, 0))
+ self.pos_emb = None
+ self.rel_pos_indices = None
+ else:
+ # relative pos embedding
+ assert feat_size is not None
+ feat_size = to_2tuple(feat_size)
+ rel_size = [2 * s - 1 for s in feat_size]
+ self.conv_lambda = None
+ self.pos_emb = nn.Parameter(torch.zeros(rel_size[0], rel_size[1], self.dim_qk))
+ self.register_buffer('rel_pos_indices', rel_pos_indices(feat_size), persistent=False)
+
+ self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) # fan-in
+ if self.conv_lambda is not None:
+ trunc_normal_(self.conv_lambda.weight, std=self.dim_qk ** -0.5)
+ if self.pos_emb is not None:
+ trunc_normal_(self.pos_emb, std=.02)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ M = H * W
+ qkv = self.qkv(x)
+ q, k, v = torch.split(qkv, [
+ self.num_heads * self.dim_qk, self.dim_qk, self.dim_v], dim=1)
+ q = self.norm_q(q).reshape(B, self.num_heads, self.dim_qk, M).transpose(-1, -2) # B, num_heads, M, K
+ v = self.norm_v(v).reshape(B, self.dim_v, M).transpose(-1, -2) # B, M, V
+ k = F.softmax(k.reshape(B, self.dim_qk, M), dim=-1) # B, K, M
+
+ content_lam = k @ v # B, K, V
+ content_out = q @ content_lam.unsqueeze(1) # B, num_heads, M, V
+
+ if self.pos_emb is None:
+ position_lam = self.conv_lambda(v.reshape(B, 1, H, W, self.dim_v)) # B, H, W, V, K
+ position_lam = position_lam.reshape(B, 1, self.dim_qk, H * W, self.dim_v).transpose(2, 3) # B, 1, M, K, V
+ else:
+ # FIXME relative pos embedding path not fully verified
+ pos_emb = self.pos_emb[self.rel_pos_indices[0], self.rel_pos_indices[1]].expand(B, -1, -1, -1)
+ position_lam = (pos_emb.transpose(-1, -2) @ v.unsqueeze(1)).unsqueeze(1) # B, 1, M, K, V
+ position_out = (q.unsqueeze(-2) @ position_lam).squeeze(-2) # B, num_heads, M, V
+
+ out = (content_out + position_out).transpose(-1, -2).reshape(B, C, H, W) # B, C (num_heads * V), H, W
+ out = self.pool(out)
+ return out
diff --git a/timm/models/layers/linear.py b/timm/models/layers/linear.py
new file mode 100644
index 0000000..38fe338
--- /dev/null
+++ b/timm/models/layers/linear.py
@@ -0,0 +1,19 @@
+""" Linear layer (alternate definition)
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn as nn
+
+
+class Linear(nn.Linear):
+ r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
+
+ Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
+ weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
+ """
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ if torch.jit.is_scripting():
+ bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
+ return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
+ else:
+ return F.linear(input, self.weight, self.bias)
diff --git a/timm/models/layers/median_pool.py b/timm/models/layers/median_pool.py
new file mode 100644
index 0000000..40bd71a
--- /dev/null
+++ b/timm/models/layers/median_pool.py
@@ -0,0 +1,49 @@
+""" Median Pool
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch.nn as nn
+import torch.nn.functional as F
+from .helpers import to_2tuple, to_4tuple
+
+
+class MedianPool2d(nn.Module):
+ """ Median pool (usable as median filter when stride=1) module.
+
+ Args:
+ kernel_size: size of pooling kernel, int or 2-tuple
+ stride: pool stride, int or 2-tuple
+ padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
+ same: override padding and enforce same padding, boolean
+ """
+ def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
+ super(MedianPool2d, self).__init__()
+ self.k = to_2tuple(kernel_size)
+ self.stride = to_2tuple(stride)
+ self.padding = to_4tuple(padding) # convert to l, r, t, b
+ self.same = same
+
+ def _padding(self, x):
+ if self.same:
+ ih, iw = x.size()[2:]
+ if ih % self.stride[0] == 0:
+ ph = max(self.k[0] - self.stride[0], 0)
+ else:
+ ph = max(self.k[0] - (ih % self.stride[0]), 0)
+ if iw % self.stride[1] == 0:
+ pw = max(self.k[1] - self.stride[1], 0)
+ else:
+ pw = max(self.k[1] - (iw % self.stride[1]), 0)
+ pl = pw // 2
+ pr = pw - pl
+ pt = ph // 2
+ pb = ph - pt
+ padding = (pl, pr, pt, pb)
+ else:
+ padding = self.padding
+ return padding
+
+ def forward(self, x):
+ x = F.pad(x, self._padding(x), mode='reflect')
+ x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
+ x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
+ return x
diff --git a/timm/models/layers/mixed_conv2d.py b/timm/models/layers/mixed_conv2d.py
new file mode 100644
index 0000000..fa0ce56
--- /dev/null
+++ b/timm/models/layers/mixed_conv2d.py
@@ -0,0 +1,51 @@
+""" PyTorch Mixed Convolution
+
+Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import torch
+from torch import nn as nn
+
+from .conv2d_same import create_conv2d_pad
+
+
+def _split_channels(num_chan, num_groups):
+ split = [num_chan // num_groups for _ in range(num_groups)]
+ split[0] += num_chan - sum(split)
+ return split
+
+
+class MixedConv2d(nn.ModuleDict):
+ """ Mixed Grouped Convolution
+
+ Based on MDConv and GroupedConv in MixNet impl:
+ https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
+ """
+ def __init__(self, in_channels, out_channels, kernel_size=3,
+ stride=1, padding='', dilation=1, depthwise=False, **kwargs):
+ super(MixedConv2d, self).__init__()
+
+ kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
+ num_groups = len(kernel_size)
+ in_splits = _split_channels(in_channels, num_groups)
+ out_splits = _split_channels(out_channels, num_groups)
+ self.in_channels = sum(in_splits)
+ self.out_channels = sum(out_splits)
+ for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
+ conv_groups = in_ch if depthwise else 1
+ # use add_module to keep key space clean
+ self.add_module(
+ str(idx),
+ create_conv2d_pad(
+ in_ch, out_ch, k, stride=stride,
+ padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
+ )
+ self.splits = in_splits
+
+ def forward(self, x):
+ x_split = torch.split(x, self.splits, 1)
+ x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
+ x = torch.cat(x_out, 1)
+ return x
diff --git a/timm/models/layers/mlp.py b/timm/models/layers/mlp.py
new file mode 100644
index 0000000..a85e28d
--- /dev/null
+++ b/timm/models/layers/mlp.py
@@ -0,0 +1,119 @@
+""" MLP module w/ dropout and configurable activation layer
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from torch import nn as nn
+
+from .helpers import to_2tuple
+
+
+class Mlp(nn.Module):
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+ """
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class GluMlp(nn.Module):
+ """ MLP w/ GLU style gating
+ See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
+ """
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ assert hidden_features % 2 == 0
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.fc2 = nn.Linear(hidden_features // 2, out_features)
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def init_weights(self):
+ # override init of fc1 w/ gate portion set to weight near zero, bias=1
+ fc1_mid = self.fc1.bias.shape[0] // 2
+ nn.init.ones_(self.fc1.bias[fc1_mid:])
+ nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x, gates = x.chunk(2, dim=-1)
+ x = x * self.act(gates)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class GatedMlp(nn.Module):
+ """ MLP as used in gMLP
+ """
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
+ gate_layer=None, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ drop_probs = to_2tuple(drop)
+
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ if gate_layer is not None:
+ assert hidden_features % 2 == 0
+ self.gate = gate_layer(hidden_features)
+ hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
+ else:
+ self.gate = nn.Identity()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.gate(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class ConvMlp(nn.Module):
+ """ MLP using 1x1 convs that keeps spatial dims
+ """
+ def __init__(
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
+ self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
+ self.act = act_layer()
+ self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.norm(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ return x
diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py
new file mode 100644
index 0000000..881fa36
--- /dev/null
+++ b/timm/models/layers/non_local_attn.py
@@ -0,0 +1,145 @@
+""" Bilinear-Attention-Transform and Non-Local Attention
+
+Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms`
+ - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html
+Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification
+"""
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .conv_bn_act import ConvBnAct
+from .helpers import make_divisible
+from .trace_utils import _assert
+
+
+class NonLocalAttn(nn.Module):
+ """Spatial NL block for image classification.
+
+ This was adapted from https://github.com/BA-Transform/BAT-Image-Classification
+ Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net.
+ """
+
+ def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs):
+ super(NonLocalAttn, self).__init__()
+ if rd_channels is None:
+ rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
+ self.scale = in_channels ** -0.5 if use_scale else 1.0
+ self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
+ self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
+ self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
+ self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True)
+ self.norm = nn.BatchNorm2d(in_channels)
+ self.reset_parameters()
+
+ def forward(self, x):
+ shortcut = x
+
+ t = self.t(x)
+ p = self.p(x)
+ g = self.g(x)
+
+ B, C, H, W = t.size()
+ t = t.view(B, C, -1).permute(0, 2, 1)
+ p = p.view(B, C, -1)
+ g = g.view(B, C, -1).permute(0, 2, 1)
+
+ att = torch.bmm(t, p) * self.scale
+ att = F.softmax(att, dim=2)
+ x = torch.bmm(att, g)
+
+ x = x.permute(0, 2, 1).reshape(B, C, H, W)
+ x = self.z(x)
+ x = self.norm(x) + shortcut
+
+ return x
+
+ def reset_parameters(self):
+ for name, m in self.named_modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ if len(list(m.parameters())) > 1:
+ nn.init.constant_(m.bias, 0.0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 0)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.GroupNorm):
+ nn.init.constant_(m.weight, 0)
+ nn.init.constant_(m.bias, 0)
+
+
+class BilinearAttnTransform(nn.Module):
+
+ def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
+ super(BilinearAttnTransform, self).__init__()
+
+ self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer)
+ self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1))
+ self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size))
+ self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
+ self.block_size = block_size
+ self.groups = groups
+ self.in_channels = in_channels
+
+ def resize_mat(self, x, t: int):
+ B, C, block_size, block_size1 = x.shape
+ _assert(block_size == block_size1, '')
+ if t <= 1:
+ return x
+ x = x.view(B * C, -1, 1, 1)
+ x = x * torch.eye(t, t, dtype=x.dtype, device=x.device)
+ x = x.view(B * C, block_size, block_size, t, t)
+ x = torch.cat(torch.split(x, 1, dim=1), dim=3)
+ x = torch.cat(torch.split(x, 1, dim=2), dim=4)
+ x = x.view(B, C, block_size * t, block_size * t)
+ return x
+
+ def forward(self, x):
+ _assert(x.shape[-1] % self.block_size == 0, '')
+ _assert(x.shape[-2] % self.block_size == 0, '')
+ B, C, H, W = x.shape
+ out = self.conv1(x)
+ rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
+ cp = F.adaptive_max_pool2d(out, (1, self.block_size))
+ p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
+ q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size).sigmoid()
+ p = p / p.sum(dim=3, keepdim=True)
+ q = q / q.sum(dim=2, keepdim=True)
+ p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
+ 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
+ p = p.view(B, C, self.block_size, self.block_size)
+ q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
+ 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
+ q = q.view(B, C, self.block_size, self.block_size)
+ p = self.resize_mat(p, H // self.block_size)
+ q = self.resize_mat(q, W // self.block_size)
+ y = p.matmul(x)
+ y = y.matmul(q)
+
+ y = self.conv2(y)
+ return y
+
+
+class BatNonLocalAttn(nn.Module):
+ """ BAT
+ Adapted from: https://github.com/BA-Transform/BAT-Image-Classification
+ """
+
+ def __init__(
+ self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
+ drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_):
+ super().__init__()
+ if rd_channels is None:
+ rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
+ self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
+ self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer)
+ self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
+ self.dropout = nn.Dropout2d(p=drop_rate)
+
+ def forward(self, x):
+ xl = self.conv1(x)
+ y = self.ba(xl)
+ y = self.conv2(y)
+ y = self.dropout(y)
+ return y + x
diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py
new file mode 100644
index 0000000..8529742
--- /dev/null
+++ b/timm/models/layers/norm.py
@@ -0,0 +1,24 @@
+""" Normalization layers and wrappers
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class GroupNorm(nn.GroupNorm):
+ def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
+ # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
+ super().__init__(num_groups, num_channels, eps=eps, affine=affine)
+
+ def forward(self, x):
+ return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
+
+
+class LayerNorm2d(nn.LayerNorm):
+ """ LayerNorm for channels of '2D' spatial BCHW tensors """
+ def __init__(self, num_channels):
+ super().__init__(num_channels)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return F.layer_norm(
+ x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py
new file mode 100644
index 0000000..2e15181
--- /dev/null
+++ b/timm/models/layers/norm_act.py
@@ -0,0 +1,85 @@
+""" Normalization + Activation Layers
+"""
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from .create_act import get_act_layer
+
+
+class BatchNormAct2d(nn.BatchNorm2d):
+ """BatchNorm + Activation
+
+ This module performs BatchNorm + Activation in a manner that will remain backwards
+ compatible with weights trained with separate bn, act. This is why we inherit from BN
+ instead of composing it as a .bn member.
+ """
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
+ apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
+ super(BatchNormAct2d, self).__init__(
+ num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
+ if isinstance(act_layer, str):
+ act_layer = get_act_layer(act_layer)
+ if act_layer is not None and apply_act:
+ act_args = dict(inplace=True) if inplace else {}
+ self.act = act_layer(**act_args)
+ else:
+ self.act = nn.Identity()
+
+ def _forward_jit(self, x):
+ """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function
+ """
+ # exponential_average_factor is self.momentum set to
+ # (when it is available) only so that if gets updated
+ # in ONNX graph when this node is exported to ONNX.
+ if self.momentum is None:
+ exponential_average_factor = 0.0
+ else:
+ exponential_average_factor = self.momentum
+
+ if self.training and self.track_running_stats:
+ # TODO: if statement only here to tell the jit to skip emitting this when it is None
+ if self.num_batches_tracked is not None:
+ self.num_batches_tracked += 1
+ if self.momentum is None: # use cumulative moving average
+ exponential_average_factor = 1.0 / float(self.num_batches_tracked)
+ else: # use exponential moving average
+ exponential_average_factor = self.momentum
+
+ x = F.batch_norm(
+ x, self.running_mean, self.running_var, self.weight, self.bias,
+ self.training or not self.track_running_stats,
+ exponential_average_factor, self.eps)
+ return x
+
+ @torch.jit.ignore
+ def _forward_python(self, x):
+ return super(BatchNormAct2d, self).forward(x)
+
+ def forward(self, x):
+ # FIXME cannot call parent forward() and maintain jit.script compatibility?
+ if torch.jit.is_scripting():
+ x = self._forward_jit(x)
+ else:
+ x = self._forward_python(x)
+ x = self.act(x)
+ return x
+
+
+class GroupNormAct(nn.GroupNorm):
+ # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
+ def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True,
+ apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
+ super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine)
+ if isinstance(act_layer, str):
+ act_layer = get_act_layer(act_layer)
+ if act_layer is not None and apply_act:
+ act_args = dict(inplace=True) if inplace else {}
+ self.act = act_layer(**act_args)
+ else:
+ self.act = nn.Identity()
+
+ def forward(self, x):
+ x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
+ x = self.act(x)
+ return x
diff --git a/timm/models/layers/padding.py b/timm/models/layers/padding.py
new file mode 100644
index 0000000..34afc37
--- /dev/null
+++ b/timm/models/layers/padding.py
@@ -0,0 +1,56 @@
+""" Padding Helpers
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import math
+from typing import List, Tuple
+
+import torch.nn.functional as F
+
+
+# Calculate symmetric padding for a convolution
+def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
+ return padding
+
+
+# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
+def get_same_padding(x: int, k: int, s: int, d: int):
+ return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
+
+
+# Can SAME padding for given args be done statically?
+def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
+ return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
+
+
+# Dynamically pad input x with 'SAME' padding for conv with specified args
+def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
+ ih, iw = x.size()[-2:]
+ pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
+ return x
+
+
+def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
+ dynamic = False
+ if isinstance(padding, str):
+ # for any string padding, the padding will be calculated for you, one of three ways
+ padding = padding.lower()
+ if padding == 'same':
+ # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
+ if is_static_pad(kernel_size, **kwargs):
+ # static case, no extra overhead
+ padding = get_padding(kernel_size, **kwargs)
+ else:
+ # dynamic 'SAME' padding, has runtime/GPU memory overhead
+ padding = 0
+ dynamic = True
+ elif padding == 'valid':
+ # 'VALID' padding, same as padding=0
+ padding = 0
+ else:
+ # Default to PyTorch style 'same'-ish symmetric padding
+ padding = get_padding(kernel_size, **kwargs)
+ return padding, dynamic
diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py
new file mode 100644
index 0000000..6a7face
--- /dev/null
+++ b/timm/models/layers/patch_embed.py
@@ -0,0 +1,39 @@
+""" Image to Patch Embedding using Conv2d
+
+A convolution based approach to patchifying a 2D image w/ embedding projection.
+
+Based on the impl in https://github.com/google-research/vision_transformer
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from torch import nn as nn
+
+from .helpers import to_2tuple
+from .trace_utils import _assert
+
+
+class PatchEmbed(nn.Module):
+ """ 2D Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.flatten = flatten
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
+ _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
+ x = self.proj(x)
+ if self.flatten:
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
+ x = self.norm(x)
+ return x
diff --git a/timm/models/layers/pool2d_same.py b/timm/models/layers/pool2d_same.py
new file mode 100644
index 0000000..4c2a1c4
--- /dev/null
+++ b/timm/models/layers/pool2d_same.py
@@ -0,0 +1,73 @@
+""" AvgPool2d w/ Same Padding
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List, Tuple, Optional
+
+from .helpers import to_2tuple
+from .padding import pad_same, get_padding_value
+
+
+def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
+ ceil_mode: bool = False, count_include_pad: bool = True):
+ # FIXME how to deal with count_include_pad vs not for external padding?
+ x = pad_same(x, kernel_size, stride)
+ return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
+
+
+class AvgPool2dSame(nn.AvgPool2d):
+ """ Tensorflow like 'SAME' wrapper for 2D average pooling
+ """
+ def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
+
+ def forward(self, x):
+ x = pad_same(x, self.kernel_size, self.stride)
+ return F.avg_pool2d(
+ x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
+
+
+def max_pool2d_same(
+ x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
+ dilation: List[int] = (1, 1), ceil_mode: bool = False):
+ x = pad_same(x, kernel_size, stride, value=-float('inf'))
+ return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
+
+
+class MaxPool2dSame(nn.MaxPool2d):
+ """ Tensorflow like 'SAME' wrapper for 2D max pooling
+ """
+ def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ dilation = to_2tuple(dilation)
+ super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)
+
+ def forward(self, x):
+ x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))
+ return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)
+
+
+def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
+ stride = stride or kernel_size
+ padding = kwargs.pop('padding', '')
+ padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
+ if is_dynamic:
+ if pool_type == 'avg':
+ return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
+ elif pool_type == 'max':
+ return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
+ else:
+ assert False, f'Unsupported pool type {pool_type}'
+ else:
+ if pool_type == 'avg':
+ return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
+ elif pool_type == 'max':
+ return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
+ else:
+ assert False, f'Unsupported pool type {pool_type}'
diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py
new file mode 100644
index 0000000..1aeb929
--- /dev/null
+++ b/timm/models/layers/selective_kernel.py
@@ -0,0 +1,120 @@
+""" Selective Kernel Convolution/Attention
+
+Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+from torch import nn as nn
+
+from .conv_bn_act import ConvBnAct
+from .helpers import make_divisible
+from .trace_utils import _assert
+
+
+def _kernel_valid(k):
+ if isinstance(k, (list, tuple)):
+ for ki in k:
+ return _kernel_valid(ki)
+ assert k >= 3 and k % 2
+
+
+class SelectiveKernelAttn(nn.Module):
+ def __init__(self, channels, num_paths=2, attn_channels=32,
+ act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
+ """ Selective Kernel Attention Module
+
+ Selective Kernel attention mechanism factored out into its own module.
+
+ """
+ super(SelectiveKernelAttn, self).__init__()
+ self.num_paths = num_paths
+ self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
+ self.bn = norm_layer(attn_channels)
+ self.act = act_layer(inplace=True)
+ self.fc_select = nn.Conv2d(attn_channels, channels * num_paths, kernel_size=1, bias=False)
+
+ def forward(self, x):
+ _assert(x.shape[1] == self.num_paths, '')
+ x = x.sum(1).mean((2, 3), keepdim=True)
+ x = self.fc_reduce(x)
+ x = self.bn(x)
+ x = self.act(x)
+ x = self.fc_select(x)
+ B, C, H, W = x.shape
+ x = x.view(B, self.num_paths, C // self.num_paths, H, W)
+ x = torch.softmax(x, dim=1)
+ return x
+
+
+class SelectiveKernel(nn.Module):
+
+ def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
+ rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True,
+ drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
+ """ Selective Kernel Convolution Module
+
+ As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
+
+ Largest change is the input split, which divides the input channels across each convolution path, this can
+ be viewed as a grouping of sorts, but the output channel counts expand to the module level value. This keeps
+ the parameter count from ballooning when the convolutions themselves don't have groups, but still provides
+ a noteworthy increase in performance over similar param count models without this attention layer. -Ross W
+
+ Args:
+ in_channels (int): module input (feature) channel count
+ out_channels (int): module output (feature) channel count
+ kernel_size (int, list): kernel size for each convolution branch
+ stride (int): stride for convolutions
+ dilation (int): dilation for module as a whole, impacts dilation of each branch
+ groups (int): number of groups for each branch
+ rd_ratio (int, float): reduction factor for attention features
+ keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
+ split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
+ can be viewed as grouping by path, output expands to module out_channels count
+ drop_block (nn.Module): drop block module
+ act_layer (nn.Module): activation layer to use
+ norm_layer (nn.Module): batchnorm/norm layer to use
+ """
+ super(SelectiveKernel, self).__init__()
+ out_channels = out_channels or in_channels
+ kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
+ _kernel_valid(kernel_size)
+ if not isinstance(kernel_size, list):
+ kernel_size = [kernel_size] * 2
+ if keep_3x3:
+ dilation = [dilation * (k - 1) // 2 for k in kernel_size]
+ kernel_size = [3] * len(kernel_size)
+ else:
+ dilation = [dilation] * len(kernel_size)
+ self.num_paths = len(kernel_size)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.split_input = split_input
+ if self.split_input:
+ assert in_channels % self.num_paths == 0
+ in_channels = in_channels // self.num_paths
+ groups = min(out_channels, groups)
+
+ conv_kwargs = dict(
+ stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer,
+ aa_layer=aa_layer)
+ self.paths = nn.ModuleList([
+ ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
+ for k, d in zip(kernel_size, dilation)])
+
+ attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
+ self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
+ self.drop_block = drop_block
+
+ def forward(self, x):
+ if self.split_input:
+ x_split = torch.split(x, self.in_channels // self.num_paths, 1)
+ x_paths = [op(x_split[i]) for i, op in enumerate(self.paths)]
+ else:
+ x_paths = [op(x) for op in self.paths]
+ x = torch.stack(x_paths, dim=1)
+ x_attn = self.attn(x)
+ x = x * x_attn
+ x = torch.sum(x, dim=1)
+ return x
diff --git a/timm/models/layers/separable_conv.py b/timm/models/layers/separable_conv.py
new file mode 100644
index 0000000..1ddcb4e
--- /dev/null
+++ b/timm/models/layers/separable_conv.py
@@ -0,0 +1,73 @@
+""" Depthwise Separable Conv Modules
+
+Basic DWS convs. Other variations of DWS exist with batch norm or activations between the
+DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from torch import nn as nn
+
+from .create_conv2d import create_conv2d
+from .create_norm_act import convert_norm_act
+
+
+class SeparableConvBnAct(nn.Module):
+ """ Separable Conv w/ trailing Norm and Activation
+ """
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
+ channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU,
+ apply_act=True, drop_block=None):
+ super(SeparableConvBnAct, self).__init__()
+
+ self.conv_dw = create_conv2d(
+ in_channels, int(in_channels * channel_multiplier), kernel_size,
+ stride=stride, dilation=dilation, padding=padding, depthwise=True)
+
+ self.conv_pw = create_conv2d(
+ int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
+
+ norm_act_layer = convert_norm_act(norm_layer, act_layer)
+ self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block)
+
+ @property
+ def in_channels(self):
+ return self.conv_dw.in_channels
+
+ @property
+ def out_channels(self):
+ return self.conv_pw.out_channels
+
+ def forward(self, x):
+ x = self.conv_dw(x)
+ x = self.conv_pw(x)
+ if self.bn is not None:
+ x = self.bn(x)
+ return x
+
+
+class SeparableConv2d(nn.Module):
+ """ Separable Conv
+ """
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
+ channel_multiplier=1.0, pw_kernel_size=1):
+ super(SeparableConv2d, self).__init__()
+
+ self.conv_dw = create_conv2d(
+ in_channels, int(in_channels * channel_multiplier), kernel_size,
+ stride=stride, dilation=dilation, padding=padding, depthwise=True)
+
+ self.conv_pw = create_conv2d(
+ int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
+
+ @property
+ def in_channels(self):
+ return self.conv_dw.in_channels
+
+ @property
+ def out_channels(self):
+ return self.conv_pw.out_channels
+
+ def forward(self, x):
+ x = self.conv_dw(x)
+ x = self.conv_pw(x)
+ return x
diff --git a/timm/models/layers/space_to_depth.py b/timm/models/layers/space_to_depth.py
new file mode 100644
index 0000000..a7e8e0b
--- /dev/null
+++ b/timm/models/layers/space_to_depth.py
@@ -0,0 +1,53 @@
+import torch
+import torch.nn as nn
+
+
+class SpaceToDepth(nn.Module):
+ def __init__(self, block_size=4):
+ super().__init__()
+ assert block_size == 4
+ self.bs = block_size
+
+ def forward(self, x):
+ N, C, H, W = x.size()
+ x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
+ x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
+ return x
+
+
+@torch.jit.script
+class SpaceToDepthJit(object):
+ def __call__(self, x: torch.Tensor):
+ # assuming hard-coded that block_size==4 for acceleration
+ N, C, H, W = x.size()
+ x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
+ x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
+ x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
+ return x
+
+
+class SpaceToDepthModule(nn.Module):
+ def __init__(self, no_jit=False):
+ super().__init__()
+ if not no_jit:
+ self.op = SpaceToDepthJit()
+ else:
+ self.op = SpaceToDepth()
+
+ def forward(self, x):
+ return self.op(x)
+
+
+class DepthToSpace(nn.Module):
+
+ def __init__(self, block_size):
+ super().__init__()
+ self.bs = block_size
+
+ def forward(self, x):
+ N, C, H, W = x.size()
+ x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
+ x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
+ x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
+ return x
diff --git a/timm/models/layers/split_attn.py b/timm/models/layers/split_attn.py
new file mode 100644
index 0000000..dde601b
--- /dev/null
+++ b/timm/models/layers/split_attn.py
@@ -0,0 +1,85 @@
+""" Split Attention Conv2d (for ResNeSt Models)
+
+Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
+
+Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
+
+Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .helpers import make_divisible
+
+
+class RadixSoftmax(nn.Module):
+ def __init__(self, radix, cardinality):
+ super(RadixSoftmax, self).__init__()
+ self.radix = radix
+ self.cardinality = cardinality
+
+ def forward(self, x):
+ batch = x.size(0)
+ if self.radix > 1:
+ x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
+ x = F.softmax(x, dim=1)
+ x = x.reshape(batch, -1)
+ else:
+ x = torch.sigmoid(x)
+ return x
+
+
+class SplitAttn(nn.Module):
+ """Split-Attention (aka Splat)
+ """
+ def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
+ dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
+ act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
+ super(SplitAttn, self).__init__()
+ out_channels = out_channels or in_channels
+ self.radix = radix
+ self.drop_block = drop_block
+ mid_chs = out_channels * radix
+ if rd_channels is None:
+ attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
+ else:
+ attn_chs = rd_channels * radix
+
+ padding = kernel_size // 2 if padding is None else padding
+ self.conv = nn.Conv2d(
+ in_channels, mid_chs, kernel_size, stride, padding, dilation,
+ groups=groups * radix, bias=bias, **kwargs)
+ self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
+ self.act0 = act_layer(inplace=True)
+ self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
+ self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
+ self.act1 = act_layer(inplace=True)
+ self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
+ self.rsoftmax = RadixSoftmax(radix, groups)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn0(x)
+ if self.drop_block is not None:
+ x = self.drop_block(x)
+ x = self.act0(x)
+
+ B, RC, H, W = x.shape
+ if self.radix > 1:
+ x = x.reshape((B, self.radix, RC // self.radix, H, W))
+ x_gap = x.sum(dim=1)
+ else:
+ x_gap = x
+ x_gap = x_gap.mean((2, 3), keepdim=True)
+ x_gap = self.fc1(x_gap)
+ x_gap = self.bn1(x_gap)
+ x_gap = self.act1(x_gap)
+ x_attn = self.fc2(x_gap)
+
+ x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
+ if self.radix > 1:
+ out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
+ else:
+ out = x * x_attn
+ return out.contiguous()
diff --git a/timm/models/layers/split_batchnorm.py b/timm/models/layers/split_batchnorm.py
new file mode 100644
index 0000000..830781b
--- /dev/null
+++ b/timm/models/layers/split_batchnorm.py
@@ -0,0 +1,75 @@
+""" Split BatchNorm
+
+A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
+a separate BN layer. The first split is passed through the parent BN layers with weight/bias
+keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
+namespace.
+
+This allows easily removing the auxiliary BN layers after training to efficiently
+achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
+'Disentangled Learning via An Auxiliary BN'
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+
+
+class SplitBatchNorm2d(torch.nn.BatchNorm2d):
+
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
+ track_running_stats=True, num_splits=2):
+ super().__init__(num_features, eps, momentum, affine, track_running_stats)
+ assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
+ self.num_splits = num_splits
+ self.aux_bn = nn.ModuleList([
+ nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
+
+ def forward(self, input: torch.Tensor):
+ if self.training: # aux BN only relevant while training
+ split_size = input.shape[0] // self.num_splits
+ assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
+ split_input = input.split(split_size)
+ x = [super().forward(split_input[0])]
+ for i, a in enumerate(self.aux_bn):
+ x.append(a(split_input[i + 1]))
+ return torch.cat(x, dim=0)
+ else:
+ return super().forward(input)
+
+
+def convert_splitbn_model(module, num_splits=2):
+ """
+ Recursively traverse module and its children to replace all instances of
+ ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
+ Args:
+ module (torch.nn.Module): input module
+ num_splits: number of separate batchnorm layers to split input across
+ Example::
+ >>> # model is an instance of torch.nn.Module
+ >>> model = timm.models.convert_splitbn_model(model, num_splits=2)
+ """
+ mod = module
+ if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
+ return module
+ if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
+ mod = SplitBatchNorm2d(
+ module.num_features, module.eps, module.momentum, module.affine,
+ module.track_running_stats, num_splits=num_splits)
+ mod.running_mean = module.running_mean
+ mod.running_var = module.running_var
+ mod.num_batches_tracked = module.num_batches_tracked
+ if module.affine:
+ mod.weight.data = module.weight.data.clone().detach()
+ mod.bias.data = module.bias.data.clone().detach()
+ for aux in mod.aux_bn:
+ aux.running_mean = module.running_mean.clone()
+ aux.running_var = module.running_var.clone()
+ aux.num_batches_tracked = module.num_batches_tracked.clone()
+ if module.affine:
+ aux.weight.data = module.weight.data.clone().detach()
+ aux.bias.data = module.bias.data.clone().detach()
+ for name, child in module.named_children():
+ mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
+ del module
+ return mod
diff --git a/timm/models/layers/squeeze_excite.py b/timm/models/layers/squeeze_excite.py
new file mode 100644
index 0000000..e5da29e
--- /dev/null
+++ b/timm/models/layers/squeeze_excite.py
@@ -0,0 +1,74 @@
+""" Squeeze-and-Excitation Channel Attention
+
+An SE implementation originally based on PyTorch SE-Net impl.
+Has since evolved with additional functionality / configuration.
+
+Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507
+
+Also included is Effective Squeeze-Excitation (ESE).
+Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from torch import nn as nn
+
+from .create_act import create_act_layer
+from .helpers import make_divisible
+
+
+class SEModule(nn.Module):
+ """ SE Module as defined in original SE-Nets with a few additions
+ Additions include:
+ * divisor can be specified to keep channels % div == 0 (default: 8)
+ * reduction channels can be specified directly by arg (if rd_channels is set)
+ * reduction channels can be specified by float rd_ratio (default: 1/16)
+ * global max pooling can be added to the squeeze aggregation
+ * customizable activation, normalization, and gate layer
+ """
+ def __init__(
+ self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
+ act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
+ super(SEModule, self).__init__()
+ self.add_maxpool = add_maxpool
+ if not rd_channels:
+ rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
+ self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True)
+ self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
+ self.act = create_act_layer(act_layer, inplace=True)
+ self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True)
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ x_se = x.mean((2, 3), keepdim=True)
+ if self.add_maxpool:
+ # experimental codepath, may remove or change
+ x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
+ x_se = self.fc1(x_se)
+ x_se = self.act(self.bn(x_se))
+ x_se = self.fc2(x_se)
+ return x * self.gate(x_se)
+
+
+SqueezeExcite = SEModule # alias
+
+
+class EffectiveSEModule(nn.Module):
+ """ 'Effective Squeeze-Excitation
+ From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
+ """
+ def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_):
+ super(EffectiveSEModule, self).__init__()
+ self.add_maxpool = add_maxpool
+ self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
+ self.gate = create_act_layer(gate_layer)
+
+ def forward(self, x):
+ x_se = x.mean((2, 3), keepdim=True)
+ if self.add_maxpool:
+ # experimental codepath, may remove or change
+ x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
+ x_se = self.fc(x_se)
+ return x * self.gate(x_se)
+
+
+EffectiveSqueezeExcite = EffectiveSEModule # alias
diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py
new file mode 100644
index 0000000..d896ba5
--- /dev/null
+++ b/timm/models/layers/std_conv.py
@@ -0,0 +1,133 @@
+""" Convolution with Weight Standardization (StdConv and ScaledStdConv)
+
+StdConv:
+@article{weightstandardization,
+ author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille},
+ title = {Weight Standardization},
+ journal = {arXiv preprint arXiv:1903.10520},
+ year = {2019},
+}
+Code: https://github.com/joe-siyuan-qiao/WeightStandardization
+
+ScaledStdConv:
+Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
+ - https://arxiv.org/abs/2101.08692
+Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
+
+Hacked together by / copyright Ross Wightman, 2021.
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .padding import get_padding, get_padding_value, pad_same
+
+
+class StdConv2d(nn.Conv2d):
+ """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
+
+ Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
+ https://arxiv.org/abs/1903.10520v2
+ """
+ def __init__(
+ self, in_channel, out_channels, kernel_size, stride=1, padding=None,
+ dilation=1, groups=1, bias=False, eps=1e-6):
+ if padding is None:
+ padding = get_padding(kernel_size, stride, dilation)
+ super().__init__(
+ in_channel, out_channels, kernel_size, stride=stride,
+ padding=padding, dilation=dilation, groups=groups, bias=bias)
+ self.eps = eps
+
+ def forward(self, x):
+ weight = F.batch_norm(
+ self.weight.reshape(1, self.out_channels, -1), None, None,
+ training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
+ x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+ return x
+
+
+class StdConv2dSame(nn.Conv2d):
+ """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model.
+
+ Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
+ https://arxiv.org/abs/1903.10520v2
+ """
+ def __init__(
+ self, in_channel, out_channels, kernel_size, stride=1, padding='SAME',
+ dilation=1, groups=1, bias=False, eps=1e-6):
+ padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
+ super().__init__(
+ in_channel, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
+ groups=groups, bias=bias)
+ self.same_pad = is_dynamic
+ self.eps = eps
+
+ def forward(self, x):
+ if self.same_pad:
+ x = pad_same(x, self.kernel_size, self.stride, self.dilation)
+ weight = F.batch_norm(
+ self.weight.reshape(1, self.out_channels, -1), None, None,
+ training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
+ x = F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+ return x
+
+
+class ScaledStdConv2d(nn.Conv2d):
+ """Conv2d layer with Scaled Weight Standardization.
+
+ Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
+ https://arxiv.org/abs/2101.08692
+
+ NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.
+ """
+
+ def __init__(
+ self, in_channels, out_channels, kernel_size, stride=1, padding=None,
+ dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
+ if padding is None:
+ padding = get_padding(kernel_size, stride, dilation)
+ super().__init__(
+ in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
+ groups=groups, bias=bias)
+ self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
+ self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
+ self.eps = eps
+
+ def forward(self, x):
+ weight = F.batch_norm(
+ self.weight.reshape(1, self.out_channels, -1), None, None,
+ weight=(self.gain * self.scale).view(-1),
+ training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+
+class ScaledStdConv2dSame(nn.Conv2d):
+ """Conv2d layer with Scaled Weight Standardization and Tensorflow-like SAME padding support
+
+ Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
+ https://arxiv.org/abs/2101.08692
+
+ NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.
+ """
+
+ def __init__(
+ self, in_channels, out_channels, kernel_size, stride=1, padding='SAME',
+ dilation=1, groups=1, bias=True, gamma=1.0, eps=1e-6, gain_init=1.0):
+ padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, dilation=dilation)
+ super().__init__(
+ in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
+ groups=groups, bias=bias)
+ self.gain = nn.Parameter(torch.full((self.out_channels, 1, 1, 1), gain_init))
+ self.scale = gamma * self.weight[0].numel() ** -0.5
+ self.same_pad = is_dynamic
+ self.eps = eps
+
+ def forward(self, x):
+ if self.same_pad:
+ x = pad_same(x, self.kernel_size, self.stride, self.dilation)
+ weight = F.batch_norm(
+ self.weight.reshape(1, self.out_channels, -1), None, None,
+ weight=(self.gain * self.scale).view(-1),
+ training=True, momentum=0., eps=self.eps).reshape_as(self.weight)
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
diff --git a/timm/models/layers/test_time_pool.py b/timm/models/layers/test_time_pool.py
new file mode 100644
index 0000000..98c0bf5
--- /dev/null
+++ b/timm/models/layers/test_time_pool.py
@@ -0,0 +1,52 @@
+""" Test Time Pooling (Average-Max Pool)
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import logging
+from torch import nn
+import torch.nn.functional as F
+
+from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TestTimePoolHead(nn.Module):
+ def __init__(self, base, original_pool=7):
+ super(TestTimePoolHead, self).__init__()
+ self.base = base
+ self.original_pool = original_pool
+ base_fc = self.base.get_classifier()
+ if isinstance(base_fc, nn.Conv2d):
+ self.fc = base_fc
+ else:
+ self.fc = nn.Conv2d(
+ self.base.num_features, self.base.num_classes, kernel_size=1, bias=True)
+ self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))
+ self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))
+ self.base.reset_classifier(0) # delete original fc layer
+
+ def forward(self, x):
+ x = self.base.forward_features(x)
+ x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)
+ x = self.fc(x)
+ x = adaptive_avgmax_pool2d(x, 1)
+ return x.view(x.size(0), -1)
+
+
+def apply_test_time_pool(model, config, use_test_size=True):
+ test_time_pool = False
+ if not hasattr(model, 'default_cfg') or not model.default_cfg:
+ return model, False
+ if use_test_size and 'test_input_size' in model.default_cfg:
+ df_input_size = model.default_cfg['test_input_size']
+ else:
+ df_input_size = model.default_cfg['input_size']
+ if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]:
+ _logger.info('Target input size %s > pretrained default %s, using test time pooling' %
+ (str(config['input_size'][-2:]), str(df_input_size[-2:])))
+ model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
+ test_time_pool = True
+ return model, test_time_pool
diff --git a/timm/models/layers/trace_utils.py b/timm/models/layers/trace_utils.py
new file mode 100644
index 0000000..8397072
--- /dev/null
+++ b/timm/models/layers/trace_utils.py
@@ -0,0 +1,13 @@
+try:
+ from torch import _assert
+except ImportError:
+ def _assert(condition: bool, message: str):
+ assert condition, message
+
+
+def _float_to_int(x: float) -> int:
+ """
+ Symbolic tracing helper to substitute for inbuilt `int`.
+ Hint: Inbuilt `int` can't accept an argument of type `Proxy`
+ """
+ return int(x)
diff --git a/timm/models/layers/weight_init.py b/timm/models/layers/weight_init.py
new file mode 100644
index 0000000..305a2fd
--- /dev/null
+++ b/timm/models/layers/weight_init.py
@@ -0,0 +1,89 @@
+import torch
+import math
+import warnings
+
+from torch.nn.init import _calculate_fan_in_and_fan_out
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ # type: (Tensor, float, float, float, float) -> Tensor
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution. The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
+ if mode == 'fan_in':
+ denom = fan_in
+ elif mode == 'fan_out':
+ denom = fan_out
+ elif mode == 'fan_avg':
+ denom = (fan_in + fan_out) / 2
+
+ variance = scale / denom
+
+ if distribution == "truncated_normal":
+ # constant is stddev of standard normal truncated to (-2, 2)
+ trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
+ elif distribution == "normal":
+ tensor.normal_(std=math.sqrt(variance))
+ elif distribution == "uniform":
+ bound = math.sqrt(3 * variance)
+ tensor.uniform_(-bound, bound)
+ else:
+ raise ValueError(f"invalid distribution {distribution}")
+
+
+def lecun_normal_(tensor):
+ variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
diff --git a/timm/models/levit.py b/timm/models/levit.py
new file mode 100644
index 0000000..9987e4b
--- /dev/null
+++ b/timm/models/levit.py
@@ -0,0 +1,563 @@
+""" LeViT
+
+Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
+ - https://arxiv.org/abs/2104.01136
+
+@article{graham2021levit,
+ title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
+ author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
+ journal={arXiv preprint arXiv:22104.01136},
+ year={2021}
+}
+
+Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
+
+This version combines both conv/linear models and fixes torchscript compatibility.
+
+Modifications by/coyright Copyright 2021 Ross Wightman
+"""
+
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+
+# Modified from
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+# Copyright 2020 Ross Wightman, Apache-2.0 License
+import itertools
+from copy import deepcopy
+from functools import partial
+from typing import Dict
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
+from .helpers import build_model_with_cfg, overlay_external_default_cfg
+from .layers import to_ntuple, get_act_layer
+from .vision_transformer import trunc_normal_
+from .registry import register_model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'),
+ **kwargs
+ }
+
+
+default_cfgs = dict(
+ levit_128s=_cfg(
+ url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'
+ ),
+ levit_128=_cfg(
+ url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'
+ ),
+ levit_192=_cfg(
+ url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'
+ ),
+ levit_256=_cfg(
+ url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'
+ ),
+ levit_384=_cfg(
+ url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'
+ ),
+)
+
+model_cfgs = dict(
+ levit_128s=dict(
+ embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
+ levit_128=dict(
+ embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
+ levit_192=dict(
+ embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
+ levit_256=dict(
+ embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
+ levit_384=dict(
+ embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
+)
+
+__all__ = ['Levit']
+
+
+@register_model
+def levit_128s(pretrained=False, use_conv=False, **kwargs):
+ return create_levit(
+ 'levit_128s', pretrained=pretrained, use_conv=use_conv, **kwargs)
+
+
+@register_model
+def levit_128(pretrained=False, use_conv=False, **kwargs):
+ return create_levit(
+ 'levit_128', pretrained=pretrained, use_conv=use_conv, **kwargs)
+
+
+@register_model
+def levit_192(pretrained=False, use_conv=False, **kwargs):
+ return create_levit(
+ 'levit_192', pretrained=pretrained, use_conv=use_conv, **kwargs)
+
+
+@register_model
+def levit_256(pretrained=False, use_conv=False, **kwargs):
+ return create_levit(
+ 'levit_256', pretrained=pretrained, use_conv=use_conv, **kwargs)
+
+
+@register_model
+def levit_384(pretrained=False, use_conv=False, **kwargs):
+ return create_levit(
+ 'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs)
+
+
+class ConvNorm(nn.Sequential):
+ def __init__(
+ self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000):
+ super().__init__()
+ self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
+ bn = nn.BatchNorm2d(b)
+ nn.init.constant_(bn.weight, bn_weight_init)
+ nn.init.constant_(bn.bias, 0)
+ self.add_module('bn', bn)
+
+ @torch.no_grad()
+ def fuse(self):
+ c, bn = self._modules.values()
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
+ w = c.weight * w[:, None, None, None]
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
+ m = nn.Conv2d(
+ w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
+ padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+
+class LinearNorm(nn.Sequential):
+ def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
+ super().__init__()
+ self.add_module('c', nn.Linear(a, b, bias=False))
+ bn = nn.BatchNorm1d(b)
+ nn.init.constant_(bn.weight, bn_weight_init)
+ nn.init.constant_(bn.bias, 0)
+ self.add_module('bn', bn)
+
+ @torch.no_grad()
+ def fuse(self):
+ l, bn = self._modules.values()
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
+ w = l.weight * w[:, None]
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
+ m = nn.Linear(w.size(1), w.size(0))
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+ def forward(self, x):
+ x = self.c(x)
+ return self.bn(x.flatten(0, 1)).reshape_as(x)
+
+
+class NormLinear(nn.Sequential):
+ def __init__(self, a, b, bias=True, std=0.02):
+ super().__init__()
+ self.add_module('bn', nn.BatchNorm1d(a))
+ l = nn.Linear(a, b, bias=bias)
+ trunc_normal_(l.weight, std=std)
+ if bias:
+ nn.init.constant_(l.bias, 0)
+ self.add_module('l', l)
+
+ @torch.no_grad()
+ def fuse(self):
+ bn, l = self._modules.values()
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
+ b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
+ w = l.weight * w[None, :]
+ if l.bias is None:
+ b = b @ self.l.weight.T
+ else:
+ b = (l.weight @ b[:, None]).view(-1) + self.l.bias
+ m = nn.Linear(w.size(1), w.size(0))
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+
+def stem_b16(in_chs, out_chs, activation, resolution=224):
+ return nn.Sequential(
+ ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution),
+ activation(),
+ ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2),
+ activation(),
+ ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4),
+ activation(),
+ ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8))
+
+
+class Residual(nn.Module):
+ def __init__(self, m, drop):
+ super().__init__()
+ self.m = m
+ self.drop = drop
+
+ def forward(self, x):
+ if self.training and self.drop > 0:
+ return x + self.m(x) * torch.rand(
+ x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
+ else:
+ return x + self.m(x)
+
+
+class Subsample(nn.Module):
+ def __init__(self, stride, resolution):
+ super().__init__()
+ self.stride = stride
+ self.resolution = resolution
+
+ def forward(self, x):
+ B, N, C = x.shape
+ x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride]
+ return x.reshape(B, -1, C)
+
+
+class Attention(nn.Module):
+ ab: Dict[str, torch.Tensor]
+
+ def __init__(
+ self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
+ super().__init__()
+
+ self.num_heads = num_heads
+ self.scale = key_dim ** -0.5
+ self.key_dim = key_dim
+ self.nh_kd = nh_kd = key_dim * num_heads
+ self.d = int(attn_ratio * key_dim)
+ self.dh = int(attn_ratio * key_dim) * num_heads
+ self.attn_ratio = attn_ratio
+ self.use_conv = use_conv
+ ln_layer = ConvNorm if self.use_conv else LinearNorm
+ h = self.dh + nh_kd * 2
+ self.qkv = ln_layer(dim, h, resolution=resolution)
+ self.proj = nn.Sequential(
+ act_layer(),
+ ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution))
+
+ points = list(itertools.product(range(resolution), range(resolution)))
+ N = len(points)
+ attention_offsets = {}
+ idxs = []
+ for p1 in points:
+ for p2 in points:
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ idxs.append(attention_offsets[offset])
+ self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
+ self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
+ self.ab = {}
+
+ @torch.no_grad()
+ def train(self, mode=True):
+ super().train(mode)
+ if mode and self.ab:
+ self.ab = {} # clear ab cache
+
+ def get_attention_biases(self, device: torch.device) -> torch.Tensor:
+ if self.training:
+ return self.attention_biases[:, self.attention_bias_idxs]
+ else:
+ device_key = str(device)
+ if device_key not in self.ab:
+ self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
+ return self.ab[device_key]
+
+ def forward(self, x): # x (B,C,H,W)
+ if self.use_conv:
+ B, C, H, W = x.shape
+ q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
+
+ attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
+ attn = attn.softmax(dim=-1)
+
+ x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
+ else:
+ B, N, C = x.shape
+ qkv = self.qkv(x)
+ q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
+ q = q.permute(0, 2, 1, 3)
+ k = k.permute(0, 2, 1, 3)
+ v = v.permute(0, 2, 1, 3)
+
+ attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
+ attn = attn.softmax(dim=-1)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
+ x = self.proj(x)
+ return x
+
+
+class AttentionSubsample(nn.Module):
+ ab: Dict[str, torch.Tensor]
+
+ def __init__(
+ self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
+ act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False):
+ super().__init__()
+ self.num_heads = num_heads
+ self.scale = key_dim ** -0.5
+ self.key_dim = key_dim
+ self.nh_kd = nh_kd = key_dim * num_heads
+ self.d = int(attn_ratio * key_dim)
+ self.dh = self.d * self.num_heads
+ self.attn_ratio = attn_ratio
+ self.resolution_ = resolution_
+ self.resolution_2 = resolution_ ** 2
+ self.use_conv = use_conv
+ if self.use_conv:
+ ln_layer = ConvNorm
+ sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0)
+ else:
+ ln_layer = LinearNorm
+ sub_layer = partial(Subsample, resolution=resolution)
+
+ h = self.dh + nh_kd
+ self.kv = ln_layer(in_dim, h, resolution=resolution)
+ self.q = nn.Sequential(
+ sub_layer(stride=stride),
+ ln_layer(in_dim, nh_kd, resolution=resolution_))
+ self.proj = nn.Sequential(
+ act_layer(),
+ ln_layer(self.dh, out_dim, resolution=resolution_))
+
+ self.stride = stride
+ self.resolution = resolution
+ points = list(itertools.product(range(resolution), range(resolution)))
+ points_ = list(itertools.product(range(resolution_), range(resolution_)))
+ N = len(points)
+ N_ = len(points_)
+ attention_offsets = {}
+ idxs = []
+ for p1 in points_:
+ for p2 in points:
+ size = 1
+ offset = (
+ abs(p1[0] * stride - p2[0] + (size - 1) / 2),
+ abs(p1[1] * stride - p2[1] + (size - 1) / 2))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ idxs.append(attention_offsets[offset])
+ self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
+ self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
+ self.ab = {} # per-device attention_biases cache
+
+ @torch.no_grad()
+ def train(self, mode=True):
+ super().train(mode)
+ if mode and self.ab:
+ self.ab = {} # clear ab cache
+
+ def get_attention_biases(self, device: torch.device) -> torch.Tensor:
+ if self.training:
+ return self.attention_biases[:, self.attention_bias_idxs]
+ else:
+ device_key = str(device)
+ if device_key not in self.ab:
+ self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
+ return self.ab[device_key]
+
+ def forward(self, x):
+ if self.use_conv:
+ B, C, H, W = x.shape
+ k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
+ q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
+
+ attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
+ attn = attn.softmax(dim=-1)
+
+ x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
+ else:
+ B, N, C = x.shape
+ k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3)
+ k = k.permute(0, 2, 1, 3) # BHNC
+ v = v.permute(0, 2, 1, 3) # BHNC
+ q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
+
+ attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
+ attn = attn.softmax(dim=-1)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
+ x = self.proj(x)
+ return x
+
+
+class Levit(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+
+ NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems
+ w/ train scripts that don't take tuple outputs,
+ """
+
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ num_classes=1000,
+ embed_dim=(192,),
+ key_dim=64,
+ depth=(12,),
+ num_heads=(3,),
+ attn_ratio=2,
+ mlp_ratio=2,
+ hybrid_backbone=None,
+ down_ops=None,
+ act_layer='hard_swish',
+ attn_act_layer='hard_swish',
+ distillation=True,
+ use_conv=False,
+ drop_rate=0.,
+ drop_path_rate=0.):
+ super().__init__()
+ act_layer = get_act_layer(act_layer)
+ attn_act_layer = get_act_layer(attn_act_layer)
+ if isinstance(img_size, tuple):
+ # FIXME origin impl passes single img/res dim through whole hierarchy,
+ # not sure this model will be used enough to spend time fixing it.
+ assert img_size[0] == img_size[1]
+ img_size = img_size[0]
+ self.num_classes = num_classes
+ self.num_features = embed_dim[-1]
+ self.embed_dim = embed_dim
+ N = len(embed_dim)
+ assert len(depth) == len(num_heads) == N
+ key_dim = to_ntuple(N)(key_dim)
+ attn_ratio = to_ntuple(N)(attn_ratio)
+ mlp_ratio = to_ntuple(N)(mlp_ratio)
+ down_ops = down_ops or (
+ # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
+ ('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2),
+ ('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2),
+ ('',)
+ )
+ self.distillation = distillation
+ self.use_conv = use_conv
+ ln_layer = ConvNorm if self.use_conv else LinearNorm
+
+ self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer)
+
+ self.blocks = []
+ resolution = img_size // patch_size
+ for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
+ zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
+ for _ in range(dpth):
+ self.blocks.append(
+ Residual(
+ Attention(
+ ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer,
+ resolution=resolution, use_conv=use_conv),
+ drop_path_rate))
+ if mr > 0:
+ h = int(ed * mr)
+ self.blocks.append(
+ Residual(nn.Sequential(
+ ln_layer(ed, h, resolution=resolution),
+ act_layer(),
+ ln_layer(h, ed, bn_weight_init=0, resolution=resolution),
+ ), drop_path_rate))
+ if do[0] == 'Subsample':
+ # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
+ resolution_ = (resolution - 1) // do[5] + 1
+ self.blocks.append(
+ AttentionSubsample(
+ *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2],
+ attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5],
+ resolution=resolution, resolution_=resolution_, use_conv=use_conv))
+ resolution = resolution_
+ if do[4] > 0: # mlp_ratio
+ h = int(embed_dim[i + 1] * do[4])
+ self.blocks.append(
+ Residual(nn.Sequential(
+ ln_layer(embed_dim[i + 1], h, resolution=resolution),
+ act_layer(),
+ ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
+ ), drop_path_rate))
+ self.blocks = nn.Sequential(*self.blocks)
+
+ # Classifier head
+ self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
+ self.head_dist = None
+ if distillation:
+ self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {x for x in self.state_dict().keys() if 'attention_biases' in x}
+
+ def get_classifier(self):
+ if self.head_dist is None:
+ return self.head
+ else:
+ return self.head, self.head_dist
+
+ def reset_classifier(self, num_classes, global_pool='', distillation=None):
+ self.num_classes = num_classes
+ self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
+ if distillation is not None:
+ self.distillation = distillation
+ if self.distillation:
+ self.head_dist = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
+ else:
+ self.head_dist = None
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ if not self.use_conv:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.blocks(x)
+ x = x.mean((-2, -1)) if self.use_conv else x.mean(1)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ if self.head_dist is not None:
+ x, x_dist = self.head(x), self.head_dist(x)
+ if self.training and not torch.jit.is_scripting():
+ return x, x_dist
+ else:
+ # during inference, return the average of both classifier predictions
+ return (x + x_dist) / 2
+ else:
+ x = self.head(x)
+ return x
+
+
+def checkpoint_filter_fn(state_dict, model):
+ if 'model' in state_dict:
+ # For deit models
+ state_dict = state_dict['model']
+ D = model.state_dict()
+ for k in state_dict.keys():
+ if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2:
+ state_dict[k] = state_dict[k][:, :, None, None]
+ return state_dict
+
+
+def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ model_cfg = dict(**model_cfgs[variant], **kwargs)
+ model = build_model_with_cfg(
+ Levit, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **model_cfg)
+ #if fuse:
+ # utils.replace_batchnorm(model)
+ return model
+
diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py
new file mode 100644
index 0000000..727b655
--- /dev/null
+++ b/timm/models/mlp_mixer.py
@@ -0,0 +1,659 @@
+""" MLP-Mixer, ResMLP, and gMLP in PyTorch
+
+This impl originally based on MLP-Mixer paper.
+
+Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
+
+Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+
+@article{tolstikhin2021,
+ title={MLP-Mixer: An all-MLP Architecture for Vision},
+ author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner,
+ Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
+ journal={arXiv preprint arXiv:2105.01601},
+ year={2021}
+}
+
+Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP
+
+Code: https://github.com/facebookresearch/deit
+Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+@misc{touvron2021resmlp,
+ title={ResMLP: Feedforward networks for image classification with data-efficient training},
+ author={Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and
+ Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
+ year={2021},
+ eprint={2105.03404},
+}
+
+Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
+@misc{liu2021pay,
+ title={Pay Attention to MLPs},
+ author={Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
+ year={2021},
+ eprint={2105.08050},
+}
+
+A thank you to paper authors for releasing code and weights.
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import math
+from copy import deepcopy
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply
+from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
+from .registry import register_model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+ 'first_conv': 'stem.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = dict(
+ mixer_s32_224=_cfg(),
+ mixer_s16_224=_cfg(),
+ mixer_b32_224=_cfg(),
+ mixer_b16_224=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
+ ),
+ mixer_b16_224_in21k=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth',
+ num_classes=21843
+ ),
+ mixer_l32_224=_cfg(),
+ mixer_l16_224=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth',
+ ),
+ mixer_l16_224_in21k=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
+ num_classes=21843
+ ),
+
+ # Mixer ImageNet-21K-P pretraining
+ mixer_b16_224_miil_in21k=_cfg(
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil_in21k.pth',
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
+ ),
+ mixer_b16_224_miil=_cfg(
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil.pth',
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
+ ),
+
+ gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ gmixer_24_224=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+
+ resmlp_12_224=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ resmlp_24_224=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth',
+ #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ resmlp_36_224=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ resmlp_big_24_224=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+
+ resmlp_12_distilled_224=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ resmlp_24_distilled_224=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ resmlp_36_distilled_224=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ resmlp_big_24_distilled_224=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+
+ resmlp_big_24_224_in22ft1k=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+
+ resmlp_12_224_dino=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dino.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ resmlp_24_224_dino=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dino.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+
+ gmlp_ti16_224=_cfg(),
+ gmlp_s16_224=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth',
+ ),
+ gmlp_b16_224=_cfg(),
+)
+
+
+class MixerBlock(nn.Module):
+ """ Residual Block w/ token mixing and channel MLPs
+ Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ def __init__(
+ self, dim, seq_len, mlp_ratio=(0.5, 4.0), mlp_layer=Mlp,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
+ super().__init__()
+ tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
+ self.norm1 = norm_layer(dim)
+ self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
+ x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
+ return x
+
+
+class Affine(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.alpha = nn.Parameter(torch.ones((1, 1, dim)))
+ self.beta = nn.Parameter(torch.zeros((1, 1, dim)))
+
+ def forward(self, x):
+ return torch.addcmul(self.beta, self.alpha, x)
+
+
+class ResBlock(nn.Module):
+ """ Residual MLP block w/ LayerScale and Affine 'norm'
+
+ Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ def __init__(
+ self, dim, seq_len, mlp_ratio=4, mlp_layer=Mlp, norm_layer=Affine,
+ act_layer=nn.GELU, init_values=1e-4, drop=0., drop_path=0.):
+ super().__init__()
+ channel_dim = int(dim * mlp_ratio)
+ self.norm1 = norm_layer(dim)
+ self.linear_tokens = nn.Linear(seq_len, seq_len)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop)
+ self.ls1 = nn.Parameter(init_values * torch.ones(dim))
+ self.ls2 = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ x = x + self.drop_path(self.ls1 * self.linear_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
+ x = x + self.drop_path(self.ls2 * self.mlp_channels(self.norm2(x)))
+ return x
+
+
+class SpatialGatingUnit(nn.Module):
+ """ Spatial Gating Unit
+
+ Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
+ """
+ def __init__(self, dim, seq_len, norm_layer=nn.LayerNorm):
+ super().__init__()
+ gate_dim = dim // 2
+ self.norm = norm_layer(gate_dim)
+ self.proj = nn.Linear(seq_len, seq_len)
+
+ def init_weights(self):
+ # special init for the projection gate, called as override by base model init
+ nn.init.normal_(self.proj.weight, std=1e-6)
+ nn.init.ones_(self.proj.bias)
+
+ def forward(self, x):
+ u, v = x.chunk(2, dim=-1)
+ v = self.norm(v)
+ v = self.proj(v.transpose(-1, -2))
+ return u * v.transpose(-1, -2)
+
+
+class SpatialGatingBlock(nn.Module):
+ """ Residual Block w/ Spatial Gating
+
+ Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
+ """
+ def __init__(
+ self, dim, seq_len, mlp_ratio=4, mlp_layer=GatedMlp,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, drop=0., drop_path=0.):
+ super().__init__()
+ channel_dim = int(dim * mlp_ratio)
+ self.norm = norm_layer(dim)
+ sgu = partial(SpatialGatingUnit, seq_len=seq_len)
+ self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ x = x + self.drop_path(self.mlp_channels(self.norm(x)))
+ return x
+
+
+class MlpMixer(nn.Module):
+
+ def __init__(
+ self,
+ num_classes=1000,
+ img_size=224,
+ in_chans=3,
+ patch_size=16,
+ num_blocks=8,
+ embed_dim=512,
+ mlp_ratio=(0.5, 4.0),
+ block_layer=MixerBlock,
+ mlp_layer=Mlp,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ drop_rate=0.,
+ drop_path_rate=0.,
+ nlhb=False,
+ stem_norm=False,
+ ):
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ self.stem = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans,
+ embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None)
+ # FIXME drop_path (stochastic depth scaling rule or all the same?)
+ self.blocks = nn.Sequential(*[
+ block_layer(
+ embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
+ act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
+ for _ in range(num_blocks)])
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ self.init_weights(nlhb=nlhb)
+
+ def init_weights(self, nlhb=False):
+ head_bias = -math.log(self.num_classes) if nlhb else 0.
+ named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ x = self.blocks(x)
+ x = self.norm(x)
+ x = x.mean(dim=1)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
+ """ Mixer weight initialization (trying to match Flax defaults)
+ """
+ if isinstance(module, nn.Linear):
+ if name.startswith('head'):
+ nn.init.zeros_(module.weight)
+ nn.init.constant_(module.bias, head_bias)
+ else:
+ if flax:
+ # Flax defaults
+ lecun_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ else:
+ # like MLP init in vit (my original init)
+ nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ if 'mlp' in name:
+ nn.init.normal_(module.bias, std=1e-6)
+ else:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Conv2d):
+ lecun_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.ones_(module.weight)
+ nn.init.zeros_(module.bias)
+ elif hasattr(module, 'init_weights'):
+ # NOTE if a parent module contains init_weights method, it can override the init of the
+ # child modules as this will be called in depth-first order.
+ module.init_weights()
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """ Remap checkpoints if needed """
+ if 'patch_embed.proj.weight' in state_dict:
+ # Remap FB ResMlp models -> timm
+ out_dict = {}
+ for k, v in state_dict.items():
+ k = k.replace('patch_embed.', 'stem.')
+ k = k.replace('attn.', 'linear_tokens.')
+ k = k.replace('mlp.', 'mlp_channels.')
+ k = k.replace('gamma_', 'ls')
+ if k.endswith('.alpha') or k.endswith('.beta'):
+ v = v.reshape(1, 1, -1)
+ out_dict[k] = v
+ return out_dict
+ return state_dict
+
+
+def _create_mixer(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for MLP-Mixer models.')
+
+ model = build_model_with_cfg(
+ MlpMixer, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+ return model
+
+
+@register_model
+def mixer_s32_224(pretrained=False, **kwargs):
+ """ Mixer-S/32 224x224
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs)
+ model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_s16_224(pretrained=False, **kwargs):
+ """ Mixer-S/16 224x224
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs)
+ model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_b32_224(pretrained=False, **kwargs):
+ """ Mixer-B/32 224x224
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs)
+ model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_b16_224(pretrained=False, **kwargs):
+ """ Mixer-B/16 224x224. ImageNet-1k pretrained weights.
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
+ model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_b16_224_in21k(pretrained=False, **kwargs):
+ """ Mixer-B/16 224x224. ImageNet-21k pretrained weights.
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
+ model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_l32_224(pretrained=False, **kwargs):
+ """ Mixer-L/32 224x224.
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs)
+ model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_l16_224(pretrained=False, **kwargs):
+ """ Mixer-L/16 224x224. ImageNet-1k pretrained weights.
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
+ model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_l16_224_in21k(pretrained=False, **kwargs):
+ """ Mixer-L/16 224x224. ImageNet-21k pretrained weights.
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
+ model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_b16_224_miil(pretrained=False, **kwargs):
+ """ Mixer-B/16 224x224. ImageNet-21k pretrained weights.
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
+ """
+ model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
+ model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_b16_224_miil_in21k(pretrained=False, **kwargs):
+ """ Mixer-B/16 224x224. ImageNet-1k pretrained weights.
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
+ """
+ model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
+ model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def gmixer_12_224(pretrained=False, **kwargs):
+ """ Glu-Mixer-12 224x224
+ Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=(1.0, 4.0),
+ mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
+ model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def gmixer_24_224(pretrained=False, **kwargs):
+ """ Glu-Mixer-24 224x224
+ Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 4.0),
+ mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
+ model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_12_224(pretrained=False, **kwargs):
+ """ ResMLP-12
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_24_224(pretrained=False, **kwargs):
+ """ ResMLP-24
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
+ block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_36_224(pretrained=False, **kwargs):
+ """ ResMLP-36
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
+ block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_big_24_224(pretrained=False, **kwargs):
+ """ ResMLP-B-24
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ model_args = dict(
+ patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
+ block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_big_24_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_12_distilled_224(pretrained=False, **kwargs):
+ """ ResMLP-12
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_12_distilled_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_24_distilled_224(pretrained=False, **kwargs):
+ """ ResMLP-24
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
+ block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_24_distilled_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_36_distilled_224(pretrained=False, **kwargs):
+ """ ResMLP-36
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
+ block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_36_distilled_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_big_24_distilled_224(pretrained=False, **kwargs):
+ """ ResMLP-B-24
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ model_args = dict(
+ patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
+ block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_big_24_distilled_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_big_24_224_in22ft1k(pretrained=False, **kwargs):
+ """ ResMLP-B-24
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ model_args = dict(
+ patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
+ block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_big_24_224_in22ft1k', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_12_224_dino(pretrained=False, **kwargs):
+ """ ResMLP-12
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+
+ Model pretrained via DINO (self-supervised) - https://arxiv.org/abs/2104.14294
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_12_224_dino', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_24_224_dino(pretrained=False, **kwargs):
+ """ ResMLP-24
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+
+ Model pretrained via DINO (self-supervised) - https://arxiv.org/abs/2104.14294
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
+ block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_24_224_dino', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def gmlp_ti16_224(pretrained=False, **kwargs):
+ """ gMLP-Tiny
+ Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=30, embed_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
+ mlp_layer=GatedMlp, **kwargs)
+ model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def gmlp_s16_224(pretrained=False, **kwargs):
+ """ gMLP-Small
+ Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=30, embed_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
+ mlp_layer=GatedMlp, **kwargs)
+ model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def gmlp_b16_224(pretrained=False, **kwargs):
+ """ gMLP-Base
+ Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=30, embed_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
+ mlp_layer=GatedMlp, **kwargs)
+ model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args)
+ return model
diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py
new file mode 100644
index 0000000..f810eb8
--- /dev/null
+++ b/timm/models/mobilenetv3.py
@@ -0,0 +1,562 @@
+
+""" MobileNet V3
+
+A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl.
+
+Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from functools import partial
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
+from .efficientnet_blocks import SqueezeExcite
+from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
+ round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
+from .features import FeatureInfo, FeatureHooks
+from .helpers import build_model_with_cfg, default_cfg_for_features
+from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid
+from .registry import register_model
+
+__all__ = ['MobileNetV3', 'MobileNetV3Features']
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv_stem', 'classifier': 'classifier',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'mobilenetv3_large_075': _cfg(url=''),
+ 'mobilenetv3_large_100': _cfg(
+ interpolation='bicubic',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'),
+ 'mobilenetv3_large_100_miil': _cfg(
+ interpolation='bilinear', mean=(0, 0, 0), std=(1, 1, 1),
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_1k_miil_78_0.pth'),
+ 'mobilenetv3_large_100_miil_in21k': _cfg(
+ interpolation='bilinear', mean=(0, 0, 0), std=(1, 1, 1),
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mobilenetv3_large_100_in21k_miil.pth', num_classes=11221),
+ 'mobilenetv3_small_075': _cfg(url=''),
+ 'mobilenetv3_small_100': _cfg(url=''),
+
+ 'mobilenetv3_rw': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
+ interpolation='bicubic'),
+
+ 'tf_mobilenetv3_large_075': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
+ 'tf_mobilenetv3_large_100': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
+ 'tf_mobilenetv3_large_minimal_100': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
+ 'tf_mobilenetv3_small_075': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
+ 'tf_mobilenetv3_small_100': _cfg(
+ url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
+ 'tf_mobilenetv3_small_minimal_100': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
+ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
+
+ 'fbnetv3_b': _cfg(),
+ 'fbnetv3_d': _cfg(),
+ 'fbnetv3_g': _cfg(),
+}
+
+
+class MobileNetV3(nn.Module):
+ """ MobiletNet-V3
+
+ Based on my EfficientNet implementation and building blocks, this model utilizes the MobileNet-v3 specific
+ 'efficient head', where global pooling is done before the head convolution without a final batch-norm
+ layer before the classifier.
+
+ Paper: https://arxiv.org/abs/1905.02244
+ """
+
+ def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
+ pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
+ round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
+ super(MobileNetV3, self).__init__()
+ act_layer = act_layer or nn.ReLU
+ norm_layer = norm_layer or nn.BatchNorm2d
+ se_layer = se_layer or SqueezeExcite
+ self.num_classes = num_classes
+ self.num_features = num_features
+ self.drop_rate = drop_rate
+
+ # Stem
+ stem_size = round_chs_fn(stem_size)
+ self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
+ self.bn1 = norm_layer(stem_size)
+ self.act1 = act_layer(inplace=True)
+
+ # Middle stages (IR/ER/DS Blocks)
+ builder = EfficientNetBuilder(
+ output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
+ act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
+ self.blocks = nn.Sequential(*builder(stem_size, block_args))
+ self.feature_info = builder.features
+ head_chs = builder.in_chs
+
+ # Head + Pooling
+ self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
+ num_pooled_chs = head_chs * self.global_pool.feat_mult()
+ self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
+ self.act2 = act_layer(inplace=True)
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
+ self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ efficientnet_init_weights(self)
+
+ def as_sequential(self):
+ layers = [self.conv_stem, self.bn1, self.act1]
+ layers.extend(self.blocks)
+ layers.extend([self.global_pool, self.conv_head, self.act2])
+ layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
+ return nn.Sequential(*layers)
+
+ def get_classifier(self):
+ return self.classifier
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ # cannot meaningfully change pooling of efficient head after creation
+ self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
+ self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.conv_stem(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ x = self.blocks(x)
+ x = self.global_pool(x)
+ x = self.conv_head(x)
+ x = self.act2(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.flatten(x)
+ if self.drop_rate > 0.:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ return self.classifier(x)
+
+
+class MobileNetV3Features(nn.Module):
+ """ MobileNetV3 Feature Extractor
+
+ A work-in-progress feature extraction module for MobileNet-V3 to use as a backbone for segmentation
+ and object detection models.
+ """
+
+ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
+ stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=True,
+ act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
+ super(MobileNetV3Features, self).__init__()
+ act_layer = act_layer or nn.ReLU
+ norm_layer = norm_layer or nn.BatchNorm2d
+ se_layer = se_layer or SqueezeExcite
+ self.drop_rate = drop_rate
+
+ # Stem
+ stem_size = round_chs_fn(stem_size)
+ self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
+ self.bn1 = norm_layer(stem_size)
+ self.act1 = act_layer(inplace=True)
+
+ # Middle stages (IR/ER/DS Blocks)
+ builder = EfficientNetBuilder(
+ output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
+ act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer,
+ drop_path_rate=drop_path_rate, feature_location=feature_location)
+ self.blocks = nn.Sequential(*builder(stem_size, block_args))
+ self.feature_info = FeatureInfo(builder.features, out_indices)
+ self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
+
+ efficientnet_init_weights(self)
+
+ # Register feature extraction hooks with FeatureHooks helper
+ self.feature_hooks = None
+ if feature_location != 'bottleneck':
+ hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
+ self.feature_hooks = FeatureHooks(hooks, self.named_modules())
+
+ def forward(self, x) -> List[torch.Tensor]:
+ x = self.conv_stem(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ if self.feature_hooks is None:
+ features = []
+ if 0 in self._stage_out_idx:
+ features.append(x) # add stem out
+ for i, b in enumerate(self.blocks):
+ x = b(x)
+ if i + 1 in self._stage_out_idx:
+ features.append(x)
+ return features
+ else:
+ self.blocks(x)
+ out = self.feature_hooks.get_output(x.device)
+ return list(out.values())
+
+
+def _create_mnv3(variant, pretrained=False, **kwargs):
+ features_only = False
+ model_cls = MobileNetV3
+ kwargs_filter = None
+ if kwargs.pop('features_only', False):
+ features_only = True
+ kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
+ model_cls = MobileNetV3Features
+ model = build_model_with_cfg(
+ model_cls, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ pretrained_strict=not features_only,
+ kwargs_filter=kwargs_filter,
+ **kwargs)
+ if features_only:
+ model.default_cfg = default_cfg_for_features(model.default_cfg)
+ return model
+
+
+def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a MobileNet-V3 model.
+
+ Ref impl: ?
+ Paper: https://arxiv.org/abs/1905.02244
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer.
+ """
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
+ # stage 1, 112x112 in
+ ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
+ # stage 2, 56x56 in
+ ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
+ # stage 3, 28x28 in
+ ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
+ # stage 4, 14x14in
+ ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
+ # stage 5, 14x14in
+ ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
+ # stage 6, 7x7 in
+ ['cn_r1_k1_s1_c960'], # hard-swish
+ ]
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ head_bias=False,
+ round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
+ norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ act_layer=resolve_act_layer(kwargs, 'hard_swish'),
+ se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid'),
+ **kwargs,
+ )
+ model = _create_mnv3(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """Creates a MobileNet-V3 model.
+
+ Ref impl: ?
+ Paper: https://arxiv.org/abs/1905.02244
+
+ Args:
+ channel_multiplier: multiplier to number of channels per layer.
+ """
+ if 'small' in variant:
+ num_features = 1024
+ if 'minimal' in variant:
+ act_layer = resolve_act_layer(kwargs, 'relu')
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s2_e1_c16'],
+ # stage 1, 56x56 in
+ ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'],
+ # stage 2, 28x28 in
+ ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'],
+ # stage 3, 14x14 in
+ ['ir_r2_k3_s1_e3_c48'],
+ # stage 4, 14x14in
+ ['ir_r3_k3_s2_e6_c96'],
+ # stage 6, 7x7 in
+ ['cn_r1_k1_s1_c576'],
+ ]
+ else:
+ act_layer = resolve_act_layer(kwargs, 'hard_swish')
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu
+ # stage 1, 56x56 in
+ ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu
+ # stage 2, 28x28 in
+ ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish
+ # stage 3, 14x14 in
+ ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish
+ # stage 4, 14x14in
+ ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish
+ # stage 6, 7x7 in
+ ['cn_r1_k1_s1_c576'], # hard-swish
+ ]
+ else:
+ num_features = 1280
+ if 'minimal' in variant:
+ act_layer = resolve_act_layer(kwargs, 'relu')
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_e1_c16'],
+ # stage 1, 112x112 in
+ ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'],
+ # stage 2, 56x56 in
+ ['ir_r3_k3_s2_e3_c40'],
+ # stage 3, 28x28 in
+ ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'],
+ # stage 4, 14x14in
+ ['ir_r2_k3_s1_e6_c112'],
+ # stage 5, 14x14in
+ ['ir_r3_k3_s2_e6_c160'],
+ # stage 6, 7x7 in
+ ['cn_r1_k1_s1_c960'],
+ ]
+ else:
+ act_layer = resolve_act_layer(kwargs, 'hard_swish')
+ arch_def = [
+ # stage 0, 112x112 in
+ ['ds_r1_k3_s1_e1_c16_nre'], # relu
+ # stage 1, 112x112 in
+ ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
+ # stage 2, 56x56 in
+ ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
+ # stage 3, 28x28 in
+ ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
+ # stage 4, 14x14in
+ ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
+ # stage 5, 14x14in
+ ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
+ # stage 6, 7x7 in
+ ['cn_r1_k1_s1_c960'], # hard-swish
+ ]
+ se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ num_features=num_features,
+ stem_size=16,
+ round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
+ norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ act_layer=act_layer,
+ se_layer=se_layer,
+ **kwargs,
+ )
+ model = _create_mnv3(variant, pretrained, **model_kwargs)
+ return model
+
+
+def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
+ """ FBNetV3
+ Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining`
+ - https://arxiv.org/abs/2006.02049
+ FIXME untested, this is a preliminary impl of some FBNet-V3 variants.
+ """
+ vl = variant.split('_')[-1]
+ if vl in ('a', 'b'):
+ stem_size = 16
+ arch_def = [
+ ['ds_r2_k3_s1_e1_c16'],
+ ['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'],
+ ['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'],
+ ['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
+ ['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'],
+ ['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'],
+ ['cn_r1_k1_s1_c1344'],
+ ]
+ elif vl == 'd':
+ stem_size = 24
+ arch_def = [
+ ['ds_r2_k3_s1_e1_c16'],
+ ['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'],
+ ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'],
+ ['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
+ ['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'],
+ ['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'],
+ ['cn_r1_k1_s1_c1440'],
+ ]
+ elif vl == 'g':
+ stem_size = 32
+ arch_def = [
+ ['ds_r3_k3_s1_e1_c24'],
+ ['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'],
+ ['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'],
+ ['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'],
+ ['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'],
+ ['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'],
+ ['cn_r1_k1_s1_c1728'],
+ ]
+ else:
+ raise NotImplemented
+ round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95)
+ se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=round_chs_fn)
+ act_layer = resolve_act_layer(kwargs, 'hard_swish')
+ model_kwargs = dict(
+ block_args=decode_arch_def(arch_def),
+ num_features=1984,
+ head_bias=False,
+ stem_size=stem_size,
+ round_chs_fn=round_chs_fn,
+ se_from_exp=False,
+ norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
+ act_layer=act_layer,
+ se_layer=se_layer,
+ **kwargs,
+ )
+ model = _create_mnv3(variant, pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def mobilenetv3_large_075(pretrained=False, **kwargs):
+ """ MobileNet V3 """
+ model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mobilenetv3_large_100(pretrained=False, **kwargs):
+ """ MobileNet V3 """
+ model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mobilenetv3_large_100_miil(pretrained=False, **kwargs):
+ """ MobileNet V3
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
+ """
+ model = _gen_mobilenet_v3('mobilenetv3_large_100_miil', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mobilenetv3_large_100_miil_in21k(pretrained=False, **kwargs):
+ """ MobileNet V3, 21k pretraining
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
+ """
+ model = _gen_mobilenet_v3('mobilenetv3_large_100_miil_in21k', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mobilenetv3_small_075(pretrained=False, **kwargs):
+ """ MobileNet V3 """
+ model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mobilenetv3_small_100(pretrained=False, **kwargs):
+ """ MobileNet V3 """
+ model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def mobilenetv3_rw(pretrained=False, **kwargs):
+ """ MobileNet V3 """
+ if pretrained:
+ # pretrained model trained with non-default BN epsilon
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_mobilenetv3_large_075(pretrained=False, **kwargs):
+ """ MobileNet V3 """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_mobilenetv3_large_100(pretrained=False, **kwargs):
+ """ MobileNet V3 """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs):
+ """ MobileNet V3 """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_mobilenetv3_small_075(pretrained=False, **kwargs):
+ """ MobileNet V3 """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_mobilenetv3_small_100(pretrained=False, **kwargs):
+ """ MobileNet V3 """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
+ """ MobileNet V3 """
+ kwargs['bn_eps'] = BN_EPS_TF_DEFAULT
+ kwargs['pad_type'] = 'same'
+ model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def fbnetv3_b(pretrained=False, **kwargs):
+ """ FBNetV3-B """
+ model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def fbnetv3_d(pretrained=False, **kwargs):
+ """ FBNetV3-D """
+ model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def fbnetv3_g(pretrained=False, **kwargs):
+ """ FBNetV3-G """
+ model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs)
+ return model
diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py
new file mode 100644
index 0000000..2afe82c
--- /dev/null
+++ b/timm/models/nasnet.py
@@ -0,0 +1,567 @@
+""" NasNet-A (Large)
+ nasnetalarge implementation grabbed from Cadene's pretrained models
+ https://github.com/Cadene/pretrained-models.pytorch
+"""
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .helpers import build_model_with_cfg
+from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier
+from .registry import register_model
+
+__all__ = ['NASNetALarge']
+
+default_cfgs = {
+ 'nasnetalarge': {
+ 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth',
+ 'input_size': (3, 331, 331),
+ 'pool_size': (11, 11),
+ 'crop_pct': 0.911,
+ 'interpolation': 'bicubic',
+ 'mean': (0.5, 0.5, 0.5),
+ 'std': (0.5, 0.5, 0.5),
+ 'num_classes': 1000,
+ 'first_conv': 'conv0.conv',
+ 'classifier': 'last_linear',
+ 'label_offset': 1, # 1001 classes in pretrained weights
+ },
+}
+
+
+class ActConvBn(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''):
+ super(ActConvBn, self).__init__()
+ self.act = nn.ReLU()
+ self.conv = create_conv2d(
+ in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1)
+
+ def forward(self, x):
+ x = self.act(x)
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class SeparableConv2d(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''):
+ super(SeparableConv2d, self).__init__()
+ self.depthwise_conv2d = create_conv2d(
+ in_channels, in_channels, kernel_size=kernel_size,
+ stride=stride, padding=padding, groups=in_channels)
+ self.pointwise_conv2d = create_conv2d(
+ in_channels, out_channels, kernel_size=1, padding=0)
+
+ def forward(self, x):
+ x = self.depthwise_conv2d(x)
+ x = self.pointwise_conv2d(x)
+ return x
+
+
+class BranchSeparables(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_type='', stem_cell=False):
+ super(BranchSeparables, self).__init__()
+ middle_channels = out_channels if stem_cell else in_channels
+ self.act_1 = nn.ReLU()
+ self.separable_1 = SeparableConv2d(
+ in_channels, middle_channels, kernel_size, stride=stride, padding=pad_type)
+ self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, momentum=0.1)
+ self.act_2 = nn.ReLU(inplace=True)
+ self.separable_2 = SeparableConv2d(
+ middle_channels, out_channels, kernel_size, stride=1, padding=pad_type)
+ self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1)
+
+ def forward(self, x):
+ x = self.act_1(x)
+ x = self.separable_1(x)
+ x = self.bn_sep_1(x)
+ x = self.act_2(x)
+ x = self.separable_2(x)
+ x = self.bn_sep_2(x)
+ return x
+
+
+class CellStem0(nn.Module):
+ def __init__(self, stem_size, num_channels=42, pad_type=''):
+ super(CellStem0, self).__init__()
+ self.num_channels = num_channels
+ self.stem_size = stem_size
+ self.conv_1x1 = ActConvBn(self.stem_size, self.num_channels, 1, stride=1)
+
+ self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
+ self.comb_iter_0_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True)
+
+ self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
+ self.comb_iter_1_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True)
+
+ self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
+ self.comb_iter_2_right = BranchSeparables(self.stem_size, self.num_channels, 5, 2, pad_type, stem_cell=True)
+
+ self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type)
+ self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
+
+ def forward(self, x):
+ x1 = self.conv_1x1(x)
+
+ x_comb_iter_0_left = self.comb_iter_0_left(x1)
+ x_comb_iter_0_right = self.comb_iter_0_right(x)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x1)
+ x_comb_iter_1_right = self.comb_iter_1_right(x)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x1)
+ x_comb_iter_2_right = self.comb_iter_2_right(x)
+ x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
+
+ x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
+ x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
+ x_comb_iter_4_right = self.comb_iter_4_right(x1)
+ x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
+
+ x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class CellStem1(nn.Module):
+
+ def __init__(self, stem_size, num_channels, pad_type=''):
+ super(CellStem1, self).__init__()
+ self.num_channels = num_channels
+ self.stem_size = stem_size
+ self.conv_1x1 = ActConvBn(2 * self.num_channels, self.num_channels, 1, stride=1)
+
+ self.act = nn.ReLU()
+ self.path_1 = nn.Sequential()
+ self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
+ self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
+
+ self.path_2 = nn.Sequential()
+ self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
+ self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
+ self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
+
+ self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1)
+
+ self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
+ self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type)
+
+ self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
+ self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type)
+
+ self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
+ self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
+
+ self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type)
+ self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
+
+ def forward(self, x_conv0, x_stem_0):
+ x_left = self.conv_1x1(x_stem_0)
+
+ x_relu = self.act(x_conv0)
+ # path 1
+ x_path1 = self.path_1(x_relu)
+ # path 2
+ x_path2 = self.path_2(x_relu)
+ # final path
+ x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
+
+ x_comb_iter_0_left = self.comb_iter_0_left(x_left)
+ x_comb_iter_0_right = self.comb_iter_0_right(x_right)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x_left)
+ x_comb_iter_1_right = self.comb_iter_1_right(x_right)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x_left)
+ x_comb_iter_2_right = self.comb_iter_2_right(x_right)
+ x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
+
+ x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
+ x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
+ x_comb_iter_4_right = self.comb_iter_4_right(x_left)
+ x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
+
+ x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class FirstCell(nn.Module):
+
+ def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
+ super(FirstCell, self).__init__()
+ self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1)
+
+ self.act = nn.ReLU()
+ self.path_1 = nn.Sequential()
+ self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
+ self.path_1.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))
+
+ self.path_2 = nn.Sequential()
+ self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
+ self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
+ self.path_2.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))
+
+ self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2, eps=0.001, momentum=0.1)
+
+ self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
+ self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
+
+ self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
+ self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
+
+ self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+ self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
+
+ def forward(self, x, x_prev):
+ x_relu = self.act(x_prev)
+ x_path1 = self.path_1(x_relu)
+ x_path2 = self.path_2(x_relu)
+ x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
+ x_right = self.conv_1x1(x)
+
+ x_comb_iter_0_left = self.comb_iter_0_left(x_right)
+ x_comb_iter_0_right = self.comb_iter_0_right(x_left)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x_left)
+ x_comb_iter_1_right = self.comb_iter_1_right(x_left)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x_right)
+ x_comb_iter_2 = x_comb_iter_2_left + x_left
+
+ x_comb_iter_3_left = self.comb_iter_3_left(x_left)
+ x_comb_iter_3_right = self.comb_iter_3_right(x_left)
+ x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_right)
+ x_comb_iter_4 = x_comb_iter_4_left + x_right
+
+ x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class NormalCell(nn.Module):
+
+ def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
+ super(NormalCell, self).__init__()
+ self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
+ self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
+
+ self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
+ self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type)
+
+ self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5, 1, pad_type)
+ self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type)
+
+ self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+ self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
+
+ def forward(self, x, x_prev):
+ x_left = self.conv_prev_1x1(x_prev)
+ x_right = self.conv_1x1(x)
+
+ x_comb_iter_0_left = self.comb_iter_0_left(x_right)
+ x_comb_iter_0_right = self.comb_iter_0_right(x_left)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x_left)
+ x_comb_iter_1_right = self.comb_iter_1_right(x_left)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x_right)
+ x_comb_iter_2 = x_comb_iter_2_left + x_left
+
+ x_comb_iter_3_left = self.comb_iter_3_left(x_left)
+ x_comb_iter_3_right = self.comb_iter_3_right(x_left)
+ x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_right)
+ x_comb_iter_4 = x_comb_iter_4_left + x_right
+
+ x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class ReductionCell0(nn.Module):
+
+ def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
+ super(ReductionCell0, self).__init__()
+ self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
+ self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
+
+ self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
+ self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
+
+ self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
+ self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
+
+ self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
+ self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
+
+ self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
+ self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
+
+ def forward(self, x, x_prev):
+ x_left = self.conv_prev_1x1(x_prev)
+ x_right = self.conv_1x1(x)
+
+ x_comb_iter_0_left = self.comb_iter_0_left(x_right)
+ x_comb_iter_0_right = self.comb_iter_0_right(x_left)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x_right)
+ x_comb_iter_1_right = self.comb_iter_1_right(x_left)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x_right)
+ x_comb_iter_2_right = self.comb_iter_2_right(x_left)
+ x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
+
+ x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
+ x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
+ x_comb_iter_4_right = self.comb_iter_4_right(x_right)
+ x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
+
+ x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class ReductionCell1(nn.Module):
+
+ def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
+ super(ReductionCell1, self).__init__()
+ self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
+ self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
+
+ self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
+ self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
+
+ self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
+ self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
+
+ self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
+ self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
+
+ self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
+ self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
+
+ def forward(self, x, x_prev):
+ x_left = self.conv_prev_1x1(x_prev)
+ x_right = self.conv_1x1(x)
+
+ x_comb_iter_0_left = self.comb_iter_0_left(x_right)
+ x_comb_iter_0_right = self.comb_iter_0_right(x_left)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x_right)
+ x_comb_iter_1_right = self.comb_iter_1_right(x_left)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x_right)
+ x_comb_iter_2_right = self.comb_iter_2_right(x_left)
+ x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
+
+ x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
+ x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
+ x_comb_iter_4_right = self.comb_iter_4_right(x_right)
+ x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
+
+ x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class NASNetALarge(nn.Module):
+ """NASNetALarge (6 @ 4032) """
+
+ def __init__(self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2,
+ num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'):
+ super(NASNetALarge, self).__init__()
+ self.num_classes = num_classes
+ self.stem_size = stem_size
+ self.num_features = num_features
+ self.channel_multiplier = channel_multiplier
+ self.drop_rate = drop_rate
+ assert output_stride == 32
+
+ channels = self.num_features // 24
+ # 24 is default value for the architecture
+
+ self.conv0 = ConvBnAct(
+ in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2,
+ norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)
+
+ self.cell_stem_0 = CellStem0(
+ self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type)
+ self.cell_stem_1 = CellStem1(
+ self.stem_size, num_channels=channels // channel_multiplier, pad_type=pad_type)
+
+ self.cell_0 = FirstCell(
+ in_chs_left=channels, out_chs_left=channels // 2,
+ in_chs_right=2 * channels, out_chs_right=channels, pad_type=pad_type)
+ self.cell_1 = NormalCell(
+ in_chs_left=2 * channels, out_chs_left=channels,
+ in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
+ self.cell_2 = NormalCell(
+ in_chs_left=6 * channels, out_chs_left=channels,
+ in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
+ self.cell_3 = NormalCell(
+ in_chs_left=6 * channels, out_chs_left=channels,
+ in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
+ self.cell_4 = NormalCell(
+ in_chs_left=6 * channels, out_chs_left=channels,
+ in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
+ self.cell_5 = NormalCell(
+ in_chs_left=6 * channels, out_chs_left=channels,
+ in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
+
+ self.reduction_cell_0 = ReductionCell0(
+ in_chs_left=6 * channels, out_chs_left=2 * channels,
+ in_chs_right=6 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+ self.cell_6 = FirstCell(
+ in_chs_left=6 * channels, out_chs_left=channels,
+ in_chs_right=8 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+ self.cell_7 = NormalCell(
+ in_chs_left=8 * channels, out_chs_left=2 * channels,
+ in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+ self.cell_8 = NormalCell(
+ in_chs_left=12 * channels, out_chs_left=2 * channels,
+ in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+ self.cell_9 = NormalCell(
+ in_chs_left=12 * channels, out_chs_left=2 * channels,
+ in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+ self.cell_10 = NormalCell(
+ in_chs_left=12 * channels, out_chs_left=2 * channels,
+ in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+ self.cell_11 = NormalCell(
+ in_chs_left=12 * channels, out_chs_left=2 * channels,
+ in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+
+ self.reduction_cell_1 = ReductionCell1(
+ in_chs_left=12 * channels, out_chs_left=4 * channels,
+ in_chs_right=12 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.cell_12 = FirstCell(
+ in_chs_left=12 * channels, out_chs_left=2 * channels,
+ in_chs_right=16 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.cell_13 = NormalCell(
+ in_chs_left=16 * channels, out_chs_left=4 * channels,
+ in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.cell_14 = NormalCell(
+ in_chs_left=24 * channels, out_chs_left=4 * channels,
+ in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.cell_15 = NormalCell(
+ in_chs_left=24 * channels, out_chs_left=4 * channels,
+ in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.cell_16 = NormalCell(
+ in_chs_left=24 * channels, out_chs_left=4 * channels,
+ in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.cell_17 = NormalCell(
+ in_chs_left=24 * channels, out_chs_left=4 * channels,
+ in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.act = nn.ReLU(inplace=True)
+ self.feature_info = [
+ dict(num_chs=96, reduction=2, module='conv0'),
+ dict(num_chs=168, reduction=4, module='cell_stem_1.conv_1x1.act'),
+ dict(num_chs=1008, reduction=8, module='reduction_cell_0.conv_1x1.act'),
+ dict(num_chs=2016, reduction=16, module='reduction_cell_1.conv_1x1.act'),
+ dict(num_chs=4032, reduction=32, module='act'),
+ ]
+
+ self.global_pool, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def get_classifier(self):
+ return self.last_linear
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ x_conv0 = self.conv0(x)
+
+ x_stem_0 = self.cell_stem_0(x_conv0)
+ x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
+
+ x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
+ x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
+ x_cell_2 = self.cell_2(x_cell_1, x_cell_0)
+ x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
+ x_cell_4 = self.cell_4(x_cell_3, x_cell_2)
+ x_cell_5 = self.cell_5(x_cell_4, x_cell_3)
+
+ x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4)
+ x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4)
+ x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
+ x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
+ x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
+ x_cell_10 = self.cell_10(x_cell_9, x_cell_8)
+ x_cell_11 = self.cell_11(x_cell_10, x_cell_9)
+
+ x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10)
+ x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10)
+ x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
+ x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
+ x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
+ x_cell_16 = self.cell_16(x_cell_15, x_cell_14)
+ x_cell_17 = self.cell_17(x_cell_16, x_cell_15)
+ x = self.act(x_cell_17)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate > 0:
+ x = F.dropout(x, self.drop_rate, training=self.training)
+ x = self.last_linear(x)
+ return x
+
+
+def _create_nasnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ NASNetALarge, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
+ **kwargs)
+
+
+@register_model
+def nasnetalarge(pretrained=False, **kwargs):
+ """NASNet-A large model architecture.
+ """
+ model_kwargs = dict(pad_type='same', **kwargs)
+ return _create_nasnet('nasnetalarge', pretrained, **model_kwargs)
diff --git a/timm/models/nest.py b/timm/models/nest.py
new file mode 100644
index 0000000..22cf609
--- /dev/null
+++ b/timm/models/nest.py
@@ -0,0 +1,465 @@
+""" Nested Transformer (NesT) in PyTorch
+
+A PyTorch implement of Aggregating Nested Transformers as described in:
+
+'Aggregating Nested Transformers'
+ - https://arxiv.org/abs/2105.12723
+
+The official Jax code is released and available at https://github.com/google-research/nested-transformer. The weights
+have been converted with convert/convert_nest_flax.py
+
+Acknowledgments:
+* The paper authors for sharing their research, code, and model weights
+* Ross Wightman's existing code off which I based this
+
+Copyright 2021 Alexander Soare
+"""
+
+import collections.abc
+import logging
+import math
+from functools import partial
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .fx_features import register_notrace_function
+from .helpers import build_model_with_cfg, named_apply
+from .layers import PatchEmbed, Mlp, DropPath, create_classifier, trunc_normal_
+from .layers import _assert
+from .layers import create_conv2d, create_pool2d, to_ntuple
+from .registry import register_model
+
+_logger = logging.getLogger(__name__)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': [14, 14],
+ 'crop_pct': .875, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # (weights from official Google JAX impl)
+ 'nest_base': _cfg(),
+ 'nest_small': _cfg(),
+ 'nest_tiny': _cfg(),
+ 'jx_nest_base': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_base-8bc41011.pth'),
+ 'jx_nest_small': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_small-422eaded.pth'),
+ 'jx_nest_tiny': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/jx_nest_tiny-e3428fb9.pth'),
+}
+
+
+class Attention(nn.Module):
+ """
+ This is much like `.vision_transformer.Attention` but uses *localised* self attention by accepting an input with
+ an extra "image block" dim
+ """
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, 3*dim, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ """
+ x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim)
+ """
+ B, T, N, C = x.shape
+ # result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
+ qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ # (B, H, T, N, C'), permute -> (B, T, N, C', H)
+ x = (attn @ v).permute(0, 2, 3, 4, 1).reshape(B, T, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x # (B, T, N, C)
+
+
+class TransformerLayer(nn.Module):
+ """
+ This is much like `.vision_transformer.Block` but:
+ - Called TransformerLayer here to allow for "block" as defined in the paper ("non-overlapping image blocks")
+ - Uses modified Attention layer that handles the "block" dimension
+ """
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ y = self.norm1(x)
+ x = x + self.drop_path(self.attn(y))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class ConvPool(nn.Module):
+ def __init__(self, in_channels, out_channels, norm_layer, pad_type=''):
+ super().__init__()
+ self.conv = create_conv2d(in_channels, out_channels, kernel_size=3, padding=pad_type, bias=True)
+ self.norm = norm_layer(out_channels)
+ self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=pad_type)
+
+ def forward(self, x):
+ """
+ x is expected to have shape (B, C, H, W)
+ """
+ _assert(x.shape[-2] % 2 == 0, 'BlockAggregation requires even input spatial dims')
+ _assert(x.shape[-1] % 2 == 0, 'BlockAggregation requires even input spatial dims')
+ x = self.conv(x)
+ # Layer norm done over channel dim only
+ x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ x = self.pool(x)
+ return x # (B, C, H//2, W//2)
+
+
+def blockify(x, block_size: int):
+ """image to blocks
+ Args:
+ x (Tensor): with shape (B, H, W, C)
+ block_size (int): edge length of a single square block in units of H, W
+ """
+ B, H, W, C = x.shape
+ _assert(H % block_size == 0, '`block_size` must divide input height evenly')
+ _assert(W % block_size == 0, '`block_size` must divide input width evenly')
+ grid_height = H // block_size
+ grid_width = W // block_size
+ x = x.reshape(B, grid_height, block_size, grid_width, block_size, C)
+ x = x.transpose(2, 3).reshape(B, grid_height * grid_width, -1, C)
+ return x # (B, T, N, C)
+
+
+@register_notrace_function # reason: int receives Proxy
+def deblockify(x, block_size: int):
+ """blocks to image
+ Args:
+ x (Tensor): with shape (B, T, N, C) where T is number of blocks and N is sequence size per block
+ block_size (int): edge length of a single square block in units of desired H, W
+ """
+ B, T, _, C = x.shape
+ grid_size = int(math.sqrt(T))
+ height = width = grid_size * block_size
+ x = x.reshape(B, grid_size, grid_size, block_size, block_size, C)
+ x = x.transpose(2, 3).reshape(B, height, width, C)
+ return x # (B, H, W, C)
+
+
+class NestLevel(nn.Module):
+ """ Single hierarchical level of a Nested Transformer
+ """
+ def __init__(
+ self, num_blocks, block_size, seq_length, num_heads, depth, embed_dim, prev_embed_dim=None,
+ mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rates=[],
+ norm_layer=None, act_layer=None, pad_type=''):
+ super().__init__()
+ self.block_size = block_size
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_blocks, seq_length, embed_dim))
+
+ if prev_embed_dim is not None:
+ self.pool = ConvPool(prev_embed_dim, embed_dim, norm_layer=norm_layer, pad_type=pad_type)
+ else:
+ self.pool = nn.Identity()
+
+ # Transformer encoder
+ if len(drop_path_rates):
+ assert len(drop_path_rates) == depth, 'Must provide as many drop path rates as there are transformer layers'
+ self.transformer_encoder = nn.Sequential(*[
+ TransformerLayer(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rates[i],
+ norm_layer=norm_layer, act_layer=act_layer)
+ for i in range(depth)])
+
+ def forward(self, x):
+ """
+ expects x as (B, C, H, W)
+ """
+ x = self.pool(x)
+ x = x.permute(0, 2, 3, 1) # (B, H', W', C), switch to channels last for transformer
+ x = blockify(x, self.block_size) # (B, T, N, C')
+ x = x + self.pos_embed
+ x = self.transformer_encoder(x) # (B, T, N, C')
+ x = deblockify(x, self.block_size) # (B, H', W', C')
+ # Channel-first for block aggregation, and generally to replicate convnet feature map at each stage
+ return x.permute(0, 3, 1, 2) # (B, C, H', W')
+
+
+class Nest(nn.Module):
+ """ Nested Transformer (NesT)
+
+ A PyTorch impl of : `Aggregating Nested Transformers`
+ - https://arxiv.org/abs/2105.12723
+ """
+
+ def __init__(self, img_size=224, in_chans=3, patch_size=4, num_levels=3, embed_dims=(128, 256, 512),
+ num_heads=(4, 8, 16), depths=(2, 2, 20), num_classes=1000, mlp_ratio=4., qkv_bias=True,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.5, norm_layer=None, act_layer=None,
+ pad_type='', weight_init='', global_pool='avg'):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ in_chans (int): number of input channels
+ patch_size (int): patch size
+ num_levels (int): number of block hierarchies (T_d in the paper)
+ embed_dims (int, tuple): embedding dimensions of each level
+ num_heads (int, tuple): number of attention heads for each level
+ depths (int, tuple): number of transformer layers for each level
+ num_classes (int): number of classes for classification head
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim for MLP of transformer layers
+ qkv_bias (bool): enable bias for qkv if True
+ drop_rate (float): dropout rate for MLP of transformer layers, MSA final projection layer, and classifier
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ norm_layer: (nn.Module): normalization layer for transformer layers
+ act_layer: (nn.Module): activation layer in MLP of transformer layers
+ pad_type: str: Type of padding to use '' for PyTorch symmetric, 'same' for TF SAME
+ weight_init: (str): weight init scheme
+ global_pool: (str): type of pooling operation to apply to final feature map
+
+ Notes:
+ - Default values follow NesT-B from the original Jax code.
+ - `embed_dims`, `num_heads`, `depths` should be ints or tuples with length `num_levels`.
+ - For those following the paper, Table A1 may have errors!
+ - https://github.com/google-research/nested-transformer/issues/2
+ """
+ super().__init__()
+
+ for param_name in ['embed_dims', 'num_heads', 'depths']:
+ param_value = locals()[param_name]
+ if isinstance(param_value, collections.abc.Sequence):
+ assert len(param_value) == num_levels, f'Require `len({param_name}) == num_levels`'
+
+ embed_dims = to_ntuple(num_levels)(embed_dims)
+ num_heads = to_ntuple(num_levels)(num_heads)
+ depths = to_ntuple(num_levels)(depths)
+ self.num_classes = num_classes
+ self.num_features = embed_dims[-1]
+ self.feature_info = []
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ act_layer = act_layer or nn.GELU
+ self.drop_rate = drop_rate
+ self.num_levels = num_levels
+ if isinstance(img_size, collections.abc.Sequence):
+ assert img_size[0] == img_size[1], 'Model only handles square inputs'
+ img_size = img_size[0]
+ assert img_size % patch_size == 0, '`patch_size` must divide `img_size` evenly'
+ self.patch_size = patch_size
+
+ # Number of blocks at each level
+ self.num_blocks = (4 ** torch.arange(num_levels)).flip(0).tolist()
+ assert (img_size // patch_size) % math.sqrt(self.num_blocks[0]) == 0, \
+ 'First level blocks don\'t fit evenly. Check `img_size`, `patch_size`, and `num_levels`'
+
+ # Block edge size in units of patches
+ # Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
+ # number of blocks along edge of image
+ self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0]))
+
+ # Patch embedding
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], flatten=False)
+ self.num_patches = self.patch_embed.num_patches
+ self.seq_length = self.num_patches // self.num_blocks[0]
+
+ # Build up each hierarchical level
+ levels = []
+ dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
+ prev_dim = None
+ curr_stride = 4
+ for i in range(len(self.num_blocks)):
+ dim = embed_dims[i]
+ levels.append(NestLevel(
+ self.num_blocks[i], self.block_size, self.seq_length, num_heads[i], depths[i], dim, prev_dim,
+ mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, dp_rates[i], norm_layer, act_layer, pad_type=pad_type))
+ self.feature_info += [dict(num_chs=dim, reduction=curr_stride, module=f'levels.{i}')]
+ prev_dim = dim
+ curr_stride *= 2
+ self.levels = nn.Sequential(*levels)
+
+ # Final normalization layer
+ self.norm = norm_layer(embed_dims[-1])
+
+ # Classifier
+ self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ self.init_weights(weight_init)
+
+ def init_weights(self, mode=''):
+ assert mode in ('nlhb', '')
+ head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
+ for level in self.levels:
+ trunc_normal_(level.pos_embed, std=.02, a=-2, b=2)
+ named_apply(partial(_init_nest_weights, head_bias=head_bias), self)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {f'level.{i}.pos_embed' for i in range(len(self.levels))}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.head = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ """ x shape (B, C, H, W)
+ """
+ x = self.patch_embed(x)
+ x = self.levels(x)
+ # Layer norm done over channel dim only (to NHWC and back)
+ x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x
+
+ def forward(self, x):
+ """ x shape (B, C, H, W)
+ """
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate > 0.:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ return self.head(x)
+
+
+def _init_nest_weights(module: nn.Module, name: str = '', head_bias: float = 0.):
+ """ NesT weight initialization
+ Can replicate Jax implementation. Otherwise follows vision_transformer.py
+ """
+ if isinstance(module, nn.Linear):
+ if name.startswith('head'):
+ trunc_normal_(module.weight, std=.02, a=-2, b=2)
+ nn.init.constant_(module.bias, head_bias)
+ else:
+ trunc_normal_(module.weight, std=.02, a=-2, b=2)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Conv2d):
+ trunc_normal_(module.weight, std=.02, a=-2, b=2)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
+ nn.init.zeros_(module.bias)
+ nn.init.ones_(module.weight)
+
+
+def resize_pos_embed(posemb, posemb_new):
+ """
+ Rescale the grid of position embeddings when loading from state_dict
+ Expected shape of position embeddings is (1, T, N, C), and considers only square images
+ """
+ _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
+ seq_length_old = posemb.shape[2]
+ num_blocks_new, seq_length_new = posemb_new.shape[1:3]
+ size_new = int(math.sqrt(num_blocks_new*seq_length_new))
+ # First change to (1, C, H, W)
+ posemb = deblockify(posemb, int(math.sqrt(seq_length_old))).permute(0, 3, 1, 2)
+ posemb = F.interpolate(posemb, size=[size_new, size_new], mode='bicubic', align_corners=False)
+ # Now change to new (1, T, N, C)
+ posemb = blockify(posemb.permute(0, 2, 3, 1), int(math.sqrt(seq_length_new)))
+ return posemb
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """ resize positional embeddings of pretrained weights """
+ pos_embed_keys = [k for k in state_dict.keys() if k.startswith('pos_embed_')]
+ for k in pos_embed_keys:
+ if state_dict[k].shape != getattr(model, k).shape:
+ state_dict[k] = resize_pos_embed(state_dict[k], getattr(model, k))
+ return state_dict
+
+
+def _create_nest(variant, pretrained=False, default_cfg=None, **kwargs):
+ default_cfg = default_cfg or default_cfgs[variant]
+ model = build_model_with_cfg(
+ Nest, variant, pretrained,
+ default_cfg=default_cfg,
+ feature_cfg=dict(out_indices=(0, 1, 2), flatten_sequential=True),
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+
+ return model
+
+
+@register_model
+def nest_base(pretrained=False, **kwargs):
+ """ Nest-B @ 224x224
+ """
+ model_kwargs = dict(
+ embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs)
+ model = _create_nest('nest_base', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def nest_small(pretrained=False, **kwargs):
+ """ Nest-S @ 224x224
+ """
+ model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs)
+ model = _create_nest('nest_small', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def nest_tiny(pretrained=False, **kwargs):
+ """ Nest-T @ 224x224
+ """
+ model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs)
+ model = _create_nest('nest_tiny', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def jx_nest_base(pretrained=False, **kwargs):
+ """ Nest-B @ 224x224, Pretrained weights converted from official Jax impl.
+ """
+ kwargs['pad_type'] = 'same'
+ model_kwargs = dict(embed_dims=(128, 256, 512), num_heads=(4, 8, 16), depths=(2, 2, 20), **kwargs)
+ model = _create_nest('jx_nest_base', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def jx_nest_small(pretrained=False, **kwargs):
+ """ Nest-S @ 224x224, Pretrained weights converted from official Jax impl.
+ """
+ kwargs['pad_type'] = 'same'
+ model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 20), **kwargs)
+ model = _create_nest('jx_nest_small', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def jx_nest_tiny(pretrained=False, **kwargs):
+ """ Nest-T @ 224x224, Pretrained weights converted from official Jax impl.
+ """
+ kwargs['pad_type'] = 'same'
+ model_kwargs = dict(embed_dims=(96, 192, 384), num_heads=(3, 6, 12), depths=(2, 2, 8), **kwargs)
+ model = _create_nest('jx_nest_tiny', pretrained=pretrained, **model_kwargs)
+ return model
diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py
new file mode 100644
index 0000000..973cbd6
--- /dev/null
+++ b/timm/models/nfnet.py
@@ -0,0 +1,968 @@
+""" Normalization Free Nets. NFNet, NF-RegNet, NF-ResNet (pre-activation) Models
+
+Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
+ - https://arxiv.org/abs/2101.08692
+
+Paper: `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+
+Official Deepmind JAX code: https://github.com/deepmind/deepmind-research/tree/master/nfnets
+
+Status:
+* These models are a work in progress, experiments ongoing.
+* Pretrained weights for two models so far, more to come.
+* Model details updated to closer match official JAX code now that it's released
+* NF-ResNet, NF-RegNet-B, and NFNet-F models supported
+
+Hacked together by / copyright Ross Wightman, 2021.
+"""
+import math
+from dataclasses import dataclass, field
+from collections import OrderedDict
+from typing import Tuple, Optional
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .fx_features import register_notrace_module
+from .helpers import build_model_with_cfg
+from .registry import register_model
+from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, ScaledStdConv2dSame,\
+ get_act_layer, get_act_fn, get_attn, make_divisible
+
+
+def _dcfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.9, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv1', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+default_cfgs = dict(
+ dm_nfnet_f0=_dcfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f0-604f9c3a.pth',
+ pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), crop_pct=.9),
+ dm_nfnet_f1=_dcfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f1-fc540f82.pth',
+ pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320), crop_pct=0.91),
+ dm_nfnet_f2=_dcfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f2-89875923.pth',
+ pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352), crop_pct=0.92),
+ dm_nfnet_f3=_dcfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f3-d74ab3aa.pth',
+ pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416), crop_pct=0.94),
+ dm_nfnet_f4=_dcfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f4-0ac5b10b.pth',
+ pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512), crop_pct=0.951),
+ dm_nfnet_f5=_dcfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f5-ecb20ab1.pth',
+ pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544), crop_pct=0.954),
+ dm_nfnet_f6=_dcfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-dnf-weights/dm_nfnet_f6-e0f12116.pth',
+ pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576), crop_pct=0.956),
+
+ nfnet_f0=_dcfg(
+ url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
+ nfnet_f1=_dcfg(
+ url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)),
+ nfnet_f2=_dcfg(
+ url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)),
+ nfnet_f3=_dcfg(
+ url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)),
+ nfnet_f4=_dcfg(
+ url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)),
+ nfnet_f5=_dcfg(
+ url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)),
+ nfnet_f6=_dcfg(
+ url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)),
+ nfnet_f7=_dcfg(
+ url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
+
+ nfnet_f0s=_dcfg(
+ url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256)),
+ nfnet_f1s=_dcfg(
+ url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 320, 320)),
+ nfnet_f2s=_dcfg(
+ url='', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 352, 352)),
+ nfnet_f3s=_dcfg(
+ url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 416, 416)),
+ nfnet_f4s=_dcfg(
+ url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 512, 512)),
+ nfnet_f5s=_dcfg(
+ url='', pool_size=(13, 13), input_size=(3, 416, 416), test_input_size=(3, 544, 544)),
+ nfnet_f6s=_dcfg(
+ url='', pool_size=(14, 14), input_size=(3, 448, 448), test_input_size=(3, 576, 576)),
+ nfnet_f7s=_dcfg(
+ url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
+
+ nfnet_l0=_dcfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0_ra2-45c6688d.pth',
+ pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0),
+ eca_nfnet_l0=_dcfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l0_ra2-e3e9ac50.pth',
+ hf_hub='timm/eca_nfnet_l0',
+ pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0),
+ eca_nfnet_l1=_dcfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth',
+ pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0),
+ eca_nfnet_l2=_dcfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l2_ra3-da781a61.pth',
+ pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0),
+ eca_nfnet_l3=_dcfg(
+ url='',
+ pool_size=(11, 11), input_size=(3, 352, 352), test_input_size=(3, 448, 448), crop_pct=1.0),
+
+ nf_regnet_b0=_dcfg(
+ url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
+ nf_regnet_b1=_dcfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_regnet_b1_256_ra2-ad85cfef.pth',
+ pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), first_conv='stem.conv'), # NOT to paper spec
+ nf_regnet_b2=_dcfg(
+ url='', pool_size=(8, 8), input_size=(3, 240, 240), test_input_size=(3, 272, 272), first_conv='stem.conv'),
+ nf_regnet_b3=_dcfg(
+ url='', pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 320, 320), first_conv='stem.conv'),
+ nf_regnet_b4=_dcfg(
+ url='', pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), first_conv='stem.conv'),
+ nf_regnet_b5=_dcfg(
+ url='', pool_size=(12, 12), input_size=(3, 384, 384), test_input_size=(3, 456, 456), first_conv='stem.conv'),
+
+ nf_resnet26=_dcfg(url='', first_conv='stem.conv'),
+ nf_resnet50=_dcfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nf_resnet50_ra2-9f236009.pth',
+ pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 288, 288), crop_pct=0.94, first_conv='stem.conv'),
+ nf_resnet101=_dcfg(url='', first_conv='stem.conv'),
+
+ nf_seresnet26=_dcfg(url='', first_conv='stem.conv'),
+ nf_seresnet50=_dcfg(url='', first_conv='stem.conv'),
+ nf_seresnet101=_dcfg(url='', first_conv='stem.conv'),
+
+ nf_ecaresnet26=_dcfg(url='', first_conv='stem.conv'),
+ nf_ecaresnet50=_dcfg(url='', first_conv='stem.conv'),
+ nf_ecaresnet101=_dcfg(url='', first_conv='stem.conv'),
+)
+
+
+@dataclass
+class NfCfg:
+ depths: Tuple[int, int, int, int]
+ channels: Tuple[int, int, int, int]
+ alpha: float = 0.2
+ stem_type: str = '3x3'
+ stem_chs: Optional[int] = None
+ group_size: Optional[int] = None
+ attn_layer: Optional[str] = None
+ attn_kwargs: dict = None
+ attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used
+ width_factor: float = 1.0
+ bottle_ratio: float = 0.5
+ num_features: int = 0 # num out_channels for final conv, no final_conv if 0
+ ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal
+ reg: bool = False # enables EfficientNet-like options used in RegNet variants, expand from in_chs, se in middle
+ extra_conv: bool = False # extra 3x3 bottleneck convolution for NFNet models
+ gamma_in_act: bool = False
+ same_padding: bool = False
+ std_conv_eps: float = 1e-5
+ skipinit: bool = False # disabled by default, non-trivial performance impact
+ zero_init_fc: bool = False
+ act_layer: str = 'silu'
+
+
+def _nfres_cfg(
+ depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None):
+ attn_kwargs = attn_kwargs or {}
+ cfg = NfCfg(
+ depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25,
+ group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
+ return cfg
+
+
+def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
+ num_features = 1280 * channels[-1] // 440
+ attn_kwargs = dict(rd_ratio=0.5)
+ cfg = NfCfg(
+ depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25,
+ num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs)
+ return cfg
+
+
+def _nfnet_cfg(
+ depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2.,
+ act_layer='gelu', attn_layer='se', attn_kwargs=None):
+ num_features = int(channels[-1] * feat_mult)
+ attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5)
+ cfg = NfCfg(
+ depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size,
+ bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer,
+ attn_layer=attn_layer, attn_kwargs=attn_kwargs)
+ return cfg
+
+
+def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True):
+ cfg = NfCfg(
+ depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128,
+ bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit,
+ num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5))
+ return cfg
+
+
+
+model_cfgs = dict(
+ # NFNet-F models w/ GELU compatible with DeepMind weights
+ dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)),
+ dm_nfnet_f1=_dm_nfnet_cfg(depths=(2, 4, 12, 6)),
+ dm_nfnet_f2=_dm_nfnet_cfg(depths=(3, 6, 18, 9)),
+ dm_nfnet_f3=_dm_nfnet_cfg(depths=(4, 8, 24, 12)),
+ dm_nfnet_f4=_dm_nfnet_cfg(depths=(5, 10, 30, 15)),
+ dm_nfnet_f5=_dm_nfnet_cfg(depths=(6, 12, 36, 18)),
+ dm_nfnet_f6=_dm_nfnet_cfg(depths=(7, 14, 42, 21)),
+
+ # NFNet-F models w/ GELU (I will likely deprecate/remove these models and just keep dm_ ver for GELU)
+ nfnet_f0=_nfnet_cfg(depths=(1, 2, 6, 3)),
+ nfnet_f1=_nfnet_cfg(depths=(2, 4, 12, 6)),
+ nfnet_f2=_nfnet_cfg(depths=(3, 6, 18, 9)),
+ nfnet_f3=_nfnet_cfg(depths=(4, 8, 24, 12)),
+ nfnet_f4=_nfnet_cfg(depths=(5, 10, 30, 15)),
+ nfnet_f5=_nfnet_cfg(depths=(6, 12, 36, 18)),
+ nfnet_f6=_nfnet_cfg(depths=(7, 14, 42, 21)),
+ nfnet_f7=_nfnet_cfg(depths=(8, 16, 48, 24)),
+
+ # NFNet-F models w/ SiLU (much faster in PyTorch)
+ nfnet_f0s=_nfnet_cfg(depths=(1, 2, 6, 3), act_layer='silu'),
+ nfnet_f1s=_nfnet_cfg(depths=(2, 4, 12, 6), act_layer='silu'),
+ nfnet_f2s=_nfnet_cfg(depths=(3, 6, 18, 9), act_layer='silu'),
+ nfnet_f3s=_nfnet_cfg(depths=(4, 8, 24, 12), act_layer='silu'),
+ nfnet_f4s=_nfnet_cfg(depths=(5, 10, 30, 15), act_layer='silu'),
+ nfnet_f5s=_nfnet_cfg(depths=(6, 12, 36, 18), act_layer='silu'),
+ nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'),
+ nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'),
+
+ # Experimental 'light' versions of NFNet-F that are little leaner
+ nfnet_l0=_nfnet_cfg(
+ depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
+ attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'),
+ eca_nfnet_l0=_nfnet_cfg(
+ depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
+ attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
+ eca_nfnet_l1=_nfnet_cfg(
+ depths=(2, 4, 12, 6), feat_mult=2, group_size=64, bottle_ratio=0.25,
+ attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
+ eca_nfnet_l2=_nfnet_cfg(
+ depths=(3, 6, 18, 9), feat_mult=2, group_size=64, bottle_ratio=0.25,
+ attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
+ eca_nfnet_l3=_nfnet_cfg(
+ depths=(4, 8, 24, 12), feat_mult=2, group_size=64, bottle_ratio=0.25,
+ attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
+
+ # EffNet influenced RegNet defs.
+ # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8.
+ nf_regnet_b0=_nfreg_cfg(depths=(1, 3, 6, 6)),
+ nf_regnet_b1=_nfreg_cfg(depths=(2, 4, 7, 7)),
+ nf_regnet_b2=_nfreg_cfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488)),
+ nf_regnet_b3=_nfreg_cfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528)),
+ nf_regnet_b4=_nfreg_cfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616)),
+ nf_regnet_b5=_nfreg_cfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704)),
+ # FIXME add B6-B8
+
+ # ResNet (preact, D style deep stem/avg down) defs
+ nf_resnet26=_nfres_cfg(depths=(2, 2, 2, 2)),
+ nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)),
+ nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)),
+
+ nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
+ nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
+ nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
+
+ nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
+ nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),
+ nf_ecaresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='eca', attn_kwargs=dict()),
+
+)
+
+
+class GammaAct(nn.Module):
+ def __init__(self, act_type='relu', gamma: float = 1.0, inplace=False):
+ super().__init__()
+ self.act_fn = get_act_fn(act_type)
+ self.gamma = gamma
+ self.inplace = inplace
+
+ def forward(self, x):
+ return self.act_fn(x, inplace=self.inplace).mul_(self.gamma)
+
+
+def act_with_gamma(act_type, gamma: float = 1.):
+ def _create(inplace=False):
+ return GammaAct(act_type, gamma=gamma, inplace=inplace)
+ return _create
+
+
+class DownsampleAvg(nn.Module):
+ def __init__(
+ self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, conv_layer=ScaledStdConv2d):
+ """ AvgPool Downsampling as in 'D' ResNet variants. Support for dilation."""
+ super(DownsampleAvg, self).__init__()
+ avg_stride = stride if dilation == 1 else 1
+ if stride > 1 or dilation > 1:
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
+ self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
+ else:
+ self.pool = nn.Identity()
+ self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
+
+ def forward(self, x):
+ return self.conv(self.pool(x))
+
+
+@register_notrace_module # reason: mul_ causes FX to drop a relevant node. https://github.com/pytorch/pytorch/issues/68301
+class NormFreeBlock(nn.Module):
+ """Normalization-Free pre-activation block.
+ """
+
+ def __init__(
+ self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,
+ alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False,
+ skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.):
+ super().__init__()
+ first_dilation = first_dilation or dilation
+ out_chs = out_chs or in_chs
+ # RegNet variants scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
+ mid_chs = make_divisible(in_chs * bottle_ratio if reg else out_chs * bottle_ratio, ch_div)
+ groups = 1 if not group_size else mid_chs // group_size
+ if group_size and group_size % ch_div == 0:
+ mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error
+ self.alpha = alpha
+ self.beta = beta
+ self.attn_gain = attn_gain
+
+ if in_chs != out_chs or stride != 1 or dilation != first_dilation:
+ self.downsample = DownsampleAvg(
+ in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer)
+ else:
+ self.downsample = None
+
+ self.act1 = act_layer()
+ self.conv1 = conv_layer(in_chs, mid_chs, 1)
+ self.act2 = act_layer(inplace=True)
+ self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
+ if extra_conv:
+ self.act2b = act_layer(inplace=True)
+ self.conv2b = conv_layer(mid_chs, mid_chs, 3, stride=1, dilation=dilation, groups=groups)
+ else:
+ self.act2b = None
+ self.conv2b = None
+ if reg and attn_layer is not None:
+ self.attn = attn_layer(mid_chs) # RegNet blocks apply attn btw conv2 & 3
+ else:
+ self.attn = None
+ self.act3 = act_layer()
+ self.conv3 = conv_layer(mid_chs, out_chs, 1, gain_init=1. if skipinit else 0.)
+ if not reg and attn_layer is not None:
+ self.attn_last = attn_layer(out_chs) # ResNet blocks apply attn after conv3
+ else:
+ self.attn_last = None
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
+ self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None
+
+ def forward(self, x):
+ out = self.act1(x) * self.beta
+
+ # shortcut branch
+ shortcut = x
+ if self.downsample is not None:
+ shortcut = self.downsample(out)
+
+ # residual branch
+ out = self.conv1(out)
+ out = self.conv2(self.act2(out))
+ if self.conv2b is not None:
+ out = self.conv2b(self.act2b(out))
+ if self.attn is not None:
+ out = self.attn_gain * self.attn(out)
+ out = self.conv3(self.act3(out))
+ if self.attn_last is not None:
+ out = self.attn_gain * self.attn_last(out)
+ out = self.drop_path(out)
+
+ if self.skipinit_gain is not None:
+ out.mul_(self.skipinit_gain) # this slows things down more than expected, TBD
+ out = out * self.alpha + shortcut
+ return out
+
+
+def create_stem(in_chs, out_chs, stem_type='', conv_layer=None, act_layer=None, preact_feature=True):
+ stem_stride = 2
+ stem_feature = dict(num_chs=out_chs, reduction=2, module='stem.conv')
+ stem = OrderedDict()
+ assert stem_type in ('', 'deep', 'deep_tiered', 'deep_quad', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
+ if 'deep' in stem_type:
+ if 'quad' in stem_type:
+ # 4 deep conv stack as in NFNet-F models
+ assert not 'pool' in stem_type
+ stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs)
+ strides = (2, 1, 1, 2)
+ stem_stride = 4
+ stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv3')
+ else:
+ if 'tiered' in stem_type:
+ stem_chs = (3 * out_chs // 8, out_chs // 2, out_chs) # 'T' resnets in resnet.py
+ else:
+ stem_chs = (out_chs // 2, out_chs // 2, out_chs) # 'D' ResNets
+ strides = (2, 1, 1)
+ stem_feature = dict(num_chs=out_chs // 2, reduction=2, module='stem.conv2')
+ last_idx = len(stem_chs) - 1
+ for i, (c, s) in enumerate(zip(stem_chs, strides)):
+ stem[f'conv{i + 1}'] = conv_layer(in_chs, c, kernel_size=3, stride=s)
+ if i != last_idx:
+ stem[f'act{i + 2}'] = act_layer(inplace=True)
+ in_chs = c
+ elif '3x3' in stem_type:
+ # 3x3 stem conv as in RegNet
+ stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2)
+ else:
+ # 7x7 stem conv as in ResNet
+ stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
+
+ if 'pool' in stem_type:
+ stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1)
+ stem_stride = 4
+
+ return nn.Sequential(stem), stem_stride, stem_feature
+
+
+# from https://github.com/deepmind/deepmind-research/tree/master/nfnets
+_nonlin_gamma = dict(
+ identity=1.0,
+ celu=1.270926833152771,
+ elu=1.2716004848480225,
+ gelu=1.7015043497085571,
+ leaky_relu=1.70590341091156,
+ log_sigmoid=1.9193484783172607,
+ log_softmax=1.0002083778381348,
+ relu=1.7139588594436646,
+ relu6=1.7131484746932983,
+ selu=1.0008515119552612,
+ sigmoid=4.803835391998291,
+ silu=1.7881293296813965,
+ softsign=2.338853120803833,
+ softplus=1.9203323125839233,
+ tanh=1.5939117670059204,
+)
+
+
+class NormFreeNet(nn.Module):
+ """ Normalization-Free Network
+
+ As described in :
+ `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
+ - https://arxiv.org/abs/2101.08692
+ and
+ `High-Performance Large-Scale Image Recognition Without Normalization` - https://arxiv.org/abs/2102.06171
+
+ This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and
+ the (preact) ResNet models described earlier in the paper.
+
+ There are a few differences:
+ * channels are rounded to be divisible by 8 by default (keep tensor core kernels happy),
+ this changes channel dim and param counts slightly from the paper models
+ * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
+ impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
+ * a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
+ apply it in each activation. This is slightly slower, numerically different, but matches official impl.
+ * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
+ for what it is/does. Approx 8-10% throughput loss.
+ """
+ def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
+ drop_rate=0., drop_path_rate=0.):
+ super().__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
+ conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
+ if cfg.gamma_in_act:
+ act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
+ conv_layer = partial(conv_layer, eps=cfg.std_conv_eps)
+ else:
+ act_layer = get_act_layer(cfg.act_layer)
+ conv_layer = partial(conv_layer, gamma=_nonlin_gamma[cfg.act_layer], eps=cfg.std_conv_eps)
+ attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
+
+ stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
+ self.stem, stem_stride, stem_feat = create_stem(
+ in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
+
+ self.feature_info = [stem_feat]
+ drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
+ prev_chs = stem_chs
+ net_stride = stem_stride
+ dilation = 1
+ expected_var = 1.0
+ stages = []
+ for stage_idx, stage_depth in enumerate(cfg.depths):
+ stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
+ if net_stride >= output_stride and stride > 1:
+ dilation *= stride
+ stride = 1
+ net_stride *= stride
+ first_dilation = 1 if dilation in (1, 2) else 2
+
+ blocks = []
+ for block_idx in range(cfg.depths[stage_idx]):
+ first_block = block_idx == 0 and stage_idx == 0
+ out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div)
+ blocks += [NormFreeBlock(
+ in_chs=prev_chs, out_chs=out_chs,
+ alpha=cfg.alpha,
+ beta=1. / expected_var ** 0.5,
+ stride=stride if block_idx == 0 else 1,
+ dilation=dilation,
+ first_dilation=first_dilation,
+ group_size=cfg.group_size,
+ bottle_ratio=1. if cfg.reg and first_block else cfg.bottle_ratio,
+ ch_div=cfg.ch_div,
+ reg=cfg.reg,
+ extra_conv=cfg.extra_conv,
+ skipinit=cfg.skipinit,
+ attn_layer=attn_layer,
+ attn_gain=cfg.attn_gain,
+ act_layer=act_layer,
+ conv_layer=conv_layer,
+ drop_path_rate=drop_path_rates[stage_idx][block_idx],
+ )]
+ if block_idx == 0:
+ expected_var = 1. # expected var is reset after first block of each stage
+ expected_var += cfg.alpha ** 2 # Even if reset occurs, increment expected variance
+ first_dilation = dilation
+ prev_chs = out_chs
+ self.feature_info += [dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')]
+ stages += [nn.Sequential(*blocks)]
+ self.stages = nn.Sequential(*stages)
+
+ if cfg.num_features:
+ # The paper NFRegNet models have an EfficientNet-like final head convolution.
+ self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
+ self.final_conv = conv_layer(prev_chs, self.num_features, 1)
+ self.feature_info[-1] = dict(num_chs=self.num_features, reduction=net_stride, module=f'final_conv')
+ else:
+ self.num_features = prev_chs
+ self.final_conv = nn.Identity()
+ self.final_act = act_layer(inplace=cfg.num_features > 0)
+
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ for n, m in self.named_modules():
+ if 'fc' in n and isinstance(m, nn.Linear):
+ if cfg.zero_init_fc:
+ nn.init.zeros_(m.weight)
+ else:
+ nn.init.normal_(m.weight, 0., .01)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear')
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ x = self.stages(x)
+ x = self.final_conv(x)
+ x = self.final_act(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def _create_normfreenet(variant, pretrained=False, **kwargs):
+ model_cfg = model_cfgs[variant]
+ feature_cfg = dict(flatten_sequential=True)
+ return build_model_with_cfg(
+ NormFreeNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ model_cfg=model_cfg,
+ feature_cfg=feature_cfg,
+ **kwargs)
+
+
+@register_model
+def dm_nfnet_f0(pretrained=False, **kwargs):
+ """ NFNet-F0 (DeepMind weight compatible)
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('dm_nfnet_f0', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def dm_nfnet_f1(pretrained=False, **kwargs):
+ """ NFNet-F1 (DeepMind weight compatible)
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('dm_nfnet_f1', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def dm_nfnet_f2(pretrained=False, **kwargs):
+ """ NFNet-F2 (DeepMind weight compatible)
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('dm_nfnet_f2', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def dm_nfnet_f3(pretrained=False, **kwargs):
+ """ NFNet-F3 (DeepMind weight compatible)
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('dm_nfnet_f3', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def dm_nfnet_f4(pretrained=False, **kwargs):
+ """ NFNet-F4 (DeepMind weight compatible)
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('dm_nfnet_f4', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def dm_nfnet_f5(pretrained=False, **kwargs):
+ """ NFNet-F5 (DeepMind weight compatible)
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('dm_nfnet_f5', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def dm_nfnet_f6(pretrained=False, **kwargs):
+ """ NFNet-F6 (DeepMind weight compatible)
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('dm_nfnet_f6', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f0(pretrained=False, **kwargs):
+ """ NFNet-F0
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f0', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f1(pretrained=False, **kwargs):
+ """ NFNet-F1
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f1', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f2(pretrained=False, **kwargs):
+ """ NFNet-F2
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f2', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f3(pretrained=False, **kwargs):
+ """ NFNet-F3
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f3', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f4(pretrained=False, **kwargs):
+ """ NFNet-F4
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f4', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f5(pretrained=False, **kwargs):
+ """ NFNet-F5
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f5', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f6(pretrained=False, **kwargs):
+ """ NFNet-F6
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f6', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f7(pretrained=False, **kwargs):
+ """ NFNet-F7
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f7', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f0s(pretrained=False, **kwargs):
+ """ NFNet-F0 w/ SiLU
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f0s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f1s(pretrained=False, **kwargs):
+ """ NFNet-F1 w/ SiLU
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f1s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f2s(pretrained=False, **kwargs):
+ """ NFNet-F2 w/ SiLU
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f2s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f3s(pretrained=False, **kwargs):
+ """ NFNet-F3 w/ SiLU
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f3s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f4s(pretrained=False, **kwargs):
+ """ NFNet-F4 w/ SiLU
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f4s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f5s(pretrained=False, **kwargs):
+ """ NFNet-F5 w/ SiLU
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f5s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f6s(pretrained=False, **kwargs):
+ """ NFNet-F6 w/ SiLU
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f6s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_f7s(pretrained=False, **kwargs):
+ """ NFNet-F7 w/ SiLU
+ `High-Performance Large-Scale Image Recognition Without Normalization`
+ - https://arxiv.org/abs/2102.06171
+ """
+ return _create_normfreenet('nfnet_f7s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nfnet_l0(pretrained=False, **kwargs):
+ """ NFNet-L0b w/ SiLU
+ My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio
+ """
+ return _create_normfreenet('nfnet_l0', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_nfnet_l0(pretrained=False, **kwargs):
+ """ ECA-NFNet-L0 w/ SiLU
+ My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
+ """
+ return _create_normfreenet('eca_nfnet_l0', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_nfnet_l1(pretrained=False, **kwargs):
+ """ ECA-NFNet-L1 w/ SiLU
+ My experimental 'light' model w/ F1 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
+ """
+ return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_nfnet_l2(pretrained=False, **kwargs):
+ """ ECA-NFNet-L2 w/ SiLU
+ My experimental 'light' model w/ F2 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
+ """
+ return _create_normfreenet('eca_nfnet_l2', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_nfnet_l3(pretrained=False, **kwargs):
+ """ ECA-NFNet-L3 w/ SiLU
+ My experimental 'light' model w/ F3 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
+ """
+ return _create_normfreenet('eca_nfnet_l3', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_regnet_b0(pretrained=False, **kwargs):
+ """ Normalization-Free RegNet-B0
+ `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
+ - https://arxiv.org/abs/2101.08692
+ """
+ return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_regnet_b1(pretrained=False, **kwargs):
+ """ Normalization-Free RegNet-B1
+ `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
+ - https://arxiv.org/abs/2101.08692
+ """
+ return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_regnet_b2(pretrained=False, **kwargs):
+ """ Normalization-Free RegNet-B2
+ `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
+ - https://arxiv.org/abs/2101.08692
+ """
+ return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_regnet_b3(pretrained=False, **kwargs):
+ """ Normalization-Free RegNet-B3
+ `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
+ - https://arxiv.org/abs/2101.08692
+ """
+ return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_regnet_b4(pretrained=False, **kwargs):
+ """ Normalization-Free RegNet-B4
+ `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
+ - https://arxiv.org/abs/2101.08692
+ """
+ return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_regnet_b5(pretrained=False, **kwargs):
+ """ Normalization-Free RegNet-B5
+ `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
+ - https://arxiv.org/abs/2101.08692
+ """
+ return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_resnet26(pretrained=False, **kwargs):
+ """ Normalization-Free ResNet-26
+ `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
+ - https://arxiv.org/abs/2101.08692
+ """
+ return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_resnet50(pretrained=False, **kwargs):
+ """ Normalization-Free ResNet-50
+ `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
+ - https://arxiv.org/abs/2101.08692
+ """
+ return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_resnet101(pretrained=False, **kwargs):
+ """ Normalization-Free ResNet-101
+ `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
+ - https://arxiv.org/abs/2101.08692
+ """
+ return _create_normfreenet('nf_resnet101', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_seresnet26(pretrained=False, **kwargs):
+ """ Normalization-Free SE-ResNet26
+ """
+ return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_seresnet50(pretrained=False, **kwargs):
+ """ Normalization-Free SE-ResNet50
+ """
+ return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_seresnet101(pretrained=False, **kwargs):
+ """ Normalization-Free SE-ResNet101
+ """
+ return _create_normfreenet('nf_seresnet101', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_ecaresnet26(pretrained=False, **kwargs):
+ """ Normalization-Free ECA-ResNet26
+ """
+ return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_ecaresnet50(pretrained=False, **kwargs):
+ """ Normalization-Free ECA-ResNet50
+ """
+ return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def nf_ecaresnet101(pretrained=False, **kwargs):
+ """ Normalization-Free ECA-ResNet101
+ """
+ return _create_normfreenet('nf_ecaresnet101', pretrained=pretrained, **kwargs)
diff --git a/timm/models/pit.py b/timm/models/pit.py
new file mode 100644
index 0000000..460824e
--- /dev/null
+++ b/timm/models/pit.py
@@ -0,0 +1,384 @@
+""" Pooling-based Vision Transformer (PiT) in PyTorch
+
+A PyTorch implement of Pooling-based Vision Transformers as described in
+'Rethinking Spatial Dimensions of Vision Transformers' - https://arxiv.org/abs/2103.16302
+
+This code was adapted from the original version at https://github.com/naver-ai/pit, original copyright below.
+
+Modifications for timm by / Copyright 2020 Ross Wightman
+"""
+# PiT
+# Copyright 2021-present NAVER Corp.
+# Apache License v2.0
+
+import math
+import re
+from copy import deepcopy
+from functools import partial
+from typing import Tuple
+
+import torch
+from torch import nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg, overlay_external_default_cfg
+from .layers import trunc_normal_, to_2tuple
+from .registry import register_model
+from .vision_transformer import Block
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.conv', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # deit models (FB weights)
+ 'pit_ti_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_730.pth'),
+ 'pit_xs_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_781.pth'),
+ 'pit_s_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_809.pth'),
+ 'pit_b_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'),
+ 'pit_ti_distilled_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth',
+ classifier=('head', 'head_dist')),
+ 'pit_xs_distilled_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth',
+ classifier=('head', 'head_dist')),
+ 'pit_s_distilled_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth',
+ classifier=('head', 'head_dist')),
+ 'pit_b_distilled_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth',
+ classifier=('head', 'head_dist')),
+}
+
+
+class SequentialTuple(nn.Sequential):
+ """ This module exists to work around torchscript typing issues list -> list"""
+ def __init__(self, *args):
+ super(SequentialTuple, self).__init__(*args)
+
+ def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
+ for module in self:
+ x = module(x)
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self, base_dim, depth, heads, mlp_ratio, pool=None, drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None):
+ super(Transformer, self).__init__()
+ self.layers = nn.ModuleList([])
+ embed_dim = base_dim * heads
+
+ self.blocks = nn.Sequential(*[
+ Block(
+ dim=embed_dim,
+ num_heads=heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=True,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=drop_path_prob[i],
+ norm_layer=partial(nn.LayerNorm, eps=1e-6)
+ )
+ for i in range(depth)])
+
+ self.pool = pool
+
+ def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
+ x, cls_tokens = x
+ B, C, H, W = x.shape
+ token_length = cls_tokens.shape[1]
+
+ x = x.flatten(2).transpose(1, 2)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = self.blocks(x)
+
+ cls_tokens = x[:, :token_length]
+ x = x[:, token_length:]
+ x = x.transpose(1, 2).reshape(B, C, H, W)
+
+ if self.pool is not None:
+ x, cls_tokens = self.pool(x, cls_tokens)
+ return x, cls_tokens
+
+
+class ConvHeadPooling(nn.Module):
+ def __init__(self, in_feature, out_feature, stride, padding_mode='zeros'):
+ super(ConvHeadPooling, self).__init__()
+
+ self.conv = nn.Conv2d(
+ in_feature, out_feature, kernel_size=stride + 1, padding=stride // 2, stride=stride,
+ padding_mode=padding_mode, groups=in_feature)
+ self.fc = nn.Linear(in_feature, out_feature)
+
+ def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ x = self.conv(x)
+ cls_token = self.fc(cls_token)
+
+ return x, cls_token
+
+
+class ConvEmbedding(nn.Module):
+ def __init__(self, in_channels, out_channels, patch_size, stride, padding):
+ super(ConvEmbedding, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels, out_channels, kernel_size=patch_size, stride=stride, padding=padding, bias=True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+
+class PoolingVisionTransformer(nn.Module):
+ """ Pooling-based Vision Transformer
+
+ A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers'
+ - https://arxiv.org/abs/2103.16302
+ """
+ def __init__(self, img_size, patch_size, stride, base_dims, depth, heads,
+ mlp_ratio, num_classes=1000, in_chans=3, distilled=False,
+ attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
+ super(PoolingVisionTransformer, self).__init__()
+
+ padding = 0
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ height = math.floor((img_size[0] + 2 * padding - patch_size[0]) / stride + 1)
+ width = math.floor((img_size[1] + 2 * padding - patch_size[1]) / stride + 1)
+
+ self.base_dims = base_dims
+ self.heads = heads
+ self.num_classes = num_classes
+ self.num_tokens = 2 if distilled else 1
+
+ self.patch_size = patch_size
+ self.pos_embed = nn.Parameter(torch.randn(1, base_dims[0] * heads[0], height, width))
+ self.patch_embed = ConvEmbedding(in_chans, base_dims[0] * heads[0], patch_size, stride, padding)
+
+ self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, base_dims[0] * heads[0]))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ transformers = []
+ # stochastic depth decay rule
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depth)).split(depth)]
+ for stage in range(len(depth)):
+ pool = None
+ if stage < len(heads) - 1:
+ pool = ConvHeadPooling(
+ base_dims[stage] * heads[stage], base_dims[stage + 1] * heads[stage + 1], stride=2)
+ transformers += [Transformer(
+ base_dims[stage], depth[stage], heads[stage], mlp_ratio, pool=pool,
+ drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_prob=dpr[stage])
+ ]
+ self.transformers = SequentialTuple(*transformers)
+ self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
+ self.num_features = self.embed_dim = base_dims[-1] * heads[-1]
+
+ # Classifier head
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ self.head_dist = None
+ if distilled:
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def get_classifier(self):
+ if self.head_dist is not None:
+ return self.head, self.head_dist
+ else:
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ if self.head_dist is not None:
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ x = self.pos_drop(x + self.pos_embed)
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
+ x, cls_tokens = self.transformers((x, cls_tokens))
+ cls_tokens = self.norm(cls_tokens)
+ if self.head_dist is not None:
+ return cls_tokens[:, 0], cls_tokens[:, 1]
+ else:
+ return cls_tokens[:, 0]
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ if self.head_dist is not None:
+ x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
+ if self.training and not torch.jit.is_scripting():
+ return x, x_dist
+ else:
+ return (x + x_dist) / 2
+ else:
+ return self.head(x)
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """ preprocess checkpoints """
+ out_dict = {}
+ p_blocks = re.compile(r'pools\.(\d)\.')
+ for k, v in state_dict.items():
+ # FIXME need to update resize for PiT impl
+ # if k == 'pos_embed' and v.shape != model.pos_embed.shape:
+ # # To resize pos embedding when using model at different size from pretrained weights
+ # v = resize_pos_embed(v, model.pos_embed)
+ k = p_blocks.sub(lambda exp: f'transformers.{int(exp.group(1))}.pool.', k)
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_pit(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ model = build_model_with_cfg(
+ PoolingVisionTransformer, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+ return model
+
+
+@register_model
+def pit_b_224(pretrained, **kwargs):
+ model_kwargs = dict(
+ patch_size=14,
+ stride=7,
+ base_dims=[64, 64, 64],
+ depth=[3, 6, 4],
+ heads=[4, 8, 16],
+ mlp_ratio=4,
+ **kwargs
+ )
+ return _create_pit('pit_b_224', pretrained, **model_kwargs)
+
+
+@register_model
+def pit_s_224(pretrained, **kwargs):
+ model_kwargs = dict(
+ patch_size=16,
+ stride=8,
+ base_dims=[48, 48, 48],
+ depth=[2, 6, 4],
+ heads=[3, 6, 12],
+ mlp_ratio=4,
+ **kwargs
+ )
+ return _create_pit('pit_s_224', pretrained, **model_kwargs)
+
+
+@register_model
+def pit_xs_224(pretrained, **kwargs):
+ model_kwargs = dict(
+ patch_size=16,
+ stride=8,
+ base_dims=[48, 48, 48],
+ depth=[2, 6, 4],
+ heads=[2, 4, 8],
+ mlp_ratio=4,
+ **kwargs
+ )
+ return _create_pit('pit_xs_224', pretrained, **model_kwargs)
+
+
+@register_model
+def pit_ti_224(pretrained, **kwargs):
+ model_kwargs = dict(
+ patch_size=16,
+ stride=8,
+ base_dims=[32, 32, 32],
+ depth=[2, 6, 4],
+ heads=[2, 4, 8],
+ mlp_ratio=4,
+ **kwargs
+ )
+ return _create_pit('pit_ti_224', pretrained, **model_kwargs)
+
+
+@register_model
+def pit_b_distilled_224(pretrained, **kwargs):
+ model_kwargs = dict(
+ patch_size=14,
+ stride=7,
+ base_dims=[64, 64, 64],
+ depth=[3, 6, 4],
+ heads=[4, 8, 16],
+ mlp_ratio=4,
+ distilled=True,
+ **kwargs
+ )
+ return _create_pit('pit_b_distilled_224', pretrained, **model_kwargs)
+
+
+@register_model
+def pit_s_distilled_224(pretrained, **kwargs):
+ model_kwargs = dict(
+ patch_size=16,
+ stride=8,
+ base_dims=[48, 48, 48],
+ depth=[2, 6, 4],
+ heads=[3, 6, 12],
+ mlp_ratio=4,
+ distilled=True,
+ **kwargs
+ )
+ return _create_pit('pit_s_distilled_224', pretrained, **model_kwargs)
+
+
+@register_model
+def pit_xs_distilled_224(pretrained, **kwargs):
+ model_kwargs = dict(
+ patch_size=16,
+ stride=8,
+ base_dims=[48, 48, 48],
+ depth=[2, 6, 4],
+ heads=[2, 4, 8],
+ mlp_ratio=4,
+ distilled=True,
+ **kwargs
+ )
+ return _create_pit('pit_xs_distilled_224', pretrained, **model_kwargs)
+
+
+@register_model
+def pit_ti_distilled_224(pretrained, **kwargs):
+ model_kwargs = dict(
+ patch_size=16,
+ stride=8,
+ base_dims=[32, 32, 32],
+ depth=[2, 6, 4],
+ heads=[2, 4, 8],
+ mlp_ratio=4,
+ distilled=True,
+ **kwargs
+ )
+ return _create_pit('pit_ti_distilled_224', pretrained, **model_kwargs)
\ No newline at end of file
diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py
new file mode 100644
index 0000000..9991815
--- /dev/null
+++ b/timm/models/pnasnet.py
@@ -0,0 +1,350 @@
+"""
+ pnasnet5large implementation grabbed from Cadene's pretrained models
+ Additional credit to https://github.com/creafz
+
+ https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py
+
+"""
+from collections import OrderedDict
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .helpers import build_model_with_cfg
+from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier
+from .registry import register_model
+
+__all__ = ['PNASNet5Large']
+
+default_cfgs = {
+ 'pnasnet5large': {
+ 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth',
+ 'input_size': (3, 331, 331),
+ 'pool_size': (11, 11),
+ 'crop_pct': 0.911,
+ 'interpolation': 'bicubic',
+ 'mean': (0.5, 0.5, 0.5),
+ 'std': (0.5, 0.5, 0.5),
+ 'num_classes': 1000,
+ 'first_conv': 'conv_0.conv',
+ 'classifier': 'last_linear',
+ 'label_offset': 1, # 1001 classes in pretrained weights
+ },
+}
+
+
+class SeparableConv2d(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''):
+ super(SeparableConv2d, self).__init__()
+ self.depthwise_conv2d = create_conv2d(
+ in_channels, in_channels, kernel_size=kernel_size,
+ stride=stride, padding=padding, groups=in_channels)
+ self.pointwise_conv2d = create_conv2d(
+ in_channels, out_channels, kernel_size=1, padding=padding)
+
+ def forward(self, x):
+ x = self.depthwise_conv2d(x)
+ x = self.pointwise_conv2d(x)
+ return x
+
+
+class BranchSeparables(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, stem_cell=False, padding=''):
+ super(BranchSeparables, self).__init__()
+ middle_channels = out_channels if stem_cell else in_channels
+ self.act_1 = nn.ReLU()
+ self.separable_1 = SeparableConv2d(
+ in_channels, middle_channels, kernel_size, stride=stride, padding=padding)
+ self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001)
+ self.act_2 = nn.ReLU()
+ self.separable_2 = SeparableConv2d(
+ middle_channels, out_channels, kernel_size, stride=1, padding=padding)
+ self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001)
+
+ def forward(self, x):
+ x = self.act_1(x)
+ x = self.separable_1(x)
+ x = self.bn_sep_1(x)
+ x = self.act_2(x)
+ x = self.separable_2(x)
+ x = self.bn_sep_2(x)
+ return x
+
+
+class ActConvBn(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''):
+ super(ActConvBn, self).__init__()
+ self.act = nn.ReLU()
+ self.conv = create_conv2d(
+ in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
+
+ def forward(self, x):
+ x = self.act(x)
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class FactorizedReduction(nn.Module):
+
+ def __init__(self, in_channels, out_channels, padding=''):
+ super(FactorizedReduction, self).__init__()
+ self.act = nn.ReLU()
+ self.path_1 = nn.Sequential(OrderedDict([
+ ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
+ ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)),
+ ]))
+ self.path_2 = nn.Sequential(OrderedDict([
+ ('pad', nn.ZeroPad2d((-1, 1, -1, 1))), # shift
+ ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
+ ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)),
+ ]))
+ self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001)
+
+ def forward(self, x):
+ x = self.act(x)
+ x_path1 = self.path_1(x)
+ x_path2 = self.path_2(x)
+ out = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
+ return out
+
+
+class CellBase(nn.Module):
+
+ def cell_forward(self, x_left, x_right):
+ x_comb_iter_0_left = self.comb_iter_0_left(x_left)
+ x_comb_iter_0_right = self.comb_iter_0_right(x_left)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x_right)
+ x_comb_iter_1_right = self.comb_iter_1_right(x_right)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x_right)
+ x_comb_iter_2_right = self.comb_iter_2_right(x_right)
+ x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
+
+ x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2)
+ x_comb_iter_3_right = self.comb_iter_3_right(x_right)
+ x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_left)
+ if self.comb_iter_4_right is not None:
+ x_comb_iter_4_right = self.comb_iter_4_right(x_right)
+ else:
+ x_comb_iter_4_right = x_right
+ x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
+
+ x_out = torch.cat([x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class CellStem0(CellBase):
+
+ def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
+ super(CellStem0, self).__init__()
+ self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type)
+
+ self.comb_iter_0_left = BranchSeparables(
+ in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=pad_type)
+ self.comb_iter_0_right = nn.Sequential(OrderedDict([
+ ('max_pool', create_pool2d('max', 3, stride=2, padding=pad_type)),
+ ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)),
+ ('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)),
+ ]))
+
+ self.comb_iter_1_left = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=pad_type)
+ self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=pad_type)
+
+ self.comb_iter_2_left = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=pad_type)
+ self.comb_iter_2_right = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=pad_type)
+
+ self.comb_iter_3_left = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=3, padding=pad_type)
+ self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(
+ in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=pad_type)
+ self.comb_iter_4_right = ActConvBn(
+ out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=pad_type)
+
+ def forward(self, x_left):
+ x_right = self.conv_1x1(x_left)
+ x_out = self.cell_forward(x_left, x_right)
+ return x_out
+
+
+class Cell(CellBase):
+
+ def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type='',
+ is_reduction=False, match_prev_layer_dims=False):
+ super(Cell, self).__init__()
+
+ # If `is_reduction` is set to `True` stride 2 is used for
+ # convolution and pooling layers to reduce the spatial size of
+ # the output of a cell approximately by a factor of 2.
+ stride = 2 if is_reduction else 1
+
+ # If `match_prev_layer_dimensions` is set to `True`
+ # `FactorizedReduction` is used to reduce the spatial size
+ # of the left input of a cell approximately by a factor of 2.
+ self.match_prev_layer_dimensions = match_prev_layer_dims
+ if match_prev_layer_dims:
+ self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=pad_type)
+ else:
+ self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)
+ self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type)
+
+ self.comb_iter_0_left = BranchSeparables(
+ out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=pad_type)
+ self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
+
+ self.comb_iter_1_left = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=pad_type)
+ self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
+
+ self.comb_iter_2_left = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=pad_type)
+ self.comb_iter_2_right = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=pad_type)
+
+ self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3)
+ self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(
+ out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=pad_type)
+ if is_reduction:
+ self.comb_iter_4_right = ActConvBn(
+ out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=pad_type)
+ else:
+ self.comb_iter_4_right = None
+
+ def forward(self, x_left, x_right):
+ x_left = self.conv_prev_1x1(x_left)
+ x_right = self.conv_1x1(x_right)
+ x_out = self.cell_forward(x_left, x_right)
+ return x_out
+
+
+class PNASNet5Large(nn.Module):
+ def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg', pad_type=''):
+ super(PNASNet5Large, self).__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ self.num_features = 4320
+ assert output_stride == 32
+
+ self.conv_0 = ConvBnAct(
+ in_chans, 96, kernel_size=3, stride=2, padding=0,
+ norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)
+
+ self.cell_stem_0 = CellStem0(
+ in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type)
+
+ self.cell_stem_1 = Cell(
+ in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, pad_type=pad_type,
+ match_prev_layer_dims=True, is_reduction=True)
+ self.cell_0 = Cell(
+ in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, pad_type=pad_type,
+ match_prev_layer_dims=True)
+ self.cell_1 = Cell(
+ in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type)
+ self.cell_2 = Cell(
+ in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type)
+ self.cell_3 = Cell(
+ in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type)
+
+ self.cell_4 = Cell(
+ in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, pad_type=pad_type,
+ is_reduction=True)
+ self.cell_5 = Cell(
+ in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type,
+ match_prev_layer_dims=True)
+ self.cell_6 = Cell(
+ in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type)
+ self.cell_7 = Cell(
+ in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type)
+
+ self.cell_8 = Cell(
+ in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, pad_type=pad_type,
+ is_reduction=True)
+ self.cell_9 = Cell(
+ in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type,
+ match_prev_layer_dims=True)
+ self.cell_10 = Cell(
+ in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type)
+ self.cell_11 = Cell(
+ in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type)
+ self.act = nn.ReLU()
+ self.feature_info = [
+ dict(num_chs=96, reduction=2, module='conv_0'),
+ dict(num_chs=270, reduction=4, module='cell_stem_1.conv_1x1.act'),
+ dict(num_chs=1080, reduction=8, module='cell_4.conv_1x1.act'),
+ dict(num_chs=2160, reduction=16, module='cell_8.conv_1x1.act'),
+ dict(num_chs=4320, reduction=32, module='act'),
+ ]
+
+ self.global_pool, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def get_classifier(self):
+ return self.last_linear
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ x_conv_0 = self.conv_0(x)
+ x_stem_0 = self.cell_stem_0(x_conv_0)
+ x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0)
+ x_cell_0 = self.cell_0(x_stem_0, x_stem_1)
+ x_cell_1 = self.cell_1(x_stem_1, x_cell_0)
+ x_cell_2 = self.cell_2(x_cell_0, x_cell_1)
+ x_cell_3 = self.cell_3(x_cell_1, x_cell_2)
+ x_cell_4 = self.cell_4(x_cell_2, x_cell_3)
+ x_cell_5 = self.cell_5(x_cell_3, x_cell_4)
+ x_cell_6 = self.cell_6(x_cell_4, x_cell_5)
+ x_cell_7 = self.cell_7(x_cell_5, x_cell_6)
+ x_cell_8 = self.cell_8(x_cell_6, x_cell_7)
+ x_cell_9 = self.cell_9(x_cell_7, x_cell_8)
+ x_cell_10 = self.cell_10(x_cell_8, x_cell_9)
+ x_cell_11 = self.cell_11(x_cell_9, x_cell_10)
+ x = self.act(x_cell_11)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate > 0:
+ x = F.dropout(x, self.drop_rate, training=self.training)
+ x = self.last_linear(x)
+ return x
+
+
+def _create_pnasnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ PNASNet5Large, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
+ **kwargs)
+
+
+@register_model
+def pnasnet5large(pretrained=False, **kwargs):
+ r"""PNASNet-5 model architecture from the
+ `"Progressive Neural Architecture Search"
+ `_ paper.
+ """
+ model_kwargs = dict(pad_type='same', **kwargs)
+ return _create_pnasnet('pnasnet5large', pretrained, **model_kwargs)
diff --git a/timm/models/pruned/ecaresnet101d_pruned.txt b/timm/models/pruned/ecaresnet101d_pruned.txt
new file mode 100644
index 0000000..2589b2f
--- /dev/null
+++ b/timm/models/pruned/ecaresnet101d_pruned.txt
@@ -0,0 +1 @@
+conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[45, 64, 1, 1]***layer1.0.bn1.weight:[45]***layer1.0.conv2.weight:[25, 45, 3, 3]***layer1.0.bn2.weight:[25]***layer1.0.conv3.weight:[26, 25, 1, 1]***layer1.0.bn3.weight:[26]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[26, 64, 1, 1]***layer1.0.downsample.2.weight:[26]***layer1.1.conv1.weight:[53, 26, 1, 1]***layer1.1.bn1.weight:[53]***layer1.1.conv2.weight:[20, 53, 3, 3]***layer1.1.bn2.weight:[20]***layer1.1.conv3.weight:[26, 20, 1, 1]***layer1.1.bn3.weight:[26]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[60, 26, 1, 1]***layer1.2.bn1.weight:[60]***layer1.2.conv2.weight:[27, 60, 3, 3]***layer1.2.bn2.weight:[27]***layer1.2.conv3.weight:[26, 27, 1, 1]***layer1.2.bn3.weight:[26]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[81, 26, 1, 1]***layer2.0.bn1.weight:[81]***layer2.0.conv2.weight:[24, 81, 3, 3]***layer2.0.bn2.weight:[24]***layer2.0.conv3.weight:[142, 24, 1, 1]***layer2.0.bn3.weight:[142]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[142, 26, 1, 1]***layer2.0.downsample.2.weight:[142]***layer2.1.conv1.weight:[93, 142, 1, 1]***layer2.1.bn1.weight:[93]***layer2.1.conv2.weight:[49, 93, 3, 3]***layer2.1.bn2.weight:[49]***layer2.1.conv3.weight:[142, 49, 1, 1]***layer2.1.bn3.weight:[142]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[102, 142, 1, 1]***layer2.2.bn1.weight:[102]***layer2.2.conv2.weight:[54, 102, 3, 3]***layer2.2.bn2.weight:[54]***layer2.2.conv3.weight:[142, 54, 1, 1]***layer2.2.bn3.weight:[142]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[122, 142, 1, 1]***layer2.3.bn1.weight:[122]***layer2.3.conv2.weight:[78, 122, 3, 3]***layer2.3.bn2.weight:[78]***layer2.3.conv3.weight:[142, 78, 1, 1]***layer2.3.bn3.weight:[142]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[101, 142, 1, 1]***layer3.0.bn1.weight:[101]***layer3.0.conv2.weight:[25, 101, 3, 3]***layer3.0.bn2.weight:[25]***layer3.0.conv3.weight:[278, 25, 1, 1]***layer3.0.bn3.weight:[278]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[278, 142, 1, 1]***layer3.0.downsample.2.weight:[278]***layer3.1.conv1.weight:[239, 278, 1, 1]***layer3.1.bn1.weight:[239]***layer3.1.conv2.weight:[160, 239, 3, 3]***layer3.1.bn2.weight:[160]***layer3.1.conv3.weight:[278, 160, 1, 1]***layer3.1.bn3.weight:[278]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[234, 278, 1, 1]***layer3.2.bn1.weight:[234]***layer3.2.conv2.weight:[156, 234, 3, 3]***layer3.2.bn2.weight:[156]***layer3.2.conv3.weight:[278, 156, 1, 1]***layer3.2.bn3.weight:[278]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[250, 278, 1, 1]***layer3.3.bn1.weight:[250]***layer3.3.conv2.weight:[176, 250, 3, 3]***layer3.3.bn2.weight:[176]***layer3.3.conv3.weight:[278, 176, 1, 1]***layer3.3.bn3.weight:[278]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[253, 278, 1, 1]***layer3.4.bn1.weight:[253]***layer3.4.conv2.weight:[191, 253, 3, 3]***layer3.4.bn2.weight:[191]***layer3.4.conv3.weight:[278, 191, 1, 1]***layer3.4.bn3.weight:[278]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[251, 278, 1, 1]***layer3.5.bn1.weight:[251]***layer3.5.conv2.weight:[175, 251, 3, 3]***layer3.5.bn2.weight:[175]***layer3.5.conv3.weight:[278, 175, 1, 1]***layer3.5.bn3.weight:[278]***layer3.5.se.conv.weight:[1, 1, 5]***layer3.6.conv1.weight:[230, 278, 1, 1]***layer3.6.bn1.weight:[230]***layer3.6.conv2.weight:[128, 230, 3, 3]***layer3.6.bn2.weight:[128]***layer3.6.conv3.weight:[278, 128, 1, 1]***layer3.6.bn3.weight:[278]***layer3.6.se.conv.weight:[1, 1, 5]***layer3.7.conv1.weight:[244, 278, 1, 1]***layer3.7.bn1.weight:[244]***layer3.7.conv2.weight:[154, 244, 3, 3]***layer3.7.bn2.weight:[154]***layer3.7.conv3.weight:[278, 154, 1, 1]***layer3.7.bn3.weight:[278]***layer3.7.se.conv.weight:[1, 1, 5]***layer3.8.conv1.weight:[244, 278, 1, 1]***layer3.8.bn1.weight:[244]***layer3.8.conv2.weight:[159, 244, 3, 3]***layer3.8.bn2.weight:[159]***layer3.8.conv3.weight:[278, 159, 1, 1]***layer3.8.bn3.weight:[278]***layer3.8.se.conv.weight:[1, 1, 5]***layer3.9.conv1.weight:[238, 278, 1, 1]***layer3.9.bn1.weight:[238]***layer3.9.conv2.weight:[97, 238, 3, 3]***layer3.9.bn2.weight:[97]***layer3.9.conv3.weight:[278, 97, 1, 1]***layer3.9.bn3.weight:[278]***layer3.9.se.conv.weight:[1, 1, 5]***layer3.10.conv1.weight:[244, 278, 1, 1]***layer3.10.bn1.weight:[244]***layer3.10.conv2.weight:[149, 244, 3, 3]***layer3.10.bn2.weight:[149]***layer3.10.conv3.weight:[278, 149, 1, 1]***layer3.10.bn3.weight:[278]***layer3.10.se.conv.weight:[1, 1, 5]***layer3.11.conv1.weight:[253, 278, 1, 1]***layer3.11.bn1.weight:[253]***layer3.11.conv2.weight:[181, 253, 3, 3]***layer3.11.bn2.weight:[181]***layer3.11.conv3.weight:[278, 181, 1, 1]***layer3.11.bn3.weight:[278]***layer3.11.se.conv.weight:[1, 1, 5]***layer3.12.conv1.weight:[245, 278, 1, 1]***layer3.12.bn1.weight:[245]***layer3.12.conv2.weight:[119, 245, 3, 3]***layer3.12.bn2.weight:[119]***layer3.12.conv3.weight:[278, 119, 1, 1]***layer3.12.bn3.weight:[278]***layer3.12.se.conv.weight:[1, 1, 5]***layer3.13.conv1.weight:[255, 278, 1, 1]***layer3.13.bn1.weight:[255]***layer3.13.conv2.weight:[216, 255, 3, 3]***layer3.13.bn2.weight:[216]***layer3.13.conv3.weight:[278, 216, 1, 1]***layer3.13.bn3.weight:[278]***layer3.13.se.conv.weight:[1, 1, 5]***layer3.14.conv1.weight:[256, 278, 1, 1]***layer3.14.bn1.weight:[256]***layer3.14.conv2.weight:[201, 256, 3, 3]***layer3.14.bn2.weight:[201]***layer3.14.conv3.weight:[278, 201, 1, 1]***layer3.14.bn3.weight:[278]***layer3.14.se.conv.weight:[1, 1, 5]***layer3.15.conv1.weight:[253, 278, 1, 1]***layer3.15.bn1.weight:[253]***layer3.15.conv2.weight:[149, 253, 3, 3]***layer3.15.bn2.weight:[149]***layer3.15.conv3.weight:[278, 149, 1, 1]***layer3.15.bn3.weight:[278]***layer3.15.se.conv.weight:[1, 1, 5]***layer3.16.conv1.weight:[254, 278, 1, 1]***layer3.16.bn1.weight:[254]***layer3.16.conv2.weight:[141, 254, 3, 3]***layer3.16.bn2.weight:[141]***layer3.16.conv3.weight:[278, 141, 1, 1]***layer3.16.bn3.weight:[278]***layer3.16.se.conv.weight:[1, 1, 5]***layer3.17.conv1.weight:[256, 278, 1, 1]***layer3.17.bn1.weight:[256]***layer3.17.conv2.weight:[190, 256, 3, 3]***layer3.17.bn2.weight:[190]***layer3.17.conv3.weight:[278, 190, 1, 1]***layer3.17.bn3.weight:[278]***layer3.17.se.conv.weight:[1, 1, 5]***layer3.18.conv1.weight:[256, 278, 1, 1]***layer3.18.bn1.weight:[256]***layer3.18.conv2.weight:[217, 256, 3, 3]***layer3.18.bn2.weight:[217]***layer3.18.conv3.weight:[278, 217, 1, 1]***layer3.18.bn3.weight:[278]***layer3.18.se.conv.weight:[1, 1, 5]***layer3.19.conv1.weight:[255, 278, 1, 1]***layer3.19.bn1.weight:[255]***layer3.19.conv2.weight:[156, 255, 3, 3]***layer3.19.bn2.weight:[156]***layer3.19.conv3.weight:[278, 156, 1, 1]***layer3.19.bn3.weight:[278]***layer3.19.se.conv.weight:[1, 1, 5]***layer3.20.conv1.weight:[256, 278, 1, 1]***layer3.20.bn1.weight:[256]***layer3.20.conv2.weight:[155, 256, 3, 3]***layer3.20.bn2.weight:[155]***layer3.20.conv3.weight:[278, 155, 1, 1]***layer3.20.bn3.weight:[278]***layer3.20.se.conv.weight:[1, 1, 5]***layer3.21.conv1.weight:[256, 278, 1, 1]***layer3.21.bn1.weight:[256]***layer3.21.conv2.weight:[232, 256, 3, 3]***layer3.21.bn2.weight:[232]***layer3.21.conv3.weight:[278, 232, 1, 1]***layer3.21.bn3.weight:[278]***layer3.21.se.conv.weight:[1, 1, 5]***layer3.22.conv1.weight:[256, 278, 1, 1]***layer3.22.bn1.weight:[256]***layer3.22.conv2.weight:[214, 256, 3, 3]***layer3.22.bn2.weight:[214]***layer3.22.conv3.weight:[278, 214, 1, 1]***layer3.22.bn3.weight:[278]***layer3.22.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[499, 278, 1, 1]***layer4.0.bn1.weight:[499]***layer4.0.conv2.weight:[289, 499, 3, 3]***layer4.0.bn2.weight:[289]***layer4.0.conv3.weight:[2042, 289, 1, 1]***layer4.0.bn3.weight:[2042]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2042, 278, 1, 1]***layer4.0.downsample.2.weight:[2042]***layer4.1.conv1.weight:[512, 2042, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[512, 512, 3, 3]***layer4.1.bn2.weight:[512]***layer4.1.conv3.weight:[2042, 512, 1, 1]***layer4.1.bn3.weight:[2042]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2042, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[502, 512, 3, 3]***layer4.2.bn2.weight:[502]***layer4.2.conv3.weight:[2042, 502, 1, 1]***layer4.2.bn3.weight:[2042]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2042]***layer1_2_conv3_M.weight:[256, 26]***layer2_3_conv3_M.weight:[512, 142]***layer3_22_conv3_M.weight:[1024, 278]***layer4_2_conv3_M.weight:[2048, 2042]
\ No newline at end of file
diff --git a/timm/models/pruned/ecaresnet50d_pruned.txt b/timm/models/pruned/ecaresnet50d_pruned.txt
new file mode 100644
index 0000000..9a8b2bf
--- /dev/null
+++ b/timm/models/pruned/ecaresnet50d_pruned.txt
@@ -0,0 +1 @@
+conv1.0.weight:[32, 3, 3, 3]***conv1.1.weight:[32]***conv1.3.weight:[32, 32, 3, 3]***conv1.4.weight:[32]***conv1.6.weight:[64, 32, 3, 3]***bn1.weight:[64]***layer1.0.conv1.weight:[47, 64, 1, 1]***layer1.0.bn1.weight:[47]***layer1.0.conv2.weight:[18, 47, 3, 3]***layer1.0.bn2.weight:[18]***layer1.0.conv3.weight:[19, 18, 1, 1]***layer1.0.bn3.weight:[19]***layer1.0.se.conv.weight:[1, 1, 5]***layer1.0.downsample.1.weight:[19, 64, 1, 1]***layer1.0.downsample.2.weight:[19]***layer1.1.conv1.weight:[52, 19, 1, 1]***layer1.1.bn1.weight:[52]***layer1.1.conv2.weight:[22, 52, 3, 3]***layer1.1.bn2.weight:[22]***layer1.1.conv3.weight:[19, 22, 1, 1]***layer1.1.bn3.weight:[19]***layer1.1.se.conv.weight:[1, 1, 5]***layer1.2.conv1.weight:[64, 19, 1, 1]***layer1.2.bn1.weight:[64]***layer1.2.conv2.weight:[35, 64, 3, 3]***layer1.2.bn2.weight:[35]***layer1.2.conv3.weight:[19, 35, 1, 1]***layer1.2.bn3.weight:[19]***layer1.2.se.conv.weight:[1, 1, 5]***layer2.0.conv1.weight:[85, 19, 1, 1]***layer2.0.bn1.weight:[85]***layer2.0.conv2.weight:[37, 85, 3, 3]***layer2.0.bn2.weight:[37]***layer2.0.conv3.weight:[171, 37, 1, 1]***layer2.0.bn3.weight:[171]***layer2.0.se.conv.weight:[1, 1, 5]***layer2.0.downsample.1.weight:[171, 19, 1, 1]***layer2.0.downsample.2.weight:[171]***layer2.1.conv1.weight:[107, 171, 1, 1]***layer2.1.bn1.weight:[107]***layer2.1.conv2.weight:[80, 107, 3, 3]***layer2.1.bn2.weight:[80]***layer2.1.conv3.weight:[171, 80, 1, 1]***layer2.1.bn3.weight:[171]***layer2.1.se.conv.weight:[1, 1, 5]***layer2.2.conv1.weight:[120, 171, 1, 1]***layer2.2.bn1.weight:[120]***layer2.2.conv2.weight:[85, 120, 3, 3]***layer2.2.bn2.weight:[85]***layer2.2.conv3.weight:[171, 85, 1, 1]***layer2.2.bn3.weight:[171]***layer2.2.se.conv.weight:[1, 1, 5]***layer2.3.conv1.weight:[125, 171, 1, 1]***layer2.3.bn1.weight:[125]***layer2.3.conv2.weight:[87, 125, 3, 3]***layer2.3.bn2.weight:[87]***layer2.3.conv3.weight:[171, 87, 1, 1]***layer2.3.bn3.weight:[171]***layer2.3.se.conv.weight:[1, 1, 5]***layer3.0.conv1.weight:[198, 171, 1, 1]***layer3.0.bn1.weight:[198]***layer3.0.conv2.weight:[126, 198, 3, 3]***layer3.0.bn2.weight:[126]***layer3.0.conv3.weight:[818, 126, 1, 1]***layer3.0.bn3.weight:[818]***layer3.0.se.conv.weight:[1, 1, 5]***layer3.0.downsample.1.weight:[818, 171, 1, 1]***layer3.0.downsample.2.weight:[818]***layer3.1.conv1.weight:[255, 818, 1, 1]***layer3.1.bn1.weight:[255]***layer3.1.conv2.weight:[232, 255, 3, 3]***layer3.1.bn2.weight:[232]***layer3.1.conv3.weight:[818, 232, 1, 1]***layer3.1.bn3.weight:[818]***layer3.1.se.conv.weight:[1, 1, 5]***layer3.2.conv1.weight:[256, 818, 1, 1]***layer3.2.bn1.weight:[256]***layer3.2.conv2.weight:[233, 256, 3, 3]***layer3.2.bn2.weight:[233]***layer3.2.conv3.weight:[818, 233, 1, 1]***layer3.2.bn3.weight:[818]***layer3.2.se.conv.weight:[1, 1, 5]***layer3.3.conv1.weight:[253, 818, 1, 1]***layer3.3.bn1.weight:[253]***layer3.3.conv2.weight:[235, 253, 3, 3]***layer3.3.bn2.weight:[235]***layer3.3.conv3.weight:[818, 235, 1, 1]***layer3.3.bn3.weight:[818]***layer3.3.se.conv.weight:[1, 1, 5]***layer3.4.conv1.weight:[256, 818, 1, 1]***layer3.4.bn1.weight:[256]***layer3.4.conv2.weight:[225, 256, 3, 3]***layer3.4.bn2.weight:[225]***layer3.4.conv3.weight:[818, 225, 1, 1]***layer3.4.bn3.weight:[818]***layer3.4.se.conv.weight:[1, 1, 5]***layer3.5.conv1.weight:[256, 818, 1, 1]***layer3.5.bn1.weight:[256]***layer3.5.conv2.weight:[239, 256, 3, 3]***layer3.5.bn2.weight:[239]***layer3.5.conv3.weight:[818, 239, 1, 1]***layer3.5.bn3.weight:[818]***layer3.5.se.conv.weight:[1, 1, 5]***layer4.0.conv1.weight:[492, 818, 1, 1]***layer4.0.bn1.weight:[492]***layer4.0.conv2.weight:[237, 492, 3, 3]***layer4.0.bn2.weight:[237]***layer4.0.conv3.weight:[2022, 237, 1, 1]***layer4.0.bn3.weight:[2022]***layer4.0.se.conv.weight:[1, 1, 7]***layer4.0.downsample.1.weight:[2022, 818, 1, 1]***layer4.0.downsample.2.weight:[2022]***layer4.1.conv1.weight:[512, 2022, 1, 1]***layer4.1.bn1.weight:[512]***layer4.1.conv2.weight:[500, 512, 3, 3]***layer4.1.bn2.weight:[500]***layer4.1.conv3.weight:[2022, 500, 1, 1]***layer4.1.bn3.weight:[2022]***layer4.1.se.conv.weight:[1, 1, 7]***layer4.2.conv1.weight:[512, 2022, 1, 1]***layer4.2.bn1.weight:[512]***layer4.2.conv2.weight:[490, 512, 3, 3]***layer4.2.bn2.weight:[490]***layer4.2.conv3.weight:[2022, 490, 1, 1]***layer4.2.bn3.weight:[2022]***layer4.2.se.conv.weight:[1, 1, 7]***fc.weight:[1000, 2022]***layer1_2_conv3_M.weight:[256, 19]***layer2_3_conv3_M.weight:[512, 171]***layer3_5_conv3_M.weight:[1024, 818]***layer4_2_conv3_M.weight:[2048, 2022]
\ No newline at end of file
diff --git a/timm/models/pruned/efficientnet_b1_pruned.txt b/timm/models/pruned/efficientnet_b1_pruned.txt
new file mode 100644
index 0000000..0972b52
--- /dev/null
+++ b/timm/models/pruned/efficientnet_b1_pruned.txt
@@ -0,0 +1 @@
+conv_stem.weight:[32, 3, 3, 3]***bn1.weight:[32]***bn1.bias:[32]***bn1.running_mean:[32]***bn1.running_var:[32]***bn1.num_batches_tracked:[]***blocks.0.0.conv_dw.weight:[32, 1, 3, 3]***blocks.0.0.bn1.weight:[32]***blocks.0.0.bn1.bias:[32]***blocks.0.0.bn1.running_mean:[32]***blocks.0.0.bn1.running_var:[32]***blocks.0.0.bn1.num_batches_tracked:[]***blocks.0.0.se.conv_reduce.weight:[8, 32, 1, 1]***blocks.0.0.se.conv_reduce.bias:[8]***blocks.0.0.se.conv_expand.weight:[32, 8, 1, 1]***blocks.0.0.se.conv_expand.bias:[32]***blocks.0.0.conv_pw.weight:[16, 32, 1, 1]***blocks.0.0.bn2.weight:[16]***blocks.0.0.bn2.bias:[16]***blocks.0.0.bn2.running_mean:[16]***blocks.0.0.bn2.running_var:[16]***blocks.0.0.bn2.num_batches_tracked:[]***blocks.0.1.conv_dw.weight:[16, 1, 3, 3]***blocks.0.1.bn1.weight:[16]***blocks.0.1.bn1.bias:[16]***blocks.0.1.bn1.running_mean:[16]***blocks.0.1.bn1.running_var:[16]***blocks.0.1.bn1.num_batches_tracked:[]***blocks.0.1.se.conv_reduce.weight:[4, 16, 1, 1]***blocks.0.1.se.conv_reduce.bias:[4]***blocks.0.1.se.conv_expand.weight:[16, 4, 1, 1]***blocks.0.1.se.conv_expand.bias:[16]***blocks.0.1.conv_pw.weight:[16, 16, 1, 1]***blocks.0.1.bn2.weight:[16]***blocks.0.1.bn2.bias:[16]***blocks.0.1.bn2.running_mean:[16]***blocks.0.1.bn2.running_var:[16]***blocks.0.1.bn2.num_batches_tracked:[]***blocks.1.0.conv_pw.weight:[48, 16, 1, 1]***blocks.1.0.bn1.weight:[48]***blocks.1.0.bn1.bias:[48]***blocks.1.0.bn1.running_mean:[48]***blocks.1.0.bn1.running_var:[48]***blocks.1.0.bn1.num_batches_tracked:[]***blocks.1.0.conv_dw.weight:[48, 1, 3, 3]***blocks.1.0.bn2.weight:[48]***blocks.1.0.bn2.bias:[48]***blocks.1.0.bn2.running_mean:[48]***blocks.1.0.bn2.running_var:[48]***blocks.1.0.bn2.num_batches_tracked:[]***blocks.1.0.se.conv_reduce.weight:[4, 48, 1, 1]***blocks.1.0.se.conv_reduce.bias:[4]***blocks.1.0.se.conv_expand.weight:[48, 4, 1, 1]***blocks.1.0.se.conv_expand.bias:[48]***blocks.1.0.conv_pwl.weight:[12, 48, 1, 1]***blocks.1.0.bn3.weight:[12]***blocks.1.0.bn3.bias:[12]***blocks.1.0.bn3.running_mean:[12]***blocks.1.0.bn3.running_var:[12]***blocks.1.0.bn3.num_batches_tracked:[]***blocks.1.1.conv_pw.weight:[62, 12, 1, 1]***blocks.1.1.bn1.weight:[62]***blocks.1.1.bn1.bias:[62]***blocks.1.1.bn1.running_mean:[62]***blocks.1.1.bn1.running_var:[62]***blocks.1.1.bn1.num_batches_tracked:[]***blocks.1.1.conv_dw.weight:[62, 1, 3, 3]***blocks.1.1.bn2.weight:[62]***blocks.1.1.bn2.bias:[62]***blocks.1.1.bn2.running_mean:[62]***blocks.1.1.bn2.running_var:[62]***blocks.1.1.bn2.num_batches_tracked:[]***blocks.1.1.se.conv_reduce.weight:[6, 62, 1, 1]***blocks.1.1.se.conv_reduce.bias:[6]***blocks.1.1.se.conv_expand.weight:[62, 6, 1, 1]***blocks.1.1.se.conv_expand.bias:[62]***blocks.1.1.conv_pwl.weight:[12, 62, 1, 1]***blocks.1.1.bn3.weight:[12]***blocks.1.1.bn3.bias:[12]***blocks.1.1.bn3.running_mean:[12]***blocks.1.1.bn3.running_var:[12]***blocks.1.1.bn3.num_batches_tracked:[]***blocks.1.2.conv_pw.weight:[48, 12, 1, 1]***blocks.1.2.bn1.weight:[48]***blocks.1.2.bn1.bias:[48]***blocks.1.2.bn1.running_mean:[48]***blocks.1.2.bn1.running_var:[48]***blocks.1.2.bn1.num_batches_tracked:[]***blocks.1.2.conv_dw.weight:[48, 1, 3, 3]***blocks.1.2.bn2.weight:[48]***blocks.1.2.bn2.bias:[48]***blocks.1.2.bn2.running_mean:[48]***blocks.1.2.bn2.running_var:[48]***blocks.1.2.bn2.num_batches_tracked:[]***blocks.1.2.se.conv_reduce.weight:[6, 48, 1, 1]***blocks.1.2.se.conv_reduce.bias:[6]***blocks.1.2.se.conv_expand.weight:[48, 6, 1, 1]***blocks.1.2.se.conv_expand.bias:[48]***blocks.1.2.conv_pwl.weight:[12, 48, 1, 1]***blocks.1.2.bn3.weight:[12]***blocks.1.2.bn3.bias:[12]***blocks.1.2.bn3.running_mean:[12]***blocks.1.2.bn3.running_var:[12]***blocks.1.2.bn3.num_batches_tracked:[]***blocks.2.0.conv_pw.weight:[70, 12, 1, 1]***blocks.2.0.bn1.weight:[70]***blocks.2.0.bn1.bias:[70]***blocks.2.0.bn1.running_mean:[70]***blocks.2.0.bn1.running_var:[70]***blocks.2.0.bn1.num_batches_tracked:[]***blocks.2.0.conv_dw.weight:[70, 1, 5, 5]***blocks.2.0.bn2.weight:[70]***blocks.2.0.bn2.bias:[70]***blocks.2.0.bn2.running_mean:[70]***blocks.2.0.bn2.running_var:[70]***blocks.2.0.bn2.num_batches_tracked:[]***blocks.2.0.se.conv_reduce.weight:[6, 70, 1, 1]***blocks.2.0.se.conv_reduce.bias:[6]***blocks.2.0.se.conv_expand.weight:[70, 6, 1, 1]***blocks.2.0.se.conv_expand.bias:[70]***blocks.2.0.conv_pwl.weight:[35, 70, 1, 1]***blocks.2.0.bn3.weight:[35]***blocks.2.0.bn3.bias:[35]***blocks.2.0.bn3.running_mean:[35]***blocks.2.0.bn3.running_var:[35]***blocks.2.0.bn3.num_batches_tracked:[]***blocks.2.1.conv_pw.weight:[61, 35, 1, 1]***blocks.2.1.bn1.weight:[61]***blocks.2.1.bn1.bias:[61]***blocks.2.1.bn1.running_mean:[61]***blocks.2.1.bn1.running_var:[61]***blocks.2.1.bn1.num_batches_tracked:[]***blocks.2.1.conv_dw.weight:[61, 1, 5, 5]***blocks.2.1.bn2.weight:[61]***blocks.2.1.bn2.bias:[61]***blocks.2.1.bn2.running_mean:[61]***blocks.2.1.bn2.running_var:[61]***blocks.2.1.bn2.num_batches_tracked:[]***blocks.2.1.se.conv_reduce.weight:[10, 61, 1, 1]***blocks.2.1.se.conv_reduce.bias:[10]***blocks.2.1.se.conv_expand.weight:[61, 10, 1, 1]***blocks.2.1.se.conv_expand.bias:[61]***blocks.2.1.conv_pwl.weight:[35, 61, 1, 1]***blocks.2.1.bn3.weight:[35]***blocks.2.1.bn3.bias:[35]***blocks.2.1.bn3.running_mean:[35]***blocks.2.1.bn3.running_var:[35]***blocks.2.1.bn3.num_batches_tracked:[]***blocks.2.2.conv_pw.weight:[51, 35, 1, 1]***blocks.2.2.bn1.weight:[51]***blocks.2.2.bn1.bias:[51]***blocks.2.2.bn1.running_mean:[51]***blocks.2.2.bn1.running_var:[51]***blocks.2.2.bn1.num_batches_tracked:[]***blocks.2.2.conv_dw.weight:[51, 1, 5, 5]***blocks.2.2.bn2.weight:[51]***blocks.2.2.bn2.bias:[51]***blocks.2.2.bn2.running_mean:[51]***blocks.2.2.bn2.running_var:[51]***blocks.2.2.bn2.num_batches_tracked:[]***blocks.2.2.se.conv_reduce.weight:[10, 51, 1, 1]***blocks.2.2.se.conv_reduce.bias:[10]***blocks.2.2.se.conv_expand.weight:[51, 10, 1, 1]***blocks.2.2.se.conv_expand.bias:[51]***blocks.2.2.conv_pwl.weight:[35, 51, 1, 1]***blocks.2.2.bn3.weight:[35]***blocks.2.2.bn3.bias:[35]***blocks.2.2.bn3.running_mean:[35]***blocks.2.2.bn3.running_var:[35]***blocks.2.2.bn3.num_batches_tracked:[]***blocks.3.0.conv_pw.weight:[175, 35, 1, 1]***blocks.3.0.bn1.weight:[175]***blocks.3.0.bn1.bias:[175]***blocks.3.0.bn1.running_mean:[175]***blocks.3.0.bn1.running_var:[175]***blocks.3.0.bn1.num_batches_tracked:[]***blocks.3.0.conv_dw.weight:[175, 1, 3, 3]***blocks.3.0.bn2.weight:[175]***blocks.3.0.bn2.bias:[175]***blocks.3.0.bn2.running_mean:[175]***blocks.3.0.bn2.running_var:[175]***blocks.3.0.bn2.num_batches_tracked:[]***blocks.3.0.se.conv_reduce.weight:[10, 175, 1, 1]***blocks.3.0.se.conv_reduce.bias:[10]***blocks.3.0.se.conv_expand.weight:[175, 10, 1, 1]***blocks.3.0.se.conv_expand.bias:[175]***blocks.3.0.conv_pwl.weight:[74, 175, 1, 1]***blocks.3.0.bn3.weight:[74]***blocks.3.0.bn3.bias:[74]***blocks.3.0.bn3.running_mean:[74]***blocks.3.0.bn3.running_var:[74]***blocks.3.0.bn3.num_batches_tracked:[]***blocks.3.1.conv_pw.weight:[188, 74, 1, 1]***blocks.3.1.bn1.weight:[188]***blocks.3.1.bn1.bias:[188]***blocks.3.1.bn1.running_mean:[188]***blocks.3.1.bn1.running_var:[188]***blocks.3.1.bn1.num_batches_tracked:[]***blocks.3.1.conv_dw.weight:[188, 1, 3, 3]***blocks.3.1.bn2.weight:[188]***blocks.3.1.bn2.bias:[188]***blocks.3.1.bn2.running_mean:[188]***blocks.3.1.bn2.running_var:[188]***blocks.3.1.bn2.num_batches_tracked:[]***blocks.3.1.se.conv_reduce.weight:[20, 188, 1, 1]***blocks.3.1.se.conv_reduce.bias:[20]***blocks.3.1.se.conv_expand.weight:[188, 20, 1, 1]***blocks.3.1.se.conv_expand.bias:[188]***blocks.3.1.conv_pwl.weight:[74, 188, 1, 1]***blocks.3.1.bn3.weight:[74]***blocks.3.1.bn3.bias:[74]***blocks.3.1.bn3.running_mean:[74]***blocks.3.1.bn3.running_var:[74]***blocks.3.1.bn3.num_batches_tracked:[]***blocks.3.2.conv_pw.weight:[137, 74, 1, 1]***blocks.3.2.bn1.weight:[137]***blocks.3.2.bn1.bias:[137]***blocks.3.2.bn1.running_mean:[137]***blocks.3.2.bn1.running_var:[137]***blocks.3.2.bn1.num_batches_tracked:[]***blocks.3.2.conv_dw.weight:[137, 1, 3, 3]***blocks.3.2.bn2.weight:[137]***blocks.3.2.bn2.bias:[137]***blocks.3.2.bn2.running_mean:[137]***blocks.3.2.bn2.running_var:[137]***blocks.3.2.bn2.num_batches_tracked:[]***blocks.3.2.se.conv_reduce.weight:[20, 137, 1, 1]***blocks.3.2.se.conv_reduce.bias:[20]***blocks.3.2.se.conv_expand.weight:[137, 20, 1, 1]***blocks.3.2.se.conv_expand.bias:[137]***blocks.3.2.conv_pwl.weight:[74, 137, 1, 1]***blocks.3.2.bn3.weight:[74]***blocks.3.2.bn3.bias:[74]***blocks.3.2.bn3.running_mean:[74]***blocks.3.2.bn3.running_var:[74]***blocks.3.2.bn3.num_batches_tracked:[]***blocks.3.3.conv_pw.weight:[164, 74, 1, 1]***blocks.3.3.bn1.weight:[164]***blocks.3.3.bn1.bias:[164]***blocks.3.3.bn1.running_mean:[164]***blocks.3.3.bn1.running_var:[164]***blocks.3.3.bn1.num_batches_tracked:[]***blocks.3.3.conv_dw.weight:[164, 1, 3, 3]***blocks.3.3.bn2.weight:[164]***blocks.3.3.bn2.bias:[164]***blocks.3.3.bn2.running_mean:[164]***blocks.3.3.bn2.running_var:[164]***blocks.3.3.bn2.num_batches_tracked:[]***blocks.3.3.se.conv_reduce.weight:[20, 164, 1, 1]***blocks.3.3.se.conv_reduce.bias:[20]***blocks.3.3.se.conv_expand.weight:[164, 20, 1, 1]***blocks.3.3.se.conv_expand.bias:[164]***blocks.3.3.conv_pwl.weight:[74, 164, 1, 1]***blocks.3.3.bn3.weight:[74]***blocks.3.3.bn3.bias:[74]***blocks.3.3.bn3.running_mean:[74]***blocks.3.3.bn3.running_var:[74]***blocks.3.3.bn3.num_batches_tracked:[]***blocks.4.0.conv_pw.weight:[399, 74, 1, 1]***blocks.4.0.bn1.weight:[399]***blocks.4.0.bn1.bias:[399]***blocks.4.0.bn1.running_mean:[399]***blocks.4.0.bn1.running_var:[399]***blocks.4.0.bn1.num_batches_tracked:[]***blocks.4.0.conv_dw.weight:[399, 1, 5, 5]***blocks.4.0.bn2.weight:[399]***blocks.4.0.bn2.bias:[399]***blocks.4.0.bn2.running_mean:[399]***blocks.4.0.bn2.running_var:[399]***blocks.4.0.bn2.num_batches_tracked:[]***blocks.4.0.se.conv_reduce.weight:[20, 399, 1, 1]***blocks.4.0.se.conv_reduce.bias:[20]***blocks.4.0.se.conv_expand.weight:[399, 20, 1, 1]***blocks.4.0.se.conv_expand.bias:[399]***blocks.4.0.conv_pwl.weight:[67, 399, 1, 1]***blocks.4.0.bn3.weight:[67]***blocks.4.0.bn3.bias:[67]***blocks.4.0.bn3.running_mean:[67]***blocks.4.0.bn3.running_var:[67]***blocks.4.0.bn3.num_batches_tracked:[]***blocks.4.1.conv_pw.weight:[201, 67, 1, 1]***blocks.4.1.bn1.weight:[201]***blocks.4.1.bn1.bias:[201]***blocks.4.1.bn1.running_mean:[201]***blocks.4.1.bn1.running_var:[201]***blocks.4.1.bn1.num_batches_tracked:[]***blocks.4.1.conv_dw.weight:[201, 1, 5, 5]***blocks.4.1.bn2.weight:[201]***blocks.4.1.bn2.bias:[201]***blocks.4.1.bn2.running_mean:[201]***blocks.4.1.bn2.running_var:[201]***blocks.4.1.bn2.num_batches_tracked:[]***blocks.4.1.se.conv_reduce.weight:[28, 201, 1, 1]***blocks.4.1.se.conv_reduce.bias:[28]***blocks.4.1.se.conv_expand.weight:[201, 28, 1, 1]***blocks.4.1.se.conv_expand.bias:[201]***blocks.4.1.conv_pwl.weight:[67, 201, 1, 1]***blocks.4.1.bn3.weight:[67]***blocks.4.1.bn3.bias:[67]***blocks.4.1.bn3.running_mean:[67]***blocks.4.1.bn3.running_var:[67]***blocks.4.1.bn3.num_batches_tracked:[]***blocks.4.2.conv_pw.weight:[160, 67, 1, 1]***blocks.4.2.bn1.weight:[160]***blocks.4.2.bn1.bias:[160]***blocks.4.2.bn1.running_mean:[160]***blocks.4.2.bn1.running_var:[160]***blocks.4.2.bn1.num_batches_tracked:[]***blocks.4.2.conv_dw.weight:[160, 1, 5, 5]***blocks.4.2.bn2.weight:[160]***blocks.4.2.bn2.bias:[160]***blocks.4.2.bn2.running_mean:[160]***blocks.4.2.bn2.running_var:[160]***blocks.4.2.bn2.num_batches_tracked:[]***blocks.4.2.se.conv_reduce.weight:[28, 160, 1, 1]***blocks.4.2.se.conv_reduce.bias:[28]***blocks.4.2.se.conv_expand.weight:[160, 28, 1, 1]***blocks.4.2.se.conv_expand.bias:[160]***blocks.4.2.conv_pwl.weight:[67, 160, 1, 1]***blocks.4.2.bn3.weight:[67]***blocks.4.2.bn3.bias:[67]***blocks.4.2.bn3.running_mean:[67]***blocks.4.2.bn3.running_var:[67]***blocks.4.2.bn3.num_batches_tracked:[]***blocks.4.3.conv_pw.weight:[213, 67, 1, 1]***blocks.4.3.bn1.weight:[213]***blocks.4.3.bn1.bias:[213]***blocks.4.3.bn1.running_mean:[213]***blocks.4.3.bn1.running_var:[213]***blocks.4.3.bn1.num_batches_tracked:[]***blocks.4.3.conv_dw.weight:[213, 1, 5, 5]***blocks.4.3.bn2.weight:[213]***blocks.4.3.bn2.bias:[213]***blocks.4.3.bn2.running_mean:[213]***blocks.4.3.bn2.running_var:[213]***blocks.4.3.bn2.num_batches_tracked:[]***blocks.4.3.se.conv_reduce.weight:[28, 213, 1, 1]***blocks.4.3.se.conv_reduce.bias:[28]***blocks.4.3.se.conv_expand.weight:[213, 28, 1, 1]***blocks.4.3.se.conv_expand.bias:[213]***blocks.4.3.conv_pwl.weight:[67, 213, 1, 1]***blocks.4.3.bn3.weight:[67]***blocks.4.3.bn3.bias:[67]***blocks.4.3.bn3.running_mean:[67]***blocks.4.3.bn3.running_var:[67]***blocks.4.3.bn3.num_batches_tracked:[]***blocks.5.0.conv_pw.weight:[637, 67, 1, 1]***blocks.5.0.bn1.weight:[637]***blocks.5.0.bn1.bias:[637]***blocks.5.0.bn1.running_mean:[637]***blocks.5.0.bn1.running_var:[637]***blocks.5.0.bn1.num_batches_tracked:[]***blocks.5.0.conv_dw.weight:[637, 1, 5, 5]***blocks.5.0.bn2.weight:[637]***blocks.5.0.bn2.bias:[637]***blocks.5.0.bn2.running_mean:[637]***blocks.5.0.bn2.running_var:[637]***blocks.5.0.bn2.num_batches_tracked:[]***blocks.5.0.se.conv_reduce.weight:[27, 637, 1, 1]***blocks.5.0.se.conv_reduce.bias:[27]***blocks.5.0.se.conv_expand.weight:[637, 27, 1, 1]***blocks.5.0.se.conv_expand.bias:[637]***blocks.5.0.conv_pwl.weight:[192, 637, 1, 1]***blocks.5.0.bn3.weight:[192]***blocks.5.0.bn3.bias:[192]***blocks.5.0.bn3.running_mean:[192]***blocks.5.0.bn3.running_var:[192]***blocks.5.0.bn3.num_batches_tracked:[]***blocks.5.1.conv_pw.weight:[806, 192, 1, 1]***blocks.5.1.bn1.weight:[806]***blocks.5.1.bn1.bias:[806]***blocks.5.1.bn1.running_mean:[806]***blocks.5.1.bn1.running_var:[806]***blocks.5.1.bn1.num_batches_tracked:[]***blocks.5.1.conv_dw.weight:[806, 1, 5, 5]***blocks.5.1.bn2.weight:[806]***blocks.5.1.bn2.bias:[806]***blocks.5.1.bn2.running_mean:[806]***blocks.5.1.bn2.running_var:[806]***blocks.5.1.bn2.num_batches_tracked:[]***blocks.5.1.se.conv_reduce.weight:[48, 806, 1, 1]***blocks.5.1.se.conv_reduce.bias:[48]***blocks.5.1.se.conv_expand.weight:[806, 48, 1, 1]***blocks.5.1.se.conv_expand.bias:[806]***blocks.5.1.conv_pwl.weight:[192, 806, 1, 1]***blocks.5.1.bn3.weight:[192]***blocks.5.1.bn3.bias:[192]***blocks.5.1.bn3.running_mean:[192]***blocks.5.1.bn3.running_var:[192]***blocks.5.1.bn3.num_batches_tracked:[]***blocks.5.2.conv_pw.weight:[798, 192, 1, 1]***blocks.5.2.bn1.weight:[798]***blocks.5.2.bn1.bias:[798]***blocks.5.2.bn1.running_mean:[798]***blocks.5.2.bn1.running_var:[798]***blocks.5.2.bn1.num_batches_tracked:[]***blocks.5.2.conv_dw.weight:[798, 1, 5, 5]***blocks.5.2.bn2.weight:[798]***blocks.5.2.bn2.bias:[798]***blocks.5.2.bn2.running_mean:[798]***blocks.5.2.bn2.running_var:[798]***blocks.5.2.bn2.num_batches_tracked:[]***blocks.5.2.se.conv_reduce.weight:[48, 798, 1, 1]***blocks.5.2.se.conv_reduce.bias:[48]***blocks.5.2.se.conv_expand.weight:[798, 48, 1, 1]***blocks.5.2.se.conv_expand.bias:[798]***blocks.5.2.conv_pwl.weight:[192, 798, 1, 1]***blocks.5.2.bn3.weight:[192]***blocks.5.2.bn3.bias:[192]***blocks.5.2.bn3.running_mean:[192]***blocks.5.2.bn3.running_var:[192]***blocks.5.2.bn3.num_batches_tracked:[]***blocks.5.3.conv_pw.weight:[891, 192, 1, 1]***blocks.5.3.bn1.weight:[891]***blocks.5.3.bn1.bias:[891]***blocks.5.3.bn1.running_mean:[891]***blocks.5.3.bn1.running_var:[891]***blocks.5.3.bn1.num_batches_tracked:[]***blocks.5.3.conv_dw.weight:[891, 1, 5, 5]***blocks.5.3.bn2.weight:[891]***blocks.5.3.bn2.bias:[891]***blocks.5.3.bn2.running_mean:[891]***blocks.5.3.bn2.running_var:[891]***blocks.5.3.bn2.num_batches_tracked:[]***blocks.5.3.se.conv_reduce.weight:[48, 891, 1, 1]***blocks.5.3.se.conv_reduce.bias:[48]***blocks.5.3.se.conv_expand.weight:[891, 48, 1, 1]***blocks.5.3.se.conv_expand.bias:[891]***blocks.5.3.conv_pwl.weight:[192, 891, 1, 1]***blocks.5.3.bn3.weight:[192]***blocks.5.3.bn3.bias:[192]***blocks.5.3.bn3.running_mean:[192]***blocks.5.3.bn3.running_var:[192]***blocks.5.3.bn3.num_batches_tracked:[]***blocks.5.4.conv_pw.weight:[990, 192, 1, 1]***blocks.5.4.bn1.weight:[990]***blocks.5.4.bn1.bias:[990]***blocks.5.4.bn1.running_mean:[990]***blocks.5.4.bn1.running_var:[990]***blocks.5.4.bn1.num_batches_tracked:[]***blocks.5.4.conv_dw.weight:[990, 1, 5, 5]***blocks.5.4.bn2.weight:[990]***blocks.5.4.bn2.bias:[990]***blocks.5.4.bn2.running_mean:[990]***blocks.5.4.bn2.running_var:[990]***blocks.5.4.bn2.num_batches_tracked:[]***blocks.5.4.se.conv_reduce.weight:[48, 990, 1, 1]***blocks.5.4.se.conv_reduce.bias:[48]***blocks.5.4.se.conv_expand.weight:[990, 48, 1, 1]***blocks.5.4.se.conv_expand.bias:[990]***blocks.5.4.conv_pwl.weight:[192, 990, 1, 1]***blocks.5.4.bn3.weight:[192]***blocks.5.4.bn3.bias:[192]***blocks.5.4.bn3.running_mean:[192]***blocks.5.4.bn3.running_var:[192]***blocks.5.4.bn3.num_batches_tracked:[]***blocks.6.0.conv_pw.weight:[1152, 192, 1, 1]***blocks.6.0.bn1.weight:[1152]***blocks.6.0.bn1.bias:[1152]***blocks.6.0.bn1.running_mean:[1152]***blocks.6.0.bn1.running_var:[1152]***blocks.6.0.bn1.num_batches_tracked:[]***blocks.6.0.conv_dw.weight:[1152, 1, 3, 3]***blocks.6.0.bn2.weight:[1152]***blocks.6.0.bn2.bias:[1152]***blocks.6.0.bn2.running_mean:[1152]***blocks.6.0.bn2.running_var:[1152]***blocks.6.0.bn2.num_batches_tracked:[]***blocks.6.0.se.conv_reduce.weight:[48, 1152, 1, 1]***blocks.6.0.se.conv_reduce.bias:[48]***blocks.6.0.se.conv_expand.weight:[1152, 48, 1, 1]***blocks.6.0.se.conv_expand.bias:[1152]***blocks.6.0.conv_pwl.weight:[320, 1152, 1, 1]***blocks.6.0.bn3.weight:[320]***blocks.6.0.bn3.bias:[320]***blocks.6.0.bn3.running_mean:[320]***blocks.6.0.bn3.running_var:[320]***blocks.6.0.bn3.num_batches_tracked:[]***blocks.6.1.conv_pw.weight:[1912, 320, 1, 1]***blocks.6.1.bn1.weight:[1912]***blocks.6.1.bn1.bias:[1912]***blocks.6.1.bn1.running_mean:[1912]***blocks.6.1.bn1.running_var:[1912]***blocks.6.1.bn1.num_batches_tracked:[]***blocks.6.1.conv_dw.weight:[1912, 1, 3, 3]***blocks.6.1.bn2.weight:[1912]***blocks.6.1.bn2.bias:[1912]***blocks.6.1.bn2.running_mean:[1912]***blocks.6.1.bn2.running_var:[1912]***blocks.6.1.bn2.num_batches_tracked:[]***blocks.6.1.se.conv_reduce.weight:[80, 1912, 1, 1]***blocks.6.1.se.conv_reduce.bias:[80]***blocks.6.1.se.conv_expand.weight:[1912, 80, 1, 1]***blocks.6.1.se.conv_expand.bias:[1912]***blocks.6.1.conv_pwl.weight:[320, 1912, 1, 1]***blocks.6.1.bn3.weight:[320]***blocks.6.1.bn3.bias:[320]***blocks.6.1.bn3.running_mean:[320]***blocks.6.1.bn3.running_var:[320]***blocks.6.1.bn3.num_batches_tracked:[]***conv_head.weight:[1280, 320, 1, 1]***bn2.weight:[1280]***bn2.bias:[1280]***bn2.running_mean:[1280]***bn2.running_var:[1280]***bn2.num_batches_tracked:[]***classifier.weight:[1000, 1280]***classifier.bias:[1000]
\ No newline at end of file
diff --git a/timm/models/pruned/efficientnet_b2_pruned.txt b/timm/models/pruned/efficientnet_b2_pruned.txt
new file mode 100644
index 0000000..6e3fade
--- /dev/null
+++ b/timm/models/pruned/efficientnet_b2_pruned.txt
@@ -0,0 +1 @@
+conv_stem.weight:[32, 3, 3, 3]***bn1.weight:[32]***bn1.bias:[32]***bn1.running_mean:[32]***bn1.running_var:[32]***bn1.num_batches_tracked:[]***blocks.0.0.conv_dw.weight:[32, 1, 3, 3]***blocks.0.0.bn1.weight:[32]***blocks.0.0.bn1.bias:[32]***blocks.0.0.bn1.running_mean:[32]***blocks.0.0.bn1.running_var:[32]***blocks.0.0.bn1.num_batches_tracked:[]***blocks.0.0.se.conv_reduce.weight:[8, 32, 1, 1]***blocks.0.0.se.conv_reduce.bias:[8]***blocks.0.0.se.conv_expand.weight:[32, 8, 1, 1]***blocks.0.0.se.conv_expand.bias:[32]***blocks.0.0.conv_pw.weight:[16, 32, 1, 1]***blocks.0.0.bn2.weight:[16]***blocks.0.0.bn2.bias:[16]***blocks.0.0.bn2.running_mean:[16]***blocks.0.0.bn2.running_var:[16]***blocks.0.0.bn2.num_batches_tracked:[]***blocks.0.1.conv_dw.weight:[16, 1, 3, 3]***blocks.0.1.bn1.weight:[16]***blocks.0.1.bn1.bias:[16]***blocks.0.1.bn1.running_mean:[16]***blocks.0.1.bn1.running_var:[16]***blocks.0.1.bn1.num_batches_tracked:[]***blocks.0.1.se.conv_reduce.weight:[4, 16, 1, 1]***blocks.0.1.se.conv_reduce.bias:[4]***blocks.0.1.se.conv_expand.weight:[16, 4, 1, 1]***blocks.0.1.se.conv_expand.bias:[16]***blocks.0.1.conv_pw.weight:[16, 16, 1, 1]***blocks.0.1.bn2.weight:[16]***blocks.0.1.bn2.bias:[16]***blocks.0.1.bn2.running_mean:[16]***blocks.0.1.bn2.running_var:[16]***blocks.0.1.bn2.num_batches_tracked:[]***blocks.1.0.conv_pw.weight:[54, 16, 1, 1]***blocks.1.0.bn1.weight:[54]***blocks.1.0.bn1.bias:[54]***blocks.1.0.bn1.running_mean:[54]***blocks.1.0.bn1.running_var:[54]***blocks.1.0.bn1.num_batches_tracked:[]***blocks.1.0.conv_dw.weight:[54, 1, 3, 3]***blocks.1.0.bn2.weight:[54]***blocks.1.0.bn2.bias:[54]***blocks.1.0.bn2.running_mean:[54]***blocks.1.0.bn2.running_var:[54]***blocks.1.0.bn2.num_batches_tracked:[]***blocks.1.0.se.conv_reduce.weight:[4, 54, 1, 1]***blocks.1.0.se.conv_reduce.bias:[4]***blocks.1.0.se.conv_expand.weight:[54, 4, 1, 1]***blocks.1.0.se.conv_expand.bias:[54]***blocks.1.0.conv_pwl.weight:[17, 54, 1, 1]***blocks.1.0.bn3.weight:[17]***blocks.1.0.bn3.bias:[17]***blocks.1.0.bn3.running_mean:[17]***blocks.1.0.bn3.running_var:[17]***blocks.1.0.bn3.num_batches_tracked:[]***blocks.1.1.conv_pw.weight:[69, 17, 1, 1]***blocks.1.1.bn1.weight:[69]***blocks.1.1.bn1.bias:[69]***blocks.1.1.bn1.running_mean:[69]***blocks.1.1.bn1.running_var:[69]***blocks.1.1.bn1.num_batches_tracked:[]***blocks.1.1.conv_dw.weight:[69, 1, 3, 3]***blocks.1.1.bn2.weight:[69]***blocks.1.1.bn2.bias:[69]***blocks.1.1.bn2.running_mean:[69]***blocks.1.1.bn2.running_var:[69]***blocks.1.1.bn2.num_batches_tracked:[]***blocks.1.1.se.conv_reduce.weight:[6, 69, 1, 1]***blocks.1.1.se.conv_reduce.bias:[6]***blocks.1.1.se.conv_expand.weight:[69, 6, 1, 1]***blocks.1.1.se.conv_expand.bias:[69]***blocks.1.1.conv_pwl.weight:[17, 69, 1, 1]***blocks.1.1.bn3.weight:[17]***blocks.1.1.bn3.bias:[17]***blocks.1.1.bn3.running_mean:[17]***blocks.1.1.bn3.running_var:[17]***blocks.1.1.bn3.num_batches_tracked:[]***blocks.1.2.conv_pw.weight:[61, 17, 1, 1]***blocks.1.2.bn1.weight:[61]***blocks.1.2.bn1.bias:[61]***blocks.1.2.bn1.running_mean:[61]***blocks.1.2.bn1.running_var:[61]***blocks.1.2.bn1.num_batches_tracked:[]***blocks.1.2.conv_dw.weight:[61, 1, 3, 3]***blocks.1.2.bn2.weight:[61]***blocks.1.2.bn2.bias:[61]***blocks.1.2.bn2.running_mean:[61]***blocks.1.2.bn2.running_var:[61]***blocks.1.2.bn2.num_batches_tracked:[]***blocks.1.2.se.conv_reduce.weight:[6, 61, 1, 1]***blocks.1.2.se.conv_reduce.bias:[6]***blocks.1.2.se.conv_expand.weight:[61, 6, 1, 1]***blocks.1.2.se.conv_expand.bias:[61]***blocks.1.2.conv_pwl.weight:[17, 61, 1, 1]***blocks.1.2.bn3.weight:[17]***blocks.1.2.bn3.bias:[17]***blocks.1.2.bn3.running_mean:[17]***blocks.1.2.bn3.running_var:[17]***blocks.1.2.bn3.num_batches_tracked:[]***blocks.2.0.conv_pw.weight:[86, 17, 1, 1]***blocks.2.0.bn1.weight:[86]***blocks.2.0.bn1.bias:[86]***blocks.2.0.bn1.running_mean:[86]***blocks.2.0.bn1.running_var:[86]***blocks.2.0.bn1.num_batches_tracked:[]***blocks.2.0.conv_dw.weight:[86, 1, 5, 5]***blocks.2.0.bn2.weight:[86]***blocks.2.0.bn2.bias:[86]***blocks.2.0.bn2.running_mean:[86]***blocks.2.0.bn2.running_var:[86]***blocks.2.0.bn2.num_batches_tracked:[]***blocks.2.0.se.conv_reduce.weight:[6, 86, 1, 1]***blocks.2.0.se.conv_reduce.bias:[6]***blocks.2.0.se.conv_expand.weight:[86, 6, 1, 1]***blocks.2.0.se.conv_expand.bias:[86]***blocks.2.0.conv_pwl.weight:[42, 86, 1, 1]***blocks.2.0.bn3.weight:[42]***blocks.2.0.bn3.bias:[42]***blocks.2.0.bn3.running_mean:[42]***blocks.2.0.bn3.running_var:[42]***blocks.2.0.bn3.num_batches_tracked:[]***blocks.2.1.conv_pw.weight:[72, 42, 1, 1]***blocks.2.1.bn1.weight:[72]***blocks.2.1.bn1.bias:[72]***blocks.2.1.bn1.running_mean:[72]***blocks.2.1.bn1.running_var:[72]***blocks.2.1.bn1.num_batches_tracked:[]***blocks.2.1.conv_dw.weight:[72, 1, 5, 5]***blocks.2.1.bn2.weight:[72]***blocks.2.1.bn2.bias:[72]***blocks.2.1.bn2.running_mean:[72]***blocks.2.1.bn2.running_var:[72]***blocks.2.1.bn2.num_batches_tracked:[]***blocks.2.1.se.conv_reduce.weight:[12, 72, 1, 1]***blocks.2.1.se.conv_reduce.bias:[12]***blocks.2.1.se.conv_expand.weight:[72, 12, 1, 1]***blocks.2.1.se.conv_expand.bias:[72]***blocks.2.1.conv_pwl.weight:[42, 72, 1, 1]***blocks.2.1.bn3.weight:[42]***blocks.2.1.bn3.bias:[42]***blocks.2.1.bn3.running_mean:[42]***blocks.2.1.bn3.running_var:[42]***blocks.2.1.bn3.num_batches_tracked:[]***blocks.2.2.conv_pw.weight:[98, 42, 1, 1]***blocks.2.2.bn1.weight:[98]***blocks.2.2.bn1.bias:[98]***blocks.2.2.bn1.running_mean:[98]***blocks.2.2.bn1.running_var:[98]***blocks.2.2.bn1.num_batches_tracked:[]***blocks.2.2.conv_dw.weight:[98, 1, 5, 5]***blocks.2.2.bn2.weight:[98]***blocks.2.2.bn2.bias:[98]***blocks.2.2.bn2.running_mean:[98]***blocks.2.2.bn2.running_var:[98]***blocks.2.2.bn2.num_batches_tracked:[]***blocks.2.2.se.conv_reduce.weight:[12, 98, 1, 1]***blocks.2.2.se.conv_reduce.bias:[12]***blocks.2.2.se.conv_expand.weight:[98, 12, 1, 1]***blocks.2.2.se.conv_expand.bias:[98]***blocks.2.2.conv_pwl.weight:[42, 98, 1, 1]***blocks.2.2.bn3.weight:[42]***blocks.2.2.bn3.bias:[42]***blocks.2.2.bn3.running_mean:[42]***blocks.2.2.bn3.running_var:[42]***blocks.2.2.bn3.num_batches_tracked:[]***blocks.3.0.conv_pw.weight:[245, 42, 1, 1]***blocks.3.0.bn1.weight:[245]***blocks.3.0.bn1.bias:[245]***blocks.3.0.bn1.running_mean:[245]***blocks.3.0.bn1.running_var:[245]***blocks.3.0.bn1.num_batches_tracked:[]***blocks.3.0.conv_dw.weight:[245, 1, 3, 3]***blocks.3.0.bn2.weight:[245]***blocks.3.0.bn2.bias:[245]***blocks.3.0.bn2.running_mean:[245]***blocks.3.0.bn2.running_var:[245]***blocks.3.0.bn2.num_batches_tracked:[]***blocks.3.0.se.conv_reduce.weight:[12, 245, 1, 1]***blocks.3.0.se.conv_reduce.bias:[12]***blocks.3.0.se.conv_expand.weight:[245, 12, 1, 1]***blocks.3.0.se.conv_expand.bias:[245]***blocks.3.0.conv_pwl.weight:[85, 245, 1, 1]***blocks.3.0.bn3.weight:[85]***blocks.3.0.bn3.bias:[85]***blocks.3.0.bn3.running_mean:[85]***blocks.3.0.bn3.running_var:[85]***blocks.3.0.bn3.num_batches_tracked:[]***blocks.3.1.conv_pw.weight:[274, 85, 1, 1]***blocks.3.1.bn1.weight:[274]***blocks.3.1.bn1.bias:[274]***blocks.3.1.bn1.running_mean:[274]***blocks.3.1.bn1.running_var:[274]***blocks.3.1.bn1.num_batches_tracked:[]***blocks.3.1.conv_dw.weight:[274, 1, 3, 3]***blocks.3.1.bn2.weight:[274]***blocks.3.1.bn2.bias:[274]***blocks.3.1.bn2.running_mean:[274]***blocks.3.1.bn2.running_var:[274]***blocks.3.1.bn2.num_batches_tracked:[]***blocks.3.1.se.conv_reduce.weight:[22, 274, 1, 1]***blocks.3.1.se.conv_reduce.bias:[22]***blocks.3.1.se.conv_expand.weight:[274, 22, 1, 1]***blocks.3.1.se.conv_expand.bias:[274]***blocks.3.1.conv_pwl.weight:[85, 274, 1, 1]***blocks.3.1.bn3.weight:[85]***blocks.3.1.bn3.bias:[85]***blocks.3.1.bn3.running_mean:[85]***blocks.3.1.bn3.running_var:[85]***blocks.3.1.bn3.num_batches_tracked:[]***blocks.3.2.conv_pw.weight:[254, 85, 1, 1]***blocks.3.2.bn1.weight:[254]***blocks.3.2.bn1.bias:[254]***blocks.3.2.bn1.running_mean:[254]***blocks.3.2.bn1.running_var:[254]***blocks.3.2.bn1.num_batches_tracked:[]***blocks.3.2.conv_dw.weight:[254, 1, 3, 3]***blocks.3.2.bn2.weight:[254]***blocks.3.2.bn2.bias:[254]***blocks.3.2.bn2.running_mean:[254]***blocks.3.2.bn2.running_var:[254]***blocks.3.2.bn2.num_batches_tracked:[]***blocks.3.2.se.conv_reduce.weight:[22, 254, 1, 1]***blocks.3.2.se.conv_reduce.bias:[22]***blocks.3.2.se.conv_expand.weight:[254, 22, 1, 1]***blocks.3.2.se.conv_expand.bias:[254]***blocks.3.2.conv_pwl.weight:[85, 254, 1, 1]***blocks.3.2.bn3.weight:[85]***blocks.3.2.bn3.bias:[85]***blocks.3.2.bn3.running_mean:[85]***blocks.3.2.bn3.running_var:[85]***blocks.3.2.bn3.num_batches_tracked:[]***blocks.3.3.conv_pw.weight:[292, 85, 1, 1]***blocks.3.3.bn1.weight:[292]***blocks.3.3.bn1.bias:[292]***blocks.3.3.bn1.running_mean:[292]***blocks.3.3.bn1.running_var:[292]***blocks.3.3.bn1.num_batches_tracked:[]***blocks.3.3.conv_dw.weight:[292, 1, 3, 3]***blocks.3.3.bn2.weight:[292]***blocks.3.3.bn2.bias:[292]***blocks.3.3.bn2.running_mean:[292]***blocks.3.3.bn2.running_var:[292]***blocks.3.3.bn2.num_batches_tracked:[]***blocks.3.3.se.conv_reduce.weight:[22, 292, 1, 1]***blocks.3.3.se.conv_reduce.bias:[22]***blocks.3.3.se.conv_expand.weight:[292, 22, 1, 1]***blocks.3.3.se.conv_expand.bias:[292]***blocks.3.3.conv_pwl.weight:[85, 292, 1, 1]***blocks.3.3.bn3.weight:[85]***blocks.3.3.bn3.bias:[85]***blocks.3.3.bn3.running_mean:[85]***blocks.3.3.bn3.running_var:[85]***blocks.3.3.bn3.num_batches_tracked:[]***blocks.4.0.conv_pw.weight:[502, 85, 1, 1]***blocks.4.0.bn1.weight:[502]***blocks.4.0.bn1.bias:[502]***blocks.4.0.bn1.running_mean:[502]***blocks.4.0.bn1.running_var:[502]***blocks.4.0.bn1.num_batches_tracked:[]***blocks.4.0.conv_dw.weight:[502, 1, 5, 5]***blocks.4.0.bn2.weight:[502]***blocks.4.0.bn2.bias:[502]***blocks.4.0.bn2.running_mean:[502]***blocks.4.0.bn2.running_var:[502]***blocks.4.0.bn2.num_batches_tracked:[]***blocks.4.0.se.conv_reduce.weight:[22, 502, 1, 1]***blocks.4.0.se.conv_reduce.bias:[22]***blocks.4.0.se.conv_expand.weight:[502, 22, 1, 1]***blocks.4.0.se.conv_expand.bias:[502]***blocks.4.0.conv_pwl.weight:[116, 502, 1, 1]***blocks.4.0.bn3.weight:[116]***blocks.4.0.bn3.bias:[116]***blocks.4.0.bn3.running_mean:[116]***blocks.4.0.bn3.running_var:[116]***blocks.4.0.bn3.num_batches_tracked:[]***blocks.4.1.conv_pw.weight:[315, 116, 1, 1]***blocks.4.1.bn1.weight:[315]***blocks.4.1.bn1.bias:[315]***blocks.4.1.bn1.running_mean:[315]***blocks.4.1.bn1.running_var:[315]***blocks.4.1.bn1.num_batches_tracked:[]***blocks.4.1.conv_dw.weight:[315, 1, 5, 5]***blocks.4.1.bn2.weight:[315]***blocks.4.1.bn2.bias:[315]***blocks.4.1.bn2.running_mean:[315]***blocks.4.1.bn2.running_var:[315]***blocks.4.1.bn2.num_batches_tracked:[]***blocks.4.1.se.conv_reduce.weight:[30, 315, 1, 1]***blocks.4.1.se.conv_reduce.bias:[30]***blocks.4.1.se.conv_expand.weight:[315, 30, 1, 1]***blocks.4.1.se.conv_expand.bias:[315]***blocks.4.1.conv_pwl.weight:[116, 315, 1, 1]***blocks.4.1.bn3.weight:[116]***blocks.4.1.bn3.bias:[116]***blocks.4.1.bn3.running_mean:[116]***blocks.4.1.bn3.running_var:[116]***blocks.4.1.bn3.num_batches_tracked:[]***blocks.4.2.conv_pw.weight:[354, 116, 1, 1]***blocks.4.2.bn1.weight:[354]***blocks.4.2.bn1.bias:[354]***blocks.4.2.bn1.running_mean:[354]***blocks.4.2.bn1.running_var:[354]***blocks.4.2.bn1.num_batches_tracked:[]***blocks.4.2.conv_dw.weight:[354, 1, 5, 5]***blocks.4.2.bn2.weight:[354]***blocks.4.2.bn2.bias:[354]***blocks.4.2.bn2.running_mean:[354]***blocks.4.2.bn2.running_var:[354]***blocks.4.2.bn2.num_batches_tracked:[]***blocks.4.2.se.conv_reduce.weight:[30, 354, 1, 1]***blocks.4.2.se.conv_reduce.bias:[30]***blocks.4.2.se.conv_expand.weight:[354, 30, 1, 1]***blocks.4.2.se.conv_expand.bias:[354]***blocks.4.2.conv_pwl.weight:[116, 354, 1, 1]***blocks.4.2.bn3.weight:[116]***blocks.4.2.bn3.bias:[116]***blocks.4.2.bn3.running_mean:[116]***blocks.4.2.bn3.running_var:[116]***blocks.4.2.bn3.num_batches_tracked:[]***blocks.4.3.conv_pw.weight:[443, 116, 1, 1]***blocks.4.3.bn1.weight:[443]***blocks.4.3.bn1.bias:[443]***blocks.4.3.bn1.running_mean:[443]***blocks.4.3.bn1.running_var:[443]***blocks.4.3.bn1.num_batches_tracked:[]***blocks.4.3.conv_dw.weight:[443, 1, 5, 5]***blocks.4.3.bn2.weight:[443]***blocks.4.3.bn2.bias:[443]***blocks.4.3.bn2.running_mean:[443]***blocks.4.3.bn2.running_var:[443]***blocks.4.3.bn2.num_batches_tracked:[]***blocks.4.3.se.conv_reduce.weight:[30, 443, 1, 1]***blocks.4.3.se.conv_reduce.bias:[30]***blocks.4.3.se.conv_expand.weight:[443, 30, 1, 1]***blocks.4.3.se.conv_expand.bias:[443]***blocks.4.3.conv_pwl.weight:[116, 443, 1, 1]***blocks.4.3.bn3.weight:[116]***blocks.4.3.bn3.bias:[116]***blocks.4.3.bn3.running_mean:[116]***blocks.4.3.bn3.running_var:[116]***blocks.4.3.bn3.num_batches_tracked:[]***blocks.5.0.conv_pw.weight:[719, 116, 1, 1]***blocks.5.0.bn1.weight:[719]***blocks.5.0.bn1.bias:[719]***blocks.5.0.bn1.running_mean:[719]***blocks.5.0.bn1.running_var:[719]***blocks.5.0.bn1.num_batches_tracked:[]***blocks.5.0.conv_dw.weight:[719, 1, 5, 5]***blocks.5.0.bn2.weight:[719]***blocks.5.0.bn2.bias:[719]***blocks.5.0.bn2.running_mean:[719]***blocks.5.0.bn2.running_var:[719]***blocks.5.0.bn2.num_batches_tracked:[]***blocks.5.0.se.conv_reduce.weight:[30, 719, 1, 1]***blocks.5.0.se.conv_reduce.bias:[30]***blocks.5.0.se.conv_expand.weight:[719, 30, 1, 1]***blocks.5.0.se.conv_expand.bias:[719]***blocks.5.0.conv_pwl.weight:[208, 719, 1, 1]***blocks.5.0.bn3.weight:[208]***blocks.5.0.bn3.bias:[208]***blocks.5.0.bn3.running_mean:[208]***blocks.5.0.bn3.running_var:[208]***blocks.5.0.bn3.num_batches_tracked:[]***blocks.5.1.conv_pw.weight:[1148, 208, 1, 1]***blocks.5.1.bn1.weight:[1148]***blocks.5.1.bn1.bias:[1148]***blocks.5.1.bn1.running_mean:[1148]***blocks.5.1.bn1.running_var:[1148]***blocks.5.1.bn1.num_batches_tracked:[]***blocks.5.1.conv_dw.weight:[1148, 1, 5, 5]***blocks.5.1.bn2.weight:[1148]***blocks.5.1.bn2.bias:[1148]***blocks.5.1.bn2.running_mean:[1148]***blocks.5.1.bn2.running_var:[1148]***blocks.5.1.bn2.num_batches_tracked:[]***blocks.5.1.se.conv_reduce.weight:[52, 1148, 1, 1]***blocks.5.1.se.conv_reduce.bias:[52]***blocks.5.1.se.conv_expand.weight:[1148, 52, 1, 1]***blocks.5.1.se.conv_expand.bias:[1148]***blocks.5.1.conv_pwl.weight:[208, 1148, 1, 1]***blocks.5.1.bn3.weight:[208]***blocks.5.1.bn3.bias:[208]***blocks.5.1.bn3.running_mean:[208]***blocks.5.1.bn3.running_var:[208]***blocks.5.1.bn3.num_batches_tracked:[]***blocks.5.2.conv_pw.weight:[1160, 208, 1, 1]***blocks.5.2.bn1.weight:[1160]***blocks.5.2.bn1.bias:[1160]***blocks.5.2.bn1.running_mean:[1160]***blocks.5.2.bn1.running_var:[1160]***blocks.5.2.bn1.num_batches_tracked:[]***blocks.5.2.conv_dw.weight:[1160, 1, 5, 5]***blocks.5.2.bn2.weight:[1160]***blocks.5.2.bn2.bias:[1160]***blocks.5.2.bn2.running_mean:[1160]***blocks.5.2.bn2.running_var:[1160]***blocks.5.2.bn2.num_batches_tracked:[]***blocks.5.2.se.conv_reduce.weight:[52, 1160, 1, 1]***blocks.5.2.se.conv_reduce.bias:[52]***blocks.5.2.se.conv_expand.weight:[1160, 52, 1, 1]***blocks.5.2.se.conv_expand.bias:[1160]***blocks.5.2.conv_pwl.weight:[208, 1160, 1, 1]***blocks.5.2.bn3.weight:[208]***blocks.5.2.bn3.bias:[208]***blocks.5.2.bn3.running_mean:[208]***blocks.5.2.bn3.running_var:[208]***blocks.5.2.bn3.num_batches_tracked:[]***blocks.5.3.conv_pw.weight:[1182, 208, 1, 1]***blocks.5.3.bn1.weight:[1182]***blocks.5.3.bn1.bias:[1182]***blocks.5.3.bn1.running_mean:[1182]***blocks.5.3.bn1.running_var:[1182]***blocks.5.3.bn1.num_batches_tracked:[]***blocks.5.3.conv_dw.weight:[1182, 1, 5, 5]***blocks.5.3.bn2.weight:[1182]***blocks.5.3.bn2.bias:[1182]***blocks.5.3.bn2.running_mean:[1182]***blocks.5.3.bn2.running_var:[1182]***blocks.5.3.bn2.num_batches_tracked:[]***blocks.5.3.se.conv_reduce.weight:[52, 1182, 1, 1]***blocks.5.3.se.conv_reduce.bias:[52]***blocks.5.3.se.conv_expand.weight:[1182, 52, 1, 1]***blocks.5.3.se.conv_expand.bias:[1182]***blocks.5.3.conv_pwl.weight:[208, 1182, 1, 1]***blocks.5.3.bn3.weight:[208]***blocks.5.3.bn3.bias:[208]***blocks.5.3.bn3.running_mean:[208]***blocks.5.3.bn3.running_var:[208]***blocks.5.3.bn3.num_batches_tracked:[]***blocks.5.4.conv_pw.weight:[1228, 208, 1, 1]***blocks.5.4.bn1.weight:[1228]***blocks.5.4.bn1.bias:[1228]***blocks.5.4.bn1.running_mean:[1228]***blocks.5.4.bn1.running_var:[1228]***blocks.5.4.bn1.num_batches_tracked:[]***blocks.5.4.conv_dw.weight:[1228, 1, 5, 5]***blocks.5.4.bn2.weight:[1228]***blocks.5.4.bn2.bias:[1228]***blocks.5.4.bn2.running_mean:[1228]***blocks.5.4.bn2.running_var:[1228]***blocks.5.4.bn2.num_batches_tracked:[]***blocks.5.4.se.conv_reduce.weight:[52, 1228, 1, 1]***blocks.5.4.se.conv_reduce.bias:[52]***blocks.5.4.se.conv_expand.weight:[1228, 52, 1, 1]***blocks.5.4.se.conv_expand.bias:[1228]***blocks.5.4.conv_pwl.weight:[208, 1228, 1, 1]***blocks.5.4.bn3.weight:[208]***blocks.5.4.bn3.bias:[208]***blocks.5.4.bn3.running_mean:[208]***blocks.5.4.bn3.running_var:[208]***blocks.5.4.bn3.num_batches_tracked:[]***blocks.6.0.conv_pw.weight:[1248, 208, 1, 1]***blocks.6.0.bn1.weight:[1248]***blocks.6.0.bn1.bias:[1248]***blocks.6.0.bn1.running_mean:[1248]***blocks.6.0.bn1.running_var:[1248]***blocks.6.0.bn1.num_batches_tracked:[]***blocks.6.0.conv_dw.weight:[1248, 1, 3, 3]***blocks.6.0.bn2.weight:[1248]***blocks.6.0.bn2.bias:[1248]***blocks.6.0.bn2.running_mean:[1248]***blocks.6.0.bn2.running_var:[1248]***blocks.6.0.bn2.num_batches_tracked:[]***blocks.6.0.se.conv_reduce.weight:[52, 1248, 1, 1]***blocks.6.0.se.conv_reduce.bias:[52]***blocks.6.0.se.conv_expand.weight:[1248, 52, 1, 1]***blocks.6.0.se.conv_expand.bias:[1248]***blocks.6.0.conv_pwl.weight:[352, 1248, 1, 1]***blocks.6.0.bn3.weight:[352]***blocks.6.0.bn3.bias:[352]***blocks.6.0.bn3.running_mean:[352]***blocks.6.0.bn3.running_var:[352]***blocks.6.0.bn3.num_batches_tracked:[]***blocks.6.1.conv_pw.weight:[2112, 352, 1, 1]***blocks.6.1.bn1.weight:[2112]***blocks.6.1.bn1.bias:[2112]***blocks.6.1.bn1.running_mean:[2112]***blocks.6.1.bn1.running_var:[2112]***blocks.6.1.bn1.num_batches_tracked:[]***blocks.6.1.conv_dw.weight:[2112, 1, 3, 3]***blocks.6.1.bn2.weight:[2112]***blocks.6.1.bn2.bias:[2112]***blocks.6.1.bn2.running_mean:[2112]***blocks.6.1.bn2.running_var:[2112]***blocks.6.1.bn2.num_batches_tracked:[]***blocks.6.1.se.conv_reduce.weight:[88, 2112, 1, 1]***blocks.6.1.se.conv_reduce.bias:[88]***blocks.6.1.se.conv_expand.weight:[2112, 88, 1, 1]***blocks.6.1.se.conv_expand.bias:[2112]***blocks.6.1.conv_pwl.weight:[352, 2112, 1, 1]***blocks.6.1.bn3.weight:[352]***blocks.6.1.bn3.bias:[352]***blocks.6.1.bn3.running_mean:[352]***blocks.6.1.bn3.running_var:[352]***blocks.6.1.bn3.num_batches_tracked:[]***conv_head.weight:[1408, 352, 1, 1]***bn2.weight:[1408]***bn2.bias:[1408]***bn2.running_mean:[1408]***bn2.running_var:[1408]***bn2.num_batches_tracked:[]***classifier.weight:[1000, 1408]***classifier.bias:[1000]
\ No newline at end of file
diff --git a/timm/models/pruned/efficientnet_b3_pruned.txt b/timm/models/pruned/efficientnet_b3_pruned.txt
new file mode 100644
index 0000000..4897817
--- /dev/null
+++ b/timm/models/pruned/efficientnet_b3_pruned.txt
@@ -0,0 +1 @@
+conv_stem.weight:[40, 3, 3, 3]***bn1.weight:[40]***bn1.bias:[40]***bn1.running_mean:[40]***bn1.running_var:[40]***bn1.num_batches_tracked:[]***blocks.0.0.conv_dw.weight:[40, 1, 3, 3]***blocks.0.0.bn1.weight:[40]***blocks.0.0.bn1.bias:[40]***blocks.0.0.bn1.running_mean:[40]***blocks.0.0.bn1.running_var:[40]***blocks.0.0.bn1.num_batches_tracked:[]***blocks.0.0.se.conv_reduce.weight:[10, 40, 1, 1]***blocks.0.0.se.conv_reduce.bias:[10]***blocks.0.0.se.conv_expand.weight:[40, 10, 1, 1]***blocks.0.0.se.conv_expand.bias:[40]***blocks.0.0.conv_pw.weight:[24, 40, 1, 1]***blocks.0.0.bn2.weight:[24]***blocks.0.0.bn2.bias:[24]***blocks.0.0.bn2.running_mean:[24]***blocks.0.0.bn2.running_var:[24]***blocks.0.0.bn2.num_batches_tracked:[]***blocks.0.1.conv_dw.weight:[24, 1, 3, 3]***blocks.0.1.bn1.weight:[24]***blocks.0.1.bn1.bias:[24]***blocks.0.1.bn1.running_mean:[24]***blocks.0.1.bn1.running_var:[24]***blocks.0.1.bn1.num_batches_tracked:[]***blocks.0.1.se.conv_reduce.weight:[6, 24, 1, 1]***blocks.0.1.se.conv_reduce.bias:[6]***blocks.0.1.se.conv_expand.weight:[24, 6, 1, 1]***blocks.0.1.se.conv_expand.bias:[24]***blocks.0.1.conv_pw.weight:[24, 24, 1, 1]***blocks.0.1.bn2.weight:[24]***blocks.0.1.bn2.bias:[24]***blocks.0.1.bn2.running_mean:[24]***blocks.0.1.bn2.running_var:[24]***blocks.0.1.bn2.num_batches_tracked:[]***blocks.1.0.conv_pw.weight:[27, 24, 1, 1]***blocks.1.0.bn1.weight:[27]***blocks.1.0.bn1.bias:[27]***blocks.1.0.bn1.running_mean:[27]***blocks.1.0.bn1.running_var:[27]***blocks.1.0.bn1.num_batches_tracked:[]***blocks.1.0.conv_dw.weight:[27, 1, 3, 3]***blocks.1.0.bn2.weight:[27]***blocks.1.0.bn2.bias:[27]***blocks.1.0.bn2.running_mean:[27]***blocks.1.0.bn2.running_var:[27]***blocks.1.0.bn2.num_batches_tracked:[]***blocks.1.0.se.conv_reduce.weight:[6, 27, 1, 1]***blocks.1.0.se.conv_reduce.bias:[6]***blocks.1.0.se.conv_expand.weight:[27, 6, 1, 1]***blocks.1.0.se.conv_expand.bias:[27]***blocks.1.0.conv_pwl.weight:[12, 27, 1, 1]***blocks.1.0.bn3.weight:[12]***blocks.1.0.bn3.bias:[12]***blocks.1.0.bn3.running_mean:[12]***blocks.1.0.bn3.running_var:[12]***blocks.1.0.bn3.num_batches_tracked:[]***blocks.1.1.conv_pw.weight:[49, 12, 1, 1]***blocks.1.1.bn1.weight:[49]***blocks.1.1.bn1.bias:[49]***blocks.1.1.bn1.running_mean:[49]***blocks.1.1.bn1.running_var:[49]***blocks.1.1.bn1.num_batches_tracked:[]***blocks.1.1.conv_dw.weight:[49, 1, 3, 3]***blocks.1.1.bn2.weight:[49]***blocks.1.1.bn2.bias:[49]***blocks.1.1.bn2.running_mean:[49]***blocks.1.1.bn2.running_var:[49]***blocks.1.1.bn2.num_batches_tracked:[]***blocks.1.1.se.conv_reduce.weight:[8, 49, 1, 1]***blocks.1.1.se.conv_reduce.bias:[8]***blocks.1.1.se.conv_expand.weight:[49, 8, 1, 1]***blocks.1.1.se.conv_expand.bias:[49]***blocks.1.1.conv_pwl.weight:[12, 49, 1, 1]***blocks.1.1.bn3.weight:[12]***blocks.1.1.bn3.bias:[12]***blocks.1.1.bn3.running_mean:[12]***blocks.1.1.bn3.running_var:[12]***blocks.1.1.bn3.num_batches_tracked:[]***blocks.1.2.conv_pw.weight:[48, 12, 1, 1]***blocks.1.2.bn1.weight:[48]***blocks.1.2.bn1.bias:[48]***blocks.1.2.bn1.running_mean:[48]***blocks.1.2.bn1.running_var:[48]***blocks.1.2.bn1.num_batches_tracked:[]***blocks.1.2.conv_dw.weight:[48, 1, 3, 3]***blocks.1.2.bn2.weight:[48]***blocks.1.2.bn2.bias:[48]***blocks.1.2.bn2.running_mean:[48]***blocks.1.2.bn2.running_var:[48]***blocks.1.2.bn2.num_batches_tracked:[]***blocks.1.2.se.conv_reduce.weight:[8, 48, 1, 1]***blocks.1.2.se.conv_reduce.bias:[8]***blocks.1.2.se.conv_expand.weight:[48, 8, 1, 1]***blocks.1.2.se.conv_expand.bias:[48]***blocks.1.2.conv_pwl.weight:[12, 48, 1, 1]***blocks.1.2.bn3.weight:[12]***blocks.1.2.bn3.bias:[12]***blocks.1.2.bn3.running_mean:[12]***blocks.1.2.bn3.running_var:[12]***blocks.1.2.bn3.num_batches_tracked:[]***blocks.2.0.conv_pw.weight:[83, 12, 1, 1]***blocks.2.0.bn1.weight:[83]***blocks.2.0.bn1.bias:[83]***blocks.2.0.bn1.running_mean:[83]***blocks.2.0.bn1.running_var:[83]***blocks.2.0.bn1.num_batches_tracked:[]***blocks.2.0.conv_dw.weight:[83, 1, 5, 5]***blocks.2.0.bn2.weight:[83]***blocks.2.0.bn2.bias:[83]***blocks.2.0.bn2.running_mean:[83]***blocks.2.0.bn2.running_var:[83]***blocks.2.0.bn2.num_batches_tracked:[]***blocks.2.0.se.conv_reduce.weight:[8, 83, 1, 1]***blocks.2.0.se.conv_reduce.bias:[8]***blocks.2.0.se.conv_expand.weight:[83, 8, 1, 1]***blocks.2.0.se.conv_expand.bias:[83]***blocks.2.0.conv_pwl.weight:[40, 83, 1, 1]***blocks.2.0.bn3.weight:[40]***blocks.2.0.bn3.bias:[40]***blocks.2.0.bn3.running_mean:[40]***blocks.2.0.bn3.running_var:[40]***blocks.2.0.bn3.num_batches_tracked:[]***blocks.2.1.conv_pw.weight:[90, 40, 1, 1]***blocks.2.1.bn1.weight:[90]***blocks.2.1.bn1.bias:[90]***blocks.2.1.bn1.running_mean:[90]***blocks.2.1.bn1.running_var:[90]***blocks.2.1.bn1.num_batches_tracked:[]***blocks.2.1.conv_dw.weight:[90, 1, 5, 5]***blocks.2.1.bn2.weight:[90]***blocks.2.1.bn2.bias:[90]***blocks.2.1.bn2.running_mean:[90]***blocks.2.1.bn2.running_var:[90]***blocks.2.1.bn2.num_batches_tracked:[]***blocks.2.1.se.conv_reduce.weight:[12, 90, 1, 1]***blocks.2.1.se.conv_reduce.bias:[12]***blocks.2.1.se.conv_expand.weight:[90, 12, 1, 1]***blocks.2.1.se.conv_expand.bias:[90]***blocks.2.1.conv_pwl.weight:[40, 90, 1, 1]***blocks.2.1.bn3.weight:[40]***blocks.2.1.bn3.bias:[40]***blocks.2.1.bn3.running_mean:[40]***blocks.2.1.bn3.running_var:[40]***blocks.2.1.bn3.num_batches_tracked:[]***blocks.2.2.conv_pw.weight:[85, 40, 1, 1]***blocks.2.2.bn1.weight:[85]***blocks.2.2.bn1.bias:[85]***blocks.2.2.bn1.running_mean:[85]***blocks.2.2.bn1.running_var:[85]***blocks.2.2.bn1.num_batches_tracked:[]***blocks.2.2.conv_dw.weight:[85, 1, 5, 5]***blocks.2.2.bn2.weight:[85]***blocks.2.2.bn2.bias:[85]***blocks.2.2.bn2.running_mean:[85]***blocks.2.2.bn2.running_var:[85]***blocks.2.2.bn2.num_batches_tracked:[]***blocks.2.2.se.conv_reduce.weight:[12, 85, 1, 1]***blocks.2.2.se.conv_reduce.bias:[12]***blocks.2.2.se.conv_expand.weight:[85, 12, 1, 1]***blocks.2.2.se.conv_expand.bias:[85]***blocks.2.2.conv_pwl.weight:[40, 85, 1, 1]***blocks.2.2.bn3.weight:[40]***blocks.2.2.bn3.bias:[40]***blocks.2.2.bn3.running_mean:[40]***blocks.2.2.bn3.running_var:[40]***blocks.2.2.bn3.num_batches_tracked:[]***blocks.3.0.conv_pw.weight:[215, 40, 1, 1]***blocks.3.0.bn1.weight:[215]***blocks.3.0.bn1.bias:[215]***blocks.3.0.bn1.running_mean:[215]***blocks.3.0.bn1.running_var:[215]***blocks.3.0.bn1.num_batches_tracked:[]***blocks.3.0.conv_dw.weight:[215, 1, 3, 3]***blocks.3.0.bn2.weight:[215]***blocks.3.0.bn2.bias:[215]***blocks.3.0.bn2.running_mean:[215]***blocks.3.0.bn2.running_var:[215]***blocks.3.0.bn2.num_batches_tracked:[]***blocks.3.0.se.conv_reduce.weight:[12, 215, 1, 1]***blocks.3.0.se.conv_reduce.bias:[12]***blocks.3.0.se.conv_expand.weight:[215, 12, 1, 1]***blocks.3.0.se.conv_expand.bias:[215]***blocks.3.0.conv_pwl.weight:[93, 215, 1, 1]***blocks.3.0.bn3.weight:[93]***blocks.3.0.bn3.bias:[93]***blocks.3.0.bn3.running_mean:[93]***blocks.3.0.bn3.running_var:[93]***blocks.3.0.bn3.num_batches_tracked:[]***blocks.3.1.conv_pw.weight:[261, 93, 1, 1]***blocks.3.1.bn1.weight:[261]***blocks.3.1.bn1.bias:[261]***blocks.3.1.bn1.running_mean:[261]***blocks.3.1.bn1.running_var:[261]***blocks.3.1.bn1.num_batches_tracked:[]***blocks.3.1.conv_dw.weight:[261, 1, 3, 3]***blocks.3.1.bn2.weight:[261]***blocks.3.1.bn2.bias:[261]***blocks.3.1.bn2.running_mean:[261]***blocks.3.1.bn2.running_var:[261]***blocks.3.1.bn2.num_batches_tracked:[]***blocks.3.1.se.conv_reduce.weight:[24, 261, 1, 1]***blocks.3.1.se.conv_reduce.bias:[24]***blocks.3.1.se.conv_expand.weight:[261, 24, 1, 1]***blocks.3.1.se.conv_expand.bias:[261]***blocks.3.1.conv_pwl.weight:[93, 261, 1, 1]***blocks.3.1.bn3.weight:[93]***blocks.3.1.bn3.bias:[93]***blocks.3.1.bn3.running_mean:[93]***blocks.3.1.bn3.running_var:[93]***blocks.3.1.bn3.num_batches_tracked:[]***blocks.3.2.conv_pw.weight:[219, 93, 1, 1]***blocks.3.2.bn1.weight:[219]***blocks.3.2.bn1.bias:[219]***blocks.3.2.bn1.running_mean:[219]***blocks.3.2.bn1.running_var:[219]***blocks.3.2.bn1.num_batches_tracked:[]***blocks.3.2.conv_dw.weight:[219, 1, 3, 3]***blocks.3.2.bn2.weight:[219]***blocks.3.2.bn2.bias:[219]***blocks.3.2.bn2.running_mean:[219]***blocks.3.2.bn2.running_var:[219]***blocks.3.2.bn2.num_batches_tracked:[]***blocks.3.2.se.conv_reduce.weight:[24, 219, 1, 1]***blocks.3.2.se.conv_reduce.bias:[24]***blocks.3.2.se.conv_expand.weight:[219, 24, 1, 1]***blocks.3.2.se.conv_expand.bias:[219]***blocks.3.2.conv_pwl.weight:[93, 219, 1, 1]***blocks.3.2.bn3.weight:[93]***blocks.3.2.bn3.bias:[93]***blocks.3.2.bn3.running_mean:[93]***blocks.3.2.bn3.running_var:[93]***blocks.3.2.bn3.num_batches_tracked:[]***blocks.3.3.conv_pw.weight:[254, 93, 1, 1]***blocks.3.3.bn1.weight:[254]***blocks.3.3.bn1.bias:[254]***blocks.3.3.bn1.running_mean:[254]***blocks.3.3.bn1.running_var:[254]***blocks.3.3.bn1.num_batches_tracked:[]***blocks.3.3.conv_dw.weight:[254, 1, 3, 3]***blocks.3.3.bn2.weight:[254]***blocks.3.3.bn2.bias:[254]***blocks.3.3.bn2.running_mean:[254]***blocks.3.3.bn2.running_var:[254]***blocks.3.3.bn2.num_batches_tracked:[]***blocks.3.3.se.conv_reduce.weight:[24, 254, 1, 1]***blocks.3.3.se.conv_reduce.bias:[24]***blocks.3.3.se.conv_expand.weight:[254, 24, 1, 1]***blocks.3.3.se.conv_expand.bias:[254]***blocks.3.3.conv_pwl.weight:[93, 254, 1, 1]***blocks.3.3.bn3.weight:[93]***blocks.3.3.bn3.bias:[93]***blocks.3.3.bn3.running_mean:[93]***blocks.3.3.bn3.running_var:[93]***blocks.3.3.bn3.num_batches_tracked:[]***blocks.3.4.conv_pw.weight:[236, 93, 1, 1]***blocks.3.4.bn1.weight:[236]***blocks.3.4.bn1.bias:[236]***blocks.3.4.bn1.running_mean:[236]***blocks.3.4.bn1.running_var:[236]***blocks.3.4.bn1.num_batches_tracked:[]***blocks.3.4.conv_dw.weight:[236, 1, 3, 3]***blocks.3.4.bn2.weight:[236]***blocks.3.4.bn2.bias:[236]***blocks.3.4.bn2.running_mean:[236]***blocks.3.4.bn2.running_var:[236]***blocks.3.4.bn2.num_batches_tracked:[]***blocks.3.4.se.conv_reduce.weight:[24, 236, 1, 1]***blocks.3.4.se.conv_reduce.bias:[24]***blocks.3.4.se.conv_expand.weight:[236, 24, 1, 1]***blocks.3.4.se.conv_expand.bias:[236]***blocks.3.4.conv_pwl.weight:[93, 236, 1, 1]***blocks.3.4.bn3.weight:[93]***blocks.3.4.bn3.bias:[93]***blocks.3.4.bn3.running_mean:[93]***blocks.3.4.bn3.running_var:[93]***blocks.3.4.bn3.num_batches_tracked:[]***blocks.4.0.conv_pw.weight:[480, 93, 1, 1]***blocks.4.0.bn1.weight:[480]***blocks.4.0.bn1.bias:[480]***blocks.4.0.bn1.running_mean:[480]***blocks.4.0.bn1.running_var:[480]***blocks.4.0.bn1.num_batches_tracked:[]***blocks.4.0.conv_dw.weight:[480, 1, 5, 5]***blocks.4.0.bn2.weight:[480]***blocks.4.0.bn2.bias:[480]***blocks.4.0.bn2.running_mean:[480]***blocks.4.0.bn2.running_var:[480]***blocks.4.0.bn2.num_batches_tracked:[]***blocks.4.0.se.conv_reduce.weight:[24, 480, 1, 1]***blocks.4.0.se.conv_reduce.bias:[24]***blocks.4.0.se.conv_expand.weight:[480, 24, 1, 1]***blocks.4.0.se.conv_expand.bias:[480]***blocks.4.0.conv_pwl.weight:[120, 480, 1, 1]***blocks.4.0.bn3.weight:[120]***blocks.4.0.bn3.bias:[120]***blocks.4.0.bn3.running_mean:[120]***blocks.4.0.bn3.running_var:[120]***blocks.4.0.bn3.num_batches_tracked:[]***blocks.4.1.conv_pw.weight:[235, 120, 1, 1]***blocks.4.1.bn1.weight:[235]***blocks.4.1.bn1.bias:[235]***blocks.4.1.bn1.running_mean:[235]***blocks.4.1.bn1.running_var:[235]***blocks.4.1.bn1.num_batches_tracked:[]***blocks.4.1.conv_dw.weight:[235, 1, 5, 5]***blocks.4.1.bn2.weight:[235]***blocks.4.1.bn2.bias:[235]***blocks.4.1.bn2.running_mean:[235]***blocks.4.1.bn2.running_var:[235]***blocks.4.1.bn2.num_batches_tracked:[]***blocks.4.1.se.conv_reduce.weight:[34, 235, 1, 1]***blocks.4.1.se.conv_reduce.bias:[34]***blocks.4.1.se.conv_expand.weight:[235, 34, 1, 1]***blocks.4.1.se.conv_expand.bias:[235]***blocks.4.1.conv_pwl.weight:[120, 235, 1, 1]***blocks.4.1.bn3.weight:[120]***blocks.4.1.bn3.bias:[120]***blocks.4.1.bn3.running_mean:[120]***blocks.4.1.bn3.running_var:[120]***blocks.4.1.bn3.num_batches_tracked:[]***blocks.4.2.conv_pw.weight:[217, 120, 1, 1]***blocks.4.2.bn1.weight:[217]***blocks.4.2.bn1.bias:[217]***blocks.4.2.bn1.running_mean:[217]***blocks.4.2.bn1.running_var:[217]***blocks.4.2.bn1.num_batches_tracked:[]***blocks.4.2.conv_dw.weight:[217, 1, 5, 5]***blocks.4.2.bn2.weight:[217]***blocks.4.2.bn2.bias:[217]***blocks.4.2.bn2.running_mean:[217]***blocks.4.2.bn2.running_var:[217]***blocks.4.2.bn2.num_batches_tracked:[]***blocks.4.2.se.conv_reduce.weight:[34, 217, 1, 1]***blocks.4.2.se.conv_reduce.bias:[34]***blocks.4.2.se.conv_expand.weight:[217, 34, 1, 1]***blocks.4.2.se.conv_expand.bias:[217]***blocks.4.2.conv_pwl.weight:[120, 217, 1, 1]***blocks.4.2.bn3.weight:[120]***blocks.4.2.bn3.bias:[120]***blocks.4.2.bn3.running_mean:[120]***blocks.4.2.bn3.running_var:[120]***blocks.4.2.bn3.num_batches_tracked:[]***blocks.4.3.conv_pw.weight:[226, 120, 1, 1]***blocks.4.3.bn1.weight:[226]***blocks.4.3.bn1.bias:[226]***blocks.4.3.bn1.running_mean:[226]***blocks.4.3.bn1.running_var:[226]***blocks.4.3.bn1.num_batches_tracked:[]***blocks.4.3.conv_dw.weight:[226, 1, 5, 5]***blocks.4.3.bn2.weight:[226]***blocks.4.3.bn2.bias:[226]***blocks.4.3.bn2.running_mean:[226]***blocks.4.3.bn2.running_var:[226]***blocks.4.3.bn2.num_batches_tracked:[]***blocks.4.3.se.conv_reduce.weight:[33, 226, 1, 1]***blocks.4.3.se.conv_reduce.bias:[33]***blocks.4.3.se.conv_expand.weight:[226, 33, 1, 1]***blocks.4.3.se.conv_expand.bias:[226]***blocks.4.3.conv_pwl.weight:[120, 226, 1, 1]***blocks.4.3.bn3.weight:[120]***blocks.4.3.bn3.bias:[120]***blocks.4.3.bn3.running_mean:[120]***blocks.4.3.bn3.running_var:[120]***blocks.4.3.bn3.num_batches_tracked:[]***blocks.4.4.conv_pw.weight:[340, 120, 1, 1]***blocks.4.4.bn1.weight:[340]***blocks.4.4.bn1.bias:[340]***blocks.4.4.bn1.running_mean:[340]***blocks.4.4.bn1.running_var:[340]***blocks.4.4.bn1.num_batches_tracked:[]***blocks.4.4.conv_dw.weight:[340, 1, 5, 5]***blocks.4.4.bn2.weight:[340]***blocks.4.4.bn2.bias:[340]***blocks.4.4.bn2.running_mean:[340]***blocks.4.4.bn2.running_var:[340]***blocks.4.4.bn2.num_batches_tracked:[]***blocks.4.4.se.conv_reduce.weight:[34, 340, 1, 1]***blocks.4.4.se.conv_reduce.bias:[34]***blocks.4.4.se.conv_expand.weight:[340, 34, 1, 1]***blocks.4.4.se.conv_expand.bias:[340]***blocks.4.4.conv_pwl.weight:[120, 340, 1, 1]***blocks.4.4.bn3.weight:[120]***blocks.4.4.bn3.bias:[120]***blocks.4.4.bn3.running_mean:[120]***blocks.4.4.bn3.running_var:[120]***blocks.4.4.bn3.num_batches_tracked:[]***blocks.5.0.conv_pw.weight:[802, 120, 1, 1]***blocks.5.0.bn1.weight:[802]***blocks.5.0.bn1.bias:[802]***blocks.5.0.bn1.running_mean:[802]***blocks.5.0.bn1.running_var:[802]***blocks.5.0.bn1.num_batches_tracked:[]***blocks.5.0.conv_dw.weight:[802, 1, 5, 5]***blocks.5.0.bn2.weight:[802]***blocks.5.0.bn2.bias:[802]***blocks.5.0.bn2.running_mean:[802]***blocks.5.0.bn2.running_var:[802]***blocks.5.0.bn2.num_batches_tracked:[]***blocks.5.0.se.conv_reduce.weight:[34, 802, 1, 1]***blocks.5.0.se.conv_reduce.bias:[34]***blocks.5.0.se.conv_expand.weight:[802, 34, 1, 1]***blocks.5.0.se.conv_expand.bias:[802]***blocks.5.0.conv_pwl.weight:[232, 802, 1, 1]***blocks.5.0.bn3.weight:[232]***blocks.5.0.bn3.bias:[232]***blocks.5.0.bn3.running_mean:[232]***blocks.5.0.bn3.running_var:[232]***blocks.5.0.bn3.num_batches_tracked:[]***blocks.5.1.conv_pw.weight:[1030, 232, 1, 1]***blocks.5.1.bn1.weight:[1030]***blocks.5.1.bn1.bias:[1030]***blocks.5.1.bn1.running_mean:[1030]***blocks.5.1.bn1.running_var:[1030]***blocks.5.1.bn1.num_batches_tracked:[]***blocks.5.1.conv_dw.weight:[1030, 1, 5, 5]***blocks.5.1.bn2.weight:[1030]***blocks.5.1.bn2.bias:[1030]***blocks.5.1.bn2.running_mean:[1030]***blocks.5.1.bn2.running_var:[1030]***blocks.5.1.bn2.num_batches_tracked:[]***blocks.5.1.se.conv_reduce.weight:[58, 1030, 1, 1]***blocks.5.1.se.conv_reduce.bias:[58]***blocks.5.1.se.conv_expand.weight:[1030, 58, 1, 1]***blocks.5.1.se.conv_expand.bias:[1030]***blocks.5.1.conv_pwl.weight:[232, 1030, 1, 1]***blocks.5.1.bn3.weight:[232]***blocks.5.1.bn3.bias:[232]***blocks.5.1.bn3.running_mean:[232]***blocks.5.1.bn3.running_var:[232]***blocks.5.1.bn3.num_batches_tracked:[]***blocks.5.2.conv_pw.weight:[924, 232, 1, 1]***blocks.5.2.bn1.weight:[924]***blocks.5.2.bn1.bias:[924]***blocks.5.2.bn1.running_mean:[924]***blocks.5.2.bn1.running_var:[924]***blocks.5.2.bn1.num_batches_tracked:[]***blocks.5.2.conv_dw.weight:[924, 1, 5, 5]***blocks.5.2.bn2.weight:[924]***blocks.5.2.bn2.bias:[924]***blocks.5.2.bn2.running_mean:[924]***blocks.5.2.bn2.running_var:[924]***blocks.5.2.bn2.num_batches_tracked:[]***blocks.5.2.se.conv_reduce.weight:[58, 924, 1, 1]***blocks.5.2.se.conv_reduce.bias:[58]***blocks.5.2.se.conv_expand.weight:[924, 58, 1, 1]***blocks.5.2.se.conv_expand.bias:[924]***blocks.5.2.conv_pwl.weight:[232, 924, 1, 1]***blocks.5.2.bn3.weight:[232]***blocks.5.2.bn3.bias:[232]***blocks.5.2.bn3.running_mean:[232]***blocks.5.2.bn3.running_var:[232]***blocks.5.2.bn3.num_batches_tracked:[]***blocks.5.3.conv_pw.weight:[1016, 232, 1, 1]***blocks.5.3.bn1.weight:[1016]***blocks.5.3.bn1.bias:[1016]***blocks.5.3.bn1.running_mean:[1016]***blocks.5.3.bn1.running_var:[1016]***blocks.5.3.bn1.num_batches_tracked:[]***blocks.5.3.conv_dw.weight:[1016, 1, 5, 5]***blocks.5.3.bn2.weight:[1016]***blocks.5.3.bn2.bias:[1016]***blocks.5.3.bn2.running_mean:[1016]***blocks.5.3.bn2.running_var:[1016]***blocks.5.3.bn2.num_batches_tracked:[]***blocks.5.3.se.conv_reduce.weight:[58, 1016, 1, 1]***blocks.5.3.se.conv_reduce.bias:[58]***blocks.5.3.se.conv_expand.weight:[1016, 58, 1, 1]***blocks.5.3.se.conv_expand.bias:[1016]***blocks.5.3.conv_pwl.weight:[232, 1016, 1, 1]***blocks.5.3.bn3.weight:[232]***blocks.5.3.bn3.bias:[232]***blocks.5.3.bn3.running_mean:[232]***blocks.5.3.bn3.running_var:[232]***blocks.5.3.bn3.num_batches_tracked:[]***blocks.5.4.conv_pw.weight:[1130, 232, 1, 1]***blocks.5.4.bn1.weight:[1130]***blocks.5.4.bn1.bias:[1130]***blocks.5.4.bn1.running_mean:[1130]***blocks.5.4.bn1.running_var:[1130]***blocks.5.4.bn1.num_batches_tracked:[]***blocks.5.4.conv_dw.weight:[1130, 1, 5, 5]***blocks.5.4.bn2.weight:[1130]***blocks.5.4.bn2.bias:[1130]***blocks.5.4.bn2.running_mean:[1130]***blocks.5.4.bn2.running_var:[1130]***blocks.5.4.bn2.num_batches_tracked:[]***blocks.5.4.se.conv_reduce.weight:[58, 1130, 1, 1]***blocks.5.4.se.conv_reduce.bias:[58]***blocks.5.4.se.conv_expand.weight:[1130, 58, 1, 1]***blocks.5.4.se.conv_expand.bias:[1130]***blocks.5.4.conv_pwl.weight:[232, 1130, 1, 1]***blocks.5.4.bn3.weight:[232]***blocks.5.4.bn3.bias:[232]***blocks.5.4.bn3.running_mean:[232]***blocks.5.4.bn3.running_var:[232]***blocks.5.4.bn3.num_batches_tracked:[]***blocks.5.5.conv_pw.weight:[1266, 232, 1, 1]***blocks.5.5.bn1.weight:[1266]***blocks.5.5.bn1.bias:[1266]***blocks.5.5.bn1.running_mean:[1266]***blocks.5.5.bn1.running_var:[1266]***blocks.5.5.bn1.num_batches_tracked:[]***blocks.5.5.conv_dw.weight:[1266, 1, 5, 5]***blocks.5.5.bn2.weight:[1266]***blocks.5.5.bn2.bias:[1266]***blocks.5.5.bn2.running_mean:[1266]***blocks.5.5.bn2.running_var:[1266]***blocks.5.5.bn2.num_batches_tracked:[]***blocks.5.5.se.conv_reduce.weight:[58, 1266, 1, 1]***blocks.5.5.se.conv_reduce.bias:[58]***blocks.5.5.se.conv_expand.weight:[1266, 58, 1, 1]***blocks.5.5.se.conv_expand.bias:[1266]***blocks.5.5.conv_pwl.weight:[232, 1266, 1, 1]***blocks.5.5.bn3.weight:[232]***blocks.5.5.bn3.bias:[232]***blocks.5.5.bn3.running_mean:[232]***blocks.5.5.bn3.running_var:[232]***blocks.5.5.bn3.num_batches_tracked:[]***blocks.6.0.conv_pw.weight:[1392, 232, 1, 1]***blocks.6.0.bn1.weight:[1392]***blocks.6.0.bn1.bias:[1392]***blocks.6.0.bn1.running_mean:[1392]***blocks.6.0.bn1.running_var:[1392]***blocks.6.0.bn1.num_batches_tracked:[]***blocks.6.0.conv_dw.weight:[1392, 1, 3, 3]***blocks.6.0.bn2.weight:[1392]***blocks.6.0.bn2.bias:[1392]***blocks.6.0.bn2.running_mean:[1392]***blocks.6.0.bn2.running_var:[1392]***blocks.6.0.bn2.num_batches_tracked:[]***blocks.6.0.se.conv_reduce.weight:[58, 1392, 1, 1]***blocks.6.0.se.conv_reduce.bias:[58]***blocks.6.0.se.conv_expand.weight:[1392, 58, 1, 1]***blocks.6.0.se.conv_expand.bias:[1392]***blocks.6.0.conv_pwl.weight:[384, 1392, 1, 1]***blocks.6.0.bn3.weight:[384]***blocks.6.0.bn3.bias:[384]***blocks.6.0.bn3.running_mean:[384]***blocks.6.0.bn3.running_var:[384]***blocks.6.0.bn3.num_batches_tracked:[]***blocks.6.1.conv_pw.weight:[2301, 384, 1, 1]***blocks.6.1.bn1.weight:[2301]***blocks.6.1.bn1.bias:[2301]***blocks.6.1.bn1.running_mean:[2301]***blocks.6.1.bn1.running_var:[2301]***blocks.6.1.bn1.num_batches_tracked:[]***blocks.6.1.conv_dw.weight:[2301, 1, 3, 3]***blocks.6.1.bn2.weight:[2301]***blocks.6.1.bn2.bias:[2301]***blocks.6.1.bn2.running_mean:[2301]***blocks.6.1.bn2.running_var:[2301]***blocks.6.1.bn2.num_batches_tracked:[]***blocks.6.1.se.conv_reduce.weight:[96, 2301, 1, 1]***blocks.6.1.se.conv_reduce.bias:[96]***blocks.6.1.se.conv_expand.weight:[2301, 96, 1, 1]***blocks.6.1.se.conv_expand.bias:[2301]***blocks.6.1.conv_pwl.weight:[384, 2301, 1, 1]***blocks.6.1.bn3.weight:[384]***blocks.6.1.bn3.bias:[384]***blocks.6.1.bn3.running_mean:[384]***blocks.6.1.bn3.running_var:[384]***blocks.6.1.bn3.num_batches_tracked:[]***conv_head.weight:[1536, 384, 1, 1]***bn2.weight:[1536]***bn2.bias:[1536]***bn2.running_mean:[1536]***bn2.running_var:[1536]***bn2.num_batches_tracked:[]***classifier.weight:[1000, 1536]***classifier.bias:[1000]
\ No newline at end of file
diff --git a/timm/models/registry.py b/timm/models/registry.py
new file mode 100644
index 0000000..f92219b
--- /dev/null
+++ b/timm/models/registry.py
@@ -0,0 +1,149 @@
+""" Model Registry
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import sys
+import re
+import fnmatch
+from collections import defaultdict
+from copy import deepcopy
+
+__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
+ 'is_model_default_key', 'has_model_default_key', 'get_model_default_value', 'is_model_pretrained']
+
+_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
+_model_to_module = {} # mapping of model names to module names
+_model_entrypoints = {} # mapping of model names to entrypoint fns
+_model_has_pretrained = set() # set of model names that have pretrained weight url present
+_model_default_cfgs = dict() # central repo for model default_cfgs
+
+
+def register_model(fn):
+ # lookup containing module
+ mod = sys.modules[fn.__module__]
+ module_name_split = fn.__module__.split('.')
+ module_name = module_name_split[-1] if len(module_name_split) else ''
+
+ # add model to __all__ in module
+ model_name = fn.__name__
+ if hasattr(mod, '__all__'):
+ mod.__all__.append(model_name)
+ else:
+ mod.__all__ = [model_name]
+
+ # add entries to registry dict/sets
+ _model_entrypoints[model_name] = fn
+ _model_to_module[model_name] = module_name
+ _module_to_models[module_name].add(model_name)
+ has_pretrained = False # check if model has a pretrained url to allow filtering on this
+ if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
+ # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
+ # entrypoints or non-matching combos
+ has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
+ _model_default_cfgs[model_name] = deepcopy(mod.default_cfgs[model_name])
+ if has_pretrained:
+ _model_has_pretrained.add(model_name)
+ return fn
+
+
+def _natural_key(string_):
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
+
+
+def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
+ """ Return list of available model names, sorted alphabetically
+
+ Args:
+ filter (str) - Wildcard filter string that works with fnmatch
+ module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
+ pretrained (bool) - Include only models with pretrained weights if True
+ exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
+ name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
+
+ Example:
+ model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
+ model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
+ """
+ if module:
+ all_models = list(_module_to_models[module])
+ else:
+ all_models = _model_entrypoints.keys()
+ if filter:
+ models = []
+ include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
+ for f in include_filters:
+ include_models = fnmatch.filter(all_models, f) # include these models
+ if len(include_models):
+ models = set(models).union(include_models)
+ else:
+ models = all_models
+ if exclude_filters:
+ if not isinstance(exclude_filters, (tuple, list)):
+ exclude_filters = [exclude_filters]
+ for xf in exclude_filters:
+ exclude_models = fnmatch.filter(models, xf) # exclude these models
+ if len(exclude_models):
+ models = set(models).difference(exclude_models)
+ if pretrained:
+ models = _model_has_pretrained.intersection(models)
+ if name_matches_cfg:
+ models = set(_model_default_cfgs).intersection(models)
+ return list(sorted(models, key=_natural_key))
+
+
+def is_model(model_name):
+ """ Check if a model name exists
+ """
+ return model_name in _model_entrypoints
+
+
+def model_entrypoint(model_name):
+ """Fetch a model entrypoint for specified model name
+ """
+ return _model_entrypoints[model_name]
+
+
+def list_modules():
+ """ Return list of module names that contain models / model entrypoints
+ """
+ modules = _module_to_models.keys()
+ return list(sorted(modules))
+
+
+def is_model_in_modules(model_name, module_names):
+ """Check if a model exists within a subset of modules
+ Args:
+ model_name (str) - name of model to check
+ module_names (tuple, list, set) - names of modules to search in
+ """
+ assert isinstance(module_names, (tuple, list, set))
+ return any(model_name in _module_to_models[n] for n in module_names)
+
+
+def has_model_default_key(model_name, cfg_key):
+ """ Query model default_cfgs for existence of a specific key.
+ """
+ if model_name in _model_default_cfgs and cfg_key in _model_default_cfgs[model_name]:
+ return True
+ return False
+
+
+def is_model_default_key(model_name, cfg_key):
+ """ Return truthy value for specified model default_cfg key, False if does not exist.
+ """
+ if model_name in _model_default_cfgs and _model_default_cfgs[model_name].get(cfg_key, False):
+ return True
+ return False
+
+
+def get_model_default_value(model_name, cfg_key):
+ """ Get a specific model default_cfg value by key. None if it doesn't exist.
+ """
+ if model_name in _model_default_cfgs:
+ return _model_default_cfgs[model_name].get(cfg_key, None)
+ else:
+ return None
+
+
+def is_model_pretrained(model_name):
+ return model_name in _model_has_pretrained
diff --git a/timm/models/regnet.py b/timm/models/regnet.py
new file mode 100644
index 0000000..6a38107
--- /dev/null
+++ b/timm/models/regnet.py
@@ -0,0 +1,494 @@
+"""RegNet
+
+Paper: `Designing Network Design Spaces` - https://arxiv.org/abs/2003.13678
+Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py
+
+Based on original PyTorch impl linked above, but re-wrote to use my own blocks (adapted from ResNet here)
+and cleaned up with more descriptive variable names.
+
+Weights from original impl have been modified
+* first layer from BGR -> RGB as most PyTorch models are
+* removed training specific dict entries from checkpoints and keep model state_dict only
+* remap names to match the ones here
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import numpy as np
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule, DropPath
+from .registry import register_model
+
+
+def _mcfg(**kwargs):
+ cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32)
+ cfg.update(**kwargs)
+ return cfg
+
+
+# Model FLOPS = three trailing digits * 10^8
+model_cfgs = dict(
+ regnetx_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13),
+ regnetx_004=_mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22),
+ regnetx_006=_mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16),
+ regnetx_008=_mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16),
+ regnetx_016=_mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18),
+ regnetx_032=_mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25),
+ regnetx_040=_mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23),
+ regnetx_064=_mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17),
+ regnetx_080=_mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23),
+ regnetx_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19),
+ regnetx_160=_mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22),
+ regnetx_320=_mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23),
+ regnety_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25),
+ regnety_004=_mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25),
+ regnety_006=_mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25),
+ regnety_008=_mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25),
+ regnety_016=_mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25),
+ regnety_032=_mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25),
+ regnety_040=_mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25),
+ regnety_064=_mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25),
+ regnety_080=_mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25),
+ regnety_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25),
+ regnety_160=_mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25),
+ regnety_320=_mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25),
+)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+default_cfgs = dict(
+ regnetx_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth'),
+ regnetx_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth'),
+ regnetx_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth'),
+ regnetx_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth'),
+ regnetx_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth'),
+ regnetx_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth'),
+ regnetx_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth'),
+ regnetx_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth'),
+ regnetx_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth'),
+ regnetx_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth'),
+ regnetx_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth'),
+ regnetx_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth'),
+ regnety_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth'),
+ regnety_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth'),
+ regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'),
+ regnety_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'),
+ regnety_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'),
+ regnety_032=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth',
+ crop_pct=1.0, test_input_size=(3, 288, 288)),
+ regnety_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'),
+ regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'),
+ regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'),
+ regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'),
+ regnety_160=_cfg(
+ url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository
+ crop_pct=1.0, test_input_size=(3, 288, 288)),
+ regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
+)
+
+
+def quantize_float(f, q):
+ """Converts a float to closest non-zero int divisible by q."""
+ return int(round(f / q) * q)
+
+
+def adjust_widths_groups_comp(widths, bottle_ratios, groups):
+ """Adjusts the compatibility of widths and groups."""
+ bottleneck_widths = [int(w * b) for w, b in zip(widths, bottle_ratios)]
+ groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_widths)]
+ bottleneck_widths = [quantize_float(w_bot, g) for w_bot, g in zip(bottleneck_widths, groups)]
+ widths = [int(w_bot / b) for w_bot, b in zip(bottleneck_widths, bottle_ratios)]
+ return widths, groups
+
+
+def generate_regnet(width_slope, width_initial, width_mult, depth, q=8):
+ """Generates per block widths from RegNet parameters."""
+ assert width_slope >= 0 and width_initial > 0 and width_mult > 1 and width_initial % q == 0
+ widths_cont = np.arange(depth) * width_slope + width_initial
+ width_exps = np.round(np.log(widths_cont / width_initial) / np.log(width_mult))
+ widths = width_initial * np.power(width_mult, width_exps)
+ widths = np.round(np.divide(widths, q)) * q
+ num_stages, max_stage = len(np.unique(widths)), width_exps.max() + 1
+ widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist()
+ return widths, num_stages, max_stage, widths_cont
+
+
+class Bottleneck(nn.Module):
+ """ RegNet Bottleneck
+
+ This is almost exactly the same as a ResNet Bottlneck. The main difference is the SE block is moved from
+ after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels.
+ """
+
+ def __init__(self, in_chs, out_chs, stride=1, dilation=1, bottleneck_ratio=1, group_width=1, se_ratio=0.25,
+ downsample=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None,
+ drop_block=None, drop_path=None):
+ super(Bottleneck, self).__init__()
+ bottleneck_chs = int(round(out_chs * bottleneck_ratio))
+ groups = bottleneck_chs // group_width
+
+ cargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block)
+ self.conv1 = ConvBnAct(in_chs, bottleneck_chs, kernel_size=1, **cargs)
+ self.conv2 = ConvBnAct(
+ bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation,
+ groups=groups, **cargs)
+ if se_ratio:
+ se_channels = int(round(in_chs * se_ratio))
+ self.se = SEModule(bottleneck_chs, rd_channels=se_channels)
+ else:
+ self.se = None
+ cargs['act_layer'] = None
+ self.conv3 = ConvBnAct(bottleneck_chs, out_chs, kernel_size=1, **cargs)
+ self.act3 = act_layer(inplace=True)
+ self.downsample = downsample
+ self.drop_path = drop_path
+
+ def zero_init_last_bn(self):
+ nn.init.zeros_(self.conv3.bn.weight)
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ if self.se is not None:
+ x = self.se(x)
+ x = self.conv3(x)
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+ if self.downsample is not None:
+ shortcut = self.downsample(shortcut)
+ x += shortcut
+ x = self.act3(x)
+ return x
+
+
+def downsample_conv(
+ in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None):
+ norm_layer = norm_layer or nn.BatchNorm2d
+ kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
+ dilation = dilation if kernel_size > 1 else 1
+ return ConvBnAct(
+ in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, act_layer=None)
+
+
+def downsample_avg(
+ in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None):
+ """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
+ norm_layer = norm_layer or nn.BatchNorm2d
+ avg_stride = stride if dilation == 1 else 1
+ pool = nn.Identity()
+ if stride > 1 or dilation > 1:
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
+ pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
+ return nn.Sequential(*[
+ pool, ConvBnAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, act_layer=None)])
+
+
+class RegStage(nn.Module):
+ """Stage (sequence of blocks w/ the same output shape)."""
+
+ def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width,
+ block_fn=Bottleneck, se_ratio=0., drop_path_rates=None, drop_block=None):
+ super(RegStage, self).__init__()
+ block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args
+ first_dilation = 1 if dilation in (1, 2) else 2
+ for i in range(depth):
+ block_stride = stride if i == 0 else 1
+ block_in_chs = in_chs if i == 0 else out_chs
+ block_dilation = first_dilation if i == 0 else dilation
+ if drop_path_rates is not None and drop_path_rates[i] > 0.:
+ drop_path = DropPath(drop_path_rates[i])
+ else:
+ drop_path = None
+ if (block_in_chs != out_chs) or (block_stride != 1):
+ proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation)
+ else:
+ proj_block = None
+
+ name = "b{}".format(i + 1)
+ self.add_module(
+ name, block_fn(
+ block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio,
+ downsample=proj_block, drop_block=drop_block, drop_path=drop_path, **block_kwargs)
+ )
+
+ def forward(self, x):
+ for block in self.children():
+ x = block(x)
+ return x
+
+
+class RegNet(nn.Module):
+ """RegNet model.
+
+ Paper: https://arxiv.org/abs/2003.13678
+ Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py
+ """
+
+ def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0.,
+ drop_path_rate=0., zero_init_last_bn=True):
+ super().__init__()
+ # TODO add drop block, drop path, anti-aliasing, custom bn/act args
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ assert output_stride in (8, 16, 32)
+
+ # Construct the stem
+ stem_width = cfg['stem_width']
+ self.stem = ConvBnAct(in_chans, stem_width, 3, stride=2)
+ self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')]
+
+ # Construct the stages
+ prev_width = stem_width
+ curr_stride = 2
+ stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
+ se_ratio = cfg['se_ratio']
+ for i, stage_args in enumerate(stage_params):
+ stage_name = "s{}".format(i + 1)
+ self.add_module(stage_name, RegStage(prev_width, **stage_args, se_ratio=se_ratio))
+ prev_width = stage_args['out_chs']
+ curr_stride *= stage_args['stride']
+ self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)]
+
+ # Construct the head
+ self.num_features = prev_width
+ self.head = ClassifierHead(
+ in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, mean=0.0, std=0.01)
+ nn.init.zeros_(m.bias)
+ if zero_init_last_bn:
+ for m in self.modules():
+ if hasattr(m, 'zero_init_last_bn'):
+ m.zero_init_last_bn()
+
+ def _get_stage_params(self, cfg, default_stride=2, output_stride=32, drop_path_rate=0.):
+ # Generate RegNet ws per block
+ w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth']
+ widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
+
+ # Convert to per stage format
+ stage_widths, stage_depths = np.unique(widths, return_counts=True)
+
+ # Use the same group width, bottleneck mult and stride for each stage
+ stage_groups = [cfg['group_w'] for _ in range(num_stages)]
+ stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)]
+ stage_strides = []
+ stage_dilations = []
+ net_stride = 2
+ dilation = 1
+ for _ in range(num_stages):
+ if net_stride >= output_stride:
+ dilation *= default_stride
+ stride = 1
+ else:
+ stride = default_stride
+ net_stride *= stride
+ stage_strides.append(stride)
+ stage_dilations.append(dilation)
+ stage_dpr = np.split(np.linspace(0, drop_path_rate, d), np.cumsum(stage_depths[:-1]))
+
+ # Adjust the compatibility of ws and gws
+ stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
+ param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rates']
+ stage_params = [
+ dict(zip(param_names, params)) for params in
+ zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups,
+ stage_dpr)]
+ return stage_params
+
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ def forward_features(self, x):
+ for block in list(self.children())[:-1]:
+ x = block(x)
+ return x
+
+ def forward(self, x):
+ for block in self.children():
+ x = block(x)
+ return x
+
+
+def _filter_fn(state_dict):
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
+ if 'model' in state_dict:
+ # For DeiT trained regnety_160 pretraiend model
+ state_dict = state_dict['model']
+ return state_dict
+
+
+def _create_regnet(variant, pretrained, **kwargs):
+ return build_model_with_cfg(
+ RegNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ model_cfg=model_cfgs[variant],
+ pretrained_filter_fn=_filter_fn,
+ **kwargs)
+
+
+@register_model
+def regnetx_002(pretrained=False, **kwargs):
+ """RegNetX-200MF"""
+ return _create_regnet('regnetx_002', pretrained, **kwargs)
+
+
+@register_model
+def regnetx_004(pretrained=False, **kwargs):
+ """RegNetX-400MF"""
+ return _create_regnet('regnetx_004', pretrained, **kwargs)
+
+
+@register_model
+def regnetx_006(pretrained=False, **kwargs):
+ """RegNetX-600MF"""
+ return _create_regnet('regnetx_006', pretrained, **kwargs)
+
+
+@register_model
+def regnetx_008(pretrained=False, **kwargs):
+ """RegNetX-800MF"""
+ return _create_regnet('regnetx_008', pretrained, **kwargs)
+
+
+@register_model
+def regnetx_016(pretrained=False, **kwargs):
+ """RegNetX-1.6GF"""
+ return _create_regnet('regnetx_016', pretrained, **kwargs)
+
+
+@register_model
+def regnetx_032(pretrained=False, **kwargs):
+ """RegNetX-3.2GF"""
+ return _create_regnet('regnetx_032', pretrained, **kwargs)
+
+
+@register_model
+def regnetx_040(pretrained=False, **kwargs):
+ """RegNetX-4.0GF"""
+ return _create_regnet('regnetx_040', pretrained, **kwargs)
+
+
+@register_model
+def regnetx_064(pretrained=False, **kwargs):
+ """RegNetX-6.4GF"""
+ return _create_regnet('regnetx_064', pretrained, **kwargs)
+
+
+@register_model
+def regnetx_080(pretrained=False, **kwargs):
+ """RegNetX-8.0GF"""
+ return _create_regnet('regnetx_080', pretrained, **kwargs)
+
+
+@register_model
+def regnetx_120(pretrained=False, **kwargs):
+ """RegNetX-12GF"""
+ return _create_regnet('regnetx_120', pretrained, **kwargs)
+
+
+@register_model
+def regnetx_160(pretrained=False, **kwargs):
+ """RegNetX-16GF"""
+ return _create_regnet('regnetx_160', pretrained, **kwargs)
+
+
+@register_model
+def regnetx_320(pretrained=False, **kwargs):
+ """RegNetX-32GF"""
+ return _create_regnet('regnetx_320', pretrained, **kwargs)
+
+
+@register_model
+def regnety_002(pretrained=False, **kwargs):
+ """RegNetY-200MF"""
+ return _create_regnet('regnety_002', pretrained, **kwargs)
+
+
+@register_model
+def regnety_004(pretrained=False, **kwargs):
+ """RegNetY-400MF"""
+ return _create_regnet('regnety_004', pretrained, **kwargs)
+
+
+@register_model
+def regnety_006(pretrained=False, **kwargs):
+ """RegNetY-600MF"""
+ return _create_regnet('regnety_006', pretrained, **kwargs)
+
+
+@register_model
+def regnety_008(pretrained=False, **kwargs):
+ """RegNetY-800MF"""
+ return _create_regnet('regnety_008', pretrained, **kwargs)
+
+
+@register_model
+def regnety_016(pretrained=False, **kwargs):
+ """RegNetY-1.6GF"""
+ return _create_regnet('regnety_016', pretrained, **kwargs)
+
+
+@register_model
+def regnety_032(pretrained=False, **kwargs):
+ """RegNetY-3.2GF"""
+ return _create_regnet('regnety_032', pretrained, **kwargs)
+
+
+@register_model
+def regnety_040(pretrained=False, **kwargs):
+ """RegNetY-4.0GF"""
+ return _create_regnet('regnety_040', pretrained, **kwargs)
+
+
+@register_model
+def regnety_064(pretrained=False, **kwargs):
+ """RegNetY-6.4GF"""
+ return _create_regnet('regnety_064', pretrained, **kwargs)
+
+
+@register_model
+def regnety_080(pretrained=False, **kwargs):
+ """RegNetY-8.0GF"""
+ return _create_regnet('regnety_080', pretrained, **kwargs)
+
+
+@register_model
+def regnety_120(pretrained=False, **kwargs):
+ """RegNetY-12GF"""
+ return _create_regnet('regnety_120', pretrained, **kwargs)
+
+
+@register_model
+def regnety_160(pretrained=False, **kwargs):
+ """RegNetY-16GF"""
+ return _create_regnet('regnety_160', pretrained, **kwargs)
+
+
+@register_model
+def regnety_320(pretrained=False, **kwargs):
+ """RegNetY-32GF"""
+ return _create_regnet('regnety_320', pretrained, **kwargs)
diff --git a/timm/models/res2net.py b/timm/models/res2net.py
new file mode 100644
index 0000000..282baba
--- /dev/null
+++ b/timm/models/res2net.py
@@ -0,0 +1,216 @@
+""" Res2Net and Res2NeXt
+Adapted from Official Pytorch impl at: https://github.com/gasvn/Res2Net/
+Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169
+"""
+import math
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .registry import register_model
+from .resnet import ResNet
+
+__all__ = []
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv1', 'classifier': 'fc',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'res2net50_26w_4s': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth'),
+ 'res2net50_48w_2s': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth'),
+ 'res2net50_14w_8s': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth'),
+ 'res2net50_26w_6s': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth'),
+ 'res2net50_26w_8s': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth'),
+ 'res2net101_26w_4s': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth'),
+ 'res2next50': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth'),
+}
+
+
+class Bottle2neck(nn.Module):
+ """ Res2Net/Res2NeXT Bottleneck
+ Adapted from https://github.com/gasvn/Res2Net/blob/master/res2net.py
+ """
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
+ cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None,
+ act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_):
+ super(Bottle2neck, self).__init__()
+ self.scale = scale
+ self.is_first = stride > 1 or downsample is not None
+ self.num_scales = max(1, scale - 1)
+ width = int(math.floor(planes * (base_width / 64.0))) * cardinality
+ self.width = width
+ outplanes = planes * self.expansion
+ first_dilation = first_dilation or dilation
+
+ self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)
+ self.bn1 = norm_layer(width * scale)
+
+ convs = []
+ bns = []
+ for i in range(self.num_scales):
+ convs.append(nn.Conv2d(
+ width, width, kernel_size=3, stride=stride, padding=first_dilation,
+ dilation=first_dilation, groups=cardinality, bias=False))
+ bns.append(norm_layer(width))
+ self.convs = nn.ModuleList(convs)
+ self.bns = nn.ModuleList(bns)
+ if self.is_first:
+ # FIXME this should probably have count_include_pad=False, but hurts original weights
+ self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
+ else:
+ self.pool = None
+
+ self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False)
+ self.bn3 = norm_layer(outplanes)
+ self.se = attn_layer(outplanes) if attn_layer is not None else None
+
+ self.relu = act_layer(inplace=True)
+ self.downsample = downsample
+
+ def zero_init_last_bn(self):
+ nn.init.zeros_(self.bn3.weight)
+
+ def forward(self, x):
+ shortcut = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ spx = torch.split(out, self.width, 1)
+ spo = []
+ sp = spx[0] # redundant, for torchscript
+ for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
+ if i == 0 or self.is_first:
+ sp = spx[i]
+ else:
+ sp = sp + spx[i]
+ sp = conv(sp)
+ sp = bn(sp)
+ sp = self.relu(sp)
+ spo.append(sp)
+ if self.scale > 1:
+ if self.pool is not None:
+ # self.is_first == True, None check for torchscript
+ spo.append(self.pool(spx[-1]))
+ else:
+ spo.append(spx[-1])
+ out = torch.cat(spo, 1)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.se is not None:
+ out = self.se(out)
+
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+
+ out += shortcut
+ out = self.relu(out)
+
+ return out
+
+
+def _create_res2net(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ ResNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ **kwargs)
+
+
+@register_model
+def res2net50_26w_4s(pretrained=False, **kwargs):
+ """Constructs a Res2Net-50 26w4s model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs)
+ return _create_res2net('res2net50_26w_4s', pretrained, **model_args)
+
+
+@register_model
+def res2net101_26w_4s(pretrained=False, **kwargs):
+ """Constructs a Res2Net-101 26w4s model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs)
+ return _create_res2net('res2net101_26w_4s', pretrained, **model_args)
+
+
+@register_model
+def res2net50_26w_6s(pretrained=False, **kwargs):
+ """Constructs a Res2Net-50 26w6s model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs)
+ return _create_res2net('res2net50_26w_6s', pretrained, **model_args)
+
+
+@register_model
+def res2net50_26w_8s(pretrained=False, **kwargs):
+ """Constructs a Res2Net-50 26w8s model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs)
+ return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
+
+
+@register_model
+def res2net50_48w_2s(pretrained=False, **kwargs):
+ """Constructs a Res2Net-50 48w2s model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs)
+ return _create_res2net('res2net50_48w_2s', pretrained, **model_args)
+
+
+@register_model
+def res2net50_14w_8s(pretrained=False, **kwargs):
+ """Constructs a Res2Net-50 14w8s model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs)
+ return _create_res2net('res2net50_14w_8s', pretrained, **model_args)
+
+
+@register_model
+def res2next50(pretrained=False, **kwargs):
+ """Construct Res2NeXt-50 4s
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs)
+ return _create_res2net('res2next50', pretrained, **model_args)
diff --git a/timm/models/resnest.py b/timm/models/resnest.py
new file mode 100644
index 0000000..31eebd8
--- /dev/null
+++ b/timm/models/resnest.py
@@ -0,0 +1,237 @@
+""" ResNeSt Models
+
+Paper: `ResNeSt: Split-Attention Networks` - https://arxiv.org/abs/2004.08955
+
+Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt by Hang Zhang
+
+Modified for torchscript compat, and consistency with timm by Ross Wightman
+"""
+import torch
+from torch import nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import SplitAttn
+from .registry import register_model
+from .resnet import ResNet
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv1.0', 'classifier': 'fc',
+ **kwargs
+ }
+
+default_cfgs = {
+ 'resnest14d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'),
+ 'resnest26d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'),
+ 'resnest50d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth'),
+ 'resnest101e': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'resnest200e': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth',
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=0.909, interpolation='bicubic'),
+ 'resnest269e': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth',
+ input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.928, interpolation='bicubic'),
+ 'resnest50d_4s2x40d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth',
+ interpolation='bicubic'),
+ 'resnest50d_1s4x24d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth',
+ interpolation='bicubic')
+}
+
+
+class ResNestBottleneck(nn.Module):
+ """ResNet Bottleneck
+ """
+ # pylint: disable=unused-argument
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
+ radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False,
+ reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
+ attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
+ super(ResNestBottleneck, self).__init__()
+ assert reduce_first == 1 # not supported
+ assert attn_layer is None # not supported
+ assert aa_layer is None # TODO not yet supported
+ assert drop_path is None # TODO not yet supported
+
+ group_width = int(planes * (base_width / 64.)) * cardinality
+ first_dilation = first_dilation or dilation
+ if avd and (stride > 1 or is_first):
+ avd_stride = stride
+ stride = 1
+ else:
+ avd_stride = 0
+ self.radix = radix
+ self.drop_block = drop_block
+
+ self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
+ self.bn1 = norm_layer(group_width)
+ self.act1 = act_layer(inplace=True)
+ self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None
+
+ if self.radix >= 1:
+ self.conv2 = SplitAttn(
+ group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
+ dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block)
+ self.bn2 = nn.Identity()
+ self.act2 = nn.Identity()
+ else:
+ self.conv2 = nn.Conv2d(
+ group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
+ dilation=first_dilation, groups=cardinality, bias=False)
+ self.bn2 = norm_layer(group_width)
+ self.act2 = act_layer(inplace=True)
+ self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None
+
+ self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = norm_layer(planes*4)
+ self.act3 = act_layer(inplace=True)
+ self.downsample = downsample
+
+ def zero_init_last_bn(self):
+ nn.init.zeros_(self.bn3.weight)
+
+ def forward(self, x):
+ shortcut = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ if self.drop_block is not None:
+ out = self.drop_block(out)
+ out = self.act1(out)
+
+ if self.avd_first is not None:
+ out = self.avd_first(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ if self.drop_block is not None:
+ out = self.drop_block(out)
+ out = self.act2(out)
+
+ if self.avd_last is not None:
+ out = self.avd_last(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+ if self.drop_block is not None:
+ out = self.drop_block(out)
+
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+
+ out += shortcut
+ out = self.act3(out)
+ return out
+
+
+def _create_resnest(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ ResNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ **kwargs)
+
+
+@register_model
+def resnest14d(pretrained=False, **kwargs):
+ """ ResNeSt-14d model. Weights ported from GluonCV.
+ """
+ model_kwargs = dict(
+ block=ResNestBottleneck, layers=[1, 1, 1, 1],
+ stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
+ block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
+ return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def resnest26d(pretrained=False, **kwargs):
+ """ ResNeSt-26d model. Weights ported from GluonCV.
+ """
+ model_kwargs = dict(
+ block=ResNestBottleneck, layers=[2, 2, 2, 2],
+ stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
+ block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
+ return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def resnest50d(pretrained=False, **kwargs):
+ """ ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955
+ Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample.
+ """
+ model_kwargs = dict(
+ block=ResNestBottleneck, layers=[3, 4, 6, 3],
+ stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
+ block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
+ return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def resnest101e(pretrained=False, **kwargs):
+ """ ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955
+ Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
+ """
+ model_kwargs = dict(
+ block=ResNestBottleneck, layers=[3, 4, 23, 3],
+ stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
+ block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
+ return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def resnest200e(pretrained=False, **kwargs):
+ """ ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955
+ Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
+ """
+ model_kwargs = dict(
+ block=ResNestBottleneck, layers=[3, 24, 36, 3],
+ stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
+ block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
+ return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def resnest269e(pretrained=False, **kwargs):
+ """ ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955
+ Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
+ """
+ model_kwargs = dict(
+ block=ResNestBottleneck, layers=[3, 30, 48, 8],
+ stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
+ block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
+ return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def resnest50d_4s2x40d(pretrained=False, **kwargs):
+ """ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
+ """
+ model_kwargs = dict(
+ block=ResNestBottleneck, layers=[3, 4, 6, 3],
+ stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2,
+ block_args=dict(radix=4, avd=True, avd_first=True), **kwargs)
+ return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def resnest50d_1s4x24d(pretrained=False, **kwargs):
+ """ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
+ """
+ model_kwargs = dict(
+ block=ResNestBottleneck, layers=[3, 4, 6, 3],
+ stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
+ block_args=dict(radix=1, avd=True, avd_first=True), **kwargs)
+ return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs)
diff --git a/timm/models/resnet.py b/timm/models/resnet.py
new file mode 100644
index 0000000..bbcae9a
--- /dev/null
+++ b/timm/models/resnet.py
@@ -0,0 +1,1472 @@
+"""PyTorch ResNet
+
+This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
+additional dropout and dynamic global avg/max pool.
+
+ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
+Copyright 2020 Ross Wightman
+"""
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, create_classifier
+from .registry import register_model
+
+__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv1', 'classifier': 'fc',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # ResNet and Wide ResNet
+ 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
+ 'resnet18d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth',
+ interpolation='bicubic', first_conv='conv1.0'),
+ 'resnet34': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'),
+ 'resnet34d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34d_ra2-f8dcfcaf.pth',
+ interpolation='bicubic', first_conv='conv1.0'),
+ 'resnet26': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth',
+ interpolation='bicubic'),
+ 'resnet26d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth',
+ interpolation='bicubic', first_conv='conv1.0'),
+ 'resnet26t': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth',
+ interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
+ 'resnet50': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth',
+ interpolation='bicubic', crop_pct=0.95),
+ 'resnet50d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
+ interpolation='bicubic', first_conv='conv1.0'),
+ 'resnet50t': _cfg(
+ url='',
+ interpolation='bicubic', first_conv='conv1.0'),
+ 'resnet101': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth',
+ interpolation='bicubic', crop_pct=0.95),
+ 'resnet101d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth',
+ interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
+ crop_pct=1.0, test_input_size=(3, 320, 320)),
+ 'resnet152': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a1h-dc400468.pth',
+ interpolation='bicubic', crop_pct=0.95),
+ 'resnet152d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth',
+ interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
+ crop_pct=1.0, test_input_size=(3, 320, 320)),
+ 'resnet200': _cfg(url='', interpolation='bicubic'),
+ 'resnet200d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet200d_ra2-bdba9bf9.pth',
+ interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
+ crop_pct=1.0, test_input_size=(3, 320, 320)),
+ 'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'),
+ 'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'),
+ 'tv_resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
+ 'tv_resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'),
+ 'wide_resnet50_2': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/wide_resnet50_racm-8234f177.pth',
+ interpolation='bicubic'),
+ 'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'),
+
+ # ResNets w/ alternative norm layers
+ 'resnet50_gn': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_gn_a1h2-8fe6c4d0.pth',
+ crop_pct=0.94, interpolation='bicubic'),
+
+ # ResNeXt
+ 'resnext50_32x4d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a1h-0146ab0a.pth',
+ interpolation='bicubic', crop_pct=0.95),
+ 'resnext50d_32x4d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth',
+ interpolation='bicubic',
+ first_conv='conv1.0'),
+ 'resnext101_32x4d': _cfg(url=''),
+ 'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'),
+ 'resnext101_64x4d': _cfg(url=''),
+ 'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'),
+
+ # ResNeXt models - Weakly Supervised Pretraining on Instagram Hashtags
+ # from https://github.com/facebookresearch/WSL-Images
+ # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only.
+ 'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'),
+ 'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'),
+ 'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'),
+ 'ig_resnext101_32x48d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth'),
+
+ # Semi-Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models
+ # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only.
+ 'ssl_resnet18': _cfg(
+ url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth'),
+ 'ssl_resnet50': _cfg(
+ url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth'),
+ 'ssl_resnext50_32x4d': _cfg(
+ url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth'),
+ 'ssl_resnext101_32x4d': _cfg(
+ url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth'),
+ 'ssl_resnext101_32x8d': _cfg(
+ url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth'),
+ 'ssl_resnext101_32x16d': _cfg(
+ url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth'),
+
+ # Semi-Weakly Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models
+ # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only.
+ 'swsl_resnet18': _cfg(
+ url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth'),
+ 'swsl_resnet50': _cfg(
+ url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth'),
+ 'swsl_resnext50_32x4d': _cfg(
+ url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth'),
+ 'swsl_resnext101_32x4d': _cfg(
+ url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth'),
+ 'swsl_resnext101_32x8d': _cfg(
+ url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth'),
+ 'swsl_resnext101_32x16d': _cfg(
+ url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth'),
+
+ # Squeeze-Excitation ResNets, to eventually replace the models in senet.py
+ 'seresnet18': _cfg(
+ url='',
+ interpolation='bicubic'),
+ 'seresnet34': _cfg(
+ url='',
+ interpolation='bicubic'),
+ 'seresnet50': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet50_ra_224-8efdb4bb.pth',
+ interpolation='bicubic'),
+ 'seresnet50t': _cfg(
+ url='',
+ interpolation='bicubic',
+ first_conv='conv1.0'),
+ 'seresnet101': _cfg(
+ url='',
+ interpolation='bicubic'),
+ 'seresnet152': _cfg(
+ url='',
+ interpolation='bicubic'),
+ 'seresnet152d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth',
+ interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
+ crop_pct=1.0, test_input_size=(3, 320, 320)
+ ),
+ 'seresnet200d': _cfg(
+ url='',
+ interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
+ 'seresnet269d': _cfg(
+ url='',
+ interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
+
+
+ # Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py
+ 'seresnext26d_32x4d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth',
+ interpolation='bicubic',
+ first_conv='conv1.0'),
+ 'seresnext26t_32x4d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth',
+ interpolation='bicubic',
+ first_conv='conv1.0'),
+ 'seresnext50_32x4d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext50_32x4d_racm-a304a460.pth',
+ interpolation='bicubic'),
+ 'seresnext101_32x4d': _cfg(
+ url='',
+ interpolation='bicubic'),
+ 'seresnext101_32x8d': _cfg(
+ url='',
+ interpolation='bicubic'),
+ 'senet154': _cfg(
+ url='',
+ interpolation='bicubic',
+ first_conv='conv1.0'),
+
+ # Efficient Channel Attention ResNets
+ 'ecaresnet26t': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth',
+ interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
+ crop_pct=0.95, test_input_size=(3, 320, 320)),
+ 'ecaresnetlight': _cfg(
+ url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth',
+ interpolation='bicubic'),
+ 'ecaresnet50d': _cfg(
+ url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth',
+ interpolation='bicubic',
+ first_conv='conv1.0'),
+ 'ecaresnet50d_pruned': _cfg(
+ url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth',
+ interpolation='bicubic',
+ first_conv='conv1.0'),
+ 'ecaresnet50t': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth',
+ interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
+ crop_pct=0.95, test_input_size=(3, 320, 320)),
+ 'ecaresnet101d': _cfg(
+ url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth',
+ interpolation='bicubic', first_conv='conv1.0'),
+ 'ecaresnet101d_pruned': _cfg(
+ url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
+ interpolation='bicubic',
+ first_conv='conv1.0'),
+ 'ecaresnet200d': _cfg(
+ url='',
+ interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
+ 'ecaresnet269d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth',
+ interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10),
+ crop_pct=1.0, test_input_size=(3, 352, 352)),
+
+ # Efficient Channel Attention ResNeXts
+ 'ecaresnext26t_32x4d': _cfg(
+ url='',
+ interpolation='bicubic', first_conv='conv1.0'),
+ 'ecaresnext50t_32x4d': _cfg(
+ url='',
+ interpolation='bicubic', first_conv='conv1.0'),
+
+ # ResNets with anti-aliasing blur pool
+ 'resnetblur18': _cfg(
+ interpolation='bicubic'),
+ 'resnetblur50': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth',
+ interpolation='bicubic'),
+
+ # ResNet-RS models
+ 'resnetrs50': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50_ema-6b53758b.pth',
+ input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.91, test_input_size=(3, 224, 224),
+ interpolation='bicubic', first_conv='conv1.0'),
+ 'resnetrs101': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101_i192_ema-1509bbf6.pth',
+ input_size=(3, 192, 192), pool_size=(6, 6), crop_pct=0.94, test_input_size=(3, 288, 288),
+ interpolation='bicubic', first_conv='conv1.0'),
+ 'resnetrs152': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152_i256_ema-a9aff7f9.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
+ interpolation='bicubic', first_conv='conv1.0'),
+ 'resnetrs200': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs200_ema-623d2f59.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
+ interpolation='bicubic', first_conv='conv1.0'),
+ 'resnetrs270': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270_ema-b40e674c.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 352, 352),
+ interpolation='bicubic', first_conv='conv1.0'),
+ 'resnetrs350': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350_i256_ema-5a1aa8f1.pth',
+ input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, test_input_size=(3, 384, 384),
+ interpolation='bicubic', first_conv='conv1.0'),
+ 'resnetrs420': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420_ema-972dee69.pth',
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, test_input_size=(3, 416, 416),
+ interpolation='bicubic', first_conv='conv1.0'),
+}
+
+
+def get_padding(kernel_size, stride, dilation=1):
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
+ return padding
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
+ reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
+ attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
+ super(BasicBlock, self).__init__()
+
+ assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
+ assert base_width == 64, 'BasicBlock does not support changing base width'
+ first_planes = planes // reduce_first
+ outplanes = planes * self.expansion
+ first_dilation = first_dilation or dilation
+ use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
+
+ self.conv1 = nn.Conv2d(
+ inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation,
+ dilation=first_dilation, bias=False)
+ self.bn1 = norm_layer(first_planes)
+ self.act1 = act_layer(inplace=True)
+ self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None
+
+ self.conv2 = nn.Conv2d(
+ first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
+ self.bn2 = norm_layer(outplanes)
+
+ self.se = create_attn(attn_layer, outplanes)
+
+ self.act2 = act_layer(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.drop_block = drop_block
+ self.drop_path = drop_path
+
+ def zero_init_last_bn(self):
+ nn.init.zeros_(self.bn2.weight)
+
+ def forward(self, x):
+ shortcut = x
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ if self.drop_block is not None:
+ x = self.drop_block(x)
+ x = self.act1(x)
+ if self.aa is not None:
+ x = self.aa(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ if self.drop_block is not None:
+ x = self.drop_block(x)
+
+ if self.se is not None:
+ x = self.se(x)
+
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+
+ if self.downsample is not None:
+ shortcut = self.downsample(shortcut)
+ x += shortcut
+ x = self.act2(x)
+
+ return x
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
+ reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
+ attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
+ super(Bottleneck, self).__init__()
+
+ width = int(math.floor(planes * (base_width / 64)) * cardinality)
+ first_planes = width // reduce_first
+ outplanes = planes * self.expansion
+ first_dilation = first_dilation or dilation
+ use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
+
+ self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
+ self.bn1 = norm_layer(first_planes)
+ self.act1 = act_layer(inplace=True)
+
+ self.conv2 = nn.Conv2d(
+ first_planes, width, kernel_size=3, stride=1 if use_aa else stride,
+ padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
+ self.bn2 = norm_layer(width)
+ self.act2 = act_layer(inplace=True)
+ self.aa = aa_layer(channels=width, stride=stride) if use_aa else None
+
+ self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
+ self.bn3 = norm_layer(outplanes)
+
+ self.se = create_attn(attn_layer, outplanes)
+
+ self.act3 = act_layer(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.drop_block = drop_block
+ self.drop_path = drop_path
+
+ def zero_init_last_bn(self):
+ nn.init.zeros_(self.bn3.weight)
+
+ def forward(self, x):
+ shortcut = x
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ if self.drop_block is not None:
+ x = self.drop_block(x)
+ x = self.act1(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ if self.drop_block is not None:
+ x = self.drop_block(x)
+ x = self.act2(x)
+ if self.aa is not None:
+ x = self.aa(x)
+
+ x = self.conv3(x)
+ x = self.bn3(x)
+ if self.drop_block is not None:
+ x = self.drop_block(x)
+
+ if self.se is not None:
+ x = self.se(x)
+
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+
+ if self.downsample is not None:
+ shortcut = self.downsample(shortcut)
+ x += shortcut
+ x = self.act3(x)
+
+ return x
+
+
+def downsample_conv(
+ in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
+ norm_layer = norm_layer or nn.BatchNorm2d
+ kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
+ first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
+ p = get_padding(kernel_size, stride, first_dilation)
+
+ return nn.Sequential(*[
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False),
+ norm_layer(out_channels)
+ ])
+
+
+def downsample_avg(
+ in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
+ norm_layer = norm_layer or nn.BatchNorm2d
+ avg_stride = stride if dilation == 1 else 1
+ if stride == 1 and dilation == 1:
+ pool = nn.Identity()
+ else:
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
+ pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
+
+ return nn.Sequential(*[
+ pool,
+ nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False),
+ norm_layer(out_channels)
+ ])
+
+
+def drop_blocks(drop_block_rate=0.):
+ return [
+ None, None,
+ DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None,
+ DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None]
+
+
+def make_blocks(
+ block_fn, channels, block_repeats, inplanes, reduce_first=1, output_stride=32,
+ down_kernel_size=1, avg_down=False, drop_block_rate=0., drop_path_rate=0., **kwargs):
+ stages = []
+ feature_info = []
+ net_num_blocks = sum(block_repeats)
+ net_block_idx = 0
+ net_stride = 4
+ dilation = prev_dilation = 1
+ for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))):
+ stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
+ stride = 1 if stage_idx == 0 else 2
+ if net_stride >= output_stride:
+ dilation *= stride
+ stride = 1
+ else:
+ net_stride *= stride
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block_fn.expansion:
+ down_kwargs = dict(
+ in_channels=inplanes, out_channels=planes * block_fn.expansion, kernel_size=down_kernel_size,
+ stride=stride, dilation=dilation, first_dilation=prev_dilation, norm_layer=kwargs.get('norm_layer'))
+ downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
+
+ block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs)
+ blocks = []
+ for block_idx in range(num_blocks):
+ downsample = downsample if block_idx == 0 else None
+ stride = stride if block_idx == 0 else 1
+ block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule
+ blocks.append(block_fn(
+ inplanes, planes, stride, downsample, first_dilation=prev_dilation,
+ drop_path=DropPath(block_dpr) if block_dpr > 0. else None, **block_kwargs))
+ prev_dilation = dilation
+ inplanes = planes * block_fn.expansion
+ net_block_idx += 1
+
+ stages.append((stage_name, nn.Sequential(*blocks)))
+ feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name))
+
+ return stages, feature_info
+
+
+class ResNet(nn.Module):
+ """ResNet / ResNeXt / SE-ResNeXt / SE-Net
+
+ This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that
+ * have > 1 stride in the 3x3 conv layer of bottleneck
+ * have conv-bn-act ordering
+
+ This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s
+ variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the
+ 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.
+
+ ResNet variants (the same modifications can be used in SE/ResNeXt models as well):
+ * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
+ * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64)
+ * d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample
+ * e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample
+ * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128)
+ * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample
+ * tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample
+
+ ResNeXt
+ * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
+ * same c,d, e, s variants as ResNet can be enabled
+
+ SE-ResNeXt
+ * normal - 7x7 stem, stem_width = 64
+ * same c, d, e, s variants as ResNet can be enabled
+
+ SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64,
+ reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block
+
+ Parameters
+ ----------
+ block : Block
+ Class for the residual block. Options are BasicBlockGl, BottleneckGl.
+ layers : list of int
+ Numbers of layers in each block
+ num_classes : int, default 1000
+ Number of classification classes.
+ in_chans : int, default 3
+ Number of input (color) channels.
+ cardinality : int, default 1
+ Number of convolution groups for 3x3 conv in Bottleneck.
+ base_width : int, default 64
+ Factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
+ stem_width : int, default 64
+ Number of channels in stem convolutions
+ stem_type : str, default ''
+ The type of stem:
+ * '', default - a single 7x7 conv with a width of stem_width
+ * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
+ * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
+ block_reduce_first: int, default 1
+ Reduction factor for first convolution output width of residual blocks,
+ 1 for all archs except senets, where 2
+ down_kernel_size: int, default 1
+ Kernel size of residual block downsampling path, 1x1 for most archs, 3x3 for senets
+ avg_down : bool, default False
+ Whether to use average pooling for projection skip connection between stages/downsample.
+ output_stride : int, default 32
+ Set the output stride of the network, 32, 16, or 8. Typically used in segmentation.
+ act_layer : nn.Module, activation layer
+ norm_layer : nn.Module, normalization layer
+ aa_layer : nn.Module, anti-aliasing layer
+ drop_rate : float, default 0.
+ Dropout probability before classifier, for training
+ global_pool : str, default 'avg'
+ Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
+ """
+
+ def __init__(self, block, layers, num_classes=1000, in_chans=3,
+ cardinality=1, base_width=64, stem_width=64, stem_type='', replace_stem_pool=False,
+ output_stride=32, block_reduce_first=1, down_kernel_size=1, avg_down=False,
+ act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_rate=0.0, drop_path_rate=0.,
+ drop_block_rate=0., global_pool='avg', zero_init_last_bn=True, block_args=None):
+ block_args = block_args or dict()
+ assert output_stride in (8, 16, 32)
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ super(ResNet, self).__init__()
+
+ # Stem
+ deep_stem = 'deep' in stem_type
+ inplanes = stem_width * 2 if deep_stem else 64
+ if deep_stem:
+ stem_chs = (stem_width, stem_width)
+ if 'tiered' in stem_type:
+ stem_chs = (3 * (stem_width // 4), stem_width)
+ self.conv1 = nn.Sequential(*[
+ nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
+ norm_layer(stem_chs[0]),
+ act_layer(inplace=True),
+ nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
+ norm_layer(stem_chs[1]),
+ act_layer(inplace=True),
+ nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)])
+ else:
+ self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = norm_layer(inplanes)
+ self.act1 = act_layer(inplace=True)
+ self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
+
+ # Stem Pooling
+ if replace_stem_pool:
+ self.maxpool = nn.Sequential(*filter(None, [
+ nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False),
+ aa_layer(channels=inplanes, stride=2) if aa_layer else None,
+ norm_layer(inplanes),
+ act_layer(inplace=True)
+ ]))
+ else:
+ if aa_layer is not None:
+ self.maxpool = nn.Sequential(*[
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
+ aa_layer(channels=inplanes, stride=2)])
+ else:
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ # Feature Blocks
+ channels = [64, 128, 256, 512]
+ stage_modules, stage_feature_info = make_blocks(
+ block, channels, layers, inplanes, cardinality=cardinality, base_width=base_width,
+ output_stride=output_stride, reduce_first=block_reduce_first, avg_down=avg_down,
+ down_kernel_size=down_kernel_size, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer,
+ drop_block_rate=drop_block_rate, drop_path_rate=drop_path_rate, **block_args)
+ for stage in stage_modules:
+ self.add_module(*stage) # layer1, layer2, etc
+ self.feature_info.extend(stage_feature_info)
+
+ # Head (Pooling and Classifier)
+ self.num_features = 512 * block.expansion
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ self.init_weights(zero_init_last_bn=zero_init_last_bn)
+
+ def init_weights(self, zero_init_last_bn=True):
+ for n, m in self.named_modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ if zero_init_last_bn:
+ for m in self.modules():
+ if hasattr(m, 'zero_init_last_bn'):
+ m.zero_init_last_bn()
+
+ def get_classifier(self):
+ return self.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate:
+ x = F.dropout(x, p=float(self.drop_rate), training=self.training)
+ x = self.fc(x)
+ return x
+
+
+def _create_resnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ ResNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ **kwargs)
+
+
+@register_model
+def resnet18(pretrained=False, **kwargs):
+ """Constructs a ResNet-18 model.
+ """
+ model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
+ return _create_resnet('resnet18', pretrained, **model_args)
+
+
+@register_model
+def resnet18d(pretrained=False, **kwargs):
+ """Constructs a ResNet-18-D model.
+ """
+ model_args = dict(
+ block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
+ return _create_resnet('resnet18d', pretrained, **model_args)
+
+
+@register_model
+def resnet34(pretrained=False, **kwargs):
+ """Constructs a ResNet-34 model.
+ """
+ model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
+ return _create_resnet('resnet34', pretrained, **model_args)
+
+
+@register_model
+def resnet34d(pretrained=False, **kwargs):
+ """Constructs a ResNet-34-D model.
+ """
+ model_args = dict(
+ block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
+ return _create_resnet('resnet34d', pretrained, **model_args)
+
+
+@register_model
+def resnet26(pretrained=False, **kwargs):
+ """Constructs a ResNet-26 model.
+ """
+ model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], **kwargs)
+ return _create_resnet('resnet26', pretrained, **model_args)
+
+
+@register_model
+def resnet26t(pretrained=False, **kwargs):
+ """Constructs a ResNet-26-T model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
+ return _create_resnet('resnet26t', pretrained, **model_args)
+
+
+@register_model
+def resnet26d(pretrained=False, **kwargs):
+ """Constructs a ResNet-26-D model.
+ """
+ model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
+ return _create_resnet('resnet26d', pretrained, **model_args)
+
+
+@register_model
+def resnet50(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
+ return _create_resnet('resnet50', pretrained, **model_args)
+
+
+@register_model
+def resnet50d(pretrained=False, **kwargs):
+ """Constructs a ResNet-50-D model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
+ return _create_resnet('resnet50d', pretrained, **model_args)
+
+
+@register_model
+def resnet50t(pretrained=False, **kwargs):
+ """Constructs a ResNet-50-T model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
+ return _create_resnet('resnet50t', pretrained, **model_args)
+
+
+@register_model
+def resnet101(pretrained=False, **kwargs):
+ """Constructs a ResNet-101 model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs)
+ return _create_resnet('resnet101', pretrained, **model_args)
+
+
+@register_model
+def resnet101d(pretrained=False, **kwargs):
+ """Constructs a ResNet-101-D model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
+ return _create_resnet('resnet101d', pretrained, **model_args)
+
+
+@register_model
+def resnet152(pretrained=False, **kwargs):
+ """Constructs a ResNet-152 model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs)
+ return _create_resnet('resnet152', pretrained, **model_args)
+
+
+@register_model
+def resnet152d(pretrained=False, **kwargs):
+ """Constructs a ResNet-152-D model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
+ return _create_resnet('resnet152d', pretrained, **model_args)
+
+
+@register_model
+def resnet200(pretrained=False, **kwargs):
+ """Constructs a ResNet-200 model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], **kwargs)
+ return _create_resnet('resnet200', pretrained, **model_args)
+
+
+@register_model
+def resnet200d(pretrained=False, **kwargs):
+ """Constructs a ResNet-200-D model.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
+ return _create_resnet('resnet200d', pretrained, **model_args)
+
+
+@register_model
+def tv_resnet34(pretrained=False, **kwargs):
+ """Constructs a ResNet-34 model with original Torchvision weights.
+ """
+ model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
+ return _create_resnet('tv_resnet34', pretrained, **model_args)
+
+
+@register_model
+def tv_resnet50(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model with original Torchvision weights.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
+ return _create_resnet('tv_resnet50', pretrained, **model_args)
+
+
+@register_model
+def tv_resnet101(pretrained=False, **kwargs):
+ """Constructs a ResNet-101 model w/ Torchvision pretrained weights.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs)
+ return _create_resnet('tv_resnet101', pretrained, **model_args)
+
+
+@register_model
+def tv_resnet152(pretrained=False, **kwargs):
+ """Constructs a ResNet-152 model w/ Torchvision pretrained weights.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs)
+ return _create_resnet('tv_resnet152', pretrained, **model_args)
+
+
+@register_model
+def wide_resnet50_2(pretrained=False, **kwargs):
+ """Constructs a Wide ResNet-50-2 model.
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128, **kwargs)
+ return _create_resnet('wide_resnet50_2', pretrained, **model_args)
+
+
+@register_model
+def wide_resnet101_2(pretrained=False, **kwargs):
+ """Constructs a Wide ResNet-101-2 model.
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128, **kwargs)
+ return _create_resnet('wide_resnet101_2', pretrained, **model_args)
+
+
+@register_model
+def resnet50_gn(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model w/ GroupNorm
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
+ return _create_resnet('resnet50_gn', pretrained, norm_layer=GroupNorm, **model_args)
+
+
+@register_model
+def resnext50_32x4d(pretrained=False, **kwargs):
+ """Constructs a ResNeXt50-32x4d model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
+ return _create_resnet('resnext50_32x4d', pretrained, **model_args)
+
+
+@register_model
+def resnext50d_32x4d(pretrained=False, **kwargs):
+ """Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
+ stem_width=32, stem_type='deep', avg_down=True, **kwargs)
+ return _create_resnet('resnext50d_32x4d', pretrained, **model_args)
+
+
+@register_model
+def resnext101_32x4d(pretrained=False, **kwargs):
+ """Constructs a ResNeXt-101 32x4d model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
+ return _create_resnet('resnext101_32x4d', pretrained, **model_args)
+
+
+@register_model
+def resnext101_32x8d(pretrained=False, **kwargs):
+ """Constructs a ResNeXt-101 32x8d model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
+ return _create_resnet('resnext101_32x8d', pretrained, **model_args)
+
+
+@register_model
+def resnext101_64x4d(pretrained=False, **kwargs):
+ """Constructs a ResNeXt101-64x4d model.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs)
+ return _create_resnet('resnext101_64x4d', pretrained, **model_args)
+
+
+@register_model
+def tv_resnext50_32x4d(pretrained=False, **kwargs):
+ """Constructs a ResNeXt50-32x4d model with original Torchvision weights.
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
+ return _create_resnet('tv_resnext50_32x4d', pretrained, **model_args)
+
+
+@register_model
+def ig_resnext101_32x8d(pretrained=True, **kwargs):
+ """Constructs a ResNeXt-101 32x8 model pre-trained on weakly-supervised data
+ and finetuned on ImageNet from Figure 5 in
+ `"Exploring the Limits of Weakly Supervised Pretraining" `_
+ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
+ return _create_resnet('ig_resnext101_32x8d', pretrained, **model_args)
+
+
+@register_model
+def ig_resnext101_32x16d(pretrained=True, **kwargs):
+ """Constructs a ResNeXt-101 32x16 model pre-trained on weakly-supervised data
+ and finetuned on ImageNet from Figure 5 in
+ `"Exploring the Limits of Weakly Supervised Pretraining" `_
+ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
+ return _create_resnet('ig_resnext101_32x16d', pretrained, **model_args)
+
+
+@register_model
+def ig_resnext101_32x32d(pretrained=True, **kwargs):
+ """Constructs a ResNeXt-101 32x32 model pre-trained on weakly-supervised data
+ and finetuned on ImageNet from Figure 5 in
+ `"Exploring the Limits of Weakly Supervised Pretraining" `_
+ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32, **kwargs)
+ return _create_resnet('ig_resnext101_32x32d', pretrained, **model_args)
+
+
+@register_model
+def ig_resnext101_32x48d(pretrained=True, **kwargs):
+ """Constructs a ResNeXt-101 32x48 model pre-trained on weakly-supervised data
+ and finetuned on ImageNet from Figure 5 in
+ `"Exploring the Limits of Weakly Supervised Pretraining" `_
+ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48, **kwargs)
+ return _create_resnet('ig_resnext101_32x48d', pretrained, **model_args)
+
+
+@register_model
+def ssl_resnet18(pretrained=True, **kwargs):
+ """Constructs a semi-supervised ResNet-18 model pre-trained on YFCC100M dataset and finetuned on ImageNet
+ `"Billion-scale Semi-Supervised Learning for Image Classification" `_
+ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
+ """
+ model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
+ return _create_resnet('ssl_resnet18', pretrained, **model_args)
+
+
+@register_model
+def ssl_resnet50(pretrained=True, **kwargs):
+ """Constructs a semi-supervised ResNet-50 model pre-trained on YFCC100M dataset and finetuned on ImageNet
+ `"Billion-scale Semi-Supervised Learning for Image Classification" `_
+ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
+ return _create_resnet('ssl_resnet50', pretrained, **model_args)
+
+
+@register_model
+def ssl_resnext50_32x4d(pretrained=True, **kwargs):
+ """Constructs a semi-supervised ResNeXt-50 32x4 model pre-trained on YFCC100M dataset and finetuned on ImageNet
+ `"Billion-scale Semi-Supervised Learning for Image Classification" `_
+ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
+ return _create_resnet('ssl_resnext50_32x4d', pretrained, **model_args)
+
+
+@register_model
+def ssl_resnext101_32x4d(pretrained=True, **kwargs):
+ """Constructs a semi-supervised ResNeXt-101 32x4 model pre-trained on YFCC100M dataset and finetuned on ImageNet
+ `"Billion-scale Semi-Supervised Learning for Image Classification" `_
+ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
+ return _create_resnet('ssl_resnext101_32x4d', pretrained, **model_args)
+
+
+@register_model
+def ssl_resnext101_32x8d(pretrained=True, **kwargs):
+ """Constructs a semi-supervised ResNeXt-101 32x8 model pre-trained on YFCC100M dataset and finetuned on ImageNet
+ `"Billion-scale Semi-Supervised Learning for Image Classification" `_
+ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
+ return _create_resnet('ssl_resnext101_32x8d', pretrained, **model_args)
+
+
+@register_model
+def ssl_resnext101_32x16d(pretrained=True, **kwargs):
+ """Constructs a semi-supervised ResNeXt-101 32x16 model pre-trained on YFCC100M dataset and finetuned on ImageNet
+ `"Billion-scale Semi-Supervised Learning for Image Classification" `_
+ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
+ return _create_resnet('ssl_resnext101_32x16d', pretrained, **model_args)
+
+
+@register_model
+def swsl_resnet18(pretrained=True, **kwargs):
+ """Constructs a semi-weakly supervised Resnet-18 model pre-trained on 1B weakly supervised
+ image dataset and finetuned on ImageNet.
+ `"Billion-scale Semi-Supervised Learning for Image Classification" `_
+ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
+ """
+ model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
+ return _create_resnet('swsl_resnet18', pretrained, **model_args)
+
+
+@register_model
+def swsl_resnet50(pretrained=True, **kwargs):
+ """Constructs a semi-weakly supervised ResNet-50 model pre-trained on 1B weakly supervised
+ image dataset and finetuned on ImageNet.
+ `"Billion-scale Semi-Supervised Learning for Image Classification" `_
+ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
+ return _create_resnet('swsl_resnet50', pretrained, **model_args)
+
+
+@register_model
+def swsl_resnext50_32x4d(pretrained=True, **kwargs):
+ """Constructs a semi-weakly supervised ResNeXt-50 32x4 model pre-trained on 1B weakly supervised
+ image dataset and finetuned on ImageNet.
+ `"Billion-scale Semi-Supervised Learning for Image Classification" `_
+ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
+ return _create_resnet('swsl_resnext50_32x4d', pretrained, **model_args)
+
+
+@register_model
+def swsl_resnext101_32x4d(pretrained=True, **kwargs):
+ """Constructs a semi-weakly supervised ResNeXt-101 32x4 model pre-trained on 1B weakly supervised
+ image dataset and finetuned on ImageNet.
+ `"Billion-scale Semi-Supervised Learning for Image Classification" `_
+ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
+ return _create_resnet('swsl_resnext101_32x4d', pretrained, **model_args)
+
+
+@register_model
+def swsl_resnext101_32x8d(pretrained=True, **kwargs):
+ """Constructs a semi-weakly supervised ResNeXt-101 32x8 model pre-trained on 1B weakly supervised
+ image dataset and finetuned on ImageNet.
+ `"Billion-scale Semi-Supervised Learning for Image Classification" `_
+ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
+ return _create_resnet('swsl_resnext101_32x8d', pretrained, **model_args)
+
+
+@register_model
+def swsl_resnext101_32x16d(pretrained=True, **kwargs):
+ """Constructs a semi-weakly supervised ResNeXt-101 32x16 model pre-trained on 1B weakly supervised
+ image dataset and finetuned on ImageNet.
+ `"Billion-scale Semi-Supervised Learning for Image Classification" `_
+ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
+ return _create_resnet('swsl_resnext101_32x16d', pretrained, **model_args)
+
+
+@register_model
+def ecaresnet26t(pretrained=False, **kwargs):
+ """Constructs an ECA-ResNeXt-26-T model.
+ This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
+ in the deep stem and ECA attn.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32,
+ stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
+ return _create_resnet('ecaresnet26t', pretrained, **model_args)
+
+
+@register_model
+def ecaresnet50d(pretrained=False, **kwargs):
+ """Constructs a ResNet-50-D model with eca.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
+ block_args=dict(attn_layer='eca'), **kwargs)
+ return _create_resnet('ecaresnet50d', pretrained, **model_args)
+
+
+@register_model
+def resnetrs50(pretrained=False, **kwargs):
+ """Constructs a ResNet-RS-50 model.
+ Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
+ Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
+ """
+ attn_layer = partial(get_attn('se'), rd_ratio=0.25)
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
+ avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
+ return _create_resnet('resnetrs50', pretrained, **model_args)
+
+
+@register_model
+def resnetrs101(pretrained=False, **kwargs):
+ """Constructs a ResNet-RS-101 model.
+ Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
+ Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
+ """
+ attn_layer = partial(get_attn('se'), rd_ratio=0.25)
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
+ avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
+ return _create_resnet('resnetrs101', pretrained, **model_args)
+
+
+@register_model
+def resnetrs152(pretrained=False, **kwargs):
+ """Constructs a ResNet-RS-152 model.
+ Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
+ Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
+ """
+ attn_layer = partial(get_attn('se'), rd_ratio=0.25)
+ model_args = dict(
+ block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
+ avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
+ return _create_resnet('resnetrs152', pretrained, **model_args)
+
+
+@register_model
+def resnetrs200(pretrained=False, **kwargs):
+ """Constructs a ResNet-RS-200 model.
+ Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
+ Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
+ """
+ attn_layer = partial(get_attn('se'), rd_ratio=0.25)
+ model_args = dict(
+ block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
+ avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
+ return _create_resnet('resnetrs200', pretrained, **model_args)
+
+
+@register_model
+def resnetrs270(pretrained=False, **kwargs):
+ """Constructs a ResNet-RS-270 model.
+ Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
+ Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
+ """
+ attn_layer = partial(get_attn('se'), rd_ratio=0.25)
+ model_args = dict(
+ block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
+ avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
+ return _create_resnet('resnetrs270', pretrained, **model_args)
+
+
+
+@register_model
+def resnetrs350(pretrained=False, **kwargs):
+ """Constructs a ResNet-RS-350 model.
+ Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
+ Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
+ """
+ attn_layer = partial(get_attn('se'), rd_ratio=0.25)
+ model_args = dict(
+ block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
+ avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
+ return _create_resnet('resnetrs350', pretrained, **model_args)
+
+
+@register_model
+def resnetrs420(pretrained=False, **kwargs):
+ """Constructs a ResNet-RS-420 model
+ Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
+ Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
+ """
+ attn_layer = partial(get_attn('se'), rd_ratio=0.25)
+ model_args = dict(
+ block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
+ avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
+ return _create_resnet('resnetrs420', pretrained, **model_args)
+
+
+@register_model
+def ecaresnet50d_pruned(pretrained=False, **kwargs):
+ """Constructs a ResNet-50-D model pruned with eca.
+ The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
+ block_args=dict(attn_layer='eca'), **kwargs)
+ return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args)
+
+
+@register_model
+def ecaresnet50t(pretrained=False, **kwargs):
+ """Constructs an ECA-ResNet-50-T model.
+ Like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32,
+ stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
+ return _create_resnet('ecaresnet50t', pretrained, **model_args)
+
+
+@register_model
+def ecaresnetlight(pretrained=False, **kwargs):
+ """Constructs a ResNet-50-D light model with eca.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True,
+ block_args=dict(attn_layer='eca'), **kwargs)
+ return _create_resnet('ecaresnetlight', pretrained, **model_args)
+
+
+@register_model
+def ecaresnet101d(pretrained=False, **kwargs):
+ """Constructs a ResNet-101-D model with eca.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
+ block_args=dict(attn_layer='eca'), **kwargs)
+ return _create_resnet('ecaresnet101d', pretrained, **model_args)
+
+
+@register_model
+def ecaresnet101d_pruned(pretrained=False, **kwargs):
+ """Constructs a ResNet-101-D model pruned with eca.
+ The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
+ block_args=dict(attn_layer='eca'), **kwargs)
+ return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args)
+
+
+@register_model
+def ecaresnet200d(pretrained=False, **kwargs):
+ """Constructs a ResNet-200-D model with ECA.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
+ block_args=dict(attn_layer='eca'), **kwargs)
+ return _create_resnet('ecaresnet200d', pretrained, **model_args)
+
+
+@register_model
+def ecaresnet269d(pretrained=False, **kwargs):
+ """Constructs a ResNet-269-D model with ECA.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True,
+ block_args=dict(attn_layer='eca'), **kwargs)
+ return _create_resnet('ecaresnet269d', pretrained, **model_args)
+
+
+@register_model
+def ecaresnext26t_32x4d(pretrained=False, **kwargs):
+ """Constructs an ECA-ResNeXt-26-T model.
+ This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
+ in the deep stem. This model replaces SE module with the ECA module
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
+ stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
+ return _create_resnet('ecaresnext26t_32x4d', pretrained, **model_args)
+
+
+@register_model
+def ecaresnext50t_32x4d(pretrained=False, **kwargs):
+ """Constructs an ECA-ResNeXt-50-T model.
+ This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
+ in the deep stem. This model replaces SE module with the ECA module
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
+ stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
+ return _create_resnet('ecaresnext50t_32x4d', pretrained, **model_args)
+
+
+@register_model
+def resnetblur18(pretrained=False, **kwargs):
+ """Constructs a ResNet-18 model with blur anti-aliasing
+ """
+ model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs)
+ return _create_resnet('resnetblur18', pretrained, **model_args)
+
+
+@register_model
+def resnetblur50(pretrained=False, **kwargs):
+ """Constructs a ResNet-50 model with blur anti-aliasing
+ """
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs)
+ return _create_resnet('resnetblur50', pretrained, **model_args)
+
+
+@register_model
+def seresnet18(pretrained=False, **kwargs):
+ model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnet18', pretrained, **model_args)
+
+
+@register_model
+def seresnet34(pretrained=False, **kwargs):
+ model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnet34', pretrained, **model_args)
+
+
+@register_model
+def seresnet50(pretrained=False, **kwargs):
+ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnet50', pretrained, **model_args)
+
+
+@register_model
+def seresnet50t(pretrained=False, **kwargs):
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True,
+ block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnet50t', pretrained, **model_args)
+
+
+@register_model
+def seresnet101(pretrained=False, **kwargs):
+ model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnet101', pretrained, **model_args)
+
+
+@register_model
+def seresnet152(pretrained=False, **kwargs):
+ model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnet152', pretrained, **model_args)
+
+
+@register_model
+def seresnet152d(pretrained=False, **kwargs):
+ model_args = dict(
+ block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
+ block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnet152d', pretrained, **model_args)
+
+
+@register_model
+def seresnet200d(pretrained=False, **kwargs):
+ """Constructs a ResNet-200-D model with SE attn.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
+ block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnet200d', pretrained, **model_args)
+
+
+@register_model
+def seresnet269d(pretrained=False, **kwargs):
+ """Constructs a ResNet-269-D model with SE attn.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True,
+ block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnet269d', pretrained, **model_args)
+
+
+@register_model
+def seresnext26d_32x4d(pretrained=False, **kwargs):
+ """Constructs a SE-ResNeXt-26-D model.`
+ This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for
+ combination of deep stem and avg_pool in downsample.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
+ stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnext26d_32x4d', pretrained, **model_args)
+
+
+@register_model
+def seresnext26t_32x4d(pretrained=False, **kwargs):
+ """Constructs a SE-ResNet-26-T model.
+ This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
+ in the deep stem.
+ """
+ model_args = dict(
+ block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
+ stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnext26t_32x4d', pretrained, **model_args)
+
+
+@register_model
+def seresnext26tn_32x4d(pretrained=False, **kwargs):
+ """Constructs a SE-ResNeXt-26-T model.
+ NOTE I deprecated previous 't' model defs and replaced 't' with 'tn', this was the only tn model of note
+ so keeping this def for backwards compat with any uses out there. Old 't' model is lost.
+ """
+ return seresnext26t_32x4d(pretrained=pretrained, **kwargs)
+
+
+@register_model
+def seresnext50_32x4d(pretrained=False, **kwargs):
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
+ block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnext50_32x4d', pretrained, **model_args)
+
+
+@register_model
+def seresnext101_32x4d(pretrained=False, **kwargs):
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4,
+ block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnext101_32x4d', pretrained, **model_args)
+
+
+@register_model
+def seresnext101_32x8d(pretrained=False, **kwargs):
+ model_args = dict(
+ block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
+ block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('seresnext101_32x8d', pretrained, **model_args)
+
+
+@register_model
+def senet154(pretrained=False, **kwargs):
+ model_args = dict(
+ block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep',
+ down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs)
+ return _create_resnet('senet154', pretrained, **model_args)
diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py
new file mode 100644
index 0000000..e38eaf5
--- /dev/null
+++ b/timm/models/resnetv2.py
@@ -0,0 +1,672 @@
+"""Pre-Activation ResNet v2 with GroupNorm and Weight Standardization.
+
+A PyTorch implementation of ResNetV2 adapted from the Google Big-Transfoer (BiT) source code
+at https://github.com/google-research/big_transfer to match timm interfaces. The BiT weights have
+been included here as pretrained models from their original .NPZ checkpoints.
+
+Additionally, supports non pre-activation bottleneck for use as a backbone for Vision Transfomers (ViT) and
+extra padding support to allow porting of official Hybrid ResNet pretrained weights from
+https://github.com/google-research/vision_transformer
+
+Thanks to the Google team for the above two repositories and associated papers:
+* Big Transfer (BiT): General Visual Representation Learning - https://arxiv.org/abs/1912.11370
+* An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929
+* Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
+
+Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020.
+"""
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import OrderedDict # pylint: disable=g-importing-member
+
+import torch
+import torch.nn as nn
+from functools import partial
+
+from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
+from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
+from .registry import register_model
+from .layers import GroupNormAct, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d,\
+ ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
+ 'first_conv': 'stem.conv', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # pretrained on imagenet21k, finetuned on imagenet1k
+ 'resnetv2_50x1_bitm': _cfg(
+ url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz',
+ input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
+ 'resnetv2_50x3_bitm': _cfg(
+ url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz',
+ input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
+ 'resnetv2_101x1_bitm': _cfg(
+ url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz',
+ input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
+ 'resnetv2_101x3_bitm': _cfg(
+ url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz',
+ input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
+ 'resnetv2_152x2_bitm': _cfg(
+ url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz',
+ input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0),
+ 'resnetv2_152x4_bitm': _cfg(
+ url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz',
+ input_size=(3, 480, 480), pool_size=(15, 15), crop_pct=1.0), # only one at 480x480?
+
+ # trained on imagenet-21k
+ 'resnetv2_50x1_bitm_in21k': _cfg(
+ url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz',
+ num_classes=21843),
+ 'resnetv2_50x3_bitm_in21k': _cfg(
+ url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz',
+ num_classes=21843),
+ 'resnetv2_101x1_bitm_in21k': _cfg(
+ url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz',
+ num_classes=21843),
+ 'resnetv2_101x3_bitm_in21k': _cfg(
+ url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz',
+ num_classes=21843),
+ 'resnetv2_152x2_bitm_in21k': _cfg(
+ url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz',
+ num_classes=21843),
+ 'resnetv2_152x4_bitm_in21k': _cfg(
+ url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
+ num_classes=21843),
+
+ 'resnetv2_50x1_bit_distilled': _cfg(
+ url='https://storage.googleapis.com/bit_models/distill/R50x1_224.npz',
+ interpolation='bicubic'),
+ 'resnetv2_152x2_bit_teacher': _cfg(
+ url='https://storage.googleapis.com/bit_models/distill/R152x2_T_224.npz',
+ interpolation='bicubic'),
+ 'resnetv2_152x2_bit_teacher_384': _cfg(
+ url='https://storage.googleapis.com/bit_models/distill/R152x2_T_384.npz',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, interpolation='bicubic'),
+
+ 'resnetv2_50': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_50_a1h-000cdf49.pth',
+ interpolation='bicubic', crop_pct=0.95),
+ 'resnetv2_50d': _cfg(
+ interpolation='bicubic', first_conv='stem.conv1'),
+ 'resnetv2_50t': _cfg(
+ interpolation='bicubic', first_conv='stem.conv1'),
+ 'resnetv2_101': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetv2_101_a1h-5d01f016.pth',
+ interpolation='bicubic', crop_pct=0.95),
+ 'resnetv2_101d': _cfg(
+ interpolation='bicubic', first_conv='stem.conv1'),
+ 'resnetv2_152': _cfg(
+ interpolation='bicubic'),
+ 'resnetv2_152d': _cfg(
+ interpolation='bicubic', first_conv='stem.conv1'),
+
+ 'resnetv2_50d_gn': _cfg(
+ interpolation='bicubic', first_conv='stem.conv1'),
+ 'resnetv2_50d_evob': _cfg(
+ interpolation='bicubic', first_conv='stem.conv1'),
+ 'resnetv2_50d_evos': _cfg(
+ interpolation='bicubic', first_conv='stem.conv1'),
+}
+
+
+def make_div(v, divisor=8):
+ min_value = divisor
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+ if new_v < 0.9 * v:
+ new_v += divisor
+ return new_v
+
+
+class PreActBottleneck(nn.Module):
+ """Pre-activation (v2) bottleneck block.
+
+ Follows the implementation of "Identity Mappings in Deep Residual Networks":
+ https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
+
+ Except it puts the stride on 3x3 conv when available.
+ """
+
+ def __init__(
+ self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
+ act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
+ super().__init__()
+ first_dilation = first_dilation or dilation
+ conv_layer = conv_layer or StdConv2d
+ norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
+ out_chs = out_chs or in_chs
+ mid_chs = make_div(out_chs * bottle_ratio)
+
+ if proj_layer is not None:
+ self.downsample = proj_layer(
+ in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True,
+ conv_layer=conv_layer, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ self.norm1 = norm_layer(in_chs)
+ self.conv1 = conv_layer(in_chs, mid_chs, 1)
+ self.norm2 = norm_layer(mid_chs)
+ self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
+ self.norm3 = norm_layer(mid_chs)
+ self.conv3 = conv_layer(mid_chs, out_chs, 1)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
+
+ def zero_init_last(self):
+ nn.init.zeros_(self.conv3.weight)
+
+ def forward(self, x):
+ x_preact = self.norm1(x)
+
+ # shortcut branch
+ shortcut = x
+ if self.downsample is not None:
+ shortcut = self.downsample(x_preact)
+
+ # residual branch
+ x = self.conv1(x_preact)
+ x = self.conv2(self.norm2(x))
+ x = self.conv3(self.norm3(x))
+ x = self.drop_path(x)
+ return x + shortcut
+
+
+class Bottleneck(nn.Module):
+ """Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT.
+ """
+ def __init__(
+ self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
+ act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
+ super().__init__()
+ first_dilation = first_dilation or dilation
+ act_layer = act_layer or nn.ReLU
+ conv_layer = conv_layer or StdConv2d
+ norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
+ out_chs = out_chs or in_chs
+ mid_chs = make_div(out_chs * bottle_ratio)
+
+ if proj_layer is not None:
+ self.downsample = proj_layer(
+ in_chs, out_chs, stride=stride, dilation=dilation, preact=False,
+ conv_layer=conv_layer, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ self.conv1 = conv_layer(in_chs, mid_chs, 1)
+ self.norm1 = norm_layer(mid_chs)
+ self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
+ self.norm2 = norm_layer(mid_chs)
+ self.conv3 = conv_layer(mid_chs, out_chs, 1)
+ self.norm3 = norm_layer(out_chs, apply_act=False)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
+ self.act3 = act_layer(inplace=True)
+
+ def zero_init_last(self):
+ nn.init.zeros_(self.norm3.weight)
+
+ def forward(self, x):
+ # shortcut branch
+ shortcut = x
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+
+ # residual
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = self.conv3(x)
+ x = self.norm3(x)
+ x = self.drop_path(x)
+ x = self.act3(x + shortcut)
+ return x
+
+
+class DownsampleConv(nn.Module):
+ def __init__(
+ self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True,
+ conv_layer=None, norm_layer=None):
+ super(DownsampleConv, self).__init__()
+ self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
+ self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
+
+ def forward(self, x):
+ return self.norm(self.conv(x))
+
+
+class DownsampleAvg(nn.Module):
+ def __init__(
+ self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None,
+ preact=True, conv_layer=None, norm_layer=None):
+ """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
+ super(DownsampleAvg, self).__init__()
+ avg_stride = stride if dilation == 1 else 1
+ if stride > 1 or dilation > 1:
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
+ self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
+ else:
+ self.pool = nn.Identity()
+ self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
+ self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
+
+ def forward(self, x):
+ return self.norm(self.conv(self.pool(x)))
+
+
+class ResNetStage(nn.Module):
+ """ResNet Stage."""
+ def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
+ avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
+ act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
+ super(ResNetStage, self).__init__()
+ first_dilation = 1 if dilation in (1, 2) else 2
+ layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
+ proj_layer = DownsampleAvg if avg_down else DownsampleConv
+ prev_chs = in_chs
+ self.blocks = nn.Sequential()
+ for block_idx in range(depth):
+ drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
+ stride = stride if block_idx == 0 else 1
+ self.blocks.add_module(str(block_idx), block_fn(
+ prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups,
+ first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate,
+ **layer_kwargs, **block_kwargs))
+ prev_chs = out_chs
+ first_dilation = dilation
+ proj_layer = None
+
+ def forward(self, x):
+ x = self.blocks(x)
+ return x
+
+
+def is_stem_deep(stem_type):
+ return any([s in stem_type for s in ('deep', 'tiered')])
+
+
+def create_resnetv2_stem(
+ in_chs, out_chs=64, stem_type='', preact=True,
+ conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
+ stem = OrderedDict()
+ assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
+
+ # NOTE conv padding mode can be changed by overriding the conv_layer def
+ if is_stem_deep(stem_type):
+ # A 3 deep 3x3 conv stack as in ResNet V1D models
+ if 'tiered' in stem_type:
+ stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py
+ else:
+ stem_chs = (out_chs // 2, out_chs // 2) # 'D' ResNets
+ stem['conv1'] = conv_layer(in_chs, stem_chs[0], kernel_size=3, stride=2)
+ stem['norm1'] = norm_layer(stem_chs[0])
+ stem['conv2'] = conv_layer(stem_chs[0], stem_chs[1], kernel_size=3, stride=1)
+ stem['norm2'] = norm_layer(stem_chs[1])
+ stem['conv3'] = conv_layer(stem_chs[1], out_chs, kernel_size=3, stride=1)
+ if not preact:
+ stem['norm3'] = norm_layer(out_chs)
+ else:
+ # The usual 7x7 stem conv
+ stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
+ if not preact:
+ stem['norm'] = norm_layer(out_chs)
+
+ if 'fixed' in stem_type:
+ # 'fixed' SAME padding approximation that is used in BiT models
+ stem['pad'] = nn.ConstantPad2d(1, 0.)
+ stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
+ elif 'same' in stem_type:
+ # full, input size based 'SAME' padding, used in ViT Hybrid model
+ stem['pool'] = create_pool2d('max', kernel_size=3, stride=2, padding='same')
+ else:
+ # the usual PyTorch symmetric padding
+ stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ return nn.Sequential(stem)
+
+
+class ResNetV2(nn.Module):
+ """Implementation of Pre-activation (v2) ResNet mode.
+ """
+
+ def __init__(
+ self, layers, channels=(256, 512, 1024, 2048),
+ num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
+ width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
+ act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
+ drop_rate=0., drop_path_rate=0., zero_init_last=False):
+ super().__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ wf = width_factor
+
+ self.feature_info = []
+ stem_chs = make_div(stem_chs * wf)
+ self.stem = create_resnetv2_stem(
+ in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
+ stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm'
+ self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
+
+ prev_chs = stem_chs
+ curr_stride = 4
+ dilation = 1
+ block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
+ block_fn = PreActBottleneck if preact else Bottleneck
+ self.stages = nn.Sequential()
+ for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
+ out_chs = make_div(c * wf)
+ stride = 1 if stage_idx == 0 else 2
+ if curr_stride >= output_stride:
+ dilation *= stride
+ stride = 1
+ stage = ResNetStage(
+ prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
+ act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
+ prev_chs = out_chs
+ curr_stride *= stride
+ self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')]
+ self.stages.add_module(str(stage_idx), stage)
+
+ self.num_features = prev_chs
+ self.norm = norm_layer(self.num_features) if preact else nn.Identity()
+ self.head = ClassifierHead(
+ self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
+
+ self.init_weights(zero_init_last=zero_init_last)
+
+ def init_weights(self, zero_init_last=True):
+ named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
+
+ @torch.jit.ignore()
+ def load_pretrained(self, checkpoint_path, prefix='resnet/'):
+ _load_weights(self, checkpoint_path, prefix)
+
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.head = ClassifierHead(
+ self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ x = self.stages(x)
+ x = self.norm(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def _init_weights(module: nn.Module, name: str = '', zero_init_last=True):
+ if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)):
+ nn.init.normal_(module.weight, mean=0.0, std=0.01)
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Conv2d):
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
+ nn.init.ones_(module.weight)
+ nn.init.zeros_(module.bias)
+ elif zero_init_last and hasattr(module, 'zero_init_last'):
+ module.zero_init_last()
+
+
+@torch.no_grad()
+def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/'):
+ import numpy as np
+
+ def t2p(conv_weights):
+ """Possibly convert HWIO to OIHW."""
+ if conv_weights.ndim == 4:
+ conv_weights = conv_weights.transpose([3, 2, 0, 1])
+ return torch.from_numpy(conv_weights)
+
+ weights = np.load(checkpoint_path)
+ stem_conv_w = adapt_input_conv(
+ model.stem.conv.weight.shape[1], t2p(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
+ model.stem.conv.weight.copy_(stem_conv_w)
+ model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma']))
+ model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta']))
+ if isinstance(getattr(model.head, 'fc', None), nn.Conv2d) and \
+ model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
+ model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel']))
+ model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias']))
+ for i, (sname, stage) in enumerate(model.stages.named_children()):
+ for j, (bname, block) in enumerate(stage.blocks.named_children()):
+ cname = 'standardized_conv2d'
+ block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
+ block.conv1.weight.copy_(t2p(weights[f'{block_prefix}a/{cname}/kernel']))
+ block.conv2.weight.copy_(t2p(weights[f'{block_prefix}b/{cname}/kernel']))
+ block.conv3.weight.copy_(t2p(weights[f'{block_prefix}c/{cname}/kernel']))
+ block.norm1.weight.copy_(t2p(weights[f'{block_prefix}a/group_norm/gamma']))
+ block.norm2.weight.copy_(t2p(weights[f'{block_prefix}b/group_norm/gamma']))
+ block.norm3.weight.copy_(t2p(weights[f'{block_prefix}c/group_norm/gamma']))
+ block.norm1.bias.copy_(t2p(weights[f'{block_prefix}a/group_norm/beta']))
+ block.norm2.bias.copy_(t2p(weights[f'{block_prefix}b/group_norm/beta']))
+ block.norm3.bias.copy_(t2p(weights[f'{block_prefix}c/group_norm/beta']))
+ if block.downsample is not None:
+ w = weights[f'{block_prefix}a/proj/{cname}/kernel']
+ block.downsample.conv.weight.copy_(t2p(w))
+
+
+def _create_resnetv2(variant, pretrained=False, **kwargs):
+ feature_cfg = dict(flatten_sequential=True)
+ return build_model_with_cfg(
+ ResNetV2, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ feature_cfg=feature_cfg,
+ pretrained_custom_load='_bit' in variant,
+ **kwargs)
+
+
+def _create_resnetv2_bit(variant, pretrained=False, **kwargs):
+ return _create_resnetv2(
+ variant, pretrained=pretrained, stem_type='fixed', conv_layer=partial(StdConv2d, eps=1e-8), **kwargs)
+
+
+@register_model
+def resnetv2_50x1_bitm(pretrained=False, **kwargs):
+ return _create_resnetv2_bit(
+ 'resnetv2_50x1_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
+
+
+@register_model
+def resnetv2_50x3_bitm(pretrained=False, **kwargs):
+ return _create_resnetv2_bit(
+ 'resnetv2_50x3_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs)
+
+
+@register_model
+def resnetv2_101x1_bitm(pretrained=False, **kwargs):
+ return _create_resnetv2_bit(
+ 'resnetv2_101x1_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs)
+
+
+@register_model
+def resnetv2_101x3_bitm(pretrained=False, **kwargs):
+ return _create_resnetv2_bit(
+ 'resnetv2_101x3_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs)
+
+
+@register_model
+def resnetv2_152x2_bitm(pretrained=False, **kwargs):
+ return _create_resnetv2_bit(
+ 'resnetv2_152x2_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
+
+
+@register_model
+def resnetv2_152x4_bitm(pretrained=False, **kwargs):
+ return _create_resnetv2_bit(
+ 'resnetv2_152x4_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)
+
+
+@register_model
+def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
+ return _create_resnetv2_bit(
+ 'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
+ layers=[3, 4, 6, 3], width_factor=1, **kwargs)
+
+
+@register_model
+def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
+ return _create_resnetv2_bit(
+ 'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
+ layers=[3, 4, 6, 3], width_factor=3, **kwargs)
+
+
+@register_model
+def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
+ return _create_resnetv2(
+ 'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
+ layers=[3, 4, 23, 3], width_factor=1, **kwargs)
+
+
+@register_model
+def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
+ return _create_resnetv2_bit(
+ 'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
+ layers=[3, 4, 23, 3], width_factor=3, **kwargs)
+
+
+@register_model
+def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
+ return _create_resnetv2_bit(
+ 'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
+ layers=[3, 8, 36, 3], width_factor=2, **kwargs)
+
+
+@register_model
+def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
+ return _create_resnetv2_bit(
+ 'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
+ layers=[3, 8, 36, 3], width_factor=4, **kwargs)
+
+
+@register_model
+def resnetv2_50x1_bit_distilled(pretrained=False, **kwargs):
+ """ ResNetV2-50x1-BiT Distilled
+ Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
+ """
+ return _create_resnetv2_bit(
+ 'resnetv2_50x1_bit_distilled', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
+
+
+@register_model
+def resnetv2_152x2_bit_teacher(pretrained=False, **kwargs):
+ """ ResNetV2-152x2-BiT Teacher
+ Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
+ """
+ return _create_resnetv2_bit(
+ 'resnetv2_152x2_bit_teacher', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
+
+
+@register_model
+def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
+ """ ResNetV2-152xx-BiT Teacher @ 384x384
+ Paper: Knowledge distillation: A good teacher is patient and consistent - https://arxiv.org/abs/2106.05237
+ """
+ return _create_resnetv2_bit(
+ 'resnetv2_152x2_bit_teacher_384', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
+
+
+@register_model
+def resnetv2_50(pretrained=False, **kwargs):
+ return _create_resnetv2(
+ 'resnetv2_50', pretrained=pretrained,
+ layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
+
+
+@register_model
+def resnetv2_50d(pretrained=False, **kwargs):
+ return _create_resnetv2(
+ 'resnetv2_50d', pretrained=pretrained,
+ layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
+ stem_type='deep', avg_down=True, **kwargs)
+
+
+@register_model
+def resnetv2_50t(pretrained=False, **kwargs):
+ return _create_resnetv2(
+ 'resnetv2_50t', pretrained=pretrained,
+ layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
+ stem_type='tiered', avg_down=True, **kwargs)
+
+
+@register_model
+def resnetv2_101(pretrained=False, **kwargs):
+ return _create_resnetv2(
+ 'resnetv2_101', pretrained=pretrained,
+ layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
+
+
+@register_model
+def resnetv2_101d(pretrained=False, **kwargs):
+ return _create_resnetv2(
+ 'resnetv2_101d', pretrained=pretrained,
+ layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
+ stem_type='deep', avg_down=True, **kwargs)
+
+
+@register_model
+def resnetv2_152(pretrained=False, **kwargs):
+ return _create_resnetv2(
+ 'resnetv2_152', pretrained=pretrained,
+ layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
+
+
+@register_model
+def resnetv2_152d(pretrained=False, **kwargs):
+ return _create_resnetv2(
+ 'resnetv2_152d', pretrained=pretrained,
+ layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
+ stem_type='deep', avg_down=True, **kwargs)
+
+
+# Experimental configs (may change / be removed)
+
+@register_model
+def resnetv2_50d_gn(pretrained=False, **kwargs):
+ return _create_resnetv2(
+ 'resnetv2_50d_gn', pretrained=pretrained,
+ layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct,
+ stem_type='deep', avg_down=True, **kwargs)
+
+
+@register_model
+def resnetv2_50d_evob(pretrained=False, **kwargs):
+ return _create_resnetv2(
+ 'resnetv2_50d_evob', pretrained=pretrained,
+ layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d,
+ stem_type='deep', avg_down=True, **kwargs)
+
+
+@register_model
+def resnetv2_50d_evos(pretrained=False, **kwargs):
+ return _create_resnetv2(
+ 'resnetv2_50d_evos', pretrained=pretrained,
+ layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d,
+ stem_type='deep', avg_down=True, **kwargs)
diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py
new file mode 100644
index 0000000..f27ce5d
--- /dev/null
+++ b/timm/models/rexnet.py
@@ -0,0 +1,239 @@
+""" ReXNet
+
+A PyTorch impl of `ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network` -
+https://arxiv.org/abs/2007.00992
+
+Adapted from original impl at https://github.com/clovaai/rexnet
+Copyright (c) 2020-present NAVER Corp. MIT license
+
+Changes for timm, feature extraction, and rounded channel variant hacked together by Ross Wightman
+Copyright 2020 Ross Wightman
+"""
+
+import torch
+import torch.nn as nn
+from functools import partial
+from math import ceil
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible, SEModule
+from .registry import register_model
+from .efficientnet_builder import efficientnet_init_weights
+
+
+def _cfg(url=''):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv', 'classifier': 'head.fc',
+ }
+
+
+default_cfgs = dict(
+ rexnet_100=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_100-1b4dddf4.pth'),
+ rexnet_130=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_130-590d768e.pth'),
+ rexnet_150=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_150-bd1a6aa8.pth'),
+ rexnet_200=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rexnet/rexnetv1_200-8c0b7f2d.pth'),
+ rexnetr_100=_cfg(
+ url=''),
+ rexnetr_130=_cfg(
+ url=''),
+ rexnetr_150=_cfg(
+ url=''),
+ rexnetr_200=_cfg(
+ url=''),
+)
+
+SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d)
+
+
+class LinearBottleneck(nn.Module):
+ def __init__(self, in_chs, out_chs, stride, exp_ratio=1.0, se_ratio=0., ch_div=1,
+ act_layer='swish', dw_act_layer='relu6', drop_path=None):
+ super(LinearBottleneck, self).__init__()
+ self.use_shortcut = stride == 1 and in_chs <= out_chs
+ self.in_channels = in_chs
+ self.out_channels = out_chs
+
+ if exp_ratio != 1.:
+ dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div)
+ self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer=act_layer)
+ else:
+ dw_chs = in_chs
+ self.conv_exp = None
+
+ self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
+ if se_ratio > 0:
+ self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div))
+ else:
+ self.se = None
+ self.act_dw = create_act_layer(dw_act_layer)
+
+ self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False)
+ self.drop_path = drop_path
+
+ def feat_channels(self, exp=False):
+ return self.conv_dw.out_channels if exp else self.out_channels
+
+ def forward(self, x):
+ shortcut = x
+ if self.conv_exp is not None:
+ x = self.conv_exp(x)
+ x = self.conv_dw(x)
+ if self.se is not None:
+ x = self.se(x)
+ x = self.act_dw(x)
+ x = self.conv_pwl(x)
+ if self.use_shortcut:
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+ x = torch.cat([x[:, 0:self.in_channels] + shortcut, x[:, self.in_channels:]], dim=1)
+ return x
+
+
+def _block_cfg(width_mult=1.0, depth_mult=1.0, initial_chs=16, final_chs=180, se_ratio=0., ch_div=1):
+ layers = [1, 2, 2, 3, 3, 5]
+ strides = [1, 2, 2, 2, 1, 2]
+ layers = [ceil(element * depth_mult) for element in layers]
+ strides = sum([[element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides)], [])
+ exp_ratios = [1] * layers[0] + [6] * sum(layers[1:])
+ depth = sum(layers[:]) * 3
+ base_chs = initial_chs / width_mult if width_mult < 1.0 else initial_chs
+
+ # The following channel configuration is a simple instance to make each layer become an expand layer.
+ out_chs_list = []
+ for i in range(depth // 3):
+ out_chs_list.append(make_divisible(round(base_chs * width_mult), divisor=ch_div))
+ base_chs += final_chs / (depth // 3 * 1.0)
+
+ se_ratios = [0.] * (layers[0] + layers[1]) + [se_ratio] * sum(layers[2:])
+
+ return list(zip(out_chs_list, exp_ratios, strides, se_ratios))
+
+
+def _build_blocks(
+ block_cfg, prev_chs, width_mult, ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_path_rate=0.):
+ feat_chs = [prev_chs]
+ feature_info = []
+ curr_stride = 2
+ features = []
+ num_blocks = len(block_cfg)
+ for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg):
+ if stride > 1:
+ fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}'
+ feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)]
+ curr_stride *= stride
+ block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule
+ drop_path = DropPath(block_dpr) if block_dpr > 0. else None
+ features.append(LinearBottleneck(
+ in_chs=prev_chs, out_chs=chs, exp_ratio=exp_ratio, stride=stride, se_ratio=se_ratio,
+ ch_div=ch_div, act_layer=act_layer, dw_act_layer=dw_act_layer, drop_path=drop_path))
+ prev_chs = chs
+ feat_chs += [features[-1].feat_channels()]
+ pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
+ feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=f'features.{len(features) - 1}')]
+ features.append(ConvBnAct(prev_chs, pen_chs, act_layer=act_layer))
+ return features, feature_info
+
+
+class ReXNetV1(nn.Module):
+ def __init__(self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32,
+ initial_chs=16, final_chs=180, width_mult=1.0, depth_mult=1.0, se_ratio=1/12.,
+ ch_div=1, act_layer='swish', dw_act_layer='relu6', drop_rate=0.2, drop_path_rate=0.):
+ super(ReXNetV1, self).__init__()
+ self.drop_rate = drop_rate
+ self.num_classes = num_classes
+
+ assert output_stride == 32 # FIXME support dilation
+ stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
+ stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
+ self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer)
+
+ block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div)
+ features, self.feature_info = _build_blocks(
+ block_cfg, stem_chs, width_mult, ch_div, act_layer, dw_act_layer, drop_path_rate)
+ self.num_features = features[-1].out_channels
+ self.features = nn.Sequential(*features)
+
+ self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate)
+
+ efficientnet_init_weights(self)
+
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ x = self.features(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def _create_rexnet(variant, pretrained, **kwargs):
+ feature_cfg = dict(flatten_sequential=True)
+ return build_model_with_cfg(
+ ReXNetV1, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ feature_cfg=feature_cfg,
+ **kwargs)
+
+
+@register_model
+def rexnet_100(pretrained=False, **kwargs):
+ """ReXNet V1 1.0x"""
+ return _create_rexnet('rexnet_100', pretrained, **kwargs)
+
+
+@register_model
+def rexnet_130(pretrained=False, **kwargs):
+ """ReXNet V1 1.3x"""
+ return _create_rexnet('rexnet_130', pretrained, width_mult=1.3, **kwargs)
+
+
+@register_model
+def rexnet_150(pretrained=False, **kwargs):
+ """ReXNet V1 1.5x"""
+ return _create_rexnet('rexnet_150', pretrained, width_mult=1.5, **kwargs)
+
+
+@register_model
+def rexnet_200(pretrained=False, **kwargs):
+ """ReXNet V1 2.0x"""
+ return _create_rexnet('rexnet_200', pretrained, width_mult=2.0, **kwargs)
+
+
+@register_model
+def rexnetr_100(pretrained=False, **kwargs):
+ """ReXNet V1 1.0x w/ rounded (mod 8) channels"""
+ return _create_rexnet('rexnetr_100', pretrained, ch_div=8, **kwargs)
+
+
+@register_model
+def rexnetr_130(pretrained=False, **kwargs):
+ """ReXNet V1 1.3x w/ rounded (mod 8) channels"""
+ return _create_rexnet('rexnetr_130', pretrained, width_mult=1.3, ch_div=8, **kwargs)
+
+
+@register_model
+def rexnetr_150(pretrained=False, **kwargs):
+ """ReXNet V1 1.5x w/ rounded (mod 8) channels"""
+ return _create_rexnet('rexnetr_150', pretrained, width_mult=1.5, ch_div=8, **kwargs)
+
+
+@register_model
+def rexnetr_200(pretrained=False, **kwargs):
+ """ReXNet V1 2.0x w/ rounded (mod 8) channels"""
+ return _create_rexnet('rexnetr_200', pretrained, width_mult=2.0, ch_div=8, **kwargs)
diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py
new file mode 100644
index 0000000..1f3379d
--- /dev/null
+++ b/timm/models/selecsls.py
@@ -0,0 +1,362 @@
+"""PyTorch SelecSLS Net example for ImageNet Classification
+License: CC BY 4.0 (https://creativecommons.org/licenses/by/4.0/legalcode)
+Author: Dushyant Mehta (@mehtadushy)
+
+SelecSLS (core) Network Architecture as proposed in "XNect: Real-time Multi-person 3D
+Human Pose Estimation with a Single RGB Camera, Mehta et al."
+https://arxiv.org/abs/1907.00837
+
+Based on ResNet implementation in https://github.com/rwightman/pytorch-image-models
+and SelecSLS Net implementation in https://github.com/mehtadushy/SelecSLS-Pytorch
+"""
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import create_classifier
+from .registry import register_model
+
+__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (4, 4),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.0', 'classifier': 'fc',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'selecsls42': _cfg(
+ url='',
+ interpolation='bicubic'),
+ 'selecsls42b': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls42b-8af30141.pth',
+ interpolation='bicubic'),
+ 'selecsls60': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60-bbf87526.pth',
+ interpolation='bicubic'),
+ 'selecsls60b': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-selecsls/selecsls60b-94e619b5.pth',
+ interpolation='bicubic'),
+ 'selecsls84': _cfg(
+ url='',
+ interpolation='bicubic'),
+}
+
+
+class SequentialList(nn.Sequential):
+
+ def __init__(self, *args):
+ super(SequentialList, self).__init__(*args)
+
+ @torch.jit._overload_method # noqa: F811
+ def forward(self, x):
+ # type: (List[torch.Tensor]) -> (List[torch.Tensor])
+ pass
+
+ @torch.jit._overload_method # noqa: F811
+ def forward(self, x):
+ # type: (torch.Tensor) -> (List[torch.Tensor])
+ pass
+
+ def forward(self, x) -> List[torch.Tensor]:
+ for module in self:
+ x = module(x)
+ return x
+
+
+class SelectSeq(nn.Module):
+ def __init__(self, mode='index', index=0):
+ super(SelectSeq, self).__init__()
+ self.mode = mode
+ self.index = index
+
+ @torch.jit._overload_method # noqa: F811
+ def forward(self, x):
+ # type: (List[torch.Tensor]) -> (torch.Tensor)
+ pass
+
+ @torch.jit._overload_method # noqa: F811
+ def forward(self, x):
+ # type: (Tuple[torch.Tensor]) -> (torch.Tensor)
+ pass
+
+ def forward(self, x) -> torch.Tensor:
+ if self.mode == 'index':
+ return x[self.index]
+ else:
+ return torch.cat(x, dim=1)
+
+
+def conv_bn(in_chs, out_chs, k=3, stride=1, padding=None, dilation=1):
+ if padding is None:
+ padding = ((stride - 1) + dilation * (k - 1)) // 2
+ return nn.Sequential(
+ nn.Conv2d(in_chs, out_chs, k, stride, padding=padding, dilation=dilation, bias=False),
+ nn.BatchNorm2d(out_chs),
+ nn.ReLU(inplace=True)
+ )
+
+
+class SelecSLSBlock(nn.Module):
+ def __init__(self, in_chs, skip_chs, mid_chs, out_chs, is_first, stride, dilation=1):
+ super(SelecSLSBlock, self).__init__()
+ self.stride = stride
+ self.is_first = is_first
+ assert stride in [1, 2]
+
+ # Process input with 4 conv blocks with the same number of input and output channels
+ self.conv1 = conv_bn(in_chs, mid_chs, 3, stride, dilation=dilation)
+ self.conv2 = conv_bn(mid_chs, mid_chs, 1)
+ self.conv3 = conv_bn(mid_chs, mid_chs // 2, 3)
+ self.conv4 = conv_bn(mid_chs // 2, mid_chs, 1)
+ self.conv5 = conv_bn(mid_chs, mid_chs // 2, 3)
+ self.conv6 = conv_bn(2 * mid_chs + (0 if is_first else skip_chs), out_chs, 1)
+
+ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
+ if not isinstance(x, list):
+ x = [x]
+ assert len(x) in [1, 2]
+
+ d1 = self.conv1(x[0])
+ d2 = self.conv3(self.conv2(d1))
+ d3 = self.conv5(self.conv4(d2))
+ if self.is_first:
+ out = self.conv6(torch.cat([d1, d2, d3], 1))
+ return [out, out]
+ else:
+ return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)), x[1]]
+
+
+class SelecSLS(nn.Module):
+ """SelecSLS42 / SelecSLS60 / SelecSLS84
+
+ Parameters
+ ----------
+ cfg : network config dictionary specifying block type, feature, and head args
+ num_classes : int, default 1000
+ Number of classification classes.
+ in_chans : int, default 3
+ Number of input (color) channels.
+ drop_rate : float, default 0.
+ Dropout probability before classifier, for training
+ global_pool : str, default 'avg'
+ Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
+ """
+
+ def __init__(self, cfg, num_classes=1000, in_chans=3, drop_rate=0.0, global_pool='avg'):
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ super(SelecSLS, self).__init__()
+
+ self.stem = conv_bn(in_chans, 32, stride=2)
+ self.features = SequentialList(*[cfg['block'](*block_args) for block_args in cfg['features']])
+ self.from_seq = SelectSeq() # from List[tensor] -> Tensor in module compatible way
+ self.head = nn.Sequential(*[conv_bn(*conv_args) for conv_args in cfg['head']])
+ self.num_features = cfg['num_features']
+ self.feature_info = cfg['feature_info']
+
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ for n, m in self.named_modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1.)
+ nn.init.constant_(m.bias, 0.)
+
+ def get_classifier(self):
+ return self.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ x = self.features(x)
+ x = self.head(self.from_seq(x))
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate > 0.:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ x = self.fc(x)
+ return x
+
+
+def _create_selecsls(variant, pretrained, **kwargs):
+ cfg = {}
+ feature_info = [dict(num_chs=32, reduction=2, module='stem.2')]
+ if variant.startswith('selecsls42'):
+ cfg['block'] = SelecSLSBlock
+ # Define configuration of the network after the initial neck
+ cfg['features'] = [
+ # in_chs, skip_chs, mid_chs, out_chs, is_first, stride
+ (32, 0, 64, 64, True, 2),
+ (64, 64, 64, 128, False, 1),
+ (128, 0, 144, 144, True, 2),
+ (144, 144, 144, 288, False, 1),
+ (288, 0, 304, 304, True, 2),
+ (304, 304, 304, 480, False, 1),
+ ]
+ feature_info.extend([
+ dict(num_chs=128, reduction=4, module='features.1'),
+ dict(num_chs=288, reduction=8, module='features.3'),
+ dict(num_chs=480, reduction=16, module='features.5'),
+ ])
+ # Head can be replaced with alternative configurations depending on the problem
+ feature_info.append(dict(num_chs=1024, reduction=32, module='head.1'))
+ if variant == 'selecsls42b':
+ cfg['head'] = [
+ (480, 960, 3, 2),
+ (960, 1024, 3, 1),
+ (1024, 1280, 3, 2),
+ (1280, 1024, 1, 1),
+ ]
+ feature_info.append(dict(num_chs=1024, reduction=64, module='head.3'))
+ cfg['num_features'] = 1024
+ else:
+ cfg['head'] = [
+ (480, 960, 3, 2),
+ (960, 1024, 3, 1),
+ (1024, 1024, 3, 2),
+ (1024, 1280, 1, 1),
+ ]
+ feature_info.append(dict(num_chs=1280, reduction=64, module='head.3'))
+ cfg['num_features'] = 1280
+
+ elif variant.startswith('selecsls60'):
+ cfg['block'] = SelecSLSBlock
+ # Define configuration of the network after the initial neck
+ cfg['features'] = [
+ # in_chs, skip_chs, mid_chs, out_chs, is_first, stride
+ (32, 0, 64, 64, True, 2),
+ (64, 64, 64, 128, False, 1),
+ (128, 0, 128, 128, True, 2),
+ (128, 128, 128, 128, False, 1),
+ (128, 128, 128, 288, False, 1),
+ (288, 0, 288, 288, True, 2),
+ (288, 288, 288, 288, False, 1),
+ (288, 288, 288, 288, False, 1),
+ (288, 288, 288, 416, False, 1),
+ ]
+ feature_info.extend([
+ dict(num_chs=128, reduction=4, module='features.1'),
+ dict(num_chs=288, reduction=8, module='features.4'),
+ dict(num_chs=416, reduction=16, module='features.8'),
+ ])
+ # Head can be replaced with alternative configurations depending on the problem
+ feature_info.append(dict(num_chs=1024, reduction=32, module='head.1'))
+ if variant == 'selecsls60b':
+ cfg['head'] = [
+ (416, 756, 3, 2),
+ (756, 1024, 3, 1),
+ (1024, 1280, 3, 2),
+ (1280, 1024, 1, 1),
+ ]
+ feature_info.append(dict(num_chs=1024, reduction=64, module='head.3'))
+ cfg['num_features'] = 1024
+ else:
+ cfg['head'] = [
+ (416, 756, 3, 2),
+ (756, 1024, 3, 1),
+ (1024, 1024, 3, 2),
+ (1024, 1280, 1, 1),
+ ]
+ feature_info.append(dict(num_chs=1280, reduction=64, module='head.3'))
+ cfg['num_features'] = 1280
+
+ elif variant == 'selecsls84':
+ cfg['block'] = SelecSLSBlock
+ # Define configuration of the network after the initial neck
+ cfg['features'] = [
+ # in_chs, skip_chs, mid_chs, out_chs, is_first, stride
+ (32, 0, 64, 64, True, 2),
+ (64, 64, 64, 144, False, 1),
+ (144, 0, 144, 144, True, 2),
+ (144, 144, 144, 144, False, 1),
+ (144, 144, 144, 144, False, 1),
+ (144, 144, 144, 144, False, 1),
+ (144, 144, 144, 304, False, 1),
+ (304, 0, 304, 304, True, 2),
+ (304, 304, 304, 304, False, 1),
+ (304, 304, 304, 304, False, 1),
+ (304, 304, 304, 304, False, 1),
+ (304, 304, 304, 304, False, 1),
+ (304, 304, 304, 512, False, 1),
+ ]
+ feature_info.extend([
+ dict(num_chs=144, reduction=4, module='features.1'),
+ dict(num_chs=304, reduction=8, module='features.6'),
+ dict(num_chs=512, reduction=16, module='features.12'),
+ ])
+ # Head can be replaced with alternative configurations depending on the problem
+ cfg['head'] = [
+ (512, 960, 3, 2),
+ (960, 1024, 3, 1),
+ (1024, 1024, 3, 2),
+ (1024, 1280, 3, 1),
+ ]
+ cfg['num_features'] = 1280
+ feature_info.extend([
+ dict(num_chs=1024, reduction=32, module='head.1'),
+ dict(num_chs=1280, reduction=64, module='head.3')
+ ])
+ else:
+ raise ValueError('Invalid net configuration ' + variant + ' !!!')
+ cfg['feature_info'] = feature_info
+
+ # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises?
+ return build_model_with_cfg(
+ SelecSLS, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ model_cfg=cfg,
+ feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True),
+ **kwargs)
+
+
+@register_model
+def selecsls42(pretrained=False, **kwargs):
+ """Constructs a SelecSLS42 model.
+ """
+ return _create_selecsls('selecsls42', pretrained, **kwargs)
+
+
+@register_model
+def selecsls42b(pretrained=False, **kwargs):
+ """Constructs a SelecSLS42_B model.
+ """
+ return _create_selecsls('selecsls42b', pretrained, **kwargs)
+
+
+@register_model
+def selecsls60(pretrained=False, **kwargs):
+ """Constructs a SelecSLS60 model.
+ """
+ return _create_selecsls('selecsls60', pretrained, **kwargs)
+
+
+@register_model
+def selecsls60b(pretrained=False, **kwargs):
+ """Constructs a SelecSLS60_B model.
+ """
+ return _create_selecsls('selecsls60b', pretrained, **kwargs)
+
+
+@register_model
+def selecsls84(pretrained=False, **kwargs):
+ """Constructs a SelecSLS84 model.
+ """
+ return _create_selecsls('selecsls84', pretrained, **kwargs)
diff --git a/timm/models/senet.py b/timm/models/senet.py
new file mode 100644
index 0000000..3d0ba7b
--- /dev/null
+++ b/timm/models/senet.py
@@ -0,0 +1,467 @@
+"""
+SEResNet implementation from Cadene's pretrained models
+https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py
+Additional credit to https://github.com/creafz
+
+Original model: https://github.com/hujie-frank/SENet
+
+ResNet code gently borrowed from
+https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
+
+FIXME I'm deprecating this model and moving them to ResNet as I don't want to maintain duplicate
+support for extras like dilation, switchable BN/activations, feature extraction, etc that don't exist here.
+"""
+import math
+from collections import OrderedDict
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import create_classifier
+from .registry import register_model
+
+__all__ = ['SENet']
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'layer0.conv1', 'classifier': 'last_linear',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'legacy_senet154':
+ _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'),
+ 'legacy_seresnet18': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet18-4bb0ce65.pth',
+ interpolation='bicubic'),
+ 'legacy_seresnet34': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet34-a4004e63.pth'),
+ 'legacy_seresnet50': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'),
+ 'legacy_seresnet101': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'),
+ 'legacy_seresnet152': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'),
+ 'legacy_seresnext26_32x4d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth',
+ interpolation='bicubic'),
+ 'legacy_seresnext50_32x4d':
+ _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'),
+ 'legacy_seresnext101_32x4d':
+ _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth'),
+}
+
+
+def _weight_init(m):
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1.)
+ nn.init.constant_(m.bias, 0.)
+
+
+class SEModule(nn.Module):
+
+ def __init__(self, channels, reduction):
+ super(SEModule, self).__init__()
+ self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1)
+ self.relu = nn.ReLU(inplace=True)
+ self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ module_input = x
+ x = x.mean((2, 3), keepdim=True)
+ x = self.fc1(x)
+ x = self.relu(x)
+ x = self.fc2(x)
+ x = self.sigmoid(x)
+ return module_input * x
+
+
+class Bottleneck(nn.Module):
+ """
+ Base class for bottlenecks that implements `forward()` method.
+ """
+
+ def forward(self, x):
+ shortcut = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+
+ out = self.se_module(out) + shortcut
+ out = self.relu(out)
+
+ return out
+
+
+class SEBottleneck(Bottleneck):
+ """
+ Bottleneck for SENet154.
+ """
+ expansion = 4
+
+ def __init__(self, inplanes, planes, groups, reduction, stride=1,
+ downsample=None):
+ super(SEBottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes * 2)
+ self.conv2 = nn.Conv2d(
+ planes * 2, planes * 4, kernel_size=3, stride=stride,
+ padding=1, groups=groups, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes * 4)
+ self.conv3 = nn.Conv2d(
+ planes * 4, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.se_module = SEModule(planes * 4, reduction=reduction)
+ self.downsample = downsample
+ self.stride = stride
+
+
+class SEResNetBottleneck(Bottleneck):
+ """
+ ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
+ implementation and uses `stride=stride` in `conv1` and not in `conv2`
+ (the latter is used in the torchvision implementation of ResNet).
+ """
+ expansion = 4
+
+ def __init__(self, inplanes, planes, groups, reduction, stride=1,
+ downsample=None):
+ super(SEResNetBottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(
+ inplanes, planes, kernel_size=1, bias=False, stride=stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(
+ planes, planes, kernel_size=3, padding=1, groups=groups, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.se_module = SEModule(planes * 4, reduction=reduction)
+ self.downsample = downsample
+ self.stride = stride
+
+
+class SEResNeXtBottleneck(Bottleneck):
+ """
+ ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
+ """
+ expansion = 4
+
+ def __init__(self, inplanes, planes, groups, reduction, stride=1,
+ downsample=None, base_width=4):
+ super(SEResNeXtBottleneck, self).__init__()
+ width = math.floor(planes * (base_width / 64)) * groups
+ self.conv1 = nn.Conv2d(
+ inplanes, width, kernel_size=1, bias=False, stride=1)
+ self.bn1 = nn.BatchNorm2d(width)
+ self.conv2 = nn.Conv2d(
+ width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False)
+ self.bn2 = nn.BatchNorm2d(width)
+ self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.se_module = SEModule(planes * 4, reduction=reduction)
+ self.downsample = downsample
+ self.stride = stride
+
+
+class SEResNetBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None):
+ super(SEResNetBlock, self).__init__()
+ self.conv1 = nn.Conv2d(
+ inplanes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(
+ planes, planes, kernel_size=3, padding=1, groups=groups, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.se_module = SEModule(planes, reduction=reduction)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ shortcut = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+
+ out = self.se_module(out) + shortcut
+ out = self.relu(out)
+
+ return out
+
+
+class SENet(nn.Module):
+
+ def __init__(self, block, layers, groups, reduction, drop_rate=0.2,
+ in_chans=3, inplanes=64, input_3x3=False, downsample_kernel_size=1,
+ downsample_padding=0, num_classes=1000, global_pool='avg'):
+ """
+ Parameters
+ ----------
+ block (nn.Module): Bottleneck class.
+ - For SENet154: SEBottleneck
+ - For SE-ResNet models: SEResNetBottleneck
+ - For SE-ResNeXt models: SEResNeXtBottleneck
+ layers (list of ints): Number of residual blocks for 4 layers of the
+ network (layer1...layer4).
+ groups (int): Number of groups for the 3x3 convolution in each
+ bottleneck block.
+ - For SENet154: 64
+ - For SE-ResNet models: 1
+ - For SE-ResNeXt models: 32
+ reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
+ - For all models: 16
+ dropout_p (float or None): Drop probability for the Dropout layer.
+ If `None` the Dropout layer is not used.
+ - For SENet154: 0.2
+ - For SE-ResNet models: None
+ - For SE-ResNeXt models: None
+ inplanes (int): Number of input channels for layer1.
+ - For SENet154: 128
+ - For SE-ResNet models: 64
+ - For SE-ResNeXt models: 64
+ input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
+ a single 7x7 convolution in layer0.
+ - For SENet154: True
+ - For SE-ResNet models: False
+ - For SE-ResNeXt models: False
+ downsample_kernel_size (int): Kernel size for downsampling convolutions
+ in layer2, layer3 and layer4.
+ - For SENet154: 3
+ - For SE-ResNet models: 1
+ - For SE-ResNeXt models: 1
+ downsample_padding (int): Padding for downsampling convolutions in
+ layer2, layer3 and layer4.
+ - For SENet154: 1
+ - For SE-ResNet models: 0
+ - For SE-ResNeXt models: 0
+ num_classes (int): Number of outputs in `last_linear` layer.
+ - For all models: 1000
+ """
+ super(SENet, self).__init__()
+ self.inplanes = inplanes
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ if input_3x3:
+ layer0_modules = [
+ ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False)),
+ ('bn1', nn.BatchNorm2d(64)),
+ ('relu1', nn.ReLU(inplace=True)),
+ ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),
+ ('bn2', nn.BatchNorm2d(64)),
+ ('relu2', nn.ReLU(inplace=True)),
+ ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)),
+ ('bn3', nn.BatchNorm2d(inplanes)),
+ ('relu3', nn.ReLU(inplace=True)),
+ ]
+ else:
+ layer0_modules = [
+ ('conv1', nn.Conv2d(
+ in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
+ ('bn1', nn.BatchNorm2d(inplanes)),
+ ('relu1', nn.ReLU(inplace=True)),
+ ]
+ self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
+ # To preserve compatibility with Caffe weights `ceil_mode=True` is used instead of `padding=1`.
+ self.pool0 = nn.MaxPool2d(3, stride=2, ceil_mode=True)
+ self.feature_info = [dict(num_chs=inplanes, reduction=2, module='layer0')]
+ self.layer1 = self._make_layer(
+ block,
+ planes=64,
+ blocks=layers[0],
+ groups=groups,
+ reduction=reduction,
+ downsample_kernel_size=1,
+ downsample_padding=0
+ )
+ self.feature_info += [dict(num_chs=64 * block.expansion, reduction=4, module='layer1')]
+ self.layer2 = self._make_layer(
+ block,
+ planes=128,
+ blocks=layers[1],
+ stride=2,
+ groups=groups,
+ reduction=reduction,
+ downsample_kernel_size=downsample_kernel_size,
+ downsample_padding=downsample_padding
+ )
+ self.feature_info += [dict(num_chs=128 * block.expansion, reduction=8, module='layer2')]
+ self.layer3 = self._make_layer(
+ block,
+ planes=256,
+ blocks=layers[2],
+ stride=2,
+ groups=groups,
+ reduction=reduction,
+ downsample_kernel_size=downsample_kernel_size,
+ downsample_padding=downsample_padding
+ )
+ self.feature_info += [dict(num_chs=256 * block.expansion, reduction=16, module='layer3')]
+ self.layer4 = self._make_layer(
+ block,
+ planes=512,
+ blocks=layers[3],
+ stride=2,
+ groups=groups,
+ reduction=reduction,
+ downsample_kernel_size=downsample_kernel_size,
+ downsample_padding=downsample_padding
+ )
+ self.feature_info += [dict(num_chs=512 * block.expansion, reduction=32, module='layer4')]
+ self.num_features = 512 * block.expansion
+ self.global_pool, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ for m in self.modules():
+ _weight_init(m)
+
+ def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
+ downsample_kernel_size=1, downsample_padding=0):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ self.inplanes, planes * block.expansion, kernel_size=downsample_kernel_size,
+ stride=stride, padding=downsample_padding, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = [block(self.inplanes, planes, groups, reduction, stride, downsample)]
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups, reduction))
+
+ return nn.Sequential(*layers)
+
+ def get_classifier(self):
+ return self.last_linear
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ x = self.layer0(x)
+ x = self.pool0(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ return x
+
+ def logits(self, x):
+ x = self.global_pool(x)
+ if self.drop_rate > 0.:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ x = self.last_linear(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.logits(x)
+ return x
+
+
+def _create_senet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ SENet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ **kwargs)
+
+
+@register_model
+def legacy_seresnet18(pretrained=False, **kwargs):
+ model_args = dict(
+ block=SEResNetBlock, layers=[2, 2, 2, 2], groups=1, reduction=16, **kwargs)
+ return _create_senet('legacy_seresnet18', pretrained, **model_args)
+
+
+@register_model
+def legacy_seresnet34(pretrained=False, **kwargs):
+ model_args = dict(
+ block=SEResNetBlock, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs)
+ return _create_senet('legacy_seresnet34', pretrained, **model_args)
+
+
+@register_model
+def legacy_seresnet50(pretrained=False, **kwargs):
+ model_args = dict(
+ block=SEResNetBottleneck, layers=[3, 4, 6, 3], groups=1, reduction=16, **kwargs)
+ return _create_senet('legacy_seresnet50', pretrained, **model_args)
+
+
+@register_model
+def legacy_seresnet101(pretrained=False, **kwargs):
+ model_args = dict(
+ block=SEResNetBottleneck, layers=[3, 4, 23, 3], groups=1, reduction=16, **kwargs)
+ return _create_senet('legacy_seresnet101', pretrained, **model_args)
+
+
+@register_model
+def legacy_seresnet152(pretrained=False, **kwargs):
+ model_args = dict(
+ block=SEResNetBottleneck, layers=[3, 8, 36, 3], groups=1, reduction=16, **kwargs)
+ return _create_senet('legacy_seresnet152', pretrained, **model_args)
+
+
+@register_model
+def legacy_senet154(pretrained=False, **kwargs):
+ model_args = dict(
+ block=SEBottleneck, layers=[3, 8, 36, 3], groups=64, reduction=16,
+ downsample_kernel_size=3, downsample_padding=1, inplanes=128, input_3x3=True, **kwargs)
+ return _create_senet('legacy_senet154', pretrained, **model_args)
+
+
+@register_model
+def legacy_seresnext26_32x4d(pretrained=False, **kwargs):
+ model_args = dict(
+ block=SEResNeXtBottleneck, layers=[2, 2, 2, 2], groups=32, reduction=16, **kwargs)
+ return _create_senet('legacy_seresnext26_32x4d', pretrained, **model_args)
+
+
+@register_model
+def legacy_seresnext50_32x4d(pretrained=False, **kwargs):
+ model_args = dict(
+ block=SEResNeXtBottleneck, layers=[3, 4, 6, 3], groups=32, reduction=16, **kwargs)
+ return _create_senet('legacy_seresnext50_32x4d', pretrained, **model_args)
+
+
+@register_model
+def legacy_seresnext101_32x4d(pretrained=False, **kwargs):
+ model_args = dict(
+ block=SEResNeXtBottleneck, layers=[3, 4, 23, 3], groups=32, reduction=16, **kwargs)
+ return _create_senet('legacy_seresnext101_32x4d', pretrained, **model_args)
diff --git a/timm/models/sknet.py b/timm/models/sknet.py
new file mode 100644
index 0000000..4dc2aa5
--- /dev/null
+++ b/timm/models/sknet.py
@@ -0,0 +1,215 @@
+""" Selective Kernel Networks (ResNet base)
+
+Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
+
+This was inspired by reading 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268)
+and a streamlined impl at https://github.com/clovaai/assembled-cnn but I ended up building something closer
+to the original paper with some modifications of my own to better balance param count vs accuracy.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import math
+
+from torch import nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .layers import SelectiveKernel, ConvBnAct, create_attn
+from .registry import register_model
+from .resnet import ResNet
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv1', 'classifier': 'fc',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'skresnet18': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet18_ra-4eec2804.pth'),
+ 'skresnet34': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnet34_ra-bdc0ccde.pth'),
+ 'skresnet50': _cfg(),
+ 'skresnet50d': _cfg(
+ first_conv='conv1.0'),
+ 'skresnext50_32x4d': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/skresnext50_ra-f40e40bf.pth'),
+}
+
+
+class SelectiveKernelBasic(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
+ sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU,
+ norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
+ super(SelectiveKernelBasic, self).__init__()
+
+ sk_kwargs = sk_kwargs or {}
+ conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer)
+ assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
+ assert base_width == 64, 'BasicBlock doest not support changing base width'
+ first_planes = planes // reduce_first
+ outplanes = planes * self.expansion
+ first_dilation = first_dilation or dilation
+
+ self.conv1 = SelectiveKernel(
+ inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
+ conv_kwargs['act_layer'] = None
+ self.conv2 = ConvBnAct(
+ first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs)
+ self.se = create_attn(attn_layer, outplanes)
+ self.act = act_layer(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.drop_block = drop_block
+ self.drop_path = drop_path
+
+ def zero_init_last_bn(self):
+ nn.init.zeros_(self.conv2.bn.weight)
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ if self.se is not None:
+ x = self.se(x)
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+ if self.downsample is not None:
+ shortcut = self.downsample(shortcut)
+ x += shortcut
+ x = self.act(x)
+ return x
+
+
+class SelectiveKernelBottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
+ cardinality=1, base_width=64, sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None,
+ act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None,
+ drop_block=None, drop_path=None):
+ super(SelectiveKernelBottleneck, self).__init__()
+
+ sk_kwargs = sk_kwargs or {}
+ conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer)
+ width = int(math.floor(planes * (base_width / 64)) * cardinality)
+ first_planes = width // reduce_first
+ outplanes = planes * self.expansion
+ first_dilation = first_dilation or dilation
+
+ self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
+ self.conv2 = SelectiveKernel(
+ first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
+ **conv_kwargs, **sk_kwargs)
+ conv_kwargs['act_layer'] = None
+ self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs)
+ self.se = create_attn(attn_layer, outplanes)
+ self.act = act_layer(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.drop_block = drop_block
+ self.drop_path = drop_path
+
+ def zero_init_last_bn(self):
+ nn.init.zeros_(self.conv3.bn.weight)
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ if self.se is not None:
+ x = self.se(x)
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+ if self.downsample is not None:
+ shortcut = self.downsample(shortcut)
+ x += shortcut
+ x = self.act(x)
+ return x
+
+
+def _create_skresnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ ResNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ **kwargs)
+
+
+@register_model
+def skresnet18(pretrained=False, **kwargs):
+ """Constructs a Selective Kernel ResNet-18 model.
+
+ Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
+ variation splits the input channels to the selective convolutions to keep param count down.
+ """
+ sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
+ model_args = dict(
+ block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
+ zero_init_last_bn=False, **kwargs)
+ return _create_skresnet('skresnet18', pretrained, **model_args)
+
+
+@register_model
+def skresnet34(pretrained=False, **kwargs):
+ """Constructs a Selective Kernel ResNet-34 model.
+
+ Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
+ variation splits the input channels to the selective convolutions to keep param count down.
+ """
+ sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
+ model_args = dict(
+ block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
+ zero_init_last_bn=False, **kwargs)
+ return _create_skresnet('skresnet34', pretrained, **model_args)
+
+
+@register_model
+def skresnet50(pretrained=False, **kwargs):
+ """Constructs a Select Kernel ResNet-50 model.
+
+ Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
+ variation splits the input channels to the selective convolutions to keep param count down.
+ """
+ sk_kwargs = dict(split_input=True)
+ model_args = dict(
+ block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
+ zero_init_last_bn=False, **kwargs)
+ return _create_skresnet('skresnet50', pretrained, **model_args)
+
+
+@register_model
+def skresnet50d(pretrained=False, **kwargs):
+ """Constructs a Select Kernel ResNet-50-D model.
+
+ Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
+ variation splits the input channels to the selective convolutions to keep param count down.
+ """
+ sk_kwargs = dict(split_input=True)
+ model_args = dict(
+ block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
+ block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
+ return _create_skresnet('skresnet50d', pretrained, **model_args)
+
+
+@register_model
+def skresnext50_32x4d(pretrained=False, **kwargs):
+ """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
+ the SKNet-50 model in the Select Kernel Paper
+ """
+ sk_kwargs = dict(rd_ratio=1/16, rd_divisor=32, split_input=False)
+ model_args = dict(
+ block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
+ block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
+ return _create_skresnet('skresnext50_32x4d', pretrained, **model_args)
+
diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py
new file mode 100644
index 0000000..9205790
--- /dev/null
+++ b/timm/models/swin_transformer.py
@@ -0,0 +1,656 @@
+""" Swin Transformer
+A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
+ - https://arxiv.org/pdf/2103.14030
+
+Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
+
+"""
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu
+# --------------------------------------------------------
+import logging
+import math
+from copy import deepcopy
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .fx_features import register_notrace_function
+from .helpers import build_model_with_cfg, overlay_external_default_cfg
+from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
+from .layers import _assert
+from .registry import register_model
+from .vision_transformer import checkpoint_filter_fn, _init_vit_weights
+
+
+_logger = logging.getLogger(__name__)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # patch models (my experiments)
+ 'swin_base_patch4_window12_384': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+
+ 'swin_base_patch4_window7_224': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',
+ ),
+
+ 'swin_large_patch4_window12_384': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+
+ 'swin_large_patch4_window7_224': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',
+ ),
+
+ 'swin_small_patch4_window7_224': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',
+ ),
+
+ 'swin_tiny_patch4_window7_224': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',
+ ),
+
+ 'swin_base_patch4_window12_384_in22k': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
+
+ 'swin_base_patch4_window7_224_in22k': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
+ num_classes=21841),
+
+ 'swin_large_patch4_window12_384_in22k': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0, num_classes=21841),
+
+ 'swin_large_patch4_window7_224_in22k': _cfg(
+ url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
+ num_classes=21841),
+
+}
+
+
+def window_partition(x, window_size: int):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+@register_notrace_function # reason: int argument is a Proxy
+def window_reverse(windows, window_size: int, H: int, W: int):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask: Optional[torch.Tensor] = None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
+ attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ H, W = self.input_resolution
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ self.register_buffer("attn_mask", attn_mask)
+
+ def forward(self, x):
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ _assert(L == H * W, "input feature has wrong size")
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x):
+ """
+ x: B, H*W, C
+ """
+ H, W = self.input_resolution
+ B, L, C = x.shape
+ _assert(L == H * W, "input feature has wrong size")
+ _assert(H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even.")
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.dim
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(
+ dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ for blk in self.blocks:
+ if not torch.jit.is_scripting() and self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x)
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+
+class SwinTransformer(nn.Module):
+ r""" Swin Transformer
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 224
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
+ embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
+ window_size=7, mlp_ratio=4., qkv_bias=True,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
+ use_checkpoint=False, weight_init='', **kwargs):
+ super().__init__()
+
+ self.num_classes = num_classes
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ self.patch_grid = self.patch_embed.grid_size
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+ else:
+ self.absolute_pos_embed = None
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ layers = []
+ for i_layer in range(self.num_layers):
+ layers += [BasicLayer(
+ dim=int(embed_dim * 2 ** i_layer),
+ input_resolution=(self.patch_grid[0] // (2 ** i_layer), self.patch_grid[1] // (2 ** i_layer)),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint)
+ ]
+ self.layers = nn.Sequential(*layers)
+
+ self.norm = norm_layer(self.num_features)
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
+ head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
+ if weight_init.startswith('jax'):
+ for n, m in self.named_modules():
+ _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
+ else:
+ self.apply(_init_vit_weights)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ if self.absolute_pos_embed is not None:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+ x = self.layers(x)
+ x = self.norm(x) # B L C
+ x = self.avgpool(x.transpose(1, 2)) # B C 1
+ x = torch.flatten(x, 1)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def _create_swin_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
+ if default_cfg is None:
+ default_cfg = deepcopy(default_cfgs[variant])
+ overlay_external_default_cfg(default_cfg, kwargs)
+ default_num_classes = default_cfg['num_classes']
+ default_img_size = default_cfg['input_size'][-2:]
+
+ num_classes = kwargs.pop('num_classes', default_num_classes)
+ img_size = kwargs.pop('img_size', default_img_size)
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ model = build_model_with_cfg(
+ SwinTransformer, variant, pretrained,
+ default_cfg=default_cfg,
+ img_size=img_size,
+ num_classes=num_classes,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+
+ return model
+
+
+
+@register_model
+def swin_base_patch4_window12_384(pretrained=False, **kwargs):
+ """ Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
+ return _create_swin_transformer('swin_base_patch4_window12_384', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def swin_base_patch4_window7_224(pretrained=False, **kwargs):
+ """ Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
+ return _create_swin_transformer('swin_base_patch4_window7_224', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def swin_large_patch4_window12_384(pretrained=False, **kwargs):
+ """ Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
+ return _create_swin_transformer('swin_large_patch4_window12_384', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def swin_large_patch4_window7_224(pretrained=False, **kwargs):
+ """ Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
+ return _create_swin_transformer('swin_large_patch4_window7_224', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def swin_small_patch4_window7_224(pretrained=False, **kwargs):
+ """ Swin-S @ 224x224, trained ImageNet-1k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
+ return _create_swin_transformer('swin_small_patch4_window7_224', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def swin_tiny_patch4_window7_224(pretrained=False, **kwargs):
+ """ Swin-T @ 224x224, trained ImageNet-1k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs)
+ return _create_swin_transformer('swin_tiny_patch4_window7_224', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def swin_base_patch4_window12_384_in22k(pretrained=False, **kwargs):
+ """ Swin-B @ 384x384, trained ImageNet-22k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
+ return _create_swin_transformer('swin_base_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def swin_base_patch4_window7_224_in22k(pretrained=False, **kwargs):
+ """ Swin-B @ 224x224, trained ImageNet-22k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
+ return _create_swin_transformer('swin_base_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def swin_large_patch4_window12_384_in22k(pretrained=False, **kwargs):
+ """ Swin-L @ 384x384, trained ImageNet-22k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
+ return _create_swin_transformer('swin_large_patch4_window12_384_in22k', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
+ """ Swin-L @ 224x224, trained ImageNet-22k
+ """
+ model_kwargs = dict(
+ patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
+ return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
diff --git a/timm/models/tnt.py b/timm/models/tnt.py
new file mode 100644
index 0000000..d52f9ce
--- /dev/null
+++ b/timm/models/tnt.py
@@ -0,0 +1,272 @@
+""" Transformer in Transformer (TNT) in PyTorch
+
+A PyTorch implement of TNT as described in
+'Transformer in Transformer' - https://arxiv.org/abs/2103.00112
+
+The official mindspore code is released and available at
+https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
+"""
+import math
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.models.helpers import build_model_with_cfg
+from timm.models.layers import Mlp, DropPath, trunc_normal_
+from timm.models.layers.helpers import to_2tuple
+from timm.models.layers import _assert
+from timm.models.registry import register_model
+from timm.models.vision_transformer import resize_pos_embed
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'pixel_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'tnt_s_patch16_224': _cfg(
+ url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ ),
+ 'tnt_b_patch16_224': _cfg(
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ ),
+}
+
+
+class Attention(nn.Module):
+ """ Multi-Head Attention
+ """
+ def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ head_dim = hidden_dim // num_heads
+ self.head_dim = head_dim
+ self.scale = head_dim ** -0.5
+
+ self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias)
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop, inplace=True)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop, inplace=True)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+ v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ """ TNT Block
+ """
+ def __init__(self, dim, in_dim, num_pixel, num_heads=12, in_num_head=4, mlp_ratio=4.,
+ qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ # Inner transformer
+ self.norm_in = norm_layer(in_dim)
+ self.attn_in = Attention(
+ in_dim, in_dim, num_heads=in_num_head, qkv_bias=qkv_bias,
+ attn_drop=attn_drop, proj_drop=drop)
+
+ self.norm_mlp_in = norm_layer(in_dim)
+ self.mlp_in = Mlp(in_features=in_dim, hidden_features=int(in_dim * 4),
+ out_features=in_dim, act_layer=act_layer, drop=drop)
+
+ self.norm1_proj = norm_layer(in_dim)
+ self.proj = nn.Linear(in_dim * num_pixel, dim, bias=True)
+ # Outer transformer
+ self.norm_out = norm_layer(dim)
+ self.attn_out = Attention(
+ dim, dim, num_heads=num_heads, qkv_bias=qkv_bias,
+ attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm_mlp = norm_layer(dim)
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio),
+ out_features=dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, pixel_embed, patch_embed):
+ # inner
+ pixel_embed = pixel_embed + self.drop_path(self.attn_in(self.norm_in(pixel_embed)))
+ pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
+ # outer
+ B, N, C = patch_embed.size()
+ patch_embed = torch.cat(
+ [patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))],
+ dim=1)
+ patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
+ patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
+ return pixel_embed, patch_embed
+
+
+class PixelEmbed(nn.Module):
+ """ Image to Pixel Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ # grid_size property necessary for resizing positional embedding
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ num_patches = (self.grid_size[0]) * (self.grid_size[1])
+ self.img_size = img_size
+ self.num_patches = num_patches
+ self.in_dim = in_dim
+ new_patch_size = [math.ceil(ps / stride) for ps in patch_size]
+ self.new_patch_size = new_patch_size
+
+ self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
+ self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size)
+
+ def forward(self, x, pixel_pos):
+ B, C, H, W = x.shape
+ _assert(H == self.img_size[0],
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
+ _assert(W == self.img_size[1],
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
+ x = self.proj(x)
+ x = self.unfold(x)
+ x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
+ x = x + pixel_pos
+ x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2)
+ return x
+
+
+class TNT(nn.Module):
+ """ Transformer in Transformer - https://arxiv.org/abs/2103.00112
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, in_dim=48, depth=12,
+ num_heads=12, in_num_head=4, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
+ drop_path_rate=0., norm_layer=nn.LayerNorm, first_stride=4):
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+
+ self.pixel_embed = PixelEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, in_dim=in_dim, stride=first_stride)
+ num_patches = self.pixel_embed.num_patches
+ self.num_patches = num_patches
+ new_patch_size = self.pixel_embed.new_patch_size
+ num_pixel = new_patch_size[0] * new_patch_size[1]
+
+ self.norm1_proj = norm_layer(num_pixel * in_dim)
+ self.proj = nn.Linear(num_pixel * in_dim, embed_dim)
+ self.norm2_proj = norm_layer(embed_dim)
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size[0], new_patch_size[1]))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ blocks = []
+ for i in range(depth):
+ blocks.append(Block(
+ dim=embed_dim, in_dim=in_dim, num_pixel=num_pixel, num_heads=num_heads, in_num_head=in_num_head,
+ mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dpr[i], norm_layer=norm_layer))
+ self.blocks = nn.ModuleList(blocks)
+ self.norm = norm_layer(embed_dim)
+
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ trunc_normal_(self.cls_token, std=.02)
+ trunc_normal_(self.patch_pos, std=.02)
+ trunc_normal_(self.pixel_pos, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'patch_pos', 'pixel_pos', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ pixel_embed = self.pixel_embed(x, self.pixel_pos)
+
+ patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
+ patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
+ patch_embed = patch_embed + self.patch_pos
+ patch_embed = self.pos_drop(patch_embed)
+
+ for blk in self.blocks:
+ pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
+
+ patch_embed = self.norm(patch_embed)
+ return patch_embed[:, 0]
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
+ if state_dict['patch_pos'].shape != model.patch_pos.shape:
+ state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'],
+ model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size)
+ return state_dict
+
+
+def _create_tnt(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ model = build_model_with_cfg(
+ TNT, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+ return model
+
+
+@register_model
+def tnt_s_patch16_224(pretrained=False, **kwargs):
+ model_cfg = dict(
+ patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
+ qkv_bias=False, **kwargs)
+ model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg)
+ return model
+
+
+@register_model
+def tnt_b_patch16_224(pretrained=False, **kwargs):
+ model_cfg = dict(
+ patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4,
+ qkv_bias=False, **kwargs)
+ model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg)
+ return model
diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py
new file mode 100644
index 0000000..372bfb7
--- /dev/null
+++ b/timm/models/tresnet.py
@@ -0,0 +1,297 @@
+"""
+TResNet: High Performance GPU-Dedicated Architecture
+https://arxiv.org/pdf/2003.13630.pdf
+
+Original model: https://github.com/mrT23/TResNet
+
+"""
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+from .helpers import build_model_with_cfg
+from .layers import SpaceToDepthModule, BlurPool2d, InplaceAbn, ClassifierHead, SEModule
+from .registry import register_model
+
+__all__ = ['tresnet_m', 'tresnet_l', 'tresnet_xl']
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': (0, 0, 0), 'std': (1, 1, 1),
+ 'first_conv': 'body.conv1.0', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'tresnet_m': _cfg(
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/tresnet_m_1k_miil_83_1.pth'),
+ 'tresnet_m_miil_in21k': _cfg(
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/tresnet_m_miil_in21k.pth', num_classes=11221),
+ 'tresnet_l': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth'),
+ 'tresnet_xl': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'),
+ 'tresnet_m_448': _cfg(
+ input_size=(3, 448, 448), pool_size=(14, 14),
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'),
+ 'tresnet_l_448': _cfg(
+ input_size=(3, 448, 448), pool_size=(14, 14),
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'),
+ 'tresnet_xl_448': _cfg(
+ input_size=(3, 448, 448), pool_size=(14, 14),
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth')
+}
+
+
+def IABN2Float(module: nn.Module) -> nn.Module:
+ """If `module` is IABN don't use half precision."""
+ if isinstance(module, InplaceAbn):
+ module.float()
+ for child in module.children():
+ IABN2Float(child)
+ return module
+
+
+def conv2d_iabn(ni, nf, stride, kernel_size=3, groups=1, act_layer="leaky_relu", act_param=1e-2):
+ return nn.Sequential(
+ nn.Conv2d(
+ ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups, bias=False),
+ InplaceAbn(nf, act_layer=act_layer, act_param=act_param)
+ )
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, aa_layer=None):
+ super(BasicBlock, self).__init__()
+ if stride == 1:
+ self.conv1 = conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3)
+ else:
+ if aa_layer is None:
+ self.conv1 = conv2d_iabn(inplanes, planes, stride=2, act_param=1e-3)
+ else:
+ self.conv1 = nn.Sequential(
+ conv2d_iabn(inplanes, planes, stride=1, act_param=1e-3),
+ aa_layer(channels=planes, filt_size=3, stride=2))
+
+ self.conv2 = conv2d_iabn(planes, planes, stride=1, act_layer="identity")
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ rd_chs = max(planes * self.expansion // 4, 64)
+ self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None
+
+ def forward(self, x):
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+ else:
+ shortcut = x
+
+ out = self.conv1(x)
+ out = self.conv2(out)
+
+ if self.se is not None:
+ out = self.se(out)
+
+ out += shortcut
+ out = self.relu(out)
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True,
+ act_layer="leaky_relu", aa_layer=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = conv2d_iabn(
+ inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer, act_param=1e-3)
+ if stride == 1:
+ self.conv2 = conv2d_iabn(
+ planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3)
+ else:
+ if aa_layer is None:
+ self.conv2 = conv2d_iabn(
+ planes, planes, kernel_size=3, stride=2, act_layer=act_layer, act_param=1e-3)
+ else:
+ self.conv2 = nn.Sequential(
+ conv2d_iabn(planes, planes, kernel_size=3, stride=1, act_layer=act_layer, act_param=1e-3),
+ aa_layer(channels=planes, filt_size=3, stride=2))
+
+ reduction_chs = max(planes * self.expansion // 8, 64)
+ self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
+
+ self.conv3 = conv2d_iabn(
+ planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+ else:
+ shortcut = x
+
+ out = self.conv1(x)
+ out = self.conv2(out)
+ if self.se is not None:
+ out = self.se(out)
+
+ out = self.conv3(out)
+ out = out + shortcut # no inplace
+ out = self.relu(out)
+
+ return out
+
+
+class TResNet(nn.Module):
+ def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, global_pool='fast', drop_rate=0.):
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ super(TResNet, self).__init__()
+
+ aa_layer = BlurPool2d
+
+ # TResnet stages
+ self.inplanes = int(64 * width_factor)
+ self.planes = int(64 * width_factor)
+ conv1 = conv2d_iabn(in_chans * 16, self.planes, stride=1, kernel_size=3)
+ layer1 = self._make_layer(
+ BasicBlock, self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer) # 56x56
+ layer2 = self._make_layer(
+ BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer) # 28x28
+ layer3 = self._make_layer(
+ Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer) # 14x14
+ layer4 = self._make_layer(
+ Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer) # 7x7
+
+ # body
+ self.body = nn.Sequential(OrderedDict([
+ ('SpaceToDepth', SpaceToDepthModule()),
+ ('conv1', conv1),
+ ('layer1', layer1),
+ ('layer2', layer2),
+ ('layer3', layer3),
+ ('layer4', layer4)]))
+
+ self.feature_info = [
+ dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D?
+ dict(num_chs=self.planes, reduction=4, module='body.layer1'),
+ dict(num_chs=self.planes * 2, reduction=8, module='body.layer2'),
+ dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'),
+ dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'),
+ ]
+
+ # head
+ self.num_features = (self.planes * 8) * Bottleneck.expansion
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ # model initilization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, InplaceAbn):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ # residual connections special initialization
+ for m in self.modules():
+ if isinstance(m, BasicBlock):
+ m.conv2[1].weight = nn.Parameter(torch.zeros_like(m.conv2[1].weight)) # BN to zero
+ if isinstance(m, Bottleneck):
+ m.conv3[1].weight = nn.Parameter(torch.zeros_like(m.conv3[1].weight)) # BN to zero
+ if isinstance(m, nn.Linear):
+ m.weight.data.normal_(0, 0.01)
+
+ def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ layers = []
+ if stride == 2:
+ # avg pooling before 1x1 conv
+ layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
+ layers += [conv2d_iabn(
+ self.inplanes, planes * block.expansion, kernel_size=1, stride=1, act_layer="identity")]
+ downsample = nn.Sequential(*layers)
+
+ layers = []
+ layers.append(block(
+ self.inplanes, planes, stride, downsample, use_se=use_se, aa_layer=aa_layer))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(
+ block(self.inplanes, planes, use_se=use_se, aa_layer=aa_layer))
+ return nn.Sequential(*layers)
+
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='fast'):
+ self.head = ClassifierHead(
+ self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ def forward_features(self, x):
+ return self.body(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def _create_tresnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ TResNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True),
+ **kwargs)
+
+
+@register_model
+def tresnet_m(pretrained=False, **kwargs):
+ model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs)
+ return _create_tresnet('tresnet_m', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def tresnet_m_miil_in21k(pretrained=False, **kwargs):
+ model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs)
+ return _create_tresnet('tresnet_m_miil_in21k', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def tresnet_l(pretrained=False, **kwargs):
+ model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs)
+ return _create_tresnet('tresnet_l', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def tresnet_xl(pretrained=False, **kwargs):
+ model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs)
+ return _create_tresnet('tresnet_xl', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def tresnet_m_448(pretrained=False, **kwargs):
+ model_kwargs = dict(layers=[3, 4, 11, 3], **kwargs)
+ return _create_tresnet('tresnet_m_448', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def tresnet_l_448(pretrained=False, **kwargs):
+ model_kwargs = dict(layers=[4, 5, 18, 3], width_factor=1.2, **kwargs)
+ return _create_tresnet('tresnet_l_448', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def tresnet_xl_448(pretrained=False, **kwargs):
+ model_kwargs = dict(layers=[4, 5, 24, 3], width_factor=1.3, **kwargs)
+ return _create_tresnet('tresnet_xl_448', pretrained=pretrained, **model_kwargs)
diff --git a/timm/models/twins.py b/timm/models/twins.py
new file mode 100644
index 0000000..67a939d
--- /dev/null
+++ b/timm/models/twins.py
@@ -0,0 +1,424 @@
+""" Twins
+A PyTorch impl of : `Twins: Revisiting the Design of Spatial Attention in Vision Transformers`
+ - https://arxiv.org/pdf/2104.13840.pdf
+
+Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/license info below
+
+"""
+# --------------------------------------------------------
+# Twins
+# Copyright (c) 2021 Meituan
+# Licensed under The Apache 2.0 License [see LICENSE for details]
+# Written by Xinjie Li, Xiangxiang Chu
+# --------------------------------------------------------
+import math
+from copy import deepcopy
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .layers import Mlp, DropPath, to_2tuple, trunc_normal_
+from .fx_features import register_notrace_module
+from .registry import register_model
+from .vision_transformer import Attention
+from .helpers import build_model_with_cfg
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'twins_pcpvt_small': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_small-e70e7e7a.pth',
+ ),
+ 'twins_pcpvt_base': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_base-e5ecb09b.pth',
+ ),
+ 'twins_pcpvt_large': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_large-d273f802.pth',
+ ),
+ 'twins_svt_small': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_small-42e5f78c.pth',
+ ),
+ 'twins_svt_base': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_base-c2265010.pth',
+ ),
+ 'twins_svt_large': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_large-90f6aaa9.pth',
+ ),
+}
+
+Size_ = Tuple[int, int]
+
+
+@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
+class LocallyGroupedAttn(nn.Module):
+ """ LSA: self attention within a group
+ """
+ def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1):
+ assert ws != 1
+ super(LocallyGroupedAttn, self).__init__()
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
+
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.ws = ws
+
+ def forward(self, x, size: Size_):
+ # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for
+ # both. You can choose any one, we recommend forward_padding because it's neat. However,
+ # the masking implementation is more reasonable and accurate.
+ B, N, C = x.shape
+ H, W = size
+ x = x.view(B, H, W, C)
+ pad_l = pad_t = 0
+ pad_r = (self.ws - W % self.ws) % self.ws
+ pad_b = (self.ws - H % self.ws) % self.ws
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+ _h, _w = Hp // self.ws, Wp // self.ws
+ x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3)
+ qkv = self.qkv(x).reshape(
+ B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C)
+ x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C)
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+ x = x.reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ # def forward_mask(self, x, size: Size_):
+ # B, N, C = x.shape
+ # H, W = size
+ # x = x.view(B, H, W, C)
+ # pad_l = pad_t = 0
+ # pad_r = (self.ws - W % self.ws) % self.ws
+ # pad_b = (self.ws - H % self.ws) % self.ws
+ # x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ # _, Hp, Wp, _ = x.shape
+ # _h, _w = Hp // self.ws, Wp // self.ws
+ # mask = torch.zeros((1, Hp, Wp), device=x.device)
+ # mask[:, -pad_b:, :].fill_(1)
+ # mask[:, :, -pad_r:].fill_(1)
+ #
+ # x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) # B, _h, _w, ws, ws, C
+ # mask = mask.reshape(1, _h, self.ws, _w, self.ws).transpose(2, 3).reshape(1, _h * _w, self.ws * self.ws)
+ # attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) # 1, _h*_w, ws*ws, ws*ws
+ # attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-1000.0)).masked_fill(attn_mask == 0, float(0.0))
+ # qkv = self.qkv(x).reshape(
+ # B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
+ # # n_h, B, _w*_h, nhead, ws*ws, dim
+ # q, k, v = qkv[0], qkv[1], qkv[2] # B, _h*_w, n_head, ws*ws, dim_head
+ # attn = (q @ k.transpose(-2, -1)) * self.scale # B, _h*_w, n_head, ws*ws, ws*ws
+ # attn = attn + attn_mask.unsqueeze(2)
+ # attn = attn.softmax(dim=-1)
+ # attn = self.attn_drop(attn) # attn @v -> B, _h*_w, n_head, ws*ws, dim_head
+ # attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C)
+ # x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C)
+ # if pad_r > 0 or pad_b > 0:
+ # x = x[:, :H, :W, :].contiguous()
+ # x = x.reshape(B, N, C)
+ # x = self.proj(x)
+ # x = self.proj_drop(x)
+ # return x
+
+
+class GlobalSubSampleAttn(nn.Module):
+ """ GSA: using a key to summarize the information for a group to be efficient.
+ """
+ def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1):
+ super().__init__()
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
+
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.q = nn.Linear(dim, dim, bias=True)
+ self.kv = nn.Linear(dim, dim * 2, bias=True)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.sr_ratio = sr_ratio
+ if sr_ratio > 1:
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
+ self.norm = nn.LayerNorm(dim)
+ else:
+ self.sr = None
+ self.norm = None
+
+ def forward(self, x, size: Size_):
+ B, N, C = x.shape
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+
+ if self.sr is not None:
+ x = x.permute(0, 2, 1).reshape(B, C, *size)
+ x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
+ x = self.norm(x)
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ k, v = kv[0], kv[1]
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ if ws is None:
+ self.attn = Attention(dim, num_heads, False, None, attn_drop, drop)
+ elif ws == 1:
+ self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio)
+ else:
+ self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x, size: Size_):
+ x = x + self.drop_path(self.attn(self.norm1(x), size))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PosConv(nn.Module):
+ # PEG from https://arxiv.org/abs/2102.10882
+ def __init__(self, in_chans, embed_dim=768, stride=1):
+ super(PosConv, self).__init__()
+ self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), )
+ self.stride = stride
+
+ def forward(self, x, size: Size_):
+ B, N, C = x.shape
+ cnn_feat_token = x.transpose(1, 2).view(B, C, *size)
+ x = self.proj(cnn_feat_token)
+ if self.stride == 1:
+ x += cnn_feat_token
+ x = x.flatten(2).transpose(1, 2)
+ return x
+
+ def no_weight_decay(self):
+ return ['proj.%d.weight' % i for i in range(4)]
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+ assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
+ f"img_size {img_size} should be divided by patch_size {patch_size}."
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
+ self.num_patches = self.H * self.W
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.norm = nn.LayerNorm(embed_dim)
+
+ def forward(self, x) -> Tuple[torch.Tensor, Size_]:
+ B, C, H, W = x.shape
+
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ out_size = (H // self.patch_size[0], W // self.patch_size[1])
+
+ return x, out_size
+
+
+class Twins(nn.Module):
+ """ Twins Vision Transfomer (Revisiting Spatial Attention)
+
+ Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git
+ """
+ def __init__(
+ self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512),
+ num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=(3, 4, 6, 3), sr_ratios=(8, 4, 2, 1), wss=None,
+ block_cls=Block):
+ super().__init__()
+ self.num_classes = num_classes
+ self.depths = depths
+ self.embed_dims = embed_dims
+ self.num_features = embed_dims[-1]
+
+ img_size = to_2tuple(img_size)
+ prev_chs = in_chans
+ self.patch_embeds = nn.ModuleList()
+ self.pos_drops = nn.ModuleList()
+ for i in range(len(depths)):
+ self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i]))
+ self.pos_drops.append(nn.Dropout(p=drop_rate))
+ prev_chs = embed_dims[i]
+ img_size = tuple(t // patch_size for t in img_size)
+ patch_size = 2
+
+ self.blocks = nn.ModuleList()
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+ cur = 0
+ for k in range(len(depths)):
+ _block = nn.ModuleList([block_cls(
+ dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], drop=drop_rate,
+ attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k],
+ ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])])
+ self.blocks.append(_block)
+ cur += depths[k]
+
+ self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims])
+
+ self.norm = norm_layer(self.num_features)
+
+ # classification head
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ # init weights
+ self.apply(self._init_weights)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return set(['pos_block.' + n for n, p in self.pos_block.named_parameters()])
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1.0)
+ m.bias.data.zero_()
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ for i, (embed, drop, blocks, pos_blk) in enumerate(
+ zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
+ x, size = embed(x)
+ x = drop(x)
+ for j, blk in enumerate(blocks):
+ x = blk(x, size)
+ if j == 0:
+ x = pos_blk(x, size) # PEG here
+ if i < len(self.depths) - 1:
+ x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
+ x = self.norm(x)
+ return x.mean(dim=1) # GAP here
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def _create_twins(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ model = build_model_with_cfg(
+ Twins, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ **kwargs)
+ return model
+
+
+@register_model
+def twins_pcpvt_small(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
+ return _create_twins('twins_pcpvt_small', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def twins_pcpvt_base(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
+ depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
+ return _create_twins('twins_pcpvt_base', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def twins_pcpvt_large(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
+ depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
+ return _create_twins('twins_pcpvt_large', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def twins_svt_small(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4],
+ depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
+ return _create_twins('twins_svt_small', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def twins_svt_base(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4],
+ depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
+ return _create_twins('twins_svt_base', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def twins_svt_large(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4],
+ depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
+ return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs)
diff --git a/timm/models/vgg.py b/timm/models/vgg.py
new file mode 100644
index 0000000..11f6d0e
--- /dev/null
+++ b/timm/models/vgg.py
@@ -0,0 +1,263 @@
+"""VGG
+
+Adapted from https://github.com/pytorch/vision 'vgg.py' (BSD-3-Clause) with a few changes for
+timm functionality.
+
+Copyright 2021 Ross Wightman
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Union, List, Dict, Any, cast
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .fx_features import register_notrace_module
+from .layers import ClassifierHead
+from .registry import register_model
+
+__all__ = [
+ 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
+ 'vgg19_bn', 'vgg19',
+]
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (1, 1),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'features.0', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'vgg11': _cfg(url='https://download.pytorch.org/models/vgg11-bbd30ac9.pth'),
+ 'vgg13': _cfg(url='https://download.pytorch.org/models/vgg13-c768596a.pth'),
+ 'vgg16': _cfg(url='https://download.pytorch.org/models/vgg16-397923af.pth'),
+ 'vgg19': _cfg(url='https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'),
+ 'vgg11_bn': _cfg(url='https://download.pytorch.org/models/vgg11_bn-6002323d.pth'),
+ 'vgg13_bn': _cfg(url='https://download.pytorch.org/models/vgg13_bn-abd245e5.pth'),
+ 'vgg16_bn': _cfg(url='https://download.pytorch.org/models/vgg16_bn-6c64b313.pth'),
+ 'vgg19_bn': _cfg(url='https://download.pytorch.org/models/vgg19_bn-c79401a0.pth'),
+}
+
+
+cfgs: Dict[str, List[Union[str, int]]] = {
+ 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+ 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+ 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
+ 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
+}
+
+
+@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
+class ConvMlp(nn.Module):
+
+ def __init__(self, in_features=512, out_features=4096, kernel_size=7, mlp_ratio=1.0,
+ drop_rate: float = 0.2, act_layer: nn.Module = None, conv_layer: nn.Module = None):
+ super(ConvMlp, self).__init__()
+ self.input_kernel_size = kernel_size
+ mid_features = int(out_features * mlp_ratio)
+ self.fc1 = conv_layer(in_features, mid_features, kernel_size, bias=True)
+ self.act1 = act_layer(True)
+ self.drop = nn.Dropout(drop_rate)
+ self.fc2 = conv_layer(mid_features, out_features, 1, bias=True)
+ self.act2 = act_layer(True)
+
+ def forward(self, x):
+ if x.shape[-2] < self.input_kernel_size or x.shape[-1] < self.input_kernel_size:
+ # keep the input size >= 7x7
+ output_size = (max(self.input_kernel_size, x.shape[-2]), max(self.input_kernel_size, x.shape[-1]))
+ x = F.adaptive_avg_pool2d(x, output_size)
+ x = self.fc1(x)
+ x = self.act1(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.act2(x)
+ return x
+
+
+class VGG(nn.Module):
+
+ def __init__(
+ self,
+ cfg: List[Any],
+ num_classes: int = 1000,
+ in_chans: int = 3,
+ output_stride: int = 32,
+ mlp_ratio: float = 1.0,
+ act_layer: nn.Module = nn.ReLU,
+ conv_layer: nn.Module = nn.Conv2d,
+ norm_layer: nn.Module = None,
+ global_pool: str = 'avg',
+ drop_rate: float = 0.,
+ ) -> None:
+ super(VGG, self).__init__()
+ assert output_stride == 32
+ self.num_classes = num_classes
+ self.num_features = 4096
+ self.drop_rate = drop_rate
+ self.feature_info = []
+ prev_chs = in_chans
+ net_stride = 1
+ pool_layer = nn.MaxPool2d
+ layers: List[nn.Module] = []
+ for v in cfg:
+ last_idx = len(layers) - 1
+ if v == 'M':
+ self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{last_idx}'))
+ layers += [pool_layer(kernel_size=2, stride=2)]
+ net_stride *= 2
+ else:
+ v = cast(int, v)
+ conv2d = conv_layer(prev_chs, v, kernel_size=3, padding=1)
+ if norm_layer is not None:
+ layers += [conv2d, norm_layer(v), act_layer(inplace=True)]
+ else:
+ layers += [conv2d, act_layer(inplace=True)]
+ prev_chs = v
+ self.features = nn.Sequential(*layers)
+ self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{len(layers) - 1}'))
+ self.pre_logits = ConvMlp(
+ prev_chs, self.num_features, 7, mlp_ratio=mlp_ratio,
+ drop_rate=drop_rate, act_layer=act_layer, conv_layer=conv_layer)
+ self.head = ClassifierHead(
+ self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ self._initialize_weights()
+
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.head = ClassifierHead(
+ self.num_features, self.num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.features(x)
+ x = self.pre_logits(x)
+ return x
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+ def _initialize_weights(self) -> None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.constant_(m.bias, 0)
+
+
+def _filter_fn(state_dict):
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
+ out_dict = {}
+ for k, v in state_dict.items():
+ k_r = k
+ k_r = k_r.replace('classifier.0', 'pre_logits.fc1')
+ k_r = k_r.replace('classifier.3', 'pre_logits.fc2')
+ k_r = k_r.replace('classifier.6', 'head.fc')
+ if 'classifier.0.weight' in k:
+ v = v.reshape(-1, 512, 7, 7)
+ if 'classifier.3.weight' in k:
+ v = v.reshape(-1, 4096, 1, 1)
+ out_dict[k_r] = v
+ return out_dict
+
+
+def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG:
+ cfg = variant.split('_')[0]
+ # NOTE: VGG is one of the only models with stride==1 features, so indices are offset from other models
+ out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5))
+ model = build_model_with_cfg(
+ VGG, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ model_cfg=cfgs[cfg],
+ feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
+ pretrained_filter_fn=_filter_fn,
+ **kwargs)
+ return model
+
+
+@register_model
+def vgg11(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 11-layer model (configuration "A") from
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(**kwargs)
+ return _create_vgg('vgg11', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg11_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 11-layer model (configuration "A") with batch normalization
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
+ return _create_vgg('vgg11_bn', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg13(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 13-layer model (configuration "B")
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(**kwargs)
+ return _create_vgg('vgg13', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg13_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 13-layer model (configuration "B") with batch normalization
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
+ return _create_vgg('vgg13_bn', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg16(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 16-layer model (configuration "D")
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(**kwargs)
+ return _create_vgg('vgg16', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg16_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 16-layer model (configuration "D") with batch normalization
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
+ return _create_vgg('vgg16_bn', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg19(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 19-layer model (configuration "E")
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(**kwargs)
+ return _create_vgg('vgg19', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg19_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 19-layer model (configuration 'E') with batch normalization
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
+ return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args)
\ No newline at end of file
diff --git a/timm/models/visformer.py b/timm/models/visformer.py
new file mode 100644
index 0000000..37284c9
--- /dev/null
+++ b/timm/models/visformer.py
@@ -0,0 +1,412 @@
+""" Visformer
+
+Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533
+
+From original at https://github.com/danczs/Visformer
+
+"""
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg, overlay_external_default_cfg
+from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier
+from .registry import register_model
+
+
+__all__ = ['Visformer']
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.0', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = dict(
+ visformer_tiny=_cfg(),
+ visformer_small=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth'
+ ),
+)
+
+
+class SpatialMlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None,
+ act_layer=nn.GELU, drop=0., group=8, spatial_conv=False):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ drop_probs = to_2tuple(drop)
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.spatial_conv = spatial_conv
+ if self.spatial_conv:
+ if group < 2: # net setting
+ hidden_features = in_features * 5 // 6
+ else:
+ hidden_features = in_features * 2
+ self.hidden_features = hidden_features
+ self.group = group
+ self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False)
+ self.act1 = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ if self.spatial_conv:
+ self.conv2 = nn.Conv2d(
+ hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False)
+ self.act2 = act_layer()
+ else:
+ self.conv2 = None
+ self.act2 = None
+ self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False)
+ self.drop3 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.act1(x)
+ x = self.drop1(x)
+ if self.conv2 is not None:
+ x = self.conv2(x)
+ x = self.act2(x)
+ x = self.conv3(x)
+ x = self.drop3(x)
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = round(dim // num_heads * head_dim_ratio)
+ self.head_dim = head_dim
+ self.scale = head_dim ** -0.5
+ self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
+ q, k, v = x[0], x[1], x[2]
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4.,
+ drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d,
+ group=8, attn_disabled=False, spatial_conv=False):
+ super().__init__()
+ self.spatial_conv = spatial_conv
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ if attn_disabled:
+ self.norm1 = None
+ self.attn = None
+ else:
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop)
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = SpatialMlp(
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
+ group=group, spatial_conv=spatial_conv) # new setting
+
+ def forward(self, x):
+ if self.attn is not None:
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class Visformer(nn.Module):
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
+ depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
+ norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111',
+ vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ self.num_classes = num_classes
+ self.embed_dim = embed_dim
+ self.init_channels = init_channels
+ self.img_size = img_size
+ self.vit_stem = vit_stem
+ self.conv_init = conv_init
+ if isinstance(depth, (list, tuple)):
+ self.stage_num1, self.stage_num2, self.stage_num3 = depth
+ depth = sum(depth)
+ else:
+ self.stage_num1 = self.stage_num3 = depth // 3
+ self.stage_num2 = depth - self.stage_num1 - self.stage_num3
+ self.pos_embed = pos_embed
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+
+ # stage 1
+ if self.vit_stem:
+ self.stem = None
+ self.patch_embed1 = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans,
+ embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
+ img_size = [x // patch_size for x in img_size]
+ else:
+ if self.init_channels is None:
+ self.stem = None
+ self.patch_embed1 = PatchEmbed(
+ img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans,
+ embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
+ img_size = [x // (patch_size // 2) for x in img_size]
+ else:
+ self.stem = nn.Sequential(
+ nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False),
+ nn.BatchNorm2d(self.init_channels),
+ nn.ReLU(inplace=True)
+ )
+ img_size = [x // 2 for x in img_size]
+ self.patch_embed1 = PatchEmbed(
+ img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels,
+ embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
+ img_size = [x // (patch_size // 4) for x in img_size]
+
+ if self.pos_embed:
+ if self.vit_stem:
+ self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
+ else:
+ self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ self.stage1 = nn.ModuleList([
+ Block(
+ dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1')
+ )
+ for i in range(self.stage_num1)
+ ])
+
+ # stage2
+ if not self.vit_stem:
+ self.patch_embed2 = PatchEmbed(
+ img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2,
+ embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
+ img_size = [x // (patch_size // 8) for x in img_size]
+ if self.pos_embed:
+ self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
+ self.stage2 = nn.ModuleList([
+ Block(
+ dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1')
+ )
+ for i in range(self.stage_num1, self.stage_num1+self.stage_num2)
+ ])
+
+ # stage 3
+ if not self.vit_stem:
+ self.patch_embed3 = PatchEmbed(
+ img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim,
+ embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False)
+ img_size = [x // (patch_size // 8) for x in img_size]
+ if self.pos_embed:
+ self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size))
+ self.stage3 = nn.ModuleList([
+ Block(
+ dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1')
+ )
+ for i in range(self.stage_num1+self.stage_num2, depth)
+ ])
+
+ # head
+ self.num_features = embed_dim if self.vit_stem else embed_dim * 2
+ self.norm = norm_layer(self.num_features)
+ self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ # weights init
+ if self.pos_embed:
+ trunc_normal_(self.pos_embed1, std=0.02)
+ if not self.vit_stem:
+ trunc_normal_(self.pos_embed2, std=0.02)
+ trunc_normal_(self.pos_embed3, std=0.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+ elif isinstance(m, nn.Conv2d):
+ if self.conv_init:
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ else:
+ trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0.)
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ if self.stem is not None:
+ x = self.stem(x)
+
+ # stage 1
+ x = self.patch_embed1(x)
+ if self.pos_embed:
+ x = x + self.pos_embed1
+ x = self.pos_drop(x)
+ for b in self.stage1:
+ x = b(x)
+
+ # stage 2
+ if not self.vit_stem:
+ x = self.patch_embed2(x)
+ if self.pos_embed:
+ x = x + self.pos_embed2
+ x = self.pos_drop(x)
+ for b in self.stage2:
+ x = b(x)
+
+ # stage3
+ if not self.vit_stem:
+ x = self.patch_embed3(x)
+ if self.pos_embed:
+ x = x + self.pos_embed3
+ x = self.pos_drop(x)
+ for b in self.stage3:
+ x = b(x)
+
+ x = self.norm(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ x = self.head(x)
+ return x
+
+
+def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+ model = build_model_with_cfg(
+ Visformer, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ **kwargs)
+ return model
+
+
+@register_model
+def visformer_tiny(pretrained=False, **kwargs):
+ model_cfg = dict(
+ init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
+ attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
+ embed_norm=nn.BatchNorm2d, **kwargs)
+ model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg)
+ return model
+
+
+@register_model
+def visformer_small(pretrained=False, **kwargs):
+ model_cfg = dict(
+ init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
+ attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
+ embed_norm=nn.BatchNorm2d, **kwargs)
+ model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg)
+ return model
+
+
+# @register_model
+# def visformer_net1(pretrained=False, **kwargs):
+# model = Visformer(
+# init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
+# spatial_conv='000', vit_stem=True, conv_init=True, **kwargs)
+# model.default_cfg = _cfg()
+# return model
+#
+#
+# @register_model
+# def visformer_net2(pretrained=False, **kwargs):
+# model = Visformer(
+# init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
+# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
+# model.default_cfg = _cfg()
+# return model
+#
+#
+# @register_model
+# def visformer_net3(pretrained=False, **kwargs):
+# model = Visformer(
+# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
+# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
+# model.default_cfg = _cfg()
+# return model
+#
+#
+# @register_model
+# def visformer_net4(pretrained=False, **kwargs):
+# model = Visformer(
+# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
+# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
+# model.default_cfg = _cfg()
+# return model
+#
+#
+# @register_model
+# def visformer_net5(pretrained=False, **kwargs):
+# model = Visformer(
+# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
+# spatial_conv='111', vit_stem=False, conv_init=True, **kwargs)
+# model.default_cfg = _cfg()
+# return model
+#
+#
+# @register_model
+# def visformer_net6(pretrained=False, **kwargs):
+# model = Visformer(
+# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
+# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
+# model.default_cfg = _cfg()
+# return model
+#
+#
+# @register_model
+# def visformer_net7(pretrained=False, **kwargs):
+# model = Visformer(
+# init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000',
+# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
+# model.default_cfg = _cfg()
+# return model
+
+
+
+
diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py
new file mode 100644
index 0000000..3db6364
--- /dev/null
+++ b/timm/models/vision_transformer.py
@@ -0,0 +1,989 @@
+""" Vision Transformer (ViT) in PyTorch
+
+A PyTorch implement of Vision Transformers as described in:
+
+'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
+ - https://arxiv.org/abs/2010.11929
+
+`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
+ - https://arxiv.org/abs/2106.10270
+
+The official jax code is released and available at https://github.com/google-research/vision_transformer
+
+DeiT model defs and weights from https://github.com/facebookresearch/deit,
+paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
+
+Acknowledgments:
+* The paper authors for releasing code and weights, thanks!
+* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
+for some einops/einsum fun
+* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
+* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import math
+import logging
+from functools import partial
+from collections import OrderedDict
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
+from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
+from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
+from .registry import register_model
+
+_logger = logging.getLogger(__name__)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # patch models (weights from official Google JAX impl)
+ 'vit_tiny_patch16_224': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
+ 'vit_tiny_patch16_384': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'vit_small_patch32_224': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
+ 'vit_small_patch32_384': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'vit_small_patch16_224': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
+ 'vit_small_patch16_384': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'vit_base_patch32_224': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
+ 'vit_base_patch32_384': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'vit_base_patch16_224': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
+ 'vit_base_patch16_384': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'vit_base_patch8_224': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
+ 'vit_large_patch32_224': _cfg(
+ url='', # no official model weights for this combo, only for in21k
+ ),
+ 'vit_large_patch32_384': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'vit_large_patch16_224': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
+ 'vit_large_patch16_384': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
+ input_size=(3, 384, 384), crop_pct=1.0),
+
+ # patch models, imagenet21k (weights from official Google JAX impl)
+ 'vit_tiny_patch16_224_in21k': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
+ num_classes=21843),
+ 'vit_small_patch32_224_in21k': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
+ num_classes=21843),
+ 'vit_small_patch16_224_in21k': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
+ num_classes=21843),
+ 'vit_base_patch32_224_in21k': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
+ num_classes=21843),
+ 'vit_base_patch16_224_in21k': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
+ num_classes=21843),
+ 'vit_base_patch8_224_in21k': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
+ num_classes=21843),
+ 'vit_large_patch32_224_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
+ num_classes=21843),
+ 'vit_large_patch16_224_in21k': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
+ num_classes=21843),
+ 'vit_huge_patch14_224_in21k': _cfg(
+ url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
+ hf_hub='timm/vit_huge_patch14_224_in21k',
+ num_classes=21843),
+
+ # SAM trained models (https://arxiv.org/abs/2106.01548)
+ 'vit_base_patch32_sam_224': _cfg(
+ url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),
+ 'vit_base_patch16_sam_224': _cfg(
+ url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),
+
+ # deit models (FB weights)
+ 'deit_tiny_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'deit_small_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'deit_base_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'deit_base_patch16_384': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
+ 'deit_tiny_distilled_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
+ 'deit_small_distilled_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
+ 'deit_base_distilled_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
+ 'deit_base_distilled_patch16_384': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
+ classifier=('head', 'head_dist')),
+
+ # ViT ImageNet-21K-P pretraining by MILL
+ 'vit_base_patch16_224_miil_in21k': _cfg(
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
+ ),
+ 'vit_base_patch16_224_miil': _cfg(
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
+ '/vit_base_patch16_224_1k_miil_84_4.pth',
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
+ ),
+}
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
+ - https://arxiv.org/abs/2010.11929
+
+ Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
+ - https://arxiv.org/abs/2012.12877
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
+ act_layer=None, weight_init=''):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ distilled (bool): model includes a distillation token and head as in DeiT models
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ embed_layer (nn.Module): patch embedding layer
+ norm_layer: (nn.Module): normalization layer
+ weight_init: (str): weight init scheme
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 2 if distilled else 1
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ act_layer = act_layer or nn.GELU
+
+ self.patch_embed = embed_layer(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.Sequential(*[
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
+ attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ # Representation layer
+ if representation_size and not distilled:
+ self.num_features = representation_size
+ self.pre_logits = nn.Sequential(OrderedDict([
+ ('fc', nn.Linear(embed_dim, representation_size)),
+ ('act', nn.Tanh())
+ ]))
+ else:
+ self.pre_logits = nn.Identity()
+
+ # Classifier head(s)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+ self.head_dist = None
+ if distilled:
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ self.init_weights(weight_init)
+
+ def init_weights(self, mode=''):
+ assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
+ head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
+ trunc_normal_(self.pos_embed, std=.02)
+ if self.dist_token is not None:
+ trunc_normal_(self.dist_token, std=.02)
+ if mode.startswith('jax'):
+ # leave cls token as zeros to match jax impl
+ named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
+ else:
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(_init_vit_weights)
+
+ def _init_weights(self, m):
+ # this fn left here for compat with downstream users
+ _init_vit_weights(m)
+
+ @torch.jit.ignore()
+ def load_pretrained(self, checkpoint_path, prefix=''):
+ _load_weights(self, checkpoint_path, prefix)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token', 'dist_token'}
+
+ def get_classifier(self):
+ if self.dist_token is None:
+ return self.head
+ else:
+ return self.head, self.head_dist
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ if self.num_tokens == 2:
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x, block_layers=[], get_tokens=False, local_id = [], side_length = 7):
+ x = self.patch_embed(x)
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ if self.dist_token is None:
+ x = torch.cat((cls_token, x), dim=1)
+ else:
+ x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = self.pos_drop(x + self.pos_embed)
+
+ if local_id != []:
+ if x.shape[0] != 1:
+ print('Please enter one image at a time!')
+ x = x[:,1:,:] # 去除class_token
+ B, L, C = x.shape[0], x.shape[1], x.shape[2]
+ S = int(math.sqrt(L))
+ if S * S != L:
+ print('Not a square!')
+ x = x.reshape(B, S, S, C)
+
+ h_S = int(side_length / 2)
+ # print('h_S', h_S)
+ local_x_list = []
+ for id in local_id:
+ row_id = int(id / S)
+ column_id = id % S
+ # if row_id - h_S >=0 and row_id + h_S =0 and column_id + h_S S:
+ row_1 = row_1 + S - row_2
+ row_2 = S
+ if column_1 < 0:
+ column_2 = column_2 - column_1
+ column_1 = 0
+ if column_2 > S:
+ column_1 = column_1 + S - column_2
+ column_2 = S
+ local_x = x[:,row_1 : row_2 , column_1 : column_2 , :]
+
+ # print((row_id, column_id), local_x.shape)
+
+
+ local_x = local_x.flatten(1,2)
+ local_x_list.append(local_x)
+
+ local_x_list = torch.cat(local_x_list, 0)
+ x = local_x_list
+
+
+
+
+
+ # print('ceshi:', x.shape)
+ # x = x[:,50:,:]
+ if get_tokens:
+ # print(len(self.blocks))
+ x_list = []
+ for block_id, block in enumerate(self.blocks):
+ x = block(x)
+ if block_id in block_layers:
+ if local_id == []:
+ # print(self.norm)
+ x_list.append(self.norm(x[:,1:,:]))
+ else :
+ x_list.append(self.norm(x))
+ if block_id == block_layers[-1]:
+ return x_list
+ else:
+ x = self.blocks(x)
+ x = self.norm(x)
+ if self.dist_token is None:
+ return self.pre_logits(x[:, 0])
+ else:
+ return x[:, 0], x[:, 1]
+
+ def forward(self, x, block_layers=[], get_tokens=False, local_id=[], side_length=7):
+ x = self.forward_features(x, get_tokens=get_tokens, block_layers=block_layers, local_id=local_id, side_length=side_length)
+ if not get_tokens:
+ if self.head_dist is not None:
+ x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
+ if self.training and not torch.jit.is_scripting():
+ # during inference, return the average of both classifier predictions
+ return x, x_dist
+ else:
+ return (x + x_dist) / 2
+ else:
+ x = self.head(x)
+ return x
+
+
+def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
+ """ ViT weight initialization
+ * When called without n, head_bias, jax_impl args it will behave exactly the same
+ as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
+ * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
+ """
+ if isinstance(module, nn.Linear):
+ if name.startswith('head'):
+ nn.init.zeros_(module.weight)
+ nn.init.constant_(module.bias, head_bias)
+ elif name.startswith('pre_logits'):
+ lecun_normal_(module.weight)
+ nn.init.zeros_(module.bias)
+ else:
+ if jax_impl:
+ nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ if 'mlp' in name:
+ nn.init.normal_(module.bias, std=1e-6)
+ else:
+ nn.init.zeros_(module.bias)
+ else:
+ trunc_normal_(module.weight, std=.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif jax_impl and isinstance(module, nn.Conv2d):
+ # NOTE conv was left to pytorch default in my original init
+ lecun_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
+ nn.init.zeros_(module.bias)
+ nn.init.ones_(module.weight)
+
+
+@torch.no_grad()
+def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
+ """ Load weights from .npz checkpoints for official Google Brain Flax implementation
+ """
+ import numpy as np
+
+ def _n2p(w, t=True):
+ if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
+ w = w.flatten()
+ if t:
+ if w.ndim == 4:
+ w = w.transpose([3, 2, 0, 1])
+ elif w.ndim == 3:
+ w = w.transpose([2, 0, 1])
+ elif w.ndim == 2:
+ w = w.transpose([1, 0])
+ return torch.from_numpy(w)
+
+ w = np.load(checkpoint_path)
+ if not prefix and 'opt/target/embedding/kernel' in w:
+ prefix = 'opt/target/'
+
+ if hasattr(model.patch_embed, 'backbone'):
+ # hybrid
+ backbone = model.patch_embed.backbone
+ stem_only = not hasattr(backbone, 'stem')
+ stem = backbone if stem_only else backbone.stem
+ stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
+ stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
+ stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
+ if not stem_only:
+ for i, stage in enumerate(backbone.stages):
+ for j, block in enumerate(stage.blocks):
+ bp = f'{prefix}block{i + 1}/unit{j + 1}/'
+ for r in range(3):
+ getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
+ getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
+ getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
+ if block.downsample is not None:
+ block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
+ block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
+ block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
+ embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
+ else:
+ embed_conv_w = adapt_input_conv(
+ model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
+ model.patch_embed.proj.weight.copy_(embed_conv_w)
+ model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
+ model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
+ pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
+ if pos_embed_w.shape != model.pos_embed.shape:
+ pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
+ pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
+ model.pos_embed.copy_(pos_embed_w)
+ model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
+ model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
+ if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
+ model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
+ model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
+ if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
+ model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
+ model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
+ for i, block in enumerate(model.blocks.children()):
+ block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
+ mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
+ block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
+ block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
+ block.attn.qkv.weight.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
+ block.attn.qkv.bias.copy_(torch.cat([
+ _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
+ block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
+ block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
+ for r in range(2):
+ getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
+ getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
+ block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
+ block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
+
+
+def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
+ _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
+ ntok_new = posemb_new.shape[1]
+ if num_tokens:
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
+ ntok_new -= num_tokens
+ else:
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
+ gs_old = int(math.sqrt(len(posemb_grid)))
+ if not len(gs_new): # backwards compatibility
+ gs_new = [int(math.sqrt(ntok_new))] * 2
+ assert len(gs_new) >= 2
+ _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+ return posemb
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
+ out_dict = {}
+ if 'model' in state_dict:
+ # For deit models
+ state_dict = state_dict['model']
+ for k, v in state_dict.items():
+ if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
+ # For old models that I trained prior to conv based patchification
+ O, I, H, W = model.patch_embed.proj.weight.shape
+ v = v.reshape(O, -1, H, W)
+ elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
+ # To resize pos embedding when using model at different size from pretrained weights
+ v = resize_pos_embed(
+ v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
+ default_cfg = default_cfg or default_cfgs[variant]
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ # NOTE this extra code to support handling of repr size for in21k pretrained models
+ default_num_classes = default_cfg['num_classes']
+ num_classes = kwargs.get('num_classes', default_num_classes)
+ repr_size = kwargs.pop('representation_size', None)
+ if repr_size is not None and num_classes != default_num_classes:
+ # Remove representation layer if fine-tuning. This may not always be the desired action,
+ # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
+ _logger.warning("Removing representation layer for fine-tuning.")
+ repr_size = None
+
+ model = build_model_with_cfg(
+ VisionTransformer, variant, pretrained,
+ default_cfg=default_cfg,
+ representation_size=repr_size,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ pretrained_custom_load='npz' in default_cfg['url'],
+ **kwargs)
+ return model
+
+
+@register_model
+def vit_tiny_patch16_224(pretrained=False, **kwargs):
+ """ ViT-Tiny (Vit-Ti/16)
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
+ model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_tiny_patch16_384(pretrained=False, **kwargs):
+ """ ViT-Tiny (Vit-Ti/16) @ 384x384.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
+ model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_small_patch32_224(pretrained=False, **kwargs):
+ """ ViT-Small (ViT-S/32)
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_small_patch32_384(pretrained=False, **kwargs):
+ """ ViT-Small (ViT-S/32) at 384x384.
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_small_patch16_224(pretrained=False, **kwargs):
+ """ ViT-Small (ViT-S/16)
+ NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_small_patch16_384(pretrained=False, **kwargs):
+ """ ViT-Small (ViT-S/16)
+ NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch32_224(pretrained=False, **kwargs):
+ """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch32_384(pretrained=False, **kwargs):
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_224(pretrained=False, **kwargs):
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_384(pretrained=False, **kwargs):
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch8_224(pretrained=False, **kwargs):
+ """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch32_224(pretrained=False, **kwargs):
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch32_384(pretrained=False, **kwargs):
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch16_224(pretrained=False, **kwargs):
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch16_384(pretrained=False, **kwargs):
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_sam_224(pretrained=False, **kwargs):
+ """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
+ """
+ # NOTE original SAM weights release worked with representation_size=768
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_sam_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch32_sam_224(pretrained=False, **kwargs):
+ """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
+ """
+ # NOTE original SAM weights release worked with representation_size=768
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs)
+ model = _create_vision_transformer('vit_base_patch32_sam_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Tiny (Vit-Ti/16).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
+ model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Small (ViT-S/16)
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Small (ViT-S/16)
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
+ """
+ model_kwargs = dict(
+ patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
+ """
+ model_kwargs = dict(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch8_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
+ """
+ model_kwargs = dict(
+ patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
+ """
+ model_kwargs = dict(
+ patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
+ model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
+ """
+ model_kwargs = dict(
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
+ """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
+ """
+ model_kwargs = dict(
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
+ model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def deit_tiny_patch16_224(pretrained=False, **kwargs):
+ """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
+ model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def deit_small_patch16_224(pretrained=False, **kwargs):
+ """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def deit_base_patch16_224(pretrained=False, **kwargs):
+ """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def deit_base_patch16_384(pretrained=False, **kwargs):
+ """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
+ """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
+ model = _create_vision_transformer(
+ 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
+ return model
+
+
+@register_model
+def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
+ """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer(
+ 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
+ return model
+
+
+@register_model
+def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
+ """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer(
+ 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
+ return model
+
+
+@register_model
+def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
+ """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer(
+ 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_224_miil(pretrained=False, **kwargs):
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
+ return model
diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py
new file mode 100644
index 0000000..d5f0a53
--- /dev/null
+++ b/timm/models/vision_transformer_hybrid.py
@@ -0,0 +1,363 @@
+""" Hybrid Vision Transformer (ViT) in PyTorch
+
+A PyTorch implement of the Hybrid Vision Transformers as described in:
+
+'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
+ - https://arxiv.org/abs/2010.11929
+
+`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
+ - https://arxiv.org/abs/2106.TODO
+
+NOTE These hybrid model definitions depend on code in vision_transformer.py.
+They were moved here to keep file sizes sane.
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from copy import deepcopy
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .layers import StdConv2dSame, StdConv2d, to_2tuple
+from .resnet import resnet26d, resnet50d
+from .resnetv2 import ResNetV2, create_resnetv2_stem
+from .registry import register_model
+from timm.models.vision_transformer import _create_vision_transformer
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+ 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # hybrid in-1k models (weights from official JAX impl where they exist)
+ 'vit_tiny_r_s16_p8_224': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
+ first_conv='patch_embed.backbone.conv'),
+ 'vit_tiny_r_s16_p8_384': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
+ first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0),
+ 'vit_small_r26_s32_224': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
+ ),
+ 'vit_small_r26_s32_384': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'vit_base_r26_s32_224': _cfg(),
+ 'vit_base_r50_s16_224': _cfg(),
+ 'vit_base_r50_s16_384': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'vit_large_r50_s32_224': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'
+ ),
+ 'vit_large_r50_s32_384': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/'
+ 'R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
+ input_size=(3, 384, 384), crop_pct=1.0
+ ),
+
+ # hybrid in-21k models (weights from official Google JAX impl where they exist)
+ 'vit_tiny_r_s16_p8_224_in21k': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
+ num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv'),
+ 'vit_small_r26_s32_224_in21k': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
+ num_classes=21843, crop_pct=0.9),
+ 'vit_base_r50_s16_224_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
+ num_classes=21843, crop_pct=0.9),
+ 'vit_large_r50_s32_224_in21k': _cfg(
+ url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
+ num_classes=21843, crop_pct=0.9),
+
+ # hybrid models (using timm resnet backbones)
+ 'vit_small_resnet26d_224': _cfg(
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
+ 'vit_small_resnet50d_s16_224': _cfg(
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
+ 'vit_base_resnet26d_224': _cfg(
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
+ 'vit_base_resnet50d_224': _cfg(
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
+}
+
+
+class HybridEmbed(nn.Module):
+ """ CNN Feature Map Embedding
+ Extract feature map from CNN, flatten, project to embedding dim.
+ """
+ def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
+ super().__init__()
+ assert isinstance(backbone, nn.Module)
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.backbone = backbone
+ if feature_size is None:
+ with torch.no_grad():
+ # NOTE Most reliable way of determining output dims is to run forward pass
+ training = backbone.training
+ if training:
+ backbone.eval()
+ o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
+ if isinstance(o, (list, tuple)):
+ o = o[-1] # last feature if backbone outputs list/tuple of features
+ feature_size = o.shape[-2:]
+ feature_dim = o.shape[1]
+ backbone.train(training)
+ else:
+ feature_size = to_2tuple(feature_size)
+ if hasattr(self.backbone, 'feature_info'):
+ feature_dim = self.backbone.feature_info.channels()[-1]
+ else:
+ feature_dim = self.backbone.num_features
+ assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
+ self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ x = self.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
+ embed_layer = partial(HybridEmbed, backbone=backbone)
+ kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
+ return _create_vision_transformer(
+ variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs)
+
+
+def _resnetv2(layers=(3, 4, 9), **kwargs):
+ """ ResNet-V2 backbone helper"""
+ padding_same = kwargs.get('padding_same', True)
+ stem_type = 'same' if padding_same else ''
+ conv_layer = partial(StdConv2dSame, eps=1e-8) if padding_same else partial(StdConv2d, eps=1e-8)
+ if len(layers):
+ backbone = ResNetV2(
+ layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
+ preact=False, stem_type=stem_type, conv_layer=conv_layer)
+ else:
+ backbone = create_resnetv2_stem(
+ kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer)
+ return backbone
+
+
+@register_model
+def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
+ """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
+ """
+ backbone = _resnetv2(layers=(), **kwargs)
+ model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
+ """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
+ """
+ backbone = _resnetv2(layers=(), **kwargs)
+ model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_small_r26_s32_224(pretrained=False, **kwargs):
+ """ R26+ViT-S/S32 hybrid.
+ """
+ backbone = _resnetv2((2, 2, 2, 2), **kwargs)
+ model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_small_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_small_r26_s32_384(pretrained=False, **kwargs):
+ """ R26+ViT-S/S32 hybrid.
+ """
+ backbone = _resnetv2((2, 2, 2, 2), **kwargs)
+ model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_r26_s32_224(pretrained=False, **kwargs):
+ """ R26+ViT-B/S32 hybrid.
+ """
+ backbone = _resnetv2((2, 2, 2, 2), **kwargs)
+ model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_base_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_r50_s16_224(pretrained=False, **kwargs):
+ """ R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
+ """
+ backbone = _resnetv2((3, 4, 9), **kwargs)
+ model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_base_r50_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_r50_s16_384(pretrained=False, **kwargs):
+ """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
+ """
+ backbone = _resnetv2((3, 4, 9), **kwargs)
+ model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_resnet50_384(pretrained=False, **kwargs):
+ # DEPRECATED this is forwarding to model def above for backwards compatibility
+ return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
+
+
+@register_model
+def vit_large_r50_s32_224(pretrained=False, **kwargs):
+ """ R50+ViT-L/S32 hybrid.
+ """
+ backbone = _resnetv2((3, 4, 6, 3), **kwargs)
+ model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_r50_s32_384(pretrained=False, **kwargs):
+ """ R50+ViT-L/S32 hybrid.
+ """
+ backbone = _resnetv2((3, 4, 6, 3), **kwargs)
+ model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs):
+ """ R+ViT-Ti/S16 w/ 8x8 patch hybrid. ImageNet-21k.
+ """
+ backbone = _resnetv2(layers=(), **kwargs)
+ model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_tiny_r_s16_p8_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
+ """ R26+ViT-S/S32 hybrid. ImageNet-21k.
+ """
+ backbone = _resnetv2((2, 2, 2, 2), **kwargs)
+ model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
+ """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
+ model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
+ # DEPRECATED this is forwarding to model def above for backwards compatibility
+ return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
+
+
+@register_model
+def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
+ """ R50+ViT-L/S32 hybrid. ImageNet-21k.
+ """
+ backbone = _resnetv2((3, 4, 6, 3), **kwargs)
+ model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_large_r50_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_small_resnet26d_224(pretrained=False, **kwargs):
+ """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
+ """
+ backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
+ model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_small_resnet50d_s16_224(pretrained=False, **kwargs):
+ """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
+ """
+ backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
+ model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_resnet26d_224(pretrained=False, **kwargs):
+ """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
+ """
+ backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
+ model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_resnet50d_224(pretrained=False, **kwargs):
+ """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
+ """
+ backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
+ model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer_hybrid(
+ 'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
+ return model
\ No newline at end of file
diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py
new file mode 100644
index 0000000..ec5b3e8
--- /dev/null
+++ b/timm/models/vovnet.py
@@ -0,0 +1,406 @@
+""" VoVNet (V1 & V2)
+
+Papers:
+* `An Energy and GPU-Computation Efficient Backbone Network` - https://arxiv.org/abs/1904.09730
+* `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
+
+Looked at https://github.com/youngwanLEE/vovnet-detectron2 &
+https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
+for some reference, rewrote most of the code.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .registry import register_model
+from .helpers import build_model_with_cfg
+from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, ClassifierHead, DropPath,\
+ create_attn, create_norm_act, get_norm_act_layer
+
+
+# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 &
+# https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
+model_cfgs = dict(
+ vovnet39a=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=5,
+ block_per_stage=[1, 1, 2, 2],
+ residual=False,
+ depthwise=False,
+ attn='',
+ ),
+ vovnet57a=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=5,
+ block_per_stage=[1, 1, 4, 3],
+ residual=False,
+ depthwise=False,
+ attn='',
+
+ ),
+ ese_vovnet19b_slim_dw=dict(
+ stem_chs=[64, 64, 64],
+ stage_conv_chs=[64, 80, 96, 112],
+ stage_out_chs=[112, 256, 384, 512],
+ layer_per_block=3,
+ block_per_stage=[1, 1, 1, 1],
+ residual=True,
+ depthwise=True,
+ attn='ese',
+
+ ),
+ ese_vovnet19b_dw=dict(
+ stem_chs=[64, 64, 64],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=3,
+ block_per_stage=[1, 1, 1, 1],
+ residual=True,
+ depthwise=True,
+ attn='ese',
+ ),
+ ese_vovnet19b_slim=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[64, 80, 96, 112],
+ stage_out_chs=[112, 256, 384, 512],
+ layer_per_block=3,
+ block_per_stage=[1, 1, 1, 1],
+ residual=True,
+ depthwise=False,
+ attn='ese',
+ ),
+ ese_vovnet19b=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=3,
+ block_per_stage=[1, 1, 1, 1],
+ residual=True,
+ depthwise=False,
+ attn='ese',
+
+ ),
+ ese_vovnet39b=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=5,
+ block_per_stage=[1, 1, 2, 2],
+ residual=True,
+ depthwise=False,
+ attn='ese',
+ ),
+ ese_vovnet57b=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=5,
+ block_per_stage=[1, 1, 4, 3],
+ residual=True,
+ depthwise=False,
+ attn='ese',
+
+ ),
+ ese_vovnet99b=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=5,
+ block_per_stage=[1, 3, 9, 3],
+ residual=True,
+ depthwise=False,
+ attn='ese',
+ ),
+ eca_vovnet39b=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=5,
+ block_per_stage=[1, 1, 2, 2],
+ residual=True,
+ depthwise=False,
+ attn='eca',
+ ),
+)
+model_cfgs['ese_vovnet39b_evos'] = model_cfgs['ese_vovnet39b']
+model_cfgs['ese_vovnet99b_iabn'] = model_cfgs['ese_vovnet99b']
+
+
+def _cfg(url=''):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
+ }
+
+
+default_cfgs = dict(
+ vovnet39a=_cfg(url=''),
+ vovnet57a=_cfg(url=''),
+ ese_vovnet19b_slim_dw=_cfg(url=''),
+ ese_vovnet19b_dw=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet19b_dw-a8741004.pth'),
+ ese_vovnet19b_slim=_cfg(url=''),
+ ese_vovnet39b=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet39b-f912fe73.pth'),
+ ese_vovnet57b=_cfg(url=''),
+ ese_vovnet99b=_cfg(url=''),
+ eca_vovnet39b=_cfg(url=''),
+ ese_vovnet39b_evos=_cfg(url=''),
+ ese_vovnet99b_iabn=_cfg(url=''),
+)
+
+
+class SequentialAppendList(nn.Sequential):
+ def __init__(self, *args):
+ super(SequentialAppendList, self).__init__(*args)
+
+ def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor:
+ for i, module in enumerate(self):
+ if i == 0:
+ concat_list.append(module(x))
+ else:
+ concat_list.append(module(concat_list[-1]))
+ x = torch.cat(concat_list, dim=1)
+ return x
+
+
+class OsaBlock(nn.Module):
+
+ def __init__(self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
+ depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None):
+ super(OsaBlock, self).__init__()
+
+ self.residual = residual
+ self.depthwise = depthwise
+ conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer)
+
+ next_in_chs = in_chs
+ if self.depthwise and next_in_chs != mid_chs:
+ assert not residual
+ self.conv_reduction = ConvBnAct(next_in_chs, mid_chs, 1, **conv_kwargs)
+ else:
+ self.conv_reduction = None
+
+ mid_convs = []
+ for i in range(layer_per_block):
+ if self.depthwise:
+ conv = SeparableConvBnAct(mid_chs, mid_chs, **conv_kwargs)
+ else:
+ conv = ConvBnAct(next_in_chs, mid_chs, 3, **conv_kwargs)
+ next_in_chs = mid_chs
+ mid_convs.append(conv)
+ self.conv_mid = SequentialAppendList(*mid_convs)
+
+ # feature aggregation
+ next_in_chs = in_chs + layer_per_block * mid_chs
+ self.conv_concat = ConvBnAct(next_in_chs, out_chs, **conv_kwargs)
+
+ if attn:
+ self.attn = create_attn(attn, out_chs)
+ else:
+ self.attn = None
+
+ self.drop_path = drop_path
+
+ def forward(self, x):
+ output = [x]
+ if self.conv_reduction is not None:
+ x = self.conv_reduction(x)
+ x = self.conv_mid(x, output)
+ x = self.conv_concat(x)
+ if self.attn is not None:
+ x = self.attn(x)
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+ if self.residual:
+ x = x + output[0]
+ return x
+
+
+class OsaStage(nn.Module):
+
+ def __init__(self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True,
+ residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU,
+ drop_path_rates=None):
+ super(OsaStage, self).__init__()
+
+ if downsample:
+ self.pool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
+ else:
+ self.pool = None
+
+ blocks = []
+ for i in range(block_per_stage):
+ last_block = i == block_per_stage - 1
+ if drop_path_rates is not None and drop_path_rates[i] > 0.:
+ drop_path = DropPath(drop_path_rates[i])
+ else:
+ drop_path = None
+ blocks += [OsaBlock(
+ in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise,
+ attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path)
+ ]
+ in_chs = out_chs
+ self.blocks = nn.Sequential(*blocks)
+
+ def forward(self, x):
+ if self.pool is not None:
+ x = self.pool(x)
+ x = self.blocks(x)
+ return x
+
+
+class VovNet(nn.Module):
+
+ def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
+ output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.):
+ """ VovNet (v2)
+ """
+ super(VovNet, self).__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ assert stem_stride in (4, 2)
+ assert output_stride == 32 # FIXME support dilation
+
+ stem_chs = cfg["stem_chs"]
+ stage_conv_chs = cfg["stage_conv_chs"]
+ stage_out_chs = cfg["stage_out_chs"]
+ block_per_stage = cfg["block_per_stage"]
+ layer_per_block = cfg["layer_per_block"]
+ conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer)
+
+ # Stem module
+ last_stem_stride = stem_stride // 2
+ conv_type = SeparableConvBnAct if cfg["depthwise"] else ConvBnAct
+ self.stem = nn.Sequential(*[
+ ConvBnAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs),
+ conv_type(stem_chs[0], stem_chs[1], 3, stride=1, **conv_kwargs),
+ conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, **conv_kwargs),
+ ])
+ self.feature_info = [dict(
+ num_chs=stem_chs[1], reduction=2, module=f'stem.{1 if stem_stride == 4 else 2}')]
+ current_stride = stem_stride
+
+ # OSA stages
+ stage_dpr = torch.split(torch.linspace(0, drop_path_rate, sum(block_per_stage)), block_per_stage)
+ in_ch_list = stem_chs[-1:] + stage_out_chs[:-1]
+ stage_args = dict(residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], **conv_kwargs)
+ stages = []
+ for i in range(4): # num_stages
+ downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
+ stages += [OsaStage(
+ in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block,
+ downsample=downsample, drop_path_rates=stage_dpr[i], **stage_args)
+ ]
+ self.num_features = stage_out_chs[i]
+ current_stride *= 2 if downsample else 1
+ self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
+
+ self.stages = nn.Sequential(*stages)
+
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ for n, m in self.named_modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1.)
+ nn.init.constant_(m.bias, 0.)
+ elif isinstance(m, nn.Linear):
+ nn.init.zeros_(m.bias)
+
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ return self.stages(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ return self.head(x)
+
+
+def _create_vovnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ VovNet, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ model_cfg=model_cfgs[variant],
+ feature_cfg=dict(flatten_sequential=True),
+ **kwargs)
+
+
+@register_model
+def vovnet39a(pretrained=False, **kwargs):
+ return _create_vovnet('vovnet39a', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def vovnet57a(pretrained=False, **kwargs):
+ return _create_vovnet('vovnet57a', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def ese_vovnet19b_slim_dw(pretrained=False, **kwargs):
+ return _create_vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def ese_vovnet19b_dw(pretrained=False, **kwargs):
+ return _create_vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def ese_vovnet19b_slim(pretrained=False, **kwargs):
+ return _create_vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def ese_vovnet39b(pretrained=False, **kwargs):
+ return _create_vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def ese_vovnet57b(pretrained=False, **kwargs):
+ return _create_vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def ese_vovnet99b(pretrained=False, **kwargs):
+ return _create_vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_vovnet39b(pretrained=False, **kwargs):
+ return _create_vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs)
+
+
+# Experimental Models
+
+@register_model
+def ese_vovnet39b_evos(pretrained=False, **kwargs):
+ def norm_act_fn(num_features, **nkwargs):
+ return create_norm_act('EvoNormSample', num_features, jit=False, **nkwargs)
+ return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)
+
+
+@register_model
+def ese_vovnet99b_iabn(pretrained=False, **kwargs):
+ norm_layer = get_norm_act_layer('iabn')
+ return _create_vovnet(
+ 'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs)
diff --git a/timm/models/xception.py b/timm/models/xception.py
new file mode 100644
index 0000000..86f558c
--- /dev/null
+++ b/timm/models/xception.py
@@ -0,0 +1,232 @@
+"""
+Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch)
+
+@author: tstandley
+Adapted by cadene
+
+Creates an Xception Model as defined in:
+
+Francois Chollet
+Xception: Deep Learning with Depthwise Separable Convolutions
+https://arxiv.org/pdf/1610.02357.pdf
+
+This weights ported from the Keras implementation. Achieves the following performance on the validation set:
+
+Loss:0.9173 Prec@1:78.892 Prec@5:94.292
+
+REMEMBER to set your image size to 3x299x299 for both test and validation
+
+normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
+ std=[0.5, 0.5, 0.5])
+
+The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
+"""
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .helpers import build_model_with_cfg
+from .layers import create_classifier
+from .registry import register_model
+
+__all__ = ['Xception']
+
+default_cfgs = {
+ 'xception': {
+ 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth',
+ 'input_size': (3, 299, 299),
+ 'pool_size': (10, 10),
+ 'crop_pct': 0.8975,
+ 'interpolation': 'bicubic',
+ 'mean': (0.5, 0.5, 0.5),
+ 'std': (0.5, 0.5, 0.5),
+ 'num_classes': 1000,
+ 'first_conv': 'conv1',
+ 'classifier': 'fc'
+ # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
+ }
+}
+
+
+class SeparableConv2d(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1):
+ super(SeparableConv2d, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=False)
+ self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=False)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.pointwise(x)
+ return x
+
+
+class Block(nn.Module):
+ def __init__(self, in_channels, out_channels, reps, strides=1, start_with_relu=True, grow_first=True):
+ super(Block, self).__init__()
+
+ if out_channels != in_channels or strides != 1:
+ self.skip = nn.Conv2d(in_channels, out_channels, 1, stride=strides, bias=False)
+ self.skipbn = nn.BatchNorm2d(out_channels)
+ else:
+ self.skip = None
+
+ rep = []
+ for i in range(reps):
+ if grow_first:
+ inc = in_channels if i == 0 else out_channels
+ outc = out_channels
+ else:
+ inc = in_channels
+ outc = in_channels if i < (reps - 1) else out_channels
+ rep.append(nn.ReLU(inplace=True))
+ rep.append(SeparableConv2d(inc, outc, 3, stride=1, padding=1))
+ rep.append(nn.BatchNorm2d(outc))
+
+ if not start_with_relu:
+ rep = rep[1:]
+ else:
+ rep[0] = nn.ReLU(inplace=False)
+
+ if strides != 1:
+ rep.append(nn.MaxPool2d(3, strides, 1))
+ self.rep = nn.Sequential(*rep)
+
+ def forward(self, inp):
+ x = self.rep(inp)
+
+ if self.skip is not None:
+ skip = self.skip(inp)
+ skip = self.skipbn(skip)
+ else:
+ skip = inp
+
+ x += skip
+ return x
+
+
+class Xception(nn.Module):
+ """
+ Xception optimized for the ImageNet dataset, as specified in
+ https://arxiv.org/pdf/1610.02357.pdf
+ """
+
+ def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'):
+ """ Constructor
+ Args:
+ num_classes: number of classes
+ """
+ super(Xception, self).__init__()
+ self.drop_rate = drop_rate
+ self.global_pool = global_pool
+ self.num_classes = num_classes
+ self.num_features = 2048
+
+ self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False)
+ self.bn1 = nn.BatchNorm2d(32)
+ self.act1 = nn.ReLU(inplace=True)
+
+ self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
+ self.bn2 = nn.BatchNorm2d(64)
+ self.act2 = nn.ReLU(inplace=True)
+
+ self.block1 = Block(64, 128, 2, 2, start_with_relu=False)
+ self.block2 = Block(128, 256, 2, 2)
+ self.block3 = Block(256, 728, 2, 2)
+
+ self.block4 = Block(728, 728, 3, 1)
+ self.block5 = Block(728, 728, 3, 1)
+ self.block6 = Block(728, 728, 3, 1)
+ self.block7 = Block(728, 728, 3, 1)
+
+ self.block8 = Block(728, 728, 3, 1)
+ self.block9 = Block(728, 728, 3, 1)
+ self.block10 = Block(728, 728, 3, 1)
+ self.block11 = Block(728, 728, 3, 1)
+
+ self.block12 = Block(728, 1024, 2, 2, grow_first=False)
+
+ self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
+ self.bn3 = nn.BatchNorm2d(1536)
+ self.act3 = nn.ReLU(inplace=True)
+
+ self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1)
+ self.bn4 = nn.BatchNorm2d(self.num_features)
+ self.act4 = nn.ReLU(inplace=True)
+ self.feature_info = [
+ dict(num_chs=64, reduction=2, module='act2'),
+ dict(num_chs=128, reduction=4, module='block2.rep.0'),
+ dict(num_chs=256, reduction=8, module='block3.rep.0'),
+ dict(num_chs=728, reduction=16, module='block12.rep.0'),
+ dict(num_chs=2048, reduction=32, module='act4'),
+ ]
+
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ # #------- init weights --------
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def get_classifier(self):
+ return self.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.act2(x)
+
+ x = self.block1(x)
+ x = self.block2(x)
+ x = self.block3(x)
+ x = self.block4(x)
+ x = self.block5(x)
+ x = self.block6(x)
+ x = self.block7(x)
+ x = self.block8(x)
+ x = self.block9(x)
+ x = self.block10(x)
+ x = self.block11(x)
+ x = self.block12(x)
+
+ x = self.conv3(x)
+ x = self.bn3(x)
+ x = self.act3(x)
+
+ x = self.conv4(x)
+ x = self.bn4(x)
+ x = self.act4(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ if self.drop_rate:
+ F.dropout(x, self.drop_rate, training=self.training)
+ x = self.fc(x)
+ return x
+
+
+def _xception(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ Xception, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ feature_cfg=dict(feature_cls='hook'),
+ **kwargs)
+
+
+@register_model
+def xception(pretrained=False, **kwargs):
+ return _xception('xception', pretrained=pretrained, **kwargs)
diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py
new file mode 100644
index 0000000..ea7f5c0
--- /dev/null
+++ b/timm/models/xception_aligned.py
@@ -0,0 +1,238 @@
+"""Pytorch impl of Aligned Xception 41, 65, 71
+
+This is a correct, from scratch impl of Aligned Xception (Deeplab) models compatible with TF weights at
+https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/model_zoo.md
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from functools import partial
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
+from .helpers import build_model_with_cfg
+from .layers import ClassifierHead, ConvBnAct, create_conv2d
+from .layers.helpers import to_3tuple
+from .registry import register_model
+
+__all__ = ['XceptionAligned']
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (10, 10),
+ 'crop_pct': 0.903, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
+ 'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+default_cfgs = dict(
+ xception41=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
+ xception65=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
+ xception71=_cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
+)
+
+
+class SeparableConv2d(nn.Module):
+ def __init__(
+ self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='',
+ act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
+ super(SeparableConv2d, self).__init__()
+ self.kernel_size = kernel_size
+ self.dilation = dilation
+
+ # depthwise convolution
+ self.conv_dw = create_conv2d(
+ inplanes, inplanes, kernel_size, stride=stride,
+ padding=padding, dilation=dilation, depthwise=True)
+ self.bn_dw = norm_layer(inplanes)
+ if act_layer is not None:
+ self.act_dw = act_layer(inplace=True)
+ else:
+ self.act_dw = None
+
+ # pointwise convolution
+ self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1)
+ self.bn_pw = norm_layer(planes)
+ if act_layer is not None:
+ self.act_pw = act_layer(inplace=True)
+ else:
+ self.act_pw = None
+
+ def forward(self, x):
+ x = self.conv_dw(x)
+ x = self.bn_dw(x)
+ if self.act_dw is not None:
+ x = self.act_dw(x)
+ x = self.conv_pw(x)
+ x = self.bn_pw(x)
+ if self.act_pw is not None:
+ x = self.act_pw(x)
+ return x
+
+
+class XceptionModule(nn.Module):
+ def __init__(
+ self, in_chs, out_chs, stride=1, dilation=1, pad_type='',
+ start_with_relu=True, no_skip=False, act_layer=nn.ReLU, norm_layer=None):
+ super(XceptionModule, self).__init__()
+ out_chs = to_3tuple(out_chs)
+ self.in_channels = in_chs
+ self.out_channels = out_chs[-1]
+ self.no_skip = no_skip
+ if not no_skip and (self.out_channels != self.in_channels or stride != 1):
+ self.shortcut = ConvBnAct(
+ in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, act_layer=None)
+ else:
+ self.shortcut = None
+
+ separable_act_layer = None if start_with_relu else act_layer
+ self.stack = nn.Sequential()
+ for i in range(3):
+ if start_with_relu:
+ self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0))
+ self.stack.add_module(f'conv{i + 1}', SeparableConv2d(
+ in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type,
+ act_layer=separable_act_layer, norm_layer=norm_layer))
+ in_chs = out_chs[i]
+
+ def forward(self, x):
+ skip = x
+ x = self.stack(x)
+ if self.shortcut is not None:
+ skip = self.shortcut(skip)
+ if not self.no_skip:
+ x = x + skip
+ return x
+
+
+class XceptionAligned(nn.Module):
+ """Modified Aligned Xception
+ """
+
+ def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32,
+ act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'):
+ super(XceptionAligned, self).__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ assert output_stride in (8, 16, 32)
+
+ layer_args = dict(act_layer=act_layer, norm_layer=norm_layer)
+ self.stem = nn.Sequential(*[
+ ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args),
+ ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args)
+ ])
+
+ curr_dilation = 1
+ curr_stride = 2
+ self.feature_info = []
+ self.blocks = nn.Sequential()
+ for i, b in enumerate(block_cfg):
+ b['dilation'] = curr_dilation
+ if b['stride'] > 1:
+ self.feature_info += [dict(
+ num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=f'blocks.{i}.stack.act3')]
+ next_stride = curr_stride * b['stride']
+ if next_stride > output_stride:
+ curr_dilation *= b['stride']
+ b['stride'] = 1
+ else:
+ curr_stride = next_stride
+ self.blocks.add_module(str(i), XceptionModule(**b, **layer_args))
+ self.num_features = self.blocks[-1].out_channels
+
+ self.feature_info += [dict(
+ num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
+
+ self.head = ClassifierHead(
+ in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ x = self.blocks(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def _xception(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ XceptionAligned, variant, pretrained,
+ default_cfg=default_cfgs[variant],
+ feature_cfg=dict(flatten_sequential=True, feature_cls='hook'),
+ **kwargs)
+
+
+@register_model
+def xception41(pretrained=False, **kwargs):
+ """ Modified Aligned Xception-41
+ """
+ block_cfg = [
+ # entry flow
+ dict(in_chs=64, out_chs=128, stride=2),
+ dict(in_chs=128, out_chs=256, stride=2),
+ dict(in_chs=256, out_chs=728, stride=2),
+ # middle flow
+ *([dict(in_chs=728, out_chs=728, stride=1)] * 8),
+ # exit flow
+ dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
+ dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
+ ]
+ model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
+ return _xception('xception41', pretrained=pretrained, **model_args)
+
+
+@register_model
+def xception65(pretrained=False, **kwargs):
+ """ Modified Aligned Xception-65
+ """
+ block_cfg = [
+ # entry flow
+ dict(in_chs=64, out_chs=128, stride=2),
+ dict(in_chs=128, out_chs=256, stride=2),
+ dict(in_chs=256, out_chs=728, stride=2),
+ # middle flow
+ *([dict(in_chs=728, out_chs=728, stride=1)] * 16),
+ # exit flow
+ dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
+ dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
+ ]
+ model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
+ return _xception('xception65', pretrained=pretrained, **model_args)
+
+
+@register_model
+def xception71(pretrained=False, **kwargs):
+ """ Modified Aligned Xception-71
+ """
+ block_cfg = [
+ # entry flow
+ dict(in_chs=64, out_chs=128, stride=2),
+ dict(in_chs=128, out_chs=256, stride=1),
+ dict(in_chs=256, out_chs=256, stride=2),
+ dict(in_chs=256, out_chs=728, stride=1),
+ dict(in_chs=728, out_chs=728, stride=2),
+ # middle flow
+ *([dict(in_chs=728, out_chs=728, stride=1)] * 16),
+ # exit flow
+ dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
+ dict(in_chs=1024, out_chs=(1536, 1536, 2048), stride=1, no_skip=True, start_with_relu=False),
+ ]
+ model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
+ return _xception('xception71', pretrained=pretrained, **model_args)
diff --git a/timm/models/xcit.py b/timm/models/xcit.py
new file mode 100644
index 0000000..ac5e802
--- /dev/null
+++ b/timm/models/xcit.py
@@ -0,0 +1,812 @@
+""" Cross-Covariance Image Transformer (XCiT) in PyTorch
+
+Same as the official implementation, with some minor adaptations.
+ - https://github.com/facebookresearch/xcit/blob/master/xcit.py
+
+Paper:
+ - https://arxiv.org/abs/2106.09681
+"""
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from .helpers import build_model_with_cfg
+from .vision_transformer import _cfg, Mlp
+from .registry import register_model
+from .layers import DropPath, trunc_normal_, to_2tuple
+from .cait import ClassAttn
+from .fx_features import register_notrace_module
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.proj.0.0', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # Patch size 16
+ 'xcit_nano_12_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224.pth'),
+ 'xcit_nano_12_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_224_dist.pth'),
+ 'xcit_nano_12_p16_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p16_384_dist.pth', input_size=(3, 384, 384)),
+ 'xcit_tiny_12_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224.pth'),
+ 'xcit_tiny_12_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_224_dist.pth'),
+ 'xcit_tiny_12_p16_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p16_384_dist.pth', input_size=(3, 384, 384)),
+ 'xcit_tiny_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224.pth'),
+ 'xcit_tiny_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_224_dist.pth'),
+ 'xcit_tiny_24_p16_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p16_384_dist.pth', input_size=(3, 384, 384)),
+ 'xcit_small_12_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224.pth'),
+ 'xcit_small_12_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_224_dist.pth'),
+ 'xcit_small_12_p16_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p16_384_dist.pth', input_size=(3, 384, 384)),
+ 'xcit_small_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224.pth'),
+ 'xcit_small_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_224_dist.pth'),
+ 'xcit_small_24_p16_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p16_384_dist.pth', input_size=(3, 384, 384)),
+ 'xcit_medium_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224.pth'),
+ 'xcit_medium_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_224_dist.pth'),
+ 'xcit_medium_24_p16_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p16_384_dist.pth', input_size=(3, 384, 384)),
+ 'xcit_large_24_p16_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224.pth'),
+ 'xcit_large_24_p16_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_224_dist.pth'),
+ 'xcit_large_24_p16_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p16_384_dist.pth', input_size=(3, 384, 384)),
+
+ # Patch size 8
+ 'xcit_nano_12_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224.pth'),
+ 'xcit_nano_12_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_224_dist.pth'),
+ 'xcit_nano_12_p8_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_nano_12_p8_384_dist.pth', input_size=(3, 384, 384)),
+ 'xcit_tiny_12_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224.pth'),
+ 'xcit_tiny_12_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_224_dist.pth'),
+ 'xcit_tiny_12_p8_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_12_p8_384_dist.pth', input_size=(3, 384, 384)),
+ 'xcit_tiny_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224.pth'),
+ 'xcit_tiny_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_224_dist.pth'),
+ 'xcit_tiny_24_p8_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_tiny_24_p8_384_dist.pth', input_size=(3, 384, 384)),
+ 'xcit_small_12_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224.pth'),
+ 'xcit_small_12_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_224_dist.pth'),
+ 'xcit_small_12_p8_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_small_12_p8_384_dist.pth', input_size=(3, 384, 384)),
+ 'xcit_small_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224.pth'),
+ 'xcit_small_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_224_dist.pth'),
+ 'xcit_small_24_p8_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_small_24_p8_384_dist.pth', input_size=(3, 384, 384)),
+ 'xcit_medium_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224.pth'),
+ 'xcit_medium_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_224_dist.pth'),
+ 'xcit_medium_24_p8_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_medium_24_p8_384_dist.pth', input_size=(3, 384, 384)),
+ 'xcit_large_24_p8_224': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224.pth'),
+ 'xcit_large_24_p8_224_dist': _cfg(url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_224_dist.pth'),
+ 'xcit_large_24_p8_384_dist': _cfg(
+ url='https://dl.fbaipublicfiles.com/xcit/xcit_large_24_p8_384_dist.pth', input_size=(3, 384, 384)),
+}
+
+
+@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
+class PositionalEncodingFourier(nn.Module):
+ """
+ Positional encoding relying on a fourier kernel matching the one used in the "Attention is all of Need" paper.
+ Based on the official XCiT code
+ - https://github.com/facebookresearch/xcit/blob/master/xcit.py
+ """
+
+ def __init__(self, hidden_dim=32, dim=768, temperature=10000):
+ super().__init__()
+ self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1)
+ self.scale = 2 * math.pi
+ self.temperature = temperature
+ self.hidden_dim = hidden_dim
+ self.dim = dim
+ self.eps = 1e-6
+
+ def forward(self, B: int, H: int, W: int):
+ device = self.token_projection.weight.device
+ y_embed = torch.arange(1, H+1, dtype=torch.float32, device=device).unsqueeze(1).repeat(1, 1, W)
+ x_embed = torch.arange(1, W+1, dtype=torch.float32, device=device).repeat(1, H, 1)
+ y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
+ dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=device)
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack([pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], dim=4).flatten(3)
+ pos_y = torch.stack([pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ pos = self.token_projection(pos)
+ return pos.repeat(B, 1, 1, 1) # (B, C, H, W)
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution + batch norm"""
+ return torch.nn.Sequential(
+ nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False),
+ nn.BatchNorm2d(out_planes)
+ )
+
+
+class ConvPatchEmbed(nn.Module):
+ """Image to Patch Embedding using multiple convolutional layers"""
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, act_layer=nn.GELU):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+
+ if patch_size == 16:
+ self.proj = torch.nn.Sequential(
+ conv3x3(in_chans, embed_dim // 8, 2),
+ act_layer(),
+ conv3x3(embed_dim // 8, embed_dim // 4, 2),
+ act_layer(),
+ conv3x3(embed_dim // 4, embed_dim // 2, 2),
+ act_layer(),
+ conv3x3(embed_dim // 2, embed_dim, 2),
+ )
+ elif patch_size == 8:
+ self.proj = torch.nn.Sequential(
+ conv3x3(in_chans, embed_dim // 4, 2),
+ act_layer(),
+ conv3x3(embed_dim // 4, embed_dim // 2, 2),
+ act_layer(),
+ conv3x3(embed_dim // 2, embed_dim, 2),
+ )
+ else:
+ raise('For convolutional projection, patch size has to be in [8, 16]')
+
+ def forward(self, x):
+ x = self.proj(x)
+ Hp, Wp = x.shape[2], x.shape[3]
+ x = x.flatten(2).transpose(1, 2) # (B, N, C)
+ return x, (Hp, Wp)
+
+
+class LPI(nn.Module):
+ """
+ Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows to augment the
+ implicit communication performed by the block diagonal scatter attention. Implemented using 2 layers of separable
+ 3x3 convolutions with GeLU and BatchNorm2d
+ """
+
+ def __init__(self, in_features, out_features=None, act_layer=nn.GELU, kernel_size=3):
+ super().__init__()
+ out_features = out_features or in_features
+
+ padding = kernel_size // 2
+
+ self.conv1 = torch.nn.Conv2d(
+ in_features, in_features, kernel_size=kernel_size, padding=padding, groups=in_features)
+ self.act = act_layer()
+ self.bn = nn.BatchNorm2d(in_features)
+ self.conv2 = torch.nn.Conv2d(
+ in_features, out_features, kernel_size=kernel_size, padding=padding, groups=out_features)
+
+ def forward(self, x, H: int, W: int):
+ B, N, C = x.shape
+ x = x.permute(0, 2, 1).reshape(B, C, H, W)
+ x = self.conv1(x)
+ x = self.act(x)
+ x = self.bn(x)
+ x = self.conv2(x)
+ x = x.reshape(B, C, N).permute(0, 2, 1)
+ return x
+
+
+class ClassAttentionBlock(nn.Module):
+ """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239"""
+
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1., tokens_norm=False):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+
+ self.attn = ClassAttn(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
+
+ if eta is not None: # LayerScale Initialization (no layerscale when None)
+ self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
+ self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
+ else:
+ self.gamma1, self.gamma2 = 1.0, 1.0
+
+ # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721
+ self.tokens_norm = tokens_norm
+
+ def forward(self, x):
+ x_norm1 = self.norm1(x)
+ x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1)
+ x = x + self.drop_path(self.gamma1 * x_attn)
+ if self.tokens_norm:
+ x = self.norm2(x)
+ else:
+ x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1)
+ x_res = x
+ cls_token = x[:, 0:1]
+ cls_token = self.gamma2 * self.mlp(cls_token)
+ x = torch.cat([cls_token, x[:, 1:]], dim=1)
+ x = x_res + self.drop_path(x)
+ return x
+
+
+class XCA(nn.Module):
+ """ Cross-Covariance Attention (XCA)
+ Operation where the channels are updated using a weighted sum. The weights are obtained from the (softmax
+ normalized) Cross-covariance matrix (Q^T \\cdot K \\in d_h \\times d_h)
+ """
+
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.num_heads = num_heads
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ # Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N)
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+
+ # Paper section 3.2 l2-Normalization and temperature scaling
+ q = torch.nn.functional.normalize(q, dim=-1)
+ k = torch.nn.functional.normalize(k, dim=-1)
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ # (B, H, C', N), permute -> (B, N, H, C')
+ x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'temperature'}
+
+
+class XCABlock(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=1.):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm3 = norm_layer(dim)
+ self.local_mp = LPI(in_features=dim, act_layer=act_layer)
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
+
+ self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
+ self.gamma3 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
+ self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
+
+ def forward(self, x, H: int, W: int):
+ x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
+ # NOTE official code has 3 then 2, so keeping it the same to be consistent with loaded weights
+ # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721
+ x = x + self.drop_path(self.gamma3 * self.local_mp(self.norm3(x), H, W))
+ x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class XCiT(nn.Module):
+ """
+ Based on timm and DeiT code bases
+ https://github.com/rwightman/pytorch-image-models/tree/master/timm
+ https://github.com/facebookresearch/deit/
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
+ act_layer=None, norm_layer=None, cls_attn_layers=2, use_pos_embed=True, eta=1., tokens_norm=False):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ drop_rate (float): dropout rate after positional embedding, and in XCA/CA projection + MLP
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate (constant across all layers)
+ norm_layer: (nn.Module): normalization layer
+ cls_attn_layers: (int) Depth of Class attention layers
+ use_pos_embed: (bool) whether to use positional encoding
+ eta: (float) layerscale initialization value
+ tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA
+
+ Notes:
+ - Although `layer_norm` is user specifiable, there are hard-coded `BatchNorm2d`s in the local patch
+ interaction (class LPI) and the patch embedding (class ConvPatchEmbed)
+ """
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ assert (img_size[0] % patch_size == 0) and (img_size[0] % patch_size == 0), \
+ '`patch_size` should divide image dimensions evenly'
+
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ act_layer = act_layer or nn.GELU
+
+ self.patch_embed = ConvPatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, act_layer=act_layer)
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.use_pos_embed = use_pos_embed
+ if use_pos_embed:
+ self.pos_embed = PositionalEncodingFourier(dim=embed_dim)
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ self.blocks = nn.ModuleList([
+ XCABlock(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
+ attn_drop=attn_drop_rate, drop_path=drop_path_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta)
+ for _ in range(depth)])
+
+ self.cls_attn_blocks = nn.ModuleList([
+ ClassAttentionBlock(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
+ attn_drop=attn_drop_rate, act_layer=act_layer, norm_layer=norm_layer, eta=eta, tokens_norm=tokens_norm)
+ for _ in range(cls_attn_layers)])
+
+ # Classifier head
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ # Init weights
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ # x is (B, N, C). (Hp, Hw) is (height in units of patches, width in units of patches)
+ x, (Hp, Wp) = self.patch_embed(x)
+
+ if self.use_pos_embed:
+ # `pos_embed` (B, C, Hp, Wp), reshape -> (B, C, N), permute -> (B, N, C)
+ pos_encoding = self.pos_embed(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
+ x = x + pos_encoding
+
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x, Hp, Wp)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ for blk in self.cls_attn_blocks:
+ x = blk(x)
+
+ x = self.norm(x)[:, 0]
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.head(x)
+ return x
+
+
+def checkpoint_filter_fn(state_dict, model):
+ if 'model' in state_dict:
+ state_dict = state_dict['model']
+ # For consistency with timm's transformer models while being compatible with official weights source we rename
+ # pos_embeder to pos_embed. Also account for use_pos_embed == False
+ use_pos_embed = getattr(model, 'pos_embed', None) is not None
+ pos_embed_keys = [k for k in state_dict if k.startswith('pos_embed')]
+ for k in pos_embed_keys:
+ if use_pos_embed:
+ state_dict[k.replace('pos_embeder.', 'pos_embed.')] = state_dict.pop(k)
+ else:
+ del state_dict[k]
+ # timm's implementation of class attention in CaiT is slightly more efficient as it does not compute query vectors
+ # for all tokens, just the class token. To use official weights source we must split qkv into q, k, v
+ if 'cls_attn_blocks.0.attn.qkv.weight' in state_dict and 'cls_attn_blocks.0.attn.q.weight' in model.state_dict():
+ num_ca_blocks = len(model.cls_attn_blocks)
+ for i in range(num_ca_blocks):
+ qkv_weight = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.weight')
+ qkv_weight = qkv_weight.reshape(3, -1, qkv_weight.shape[-1])
+ for j, subscript in enumerate('qkv'):
+ state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.weight'] = qkv_weight[j]
+ qkv_bias = state_dict.pop(f'cls_attn_blocks.{i}.attn.qkv.bias', None)
+ if qkv_bias is not None:
+ qkv_bias = qkv_bias.reshape(3, -1)
+ for j, subscript in enumerate('qkv'):
+ state_dict[f'cls_attn_blocks.{i}.attn.{subscript}.bias'] = qkv_bias[j]
+ return state_dict
+
+
+def _create_xcit(variant, pretrained=False, default_cfg=None, **kwargs):
+ default_cfg = default_cfg or default_cfgs[variant]
+ model = build_model_with_cfg(
+ XCiT, variant, pretrained, default_cfg=default_cfg, pretrained_filter_fn=checkpoint_filter_fn, **kwargs)
+ return model
+
+
+@register_model
+def xcit_nano_12_p16_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
+ model = _create_xcit('xcit_nano_12_p16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_nano_12_p16_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
+ model = _create_xcit('xcit_nano_12_p16_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_nano_12_p16_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, img_size=384, **kwargs)
+ model = _create_xcit('xcit_nano_12_p16_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_tiny_12_p16_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_tiny_12_p16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_tiny_12_p16_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_tiny_12_p16_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_tiny_12_p16_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_tiny_12_p16_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_small_12_p16_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_small_12_p16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_small_12_p16_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_small_12_p16_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_small_12_p16_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_small_12_p16_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_tiny_24_p16_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_tiny_24_p16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_tiny_24_p16_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_tiny_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_tiny_24_p16_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_tiny_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_small_24_p16_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_small_24_p16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_small_24_p16_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_small_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_small_24_p16_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_small_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_medium_24_p16_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_medium_24_p16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_medium_24_p16_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_medium_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_medium_24_p16_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_medium_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_large_24_p16_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_large_24_p16_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_large_24_p16_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_large_24_p16_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_large_24_p16_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=16, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_large_24_p16_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+# Patch size 8x8 models
+@register_model
+def xcit_nano_12_p8_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
+ model = _create_xcit('xcit_nano_12_p8_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_nano_12_p8_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
+ model = _create_xcit('xcit_nano_12_p8_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_nano_12_p8_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=128, depth=12, num_heads=4, eta=1.0, tokens_norm=False, **kwargs)
+ model = _create_xcit('xcit_nano_12_p8_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_tiny_12_p8_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_tiny_12_p8_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_tiny_12_p8_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_tiny_12_p8_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_tiny_12_p8_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=192, depth=12, num_heads=4, eta=1.0, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_tiny_12_p8_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_small_12_p8_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_small_12_p8_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_small_12_p8_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_small_12_p8_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_small_12_p8_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=384, depth=12, num_heads=8, eta=1.0, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_small_12_p8_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_tiny_24_p8_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_tiny_24_p8_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_tiny_24_p8_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_tiny_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_tiny_24_p8_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=192, depth=24, num_heads=4, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_tiny_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_small_24_p8_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_small_24_p8_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_small_24_p8_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_small_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_small_24_p8_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=384, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_small_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_medium_24_p8_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_medium_24_p8_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_medium_24_p8_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_medium_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_medium_24_p8_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=512, depth=24, num_heads=8, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_medium_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_large_24_p8_224(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_large_24_p8_224', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_large_24_p8_224_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_large_24_p8_224_dist', pretrained=pretrained, **model_kwargs)
+ return model
+
+
+@register_model
+def xcit_large_24_p8_384_dist(pretrained=False, **kwargs):
+ model_kwargs = dict(
+ patch_size=8, embed_dim=768, depth=24, num_heads=16, eta=1e-5, tokens_norm=True, **kwargs)
+ model = _create_xcit('xcit_large_24_p8_384_dist', pretrained=pretrained, **model_kwargs)
+ return model
diff --git a/timm/optim/__init__.py b/timm/optim/__init__.py
new file mode 100644
index 0000000..7ee4958
--- /dev/null
+++ b/timm/optim/__init__.py
@@ -0,0 +1,15 @@
+from .adabelief import AdaBelief
+from .adafactor import Adafactor
+from .adahessian import Adahessian
+from .adamp import AdamP
+from .adamw import AdamW
+from .lamb import Lamb
+from .lars import Lars
+from .lookahead import Lookahead
+from .madgrad import MADGRAD
+from .nadam import Nadam
+from .nvnovograd import NvNovoGrad
+from .radam import RAdam
+from .rmsprop_tf import RMSpropTF
+from .sgdp import SGDP
+from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
diff --git a/timm/optim/adabelief.py b/timm/optim/adabelief.py
new file mode 100644
index 0000000..951d715
--- /dev/null
+++ b/timm/optim/adabelief.py
@@ -0,0 +1,201 @@
+import math
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class AdaBelief(Optimizer):
+ r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-16)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
+ (default: False)
+ decoupled_decay (boolean, optional): (default: True) If set as True, then
+ the optimizer uses decoupled weight decay as in AdamW
+ fixed_decay (boolean, optional): (default: False) This is used when weight_decouple
+ is set as True.
+ When fixed_decay == True, the weight decay is performed as
+ $W_{new} = W_{old} - W_{old} \times decay$.
+ When fixed_decay == False, the weight decay is performed as
+ $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the
+ weight decay ratio decreases with learning rate (lr).
+ rectify (boolean, optional): (default: True) If set as True, then perform the rectified
+ update similar to RAdam
+ degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update
+ when variance of gradient is high
+ reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020
+
+ For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer'
+ For example train/args for EfficientNet see these gists
+ - link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037
+ - link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3
+ """
+
+ def __init__(
+ self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False,
+ decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True):
+
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+
+ if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
+ for param in params:
+ if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
+ param['buffer'] = [[None, None, None] for _ in range(10)]
+
+ defaults = dict(
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad,
+ degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify,
+ fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)])
+ super(AdaBelief, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(AdaBelief, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('amsgrad', False)
+
+ @torch.no_grad()
+ def reset(self):
+ for group in self.param_groups:
+ for p in group['params']:
+ state = self.state[p]
+ amsgrad = group['amsgrad']
+
+ # State initialization
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p)
+
+ # Exponential moving average of squared gradient values
+ state['exp_avg_var'] = torch.zeros_like(p)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_var'] = torch.zeros_like(p)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad
+ if grad.dtype in {torch.float16, torch.bfloat16}:
+ grad = grad.float()
+ if grad.is_sparse:
+ raise RuntimeError(
+ 'AdaBelief does not support sparse gradients, please consider SparseAdam instead')
+
+ p_fp32 = p
+ if p.dtype in {torch.float16, torch.bfloat16}:
+ p_fp32 = p_fp32.float()
+
+ amsgrad = group['amsgrad']
+ beta1, beta2 = group['betas']
+ state = self.state[p]
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p_fp32)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_var'] = torch.zeros_like(p_fp32)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_var'] = torch.zeros_like(p_fp32)
+
+ # perform weight decay, check if decoupled weight decay
+ if group['decoupled_decay']:
+ if not group['fixed_decay']:
+ p_fp32.mul_(1.0 - group['lr'] * group['weight_decay'])
+ else:
+ p_fp32.mul_(1.0 - group['weight_decay'])
+ else:
+ if group['weight_decay'] != 0:
+ grad.add_(p_fp32, alpha=group['weight_decay'])
+
+ # get current state variable
+ exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
+
+ state['step'] += 1
+ bias_correction1 = 1 - beta1 ** state['step']
+ bias_correction2 = 1 - beta2 ** state['step']
+
+ # Update first and second moment running average
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+ grad_residual = grad - exp_avg
+ exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)
+
+ if amsgrad:
+ max_exp_avg_var = state['max_exp_avg_var']
+ # Maintains the maximum of all 2nd moment running avg. till now
+ torch.max(max_exp_avg_var, exp_avg_var.add_(group['eps']), out=max_exp_avg_var)
+
+ # Use the max. for normalizing running avg. of gradient
+ denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
+ else:
+ denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
+
+ # update
+ if not group['rectify']:
+ # Default update
+ step_size = group['lr'] / bias_correction1
+ p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
+ else:
+ # Rectified update, forked from RAdam
+ buffered = group['buffer'][int(state['step'] % 10)]
+ if state['step'] == buffered[0]:
+ num_sma, step_size = buffered[1], buffered[2]
+ else:
+ buffered[0] = state['step']
+ beta2_t = beta2 ** state['step']
+ num_sma_max = 2 / (1 - beta2) - 1
+ num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+ buffered[1] = num_sma
+
+ # more conservative since it's an approximated value
+ if num_sma >= 5:
+ step_size = math.sqrt(
+ (1 - beta2_t) *
+ (num_sma - 4) / (num_sma_max - 4) *
+ (num_sma - 2) / num_sma *
+ num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step'])
+ elif group['degenerated_to_sgd']:
+ step_size = 1.0 / (1 - beta1 ** state['step'])
+ else:
+ step_size = -1
+ buffered[2] = step_size
+
+ if num_sma >= 5:
+ denom = exp_avg_var.sqrt().add_(group['eps'])
+ p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
+ elif step_size > 0:
+ p_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
+
+ if p.dtype in {torch.float16, torch.bfloat16}:
+ p.copy_(p_fp32)
+
+ return loss
diff --git a/timm/optim/adafactor.py b/timm/optim/adafactor.py
new file mode 100644
index 0000000..0605743
--- /dev/null
+++ b/timm/optim/adafactor.py
@@ -0,0 +1,167 @@
+""" Adafactor Optimizer
+
+Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
+
+Original header/copyright below.
+
+"""
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+import torch
+import math
+
+
+class Adafactor(torch.optim.Optimizer):
+ """Implements Adafactor algorithm.
+ This implementation is based on: `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost`
+ (see https://arxiv.org/abs/1804.04235)
+
+ Note that this optimizer internally adjusts the learning rate depending on the
+ *scale_parameter*, *relative_step* and *warmup_init* options.
+
+ To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
+ `relative_step=False`.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups
+ lr (float, optional): external learning rate (default: None)
+ eps (tuple[float, float]): regularization constants for square gradient
+ and parameter scale respectively (default: (1e-30, 1e-3))
+ clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0)
+ decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8)
+ beta1 (float): coefficient used for computing running averages of gradient (default: None)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
+ warmup_init (bool): time-dependent learning rate computation depends on
+ whether warm-up initialization is being used (default: False)
+ """
+
+ def __init__(self, params, lr=None, eps=1e-30, eps_scale=1e-3, clip_threshold=1.0,
+ decay_rate=-0.8, betas=None, weight_decay=0.0, scale_parameter=True, warmup_init=False):
+ relative_step = not lr
+ if warmup_init and not relative_step:
+ raise ValueError('warmup_init requires relative_step=True')
+
+ beta1 = None if betas is None else betas[0] # make it compat with standard betas arg
+ defaults = dict(lr=lr, eps=eps, eps_scale=eps_scale, clip_threshold=clip_threshold, decay_rate=decay_rate,
+ beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter,
+ relative_step=relative_step, warmup_init=warmup_init)
+ super(Adafactor, self).__init__(params, defaults)
+
+ @staticmethod
+ def _get_lr(param_group, param_state):
+ if param_group['relative_step']:
+ min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2
+ lr_t = min(min_step, 1.0 / math.sqrt(param_state['step']))
+ param_scale = 1.0
+ if param_group['scale_parameter']:
+ param_scale = max(param_group['eps_scale'], param_state['RMS'])
+ param_group['lr'] = lr_t * param_scale
+ return param_group['lr']
+
+ @staticmethod
+ def _get_options(param_group, param_shape):
+ factored = len(param_shape) >= 2
+ use_first_moment = param_group['beta1'] is not None
+ return factored, use_first_moment
+
+ @staticmethod
+ def _rms(tensor):
+ return tensor.norm(2) / (tensor.numel() ** 0.5)
+
+ def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col):
+ r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
+ c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
+ return torch.mul(r_factor, c_factor)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad
+ if grad.dtype in {torch.float16, torch.bfloat16}:
+ grad = grad.float()
+ if grad.is_sparse:
+ raise RuntimeError('Adafactor does not support sparse gradients.')
+
+ state = self.state[p]
+
+ factored, use_first_moment = self._get_options(group, grad.shape)
+ # State Initialization
+ if len(state) == 0:
+ state['step'] = 0
+
+ if use_first_moment:
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(grad)
+ if factored:
+ state['exp_avg_sq_row'] = torch.zeros(grad.shape[:-1]).to(grad)
+ state['exp_avg_sq_col'] = torch.zeros(grad.shape[:-2] + grad.shape[-1:]).to(grad)
+ else:
+ state['exp_avg_sq'] = torch.zeros_like(grad)
+
+ state['RMS'] = 0
+ else:
+ if use_first_moment:
+ state['exp_avg'] = state['exp_avg'].to(grad)
+ if factored:
+ state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad)
+ state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad)
+ else:
+ state['exp_avg_sq'] = state['exp_avg_sq'].to(grad)
+
+ p_fp32 = p
+ if p.dtype in {torch.float16, torch.bfloat16}:
+ p_fp32 = p_fp32.float()
+
+ state['step'] += 1
+ state['RMS'] = self._rms(p_fp32)
+ lr_t = self._get_lr(group, state)
+
+ beta2t = 1.0 - math.pow(state['step'], group['decay_rate'])
+ update = grad ** 2 + group['eps']
+ if factored:
+ exp_avg_sq_row = state['exp_avg_sq_row']
+ exp_avg_sq_col = state['exp_avg_sq_col']
+
+ exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=1.0 - beta2t)
+ exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=1.0 - beta2t)
+
+ # Approximation of exponential moving average of square of gradient
+ update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
+ update.mul_(grad)
+ else:
+ exp_avg_sq = state['exp_avg_sq']
+
+ exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t)
+ update = exp_avg_sq.rsqrt().mul_(grad)
+
+ update.div_((self._rms(update) / group['clip_threshold']).clamp_(min=1.0))
+ update.mul_(lr_t)
+
+ if use_first_moment:
+ exp_avg = state['exp_avg']
+ exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1'])
+ update = exp_avg
+
+ if group['weight_decay'] != 0:
+ p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t)
+
+ p_fp32.add_(-update)
+ if p.dtype in {torch.float16, torch.bfloat16}:
+ p.copy_(p_fp32)
+
+ return loss
diff --git a/timm/optim/adahessian.py b/timm/optim/adahessian.py
new file mode 100644
index 0000000..985c67c
--- /dev/null
+++ b/timm/optim/adahessian.py
@@ -0,0 +1,156 @@
+""" AdaHessian Optimizer
+
+Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
+Originally licensed MIT, Copyright 2020, David Samuel
+"""
+import torch
+
+
+class Adahessian(torch.optim.Optimizer):
+ """
+ Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups
+ lr (float, optional): learning rate (default: 0.1)
+ betas ((float, float), optional): coefficients used for computing running averages of gradient and the
+ squared hessian trace (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
+ hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
+ update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
+ (to save time) (default: 1)
+ n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
+ """
+
+ def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
+ hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False):
+ if not 0.0 <= lr:
+ raise ValueError(f"Invalid learning rate: {lr}")
+ if not 0.0 <= eps:
+ raise ValueError(f"Invalid epsilon value: {eps}")
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
+ if not 0.0 <= hessian_power <= 1.0:
+ raise ValueError(f"Invalid Hessian power value: {hessian_power}")
+
+ self.n_samples = n_samples
+ self.update_each = update_each
+ self.avg_conv_kernel = avg_conv_kernel
+
+ # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
+ self.seed = 2147483647
+ self.generator = torch.Generator().manual_seed(self.seed)
+
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
+ super(Adahessian, self).__init__(params, defaults)
+
+ for p in self.get_params():
+ p.hess = 0.0
+ self.state[p]["hessian step"] = 0
+
+ @property
+ def is_second_order(self):
+ return True
+
+ def get_params(self):
+ """
+ Gets all parameters in all param_groups with gradients
+ """
+
+ return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
+
+ def zero_hessian(self):
+ """
+ Zeros out the accumalated hessian traces.
+ """
+
+ for p in self.get_params():
+ if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0:
+ p.hess.zero_()
+
+ @torch.no_grad()
+ def set_hessian(self):
+ """
+ Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
+ """
+
+ params = []
+ for p in filter(lambda p: p.grad is not None, self.get_params()):
+ if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step
+ params.append(p)
+ self.state[p]["hessian step"] += 1
+
+ if len(params) == 0:
+ return
+
+ if self.generator.device != params[0].device: # hackish way of casting the generator to the right device
+ self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
+
+ grads = [p.grad for p in params]
+
+ for i in range(self.n_samples):
+ # Rademacher distribution {-1.0, 1.0}
+ zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
+ h_zs = torch.autograd.grad(
+ grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1)
+ for h_z, z, p in zip(h_zs, zs, params):
+ p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """
+ Performs a single optimization step.
+ Arguments:
+ closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
+ """
+
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ self.zero_hessian()
+ self.set_hessian()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None or p.hess is None:
+ continue
+
+ if self.avg_conv_kernel and p.dim() == 4:
+ p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
+
+ # Perform correct stepweight decay as in AdamW
+ p.mul_(1 - group['lr'] * group['weight_decay'])
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 1:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p)
+ # Exponential moving average of Hessian diagonal square values
+ state['exp_hessian_diag_sq'] = torch.zeros_like(p)
+
+ exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
+ beta1, beta2 = group['betas']
+ state['step'] += 1
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
+ exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
+
+ bias_correction1 = 1 - beta1 ** state['step']
+ bias_correction2 = 1 - beta2 ** state['step']
+
+ k = group['hessian_power']
+ denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])
+
+ # make update
+ step_size = group['lr'] / bias_correction1
+ p.addcdiv_(exp_avg, denom, value=-step_size)
+
+ return loss
diff --git a/timm/optim/adamp.py b/timm/optim/adamp.py
new file mode 100644
index 0000000..ee18763
--- /dev/null
+++ b/timm/optim/adamp.py
@@ -0,0 +1,105 @@
+"""
+AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
+
+Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
+Code: https://github.com/clovaai/AdamP
+
+Copyright (c) 2020-present NAVER Corp.
+MIT license
+"""
+
+import torch
+import torch.nn.functional as F
+from torch.optim.optimizer import Optimizer
+import math
+
+
+def _channel_view(x) -> torch.Tensor:
+ return x.reshape(x.size(0), -1)
+
+
+def _layer_view(x) -> torch.Tensor:
+ return x.reshape(1, -1)
+
+
+def projection(p, grad, perturb, delta: float, wd_ratio: float, eps: float):
+ wd = 1.
+ expand_size = (-1,) + (1,) * (len(p.shape) - 1)
+ for view_func in [_channel_view, _layer_view]:
+ param_view = view_func(p)
+ grad_view = view_func(grad)
+ cosine_sim = F.cosine_similarity(grad_view, param_view, dim=1, eps=eps).abs_()
+
+ # FIXME this is a problem for PyTorch XLA
+ if cosine_sim.max() < delta / math.sqrt(param_view.size(1)):
+ p_n = p / param_view.norm(p=2, dim=1).add_(eps).reshape(expand_size)
+ perturb -= p_n * view_func(p_n * perturb).sum(dim=1).reshape(expand_size)
+ wd = wd_ratio
+ return perturb, wd
+
+ return perturb, wd
+
+
+class AdamP(Optimizer):
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
+ defaults = dict(
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
+ delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
+ super(AdamP, self).__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ grad = p.grad
+ beta1, beta2 = group['betas']
+ nesterov = group['nesterov']
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ state['exp_avg'] = torch.zeros_like(p)
+ state['exp_avg_sq'] = torch.zeros_like(p)
+
+ # Adam
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+
+ state['step'] += 1
+ bias_correction1 = 1 - beta1 ** state['step']
+ bias_correction2 = 1 - beta2 ** state['step']
+
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
+ step_size = group['lr'] / bias_correction1
+
+ if nesterov:
+ perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
+ else:
+ perturb = exp_avg / denom
+
+ # Projection
+ wd_ratio = 1.
+ if len(p.shape) > 1:
+ perturb, wd_ratio = projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
+
+ # Weight decay
+ if group['weight_decay'] > 0:
+ p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio)
+
+ # Step
+ p.add_(perturb, alpha=-step_size)
+
+ return loss
diff --git a/timm/optim/adamw.py b/timm/optim/adamw.py
new file mode 100644
index 0000000..66478bc
--- /dev/null
+++ b/timm/optim/adamw.py
@@ -0,0 +1,122 @@
+""" AdamW Optimizer
+Impl copied from PyTorch master
+
+NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed
+someday
+"""
+import math
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class AdamW(Optimizer):
+ r"""Implements AdamW algorithm.
+
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
+ (default: False)
+
+ .. _Adam\: A Method for Stochastic Optimization:
+ https://arxiv.org/abs/1412.6980
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=1e-2, amsgrad=False):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay, amsgrad=amsgrad)
+ super(AdamW, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(AdamW, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('amsgrad', False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ # Perform stepweight decay
+ p.data.mul_(1 - group['lr'] * group['weight_decay'])
+
+ # Perform optimization step
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
+ amsgrad = group['amsgrad']
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros_like(p)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ if amsgrad:
+ max_exp_avg_sq = state['max_exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ state['step'] += 1
+ bias_correction1 = 1 - beta1 ** state['step']
+ bias_correction2 = 1 - beta2 ** state['step']
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+ if amsgrad:
+ # Maintains the maximum of all 2nd moment running avg. till now
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+ # Use the max. for normalizing running avg. of gradient
+ denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
+ else:
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
+
+ step_size = group['lr'] / bias_correction1
+
+ p.addcdiv_(exp_avg, denom, value=-step_size)
+
+ return loss
diff --git a/timm/optim/lamb.py b/timm/optim/lamb.py
new file mode 100644
index 0000000..12c7c49
--- /dev/null
+++ b/timm/optim/lamb.py
@@ -0,0 +1,192 @@
+""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb
+
+This optimizer code was adapted from the following (starting with latest)
+* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py
+* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
+* https://github.com/cybertronai/pytorch-lamb
+
+Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is
+similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX.
+
+In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU.
+
+Original copyrights for above sources are below.
+
+Modifications Copyright 2021 Ross Wightman
+"""
+# Copyright (c) 2021, Habana Labs Ltd. All rights reserved.
+
+# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# MIT License
+#
+# Copyright (c) 2019 cybertronai
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import math
+
+import torch
+from torch.optim import Optimizer
+
+
+class Lamb(Optimizer):
+ """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB
+ reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py
+
+ LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
+ lr (float, optional): learning rate. (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its norm. (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability. (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ grad_averaging (bool, optional): whether apply (1-beta2) to grad when
+ calculating running averages of gradient. (default: True)
+ max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0)
+ trust_clip (bool): enable LAMBC trust ratio clipping (default: False)
+ always_adapt (boolean, optional): Apply adaptive learning rate to 0.0
+ weight decay parameter (default: False)
+
+ .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
+ https://arxiv.org/abs/1904.00962
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+
+ def __init__(
+ self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-6,
+ weight_decay=0.01, grad_averaging=True, max_grad_norm=1.0, trust_clip=False, always_adapt=False):
+ defaults = dict(
+ lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay,
+ grad_averaging=grad_averaging, max_grad_norm=max_grad_norm,
+ trust_clip=trust_clip, always_adapt=always_adapt)
+ super().__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ device = self.param_groups[0]['params'][0].device
+ one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
+ global_grad_norm = torch.zeros(1, device=device)
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
+ global_grad_norm.add_(grad.pow(2).sum())
+
+ global_grad_norm = torch.sqrt(global_grad_norm)
+ # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes
+ # scalar types properly https://github.com/pytorch/pytorch/issues/9190
+ max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device)
+ clip_global_grad_norm = torch.where(
+ global_grad_norm > max_grad_norm,
+ global_grad_norm / max_grad_norm,
+ one_tensor)
+
+ for group in self.param_groups:
+ bias_correction = 1 if group['bias_correction'] else 0
+ beta1, beta2 = group['betas']
+ grad_averaging = 1 if group['grad_averaging'] else 0
+ beta3 = 1 - beta1 if grad_averaging else 1.0
+
+ # assume same step across group now to simplify things
+ # per parameter step can be easily support by making it tensor, or pass list into kernel
+ if 'step' in group:
+ group['step'] += 1
+ else:
+ group['step'] = 1
+
+ if bias_correction:
+ bias_correction1 = 1 - beta1 ** group['step']
+ bias_correction2 = 1 - beta2 ** group['step']
+ else:
+ bias_correction1, bias_correction2 = 1.0, 1.0
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.div_(clip_global_grad_norm)
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ # Exponential moving average of gradient valuesa
+ state['exp_avg'] = torch.zeros_like(p)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t
+
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
+ update = (exp_avg / bias_correction1).div_(denom)
+
+ weight_decay = group['weight_decay']
+ if weight_decay != 0:
+ update.add_(p, alpha=weight_decay)
+
+ if weight_decay != 0 or group['always_adapt']:
+ # Layer-wise LR adaptation. By default, skip adaptation on parameters that are
+ # excluded from weight decay, unless always_adapt == True, then always enabled.
+ w_norm = p.norm(2.0)
+ g_norm = update.norm(2.0)
+ # FIXME nested where required since logical and/or not working in PT XLA
+ trust_ratio = torch.where(
+ w_norm > 0,
+ torch.where(g_norm > 0, w_norm / g_norm, one_tensor),
+ one_tensor,
+ )
+ if group['trust_clip']:
+ # LAMBC trust clipping, upper bound fixed at one
+ trust_ratio = torch.minimum(trust_ratio, one_tensor)
+ update.mul_(trust_ratio)
+
+ p.add_(update, alpha=-group['lr'])
+
+ return loss
diff --git a/timm/optim/lars.py b/timm/optim/lars.py
new file mode 100644
index 0000000..98198e6
--- /dev/null
+++ b/timm/optim/lars.py
@@ -0,0 +1,135 @@
+""" PyTorch LARS / LARC Optimizer
+
+An implementation of LARS (SGD) + LARC in PyTorch
+
+Based on:
+ * PyTorch SGD: https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
+ * NVIDIA APEX LARC: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
+
+Additional cleanup and modifications to properly support PyTorch XLA.
+
+Copyright 2021 Ross Wightman
+"""
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class Lars(Optimizer):
+ """ LARS for PyTorch
+
+ Paper: `Large batch training of Convolutional Networks` - https://arxiv.org/pdf/1708.03888.pdf
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups.
+ lr (float, optional): learning rate (default: 1.0).
+ momentum (float, optional): momentum factor (default: 0)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ dampening (float, optional): dampening for momentum (default: 0)
+ nesterov (bool, optional): enables Nesterov momentum (default: False)
+ trust_coeff (float): trust coefficient for computing adaptive lr / trust_ratio (default: 0.001)
+ eps (float): eps for division denominator (default: 1e-8)
+ trust_clip (bool): enable LARC trust ratio clipping (default: False)
+ always_adapt (bool): always apply LARS LR adapt, otherwise only when group weight_decay != 0 (default: False)
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=1.0,
+ momentum=0,
+ dampening=0,
+ weight_decay=0,
+ nesterov=False,
+ trust_coeff=0.001,
+ eps=1e-8,
+ trust_clip=False,
+ always_adapt=False,
+ ):
+ if lr < 0.0:
+ raise ValueError(f"Invalid learning rate: {lr}")
+ if momentum < 0.0:
+ raise ValueError(f"Invalid momentum value: {momentum}")
+ if weight_decay < 0.0:
+ raise ValueError(f"Invalid weight_decay value: {weight_decay}")
+ if nesterov and (momentum <= 0 or dampening != 0):
+ raise ValueError("Nesterov momentum requires a momentum and zero dampening")
+
+ defaults = dict(
+ lr=lr,
+ momentum=momentum,
+ dampening=dampening,
+ weight_decay=weight_decay,
+ nesterov=nesterov,
+ trust_coeff=trust_coeff,
+ eps=eps,
+ trust_clip=trust_clip,
+ always_adapt=always_adapt,
+ )
+ super().__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super().__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault("nesterov", False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Args:
+ closure (callable, optional): A closure that reevaluates the model and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ device = self.param_groups[0]['params'][0].device
+ one_tensor = torch.tensor(1.0, device=device) # because torch.where doesn't handle scalars correctly
+
+ for group in self.param_groups:
+ weight_decay = group['weight_decay']
+ momentum = group['momentum']
+ dampening = group['dampening']
+ nesterov = group['nesterov']
+ trust_coeff = group['trust_coeff']
+ eps = group['eps']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad
+
+ # apply LARS LR adaptation, LARC clipping, weight decay
+ # ref: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
+ if weight_decay != 0 or group['always_adapt']:
+ w_norm = p.norm(2.0)
+ g_norm = grad.norm(2.0)
+ trust_ratio = trust_coeff * w_norm / (g_norm + w_norm * weight_decay + eps)
+ # FIXME nested where required since logical and/or not working in PT XLA
+ trust_ratio = torch.where(
+ w_norm > 0,
+ torch.where(g_norm > 0, trust_ratio, one_tensor),
+ one_tensor,
+ )
+ if group['trust_clip']:
+ trust_ratio = torch.minimum(trust_ratio / group['lr'], one_tensor)
+ grad.add(p, alpha=weight_decay)
+ grad.mul_(trust_ratio)
+
+ # apply SGD update https://github.com/pytorch/pytorch/blob/1.7/torch/optim/sgd.py#L100
+ if momentum != 0:
+ param_state = self.state[p]
+ if 'momentum_buffer' not in param_state:
+ buf = param_state['momentum_buffer'] = torch.clone(grad).detach()
+ else:
+ buf = param_state['momentum_buffer']
+ buf.mul_(momentum).add_(grad, alpha=1. - dampening)
+ if nesterov:
+ grad = grad.add(buf, alpha=momentum)
+ else:
+ grad = buf
+
+ p.add_(grad, alpha=-group['lr'])
+
+ return loss
\ No newline at end of file
diff --git a/timm/optim/lookahead.py b/timm/optim/lookahead.py
new file mode 100644
index 0000000..462c3ac
--- /dev/null
+++ b/timm/optim/lookahead.py
@@ -0,0 +1,61 @@
+""" Lookahead Optimizer Wrapper.
+Implementation modified from: https://github.com/alphadl/lookahead.pytorch
+Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+from torch.optim.optimizer import Optimizer
+from collections import defaultdict
+
+
+class Lookahead(Optimizer):
+ def __init__(self, base_optimizer, alpha=0.5, k=6):
+ # NOTE super().__init__() not called on purpose
+ if not 0.0 <= alpha <= 1.0:
+ raise ValueError(f'Invalid slow update rate: {alpha}')
+ if not 1 <= k:
+ raise ValueError(f'Invalid lookahead steps: {k}')
+ defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
+ self._base_optimizer = base_optimizer
+ self.param_groups = base_optimizer.param_groups
+ self.defaults = base_optimizer.defaults
+ self.defaults.update(defaults)
+ self.state = defaultdict(dict)
+ # manually add our defaults to the param groups
+ for name, default in defaults.items():
+ for group in self._base_optimizer.param_groups:
+ group.setdefault(name, default)
+
+ @torch.no_grad()
+ def update_slow(self, group):
+ for fast_p in group["params"]:
+ if fast_p.grad is None:
+ continue
+ param_state = self._base_optimizer.state[fast_p]
+ if 'lookahead_slow_buff' not in param_state:
+ param_state['lookahead_slow_buff'] = torch.empty_like(fast_p)
+ param_state['lookahead_slow_buff'].copy_(fast_p)
+ slow = param_state['lookahead_slow_buff']
+ slow.add_(fast_p - slow, alpha=group['lookahead_alpha'])
+ fast_p.copy_(slow)
+
+ def sync_lookahead(self):
+ for group in self._base_optimizer.param_groups:
+ self.update_slow(group)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ loss = self._base_optimizer.step(closure)
+ for group in self._base_optimizer.param_groups:
+ group['lookahead_step'] += 1
+ if group['lookahead_step'] % group['lookahead_k'] == 0:
+ self.update_slow(group)
+ return loss
+
+ def state_dict(self):
+ return self._base_optimizer.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._base_optimizer.load_state_dict(state_dict)
+ self.param_groups = self._base_optimizer.param_groups
diff --git a/timm/optim/madgrad.py b/timm/optim/madgrad.py
new file mode 100644
index 0000000..a76713b
--- /dev/null
+++ b/timm/optim/madgrad.py
@@ -0,0 +1,184 @@
+""" PyTorch MADGRAD optimizer
+
+MADGRAD: https://arxiv.org/abs/2101.11075
+
+Code from: https://github.com/facebookresearch/madgrad
+"""
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import TYPE_CHECKING, Any, Callable, Optional
+
+import torch
+import torch.optim
+
+if TYPE_CHECKING:
+ from torch.optim.optimizer import _params_t
+else:
+ _params_t = Any
+
+
+class MADGRAD(torch.optim.Optimizer):
+ """
+ MADGRAD_: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
+ Optimization.
+
+ .. _MADGRAD: https://arxiv.org/abs/2101.11075
+
+ MADGRAD is a general purpose optimizer that can be used in place of SGD or
+ Adam may converge faster and generalize better. Currently GPU-only.
+ Typically, the same learning rate schedule that is used for SGD or Adam may
+ be used. The overall learning rate is not comparable to either method and
+ should be determined by a hyper-parameter sweep.
+
+ MADGRAD requires less weight decay than other methods, often as little as
+ zero. Momentum values used for SGD or Adam's beta1 should work here also.
+
+ On sparse problems both weight_decay and momentum should be set to 0.
+
+ Arguments:
+ params (iterable):
+ Iterable of parameters to optimize or dicts defining parameter groups.
+ lr (float):
+ Learning rate (default: 1e-2).
+ momentum (float):
+ Momentum value in the range [0,1) (default: 0.9).
+ weight_decay (float):
+ Weight decay, i.e. a L2 penalty (default: 0).
+ eps (float):
+ Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-6).
+ """
+
+ def __init__(
+ self,
+ params: _params_t,
+ lr: float = 1e-2,
+ momentum: float = 0.9,
+ weight_decay: float = 0,
+ eps: float = 1e-6,
+ decoupled_decay: bool = False,
+ ):
+ if momentum < 0 or momentum >= 1:
+ raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
+ if lr <= 0:
+ raise ValueError(f"Learning rate {lr} must be positive")
+ if weight_decay < 0:
+ raise ValueError(f"Weight decay {weight_decay} must be non-negative")
+ if eps < 0:
+ raise ValueError(f"Eps must be non-negative")
+
+ defaults = dict(
+ lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay, decoupled_decay=decoupled_decay)
+ super().__init__(params, defaults)
+
+ @property
+ def supports_memory_efficient_fp16(self) -> bool:
+ return False
+
+ @property
+ def supports_flat_params(self) -> bool:
+ return True
+
+ @torch.no_grad()
+ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ eps = group['eps']
+ lr = group['lr'] + eps
+ weight_decay = group['weight_decay']
+ momentum = group['momentum']
+ ck = 1 - momentum
+
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+ grad = p.grad
+ if momentum != 0.0 and grad.is_sparse:
+ raise RuntimeError("momentum != 0 is not compatible with sparse gradients")
+
+ state = self.state[p]
+ if len(state) == 0:
+ state['step'] = 0
+ state['grad_sum_sq'] = torch.zeros_like(p)
+ state['s'] = torch.zeros_like(p)
+ if momentum != 0:
+ state['x0'] = torch.clone(p).detach()
+
+ state['step'] += 1
+ grad_sum_sq = state['grad_sum_sq']
+ s = state['s']
+ lamb = lr * math.sqrt(state['step'])
+
+ # Apply weight decay
+ if weight_decay != 0:
+ if group['decoupled_decay']:
+ p.mul_(1.0 - group['lr'] * weight_decay)
+ else:
+ if grad.is_sparse:
+ raise RuntimeError("weight_decay option is not compatible with sparse gradients")
+ grad.add_(p, alpha=weight_decay)
+
+ if grad.is_sparse:
+ grad = grad.coalesce()
+ grad_val = grad._values()
+
+ p_masked = p.sparse_mask(grad)
+ grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
+ s_masked = s.sparse_mask(grad)
+
+ # Compute x_0 from other known quantities
+ rms_masked_vals = grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
+ x0_masked_vals = p_masked._values().addcdiv(s_masked._values(), rms_masked_vals, value=1)
+
+ # Dense + sparse op
+ grad_sq = grad * grad
+ grad_sum_sq.add_(grad_sq, alpha=lamb)
+ grad_sum_sq_masked.add_(grad_sq, alpha=lamb)
+
+ rms_masked_vals = grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
+
+ s.add_(grad, alpha=lamb)
+ s_masked._values().add_(grad_val, alpha=lamb)
+
+ # update masked copy of p
+ p_kp1_masked_vals = x0_masked_vals.addcdiv(s_masked._values(), rms_masked_vals, value=-1)
+ # Copy updated masked p to dense p using an add operation
+ p_masked._values().add_(p_kp1_masked_vals, alpha=-1)
+ p.add_(p_masked, alpha=-1)
+ else:
+ if momentum == 0:
+ # Compute x_0 from other known quantities
+ rms = grad_sum_sq.pow(1 / 3).add_(eps)
+ x0 = p.addcdiv(s, rms, value=1)
+ else:
+ x0 = state['x0']
+
+ # Accumulate second moments
+ grad_sum_sq.addcmul_(grad, grad, value=lamb)
+ rms = grad_sum_sq.pow(1 / 3).add_(eps)
+
+ # Update s
+ s.add_(grad, alpha=lamb)
+
+ # Step
+ if momentum == 0:
+ p.copy_(x0.addcdiv(s, rms, value=-1))
+ else:
+ z = x0.addcdiv(s, rms, value=-1)
+
+ # p is a moving average of z
+ p.mul_(1 - ck).add_(z, alpha=ck)
+
+ return loss
diff --git a/timm/optim/nadam.py b/timm/optim/nadam.py
new file mode 100644
index 0000000..6268d5d
--- /dev/null
+++ b/timm/optim/nadam.py
@@ -0,0 +1,92 @@
+import math
+
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class Nadam(Optimizer):
+ """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
+
+ It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 2e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
+
+ __ http://cs229.stanford.edu/proj2015/054_report.pdf
+ __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
+
+ Originally taken from: https://github.com/pytorch/pytorch/pull/1408
+ NOTE: Has potential issues but does work well on some problems.
+ """
+
+ def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, schedule_decay=4e-3):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ defaults = dict(
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, schedule_decay=schedule_decay)
+ super(Nadam, self).__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ state['m_schedule'] = 1.
+ state['exp_avg'] = torch.zeros_like(p)
+ state['exp_avg_sq'] = torch.zeros_like(p)
+
+ # Warming momentum schedule
+ m_schedule = state['m_schedule']
+ schedule_decay = group['schedule_decay']
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+ eps = group['eps']
+ state['step'] += 1
+ t = state['step']
+ bias_correction2 = 1 - beta2 ** t
+
+ if group['weight_decay'] != 0:
+ grad = grad.add(p, alpha=group['weight_decay'])
+
+ momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
+ momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
+ m_schedule_new = m_schedule * momentum_cache_t
+ m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
+ state['m_schedule'] = m_schedule_new
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)
+
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
+ p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
+ p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
+
+ return loss
diff --git a/timm/optim/nvnovograd.py b/timm/optim/nvnovograd.py
new file mode 100644
index 0000000..fda3f4a
--- /dev/null
+++ b/timm/optim/nvnovograd.py
@@ -0,0 +1,120 @@
+""" Nvidia NovoGrad Optimizer.
+Original impl by Nvidia from Jasper example:
+ - https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper
+Paper: `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks`
+ - https://arxiv.org/abs/1905.11286
+"""
+
+import torch
+from torch.optim.optimizer import Optimizer
+import math
+
+
+class NvNovoGrad(Optimizer):
+ """
+ Implements Novograd algorithm.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.95, 0.98))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ grad_averaging: gradient averaging
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
+ (default: False)
+ """
+
+ def __init__(self, params, lr=1e-3, betas=(0.95, 0.98), eps=1e-8,
+ weight_decay=0, grad_averaging=False, amsgrad=False):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay,
+ grad_averaging=grad_averaging,
+ amsgrad=amsgrad)
+
+ super(NvNovoGrad, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(NvNovoGrad, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('amsgrad', False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError('Sparse gradients are not supported.')
+ amsgrad = group['amsgrad']
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ if amsgrad:
+ max_exp_avg_sq = state['max_exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ state['step'] += 1
+
+ norm = torch.sum(torch.pow(grad, 2))
+
+ if exp_avg_sq == 0:
+ exp_avg_sq.copy_(norm)
+ else:
+ exp_avg_sq.mul_(beta2).add_(norm, alpha=1 - beta2)
+
+ if amsgrad:
+ # Maintains the maximum of all 2nd moment running avg. till now
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+ # Use the max. for normalizing running avg. of gradient
+ denom = max_exp_avg_sq.sqrt().add_(group['eps'])
+ else:
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+
+ grad.div_(denom)
+ if group['weight_decay'] != 0:
+ grad.add_(p, alpha=group['weight_decay'])
+ if group['grad_averaging']:
+ grad.mul_(1 - beta1)
+ exp_avg.mul_(beta1).add_(grad)
+
+ p.add_(exp_avg, alpha=-group['lr'])
+
+ return loss
diff --git a/timm/optim/optim_factory.py b/timm/optim/optim_factory.py
new file mode 100644
index 0000000..e174915
--- /dev/null
+++ b/timm/optim/optim_factory.py
@@ -0,0 +1,217 @@
+""" Optimizer Factory w/ Custom Weight Decay
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from .adabelief import AdaBelief
+from .adafactor import Adafactor
+from .adahessian import Adahessian
+from .adamp import AdamP
+from .lamb import Lamb
+from .lars import Lars
+from .lookahead import Lookahead
+from .madgrad import MADGRAD
+from .nadam import Nadam
+from .nvnovograd import NvNovoGrad
+from .radam import RAdam
+from .rmsprop_tf import RMSpropTF
+from .sgdp import SGDP
+
+try:
+ from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
+ has_apex = True
+except ImportError:
+ has_apex = False
+
+
+def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
+ decay = []
+ no_decay = []
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
+ no_decay.append(param)
+ else:
+ decay.append(param)
+ return [
+ {'params': no_decay, 'weight_decay': 0.},
+ {'params': decay, 'weight_decay': weight_decay}]
+
+
+def optimizer_kwargs(cfg):
+ """ cfg/argparse to kwargs helper
+ Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
+ """
+ kwargs = dict(
+ opt=cfg.opt,
+ lr=cfg.lr,
+ weight_decay=cfg.weight_decay,
+ momentum=cfg.momentum)
+ if getattr(cfg, 'opt_eps', None) is not None:
+ kwargs['eps'] = cfg.opt_eps
+ if getattr(cfg, 'opt_betas', None) is not None:
+ kwargs['betas'] = cfg.opt_betas
+ if getattr(cfg, 'opt_args', None) is not None:
+ kwargs.update(cfg.opt_args)
+ return kwargs
+
+
+def create_optimizer(args, model, filter_bias_and_bn=True):
+ """ Legacy optimizer factory for backwards compatibility.
+ NOTE: Use create_optimizer_v2 for new code.
+ """
+ return create_optimizer_v2(
+ model,
+ **optimizer_kwargs(cfg=args),
+ filter_bias_and_bn=filter_bias_and_bn,
+ )
+
+
+def create_optimizer_v2(
+ model_or_params,
+ opt: str = 'sgd',
+ lr: Optional[float] = None,
+ weight_decay: float = 0.,
+ momentum: float = 0.9,
+ filter_bias_and_bn: bool = True,
+ **kwargs):
+ """ Create an optimizer.
+
+ TODO currently the model is passed in and all parameters are selected for optimization.
+ For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
+ * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
+ * expose the parameters interface and leave it up to caller
+
+ Args:
+ model_or_params (nn.Module): model containing parameters to optimize
+ opt: name of optimizer to create
+ lr: initial learning rate
+ weight_decay: weight decay to apply in optimizer
+ momentum: momentum for momentum based optimizers (others may use betas via kwargs)
+ filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
+ **kwargs: extra optimizer specific kwargs to pass through
+
+ Returns:
+ Optimizer
+ """
+ if isinstance(model_or_params, nn.Module):
+ # a model was passed in, extract parameters and add weight decays to appropriate layers
+ if weight_decay and filter_bias_and_bn:
+ skip = {}
+ if hasattr(model_or_params, 'no_weight_decay'):
+ skip = model_or_params.no_weight_decay()
+ parameters = add_weight_decay(model_or_params, weight_decay, skip)
+ weight_decay = 0.
+ else:
+ parameters = model_or_params.parameters()
+ else:
+ # iterable of parameters or param groups passed in
+ parameters = model_or_params
+
+ opt_lower = opt.lower()
+ opt_split = opt_lower.split('_')
+ opt_lower = opt_split[-1]
+ if 'fused' in opt_lower:
+ assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
+
+ opt_args = dict(weight_decay=weight_decay, **kwargs)
+ if lr is not None:
+ opt_args.setdefault('lr', lr)
+
+ # basic SGD & related
+ if opt_lower == 'sgd' or opt_lower == 'nesterov':
+ # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'momentum':
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
+ elif opt_lower == 'sgdp':
+ optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
+
+ # adaptive
+ elif opt_lower == 'adam':
+ optimizer = optim.Adam(parameters, **opt_args)
+ elif opt_lower == 'adamw':
+ optimizer = optim.AdamW(parameters, **opt_args)
+ elif opt_lower == 'adamp':
+ optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
+ elif opt_lower == 'nadam':
+ try:
+ # NOTE PyTorch >= 1.10 should have native NAdam
+ optimizer = optim.Nadam(parameters, **opt_args)
+ except AttributeError:
+ optimizer = Nadam(parameters, **opt_args)
+ elif opt_lower == 'radam':
+ optimizer = RAdam(parameters, **opt_args)
+ elif opt_lower == 'adamax':
+ optimizer = optim.Adamax(parameters, **opt_args)
+ elif opt_lower == 'adabelief':
+ optimizer = AdaBelief(parameters, rectify=False, **opt_args)
+ elif opt_lower == 'radabelief':
+ optimizer = AdaBelief(parameters, rectify=True, **opt_args)
+ elif opt_lower == 'adadelta':
+ optimizer = optim.Adadelta(parameters, **opt_args)
+ elif opt_lower == 'adagrad':
+ opt_args.setdefault('eps', 1e-8)
+ optimizer = optim.Adagrad(parameters, **opt_args)
+ elif opt_lower == 'adafactor':
+ optimizer = Adafactor(parameters, **opt_args)
+ elif opt_lower == 'lamb':
+ optimizer = Lamb(parameters, **opt_args)
+ elif opt_lower == 'lambc':
+ optimizer = Lamb(parameters, trust_clip=True, **opt_args)
+ elif opt_lower == 'larc':
+ optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)
+ elif opt_lower == 'lars':
+ optimizer = Lars(parameters, momentum=momentum, **opt_args)
+ elif opt_lower == 'nlarc':
+ optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)
+ elif opt_lower == 'nlars':
+ optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'madgrad':
+ optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
+ elif opt_lower == 'madgradw':
+ optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
+ elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
+ optimizer = NvNovoGrad(parameters, **opt_args)
+ elif opt_lower == 'rmsprop':
+ optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
+ elif opt_lower == 'rmsproptf':
+ optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
+
+ # second order
+ elif opt_lower == 'adahessian':
+ optimizer = Adahessian(parameters, **opt_args)
+
+ # NVIDIA fused optimizers, require APEX to be installed
+ elif opt_lower == 'fusedsgd':
+ opt_args.pop('eps', None)
+ optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'fusedmomentum':
+ opt_args.pop('eps', None)
+ optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
+ elif opt_lower == 'fusedadam':
+ optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
+ elif opt_lower == 'fusedadamw':
+ optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
+ elif opt_lower == 'fusedlamb':
+ optimizer = FusedLAMB(parameters, **opt_args)
+ elif opt_lower == 'fusednovograd':
+ opt_args.setdefault('betas', (0.95, 0.98))
+ optimizer = FusedNovoGrad(parameters, **opt_args)
+
+ else:
+ assert False and "Invalid optimizer"
+ raise ValueError
+
+ if len(opt_split) > 1:
+ if opt_split[0] == 'lookahead':
+ optimizer = Lookahead(optimizer)
+
+ return optimizer
diff --git a/timm/optim/radam.py b/timm/optim/radam.py
new file mode 100644
index 0000000..eb8d22e
--- /dev/null
+++ b/timm/optim/radam.py
@@ -0,0 +1,89 @@
+"""RAdam Optimizer.
+Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
+Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
+"""
+import math
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class RAdam(Optimizer):
+
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
+ defaults = dict(
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
+ buffer=[[None, None, None] for _ in range(10)])
+ super(RAdam, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(RAdam, self).__setstate__(state)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.float()
+ if grad.is_sparse:
+ raise RuntimeError('RAdam does not support sparse gradients')
+
+ p_fp32 = p.float()
+
+ state = self.state[p]
+
+ if len(state) == 0:
+ state['step'] = 0
+ state['exp_avg'] = torch.zeros_like(p_fp32)
+ state['exp_avg_sq'] = torch.zeros_like(p_fp32)
+ else:
+ state['exp_avg'] = state['exp_avg'].type_as(p_fp32)
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_fp32)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+
+ state['step'] += 1
+ buffered = group['buffer'][int(state['step'] % 10)]
+ if state['step'] == buffered[0]:
+ num_sma, step_size = buffered[1], buffered[2]
+ else:
+ buffered[0] = state['step']
+ beta2_t = beta2 ** state['step']
+ num_sma_max = 2 / (1 - beta2) - 1
+ num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+ buffered[1] = num_sma
+
+ # more conservative since it's an approximated value
+ if num_sma >= 5:
+ step_size = group['lr'] * math.sqrt(
+ (1 - beta2_t) *
+ (num_sma - 4) / (num_sma_max - 4) *
+ (num_sma - 2) / num_sma *
+ num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step'])
+ else:
+ step_size = group['lr'] / (1 - beta1 ** state['step'])
+ buffered[2] = step_size
+
+ if group['weight_decay'] != 0:
+ p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * group['lr'])
+
+ # more conservative since it's an approximated value
+ if num_sma >= 5:
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+ p_fp32.addcdiv_(exp_avg, denom, value=-step_size)
+ else:
+ p_fp32.add_(exp_avg, alpha=-step_size)
+
+ p.copy_(p_fp32)
+
+ return loss
diff --git a/timm/optim/rmsprop_tf.py b/timm/optim/rmsprop_tf.py
new file mode 100644
index 0000000..0817887
--- /dev/null
+++ b/timm/optim/rmsprop_tf.py
@@ -0,0 +1,139 @@
+""" RMSProp modified to behave like Tensorflow impl
+
+Originally cut & paste from PyTorch RMSProp
+https://github.com/pytorch/pytorch/blob/063946d2b3f3f1e953a2a3b54e0b34f1393de295/torch/optim/rmsprop.py
+Licensed under BSD-Clause 3 (ish), https://github.com/pytorch/pytorch/blob/master/LICENSE
+
+Modifications Copyright 2021 Ross Wightman
+"""
+
+import torch
+from torch.optim import Optimizer
+
+
+class RMSpropTF(Optimizer):
+ """Implements RMSprop algorithm (TensorFlow style epsilon)
+
+ NOTE: This is a direct cut-and-paste of PyTorch RMSprop with eps applied before sqrt
+ and a few other modifications to closer match Tensorflow for matching hyper-params.
+
+ Noteworthy changes include:
+ 1. Epsilon applied inside square-root
+ 2. square_avg initialized to ones
+ 3. LR scaling of update accumulated in momentum buffer
+
+ Proposed by G. Hinton in his
+ `course `_.
+
+ The centered version first appears in `Generating Sequences
+ With Recurrent Neural Networks `_.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-2)
+ momentum (float, optional): momentum factor (default: 0)
+ alpha (float, optional): smoothing (decay) constant (default: 0.9)
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-10)
+ centered (bool, optional) : if ``True``, compute the centered RMSProp,
+ the gradient is normalized by an estimation of its variance
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101
+ lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer
+ update as per defaults in Tensorflow
+
+ """
+
+ def __init__(self, params, lr=1e-2, alpha=0.9, eps=1e-10, weight_decay=0, momentum=0., centered=False,
+ decoupled_decay=False, lr_in_momentum=True):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= momentum:
+ raise ValueError("Invalid momentum value: {}".format(momentum))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= alpha:
+ raise ValueError("Invalid alpha value: {}".format(alpha))
+
+ defaults = dict(
+ lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay,
+ decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum)
+ super(RMSpropTF, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(RMSpropTF, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('momentum', 0)
+ group.setdefault('centered', False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError('RMSprop does not support sparse gradients')
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ state['square_avg'] = torch.ones_like(p) # PyTorch inits to zero
+ if group['momentum'] > 0:
+ state['momentum_buffer'] = torch.zeros_like(p)
+ if group['centered']:
+ state['grad_avg'] = torch.zeros_like(p)
+
+ square_avg = state['square_avg']
+ one_minus_alpha = 1. - group['alpha']
+
+ state['step'] += 1
+
+ if group['weight_decay'] != 0:
+ if group['decoupled_decay']:
+ p.mul_(1. - group['lr'] * group['weight_decay'])
+ else:
+ grad = grad.add(p, alpha=group['weight_decay'])
+
+ # Tensorflow order of ops for updating squared avg
+ square_avg.add_(grad.pow(2) - square_avg, alpha=one_minus_alpha)
+ # square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) # PyTorch original
+
+ if group['centered']:
+ grad_avg = state['grad_avg']
+ grad_avg.add_(grad - grad_avg, alpha=one_minus_alpha)
+ avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add(group['eps']).sqrt_() # eps in sqrt
+ # grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) # PyTorch original
+ else:
+ avg = square_avg.add(group['eps']).sqrt_() # eps moved in sqrt
+
+ if group['momentum'] > 0:
+ buf = state['momentum_buffer']
+ # Tensorflow accumulates the LR scaling in the momentum buffer
+ if group['lr_in_momentum']:
+ buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr'])
+ p.add_(-buf)
+ else:
+ # PyTorch scales the param update by LR
+ buf.mul_(group['momentum']).addcdiv_(grad, avg)
+ p.add_(buf, alpha=-group['lr'])
+ else:
+ p.addcdiv_(grad, avg, value=-group['lr'])
+
+ return loss
diff --git a/timm/optim/sgdp.py b/timm/optim/sgdp.py
new file mode 100644
index 0000000..baf05fa
--- /dev/null
+++ b/timm/optim/sgdp.py
@@ -0,0 +1,70 @@
+"""
+SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py
+
+Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
+Code: https://github.com/clovaai/AdamP
+
+Copyright (c) 2020-present NAVER Corp.
+MIT license
+"""
+
+import torch
+import torch.nn.functional as F
+from torch.optim.optimizer import Optimizer, required
+import math
+
+from .adamp import projection
+
+
+class SGDP(Optimizer):
+ def __init__(self, params, lr=required, momentum=0, dampening=0,
+ weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1):
+ defaults = dict(
+ lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
+ nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
+ super(SGDP, self).__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ weight_decay = group['weight_decay']
+ momentum = group['momentum']
+ dampening = group['dampening']
+ nesterov = group['nesterov']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['momentum'] = torch.zeros_like(p)
+
+ # SGD
+ buf = state['momentum']
+ buf.mul_(momentum).add_(grad, alpha=1. - dampening)
+ if nesterov:
+ d_p = grad + momentum * buf
+ else:
+ d_p = buf
+
+ # Projection
+ wd_ratio = 1.
+ if len(p.shape) > 1:
+ d_p, wd_ratio = projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])
+
+ # Weight decay
+ if weight_decay != 0:
+ p.mul_(1. - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
+
+ # Step
+ p.add_(d_p, alpha=-group['lr'])
+
+ return loss
diff --git a/timm/scheduler/__init__.py b/timm/scheduler/__init__.py
new file mode 100644
index 0000000..f1961b8
--- /dev/null
+++ b/timm/scheduler/__init__.py
@@ -0,0 +1,8 @@
+from .cosine_lr import CosineLRScheduler
+from .multistep_lr import MultiStepLRScheduler
+from .plateau_lr import PlateauLRScheduler
+from .poly_lr import PolyLRScheduler
+from .step_lr import StepLRScheduler
+from .tanh_lr import TanhLRScheduler
+
+from .scheduler_factory import create_scheduler
diff --git a/timm/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py
new file mode 100644
index 0000000..84ee349
--- /dev/null
+++ b/timm/scheduler/cosine_lr.py
@@ -0,0 +1,119 @@
+""" Cosine Scheduler
+
+Cosine LR schedule with warmup, cycle/restarts, noise, k-decay.
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import logging
+import math
+import numpy as np
+import torch
+
+from .scheduler import Scheduler
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CosineLRScheduler(Scheduler):
+ """
+ Cosine decay with restarts.
+ This is described in the paper https://arxiv.org/abs/1608.03983.
+
+ Inspiration from
+ https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
+
+ k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
+ """
+
+ def __init__(self,
+ optimizer: torch.optim.Optimizer,
+ t_initial: int,
+ lr_min: float = 0.,
+ cycle_mul: float = 1.,
+ cycle_decay: float = 1.,
+ cycle_limit: int = 1,
+ warmup_t=0,
+ warmup_lr_init=0,
+ warmup_prefix=False,
+ t_in_epochs=True,
+ noise_range_t=None,
+ noise_pct=0.67,
+ noise_std=1.0,
+ noise_seed=42,
+ k_decay=1.0,
+ initialize=True) -> None:
+ super().__init__(
+ optimizer, param_group_field="lr",
+ noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
+ initialize=initialize)
+
+ assert t_initial > 0
+ assert lr_min >= 0
+ if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
+ _logger.warning("Cosine annealing scheduler will have no effect on the learning "
+ "rate since t_initial = t_mul = eta_mul = 1.")
+ self.t_initial = t_initial
+ self.lr_min = lr_min
+ self.cycle_mul = cycle_mul
+ self.cycle_decay = cycle_decay
+ self.cycle_limit = cycle_limit
+ self.warmup_t = warmup_t
+ self.warmup_lr_init = warmup_lr_init
+ self.warmup_prefix = warmup_prefix
+ self.t_in_epochs = t_in_epochs
+ self.k_decay = k_decay
+ if self.warmup_t:
+ self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
+ super().update_groups(self.warmup_lr_init)
+ else:
+ self.warmup_steps = [1 for _ in self.base_values]
+
+ def _get_lr(self, t):
+ if t < self.warmup_t:
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
+ else:
+ if self.warmup_prefix:
+ t = t - self.warmup_t
+
+ if self.cycle_mul != 1:
+ i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
+ t_i = self.cycle_mul ** i * self.t_initial
+ t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
+ else:
+ i = t // self.t_initial
+ t_i = self.t_initial
+ t_curr = t - (self.t_initial * i)
+
+ gamma = self.cycle_decay ** i
+ lr_max_values = [v * gamma for v in self.base_values]
+ k = self.k_decay
+
+ if i < self.cycle_limit:
+ lrs = [
+ self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k))
+ for lr_max in lr_max_values
+ ]
+ else:
+ lrs = [self.lr_min for _ in self.base_values]
+
+ return lrs
+
+ def get_epoch_values(self, epoch: int):
+ if self.t_in_epochs:
+ return self._get_lr(epoch)
+ else:
+ return None
+
+ def get_update_values(self, num_updates: int):
+ if not self.t_in_epochs:
+ return self._get_lr(num_updates)
+ else:
+ return None
+
+ def get_cycle_length(self, cycles=0):
+ cycles = max(1, cycles or self.cycle_limit)
+ if self.cycle_mul == 1.0:
+ return self.t_initial * cycles
+ else:
+ return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
diff --git a/timm/scheduler/multistep_lr.py b/timm/scheduler/multistep_lr.py
new file mode 100644
index 0000000..a5d5fe1
--- /dev/null
+++ b/timm/scheduler/multistep_lr.py
@@ -0,0 +1,65 @@
+""" MultiStep LR Scheduler
+
+Basic multi step LR schedule with warmup, noise.
+"""
+import torch
+import bisect
+from timm.scheduler.scheduler import Scheduler
+from typing import List
+
+class MultiStepLRScheduler(Scheduler):
+ """
+ """
+
+ def __init__(self,
+ optimizer: torch.optim.Optimizer,
+ decay_t: List[int],
+ decay_rate: float = 1.,
+ warmup_t=0,
+ warmup_lr_init=0,
+ t_in_epochs=True,
+ noise_range_t=None,
+ noise_pct=0.67,
+ noise_std=1.0,
+ noise_seed=42,
+ initialize=True,
+ ) -> None:
+ super().__init__(
+ optimizer, param_group_field="lr",
+ noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
+ initialize=initialize)
+
+ self.decay_t = decay_t
+ self.decay_rate = decay_rate
+ self.warmup_t = warmup_t
+ self.warmup_lr_init = warmup_lr_init
+ self.t_in_epochs = t_in_epochs
+ if self.warmup_t:
+ self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
+ super().update_groups(self.warmup_lr_init)
+ else:
+ self.warmup_steps = [1 for _ in self.base_values]
+
+ def get_curr_decay_steps(self, t):
+ # find where in the array t goes,
+ # assumes self.decay_t is sorted
+ return bisect.bisect_right(self.decay_t, t+1)
+
+ def _get_lr(self, t):
+ if t < self.warmup_t:
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
+ else:
+ lrs = [v * (self.decay_rate ** self.get_curr_decay_steps(t)) for v in self.base_values]
+ return lrs
+
+ def get_epoch_values(self, epoch: int):
+ if self.t_in_epochs:
+ return self._get_lr(epoch)
+ else:
+ return None
+
+ def get_update_values(self, num_updates: int):
+ if not self.t_in_epochs:
+ return self._get_lr(num_updates)
+ else:
+ return None
diff --git a/timm/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py
new file mode 100644
index 0000000..4f2cacb
--- /dev/null
+++ b/timm/scheduler/plateau_lr.py
@@ -0,0 +1,113 @@
+""" Plateau Scheduler
+
+Adapts PyTorch plateau scheduler and allows application of noise, warmup.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+
+from .scheduler import Scheduler
+
+
+class PlateauLRScheduler(Scheduler):
+ """Decay the LR by a factor every time the validation loss plateaus."""
+
+ def __init__(self,
+ optimizer,
+ decay_rate=0.1,
+ patience_t=10,
+ verbose=True,
+ threshold=1e-4,
+ cooldown_t=0,
+ warmup_t=0,
+ warmup_lr_init=0,
+ lr_min=0,
+ mode='max',
+ noise_range_t=None,
+ noise_type='normal',
+ noise_pct=0.67,
+ noise_std=1.0,
+ noise_seed=None,
+ initialize=True,
+ ):
+ super().__init__(optimizer, 'lr', initialize=initialize)
+
+ self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
+ self.optimizer,
+ patience=patience_t,
+ factor=decay_rate,
+ verbose=verbose,
+ threshold=threshold,
+ cooldown=cooldown_t,
+ mode=mode,
+ min_lr=lr_min
+ )
+
+ self.noise_range = noise_range_t
+ self.noise_pct = noise_pct
+ self.noise_type = noise_type
+ self.noise_std = noise_std
+ self.noise_seed = noise_seed if noise_seed is not None else 42
+ self.warmup_t = warmup_t
+ self.warmup_lr_init = warmup_lr_init
+ if self.warmup_t:
+ self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
+ super().update_groups(self.warmup_lr_init)
+ else:
+ self.warmup_steps = [1 for _ in self.base_values]
+ self.restore_lr = None
+
+ def state_dict(self):
+ return {
+ 'best': self.lr_scheduler.best,
+ 'last_epoch': self.lr_scheduler.last_epoch,
+ }
+
+ def load_state_dict(self, state_dict):
+ self.lr_scheduler.best = state_dict['best']
+ if 'last_epoch' in state_dict:
+ self.lr_scheduler.last_epoch = state_dict['last_epoch']
+
+ # override the base class step fn completely
+ def step(self, epoch, metric=None):
+ if epoch <= self.warmup_t:
+ lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
+ super().update_groups(lrs)
+ else:
+ if self.restore_lr is not None:
+ # restore actual LR from before our last noise perturbation before stepping base
+ for i, param_group in enumerate(self.optimizer.param_groups):
+ param_group['lr'] = self.restore_lr[i]
+ self.restore_lr = None
+
+ self.lr_scheduler.step(metric, epoch) # step the base scheduler
+
+ if self.noise_range is not None:
+ if isinstance(self.noise_range, (list, tuple)):
+ apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
+ else:
+ apply_noise = epoch >= self.noise_range
+ if apply_noise:
+ self._apply_noise(epoch)
+
+ def _apply_noise(self, epoch):
+ g = torch.Generator()
+ g.manual_seed(self.noise_seed + epoch)
+ if self.noise_type == 'normal':
+ while True:
+ # resample if noise out of percent limit, brute force but shouldn't spin much
+ noise = torch.randn(1, generator=g).item()
+ if abs(noise) < self.noise_pct:
+ break
+ else:
+ noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
+
+ # apply the noise on top of previous LR, cache the old value so we can restore for normal
+ # stepping of base scheduler
+ restore_lr = []
+ for i, param_group in enumerate(self.optimizer.param_groups):
+ old_lr = float(param_group['lr'])
+ restore_lr.append(old_lr)
+ new_lr = old_lr + old_lr * noise
+ param_group['lr'] = new_lr
+ self.restore_lr = restore_lr
diff --git a/timm/scheduler/poly_lr.py b/timm/scheduler/poly_lr.py
new file mode 100644
index 0000000..9c351be
--- /dev/null
+++ b/timm/scheduler/poly_lr.py
@@ -0,0 +1,116 @@
+""" Polynomial Scheduler
+
+Polynomial LR schedule with warmup, noise.
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import math
+import logging
+
+import torch
+
+from .scheduler import Scheduler
+
+
+_logger = logging.getLogger(__name__)
+
+
+class PolyLRScheduler(Scheduler):
+ """ Polynomial LR Scheduler w/ warmup, noise, and k-decay
+
+ k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
+ """
+
+ def __init__(self,
+ optimizer: torch.optim.Optimizer,
+ t_initial: int,
+ power: float = 0.5,
+ lr_min: float = 0.,
+ cycle_mul: float = 1.,
+ cycle_decay: float = 1.,
+ cycle_limit: int = 1,
+ warmup_t=0,
+ warmup_lr_init=0,
+ warmup_prefix=False,
+ t_in_epochs=True,
+ noise_range_t=None,
+ noise_pct=0.67,
+ noise_std=1.0,
+ noise_seed=42,
+ k_decay=1.0,
+ initialize=True) -> None:
+ super().__init__(
+ optimizer, param_group_field="lr",
+ noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
+ initialize=initialize)
+
+ assert t_initial > 0
+ assert lr_min >= 0
+ if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
+ _logger.warning("Cosine annealing scheduler will have no effect on the learning "
+ "rate since t_initial = t_mul = eta_mul = 1.")
+ self.t_initial = t_initial
+ self.power = power
+ self.lr_min = lr_min
+ self.cycle_mul = cycle_mul
+ self.cycle_decay = cycle_decay
+ self.cycle_limit = cycle_limit
+ self.warmup_t = warmup_t
+ self.warmup_lr_init = warmup_lr_init
+ self.warmup_prefix = warmup_prefix
+ self.t_in_epochs = t_in_epochs
+ self.k_decay = k_decay
+ if self.warmup_t:
+ self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
+ super().update_groups(self.warmup_lr_init)
+ else:
+ self.warmup_steps = [1 for _ in self.base_values]
+
+ def _get_lr(self, t):
+ if t < self.warmup_t:
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
+ else:
+ if self.warmup_prefix:
+ t = t - self.warmup_t
+
+ if self.cycle_mul != 1:
+ i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
+ t_i = self.cycle_mul ** i * self.t_initial
+ t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
+ else:
+ i = t // self.t_initial
+ t_i = self.t_initial
+ t_curr = t - (self.t_initial * i)
+
+ gamma = self.cycle_decay ** i
+ lr_max_values = [v * gamma for v in self.base_values]
+ k = self.k_decay
+
+ if i < self.cycle_limit:
+ lrs = [
+ self.lr_min + (lr_max - self.lr_min) * (1 - t_curr ** k / t_i ** k) ** self.power
+ for lr_max in lr_max_values
+ ]
+ else:
+ lrs = [self.lr_min for _ in self.base_values]
+
+ return lrs
+
+ def get_epoch_values(self, epoch: int):
+ if self.t_in_epochs:
+ return self._get_lr(epoch)
+ else:
+ return None
+
+ def get_update_values(self, num_updates: int):
+ if not self.t_in_epochs:
+ return self._get_lr(num_updates)
+ else:
+ return None
+
+ def get_cycle_length(self, cycles=0):
+ cycles = max(1, cycles or self.cycle_limit)
+ if self.cycle_mul == 1.0:
+ return self.t_initial * cycles
+ else:
+ return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py
new file mode 100644
index 0000000..21d5150
--- /dev/null
+++ b/timm/scheduler/scheduler.py
@@ -0,0 +1,105 @@
+from typing import Dict, Any
+
+import torch
+
+
+class Scheduler:
+ """ Parameter Scheduler Base Class
+ A scheduler base class that can be used to schedule any optimizer parameter groups.
+
+ Unlike the builtin PyTorch schedulers, this is intended to be consistently called
+ * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
+ * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
+
+ The schedulers built on this should try to remain as stateless as possible (for simplicity).
+
+ This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
+ and -1 values for special behaviour. All epoch and update counts must be tracked in the training
+ code and explicitly passed in to the schedulers on the corresponding step or step_update call.
+
+ Based on ideas from:
+ * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
+ * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
+ """
+
+ def __init__(self,
+ optimizer: torch.optim.Optimizer,
+ param_group_field: str,
+ noise_range_t=None,
+ noise_type='normal',
+ noise_pct=0.67,
+ noise_std=1.0,
+ noise_seed=None,
+ initialize: bool = True) -> None:
+ self.optimizer = optimizer
+ self.param_group_field = param_group_field
+ self._initial_param_group_field = f"initial_{param_group_field}"
+ if initialize:
+ for i, group in enumerate(self.optimizer.param_groups):
+ if param_group_field not in group:
+ raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
+ group.setdefault(self._initial_param_group_field, group[param_group_field])
+ else:
+ for i, group in enumerate(self.optimizer.param_groups):
+ if self._initial_param_group_field not in group:
+ raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
+ self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
+ self.metric = None # any point to having this for all?
+ self.noise_range_t = noise_range_t
+ self.noise_pct = noise_pct
+ self.noise_type = noise_type
+ self.noise_std = noise_std
+ self.noise_seed = noise_seed if noise_seed is not None else 42
+ self.update_groups(self.base_values)
+
+ def state_dict(self) -> Dict[str, Any]:
+ return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
+
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+ self.__dict__.update(state_dict)
+
+ def get_epoch_values(self, epoch: int):
+ return None
+
+ def get_update_values(self, num_updates: int):
+ return None
+
+ def step(self, epoch: int, metric: float = None) -> None:
+ self.metric = metric
+ values = self.get_epoch_values(epoch)
+ if values is not None:
+ values = self._add_noise(values, epoch)
+ self.update_groups(values)
+
+ def step_update(self, num_updates: int, metric: float = None):
+ self.metric = metric
+ values = self.get_update_values(num_updates)
+ if values is not None:
+ values = self._add_noise(values, num_updates)
+ self.update_groups(values)
+
+ def update_groups(self, values):
+ if not isinstance(values, (list, tuple)):
+ values = [values] * len(self.optimizer.param_groups)
+ for param_group, value in zip(self.optimizer.param_groups, values):
+ param_group[self.param_group_field] = value
+
+ def _add_noise(self, lrs, t):
+ if self.noise_range_t is not None:
+ if isinstance(self.noise_range_t, (list, tuple)):
+ apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
+ else:
+ apply_noise = t >= self.noise_range_t
+ if apply_noise:
+ g = torch.Generator()
+ g.manual_seed(self.noise_seed + t)
+ if self.noise_type == 'normal':
+ while True:
+ # resample if noise out of percent limit, brute force but shouldn't spin much
+ noise = torch.randn(1, generator=g).item()
+ if abs(noise) < self.noise_pct:
+ break
+ else:
+ noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
+ lrs = [v + v * noise for v in lrs]
+ return lrs
diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py
new file mode 100644
index 0000000..72a979c
--- /dev/null
+++ b/timm/scheduler/scheduler_factory.py
@@ -0,0 +1,107 @@
+""" Scheduler Factory
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+from .cosine_lr import CosineLRScheduler
+from .multistep_lr import MultiStepLRScheduler
+from .plateau_lr import PlateauLRScheduler
+from .poly_lr import PolyLRScheduler
+from .step_lr import StepLRScheduler
+from .tanh_lr import TanhLRScheduler
+
+
+def create_scheduler(args, optimizer):
+ num_epochs = args.epochs
+
+ if getattr(args, 'lr_noise', None) is not None:
+ lr_noise = getattr(args, 'lr_noise')
+ if isinstance(lr_noise, (list, tuple)):
+ noise_range = [n * num_epochs for n in lr_noise]
+ if len(noise_range) == 1:
+ noise_range = noise_range[0]
+ else:
+ noise_range = lr_noise * num_epochs
+ else:
+ noise_range = None
+ noise_args = dict(
+ noise_range_t=noise_range,
+ noise_pct=getattr(args, 'lr_noise_pct', 0.67),
+ noise_std=getattr(args, 'lr_noise_std', 1.),
+ noise_seed=getattr(args, 'seed', 42),
+ )
+ cycle_args = dict(
+ cycle_mul=getattr(args, 'lr_cycle_mul', 1.),
+ cycle_decay=getattr(args, 'lr_cycle_decay', 0.1),
+ cycle_limit=getattr(args, 'lr_cycle_limit', 1),
+ )
+
+ lr_scheduler = None
+ if args.sched == 'cosine':
+ lr_scheduler = CosineLRScheduler(
+ optimizer,
+ t_initial=num_epochs,
+ lr_min=args.min_lr,
+ warmup_lr_init=args.warmup_lr,
+ warmup_t=args.warmup_epochs,
+ k_decay=getattr(args, 'lr_k_decay', 1.0),
+ **cycle_args,
+ **noise_args,
+ )
+ num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
+ elif args.sched == 'tanh':
+ lr_scheduler = TanhLRScheduler(
+ optimizer,
+ t_initial=num_epochs,
+ lr_min=args.min_lr,
+ warmup_lr_init=args.warmup_lr,
+ warmup_t=args.warmup_epochs,
+ t_in_epochs=True,
+ **cycle_args,
+ **noise_args,
+ )
+ num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
+ elif args.sched == 'step':
+ lr_scheduler = StepLRScheduler(
+ optimizer,
+ decay_t=args.decay_epochs,
+ decay_rate=args.decay_rate,
+ warmup_lr_init=args.warmup_lr,
+ warmup_t=args.warmup_epochs,
+ **noise_args,
+ )
+ elif args.sched == 'multistep':
+ lr_scheduler = MultiStepLRScheduler(
+ optimizer,
+ decay_t=args.decay_epochs,
+ decay_rate=args.decay_rate,
+ warmup_lr_init=args.warmup_lr,
+ warmup_t=args.warmup_epochs,
+ **noise_args,
+ )
+ elif args.sched == 'plateau':
+ mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
+ lr_scheduler = PlateauLRScheduler(
+ optimizer,
+ decay_rate=args.decay_rate,
+ patience_t=args.patience_epochs,
+ lr_min=args.min_lr,
+ mode=mode,
+ warmup_lr_init=args.warmup_lr,
+ warmup_t=args.warmup_epochs,
+ cooldown_t=0,
+ **noise_args,
+ )
+ elif args.sched == 'poly':
+ lr_scheduler = PolyLRScheduler(
+ optimizer,
+ power=args.decay_rate, # overloading 'decay_rate' as polynomial power
+ t_initial=num_epochs,
+ lr_min=args.min_lr,
+ warmup_lr_init=args.warmup_lr,
+ warmup_t=args.warmup_epochs,
+ k_decay=getattr(args, 'lr_k_decay', 1.0),
+ **cycle_args,
+ **noise_args,
+ )
+ num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
+
+ return lr_scheduler, num_epochs
diff --git a/timm/scheduler/step_lr.py b/timm/scheduler/step_lr.py
new file mode 100644
index 0000000..f797e1a
--- /dev/null
+++ b/timm/scheduler/step_lr.py
@@ -0,0 +1,63 @@
+""" Step Scheduler
+
+Basic step LR schedule with warmup, noise.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import math
+import torch
+
+from .scheduler import Scheduler
+
+
+class StepLRScheduler(Scheduler):
+ """
+ """
+
+ def __init__(self,
+ optimizer: torch.optim.Optimizer,
+ decay_t: float,
+ decay_rate: float = 1.,
+ warmup_t=0,
+ warmup_lr_init=0,
+ t_in_epochs=True,
+ noise_range_t=None,
+ noise_pct=0.67,
+ noise_std=1.0,
+ noise_seed=42,
+ initialize=True,
+ ) -> None:
+ super().__init__(
+ optimizer, param_group_field="lr",
+ noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
+ initialize=initialize)
+
+ self.decay_t = decay_t
+ self.decay_rate = decay_rate
+ self.warmup_t = warmup_t
+ self.warmup_lr_init = warmup_lr_init
+ self.t_in_epochs = t_in_epochs
+ if self.warmup_t:
+ self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
+ super().update_groups(self.warmup_lr_init)
+ else:
+ self.warmup_steps = [1 for _ in self.base_values]
+
+ def _get_lr(self, t):
+ if t < self.warmup_t:
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
+ else:
+ lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
+ return lrs
+
+ def get_epoch_values(self, epoch: int):
+ if self.t_in_epochs:
+ return self._get_lr(epoch)
+ else:
+ return None
+
+ def get_update_values(self, num_updates: int):
+ if not self.t_in_epochs:
+ return self._get_lr(num_updates)
+ else:
+ return None
diff --git a/timm/scheduler/tanh_lr.py b/timm/scheduler/tanh_lr.py
new file mode 100644
index 0000000..f2d3c9c
--- /dev/null
+++ b/timm/scheduler/tanh_lr.py
@@ -0,0 +1,117 @@
+""" TanH Scheduler
+
+TanH schedule with warmup, cycle/restarts, noise.
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import logging
+import math
+import numpy as np
+import torch
+
+from .scheduler import Scheduler
+
+
+_logger = logging.getLogger(__name__)
+
+
+class TanhLRScheduler(Scheduler):
+ """
+ Hyberbolic-Tangent decay with restarts.
+ This is described in the paper https://arxiv.org/abs/1806.01593
+ """
+
+ def __init__(self,
+ optimizer: torch.optim.Optimizer,
+ t_initial: int,
+ lb: float = -7.,
+ ub: float = 3.,
+ lr_min: float = 0.,
+ cycle_mul: float = 1.,
+ cycle_decay: float = 1.,
+ cycle_limit: int = 1,
+ warmup_t=0,
+ warmup_lr_init=0,
+ warmup_prefix=False,
+ t_in_epochs=True,
+ noise_range_t=None,
+ noise_pct=0.67,
+ noise_std=1.0,
+ noise_seed=42,
+ initialize=True) -> None:
+ super().__init__(
+ optimizer, param_group_field="lr",
+ noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
+ initialize=initialize)
+
+ assert t_initial > 0
+ assert lr_min >= 0
+ assert lb < ub
+ assert cycle_limit >= 0
+ assert warmup_t >= 0
+ assert warmup_lr_init >= 0
+ self.lb = lb
+ self.ub = ub
+ self.t_initial = t_initial
+ self.lr_min = lr_min
+ self.cycle_mul = cycle_mul
+ self.cycle_decay = cycle_decay
+ self.cycle_limit = cycle_limit
+ self.warmup_t = warmup_t
+ self.warmup_lr_init = warmup_lr_init
+ self.warmup_prefix = warmup_prefix
+ self.t_in_epochs = t_in_epochs
+ if self.warmup_t:
+ t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
+ self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
+ super().update_groups(self.warmup_lr_init)
+ else:
+ self.warmup_steps = [1 for _ in self.base_values]
+
+ def _get_lr(self, t):
+ if t < self.warmup_t:
+ lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
+ else:
+ if self.warmup_prefix:
+ t = t - self.warmup_t
+
+ if self.cycle_mul != 1:
+ i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
+ t_i = self.cycle_mul ** i * self.t_initial
+ t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
+ else:
+ i = t // self.t_initial
+ t_i = self.t_initial
+ t_curr = t - (self.t_initial * i)
+
+ if i < self.cycle_limit:
+ gamma = self.cycle_decay ** i
+ lr_max_values = [v * gamma for v in self.base_values]
+
+ tr = t_curr / t_i
+ lrs = [
+ self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 - math.tanh(self.lb * (1. - tr) + self.ub * tr))
+ for lr_max in lr_max_values
+ ]
+ else:
+ lrs = [self.lr_min for _ in self.base_values]
+ return lrs
+
+ def get_epoch_values(self, epoch: int):
+ if self.t_in_epochs:
+ return self._get_lr(epoch)
+ else:
+ return None
+
+ def get_update_values(self, num_updates: int):
+ if not self.t_in_epochs:
+ return self._get_lr(num_updates)
+ else:
+ return None
+
+ def get_cycle_length(self, cycles=0):
+ cycles = max(1, cycles or self.cycle_limit)
+ if self.cycle_mul == 1.0:
+ return self.t_initial * cycles
+ else:
+ return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py
new file mode 100644
index 0000000..11de9c9
--- /dev/null
+++ b/timm/utils/__init__.py
@@ -0,0 +1,13 @@
+from .agc import adaptive_clip_grad
+from .checkpoint_saver import CheckpointSaver
+from .clip_grad import dispatch_clip_grad
+from .cuda import ApexScaler, NativeScaler
+from .distributed import distribute_bn, reduce_tensor
+from .jit import set_jit_legacy
+from .log import setup_default_logging, FormatterNoInfo
+from .metrics import AverageMeter, accuracy
+from .misc import natural_key, add_bool_arg
+from .model import unwrap_model, get_state_dict, freeze, unfreeze
+from .model_ema import ModelEma, ModelEmaV2
+from .random import random_seed
+from .summary import update_summary, get_outdir
diff --git a/timm/utils/__pycache__/__init__.cpython-36.pyc b/timm/utils/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..16b9809
Binary files /dev/null and b/timm/utils/__pycache__/__init__.cpython-36.pyc differ
diff --git a/timm/utils/__pycache__/agc.cpython-36.pyc b/timm/utils/__pycache__/agc.cpython-36.pyc
new file mode 100644
index 0000000..657b32d
Binary files /dev/null and b/timm/utils/__pycache__/agc.cpython-36.pyc differ
diff --git a/timm/utils/__pycache__/checkpoint_saver.cpython-36.pyc b/timm/utils/__pycache__/checkpoint_saver.cpython-36.pyc
new file mode 100644
index 0000000..2af661c
Binary files /dev/null and b/timm/utils/__pycache__/checkpoint_saver.cpython-36.pyc differ
diff --git a/timm/utils/__pycache__/clip_grad.cpython-36.pyc b/timm/utils/__pycache__/clip_grad.cpython-36.pyc
new file mode 100644
index 0000000..475c207
Binary files /dev/null and b/timm/utils/__pycache__/clip_grad.cpython-36.pyc differ
diff --git a/timm/utils/__pycache__/cuda.cpython-36.pyc b/timm/utils/__pycache__/cuda.cpython-36.pyc
new file mode 100644
index 0000000..bf9fe8d
Binary files /dev/null and b/timm/utils/__pycache__/cuda.cpython-36.pyc differ
diff --git a/timm/utils/__pycache__/distributed.cpython-36.pyc b/timm/utils/__pycache__/distributed.cpython-36.pyc
new file mode 100644
index 0000000..30caa73
Binary files /dev/null and b/timm/utils/__pycache__/distributed.cpython-36.pyc differ
diff --git a/timm/utils/__pycache__/jit.cpython-36.pyc b/timm/utils/__pycache__/jit.cpython-36.pyc
new file mode 100644
index 0000000..ae89166
Binary files /dev/null and b/timm/utils/__pycache__/jit.cpython-36.pyc differ
diff --git a/timm/utils/__pycache__/log.cpython-36.pyc b/timm/utils/__pycache__/log.cpython-36.pyc
new file mode 100644
index 0000000..9afd07b
Binary files /dev/null and b/timm/utils/__pycache__/log.cpython-36.pyc differ
diff --git a/timm/utils/__pycache__/metrics.cpython-36.pyc b/timm/utils/__pycache__/metrics.cpython-36.pyc
new file mode 100644
index 0000000..6cc6e8b
Binary files /dev/null and b/timm/utils/__pycache__/metrics.cpython-36.pyc differ
diff --git a/timm/utils/__pycache__/misc.cpython-36.pyc b/timm/utils/__pycache__/misc.cpython-36.pyc
new file mode 100644
index 0000000..c530081
Binary files /dev/null and b/timm/utils/__pycache__/misc.cpython-36.pyc differ
diff --git a/timm/utils/__pycache__/model.cpython-36.pyc b/timm/utils/__pycache__/model.cpython-36.pyc
new file mode 100644
index 0000000..443a1ba
Binary files /dev/null and b/timm/utils/__pycache__/model.cpython-36.pyc differ
diff --git a/timm/utils/__pycache__/model_ema.cpython-36.pyc b/timm/utils/__pycache__/model_ema.cpython-36.pyc
new file mode 100644
index 0000000..3582048
Binary files /dev/null and b/timm/utils/__pycache__/model_ema.cpython-36.pyc differ
diff --git a/timm/utils/__pycache__/random.cpython-36.pyc b/timm/utils/__pycache__/random.cpython-36.pyc
new file mode 100644
index 0000000..f36798a
Binary files /dev/null and b/timm/utils/__pycache__/random.cpython-36.pyc differ
diff --git a/timm/utils/__pycache__/summary.cpython-36.pyc b/timm/utils/__pycache__/summary.cpython-36.pyc
new file mode 100644
index 0000000..95657c2
Binary files /dev/null and b/timm/utils/__pycache__/summary.cpython-36.pyc differ
diff --git a/timm/utils/agc.py b/timm/utils/agc.py
new file mode 100644
index 0000000..f514017
--- /dev/null
+++ b/timm/utils/agc.py
@@ -0,0 +1,42 @@
+""" Adaptive Gradient Clipping
+
+An impl of AGC, as per (https://arxiv.org/abs/2102.06171):
+
+@article{brock2021high,
+ author={Andrew Brock and Soham De and Samuel L. Smith and Karen Simonyan},
+ title={High-Performance Large-Scale Image Recognition Without Normalization},
+ journal={arXiv preprint arXiv:},
+ year={2021}
+}
+
+Code references:
+ * Official JAX impl (paper authors): https://github.com/deepmind/deepmind-research/tree/master/nfnets
+ * Phil Wang's PyTorch gist: https://gist.github.com/lucidrains/0d6560077edac419ab5d3aa29e674d5c
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import torch
+
+
+def unitwise_norm(x, norm_type=2.0):
+ if x.ndim <= 1:
+ return x.norm(norm_type)
+ else:
+ # works for nn.ConvNd and nn,Linear where output dim is first in the kernel/weight tensor
+ # might need special cases for other weights (possibly MHA) where this may not be true
+ return x.norm(norm_type, dim=tuple(range(1, x.ndim)), keepdim=True)
+
+
+def adaptive_clip_grad(parameters, clip_factor=0.01, eps=1e-3, norm_type=2.0):
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ for p in parameters:
+ if p.grad is None:
+ continue
+ p_data = p.detach()
+ g_data = p.grad.detach()
+ max_norm = unitwise_norm(p_data, norm_type=norm_type).clamp_(min=eps).mul_(clip_factor)
+ grad_norm = unitwise_norm(g_data, norm_type=norm_type)
+ clipped_grad = g_data * (max_norm / grad_norm.clamp(min=1e-6))
+ new_grads = torch.where(grad_norm < max_norm, g_data, clipped_grad)
+ p.grad.detach().copy_(new_grads)
diff --git a/timm/utils/checkpoint_saver.py b/timm/utils/checkpoint_saver.py
new file mode 100644
index 0000000..6aad74e
--- /dev/null
+++ b/timm/utils/checkpoint_saver.py
@@ -0,0 +1,150 @@
+""" Checkpoint Saver
+
+Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+import glob
+import operator
+import os
+import logging
+
+import torch
+
+from .model import unwrap_model, get_state_dict
+
+
+_logger = logging.getLogger(__name__)
+
+
+class CheckpointSaver:
+ def __init__(
+ self,
+ model,
+ optimizer,
+ args=None,
+ model_ema=None,
+ amp_scaler=None,
+ checkpoint_prefix='checkpoint',
+ recovery_prefix='recovery',
+ checkpoint_dir='',
+ recovery_dir='',
+ decreasing=False,
+ max_history=10,
+ unwrap_fn=unwrap_model):
+
+ # objects to save state_dicts of
+ self.model = model
+ self.optimizer = optimizer
+ self.args = args
+ self.model_ema = model_ema
+ self.amp_scaler = amp_scaler
+
+ # state
+ self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
+ self.best_epoch = None
+ self.best_metric = None
+ self.curr_recovery_file = ''
+ self.last_recovery_file = ''
+
+ # config
+ self.checkpoint_dir = checkpoint_dir
+ self.recovery_dir = recovery_dir
+ self.save_prefix = checkpoint_prefix
+ self.recovery_prefix = recovery_prefix
+ self.extension = '.pth.tar'
+ self.decreasing = decreasing # a lower metric is better if True
+ self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
+ self.max_history = max_history
+ self.unwrap_fn = unwrap_fn
+ assert self.max_history >= 1
+
+ def save_checkpoint(self, epoch, metric=None):
+ assert epoch >= 0
+ tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
+ last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
+ self._save(tmp_save_path, epoch, metric)
+ if os.path.exists(last_save_path):
+ os.unlink(last_save_path) # required for Windows support.
+ os.rename(tmp_save_path, last_save_path)
+ worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
+ if (len(self.checkpoint_files) < self.max_history
+ or metric is None or self.cmp(metric, worst_file[1])):
+ if len(self.checkpoint_files) >= self.max_history:
+ self._cleanup_checkpoints(1)
+ filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
+ save_path = os.path.join(self.checkpoint_dir, filename)
+ os.link(last_save_path, save_path)
+ self.checkpoint_files.append((save_path, metric))
+ self.checkpoint_files = sorted(
+ self.checkpoint_files, key=lambda x: x[1],
+ reverse=not self.decreasing) # sort in descending order if a lower metric is not better
+
+ checkpoints_str = "Current checkpoints:\n"
+ for c in self.checkpoint_files:
+ checkpoints_str += ' {}\n'.format(c)
+ _logger.info(checkpoints_str)
+
+ if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
+ self.best_epoch = epoch
+ self.best_metric = metric
+ best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
+ if os.path.exists(best_save_path):
+ os.unlink(best_save_path)
+ os.link(last_save_path, best_save_path)
+
+ return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
+
+ def _save(self, save_path, epoch, metric=None):
+ save_state = {
+ 'epoch': epoch,
+ 'arch': type(self.model).__name__.lower(),
+ 'state_dict': get_state_dict(self.model, self.unwrap_fn),
+ 'optimizer': self.optimizer.state_dict(),
+ 'version': 2, # version < 2 increments epoch before save
+ }
+ if self.args is not None:
+ save_state['arch'] = self.args.model
+ save_state['args'] = self.args
+ if self.amp_scaler is not None:
+ save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
+ if self.model_ema is not None:
+ save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
+ if metric is not None:
+ save_state['metric'] = metric
+ torch.save(save_state, save_path)
+
+ def _cleanup_checkpoints(self, trim=0):
+ trim = min(len(self.checkpoint_files), trim)
+ delete_index = self.max_history - trim
+ if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
+ return
+ to_delete = self.checkpoint_files[delete_index:]
+ for d in to_delete:
+ try:
+ _logger.debug("Cleaning checkpoint: {}".format(d))
+ os.remove(d[0])
+ except Exception as e:
+ _logger.error("Exception '{}' while deleting checkpoint".format(e))
+ self.checkpoint_files = self.checkpoint_files[:delete_index]
+
+ def save_recovery(self, epoch, batch_idx=0):
+ assert epoch >= 0
+ filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
+ save_path = os.path.join(self.recovery_dir, filename)
+ self._save(save_path, epoch)
+ if os.path.exists(self.last_recovery_file):
+ try:
+ _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
+ os.remove(self.last_recovery_file)
+ except Exception as e:
+ _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
+ self.last_recovery_file = self.curr_recovery_file
+ self.curr_recovery_file = save_path
+
+ def find_recovery(self):
+ recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
+ files = glob.glob(recovery_path + '*' + self.extension)
+ files = sorted(files)
+ return files[0] if len(files) else ''
diff --git a/timm/utils/clip_grad.py b/timm/utils/clip_grad.py
new file mode 100644
index 0000000..7eb4069
--- /dev/null
+++ b/timm/utils/clip_grad.py
@@ -0,0 +1,23 @@
+import torch
+
+from timm.utils.agc import adaptive_clip_grad
+
+
+def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
+ """ Dispatch to gradient clipping method
+
+ Args:
+ parameters (Iterable): model parameters to clip
+ value (float): clipping value/factor/norm, mode dependant
+ mode (str): clipping mode, one of 'norm', 'value', 'agc'
+ norm_type (float): p-norm, default 2.0
+ """
+ if mode == 'norm':
+ torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
+ elif mode == 'value':
+ torch.nn.utils.clip_grad_value_(parameters, value)
+ elif mode == 'agc':
+ adaptive_clip_grad(parameters, value, norm_type=norm_type)
+ else:
+ assert False, f"Unknown clip mode ({mode})."
+
diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py
new file mode 100644
index 0000000..9e7bddf
--- /dev/null
+++ b/timm/utils/cuda.py
@@ -0,0 +1,55 @@
+""" CUDA / AMP utils
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+
+try:
+ from apex import amp
+ has_apex = True
+except ImportError:
+ amp = None
+ has_apex = False
+
+from .clip_grad import dispatch_clip_grad
+
+
+class ApexScaler:
+ state_dict_key = "amp"
+
+ def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
+ scaled_loss.backward(create_graph=create_graph)
+ if clip_grad is not None:
+ dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode)
+ optimizer.step()
+
+ def state_dict(self):
+ if 'state_dict' in amp.__dict__:
+ return amp.state_dict()
+
+ def load_state_dict(self, state_dict):
+ if 'load_state_dict' in amp.__dict__:
+ amp.load_state_dict(state_dict)
+
+
+class NativeScaler:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self):
+ self._scaler = torch.cuda.amp.GradScaler()
+
+ def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if clip_grad is not None:
+ assert parameters is not None
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
+ dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
+ self._scaler.step(optimizer)
+ self._scaler.update()
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
diff --git a/timm/utils/distributed.py b/timm/utils/distributed.py
new file mode 100644
index 0000000..3c5dba8
--- /dev/null
+++ b/timm/utils/distributed.py
@@ -0,0 +1,28 @@
+""" Distributed training/validation utils
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+from torch import distributed as dist
+
+from .model import unwrap_model
+
+
+def reduce_tensor(tensor, n):
+ rt = tensor.clone()
+ dist.all_reduce(rt, op=dist.ReduceOp.SUM)
+ rt /= n
+ return rt
+
+
+def distribute_bn(model, world_size, reduce=False):
+ # ensure every node has the same running bn stats
+ for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
+ if ('running_mean' in bn_name) or ('running_var' in bn_name):
+ if reduce:
+ # average bn stats across whole group
+ torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
+ bn_buf /= float(world_size)
+ else:
+ # broadcast bn stats from rank 0 to whole group
+ torch.distributed.broadcast(bn_buf, 0)
diff --git a/timm/utils/jit.py b/timm/utils/jit.py
new file mode 100644
index 0000000..185ab7a
--- /dev/null
+++ b/timm/utils/jit.py
@@ -0,0 +1,18 @@
+""" JIT scripting/tracing utils
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import torch
+
+
+def set_jit_legacy():
+ """ Set JIT executor to legacy w/ support for op fusion
+ This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
+ in the JIT exectutor. These API are not supported so could change.
+ """
+ #
+ assert hasattr(torch._C, '_jit_set_profiling_executor'), "Old JIT behavior doesn't exist!"
+ torch._C._jit_set_profiling_executor(False)
+ torch._C._jit_set_profiling_mode(False)
+ torch._C._jit_override_can_fuse_on_gpu(True)
+ #torch._C._jit_set_texpr_fuser_enabled(True)
diff --git a/timm/utils/log.py b/timm/utils/log.py
new file mode 100644
index 0000000..c99469e
--- /dev/null
+++ b/timm/utils/log.py
@@ -0,0 +1,28 @@
+""" Logging helpers
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import logging
+import logging.handlers
+
+
+class FormatterNoInfo(logging.Formatter):
+ def __init__(self, fmt='%(levelname)s: %(message)s'):
+ logging.Formatter.__init__(self, fmt)
+
+ def format(self, record):
+ if record.levelno == logging.INFO:
+ return str(record.getMessage())
+ return logging.Formatter.format(self, record)
+
+
+def setup_default_logging(default_level=logging.INFO, log_path=''):
+ console_handler = logging.StreamHandler()
+ console_handler.setFormatter(FormatterNoInfo())
+ logging.root.addHandler(console_handler)
+ logging.root.setLevel(default_level)
+ if log_path:
+ file_handler = logging.handlers.RotatingFileHandler(log_path, maxBytes=(1024 ** 2 * 2), backupCount=3)
+ file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s")
+ file_handler.setFormatter(file_formatter)
+ logging.root.addHandler(file_handler)
diff --git a/timm/utils/metrics.py b/timm/utils/metrics.py
new file mode 100644
index 0000000..9fdbe13
--- /dev/null
+++ b/timm/utils/metrics.py
@@ -0,0 +1,32 @@
+""" Eval metrics and related
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+
+class AverageMeter:
+ """Computes and stores the average and current value"""
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def accuracy(output, target, topk=(1,)):
+ """Computes the accuracy over the k top predictions for the specified values of k"""
+ maxk = min(max(topk), output.size()[1])
+ batch_size = target.size(0)
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
+ return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
diff --git a/timm/utils/misc.py b/timm/utils/misc.py
new file mode 100644
index 0000000..39c0097
--- /dev/null
+++ b/timm/utils/misc.py
@@ -0,0 +1,18 @@
+""" Misc utils
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import re
+
+
+def natural_key(string_):
+ """See http://www.codinghorror.com/blog/archives/001018.html"""
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
+
+
+def add_bool_arg(parser, name, default=False, help=''):
+ dest_name = name.replace('-', '_')
+ group = parser.add_mutually_exclusive_group(required=False)
+ group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
+ group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
+ parser.set_defaults(**{dest_name: default})
diff --git a/timm/utils/model.py b/timm/utils/model.py
new file mode 100644
index 0000000..b95c453
--- /dev/null
+++ b/timm/utils/model.py
@@ -0,0 +1,273 @@
+""" Model / state_dict utils
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import fnmatch
+
+import torch
+from torchvision.ops.misc import FrozenBatchNorm2d
+
+from .model_ema import ModelEma
+
+
+def unwrap_model(model):
+ if isinstance(model, ModelEma):
+ return unwrap_model(model.ema)
+ else:
+ return model.module if hasattr(model, 'module') else model
+
+
+def get_state_dict(model, unwrap_fn=unwrap_model):
+ return unwrap_fn(model).state_dict()
+
+
+def avg_sq_ch_mean(model, input, output):
+ """ calculate average channel square mean of output activations
+ """
+ return torch.mean(output.mean(axis=[0, 2, 3]) ** 2).item()
+
+
+def avg_ch_var(model, input, output):
+ """ calculate average channel variance of output activations
+ """
+ return torch.mean(output.var(axis=[0, 2, 3])).item()
+
+
+def avg_ch_var_residual(model, input, output):
+ """ calculate average channel variance of output activations
+ """
+ return torch.mean(output.var(axis=[0, 2, 3])).item()
+
+
+class ActivationStatsHook:
+ """Iterates through each of `model`'s modules and matches modules using unix pattern
+ matching based on `hook_fn_locs` and registers `hook_fn` to the module if there is
+ a match.
+
+ Arguments:
+ model (nn.Module): model from which we will extract the activation stats
+ hook_fn_locs (List[str]): List of `hook_fn` locations based on Unix type string
+ matching with the name of model's modules.
+ hook_fns (List[Callable]): List of hook functions to be registered at every
+ module in `layer_names`.
+
+ Inspiration from https://docs.fast.ai/callback.hook.html.
+
+ Refer to https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950 for an example
+ on how to plot Signal Propogation Plots using `ActivationStatsHook`.
+ """
+
+ def __init__(self, model, hook_fn_locs, hook_fns):
+ self.model = model
+ self.hook_fn_locs = hook_fn_locs
+ self.hook_fns = hook_fns
+ if len(hook_fn_locs) != len(hook_fns):
+ raise ValueError("Please provide `hook_fns` for each `hook_fn_locs`, \
+ their lengths are different.")
+ self.stats = dict((hook_fn.__name__, []) for hook_fn in hook_fns)
+ for hook_fn_loc, hook_fn in zip(hook_fn_locs, hook_fns):
+ self.register_hook(hook_fn_loc, hook_fn)
+
+ def _create_hook(self, hook_fn):
+ def append_activation_stats(module, input, output):
+ out = hook_fn(module, input, output)
+ self.stats[hook_fn.__name__].append(out)
+
+ return append_activation_stats
+
+ def register_hook(self, hook_fn_loc, hook_fn):
+ for name, module in self.model.named_modules():
+ if not fnmatch.fnmatch(name, hook_fn_loc):
+ continue
+ module.register_forward_hook(self._create_hook(hook_fn))
+
+
+def extract_spp_stats(
+ model,
+ hook_fn_locs,
+ hook_fns,
+ input_shape=[8, 3, 224, 224]):
+ """Extract average square channel mean and variance of activations during
+ forward pass to plot Signal Propogation Plots (SPP).
+
+ Paper: https://arxiv.org/abs/2101.08692
+
+ Example Usage: https://gist.github.com/amaarora/6e56942fcb46e67ba203f3009b30d950
+ """
+ x = torch.normal(0., 1., input_shape)
+ hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
+ _ = model(x)
+ return hook.stats
+
+
+def freeze_batch_norm_2d(module):
+ """
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
+
+ Args:
+ module (torch.nn.Module): Any PyTorch module.
+
+ Returns:
+ torch.nn.Module: Resulting module
+
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
+ """
+ res = module
+ if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
+ res = FrozenBatchNorm2d(module.num_features)
+ res.num_features = module.num_features
+ res.affine = module.affine
+ if module.affine:
+ res.weight.data = module.weight.data.clone().detach()
+ res.bias.data = module.bias.data.clone().detach()
+ res.running_mean.data = module.running_mean.data
+ res.running_var.data = module.running_var.data
+ res.eps = module.eps
+ else:
+ for name, child in module.named_children():
+ new_child = freeze_batch_norm_2d(child)
+ if new_child is not child:
+ res.add_module(name, new_child)
+ return res
+
+
+def unfreeze_batch_norm_2d(module):
+ """
+ Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
+ of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
+ recursively and submodules are converted in place.
+
+ Args:
+ module (torch.nn.Module): Any PyTorch module.
+
+ Returns:
+ torch.nn.Module: Resulting module
+
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
+ """
+ res = module
+ if isinstance(module, FrozenBatchNorm2d):
+ res = torch.nn.BatchNorm2d(module.num_features)
+ if module.affine:
+ res.weight.data = module.weight.data.clone().detach()
+ res.bias.data = module.bias.data.clone().detach()
+ res.running_mean.data = module.running_mean.data
+ res.running_var.data = module.running_var.data
+ res.eps = module.eps
+ else:
+ for name, child in module.named_children():
+ new_child = unfreeze_batch_norm_2d(child)
+ if new_child is not child:
+ res.add_module(name, new_child)
+ return res
+
+
+def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
+ """
+ Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
+ done in place.
+ Args:
+ root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced.
+ submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as
+ named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list
+ means that the whole root module will be (un)frozen. Defaults to []
+ include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm 2d layers.
+ Defaults to `True`.
+ mode (bool): Whether to freeze ("freeze") or unfreeze ("unfreeze"). Defaults to `"freeze"`.
+ """
+ assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"'
+
+ if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
+ # Raise assertion here because we can't convert it in place
+ raise AssertionError(
+ "You have provided a batch norm layer as the `root module`. Please use "
+ "`timm.utils.model.freeze_batch_norm_2d` or `timm.utils.model.unfreeze_batch_norm_2d` instead.")
+
+ if isinstance(submodules, str):
+ submodules = [submodules]
+
+ named_modules = submodules
+ submodules = [root_module.get_submodule(m) for m in submodules]
+
+ if not len(submodules):
+ named_modules, submodules = list(zip(*root_module.named_children()))
+
+ for n, m in zip(named_modules, submodules):
+ # (Un)freeze parameters
+ for p in m.parameters():
+ p.requires_grad = False if mode == 'freeze' else True
+ if include_bn_running_stats:
+ # Helper to add submodule specified as a named_module
+ def _add_submodule(module, name, submodule):
+ split = name.rsplit('.', 1)
+ if len(split) > 1:
+ module.get_submodule(split[0]).add_module(split[1], submodule)
+ else:
+ module.add_module(name, submodule)
+
+ # Freeze batch norm
+ if mode == 'freeze':
+ res = freeze_batch_norm_2d(m)
+ # It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
+ # convert it in place, but will return the converted result. In this case `res` holds the converted
+ # result and we may try to re-assign the named module
+ if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
+ _add_submodule(root_module, n, res)
+ # Unfreeze batch norm
+ else:
+ res = unfreeze_batch_norm_2d(m)
+ # Ditto. See note above in mode == 'freeze' branch
+ if isinstance(m, FrozenBatchNorm2d):
+ _add_submodule(root_module, n, res)
+
+
+def freeze(root_module, submodules=[], include_bn_running_stats=True):
+ """
+ Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
+ Args:
+ root_module (nn.Module): Root module relative to which `submodules` are referenced.
+ submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as
+ named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list
+ means that the whole root module will be frozen. Defaults to `[]`.
+ include_bn_running_stats (bool): Whether to also freeze the running statistics of `BatchNorm2d` and
+ `SyncBatchNorm` layers. These will be converted to `FrozenBatchNorm2d` in place. Hint: During fine tuning,
+ it's good practice to freeze batch norm stats. And note that these are different to the affine parameters
+ which are just normal PyTorch parameters. Defaults to `True`.
+
+ Hint: If you want to freeze batch norm ONLY, use `timm.utils.model.freeze_batch_norm_2d`.
+
+ Examples::
+
+ >>> model = timm.create_model('resnet18')
+ >>> # Freeze up to and including layer2
+ >>> submodules = [n for n, _ in model.named_children()]
+ >>> print(submodules)
+ ['conv1', 'bn1', 'act1', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'global_pool', 'fc']
+ >>> freeze(model, submodules[:submodules.index('layer2') + 1])
+ >>> # Check for yourself that it works as expected
+ >>> print(model.layer2[0].conv1.weight.requires_grad)
+ False
+ >>> print(model.layer3[0].conv1.weight.requires_grad)
+ True
+ >>> # Unfreeze
+ >>> unfreeze(model)
+ """
+ _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="freeze")
+
+
+def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
+ """
+ Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
+ Args:
+ root_module (nn.Module): Root module relative to which `submodules` are referenced.
+ submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided
+ as named modules relative to the root module (accessible via `root_module.named_modules()`). An empty
+ list means that the whole root module will be unfrozen. Defaults to `[]`.
+ include_bn_running_stats (bool): Whether to also unfreeze the running statistics of `FrozenBatchNorm2d` layers.
+ These will be converted to `BatchNorm2d` in place. Defaults to `True`.
+
+ See example in docstring for `freeze`.
+ """
+ _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze")
diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py
new file mode 100644
index 0000000..073d5c5
--- /dev/null
+++ b/timm/utils/model_ema.py
@@ -0,0 +1,126 @@
+""" Exponential Moving Average (EMA) of model updates
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import logging
+from collections import OrderedDict
+from copy import deepcopy
+
+import torch
+import torch.nn as nn
+
+_logger = logging.getLogger(__name__)
+
+
+class ModelEma:
+ """ Model Exponential Moving Average (DEPRECATED)
+
+ Keep a moving average of everything in the model state_dict (parameters and buffers).
+ This version is deprecated, it does not work with scripted models. Will be removed eventually.
+
+ This is intended to allow functionality like
+ https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
+
+ A smoothed version of the weights is necessary for some training schemes to perform well.
+ E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
+ RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
+ smoothing of weights to match results. Pay attention to the decay constant you are using
+ relative to your update count per epoch.
+
+ To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
+ disable validation of the EMA weights. Validation will have to be done manually in a separate
+ process, or after the training stops converging.
+
+ This class is sensitive where it is initialized in the sequence of model init,
+ GPU assignment and distributed training wrappers.
+ """
+ def __init__(self, model, decay=0.9999, device='', resume=''):
+ # make a copy of the model for accumulating moving average of weights
+ self.ema = deepcopy(model)
+ self.ema.eval()
+ self.decay = decay
+ self.device = device # perform ema on different device from model if set
+ if device:
+ self.ema.to(device=device)
+ self.ema_has_module = hasattr(self.ema, 'module')
+ if resume:
+ self._load_checkpoint(resume)
+ for p in self.ema.parameters():
+ p.requires_grad_(False)
+
+ def _load_checkpoint(self, checkpoint_path):
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+ assert isinstance(checkpoint, dict)
+ if 'state_dict_ema' in checkpoint:
+ new_state_dict = OrderedDict()
+ for k, v in checkpoint['state_dict_ema'].items():
+ # ema model may have been wrapped by DataParallel, and need module prefix
+ if self.ema_has_module:
+ name = 'module.' + k if not k.startswith('module') else k
+ else:
+ name = k
+ new_state_dict[name] = v
+ self.ema.load_state_dict(new_state_dict)
+ _logger.info("Loaded state_dict_ema")
+ else:
+ _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
+
+ def update(self, model):
+ # correct a mismatch in state dict keys
+ needs_module = hasattr(model, 'module') and not self.ema_has_module
+ with torch.no_grad():
+ msd = model.state_dict()
+ for k, ema_v in self.ema.state_dict().items():
+ if needs_module:
+ k = 'module.' + k
+ model_v = msd[k].detach()
+ if self.device:
+ model_v = model_v.to(device=self.device)
+ ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
+
+
+class ModelEmaV2(nn.Module):
+ """ Model Exponential Moving Average V2
+
+ Keep a moving average of everything in the model state_dict (parameters and buffers).
+ V2 of this module is simpler, it does not match params/buffers based on name but simply
+ iterates in order. It works with torchscript (JIT of full model).
+
+ This is intended to allow functionality like
+ https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
+
+ A smoothed version of the weights is necessary for some training schemes to perform well.
+ E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
+ RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
+ smoothing of weights to match results. Pay attention to the decay constant you are using
+ relative to your update count per epoch.
+
+ To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
+ disable validation of the EMA weights. Validation will have to be done manually in a separate
+ process, or after the training stops converging.
+
+ This class is sensitive where it is initialized in the sequence of model init,
+ GPU assignment and distributed training wrappers.
+ """
+ def __init__(self, model, decay=0.9999, device=None):
+ super(ModelEmaV2, self).__init__()
+ # make a copy of the model for accumulating moving average of weights
+ self.module = deepcopy(model)
+ self.module.eval()
+ self.decay = decay
+ self.device = device # perform ema on different device from model if set
+ if self.device is not None:
+ self.module.to(device=device)
+
+ def _update(self, model, update_fn):
+ with torch.no_grad():
+ for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
+ if self.device is not None:
+ model_v = model_v.to(device=self.device)
+ ema_v.copy_(update_fn(ema_v, model_v))
+
+ def update(self, model):
+ self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
+
+ def set(self, model):
+ self._update(model, update_fn=lambda e, m: m)
diff --git a/timm/utils/random.py b/timm/utils/random.py
new file mode 100644
index 0000000..a967998
--- /dev/null
+++ b/timm/utils/random.py
@@ -0,0 +1,9 @@
+import random
+import numpy as np
+import torch
+
+
+def random_seed(seed=42, rank=0):
+ torch.manual_seed(seed + rank)
+ np.random.seed(seed + rank)
+ random.seed(seed + rank)
diff --git a/timm/utils/summary.py b/timm/utils/summary.py
new file mode 100644
index 0000000..9f5af9a
--- /dev/null
+++ b/timm/utils/summary.py
@@ -0,0 +1,39 @@
+""" Summary utilities
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import csv
+import os
+from collections import OrderedDict
+try:
+ import wandb
+except ImportError:
+ pass
+
+def get_outdir(path, *paths, inc=False):
+ outdir = os.path.join(path, *paths)
+ if not os.path.exists(outdir):
+ os.makedirs(outdir)
+ elif inc:
+ count = 1
+ outdir_inc = outdir + '-' + str(count)
+ while os.path.exists(outdir_inc):
+ count = count + 1
+ outdir_inc = outdir + '-' + str(count)
+ assert count < 100
+ outdir = outdir_inc
+ os.makedirs(outdir)
+ return outdir
+
+
+def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False, log_wandb=False):
+ rowd = OrderedDict(epoch=epoch)
+ rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
+ rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
+ if log_wandb:
+ wandb.log(rowd)
+ with open(filename, mode='a') as cf:
+ dw = csv.DictWriter(cf, fieldnames=rowd.keys())
+ if write_header: # first iteration (epoch == 1 can't be used)
+ dw.writeheader()
+ dw.writerow(rowd)
diff --git a/timm/version.py b/timm/version.py
new file mode 100644
index 0000000..2b8877c
--- /dev/null
+++ b/timm/version.py
@@ -0,0 +1 @@
+__version__ = '0.5.0'
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..cfd728a
--- /dev/null
+++ b/train.py
@@ -0,0 +1,77 @@
+import time
+import torch
+from options.train_options import TrainOptions
+from data import create_dataset
+from models import create_model
+from util.visualizer import Visualizer
+
+
+if __name__ == '__main__':
+ opt = TrainOptions().parse() # get training options
+ dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
+ dataset_size = len(dataset) # get the number of images in the dataset.
+
+ model = create_model(opt) # create a model given opt.model and other options
+ print('The number of training images = %d' % dataset_size)
+
+ visualizer = Visualizer(opt) # create a visualizer that display/save images and plots
+ opt.visualizer = visualizer
+ total_iters = 0 # the total number of training iterations
+
+ optimize_time = 0.1
+
+ times = []
+ for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): # outer loop for different epochs; we save the model by , +
+ epoch_start_time = time.time() # timer for entire epoch
+ iter_data_time = time.time() # timer for data loading per iteration
+ epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
+ visualizer.reset() # reset the visualizer: make sure it saves the results to HTML at least once every epoch
+
+ dataset.set_epoch(epoch)
+ for i, data in enumerate(dataset): # inner loop within one epoch
+ iter_start_time = time.time() # timer for computation per iteration
+ if total_iters % opt.print_freq == 0:
+ t_data = iter_start_time - iter_data_time
+
+ batch_size = data["A0"].size(0)
+ total_iters += batch_size
+ epoch_iter += batch_size
+ if len(opt.gpu_ids) > 0:
+ torch.cuda.synchronize()
+ optimize_start_time = time.time()
+ if epoch == opt.epoch_count and i == 0:
+ # model.data_dependent_initialize(data)
+ model.setup(opt) # regular setup: load and print networks; create schedulers
+ model.parallelize()
+ model.set_input(data) # unpack data from dataset and apply preprocessing
+ model.optimize_parameters() # calculate loss functions, get gradients, update network weights
+ if len(opt.gpu_ids) > 0:
+ torch.cuda.synchronize()
+ optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time
+
+ if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file
+ save_result = total_iters % opt.update_html_freq == 0
+ model.compute_visuals()
+ visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
+
+ if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk
+ losses = model.get_current_losses()
+ visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data)
+ if opt.display_id is None or opt.display_id > 0:
+ visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
+
+ if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations
+ print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
+ print(opt.name) # it's useful to occasionally show the experiment name on console
+ save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
+ model.save_networks(save_suffix)
+
+ iter_data_time = time.time()
+
+ if epoch % opt.save_epoch_freq == 0: # cache our model every epochs
+ print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
+ model.save_networks('latest')
+ model.save_networks(epoch)
+
+ print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time))
+ model.update_learning_rate() # update learning rates at the end of every epoch.
diff --git a/util/__init__.py b/util/__init__.py
new file mode 100644
index 0000000..718f8f6
--- /dev/null
+++ b/util/__init__.py
@@ -0,0 +1,2 @@
+"""This package includes a miscellaneous collection of useful helper functions."""
+from util import *
diff --git a/util/__pycache__/__init__.cpython-36.pyc b/util/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..2515414
Binary files /dev/null and b/util/__pycache__/__init__.cpython-36.pyc differ
diff --git a/util/__pycache__/html.cpython-36.pyc b/util/__pycache__/html.cpython-36.pyc
new file mode 100644
index 0000000..6fec72e
Binary files /dev/null and b/util/__pycache__/html.cpython-36.pyc differ
diff --git a/util/__pycache__/util.cpython-36.pyc b/util/__pycache__/util.cpython-36.pyc
new file mode 100644
index 0000000..c28b8b4
Binary files /dev/null and b/util/__pycache__/util.cpython-36.pyc differ
diff --git a/util/__pycache__/visualizer.cpython-36.pyc b/util/__pycache__/visualizer.cpython-36.pyc
new file mode 100644
index 0000000..01a528d
Binary files /dev/null and b/util/__pycache__/visualizer.cpython-36.pyc differ
diff --git a/util/get_data.py b/util/get_data.py
new file mode 100644
index 0000000..97edc3c
--- /dev/null
+++ b/util/get_data.py
@@ -0,0 +1,110 @@
+from __future__ import print_function
+import os
+import tarfile
+import requests
+from warnings import warn
+from zipfile import ZipFile
+from bs4 import BeautifulSoup
+from os.path import abspath, isdir, join, basename
+
+
+class GetData(object):
+ """A Python script for downloading CycleGAN or pix2pix datasets.
+
+ Parameters:
+ technique (str) -- One of: 'cyclegan' or 'pix2pix'.
+ verbose (bool) -- If True, print additional information.
+
+ Examples:
+ >>> from util.get_data import GetData
+ >>> gd = GetData(technique='cyclegan')
+ >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
+
+ Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
+ and 'scripts/download_cyclegan_model.sh'.
+ """
+
+ def __init__(self, technique='cyclegan', verbose=True):
+ url_dict = {
+ 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
+ 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
+ }
+ self.url = url_dict.get(technique.lower())
+ self._verbose = verbose
+
+ def _print(self, text):
+ if self._verbose:
+ print(text)
+
+ @staticmethod
+ def _get_options(r):
+ soup = BeautifulSoup(r.text, 'lxml')
+ options = [h.text for h in soup.find_all('a', href=True)
+ if h.text.endswith(('.zip', 'tar.gz'))]
+ return options
+
+ def _present_options(self):
+ r = requests.get(self.url)
+ options = self._get_options(r)
+ print('Options:\n')
+ for i, o in enumerate(options):
+ print("{0}: {1}".format(i, o))
+ choice = input("\nPlease enter the number of the "
+ "dataset above you wish to download:")
+ return options[int(choice)]
+
+ def _download_data(self, dataset_url, save_path):
+ if not isdir(save_path):
+ os.makedirs(save_path)
+
+ base = basename(dataset_url)
+ temp_save_path = join(save_path, base)
+
+ with open(temp_save_path, "wb") as f:
+ r = requests.get(dataset_url)
+ f.write(r.content)
+
+ if base.endswith('.tar.gz'):
+ obj = tarfile.open(temp_save_path)
+ elif base.endswith('.zip'):
+ obj = ZipFile(temp_save_path, 'r')
+ else:
+ raise ValueError("Unknown File Type: {0}.".format(base))
+
+ self._print("Unpacking Data...")
+ obj.extractall(save_path)
+ obj.close()
+ os.remove(temp_save_path)
+
+ def get(self, save_path, dataset=None):
+ """
+
+ Download a dataset.
+
+ Parameters:
+ save_path (str) -- A directory to save the data to.
+ dataset (str) -- (optional). A specific dataset to download.
+ Note: this must include the file extension.
+ If None, options will be presented for you
+ to choose from.
+
+ Returns:
+ save_path_full (str) -- the absolute path to the downloaded data.
+
+ """
+ if dataset is None:
+ selected_dataset = self._present_options()
+ else:
+ selected_dataset = dataset
+
+ save_path_full = join(save_path, selected_dataset.split('.')[0])
+
+ if isdir(save_path_full):
+ warn("\n'{0}' already exists. Voiding Download.".format(
+ save_path_full))
+ else:
+ self._print('Downloading Data...')
+ url = "{0}/{1}".format(self.url, selected_dataset)
+ self._download_data(url, save_path=save_path)
+
+ return abspath(save_path_full)
diff --git a/util/html.py b/util/html.py
new file mode 100644
index 0000000..cc3262a
--- /dev/null
+++ b/util/html.py
@@ -0,0 +1,86 @@
+import dominate
+from dominate.tags import meta, h3, table, tr, td, p, a, img, br
+import os
+
+
+class HTML:
+ """This HTML class allows us to save images and write texts into a single HTML file.
+
+ It consists of functions such as (add a text header to the HTML file),
+ (add a row of images to the HTML file), and (save the HTML to the disk).
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
+ """
+
+ def __init__(self, web_dir, title, refresh=0):
+ """Initialize the HTML classes
+
+ Parameters:
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
+ with self.doc.head:
+ meta(http_equiv="refresh", content=str(refresh))
+
+ def get_image_dir(self):
+ """Return the directory that stores images"""
+ return self.img_dir
+
+ def add_header(self, text):
+ """Insert a header to the HTML file
+
+ Parameters:
+ text (str) -- the header text
+ """
+ with self.doc:
+ h3(text)
+
+ def add_images(self, ims, txts, links, width=400):
+ """add images to the HTML file
+
+ Parameters:
+ ims (str list) -- a list of image paths
+ txts (str list) -- a list of image names shown on the website
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
+ """
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
+ self.doc.add(self.t)
+ with self.t:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
+ br()
+ p(txt)
+
+ def save(self):
+ """save the current content to the HMTL file"""
+ html_file = '%s/index.html' % self.web_dir
+ f = open(html_file, 'wt')
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == '__main__': # we show an example usage here.
+ html = HTML('web/', 'test_html')
+ html.add_header('hello world')
+
+ ims, txts, links = [], [], []
+ for n in range(4):
+ ims.append('image_%d.png' % n)
+ txts.append('text_%d' % n)
+ links.append('image_%d.png' % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/util/image_pool.py b/util/image_pool.py
new file mode 100644
index 0000000..6d086f8
--- /dev/null
+++ b/util/image_pool.py
@@ -0,0 +1,54 @@
+import random
+import torch
+
+
+class ImagePool():
+ """This class implements an image buffer that stores previously generated images.
+
+ This buffer enables us to update discriminators using a history of generated images
+ rather than the ones produced by the latest generators.
+ """
+
+ def __init__(self, pool_size):
+ """Initialize the ImagePool class
+
+ Parameters:
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
+ """
+ self.pool_size = pool_size
+ if self.pool_size > 0: # create an empty pool
+ self.num_imgs = 0
+ self.images = []
+
+ def query(self, images):
+ """Return an image from the pool.
+
+ Parameters:
+ images: the latest generated images from the generator
+
+ Returns images from the buffer.
+
+ By 50/100, the buffer will return input images.
+ By 50/100, the buffer will return images previously stored in the buffer,
+ and insert the current images to the buffer.
+ """
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
+ return images
+ return_images = []
+ for image in images:
+ image = torch.unsqueeze(image.data, 0)
+ if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
+ self.num_imgs = self.num_imgs + 1
+ self.images.append(image)
+ return_images.append(image)
+ else:
+ p = random.uniform(0, 1)
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
+ tmp = self.images[random_id].clone()
+ self.images[random_id] = image
+ return_images.append(tmp)
+ else: # by another 50% chance, the buffer will return the current image
+ return_images.append(image)
+ return_images = torch.cat(return_images, 0) # collect all the images and return
+ return return_images
diff --git a/util/util.py b/util/util.py
new file mode 100644
index 0000000..5702d37
--- /dev/null
+++ b/util/util.py
@@ -0,0 +1,166 @@
+"""This module contains simple helper functions """
+from __future__ import print_function
+import torch
+import numpy as np
+from PIL import Image
+import os
+import importlib
+import argparse
+from argparse import Namespace
+import torchvision
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+def copyconf(default_opt, **kwargs):
+ conf = Namespace(**vars(default_opt))
+ for key in kwargs:
+ setattr(conf, key, kwargs[key])
+ return conf
+
+
+def find_class_in_module(target_cls_name, module):
+ target_cls_name = target_cls_name.replace('_', '').lower()
+ clslib = importlib.import_module(module)
+ cls = None
+ for name, clsobj in clslib.__dict__.items():
+ if name.lower() == target_cls_name:
+ cls = clsobj
+
+ assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
+
+ return cls
+
+
+def tensor2im(input_image, imtype=np.uint8):
+ """"Converts a Tensor array into a numpy image array.
+
+ Parameters:
+ input_image (tensor) -- the input image tensor array
+ imtype (type) -- the desired type of the converted numpy array
+ """
+ if not isinstance(input_image, np.ndarray):
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
+ image_tensor = input_image.data
+ else:
+ return input_image
+ image_numpy = image_tensor[0].clamp(-1.0, 1.0).cpu().float().numpy() # convert it into a numpy array
+ if image_numpy.shape[0] == 1: # grayscale to RGB
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
+ else: # if it is a numpy array, do nothing
+ image_numpy = input_image
+ return image_numpy.astype(imtype)
+
+
+def diagnose_network(net, name='network'):
+ """Calculate and print the mean of average absolute(gradients)
+
+ Parameters:
+ net (torch network) -- Torch network
+ name (str) -- the name of the network
+ """
+ mean = 0.0
+ count = 0
+ for param in net.parameters():
+ if param.grad is not None:
+ mean += torch.mean(torch.abs(param.grad.data))
+ count += 1
+ if count > 0:
+ mean = mean / count
+ print(name)
+ print(mean)
+
+
+def save_image(image_numpy, image_path, aspect_ratio=1.0):
+ """Save a numpy image to the disk
+
+ Parameters:
+ image_numpy (numpy array) -- input numpy array
+ image_path (str) -- the path of the image
+ """
+
+ image_pil = Image.fromarray(image_numpy)
+ h, w, _ = image_numpy.shape
+
+ if aspect_ratio is None:
+ pass
+ elif aspect_ratio > 1.0:
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
+ elif aspect_ratio < 1.0:
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
+ image_pil.save(image_path)
+
+
+def print_numpy(x, val=True, shp=False):
+ """Print the mean, min, max, median, std, and size of a numpy array
+
+ Parameters:
+ val (bool) -- if print the values of the numpy array
+ shp (bool) -- if print the shape of the numpy array
+ """
+ x = x.astype(np.float64)
+ if shp:
+ print('shape,', x.shape)
+ if val:
+ x = x.flatten()
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
+
+
+def mkdirs(paths):
+ """create empty directories if they don't exist
+
+ Parameters:
+ paths (str list) -- a list of directory paths
+ """
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+
+def mkdir(path):
+ """create a single empty directory if it didn't exist
+
+ Parameters:
+ path (str) -- a single directory path
+ """
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def correct_resize_label(t, size):
+ device = t.device
+ t = t.detach().cpu()
+ resized = []
+ for i in range(t.size(0)):
+ one_t = t[i, :1]
+ one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
+ one_np = one_np[:, :, 0]
+ one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
+ resized_t = torch.from_numpy(np.array(one_image)).long()
+ resized.append(resized_t)
+ return torch.stack(resized, dim=0).to(device)
+
+
+def correct_resize(t, size, mode=Image.BICUBIC):
+ device = t.device
+ t = t.detach().cpu()
+ resized = []
+ for i in range(t.size(0)):
+ one_t = t[i:i + 1]
+ one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
+ resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
+ resized.append(resized_t)
+ return torch.stack(resized, dim=0).to(device)
diff --git a/util/visualizer.py b/util/visualizer.py
new file mode 100644
index 0000000..c17f2c6
--- /dev/null
+++ b/util/visualizer.py
@@ -0,0 +1,242 @@
+import numpy as np
+import os
+import sys
+import ntpath
+import time
+from . import util, html
+from subprocess import Popen, PIPE
+
+if sys.version_info[0] == 2:
+ VisdomExceptionBase = Exception
+else:
+ VisdomExceptionBase = ConnectionError
+
+
+def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
+ """Save images to the disk.
+
+ Parameters:
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
+ image_path (str) -- the string is used to create image paths
+ aspect_ratio (float) -- the aspect ratio of saved images
+ width (int) -- the images will be resized to width x width
+
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
+ """
+ image_dir = webpage.get_image_dir()
+ short_path = ntpath.basename(image_path[0])
+ name = os.path.splitext(short_path)[0]
+
+ webpage.add_header(name)
+ ims, txts, links = [], [], []
+
+ for label, im_data in visuals.items():
+ im = util.tensor2im(im_data)
+ image_name = '%s/%s.png' % (label, name)
+ os.makedirs(os.path.join(image_dir, label), exist_ok=True)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+ webpage.add_images(ims, txts, links, width=width)
+
+
+class Visualizer():
+ """This class includes several functions that can display/save images and print/save logging information.
+
+ It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
+ """
+
+ def __init__(self, opt):
+ """Initialize the Visualizer class
+
+ Parameters:
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ Step 1: Cache the training/test options
+ Step 2: connect to a visdom server
+ Step 3: create an HTML object for saveing HTML filters
+ Step 4: create a logging file to store training losses
+ """
+ self.opt = opt # cache the option
+ if opt.display_id is None:
+ self.display_id = np.random.randint(100000) * 10 # just a random display id
+ else:
+ self.display_id = opt.display_id
+ self.use_html = opt.isTrain and not opt.no_html
+ self.win_size = opt.display_winsize
+ self.name = opt.name
+ self.port = opt.display_port
+ self.saved = False
+ if self.display_id > 0: # connect to a visdom server given and
+ import visdom
+ self.plot_data = {}
+ self.ncols = opt.display_ncols
+ if "tensorboard_base_url" not in os.environ:
+ self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
+ else:
+ self.vis = visdom.Visdom(port=2004,
+ base_url=os.environ['tensorboard_base_url'] + '/visdom')
+ if not self.vis.check_connection():
+ self.create_visdom_connections()
+
+ if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ print('create web directory %s...' % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+ # create a logging file to store training losses
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+ def reset(self):
+ """Reset the self.saved status"""
+ self.saved = False
+
+ def create_visdom_connections(self):
+ """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
+ cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
+ print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
+ print('Command: %s' % cmd)
+ Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
+
+ def display_current_results(self, visuals, epoch, save_result):
+ """Display current results on visdom; save current results to an HTML file.
+
+ Parameters:
+ visuals (OrderedDict) - - dictionary of images to display or save
+ epoch (int) - - the current epoch
+ save_result (bool) - - if save the current results to an HTML file
+ """
+ if self.display_id > 0: # show images in the browser using visdom
+ ncols = self.ncols
+ if ncols > 0: # show all the images in one visdom panel
+ ncols = min(ncols, len(visuals))
+ h, w = next(iter(visuals.values())).shape[:2]
+ table_css = """""" % (w, h) # create a table css
+ # create a table of images.
+ title = self.name
+ label_html = ''
+ label_html_row = ''
+ images = []
+ idx = 0
+ for label, image in visuals.items():
+ image_numpy = util.tensor2im(image)
+ label_html_row += '| %s | ' % label
+ images.append(image_numpy.transpose([2, 0, 1]))
+ idx += 1
+ if idx % ncols == 0:
+ label_html += '%s
' % label_html_row
+ label_html_row = ''
+ white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
+ while idx % ncols != 0:
+ images.append(white_image)
+ label_html_row += ' | '
+ idx += 1
+ if label_html_row != '':
+ label_html += '%s
' % label_html_row
+ try:
+ self.vis.images(images, ncols, 2, self.display_id + 1,
+ None, dict(title=title + ' images'))
+ label_html = '' % label_html
+ self.vis.text(table_css + label_html, win=self.display_id + 2,
+ opts=dict(title=title + ' labels'))
+ except VisdomExceptionBase:
+ self.create_visdom_connections()
+
+ else: # show each image in a separate visdom panel;
+ idx = 1
+ try:
+ for label, image in visuals.items():
+ image_numpy = util.tensor2im(image)
+ self.vis.image(
+ image_numpy.transpose([2, 0, 1]),
+ self.display_id + idx,
+ None,
+ dict(title=label)
+ )
+ idx += 1
+ except VisdomExceptionBase:
+ self.create_visdom_connections()
+
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
+ self.saved = True
+ # save images to the disk
+ for label, image in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
+ util.save_image(image_numpy, img_path)
+
+ # update website
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
+ for n in range(epoch, 0, -1):
+ webpage.add_header('epoch [%d]' % n)
+ ims, txts, links = [], [], []
+
+ for label, image_numpy in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = 'epoch%.3d_%s.png' % (n, label)
+ ims.append(img_path)
+ txts.append(label)
+ links.append(img_path)
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ webpage.save()
+
+ def plot_current_losses(self, epoch, counter_ratio, losses):
+ """display the current losses on visdom display: dictionary of error labels and values
+
+ Parameters:
+ epoch (int) -- current epoch
+ counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ """
+ if len(losses) == 0:
+ return
+
+ plot_name = '_'.join(list(losses.keys()))
+
+ if plot_name not in self.plot_data:
+ self.plot_data[plot_name] = {'X': [], 'Y': [], 'legend': list(losses.keys())}
+
+ plot_data = self.plot_data[plot_name]
+ plot_id = list(self.plot_data.keys()).index(plot_name)
+
+ plot_data['X'].append(epoch + counter_ratio)
+ plot_data['Y'].append([losses[k] for k in plot_data['legend']])
+ try:
+ self.vis.line(
+ X=np.stack([np.array(plot_data['X'])] * len(plot_data['legend']), 1),
+ Y=np.array(plot_data['Y']),
+ opts={
+ 'title': self.name,
+ 'legend': plot_data['legend'],
+ 'xlabel': 'epoch',
+ 'ylabel': 'loss'},
+ win=self.display_id - plot_id)
+ except VisdomExceptionBase:
+ self.create_visdom_connections()
+
+ # losses: same format as |losses| of plot_current_losses
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
+ """print current losses on console; also save the losses to the disk
+
+ Parameters:
+ epoch (int) -- current epoch
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ t_comp (float) -- computational time per data point (normalized by batch_size)
+ t_data (float) -- data loading time per data point (normalized by batch_size)
+ """
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
+ for k, v in losses.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message) # print the message
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message) # save the message