mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #323 from rwightman/imagenet21k_datasets_more
BiT (Big Transfer) ResNetV2 models, Official ViT Hybrid R50 weights, VIT IN21K weights updated w/ repr layer, ImageNet21k and dataset / parser refactor
This commit is contained in:
commit
9a38416fbd
19
README.md
19
README.md
@ -2,6 +2,19 @@
|
||||
|
||||
## What's New
|
||||
|
||||
### Jan 25, 2021
|
||||
* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer
|
||||
* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer
|
||||
* ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support
|
||||
* NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning
|
||||
* Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit
|
||||
* Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes
|
||||
* Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script
|
||||
* Ex: `train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2`
|
||||
* Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar
|
||||
* Run validation on full ImageNet-21k directly from tar w/ BiT model: `validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp`
|
||||
* Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling
|
||||
|
||||
### Jan 3, 2021
|
||||
* Add SE-ResNet-152D weights
|
||||
* 256x256 val, 0.94 crop top-1 - 83.75
|
||||
@ -130,7 +143,9 @@ All model architecture families include variants with pretrained weights. The ar
|
||||
|
||||
A full version of the list below with source links can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/models/).
|
||||
|
||||
* Big Transfer ResNetV2 (BiT) - https://arxiv.org/abs/1912.11370
|
||||
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
|
||||
* DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877
|
||||
* DenseNet - https://arxiv.org/abs/1608.06993
|
||||
* DLA - https://arxiv.org/abs/1707.06484
|
||||
* DPN (Dual-Path Network) - https://arxiv.org/abs/1707.01629
|
||||
@ -242,6 +257,10 @@ One of the greatest assets of PyTorch is the community and their contributions.
|
||||
* Albumentations - https://github.com/albumentations-team/albumentations
|
||||
* Kornia - https://github.com/kornia/kornia
|
||||
|
||||
### Knowledge Distillation
|
||||
* RepDistiller - https://github.com/HobbitLong/RepDistiller
|
||||
* torchdistill - https://github.com/yoshitomo-matsubara/torchdistill
|
||||
|
||||
### Metric Learning
|
||||
* PyTorch Metric Learning - https://github.com/KevinMusgrave/pytorch-metric-learning
|
||||
|
||||
|
@ -10,6 +10,10 @@ Most included models have pretrained weights. The weights are either:
|
||||
|
||||
The validation results for the pretrained weights can be found [here](results.md)
|
||||
|
||||
## Big Transfer ResNetV2 (BiT) [[resnetv2.py](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/resnetv2.py)]
|
||||
* Paper: `Big Transfer (BiT): General Visual Representation Learning` - https://arxiv.org/abs/1912.11370
|
||||
* Reference code: https://github.com/google-research/big_transfer
|
||||
|
||||
## Cross-Stage Partial Networks [[cspnet.py](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/cspnet.py)]
|
||||
* Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
|
||||
* Reference impl: https://github.com/WongKinYiu/CrossStagePartialNetworks
|
||||
|
@ -13,7 +13,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool
|
||||
from timm.data import Dataset, create_loader, resolve_data_config
|
||||
from timm.data import ImageDataset, create_loader, resolve_data_config
|
||||
from timm.utils import AverageMeter, setup_default_logging
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
@ -83,7 +83,7 @@ def main():
|
||||
model = model.cuda()
|
||||
|
||||
loader = create_loader(
|
||||
Dataset(args.data),
|
||||
ImageDataset(args.data),
|
||||
input_size=config['input_size'],
|
||||
batch_size=args.batch_size,
|
||||
use_prefetcher=True,
|
||||
|
21843
results/imagenet21k_goog_synsets.txt
Normal file
21843
results/imagenet21k_goog_synsets.txt
Normal file
File diff suppressed because it is too large
Load Diff
@ -13,11 +13,16 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
torch._C._jit_set_profiling_executor(True)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
# transformer models don't support many of the spatial / feature based model functionalities
|
||||
NON_STD_FILTERS = ['vit_*']
|
||||
|
||||
# exclude models that cause specific test failures
|
||||
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
|
||||
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
|
||||
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', 'vit_*']
|
||||
EXCLUDE_FILTERS = ['*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm'] + NON_STD_FILTERS
|
||||
else:
|
||||
EXCLUDE_FILTERS = ['vit_*']
|
||||
EXCLUDE_FILTERS = NON_STD_FILTERS
|
||||
|
||||
MAX_FWD_SIZE = 384
|
||||
MAX_BWD_SIZE = 128
|
||||
MAX_FWD_FEAT_SIZE = 448
|
||||
@ -68,7 +73,7 @@ def test_model_backward(model_name, batch_size):
|
||||
|
||||
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=['vit_*']))
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_default_cfgs(model_name, batch_size):
|
||||
"""Run a single forward pass with each model"""
|
||||
@ -121,7 +126,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
|
||||
create_model(model_name, pretrained=True, in_chans=in_chans)
|
||||
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=['vit_*']))
|
||||
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_features_pretrained(model_name, batch_size):
|
||||
"""Create that pretrained weights load when features_only==True."""
|
||||
|
@ -1,10 +1,12 @@
|
||||
from .constants import *
|
||||
from .config import resolve_data_config
|
||||
from .dataset import Dataset, DatasetTar, AugMixDataset
|
||||
from .transforms import *
|
||||
from .loader import create_loader
|
||||
from .transforms_factory import create_transform
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||
rand_augment_transform, auto_augment_transform
|
||||
from .config import resolve_data_config
|
||||
from .constants import *
|
||||
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
||||
from .dataset_factory import create_dataset
|
||||
from .loader import create_loader
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .parsers import create_parser
|
||||
from .real_labels import RealLabelsImagenet
|
||||
from .transforms import *
|
||||
from .transforms_factory import create_transform
|
@ -3,172 +3,106 @@
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import torch.utils.data as data
|
||||
|
||||
import os
|
||||
import re
|
||||
import torch
|
||||
import tarfile
|
||||
import logging
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .parsers import create_parser
|
||||
|
||||
IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg']
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def natural_key(string_):
|
||||
"""See http://www.codinghorror.com/blog/archives/001018.html"""
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
_ERROR_RETRY = 50
|
||||
|
||||
|
||||
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
|
||||
labels = []
|
||||
filenames = []
|
||||
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
|
||||
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
|
||||
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
|
||||
for f in files:
|
||||
base, ext = os.path.splitext(f)
|
||||
if ext.lower() in types:
|
||||
filenames.append(os.path.join(root, f))
|
||||
labels.append(label)
|
||||
if class_to_idx is None:
|
||||
# building class index
|
||||
unique_labels = set(labels)
|
||||
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
||||
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
||||
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
|
||||
if sort:
|
||||
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
|
||||
return images_and_targets, class_to_idx
|
||||
|
||||
|
||||
def load_class_map(filename, root=''):
|
||||
class_map_path = filename
|
||||
if not os.path.exists(class_map_path):
|
||||
class_map_path = os.path.join(root, filename)
|
||||
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
|
||||
class_map_ext = os.path.splitext(filename)[-1].lower()
|
||||
if class_map_ext == '.txt':
|
||||
with open(class_map_path) as f:
|
||||
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
|
||||
else:
|
||||
assert False, 'Unsupported class map extension'
|
||||
return class_to_idx
|
||||
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
class ImageDataset(data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
parser=None,
|
||||
class_map='',
|
||||
load_bytes=False,
|
||||
transform=None,
|
||||
class_map=''):
|
||||
|
||||
class_to_idx = None
|
||||
if class_map:
|
||||
class_to_idx = load_class_map(class_map, root)
|
||||
images, class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
|
||||
if len(images) == 0:
|
||||
raise RuntimeError(f'Found 0 images in subfolders of {root}. '
|
||||
f'Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
|
||||
self.root = root
|
||||
self.samples = images
|
||||
self.imgs = self.samples # torchvision ImageFolder compat
|
||||
self.class_to_idx = class_to_idx
|
||||
):
|
||||
if parser is None or isinstance(parser, str):
|
||||
parser = create_parser(parser or '', root=root, class_map=class_map)
|
||||
self.parser = parser
|
||||
self.load_bytes = load_bytes
|
||||
self.transform = transform
|
||||
self._consecutive_errors = 0
|
||||
|
||||
def __getitem__(self, index):
|
||||
path, target = self.samples[index]
|
||||
img = open(path, 'rb').read() if self.load_bytes else Image.open(path).convert('RGB')
|
||||
img, target = self.parser[index]
|
||||
try:
|
||||
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
|
||||
except Exception as e:
|
||||
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
|
||||
self._consecutive_errors += 1
|
||||
if self._consecutive_errors < _ERROR_RETRY:
|
||||
return self.__getitem__((index + 1) % len(self.parser))
|
||||
else:
|
||||
raise e
|
||||
self._consecutive_errors = 0
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if target is None:
|
||||
target = torch.zeros(1).long()
|
||||
target = torch.tensor(-1, dtype=torch.long)
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
return len(self.parser)
|
||||
|
||||
def filename(self, index, basename=False, absolute=False):
|
||||
filename = self.samples[index][0]
|
||||
if basename:
|
||||
filename = os.path.basename(filename)
|
||||
elif not absolute:
|
||||
filename = os.path.relpath(filename, self.root)
|
||||
return filename
|
||||
return self.parser.filename(index, basename, absolute)
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
fn = lambda x: x
|
||||
if basename:
|
||||
fn = os.path.basename
|
||||
elif not absolute:
|
||||
fn = lambda x: os.path.relpath(x, self.root)
|
||||
return [fn(x[0]) for x in self.samples]
|
||||
return self.parser.filenames(basename, absolute)
|
||||
|
||||
|
||||
def _extract_tar_info(tarfile, class_to_idx=None, sort=True):
|
||||
files = []
|
||||
labels = []
|
||||
for ti in tarfile.getmembers():
|
||||
if not ti.isfile():
|
||||
continue
|
||||
dirname, basename = os.path.split(ti.path)
|
||||
label = os.path.basename(dirname)
|
||||
ext = os.path.splitext(basename)[1]
|
||||
if ext.lower() in IMG_EXTENSIONS:
|
||||
files.append(ti)
|
||||
labels.append(label)
|
||||
if class_to_idx is None:
|
||||
unique_labels = set(labels)
|
||||
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
||||
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
||||
tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
|
||||
if sort:
|
||||
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
|
||||
return tarinfo_and_targets, class_to_idx
|
||||
class IterableImageDataset(data.IterableDataset):
|
||||
|
||||
|
||||
class DatasetTar(data.Dataset):
|
||||
|
||||
def __init__(self, root, load_bytes=False, transform=None, class_map=''):
|
||||
|
||||
class_to_idx = None
|
||||
if class_map:
|
||||
class_to_idx = load_class_map(class_map, root)
|
||||
assert os.path.isfile(root)
|
||||
self.root = root
|
||||
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
|
||||
self.samples, self.class_to_idx = _extract_tar_info(tf, class_to_idx)
|
||||
self.imgs = self.samples
|
||||
self.tarfile = None # lazy init in __getitem__
|
||||
self.load_bytes = load_bytes
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
parser=None,
|
||||
split='train',
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
class_map='',
|
||||
load_bytes=False,
|
||||
transform=None,
|
||||
):
|
||||
assert parser is not None
|
||||
if isinstance(parser, str):
|
||||
self.parser = create_parser(
|
||||
parser, root=root, split=split, is_training=is_training, batch_size=batch_size)
|
||||
else:
|
||||
self.parser = parser
|
||||
self.transform = transform
|
||||
self._consecutive_errors = 0
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.tarfile is None:
|
||||
self.tarfile = tarfile.open(self.root)
|
||||
tarinfo, target = self.samples[index]
|
||||
iob = self.tarfile.extractfile(tarinfo)
|
||||
img = iob.read() if self.load_bytes else Image.open(iob).convert('RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if target is None:
|
||||
target = torch.zeros(1).long()
|
||||
return img, target
|
||||
def __iter__(self):
|
||||
for img, target in self.parser:
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if target is None:
|
||||
target = torch.tensor(-1, dtype=torch.long)
|
||||
yield img, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
if hasattr(self.parser, '__len__'):
|
||||
return len(self.parser)
|
||||
else:
|
||||
return 0
|
||||
|
||||
def filename(self, index, basename=False):
|
||||
filename = self.samples[index][0].name
|
||||
if basename:
|
||||
filename = os.path.basename(filename)
|
||||
return filename
|
||||
def filename(self, index, basename=False, absolute=False):
|
||||
assert False, 'Filename lookup by index not supported, use filenames().'
|
||||
|
||||
def filenames(self, basename=False):
|
||||
fn = os.path.basename if basename else lambda x: x
|
||||
return [fn(x[0].name) for x in self.samples]
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
return self.parser.filenames(basename, absolute)
|
||||
|
||||
|
||||
class AugMixDataset(torch.utils.data.Dataset):
|
||||
|
29
timm/data/dataset_factory.py
Normal file
29
timm/data/dataset_factory.py
Normal file
@ -0,0 +1,29 @@
|
||||
import os
|
||||
|
||||
from .dataset import IterableImageDataset, ImageDataset
|
||||
|
||||
|
||||
def _search_split(root, split):
|
||||
# look for sub-folder with name of split in root and use that if it exists
|
||||
split_name = split.split('[')[0]
|
||||
try_root = os.path.join(root, split_name)
|
||||
if os.path.exists(try_root):
|
||||
return try_root
|
||||
if split_name == 'validation':
|
||||
try_root = os.path.join(root, 'val')
|
||||
if os.path.exists(try_root):
|
||||
return try_root
|
||||
return root
|
||||
|
||||
|
||||
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
|
||||
name = name.lower()
|
||||
if name.startswith('tfds'):
|
||||
ds = IterableImageDataset(
|
||||
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
|
||||
else:
|
||||
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
||||
if search_split and os.path.isdir(root):
|
||||
root = _search_split(root, split)
|
||||
ds = ImageDataset(root, parser=name, **kwargs)
|
||||
return ds
|
@ -153,7 +153,8 @@ def create_loader(
|
||||
pin_memory=False,
|
||||
fp16=False,
|
||||
tf_preprocessing=False,
|
||||
use_multi_epochs_loader=False
|
||||
use_multi_epochs_loader=False,
|
||||
persistent_workers=True,
|
||||
):
|
||||
re_num_splits = 0
|
||||
if re_split:
|
||||
@ -183,7 +184,7 @@ def create_loader(
|
||||
)
|
||||
|
||||
sampler = None
|
||||
if distributed:
|
||||
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
|
||||
if is_training:
|
||||
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
else:
|
||||
@ -199,16 +200,20 @@ def create_loader(
|
||||
if use_multi_epochs_loader:
|
||||
loader_class = MultiEpochsDataLoader
|
||||
|
||||
loader = loader_class(
|
||||
dataset,
|
||||
loader_args = dict(
|
||||
batch_size=batch_size,
|
||||
shuffle=sampler is None and is_training,
|
||||
shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
|
||||
num_workers=num_workers,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
drop_last=is_training,
|
||||
)
|
||||
persistent_workers=persistent_workers)
|
||||
try:
|
||||
loader = loader_class(dataset, **loader_args)
|
||||
except TypeError as e:
|
||||
loader_args.pop('persistent_workers') # only in Pytorch 1.7+
|
||||
loader = loader_class(dataset, **loader_args)
|
||||
if use_prefetcher:
|
||||
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
|
||||
loader = PrefetchLoader(
|
||||
|
1
timm/data/parsers/__init__.py
Normal file
1
timm/data/parsers/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .parser_factory import create_parser
|
16
timm/data/parsers/class_map.py
Normal file
16
timm/data/parsers/class_map.py
Normal file
@ -0,0 +1,16 @@
|
||||
import os
|
||||
|
||||
|
||||
def load_class_map(filename, root=''):
|
||||
class_map_path = filename
|
||||
if not os.path.exists(class_map_path):
|
||||
class_map_path = os.path.join(root, filename)
|
||||
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
|
||||
class_map_ext = os.path.splitext(filename)[-1].lower()
|
||||
if class_map_ext == '.txt':
|
||||
with open(class_map_path) as f:
|
||||
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
|
||||
else:
|
||||
assert False, 'Unsupported class map extension'
|
||||
return class_to_idx
|
||||
|
1
timm/data/parsers/constants.py
Normal file
1
timm/data/parsers/constants.py
Normal file
@ -0,0 +1 @@
|
||||
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')
|
17
timm/data/parsers/parser.py
Normal file
17
timm/data/parsers/parser.py
Normal file
@ -0,0 +1,17 @@
|
||||
from abc import abstractmethod
|
||||
|
||||
|
||||
class Parser:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
pass
|
||||
|
||||
def filename(self, index, basename=False, absolute=False):
|
||||
return self._filename(index, basename=basename, absolute=absolute)
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
return [self._filename(index, basename=basename, absolute=absolute) for index in range(len(self))]
|
||||
|
29
timm/data/parsers/parser_factory.py
Normal file
29
timm/data/parsers/parser_factory.py
Normal file
@ -0,0 +1,29 @@
|
||||
import os
|
||||
|
||||
from .parser_image_folder import ParserImageFolder
|
||||
from .parser_image_tar import ParserImageTar
|
||||
from .parser_image_in_tar import ParserImageInTar
|
||||
|
||||
|
||||
def create_parser(name, root, split='train', **kwargs):
|
||||
name = name.lower()
|
||||
name = name.split('/', 2)
|
||||
prefix = ''
|
||||
if len(name) > 1:
|
||||
prefix = name[0]
|
||||
name = name[-1]
|
||||
|
||||
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
|
||||
# explicitly select other options shortly
|
||||
if prefix == 'tfds':
|
||||
from .parser_tfds import ParserTfds # defer tensorflow import
|
||||
parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs)
|
||||
else:
|
||||
assert os.path.exists(root)
|
||||
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder
|
||||
# FIXME support split here, in parser?
|
||||
if os.path.isfile(root) and os.path.splitext(root)[1] == '.tar':
|
||||
parser = ParserImageInTar(root, **kwargs)
|
||||
else:
|
||||
parser = ParserImageFolder(root, **kwargs)
|
||||
return parser
|
69
timm/data/parsers/parser_image_folder.py
Normal file
69
timm/data/parsers/parser_image_folder.py
Normal file
@ -0,0 +1,69 @@
|
||||
""" A dataset parser that reads images from folders
|
||||
|
||||
Folders are scannerd recursively to find image files. Labels are based
|
||||
on the folder hierarchy, just leaf folders by default.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
|
||||
from timm.utils.misc import natural_key
|
||||
|
||||
from .parser import Parser
|
||||
from .class_map import load_class_map
|
||||
from .constants import IMG_EXTENSIONS
|
||||
|
||||
|
||||
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
|
||||
labels = []
|
||||
filenames = []
|
||||
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
|
||||
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
|
||||
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
|
||||
for f in files:
|
||||
base, ext = os.path.splitext(f)
|
||||
if ext.lower() in types:
|
||||
filenames.append(os.path.join(root, f))
|
||||
labels.append(label)
|
||||
if class_to_idx is None:
|
||||
# building class index
|
||||
unique_labels = set(labels)
|
||||
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
||||
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
||||
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
|
||||
if sort:
|
||||
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
|
||||
return images_and_targets, class_to_idx
|
||||
|
||||
|
||||
class ParserImageFolder(Parser):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
class_map=''):
|
||||
super().__init__()
|
||||
|
||||
self.root = root
|
||||
class_to_idx = None
|
||||
if class_map:
|
||||
class_to_idx = load_class_map(class_map, root)
|
||||
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
|
||||
if len(self.samples) == 0:
|
||||
raise RuntimeError(
|
||||
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
|
||||
|
||||
def __getitem__(self, index):
|
||||
path, target = self.samples[index]
|
||||
return open(path, 'rb'), target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
filename = self.samples[index][0]
|
||||
if basename:
|
||||
filename = os.path.basename(filename)
|
||||
elif not absolute:
|
||||
filename = os.path.relpath(filename, self.root)
|
||||
return filename
|
222
timm/data/parsers/parser_image_in_tar.py
Normal file
222
timm/data/parsers/parser_image_in_tar.py
Normal file
@ -0,0 +1,222 @@
|
||||
""" A dataset parser that reads tarfile based datasets
|
||||
|
||||
This parser can read and extract image samples from:
|
||||
* a single tar of image files
|
||||
* a folder of multiple tarfiles containing imagefiles
|
||||
* a tar of tars containing image files
|
||||
|
||||
Labels are based on the combined folder and/or tar name structure.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
import tarfile
|
||||
import pickle
|
||||
import logging
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
from typing import List, Dict
|
||||
|
||||
from timm.utils.misc import natural_key
|
||||
|
||||
from .parser import Parser
|
||||
from .class_map import load_class_map
|
||||
from .constants import IMG_EXTENSIONS
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
|
||||
|
||||
|
||||
class TarState:
|
||||
|
||||
def __init__(self, tf: tarfile.TarFile = None, ti: tarfile.TarInfo = None):
|
||||
self.tf: tarfile.TarFile = tf
|
||||
self.ti: tarfile.TarInfo = ti
|
||||
self.children: Dict[str, TarState] = {} # child states (tars within tars)
|
||||
|
||||
def reset(self):
|
||||
self.tf = None
|
||||
|
||||
|
||||
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
|
||||
sample_count = 0
|
||||
for i, ti in enumerate(tf):
|
||||
if not ti.isfile():
|
||||
continue
|
||||
dirname, basename = os.path.split(ti.path)
|
||||
name, ext = os.path.splitext(basename)
|
||||
ext = ext.lower()
|
||||
if ext == '.tar':
|
||||
with tarfile.open(fileobj=tf.extractfile(ti), mode='r|') as ctf:
|
||||
child_info = dict(
|
||||
name=ti.name, path=os.path.join(parent_info['path'], name), ti=ti, children=[], samples=[])
|
||||
sample_count += _extract_tarinfo(ctf, child_info, extensions=extensions)
|
||||
_logger.debug(f'{i}/?. Extracted child tarinfos from {ti.name}. {len(child_info["samples"])} images.')
|
||||
parent_info['children'].append(child_info)
|
||||
elif ext in extensions:
|
||||
parent_info['samples'].append(ti)
|
||||
sample_count += 1
|
||||
return sample_count
|
||||
|
||||
|
||||
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
|
||||
root_is_tar = False
|
||||
if os.path.isfile(root):
|
||||
assert os.path.splitext(root)[-1].lower() == '.tar'
|
||||
tar_filenames = [root]
|
||||
root, root_name = os.path.split(root)
|
||||
root_name = os.path.splitext(root_name)[0]
|
||||
root_is_tar = True
|
||||
else:
|
||||
root_name = root.strip(os.path.sep).split(os.path.sep)[-1]
|
||||
tar_filenames = glob(os.path.join(root, '*.tar'), recursive=True)
|
||||
num_tars = len(tar_filenames)
|
||||
tar_bytes = sum([os.path.getsize(f) for f in tar_filenames])
|
||||
assert num_tars, f'No .tar files found at specified path ({root}).'
|
||||
|
||||
_logger.info(f'Scanning {tar_bytes/1024**2:.2f}MB of tar files...')
|
||||
info = dict(tartrees=[])
|
||||
cache_path = ''
|
||||
if cache_tarinfo is None:
|
||||
cache_tarinfo = True if tar_bytes > 10*1024**3 else False # FIXME magic number, 10GB
|
||||
if cache_tarinfo:
|
||||
cache_filename = '_' + root_name + CACHE_FILENAME_SUFFIX
|
||||
cache_path = os.path.join(root, cache_filename)
|
||||
if os.path.exists(cache_path):
|
||||
_logger.info(f'Reading tar info from cache file {cache_path}.')
|
||||
with open(cache_path, 'rb') as pf:
|
||||
info = pickle.load(pf)
|
||||
assert len(info['tartrees']) == num_tars, "Cached tartree len doesn't match number of tarfiles"
|
||||
else:
|
||||
for i, fn in enumerate(tar_filenames):
|
||||
path = '' if root_is_tar else os.path.splitext(os.path.basename(fn))[0]
|
||||
with tarfile.open(fn, mode='r|') as tf: # tarinfo scans done in streaming mode
|
||||
parent_info = dict(name=os.path.relpath(fn, root), path=path, ti=None, children=[], samples=[])
|
||||
num_samples = _extract_tarinfo(tf, parent_info, extensions=extensions)
|
||||
num_children = len(parent_info["children"])
|
||||
_logger.debug(
|
||||
f'{i}/{num_tars}. Extracted tarinfos from {fn}. {num_children} children, {num_samples} samples.')
|
||||
info['tartrees'].append(parent_info)
|
||||
if cache_path:
|
||||
_logger.info(f'Writing tar info to cache file {cache_path}.')
|
||||
with open(cache_path, 'wb') as pf:
|
||||
pickle.dump(info, pf)
|
||||
|
||||
samples = []
|
||||
labels = []
|
||||
build_class_map = False
|
||||
if class_name_to_idx is None:
|
||||
build_class_map = True
|
||||
|
||||
# Flatten tartree info into lists of samples and targets w/ targets based on label id via
|
||||
# class map arg or from unique paths.
|
||||
# NOTE: currently only flattening up to two-levels, filesystem .tars and then one level of sub-tar children
|
||||
# this covers my current use cases and keeps things a little easier to test for now.
|
||||
tarfiles = []
|
||||
|
||||
def _label_from_paths(*path, leaf_only=True):
|
||||
path = os.path.join(*path).strip(os.path.sep)
|
||||
return path.split(os.path.sep)[-1] if leaf_only else path.replace(os.path.sep, '_')
|
||||
|
||||
def _add_samples(info, fn):
|
||||
added = 0
|
||||
for s in info['samples']:
|
||||
label = _label_from_paths(info['path'], os.path.dirname(s.path))
|
||||
if not build_class_map and label not in class_name_to_idx:
|
||||
continue
|
||||
samples.append((s, fn, info['ti']))
|
||||
labels.append(label)
|
||||
added += 1
|
||||
return added
|
||||
|
||||
_logger.info(f'Collecting samples and building tar states.')
|
||||
for parent_info in info['tartrees']:
|
||||
# if tartree has children, we assume all samples are at the child level
|
||||
tar_name = None if root_is_tar else parent_info['name']
|
||||
tar_state = TarState()
|
||||
parent_added = 0
|
||||
for child_info in parent_info['children']:
|
||||
child_added = _add_samples(child_info, fn=tar_name)
|
||||
if child_added:
|
||||
tar_state.children[child_info['name']] = TarState(ti=child_info['ti'])
|
||||
parent_added += child_added
|
||||
parent_added += _add_samples(parent_info, fn=tar_name)
|
||||
if parent_added:
|
||||
tarfiles.append((tar_name, tar_state))
|
||||
del info
|
||||
|
||||
if build_class_map:
|
||||
# build class index
|
||||
sorted_labels = list(sorted(set(labels), key=natural_key))
|
||||
class_name_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
||||
|
||||
_logger.info(f'Mapping targets and sorting samples.')
|
||||
samples_and_targets = [(s, class_name_to_idx[l]) for s, l in zip(samples, labels) if l in class_name_to_idx]
|
||||
if sort:
|
||||
samples_and_targets = sorted(samples_and_targets, key=lambda k: natural_key(k[0][0].path))
|
||||
samples, targets = zip(*samples_and_targets)
|
||||
samples = np.array(samples)
|
||||
targets = np.array(targets)
|
||||
_logger.info(f'Finished processing {len(samples)} samples across {len(tarfiles)} tar files.')
|
||||
return samples, targets, class_name_to_idx, tarfiles
|
||||
|
||||
|
||||
class ParserImageInTar(Parser):
|
||||
""" Multi-tarfile dataset parser where there is one .tar file per class
|
||||
"""
|
||||
|
||||
def __init__(self, root, class_map='', cache_tarfiles=True, cache_tarinfo=None):
|
||||
super().__init__()
|
||||
|
||||
class_name_to_idx = None
|
||||
if class_map:
|
||||
class_name_to_idx = load_class_map(class_map, root)
|
||||
self.root = root
|
||||
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
|
||||
self.root,
|
||||
class_name_to_idx=class_name_to_idx,
|
||||
cache_tarinfo=cache_tarinfo,
|
||||
extensions=IMG_EXTENSIONS)
|
||||
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
|
||||
if len(tarfiles) == 1 and tarfiles[0][0] is None:
|
||||
self.root_is_tar = True
|
||||
self.tar_state = tarfiles[0][1]
|
||||
else:
|
||||
self.root_is_tar = False
|
||||
self.tar_state = dict(tarfiles)
|
||||
self.cache_tarfiles = cache_tarfiles
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
target = self.targets[index]
|
||||
sample_ti, parent_fn, child_ti = sample
|
||||
parent_abs = os.path.join(self.root, parent_fn) if parent_fn else self.root
|
||||
|
||||
tf = None
|
||||
cache_state = None
|
||||
if self.cache_tarfiles:
|
||||
cache_state = self.tar_state if self.root_is_tar else self.tar_state[parent_fn]
|
||||
tf = cache_state.tf
|
||||
if tf is None:
|
||||
tf = tarfile.open(parent_abs)
|
||||
if self.cache_tarfiles:
|
||||
cache_state.tf = tf
|
||||
if child_ti is not None:
|
||||
ctf = cache_state.children[child_ti.name].tf if self.cache_tarfiles else None
|
||||
if ctf is None:
|
||||
ctf = tarfile.open(fileobj=tf.extractfile(child_ti))
|
||||
if self.cache_tarfiles:
|
||||
cache_state.children[child_ti.name].tf = ctf
|
||||
tf = ctf
|
||||
|
||||
return tf.extractfile(sample_ti), target
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
filename = self.samples[index][0].name
|
||||
if basename:
|
||||
filename = os.path.basename(filename)
|
||||
return filename
|
72
timm/data/parsers/parser_image_tar.py
Normal file
72
timm/data/parsers/parser_image_tar.py
Normal file
@ -0,0 +1,72 @@
|
||||
""" A dataset parser that reads single tarfile based datasets
|
||||
|
||||
This parser can read datasets consisting if a single tarfile containing images.
|
||||
I am planning to deprecated it in favour of ParerImageInTar.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
import tarfile
|
||||
|
||||
from .parser import Parser
|
||||
from .class_map import load_class_map
|
||||
from .constants import IMG_EXTENSIONS
|
||||
from timm.utils.misc import natural_key
|
||||
|
||||
|
||||
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
||||
files = []
|
||||
labels = []
|
||||
for ti in tarfile.getmembers():
|
||||
if not ti.isfile():
|
||||
continue
|
||||
dirname, basename = os.path.split(ti.path)
|
||||
label = os.path.basename(dirname)
|
||||
ext = os.path.splitext(basename)[1]
|
||||
if ext.lower() in IMG_EXTENSIONS:
|
||||
files.append(ti)
|
||||
labels.append(label)
|
||||
if class_to_idx is None:
|
||||
unique_labels = set(labels)
|
||||
sorted_labels = list(sorted(unique_labels, key=natural_key))
|
||||
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
|
||||
tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels) if l in class_to_idx]
|
||||
if sort:
|
||||
tarinfo_and_targets = sorted(tarinfo_and_targets, key=lambda k: natural_key(k[0].path))
|
||||
return tarinfo_and_targets, class_to_idx
|
||||
|
||||
|
||||
class ParserImageTar(Parser):
|
||||
""" Single tarfile dataset where classes are mapped to folders within tar
|
||||
NOTE: This class is being deprecated in favour of the more capable ParserImageInTar that can
|
||||
operate on folders of tars or tars in tars.
|
||||
"""
|
||||
def __init__(self, root, class_map=''):
|
||||
super().__init__()
|
||||
|
||||
class_to_idx = None
|
||||
if class_map:
|
||||
class_to_idx = load_class_map(class_map, root)
|
||||
assert os.path.isfile(root)
|
||||
self.root = root
|
||||
|
||||
with tarfile.open(root) as tf: # cannot keep this open across processes, reopen later
|
||||
self.samples, self.class_to_idx = extract_tarinfo(tf, class_to_idx)
|
||||
self.imgs = self.samples
|
||||
self.tarfile = None # lazy init in __getitem__
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.tarfile is None:
|
||||
self.tarfile = tarfile.open(self.root)
|
||||
tarinfo, target = self.samples[index]
|
||||
fileobj = self.tarfile.extractfile(tarinfo)
|
||||
return fileobj, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
filename = self.samples[index][0].name
|
||||
if basename:
|
||||
filename = os.path.basename(filename)
|
||||
return filename
|
201
timm/data/parsers/parser_tfds.py
Normal file
201
timm/data/parsers/parser_tfds.py
Normal file
@ -0,0 +1,201 @@
|
||||
""" Dataset parser interface that wraps TFDS datasets
|
||||
|
||||
Wraps many (most?) TFDS image-classification datasets
|
||||
from https://github.com/tensorflow/datasets
|
||||
https://www.tensorflow.org/datasets/catalog/overview#image_classification
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
import io
|
||||
import math
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL import Image
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
|
||||
import tensorflow_datasets as tfds
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
|
||||
exit(1)
|
||||
from .parser import Parser
|
||||
|
||||
|
||||
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
|
||||
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
|
||||
PREFETCH_SIZE = 4096 # samples to prefetch
|
||||
|
||||
|
||||
class ParserTfds(Parser):
|
||||
""" Wrap Tensorflow Datasets for use in PyTorch
|
||||
|
||||
There several things to be aware of:
|
||||
* To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of
|
||||
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
|
||||
https://github.com/pytorch/pytorch/issues/33413
|
||||
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
|
||||
from each worker could be a different size. For training this is worked around by option above, for
|
||||
validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced
|
||||
across replicas are of same size. This will slightly alter the results, distributed validation will not be
|
||||
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
|
||||
since there are up to N * J extra samples with IterableDatasets.
|
||||
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
|
||||
replicas and dataloader workers you can use. For really small datasets that only contain a few shards
|
||||
you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
|
||||
benefit of distributed training or fast dataloading should be much less for small datasets.
|
||||
* This wrapper is currently configured to return individual, decompressed image samples from the TFDS
|
||||
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
|
||||
to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
|
||||
components.
|
||||
|
||||
"""
|
||||
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None):
|
||||
super().__init__()
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.shuffle = shuffle
|
||||
self.is_training = is_training
|
||||
if self.is_training:
|
||||
assert batch_size is not None,\
|
||||
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.builder = tfds.builder(name, data_dir=root)
|
||||
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
|
||||
# download_and_prepare() by default here as it's caused issues generating unwanted paths.
|
||||
self.num_samples = self.builder.info.splits[split].num_examples
|
||||
self.ds = None # initialized lazily on each dataloader worker process
|
||||
|
||||
self.worker_info = None
|
||||
self.dist_rank = 0
|
||||
self.dist_num_replicas = 1
|
||||
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
self.dist_rank = dist.get_rank()
|
||||
self.dist_num_replicas = dist.get_world_size()
|
||||
|
||||
def _lazy_init(self):
|
||||
""" Lazily initialize the dataset.
|
||||
|
||||
This is necessary to init the Tensorflow dataset pipeline in the (dataloader) process that
|
||||
will be using the dataset instance. The __init__ method is called on the main process,
|
||||
this will be called in a dataloader worker process.
|
||||
|
||||
NOTE: There will be problems if you try to re-use this dataset across different loader/worker
|
||||
instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
|
||||
before it is passed to dataloader.
|
||||
"""
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
# setup input context to split dataset across distributed processes
|
||||
split = self.split
|
||||
num_workers = 1
|
||||
if worker_info is not None:
|
||||
self.worker_info = worker_info
|
||||
num_workers = worker_info.num_workers
|
||||
worker_id = worker_info.id
|
||||
|
||||
# FIXME I need to spend more time figuring out the best way to distribute/split data across
|
||||
# combo of distributed replicas + dataloader worker processes
|
||||
"""
|
||||
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
|
||||
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
|
||||
between the splits each iteration, but that understanding could be wrong.
|
||||
Possible split options include:
|
||||
* InputContext for both distributed & worker processes (current)
|
||||
* InputContext for distributed and sub-splits for worker processes
|
||||
* sub-splits for both
|
||||
"""
|
||||
# split_size = self.num_samples // num_workers
|
||||
# start = worker_id * split_size
|
||||
# if worker_id == num_workers - 1:
|
||||
# split = split + '[{}:]'.format(start)
|
||||
# else:
|
||||
# split = split + '[{}:{}]'.format(start, start + split_size)
|
||||
|
||||
input_context = tf.distribute.InputContext(
|
||||
num_input_pipelines=self.dist_num_replicas * num_workers,
|
||||
input_pipeline_id=self.dist_rank * num_workers + worker_id,
|
||||
num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact?
|
||||
)
|
||||
|
||||
read_config = tfds.ReadConfig(input_context=input_context)
|
||||
ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config)
|
||||
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
||||
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
||||
ds.options().experimental_threading.max_intra_op_parallelism = 1
|
||||
if self.is_training:
|
||||
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
|
||||
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
||||
ds = ds.repeat() # allow wrap around and break iteration manually
|
||||
if self.shuffle:
|
||||
ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0)
|
||||
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
|
||||
self.ds = tfds.as_numpy(ds)
|
||||
|
||||
def __iter__(self):
|
||||
if self.ds is None:
|
||||
self._lazy_init()
|
||||
# compute a rounded up sample count that is used to:
|
||||
# 1. make batches even cross workers & replicas in distributed validation.
|
||||
# This adds extra samples and will slightly alter validation results.
|
||||
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
|
||||
# batches are produced (underlying tfds iter wraps around)
|
||||
target_sample_count = math.ceil(self.num_samples / self._num_pipelines)
|
||||
if self.is_training:
|
||||
# round up to nearest batch_size per worker-replica
|
||||
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
|
||||
sample_count = 0
|
||||
for sample in self.ds:
|
||||
img = Image.fromarray(sample['image'], mode='RGB')
|
||||
yield img, sample['label']
|
||||
sample_count += 1
|
||||
if self.is_training and sample_count >= target_sample_count:
|
||||
# Need to break out of loop when repeat() is enabled for training w/ oversampling
|
||||
# this results in extra samples per epoch but seems more desirable than dropping
|
||||
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
|
||||
break
|
||||
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
|
||||
# Validation batch padding only done for distributed training where results are reduced across nodes.
|
||||
# For single process case, it won't matter if workers return different batch sizes.
|
||||
# FIXME this needs more testing, possible for sharding / split api to cause differences of > 1?
|
||||
assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal
|
||||
yield img, sample['label'] # yield prev sample again
|
||||
sample_count += 1
|
||||
|
||||
@property
|
||||
def _num_workers(self):
|
||||
return 1 if self.worker_info is None else self.worker_info.num_workers
|
||||
|
||||
@property
|
||||
def _num_pipelines(self):
|
||||
return self._num_workers * self.dist_num_replicas
|
||||
|
||||
def __len__(self):
|
||||
# this is just an estimate and does not factor in extra samples added to pad batches based on
|
||||
# complete worker & replica info (not available until init in dataloader).
|
||||
return math.ceil(self.num_samples / self.dist_num_replicas)
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
assert False, "Not supported" # no random access to samples
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
""" Return all filenames in dataset, overrides base"""
|
||||
if self.ds is None:
|
||||
self._lazy_init()
|
||||
names = []
|
||||
for sample in self.ds:
|
||||
if len(names) > self.num_samples:
|
||||
break # safety for ds.repeat() case
|
||||
if 'file_name' in sample:
|
||||
name = sample['file_name']
|
||||
elif 'filename' in sample:
|
||||
name = sample['filename']
|
||||
elif 'id' in sample:
|
||||
name = sample['id']
|
||||
else:
|
||||
assert False, "No supported name field present"
|
||||
names.append(name)
|
||||
return names
|
@ -16,6 +16,7 @@ from .regnet import *
|
||||
from .res2net import *
|
||||
from .resnest import *
|
||||
from .resnet import *
|
||||
from .resnetv2 import *
|
||||
from .rexnet import *
|
||||
from .selecsls import *
|
||||
from .senet import *
|
||||
|
@ -6,8 +6,6 @@ from .layers import set_layer_config
|
||||
def create_model(
|
||||
model_name,
|
||||
pretrained=False,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
checkpoint_path='',
|
||||
scriptable=None,
|
||||
exportable=None,
|
||||
@ -18,8 +16,6 @@ def create_model(
|
||||
Args:
|
||||
model_name (str): name of model to instantiate
|
||||
pretrained (bool): load pretrained ImageNet-1k weights if true
|
||||
num_classes (int): number of classes for final fully connected layer (default: 1000)
|
||||
in_chans (int): number of input channels / colors (default: 3)
|
||||
checkpoint_path (str): path of checkpoint to load after model is initialized
|
||||
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
|
||||
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
|
||||
@ -30,7 +26,7 @@ def create_model(
|
||||
global_pool (str): global pool type (default: 'avg')
|
||||
**: other kwargs are model specific
|
||||
"""
|
||||
model_args = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
|
||||
model_args = dict(pretrained=pretrained)
|
||||
|
||||
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
|
||||
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
|
||||
|
@ -11,7 +11,11 @@ from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
|
||||
try:
|
||||
from torch.hub import get_dir
|
||||
except ImportError:
|
||||
from torch.hub import _get_torch_home as get_dir
|
||||
|
||||
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
||||
from .layers import Conv2dSame, Linear
|
||||
@ -88,15 +92,70 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None,
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True):
|
||||
def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_hash=False):
|
||||
r"""Loads a custom (read non .pth) weight file
|
||||
|
||||
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
|
||||
a passed in custom load fun, or the `load_pretrained` model member fn.
|
||||
|
||||
If the object is already present in `model_dir`, it's deserialized and returned.
|
||||
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
|
||||
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
|
||||
|
||||
Args:
|
||||
model: The instantiated model to load weights into
|
||||
cfg (dict): Default pretrained model cfg
|
||||
load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
|
||||
'laod_pretrained' on the model will be called if it exists
|
||||
progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
|
||||
check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
|
||||
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
|
||||
digits of the SHA256 hash of the contents of the file. The hash is used to
|
||||
ensure unique names and to verify the contents of the file. Default: False
|
||||
"""
|
||||
if cfg is None:
|
||||
cfg = getattr(model, 'default_cfg')
|
||||
if cfg is None or 'url' not in cfg or not cfg['url']:
|
||||
_logger.warning("Pretrained model URL is invalid, using random initialization.")
|
||||
_logger.warning("Pretrained model URL does not exist, using random initialization.")
|
||||
return
|
||||
url = cfg['url']
|
||||
|
||||
# Issue warning to move data if old env is set
|
||||
if os.getenv('TORCH_MODEL_ZOO'):
|
||||
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
||||
|
||||
hub_dir = get_dir()
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints')
|
||||
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(model_dir, filename)
|
||||
if not os.path.exists(cached_file):
|
||||
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
if check_hash:
|
||||
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
||||
hash_prefix = r.group(1) if r else None
|
||||
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
||||
|
||||
if load_fn is not None:
|
||||
load_fn(model, cached_file)
|
||||
elif hasattr(model, 'load_pretrained'):
|
||||
model.load_pretrained(cached_file)
|
||||
else:
|
||||
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
|
||||
|
||||
|
||||
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
|
||||
if cfg is None:
|
||||
cfg = getattr(model, 'default_cfg')
|
||||
if cfg is None or 'url' not in cfg or not cfg['url']:
|
||||
_logger.warning("Pretrained model URL does not exist, using random initialization.")
|
||||
return
|
||||
|
||||
state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
|
||||
|
||||
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
|
||||
if filter_fn is not None:
|
||||
state_dict = filter_fn(state_dict)
|
||||
|
||||
@ -139,6 +198,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
||||
|
||||
classifier_name = cfg['classifier']
|
||||
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
||||
# FIXME this special case is problematic as number of pretrained weight sources increases
|
||||
# special case for imagenet trained models with extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
||||
@ -269,6 +329,7 @@ def build_model_with_cfg(
|
||||
feature_cfg: dict = None,
|
||||
pretrained_strict: bool = True,
|
||||
pretrained_filter_fn: Callable = None,
|
||||
pretrained_custom_load: bool = False,
|
||||
**kwargs):
|
||||
pruned = kwargs.pop('pruned', False)
|
||||
features = False
|
||||
@ -289,10 +350,13 @@ def build_model_with_cfg(
|
||||
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
||||
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
|
||||
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
|
||||
if pretrained_custom_load:
|
||||
load_custom_pretrained(model)
|
||||
else:
|
||||
load_pretrained(
|
||||
model,
|
||||
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
|
||||
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
|
||||
|
||||
if features:
|
||||
feature_cls = FeatureListNet
|
||||
|
@ -7,7 +7,7 @@ from .classifier import ClassifierHead, create_classifier
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
|
||||
set_layer_config
|
||||
from .conv2d_same import Conv2dSame
|
||||
from .conv2d_same import Conv2dSame, conv2d_same
|
||||
from .conv_bn_act import ConvBnAct
|
||||
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||
from .create_attn import create_attn
|
||||
@ -20,8 +20,8 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
|
||||
from .inplace_abn import InplaceAbn
|
||||
from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .norm_act import BatchNormAct2d
|
||||
from .padding import get_padding
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct
|
||||
from .padding import get_padding, get_same_padding, pad_same
|
||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||
from .se import SEModule
|
||||
from .selective_kernel import SelectiveKernelConv
|
||||
|
@ -9,31 +9,43 @@ from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from .linear import Linear
|
||||
|
||||
|
||||
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
flatten = not use_conv # flatten when we use a Linear layer after pooling
|
||||
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
|
||||
if not pool_type:
|
||||
assert num_classes == 0 or use_conv,\
|
||||
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
|
||||
flatten = False # disable flattening if pooling is pass-through (no pooling)
|
||||
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten)
|
||||
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
|
||||
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
|
||||
num_pooled_features = num_features * global_pool.feat_mult()
|
||||
return global_pool, num_pooled_features
|
||||
|
||||
|
||||
def _create_fc(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
if num_classes <= 0:
|
||||
fc = nn.Identity() # pass-through (no classifier)
|
||||
elif use_conv:
|
||||
fc = nn.Conv2d(num_pooled_features, num_classes, 1, bias=True)
|
||||
fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
|
||||
else:
|
||||
# NOTE: using my Linear wrapper that fixes AMP + torchscript casting issue
|
||||
fc = Linear(num_pooled_features, num_classes, bias=True)
|
||||
fc = Linear(num_features, num_classes, bias=True)
|
||||
return fc
|
||||
|
||||
|
||||
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
|
||||
fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||||
return global_pool, fc
|
||||
|
||||
|
||||
class ClassifierHead(nn.Module):
|
||||
"""Classifier head w/ configurable global pooling and dropout."""
|
||||
|
||||
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0.):
|
||||
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
|
||||
super(ClassifierHead, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.global_pool, self.fc = create_classifier(in_chs, num_classes, pool_type=pool_type)
|
||||
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
|
||||
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||||
self.flatten_after_fc = use_conv and pool_type
|
||||
|
||||
def forward(self, x):
|
||||
x = self.global_pool(x)
|
||||
|
@ -68,8 +68,8 @@ class BatchNormAct2d(nn.BatchNorm2d):
|
||||
|
||||
|
||||
class GroupNormAct(nn.GroupNorm):
|
||||
|
||||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True,
|
||||
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
|
||||
def __init__(self, num_channels, num_groups, eps=1e-5, affine=True,
|
||||
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
|
||||
super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine)
|
||||
if isinstance(act_layer, str):
|
||||
|
@ -403,7 +403,7 @@ class ReductionCell1(nn.Module):
|
||||
class NASNetALarge(nn.Module):
|
||||
"""NASNetALarge (6 @ 4032) """
|
||||
|
||||
def __init__(self, num_classes=1000, in_chans=1, stem_size=96, channel_multiplier=2,
|
||||
def __init__(self, num_classes=1000, in_chans=3, stem_size=96, channel_multiplier=2,
|
||||
num_features=4032, output_stride=32, drop_rate=0., global_pool='avg', pad_type='same'):
|
||||
super(NASNetALarge, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
|
@ -162,6 +162,12 @@ default_cfgs = {
|
||||
'seresnet152d_320': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth',
|
||||
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), crop_pct=1.0, pool_size=(10, 10)),
|
||||
'seresnet200d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
||||
'seresnet269d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
||||
|
||||
|
||||
# Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py
|
||||
@ -216,6 +222,12 @@ default_cfgs = {
|
||||
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
|
||||
interpolation='bicubic',
|
||||
first_conv='conv1.0'),
|
||||
'ecaresnet200d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
||||
'ecaresnet269d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
|
||||
|
||||
# Efficient Channel Attention ResNeXts
|
||||
'ecaresnext26tn_32x4d': _cfg(
|
||||
@ -1123,6 +1135,26 @@ def ecaresnet101d_pruned(pretrained=False, **kwargs):
|
||||
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnet200d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-200-D model with ECA.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='eca'), **kwargs)
|
||||
return _create_resnet('ecaresnet200d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnet269d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-269-D model with ECA.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='eca'), **kwargs)
|
||||
return _create_resnet('ecaresnet269d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def ecaresnext26tn_32x4d(pretrained=False, **kwargs):
|
||||
"""Constructs an ECA-ResNeXt-26-TN model.
|
||||
@ -1198,6 +1230,26 @@ def seresnet152d(pretrained=False, **kwargs):
|
||||
return _create_resnet('seresnet152d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet200d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-200-D model with SE attn.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnet200d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet269d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-269-D model with SE attn.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True,
|
||||
block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnet269d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet152d_320(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
|
593
timm/models/resnetv2.py
Normal file
593
timm/models/resnetv2.py
Normal file
@ -0,0 +1,593 @@
|
||||
"""Pre-Activation ResNet v2 with GroupNorm and Weight Standardization.
|
||||
|
||||
A PyTorch implementation of ResNetV2 adapted from the Google Big-Transfoer (BiT) source code
|
||||
at https://github.com/google-research/big_transfer to match timm interfaces. The BiT weights have
|
||||
been included here as pretrained models from their original .NPZ checkpoints.
|
||||
|
||||
Additionally, supports non pre-activation bottleneck for use as a backbone for Vision Transfomers (ViT) and
|
||||
extra padding support to allow porting of official Hybrid ResNet pretrained weights from
|
||||
https://github.com/google-research/vision_transformer
|
||||
|
||||
Thanks to the Google team for the above two repositories and associated papers:
|
||||
* Big Transfer (BiT): General Visual Representation Learning - https://arxiv.org/abs/1912.11370
|
||||
* An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale - https://arxiv.org/abs/2010.11929
|
||||
|
||||
Original copyright of Google code below, modifications by Ross Wightman, Copyright 2020.
|
||||
"""
|
||||
# Copyright 2020 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict # pylint: disable=g-importing-member
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
|
||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .registry import register_model
|
||||
from .layers import get_padding, GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, conv2d_same
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 480, 480), 'pool_size': (7, 7),
|
||||
'crop_pct': 1.0, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# pretrained on imagenet21k, finetuned on imagenet1k
|
||||
'resnetv2_50x1_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz'),
|
||||
'resnetv2_50x3_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3-ILSVRC2012.npz'),
|
||||
'resnetv2_101x1_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1-ILSVRC2012.npz'),
|
||||
'resnetv2_101x3_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3-ILSVRC2012.npz'),
|
||||
'resnetv2_152x2_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2-ILSVRC2012.npz'),
|
||||
'resnetv2_152x4_bitm': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4-ILSVRC2012.npz'),
|
||||
|
||||
# trained on imagenet-21k
|
||||
'resnetv2_50x1_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_50x3_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_101x1_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_101x3_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_152x2_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz',
|
||||
num_classes=21843),
|
||||
'resnetv2_152x4_bitm_in21k': _cfg(
|
||||
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
|
||||
num_classes=21843),
|
||||
|
||||
|
||||
# trained on imagenet-1k, NOTE not overly interesting set of weights, leaving disabled for now
|
||||
# 'resnetv2_50x1_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x1.npz'),
|
||||
# 'resnetv2_50x3_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x3.npz'),
|
||||
# 'resnetv2_101x1_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
|
||||
# 'resnetv2_101x3_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
|
||||
# 'resnetv2_152x2_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x2.npz'),
|
||||
# 'resnetv2_152x4_bits': _cfg(
|
||||
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x4.npz'),
|
||||
}
|
||||
|
||||
|
||||
def make_div(v, divisor=8):
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class StdConv2d(nn.Conv2d):
|
||||
|
||||
def __init__(
|
||||
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
|
||||
padding = get_padding(kernel_size, stride, dilation)
|
||||
super().__init__(
|
||||
in_channel, out_channels, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=bias, groups=groups)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
w = self.weight
|
||||
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
w = (w - m) / (torch.sqrt(v) + self.eps)
|
||||
x = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return x
|
||||
|
||||
|
||||
class StdConv2dSame(nn.Conv2d):
|
||||
"""StdConv2d w/ TF compatible SAME padding. Used for ViT Hybrid model.
|
||||
"""
|
||||
def __init__(
|
||||
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
|
||||
padding = get_padding(kernel_size, stride, dilation)
|
||||
super().__init__(
|
||||
in_channel, out_channels, kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation, bias=bias, groups=groups)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
w = self.weight
|
||||
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
|
||||
w = (w - m) / (torch.sqrt(v) + self.eps)
|
||||
x = conv2d_same(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return x
|
||||
|
||||
|
||||
def tf2th(conv_weights):
|
||||
"""Possibly convert HWIO to OIHW."""
|
||||
if conv_weights.ndim == 4:
|
||||
conv_weights = conv_weights.transpose([3, 2, 0, 1])
|
||||
return torch.from_numpy(conv_weights)
|
||||
|
||||
|
||||
class PreActBottleneck(nn.Module):
|
||||
"""Pre-activation (v2) bottleneck block.
|
||||
|
||||
Follows the implementation of "Identity Mappings in Deep Residual Networks":
|
||||
https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
|
||||
|
||||
Except it puts the stride on 3x3 conv when available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
|
||||
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
|
||||
super().__init__()
|
||||
first_dilation = first_dilation or dilation
|
||||
conv_layer = conv_layer or StdConv2d
|
||||
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
|
||||
out_chs = out_chs or in_chs
|
||||
mid_chs = make_div(out_chs * bottle_ratio)
|
||||
|
||||
if proj_layer is not None:
|
||||
self.downsample = proj_layer(
|
||||
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, preact=True,
|
||||
conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
self.norm1 = norm_layer(in_chs)
|
||||
self.conv1 = conv_layer(in_chs, mid_chs, 1)
|
||||
self.norm2 = norm_layer(mid_chs)
|
||||
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
|
||||
self.norm3 = norm_layer(mid_chs)
|
||||
self.conv3 = conv_layer(mid_chs, out_chs, 1)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x_preact = self.norm1(x)
|
||||
|
||||
# shortcut branch
|
||||
shortcut = x
|
||||
if self.downsample is not None:
|
||||
shortcut = self.downsample(x_preact)
|
||||
|
||||
# residual branch
|
||||
x = self.conv1(x_preact)
|
||||
x = self.conv2(self.norm2(x))
|
||||
x = self.conv3(self.norm3(x))
|
||||
x = self.drop_path(x)
|
||||
return x + shortcut
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT.
|
||||
"""
|
||||
def __init__(
|
||||
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
|
||||
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
|
||||
super().__init__()
|
||||
first_dilation = first_dilation or dilation
|
||||
act_layer = act_layer or nn.ReLU
|
||||
conv_layer = conv_layer or StdConv2d
|
||||
norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
|
||||
out_chs = out_chs or in_chs
|
||||
mid_chs = make_div(out_chs * bottle_ratio)
|
||||
|
||||
if proj_layer is not None:
|
||||
self.downsample = proj_layer(
|
||||
in_chs, out_chs, stride=stride, dilation=dilation, preact=False,
|
||||
conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
self.conv1 = conv_layer(in_chs, mid_chs, 1)
|
||||
self.norm1 = norm_layer(mid_chs)
|
||||
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
|
||||
self.norm2 = norm_layer(mid_chs)
|
||||
self.conv3 = conv_layer(mid_chs, out_chs, 1)
|
||||
self.norm3 = norm_layer(out_chs, apply_act=False)
|
||||
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
|
||||
self.act3 = act_layer(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
# shortcut branch
|
||||
shortcut = x
|
||||
if self.downsample is not None:
|
||||
shortcut = self.downsample(x)
|
||||
|
||||
# residual
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.norm3(x)
|
||||
x = self.drop_path(x)
|
||||
x = self.act3(x + shortcut)
|
||||
return x
|
||||
|
||||
|
||||
class DownsampleConv(nn.Module):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True,
|
||||
conv_layer=None, norm_layer=None):
|
||||
super(DownsampleConv, self).__init__()
|
||||
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
|
||||
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.conv(x))
|
||||
|
||||
|
||||
class DownsampleAvg(nn.Module):
|
||||
def __init__(
|
||||
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None,
|
||||
preact=True, conv_layer=None, norm_layer=None):
|
||||
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
|
||||
super(DownsampleAvg, self).__init__()
|
||||
avg_stride = stride if dilation == 1 else 1
|
||||
if stride > 1 or dilation > 1:
|
||||
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
||||
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
||||
else:
|
||||
self.pool = nn.Identity()
|
||||
self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
|
||||
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.conv(self.pool(x)))
|
||||
|
||||
|
||||
class ResNetStage(nn.Module):
|
||||
"""ResNet Stage."""
|
||||
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
|
||||
avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
|
||||
act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
|
||||
super(ResNetStage, self).__init__()
|
||||
first_dilation = 1 if dilation in (1, 2) else 2
|
||||
layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
proj_layer = DownsampleAvg if avg_down else DownsampleConv
|
||||
prev_chs = in_chs
|
||||
self.blocks = nn.Sequential()
|
||||
for block_idx in range(depth):
|
||||
drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
|
||||
stride = stride if block_idx == 0 else 1
|
||||
self.blocks.add_module(str(block_idx), block_fn(
|
||||
prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups,
|
||||
first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate,
|
||||
**layer_kwargs, **block_kwargs))
|
||||
prev_chs = out_chs
|
||||
first_dilation = dilation
|
||||
proj_layer = None
|
||||
|
||||
def forward(self, x):
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, norm_layer=None):
|
||||
stem = OrderedDict()
|
||||
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
|
||||
|
||||
# NOTE conv padding mode can be changed by overriding the conv_layer def
|
||||
if 'deep' in stem_type:
|
||||
# A 3 deep 3x3 conv stack as in ResNet V1D models
|
||||
mid_chs = out_chs // 2
|
||||
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
|
||||
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
|
||||
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
|
||||
else:
|
||||
# The usual 7x7 stem conv
|
||||
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
|
||||
|
||||
if not preact:
|
||||
stem['norm'] = norm_layer(out_chs)
|
||||
|
||||
if 'fixed' in stem_type:
|
||||
# 'fixed' SAME padding approximation that is used in BiT models
|
||||
stem['pad'] = nn.ConstantPad2d(1, 0.)
|
||||
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
|
||||
elif 'same' in stem_type:
|
||||
# full, input size based 'SAME' padding, used in ViT Hybrid model
|
||||
stem['pool'] = create_pool2d('max', kernel_size=3, stride=2, padding='same')
|
||||
else:
|
||||
# the usual PyTorch symmetric padding
|
||||
stem['pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
return nn.Sequential(stem)
|
||||
|
||||
|
||||
class ResNetV2(nn.Module):
|
||||
"""Implementation of Pre-activation (v2) ResNet mode.
|
||||
"""
|
||||
|
||||
def __init__(self, layers, channels=(256, 512, 1024, 2048),
|
||||
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
||||
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
|
||||
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
|
||||
drop_rate=0., drop_path_rate=0.):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
wf = width_factor
|
||||
|
||||
self.feature_info = []
|
||||
stem_chs = make_div(stem_chs * wf)
|
||||
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
|
||||
# NOTE no, reduction 2 feature if preact
|
||||
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module='' if preact else 'stem.norm'))
|
||||
|
||||
prev_chs = stem_chs
|
||||
curr_stride = 4
|
||||
dilation = 1
|
||||
block_dprs = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
|
||||
block_fn = PreActBottleneck if preact else Bottleneck
|
||||
self.stages = nn.Sequential()
|
||||
for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
|
||||
out_chs = make_div(c * wf)
|
||||
stride = 1 if stage_idx == 0 else 2
|
||||
if curr_stride >= output_stride:
|
||||
dilation *= stride
|
||||
stride = 1
|
||||
stage = ResNetStage(
|
||||
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
|
||||
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
|
||||
prev_chs = out_chs
|
||||
curr_stride *= stride
|
||||
feat_name = f'stages.{stage_idx}'
|
||||
if preact:
|
||||
feat_name = f'stages.{stage_idx + 1}.blocks.0.norm1' if (stage_idx + 1) != len(channels) else 'norm'
|
||||
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=feat_name)]
|
||||
self.stages.add_module(str(stage_idx), stage)
|
||||
|
||||
self.num_features = prev_chs
|
||||
self.norm = norm_layer(self.num_features) if preact else nn.Identity()
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
||||
|
||||
for n, m in self.named_modules():
|
||||
if isinstance(m, nn.Linear) or ('.fc' in n and isinstance(m, nn.Conv2d)):
|
||||
nn.init.normal_(m.weight, mean=0.0, std=0.01)
|
||||
nn.init.zeros_(m.bias)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.head = ClassifierHead(
|
||||
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
if not self.head.global_pool.is_identity():
|
||||
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
|
||||
return x
|
||||
|
||||
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
|
||||
import numpy as np
|
||||
weights = np.load(checkpoint_path)
|
||||
with torch.no_grad():
|
||||
stem_conv_w = tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])
|
||||
if self.stem.conv.weight.shape[1] == 1:
|
||||
self.stem.conv.weight.copy_(stem_conv_w.sum(dim=1, keepdim=True))
|
||||
# FIXME handle > 3 in_chans?
|
||||
else:
|
||||
self.stem.conv.weight.copy_(stem_conv_w)
|
||||
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
|
||||
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
|
||||
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
|
||||
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
|
||||
for i, (sname, stage) in enumerate(self.stages.named_children()):
|
||||
for j, (bname, block) in enumerate(stage.blocks.named_children()):
|
||||
convname = 'standardized_conv2d'
|
||||
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
|
||||
block.conv1.weight.copy_(tf2th(weights[f'{block_prefix}a/{convname}/kernel']))
|
||||
block.conv2.weight.copy_(tf2th(weights[f'{block_prefix}b/{convname}/kernel']))
|
||||
block.conv3.weight.copy_(tf2th(weights[f'{block_prefix}c/{convname}/kernel']))
|
||||
block.norm1.weight.copy_(tf2th(weights[f'{block_prefix}a/group_norm/gamma']))
|
||||
block.norm2.weight.copy_(tf2th(weights[f'{block_prefix}b/group_norm/gamma']))
|
||||
block.norm3.weight.copy_(tf2th(weights[f'{block_prefix}c/group_norm/gamma']))
|
||||
block.norm1.bias.copy_(tf2th(weights[f'{block_prefix}a/group_norm/beta']))
|
||||
block.norm2.bias.copy_(tf2th(weights[f'{block_prefix}b/group_norm/beta']))
|
||||
block.norm3.bias.copy_(tf2th(weights[f'{block_prefix}c/group_norm/beta']))
|
||||
if block.downsample is not None:
|
||||
w = weights[f'{block_prefix}a/proj/{convname}/kernel']
|
||||
block.downsample.conv.weight.copy_(tf2th(w))
|
||||
|
||||
|
||||
def _create_resnetv2(variant, pretrained=False, **kwargs):
|
||||
# FIXME feature map extraction is not setup properly for pre-activation mode right now
|
||||
preact = kwargs.get('preact', True)
|
||||
feature_cfg = dict(flatten_sequential=True)
|
||||
if preact:
|
||||
feature_cfg['feature_cls'] = 'hook'
|
||||
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for preact
|
||||
|
||||
return build_model_with_cfg(
|
||||
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
|
||||
feature_cfg=feature_cfg, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x1_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x3_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x1_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x3_bitm', pretrained=pretrained,
|
||||
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x2_bitm', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bitm(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x4_bitm', pretrained=pretrained,
|
||||
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
|
||||
return _create_resnetv2(
|
||||
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
|
||||
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
|
||||
|
||||
# NOTE the 'S' versions of the model weights arent as interesting as original 21k or transfer to 1K M.
|
||||
|
||||
# @register_model
|
||||
# def resnetv2_50x1_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_50x1_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_50x3_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_50x3_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_101x1_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_101x1_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_101x3_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_101x3_bits', pretrained=pretrained,
|
||||
# layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_152x2_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_152x2_bits', pretrained=pretrained,
|
||||
# layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
|
||||
#
|
||||
#
|
||||
# @register_model
|
||||
# def resnetv2_152x4_bits(pretrained=False, **kwargs):
|
||||
# return _create_resnetv2(
|
||||
# 'resnetv2_152x4_bits', pretrained=pretrained,
|
||||
# layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
|
||||
#
|
@ -5,12 +5,6 @@ A PyTorch implement of Vision Transformers as described in
|
||||
|
||||
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
||||
|
||||
Status/TODO:
|
||||
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
|
||||
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
|
||||
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
|
||||
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
|
||||
|
||||
Acknowledgments:
|
||||
* The paper authors for releasing code and weights, thanks!
|
||||
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
||||
@ -18,18 +12,29 @@ for some einops/einsum fun
|
||||
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
||||
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
||||
|
||||
DeiT model defs and weights from https://github.com/facebookresearch/deit,
|
||||
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
import logging
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
from .layers import DropPath, to_2tuple, trunc_normal_
|
||||
from .resnet import resnet26d, resnet50d
|
||||
from .resnetv2 import ResNetV2, StdConv2dSame
|
||||
from .registry import register_model
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
@ -43,14 +48,19 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
# patch models
|
||||
# patch models (my experiments)
|
||||
'vit_small_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
||||
),
|
||||
|
||||
# patch models (weights ported from official Google JAX impl)
|
||||
'vit_base_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
||||
),
|
||||
'vit_base_patch32_224': _cfg(
|
||||
url='', # no official model weights for this combo, only for in21k
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_base_patch16_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
@ -60,19 +70,66 @@ default_cfgs = {
|
||||
'vit_large_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_large_patch32_224': _cfg(
|
||||
url='', # no official model weights for this combo, only for in21k
|
||||
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_large_patch16_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
'vit_large_patch32_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
'vit_huge_patch16_224': _cfg(),
|
||||
'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)),
|
||||
# hybrid models
|
||||
|
||||
# patch models, imagenet21k (weights ported from official Google JAX impl)
|
||||
'vit_base_patch16_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_base_patch32_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_large_patch16_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_large_patch32_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
'vit_huge_patch14_224_in21k': _cfg(
|
||||
url='', # FIXME I have weights for this but > 2GB limit for github release binaries
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||
|
||||
# hybrid models (weights ported from official Google JAX impl)
|
||||
'vit_base_resnet50_224_in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'),
|
||||
'vit_base_resnet50_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'),
|
||||
|
||||
# hybrid models (my experiments)
|
||||
'vit_small_resnet26d_224': _cfg(),
|
||||
'vit_small_resnet50d_s3_224': _cfg(),
|
||||
'vit_base_resnet26d_224': _cfg(),
|
||||
'vit_base_resnet50d_224': _cfg(),
|
||||
|
||||
# deit models (FB weights)
|
||||
'vit_deit_tiny_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
||||
'vit_deit_small_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
||||
'vit_deit_base_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
|
||||
'vit_deit_base_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'vit_deit_tiny_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'),
|
||||
'vit_deit_small_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'),
|
||||
'vit_deit_base_distilled_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ),
|
||||
'vit_deit_base_distilled_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
}
|
||||
|
||||
|
||||
@ -184,32 +241,61 @@ class HybridEmbed(nn.Module):
|
||||
training = backbone.training
|
||||
if training:
|
||||
backbone.eval()
|
||||
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
||||
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
||||
if isinstance(o, (list, tuple)):
|
||||
o = o[-1] # last feature if backbone outputs list/tuple of features
|
||||
feature_size = o.shape[-2:]
|
||||
feature_dim = o.shape[1]
|
||||
backbone.train(training)
|
||||
else:
|
||||
feature_size = to_2tuple(feature_size)
|
||||
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||
if hasattr(self.backbone, 'feature_info'):
|
||||
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||
else:
|
||||
feature_dim = self.backbone.num_features
|
||||
self.num_patches = feature_size[0] * feature_size[1]
|
||||
self.proj = nn.Linear(feature_dim, embed_dim)
|
||||
self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)[-1]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.proj(x)
|
||||
x = self.backbone(x)
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
""" Vision Transformer
|
||||
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
||||
https://arxiv.org/abs/2010.11929
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
||||
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
num_classes (int): number of classes for classification head
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
||||
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
||||
drop_rate (float): dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
if hybrid_backbone is not None:
|
||||
self.patch_embed = HybridEmbed(
|
||||
@ -231,12 +317,18 @@ class VisionTransformer(nn.Module):
|
||||
for i in range(depth)])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
|
||||
#self.repr = nn.Linear(embed_dim, representation_size)
|
||||
#self.repr_act = nn.Tanh()
|
||||
# Representation layer
|
||||
if representation_size:
|
||||
self.num_features = representation_size
|
||||
self.pre_logits = nn.Sequential(OrderedDict([
|
||||
('fc', nn.Linear(embed_dim, representation_size)),
|
||||
('act', nn.Tanh())
|
||||
]))
|
||||
else:
|
||||
self.pre_logits = nn.Identity()
|
||||
|
||||
# Classifier head
|
||||
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
@ -274,8 +366,9 @@ class VisionTransformer(nn.Module):
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
return x[:, 0]
|
||||
x = self.norm(x)[:, 0]
|
||||
x = self.pre_logits(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
@ -283,146 +376,412 @@ class VisionTransformer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _conv_filter(state_dict, patch_size=16):
|
||||
class DistilledVisionTransformer(VisionTransformer):
|
||||
""" Vision Transformer with distillation token.
|
||||
|
||||
Paper: `Training data-efficient image transformers & distillation through attention` -
|
||||
https://arxiv.org/abs/2012.12877
|
||||
|
||||
This impl of distilled ViT is taken from https://github.com/facebookresearch/deit
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||
num_patches = self.patch_embed.num_patches
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
||||
|
||||
trunc_normal_(self.dist_token, std=.02)
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
self.head_dist.apply(self._init_weights)
|
||||
|
||||
def forward_features(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
dist_token = self.dist_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
||||
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x)
|
||||
return x[:, 0], x[:, 1]
|
||||
|
||||
def forward(self, x):
|
||||
x, x_dist = self.forward_features(x)
|
||||
x = self.head(x)
|
||||
x_dist = self.head_dist(x_dist)
|
||||
if self.training:
|
||||
return x, x_dist
|
||||
else:
|
||||
# during inference, return the average of both classifier predictions
|
||||
return (x + x_dist) / 2
|
||||
|
||||
|
||||
def resize_pos_embed(posemb, posemb_new):
|
||||
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
||||
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
||||
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
||||
ntok_new = posemb_new.shape[1]
|
||||
if True:
|
||||
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
|
||||
ntok_new -= 1
|
||||
else:
|
||||
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
gs_new = int(math.sqrt(ntok_new))
|
||||
_logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
|
||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
||||
return posemb
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
out_dict = {}
|
||||
if 'model' in state_dict:
|
||||
# For deit models
|
||||
state_dict = state_dict['model']
|
||||
for k, v in state_dict.items():
|
||||
if 'patch_embed.proj.weight' in k:
|
||||
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
||||
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
||||
# For old models that I trained prior to conv based patchification
|
||||
O, I, H, W = model.patch_embed.proj.weight.shape
|
||||
v = v.reshape(O, -1, H, W)
|
||||
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
||||
# To resize pos embedding when using model at different size from pretrained weights
|
||||
v = resize_pos_embed(v, model.pos_embed)
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch16_224(pretrained=False, **kwargs):
|
||||
if pretrained:
|
||||
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
|
||||
kwargs.setdefault('qk_scale', 768 ** -0.5)
|
||||
model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_small_patch16_224']
|
||||
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
|
||||
default_cfg = default_cfgs[variant]
|
||||
default_num_classes = default_cfg['num_classes']
|
||||
default_img_size = default_cfg['input_size'][-1]
|
||||
|
||||
num_classes = kwargs.pop('num_classes', default_num_classes)
|
||||
img_size = kwargs.pop('img_size', default_img_size)
|
||||
repr_size = kwargs.pop('representation_size', None)
|
||||
if repr_size is not None and num_classes != default_num_classes:
|
||||
# Remove representation layer if fine-tuning. This may not always be the desired action,
|
||||
# but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
|
||||
_logger.warning("Removing representation layer for fine-tuning.")
|
||||
repr_size = None
|
||||
|
||||
model_cls = DistilledVisionTransformer if distilled else VisionTransformer
|
||||
model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
|
||||
model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3),
|
||||
filter_fn=partial(checkpoint_filter_fn, model=model))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch16_224(pretrained=False, **kwargs):
|
||||
""" My custom 'small' ViT model. Depth=8, heads=8= mlp_ratio=3."""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
|
||||
qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
|
||||
if pretrained:
|
||||
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
|
||||
model_kwargs.setdefault('qk_scale', 768 ** -0.5)
|
||||
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_224(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_patch16_224']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
|
||||
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch32_224(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_384(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_patch16_384']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch32_384(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_patch32_384']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_224(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_large_patch16_224']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch32_224(pretrained=False, **kwargs):
|
||||
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_384(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_large_patch16_384']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch32_384(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_large_patch32_384']
|
||||
if pretrained:
|
||||
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
||||
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch16_224(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(patch_size=16, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_huge_patch16_224']
|
||||
def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch32_384(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384, patch_size=32, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_huge_patch32_384']
|
||||
def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
||||
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
||||
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
||||
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
||||
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
||||
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
NOTE: converted weights not currently available, too large for github release hosting.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
|
||||
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
|
||||
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
|
||||
backbone = ResNetV2(
|
||||
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
||||
preact=False, stem_type='same', conv_layer=StdConv2dSame)
|
||||
model_kwargs = dict(
|
||||
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone,
|
||||
representation_size=768, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50_384(pretrained=False, **kwargs):
|
||||
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||
"""
|
||||
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
|
||||
backbone = ResNetV2(
|
||||
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
||||
preact=False, stem_type='same', conv_layer=StdConv2dSame)
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_resnet26d_224(pretrained=False, **kwargs):
|
||||
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing
|
||||
backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4])
|
||||
model = VisionTransformer(
|
||||
img_size=224, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_small_resnet26d_224']
|
||||
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
|
||||
"""
|
||||
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
||||
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
|
||||
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing
|
||||
backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[3])
|
||||
model = VisionTransformer(
|
||||
img_size=224, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_small_resnet50d_s3_224']
|
||||
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
|
||||
"""
|
||||
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
|
||||
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet26d_224(pretrained=False, **kwargs):
|
||||
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing
|
||||
backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4])
|
||||
model = VisionTransformer(
|
||||
img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_resnet26d_224']
|
||||
""" Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
|
||||
"""
|
||||
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50d_224(pretrained=False, **kwargs):
|
||||
pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing
|
||||
backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[4])
|
||||
model = VisionTransformer(
|
||||
img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_resnet50d_224']
|
||||
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
|
||||
"""
|
||||
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_small_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_base_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_base_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||
model = _create_vision_transformer(
|
||||
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||
model = _create_vision_transformer(
|
||||
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer(
|
||||
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||
model = _create_vision_transformer(
|
||||
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
@ -1 +1 @@
|
||||
__version__ = '0.3.4'
|
||||
__version__ = '0.4.0'
|
||||
|
40
train.py
40
train.py
@ -28,7 +28,7 @@ import torch.nn as nn
|
||||
import torchvision.utils
|
||||
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
||||
|
||||
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
||||
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
|
||||
from timm.utils import *
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||
@ -64,8 +64,14 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
||||
|
||||
# Dataset / Model parameters
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
parser.add_argument('data_dir', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
|
||||
help='dataset type (default: ImageFolder/ImageTar if empty)')
|
||||
parser.add_argument('--train-split', metavar='NAME', default='train',
|
||||
help='dataset train split (default: train)')
|
||||
parser.add_argument('--val-split', metavar='NAME', default='validation',
|
||||
help='dataset validation split (default: validation)')
|
||||
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
|
||||
help='Name of model to train (default: "countception"')
|
||||
parser.add_argument('--pretrained', action='store_true', default=False,
|
||||
@ -76,8 +82,8 @@ parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='Resume full model and optimizer state from checkpoint (default: none)')
|
||||
parser.add_argument('--no-resume-opt', action='store_true', default=False,
|
||||
help='prevent resume of optimizer state when resuming model')
|
||||
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
|
||||
help='number of label classes (default: 1000)')
|
||||
parser.add_argument('--num-classes', type=int, default=None, metavar='N',
|
||||
help='number of label classes (Model default if None)')
|
||||
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
||||
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
||||
parser.add_argument('--img-size', type=int, default=None, metavar='N',
|
||||
@ -331,6 +337,9 @@ def main():
|
||||
bn_eps=args.bn_eps,
|
||||
scriptable=args.torchscript,
|
||||
checkpoint_path=args.initial_checkpoint)
|
||||
if args.num_classes is None:
|
||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
|
||||
|
||||
if args.local_rank == 0:
|
||||
_logger.info('Model %s created, param count: %d' %
|
||||
@ -437,19 +446,10 @@ def main():
|
||||
_logger.info('Scheduled epochs: {}'.format(num_epochs))
|
||||
|
||||
# create the train and eval datasets
|
||||
train_dir = os.path.join(args.data, 'train')
|
||||
if not os.path.exists(train_dir):
|
||||
_logger.error('Training folder does not exist at: {}'.format(train_dir))
|
||||
exit(1)
|
||||
dataset_train = Dataset(train_dir)
|
||||
|
||||
eval_dir = os.path.join(args.data, 'val')
|
||||
if not os.path.isdir(eval_dir):
|
||||
eval_dir = os.path.join(args.data, 'validation')
|
||||
if not os.path.isdir(eval_dir):
|
||||
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
|
||||
exit(1)
|
||||
dataset_eval = Dataset(eval_dir)
|
||||
dataset_train = create_dataset(
|
||||
args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size)
|
||||
dataset_eval = create_dataset(
|
||||
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
|
||||
|
||||
# setup mixup / cutmix
|
||||
collate_fn = None
|
||||
@ -553,10 +553,10 @@ def main():
|
||||
|
||||
try:
|
||||
for epoch in range(start_epoch, num_epochs):
|
||||
if args.distributed:
|
||||
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
|
||||
loader_train.sampler.set_epoch(epoch)
|
||||
|
||||
train_metrics = train_epoch(
|
||||
train_metrics = train_one_epoch(
|
||||
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
||||
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
||||
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
|
||||
@ -594,7 +594,7 @@ def main():
|
||||
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
|
||||
|
||||
|
||||
def train_epoch(
|
||||
def train_one_epoch(
|
||||
epoch, model, loader, optimizer, loss_fn, args,
|
||||
lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
|
||||
loss_scaler=None, model_ema=None, mixup_fn=None):
|
||||
|
20
validate.py
20
validate.py
@ -20,7 +20,7 @@ from collections import OrderedDict
|
||||
from contextlib import suppress
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config, RealLabelsImagenet
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
|
||||
|
||||
has_apex = False
|
||||
@ -44,7 +44,11 @@ _logger = logging.getLogger('validate')
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
|
||||
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
|
||||
help='dataset type (default: ImageFolder/ImageTar if empty)')
|
||||
parser.add_argument('--split', metavar='NAME', default='validation',
|
||||
help='dataset split (default: validation)')
|
||||
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
|
||||
help='model architecture (default: dpn92)')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 2)')
|
||||
@ -62,7 +66,7 @@ parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD'
|
||||
help='Override std deviation of of dataset')
|
||||
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||
help='Image resize interpolation type (overrides model)')
|
||||
parser.add_argument('--num-classes', type=int, default=1000,
|
||||
parser.add_argument('--num-classes', type=int, default=None,
|
||||
help='Number classes in dataset')
|
||||
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
|
||||
help='path to class to idx mapping file (default: "")')
|
||||
@ -133,6 +137,9 @@ def validate(args):
|
||||
in_chans=3,
|
||||
global_pool=args.gp,
|
||||
scriptable=args.torchscript)
|
||||
if args.num_classes is None:
|
||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||
args.num_classes = model.num_classes
|
||||
|
||||
if args.checkpoint:
|
||||
load_checkpoint(model, args.checkpoint, args.use_ema)
|
||||
@ -159,10 +166,9 @@ def validate(args):
|
||||
|
||||
criterion = nn.CrossEntropyLoss().cuda()
|
||||
|
||||
if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data):
|
||||
dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||
else:
|
||||
dataset = Dataset(args.data, load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||
dataset = create_dataset(
|
||||
root=args.data, name=args.dataset, split=args.split,
|
||||
load_bytes=args.tf_preprocessing, class_map=args.class_map)
|
||||
|
||||
if args.valid_labels:
|
||||
with open(args.valid_labels, 'r') as f:
|
||||
|
Loading…
x
Reference in New Issue
Block a user