mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1327 from rwightman/edgenext_csp_and_more
EdgeNeXt, additional DarkNets, and more
This commit is contained in:
commit
1ccce50d48
69
benchmark.py
69
benchmark.py
@ -6,24 +6,23 @@ An inference and train step benchmark script for timm models.
|
||||
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import csv
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
|
||||
from timm.data import resolve_data_config
|
||||
from timm.models import create_model, is_model, list_models
|
||||
from timm.optim import create_optimizer_v2
|
||||
from timm.data import resolve_data_config
|
||||
from timm.utils import setup_default_logging, set_jit_fuser
|
||||
|
||||
|
||||
has_apex = False
|
||||
try:
|
||||
from apex import amp
|
||||
@ -71,6 +70,8 @@ parser.add_argument('--bench', default='both', type=str,
|
||||
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'")
|
||||
parser.add_argument('--detail', action='store_true', default=False,
|
||||
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
|
||||
parser.add_argument('--no-retry', action='store_true', default=False,
|
||||
help='Do not decay batch size and retry on error.')
|
||||
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||
help='Output csv file for validation results (summary)')
|
||||
parser.add_argument('--num-warm-iter', default=10, type=int,
|
||||
@ -169,10 +170,9 @@ def resolve_precision(precision: str):
|
||||
|
||||
|
||||
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
|
||||
macs, _ = get_model_profile(
|
||||
_, macs, _ = get_model_profile(
|
||||
model=model,
|
||||
input_res=(batch_size,) + input_size, # input shape or input to the input_constructor
|
||||
input_constructor=None, # if specified, a constructor taking input_res is used as input to the model
|
||||
input_shape=(batch_size,) + input_size, # input shape/resolution
|
||||
print_profile=detailed, # prints the model graph with the measured profile attached to each module
|
||||
detailed=detailed, # print the detailed profile
|
||||
warm_up=10, # the number of warm-ups before measuring the time of each module
|
||||
@ -197,8 +197,19 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
|
||||
|
||||
class BenchmarkRunner:
|
||||
def __init__(
|
||||
self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, precision='float32',
|
||||
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
|
||||
self,
|
||||
model_name,
|
||||
detail=False,
|
||||
device='cuda',
|
||||
torchscript=False,
|
||||
aot_autograd=False,
|
||||
precision='float32',
|
||||
fuser='',
|
||||
num_warm_iter=10,
|
||||
num_bench_iter=50,
|
||||
use_train_size=False,
|
||||
**kwargs
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.detail = detail
|
||||
self.device = device
|
||||
@ -225,11 +236,12 @@ class BenchmarkRunner:
|
||||
self.num_classes = self.model.num_classes
|
||||
self.param_count = count_params(self.model)
|
||||
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
|
||||
|
||||
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
|
||||
self.scripted = False
|
||||
if torchscript:
|
||||
self.model = torch.jit.script(self.model)
|
||||
self.scripted = True
|
||||
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
|
||||
self.input_size = data_config['input_size']
|
||||
self.batch_size = kwargs.pop('batch_size', 256)
|
||||
|
||||
@ -255,7 +267,13 @@ class BenchmarkRunner:
|
||||
|
||||
class InferenceBenchmarkRunner(BenchmarkRunner):
|
||||
|
||||
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
device='cuda',
|
||||
torchscript=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
|
||||
self.model.eval()
|
||||
|
||||
@ -324,7 +342,13 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
|
||||
|
||||
class TrainBenchmarkRunner(BenchmarkRunner):
|
||||
|
||||
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
device='cuda',
|
||||
torchscript=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
|
||||
self.model.train()
|
||||
|
||||
@ -491,7 +515,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16):
|
||||
return max(0, int(out_batch_size))
|
||||
|
||||
|
||||
def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
|
||||
def _try_run(model_name, bench_fn, bench_kwargs, initial_batch_size, no_batch_size_retry=False):
|
||||
batch_size = initial_batch_size
|
||||
results = dict()
|
||||
error_str = 'Unknown'
|
||||
@ -506,8 +530,11 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
|
||||
if 'channels_last' in error_str:
|
||||
_logger.error(f'{model_name} not supported in channels_last, skipping.')
|
||||
break
|
||||
_logger.warning(f'"{error_str}" while running benchmark. Reducing batch size to {batch_size} for retry.')
|
||||
_logger.error(f'"{error_str}" while running benchmark.')
|
||||
if no_batch_size_retry:
|
||||
break
|
||||
batch_size = decay_batch_exp(batch_size)
|
||||
_logger.warning(f'Reducing batch size to {batch_size} for retry.')
|
||||
results['error'] = error_str
|
||||
return results
|
||||
|
||||
@ -549,7 +576,13 @@ def benchmark(args):
|
||||
|
||||
model_results = OrderedDict(model=model)
|
||||
for prefix, bench_fn in zip(prefixes, bench_fns):
|
||||
run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs)
|
||||
run_results = _try_run(
|
||||
model,
|
||||
bench_fn,
|
||||
bench_kwargs=bench_kwargs,
|
||||
initial_batch_size=batch_size,
|
||||
no_batch_size_retry=args.no_retry,
|
||||
)
|
||||
if prefix and 'error' not in run_results:
|
||||
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
|
||||
model_results.update(run_results)
|
||||
|
@ -6,7 +6,8 @@ from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
||||
from .dataset_factory import create_dataset
|
||||
from .loader import create_loader
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .parsers import create_parser
|
||||
from .parsers import create_parser,\
|
||||
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
||||
from .real_labels import RealLabelsImagenet
|
||||
from .transforms import *
|
||||
from .transforms_factory import create_transform
|
||||
from .transforms_factory import create_transform
|
||||
|
@ -64,11 +64,15 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v
|
||||
new_config['std'] = default_cfg['std']
|
||||
|
||||
# resolve default crop percentage
|
||||
new_config['crop_pct'] = DEFAULT_CROP_PCT
|
||||
crop_pct = DEFAULT_CROP_PCT
|
||||
if 'crop_pct' in args and args['crop_pct'] is not None:
|
||||
new_config['crop_pct'] = args['crop_pct']
|
||||
elif 'crop_pct' in default_cfg:
|
||||
new_config['crop_pct'] = default_cfg['crop_pct']
|
||||
crop_pct = args['crop_pct']
|
||||
else:
|
||||
if use_test_size and 'test_crop_pct' in default_cfg:
|
||||
crop_pct = default_cfg['test_crop_pct']
|
||||
elif 'crop_pct' in default_cfg:
|
||||
crop_pct = default_cfg['crop_pct']
|
||||
new_config['crop_pct'] = crop_pct
|
||||
|
||||
if verbose:
|
||||
_logger.info('Data processing configuration for current model + dataset:')
|
||||
|
@ -26,8 +26,8 @@ _TORCH_BASIC_DS = dict(
|
||||
kmnist=KMNIST,
|
||||
fashion_mnist=FashionMNIST,
|
||||
)
|
||||
_TRAIN_SYNONYM = {'train', 'training'}
|
||||
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
|
||||
_TRAIN_SYNONYM = dict(train=None, training=None)
|
||||
_EVAL_SYNONYM = dict(val=None, valid=None, validation=None, eval=None, evaluation=None)
|
||||
|
||||
|
||||
def _search_split(root, split):
|
||||
|
@ -1 +1,2 @@
|
||||
from .parser_factory import create_parser
|
||||
from .img_extensions import *
|
||||
|
@ -1 +0,0 @@
|
||||
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg')
|
50
timm/data/parsers/img_extensions.py
Normal file
50
timm/data/parsers/img_extensions.py
Normal file
@ -0,0 +1,50 @@
|
||||
from copy import deepcopy
|
||||
|
||||
__all__ = ['get_img_extensions', 'is_img_extension', 'set_img_extensions', 'add_img_extensions', 'del_img_extensions']
|
||||
|
||||
|
||||
IMG_EXTENSIONS = ('.png', '.jpg', '.jpeg') # singleton, kept public for bwd compat use
|
||||
_IMG_EXTENSIONS_SET = set(IMG_EXTENSIONS) # set version, private, kept in sync
|
||||
|
||||
|
||||
def _set_extensions(extensions):
|
||||
global IMG_EXTENSIONS
|
||||
global _IMG_EXTENSIONS_SET
|
||||
dedupe = set() # NOTE de-duping tuple while keeping original order
|
||||
IMG_EXTENSIONS = tuple(x for x in extensions if x not in dedupe and not dedupe.add(x))
|
||||
_IMG_EXTENSIONS_SET = set(extensions)
|
||||
|
||||
|
||||
def _valid_extension(x: str):
|
||||
return x and isinstance(x, str) and len(x) >= 2 and x.startswith('.')
|
||||
|
||||
|
||||
def is_img_extension(ext):
|
||||
return ext in _IMG_EXTENSIONS_SET
|
||||
|
||||
|
||||
def get_img_extensions(as_set=False):
|
||||
return deepcopy(_IMG_EXTENSIONS_SET if as_set else IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def set_img_extensions(extensions):
|
||||
assert len(extensions)
|
||||
for x in extensions:
|
||||
assert _valid_extension(x)
|
||||
_set_extensions(extensions)
|
||||
|
||||
|
||||
def add_img_extensions(ext):
|
||||
if not isinstance(ext, (list, tuple, set)):
|
||||
ext = (ext,)
|
||||
for x in ext:
|
||||
assert _valid_extension(x)
|
||||
extensions = IMG_EXTENSIONS + tuple(ext)
|
||||
_set_extensions(extensions)
|
||||
|
||||
|
||||
def del_img_extensions(ext):
|
||||
if not isinstance(ext, (list, tuple, set)):
|
||||
ext = (ext,)
|
||||
extensions = tuple(x for x in IMG_EXTENSIONS if x not in ext)
|
||||
_set_extensions(extensions)
|
@ -1,7 +1,6 @@
|
||||
import os
|
||||
|
||||
from .parser_image_folder import ParserImageFolder
|
||||
from .parser_image_tar import ParserImageTar
|
||||
from .parser_image_in_tar import ParserImageInTar
|
||||
|
||||
|
||||
|
@ -6,15 +6,35 @@ on the folder hierarchy, just leaf folders by default.
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from timm.utils.misc import natural_key
|
||||
|
||||
from .parser import Parser
|
||||
from .class_map import load_class_map
|
||||
from .constants import IMG_EXTENSIONS
|
||||
from .img_extensions import get_img_extensions
|
||||
from .parser import Parser
|
||||
|
||||
|
||||
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
|
||||
def find_images_and_targets(
|
||||
folder: str,
|
||||
types: Optional[Union[List, Tuple, Set]] = None,
|
||||
class_to_idx: Optional[Dict] = None,
|
||||
leaf_name_only: bool = True,
|
||||
sort: bool = True
|
||||
):
|
||||
""" Walk folder recursively to discover images and map them to classes by folder names.
|
||||
|
||||
Args:
|
||||
folder: root of folder to recrusively search
|
||||
types: types (file extensions) to search for in path
|
||||
class_to_idx: specify mapping for class (folder name) to class index if set
|
||||
leaf_name_only: use only leaf-name of folder walk for class names
|
||||
sort: re-sort found images by name (for consistent ordering)
|
||||
|
||||
Returns:
|
||||
A list of image and target tuples, class_to_idx mapping
|
||||
"""
|
||||
types = get_img_extensions(as_set=True) if not types else set(types)
|
||||
labels = []
|
||||
filenames = []
|
||||
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
|
||||
@ -51,7 +71,8 @@ class ParserImageFolder(Parser):
|
||||
self.samples, self.class_to_idx = find_images_and_targets(root, class_to_idx=class_to_idx)
|
||||
if len(self.samples) == 0:
|
||||
raise RuntimeError(
|
||||
f'Found 0 images in subfolders of {root}. Supported image extensions are {", ".join(IMG_EXTENSIONS)}')
|
||||
f'Found 0 images in subfolders of {root}. '
|
||||
f'Supported image extensions are {", ".join(get_img_extensions())}')
|
||||
|
||||
def __getitem__(self, index):
|
||||
path, target = self.samples[index]
|
||||
|
@ -9,20 +9,20 @@ Labels are based on the combined folder and/or tar name structure.
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
import tarfile
|
||||
import pickle
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import tarfile
|
||||
from glob import glob
|
||||
from typing import List, Dict
|
||||
from typing import List, Tuple, Dict, Set, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from timm.utils.misc import natural_key
|
||||
|
||||
from .parser import Parser
|
||||
from .class_map import load_class_map
|
||||
from .constants import IMG_EXTENSIONS
|
||||
|
||||
from .img_extensions import get_img_extensions
|
||||
from .parser import Parser
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
CACHE_FILENAME_SUFFIX = '_tarinfos.pickle'
|
||||
@ -39,7 +39,7 @@ class TarState:
|
||||
self.tf = None
|
||||
|
||||
|
||||
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTENSIONS):
|
||||
def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions: Set[str]):
|
||||
sample_count = 0
|
||||
for i, ti in enumerate(tf):
|
||||
if not ti.isfile():
|
||||
@ -60,7 +60,14 @@ def _extract_tarinfo(tf: tarfile.TarFile, parent_info: Dict, extensions=IMG_EXTE
|
||||
return sample_count
|
||||
|
||||
|
||||
def extract_tarinfos(root, class_name_to_idx=None, cache_tarinfo=None, extensions=IMG_EXTENSIONS, sort=True):
|
||||
def extract_tarinfos(
|
||||
root,
|
||||
class_name_to_idx: Optional[Dict] = None,
|
||||
cache_tarinfo: Optional[bool] = None,
|
||||
extensions: Optional[Union[List, Tuple, Set]] = None,
|
||||
sort: bool = True
|
||||
):
|
||||
extensions = get_img_extensions(as_set=True) if not extensions else set(extensions)
|
||||
root_is_tar = False
|
||||
if os.path.isfile(root):
|
||||
assert os.path.splitext(root)[-1].lower() == '.tar'
|
||||
@ -176,8 +183,8 @@ class ParserImageInTar(Parser):
|
||||
self.samples, self.targets, self.class_name_to_idx, tarfiles = extract_tarinfos(
|
||||
self.root,
|
||||
class_name_to_idx=class_name_to_idx,
|
||||
cache_tarinfo=cache_tarinfo,
|
||||
extensions=IMG_EXTENSIONS)
|
||||
cache_tarinfo=cache_tarinfo
|
||||
)
|
||||
self.class_idx_to_name = {v: k for k, v in self.class_name_to_idx.items()}
|
||||
if len(tarfiles) == 1 and tarfiles[0][0] is None:
|
||||
self.root_is_tar = True
|
||||
|
@ -8,13 +8,15 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
import os
|
||||
import tarfile
|
||||
|
||||
from .parser import Parser
|
||||
from .class_map import load_class_map
|
||||
from .constants import IMG_EXTENSIONS
|
||||
from timm.utils.misc import natural_key
|
||||
|
||||
from .class_map import load_class_map
|
||||
from .img_extensions import get_img_extensions
|
||||
from .parser import Parser
|
||||
|
||||
|
||||
def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
||||
extensions = get_img_extensions(as_set=True)
|
||||
files = []
|
||||
labels = []
|
||||
for ti in tarfile.getmembers():
|
||||
@ -23,7 +25,7 @@ def extract_tarinfo(tarfile, class_to_idx=None, sort=True):
|
||||
dirname, basename = os.path.split(ti.path)
|
||||
label = os.path.basename(dirname)
|
||||
ext = os.path.splitext(basename)[1]
|
||||
if ext.lower() in IMG_EXTENSIONS:
|
||||
if ext.lower() in extensions:
|
||||
files.append(ti)
|
||||
labels.append(label)
|
||||
if class_to_idx is None:
|
||||
|
@ -12,6 +12,7 @@ from .deit import *
|
||||
from .densenet import *
|
||||
from .dla import *
|
||||
from .dpn import *
|
||||
from .edgenext import *
|
||||
from .efficientnet import *
|
||||
from .ghostnet import *
|
||||
from .gluon_resnet import *
|
||||
|
@ -17,9 +17,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_notrace_module
|
||||
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
|
||||
from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp
|
||||
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
@ -44,6 +43,7 @@ default_cfgs = dict(
|
||||
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
|
||||
|
||||
convnext_nano_hnf=_cfg(url=''),
|
||||
convnext_nano_ols=_cfg(url=''),
|
||||
convnext_tiny_hnf=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
|
||||
crop_pct=0.95),
|
||||
@ -88,35 +88,6 @@ default_cfgs = dict(
|
||||
)
|
||||
|
||||
|
||||
def _is_contiguous(tensor: torch.Tensor) -> bool:
|
||||
# jit is oh so lovely :/
|
||||
# if torch.jit.is_tracing():
|
||||
# return True
|
||||
if torch.jit.is_scripting():
|
||||
return tensor.is_contiguous()
|
||||
else:
|
||||
return tensor.is_contiguous(memory_format=torch.contiguous_format)
|
||||
|
||||
|
||||
@register_notrace_module
|
||||
class LayerNorm2d(nn.LayerNorm):
|
||||
r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-6):
|
||||
super().__init__(normalized_shape, eps=eps)
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
if _is_contiguous(x):
|
||||
return F.layer_norm(
|
||||
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
||||
else:
|
||||
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
|
||||
x = (x - u) * torch.rsqrt(s + self.eps)
|
||||
x = x * self.weight[:, None, None] + self.bias[:, None, None]
|
||||
return x
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
""" ConvNeXt Block
|
||||
There are two equivalent implementations:
|
||||
@ -133,16 +104,30 @@ class ConvNeXtBlock(nn.Module):
|
||||
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out=None,
|
||||
stride=1,
|
||||
mlp_ratio=4,
|
||||
conv_mlp=False,
|
||||
conv_bias=True,
|
||||
ls_init_value=1e-6,
|
||||
norm_layer=None,
|
||||
act_layer=nn.GELU,
|
||||
drop_path=0.,
|
||||
):
|
||||
super().__init__()
|
||||
dim_out = dim_out or dim
|
||||
if not norm_layer:
|
||||
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
|
||||
mlp_layer = ConvMlp if conv_mlp else Mlp
|
||||
self.use_conv_mlp = conv_mlp
|
||||
self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
||||
self.norm = norm_layer(dim)
|
||||
self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU)
|
||||
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
|
||||
|
||||
self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias)
|
||||
self.norm = norm_layer(dim_out)
|
||||
self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer)
|
||||
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
@ -158,6 +143,7 @@ class ConvNeXtBlock(nn.Module):
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
if self.gamma is not None:
|
||||
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
||||
|
||||
x = self.drop_path(x) + shortcut
|
||||
return x
|
||||
|
||||
@ -165,25 +151,44 @@ class ConvNeXtBlock(nn.Module):
|
||||
class ConvNeXtStage(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False,
|
||||
norm_layer=None, cl_norm_layer=None, cross_stage=False):
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=2,
|
||||
depth=2,
|
||||
drop_path_rates=None,
|
||||
ls_init_value=1.0,
|
||||
conv_mlp=False,
|
||||
conv_bias=True,
|
||||
norm_layer=None,
|
||||
norm_layer_cl=None
|
||||
):
|
||||
super().__init__()
|
||||
self.grad_checkpointing = False
|
||||
|
||||
if in_chs != out_chs or stride > 1:
|
||||
self.downsample = nn.Sequential(
|
||||
norm_layer(in_chs),
|
||||
nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride),
|
||||
nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias),
|
||||
)
|
||||
in_chs = out_chs
|
||||
else:
|
||||
self.downsample = nn.Identity()
|
||||
|
||||
dp_rates = dp_rates or [0.] * depth
|
||||
self.blocks = nn.Sequential(*[ConvNeXtBlock(
|
||||
dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp,
|
||||
norm_layer=norm_layer if conv_mlp else cl_norm_layer)
|
||||
for j in range(depth)]
|
||||
)
|
||||
drop_path_rates = drop_path_rates or [0.] * depth
|
||||
stage_blocks = []
|
||||
for i in range(depth):
|
||||
stage_blocks.append(ConvNeXtBlock(
|
||||
dim=in_chs,
|
||||
dim_out=out_chs,
|
||||
drop_path=drop_path_rates[i],
|
||||
ls_init_value=ls_init_value,
|
||||
conv_mlp=conv_mlp,
|
||||
conv_bias=conv_bias,
|
||||
norm_layer=norm_layer if conv_mlp else norm_layer_cl
|
||||
))
|
||||
in_chs = out_chs
|
||||
self.blocks = nn.Sequential(*stage_blocks)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
@ -210,41 +215,56 @@ class ConvNeXt(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4,
|
||||
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, stem_type='patch',
|
||||
head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0.,
|
||||
self,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
output_stride=32,
|
||||
depths=(3, 3, 9, 3),
|
||||
dims=(96, 192, 384, 768),
|
||||
ls_init_value=1e-6,
|
||||
stem_type='patch',
|
||||
stem_kernel_size=4,
|
||||
stem_stride=4,
|
||||
head_init_scale=1.,
|
||||
head_norm_first=False,
|
||||
conv_mlp=False,
|
||||
conv_bias=True,
|
||||
norm_layer=None,
|
||||
drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
):
|
||||
super().__init__()
|
||||
assert output_stride == 32
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
||||
cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
|
||||
norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
|
||||
else:
|
||||
assert conv_mlp,\
|
||||
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
|
||||
cl_norm_layer = norm_layer
|
||||
norm_layer_cl = norm_layer
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
self.feature_info = []
|
||||
|
||||
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
|
||||
assert stem_type in ('patch', 'overlap')
|
||||
if stem_type == 'patch':
|
||||
assert stem_kernel_size == stem_stride
|
||||
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size),
|
||||
nn.Conv2d(in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, bias=conv_bias),
|
||||
norm_layer(dims[0])
|
||||
)
|
||||
curr_stride = patch_size
|
||||
prev_chs = dims[0]
|
||||
else:
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1),
|
||||
norm_layer(32),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.Conv2d(
|
||||
in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride,
|
||||
padding=stem_kernel_size // 2, bias=conv_bias),
|
||||
norm_layer(dims[0]),
|
||||
)
|
||||
curr_stride = 2
|
||||
prev_chs = 64
|
||||
prev_chs = dims[0]
|
||||
curr_stride = stem_stride
|
||||
|
||||
self.stages = nn.Sequential()
|
||||
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
@ -256,16 +276,23 @@ class ConvNeXt(nn.Module):
|
||||
curr_stride *= stride
|
||||
out_chs = dims[i]
|
||||
stages.append(ConvNeXtStage(
|
||||
prev_chs, out_chs, stride=stride,
|
||||
depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp,
|
||||
norm_layer=norm_layer, cl_norm_layer=cl_norm_layer)
|
||||
)
|
||||
prev_chs,
|
||||
out_chs,
|
||||
stride=stride,
|
||||
depth=depths[i],
|
||||
drop_path_rates=dp_rates[i],
|
||||
ls_init_value=ls_init_value,
|
||||
conv_mlp=conv_mlp,
|
||||
conv_bias=conv_bias,
|
||||
norm_layer=norm_layer,
|
||||
norm_layer_cl=norm_layer_cl
|
||||
))
|
||||
prev_chs = out_chs
|
||||
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
|
||||
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.num_features = prev_chs
|
||||
|
||||
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
|
||||
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
|
||||
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
|
||||
@ -327,10 +354,11 @@ class ConvNeXt(nn.Module):
|
||||
def _init_weights(module, name=None, head_init_scale=1.0):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
trunc_normal_(module.weight, std=.02)
|
||||
nn.init.constant_(module.bias, 0)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Linear):
|
||||
trunc_normal_(module.weight, std=.02)
|
||||
nn.init.constant_(module.bias, 0)
|
||||
nn.init.zeros_(module.bias)
|
||||
if name and 'head.' in name:
|
||||
module.weight.data.mul_(head_init_scale)
|
||||
module.bias.data.mul_(head_init_scale)
|
||||
@ -371,14 +399,25 @@ def _create_convnext(variant, pretrained=False, **kwargs):
|
||||
|
||||
@register_model
|
||||
def convnext_nano_hnf(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs)
|
||||
model_args = dict(
|
||||
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs)
|
||||
model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_nano_ols(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True,
|
||||
conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs)
|
||||
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_tiny_hnf(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
|
||||
model_args = dict(
|
||||
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
|
||||
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
@ -386,7 +425,7 @@ def convnext_tiny_hnf(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def convnext_tiny_hnfd(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, stem_type='dual', **kwargs)
|
||||
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
|
||||
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,7 +1,10 @@
|
||||
""" DeiT - Data-efficient Image Transformers
|
||||
|
||||
DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
|
||||
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
||||
|
||||
paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
||||
|
||||
paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118
|
||||
|
||||
Modifications copyright 2021, Ross Wightman
|
||||
"""
|
||||
@ -53,6 +56,46 @@ default_cfgs = {
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0,
|
||||
classifier=('head', 'head_dist')),
|
||||
|
||||
'deit3_small_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
|
||||
'deit3_small_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_base_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
|
||||
'deit3_base_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_large_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
|
||||
'deit3_large_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_huge_patch14_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
|
||||
|
||||
'deit3_small_patch16_224_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_small_patch16_384_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_base_patch16_224_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_base_patch16_384_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_large_patch16_224_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_large_patch16_384_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_huge_patch14_224_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
|
||||
crop_pct=1.0),
|
||||
}
|
||||
|
||||
|
||||
@ -68,9 +111,10 @@ class VisionTransformerDistilled(VisionTransformer):
|
||||
super().__init__(*args, **kwargs, weight_init='skip')
|
||||
assert self.global_pool in ('token',)
|
||||
|
||||
self.num_tokens = 2
|
||||
self.num_prefix_tokens = 2
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim))
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
||||
self.distilled_training = False # must set this True to train w/ distillation token
|
||||
|
||||
@ -220,3 +264,157 @@ def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
||||
model = _create_deit(
|
||||
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_small_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_small_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_base_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_base_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_large_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_large_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_huge_patch14_224(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_small_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_small_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_small_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_base_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_base_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_base_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_large_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_large_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_large_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_large_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_huge_patch14_224_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_huge_patch14_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
557
timm/models/edgenext.py
Normal file
557
timm/models/edgenext.py
Normal file
@ -0,0 +1,557 @@
|
||||
""" EdgeNeXt
|
||||
|
||||
Paper: `EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications`
|
||||
- https://arxiv.org/abs/2206.10589
|
||||
|
||||
Original code and weights from https://github.com/mmaaz60/EdgeNeXt
|
||||
|
||||
Modifications and additions for timm by / Copyright 2022, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_notrace_module
|
||||
from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
|
||||
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = dict(
|
||||
edgenext_xx_small=_cfg(
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth"),
|
||||
edgenext_x_small=_cfg(
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth"),
|
||||
# edgenext_small=_cfg(
|
||||
# url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"),
|
||||
edgenext_small=_cfg( # USI weights
|
||||
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
|
||||
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
|
||||
edgenext_small_rw=_cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
|
||||
test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
|
||||
class PositionalEncodingFourier(nn.Module):
|
||||
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
|
||||
super().__init__()
|
||||
self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1)
|
||||
self.scale = 2 * math.pi
|
||||
self.temperature = temperature
|
||||
self.hidden_dim = hidden_dim
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, shape: Tuple[int, int, int]):
|
||||
inv_mask = ~torch.zeros(shape).to(device=self.token_projection.weight.device, dtype=torch.bool)
|
||||
y_embed = inv_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = inv_mask.cumsum(2, dtype=torch.float32)
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=inv_mask.device)
|
||||
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack(
|
||||
(pos_x[:, :, :, 0::2].sin(),
|
||||
pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos_y = torch.stack(
|
||||
(pos_y[:, :, :, 0::2].sin(),
|
||||
pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
pos = self.token_projection(pos)
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_out=None,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
conv_bias=True,
|
||||
expand_ratio=4,
|
||||
ls_init_value=1e-6,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
act_layer=nn.GELU, drop_path=0.,
|
||||
):
|
||||
super().__init__()
|
||||
dim_out = dim_out or dim
|
||||
self.shortcut_after_dw = stride > 1 or dim != dim_out
|
||||
|
||||
self.conv_dw = create_conv2d(
|
||||
dim, dim_out, kernel_size=kernel_size, stride=stride, depthwise=True, bias=conv_bias)
|
||||
self.norm = norm_layer(dim_out)
|
||||
self.mlp = Mlp(dim_out, int(expand_ratio * dim_out), act_layer=act_layer)
|
||||
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
x = self.conv_dw(x)
|
||||
if self.shortcut_after_dw:
|
||||
shortcut = x
|
||||
|
||||
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
||||
x = self.norm(x)
|
||||
x = self.mlp(x)
|
||||
if self.gamma is not None:
|
||||
x = self.gamma * x
|
||||
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
||||
|
||||
x = shortcut + self.drop_path(x)
|
||||
return x
|
||||
|
||||
|
||||
class CrossCovarianceAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
# NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
|
||||
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'temperature'}
|
||||
|
||||
|
||||
class SplitTransposeBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_scales=1,
|
||||
num_heads=8,
|
||||
expand_ratio=4,
|
||||
use_pos_emb=True,
|
||||
conv_bias=True,
|
||||
qkv_bias=True,
|
||||
ls_init_value=1e-6,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
act_layer=nn.GELU,
|
||||
drop_path=0.,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.
|
||||
):
|
||||
super().__init__()
|
||||
width = max(int(math.ceil(dim / num_scales)), int(math.floor(dim // num_scales)))
|
||||
self.width = width
|
||||
self.num_scales = max(1, num_scales - 1)
|
||||
|
||||
convs = []
|
||||
for i in range(self.num_scales):
|
||||
convs.append(create_conv2d(width, width, kernel_size=3, depthwise=True, bias=conv_bias))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
|
||||
self.pos_embd = None
|
||||
if use_pos_emb:
|
||||
self.pos_embd = PositionalEncodingFourier(dim=dim)
|
||||
self.norm_xca = norm_layer(dim)
|
||||
self.gamma_xca = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
|
||||
self.xca = CrossCovarianceAttn(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop)
|
||||
|
||||
self.norm = norm_layer(dim, eps=1e-6)
|
||||
self.mlp = Mlp(dim, int(expand_ratio * dim), act_layer=act_layer)
|
||||
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = x
|
||||
|
||||
# scales code re-written for torchscript as per my res2net fixes -rw
|
||||
spx = torch.split(x, self.width, 1)
|
||||
spo = []
|
||||
sp = spx[0]
|
||||
for i, conv in enumerate(self.convs):
|
||||
if i > 0:
|
||||
sp = sp + spx[i]
|
||||
sp = conv(sp)
|
||||
spo.append(sp)
|
||||
spo.append(spx[-1])
|
||||
x = torch.cat(spo, 1)
|
||||
|
||||
# XCA
|
||||
B, C, H, W = x.shape
|
||||
x = x.reshape(B, C, H * W).permute(0, 2, 1)
|
||||
if self.pos_embd is not None:
|
||||
pos_encoding = self.pos_embd((B, H, W)).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
|
||||
x = x + pos_encoding
|
||||
x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x)))
|
||||
x = x.reshape(B, H, W, C)
|
||||
|
||||
# Inverted Bottleneck
|
||||
x = self.norm(x)
|
||||
x = self.mlp(x)
|
||||
if self.gamma is not None:
|
||||
x = self.gamma * x
|
||||
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
||||
|
||||
x = shortcut + self.drop_path(x)
|
||||
return x
|
||||
|
||||
|
||||
class EdgeNeXtStage(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs,
|
||||
stride=2,
|
||||
depth=2,
|
||||
num_global_blocks=1,
|
||||
num_heads=4,
|
||||
scales=2,
|
||||
kernel_size=7,
|
||||
expand_ratio=4,
|
||||
use_pos_emb=False,
|
||||
downsample_block=False,
|
||||
conv_bias=True,
|
||||
ls_init_value=1.0,
|
||||
drop_path_rates=None,
|
||||
norm_layer=LayerNorm2d,
|
||||
norm_layer_cl=partial(nn.LayerNorm, eps=1e-6),
|
||||
act_layer=nn.GELU
|
||||
):
|
||||
super().__init__()
|
||||
self.grad_checkpointing = False
|
||||
|
||||
if downsample_block or stride == 1:
|
||||
self.downsample = nn.Identity()
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
norm_layer(in_chs),
|
||||
nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2, bias=conv_bias)
|
||||
)
|
||||
in_chs = out_chs
|
||||
|
||||
stage_blocks = []
|
||||
for i in range(depth):
|
||||
if i < depth - num_global_blocks:
|
||||
stage_blocks.append(
|
||||
ConvBlock(
|
||||
dim=in_chs,
|
||||
dim_out=out_chs,
|
||||
stride=stride if downsample_block and i == 0 else 1,
|
||||
conv_bias=conv_bias,
|
||||
kernel_size=kernel_size,
|
||||
expand_ratio=expand_ratio,
|
||||
ls_init_value=ls_init_value,
|
||||
drop_path=drop_path_rates[i],
|
||||
norm_layer=norm_layer_cl,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
)
|
||||
else:
|
||||
stage_blocks.append(
|
||||
SplitTransposeBlock(
|
||||
dim=in_chs,
|
||||
num_scales=scales,
|
||||
num_heads=num_heads,
|
||||
expand_ratio=expand_ratio,
|
||||
use_pos_emb=use_pos_emb,
|
||||
conv_bias=conv_bias,
|
||||
ls_init_value=ls_init_value,
|
||||
drop_path=drop_path_rates[i],
|
||||
norm_layer=norm_layer_cl,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
)
|
||||
in_chs = out_chs
|
||||
self.blocks = nn.Sequential(*stage_blocks)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.downsample(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
x = self.blocks(x)
|
||||
return x
|
||||
|
||||
|
||||
class EdgeNeXt(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
dims=(24, 48, 88, 168),
|
||||
depths=(3, 3, 9, 3),
|
||||
global_block_counts=(0, 1, 1, 1),
|
||||
kernel_sizes=(3, 5, 7, 9),
|
||||
heads=(8, 8, 8, 8),
|
||||
d2_scales=(2, 2, 3, 4),
|
||||
use_pos_emb=(False, True, False, False),
|
||||
ls_init_value=1e-6,
|
||||
head_init_scale=1.,
|
||||
expand_ratio=4,
|
||||
downsample_block=False,
|
||||
conv_bias=True,
|
||||
stem_type='patch',
|
||||
head_norm_first=False,
|
||||
act_layer=nn.GELU,
|
||||
drop_path_rate=0.,
|
||||
drop_rate=0.,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.drop_rate = drop_rate
|
||||
norm_layer = partial(LayerNorm2d, eps=1e-6)
|
||||
norm_layer_cl = partial(nn.LayerNorm, eps=1e-6)
|
||||
self.feature_info = []
|
||||
|
||||
assert stem_type in ('patch', 'overlap')
|
||||
if stem_type == 'patch':
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=conv_bias),
|
||||
norm_layer(dims[0]),
|
||||
)
|
||||
else:
|
||||
self.stem = nn.Sequential(
|
||||
nn.Conv2d(in_chans, dims[0], kernel_size=9, stride=4, padding=9 // 2, bias=conv_bias),
|
||||
norm_layer(dims[0]),
|
||||
)
|
||||
|
||||
curr_stride = 4
|
||||
stages = []
|
||||
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
||||
in_chs = dims[0]
|
||||
for i in range(4):
|
||||
stride = 2 if curr_stride == 2 or i > 0 else 1
|
||||
# FIXME support dilation / output_stride
|
||||
curr_stride *= stride
|
||||
stages.append(EdgeNeXtStage(
|
||||
in_chs=in_chs,
|
||||
out_chs=dims[i],
|
||||
stride=stride,
|
||||
depth=depths[i],
|
||||
num_global_blocks=global_block_counts[i],
|
||||
num_heads=heads[i],
|
||||
drop_path_rates=dp_rates[i],
|
||||
scales=d2_scales[i],
|
||||
expand_ratio=expand_ratio,
|
||||
kernel_size=kernel_sizes[i],
|
||||
use_pos_emb=use_pos_emb[i],
|
||||
ls_init_value=ls_init_value,
|
||||
downsample_block=downsample_block,
|
||||
conv_bias=conv_bias,
|
||||
norm_layer=norm_layer,
|
||||
norm_layer_cl=norm_layer_cl,
|
||||
act_layer=act_layer,
|
||||
))
|
||||
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
|
||||
in_chs = dims[i]
|
||||
self.feature_info += [dict(num_chs=in_chs, reduction=curr_stride, module=f'stages.{i}')]
|
||||
|
||||
self.stages = nn.Sequential(*stages)
|
||||
|
||||
self.num_features = dims[-1]
|
||||
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
|
||||
self.head = nn.Sequential(OrderedDict([
|
||||
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
|
||||
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
|
||||
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
|
||||
('drop', nn.Dropout(self.drop_rate)),
|
||||
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
|
||||
|
||||
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
|
||||
|
||||
@torch.jit.ignore
|
||||
def group_matcher(self, coarse=False):
|
||||
return dict(
|
||||
stem=r'^stem',
|
||||
blocks=r'^stages\.(\d+)' if coarse else [
|
||||
(r'^stages\.(\d+)\.downsample', (0,)), # blocks
|
||||
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
||||
(r'^norm_pre', (99999,))
|
||||
]
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
for s in self.stages:
|
||||
s.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes=0, global_pool=None):
|
||||
if global_pool is not None:
|
||||
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.stem(x)
|
||||
x = self.stages(x)
|
||||
x = self.norm_pre(x)
|
||||
return x
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
|
||||
x = self.head.global_pool(x)
|
||||
x = self.head.norm(x)
|
||||
x = self.head.flatten(x)
|
||||
x = self.head.drop(x)
|
||||
return x if pre_logits else self.head.fc(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.forward_head(x)
|
||||
return x
|
||||
|
||||
|
||||
def _init_weights(module, name=None, head_init_scale=1.0):
|
||||
if isinstance(module, nn.Conv2d):
|
||||
trunc_normal_tf_(module.weight, std=.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Linear):
|
||||
trunc_normal_tf_(module.weight, std=.02)
|
||||
nn.init.zeros_(module.bias)
|
||||
if name and 'head.' in name:
|
||||
module.weight.data.mul_(head_init_scale)
|
||||
module.bias.data.mul_(head_init_scale)
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" Remap FB checkpoints -> timm """
|
||||
if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
|
||||
return state_dict # non-FB checkpoint
|
||||
|
||||
# models were released as train checkpoints... :/
|
||||
if 'model_ema' in state_dict:
|
||||
state_dict = state_dict['model_ema']
|
||||
elif 'model' in state_dict:
|
||||
state_dict = state_dict['model']
|
||||
elif 'state_dict' in state_dict:
|
||||
state_dict = state_dict['state_dict']
|
||||
|
||||
out_dict = {}
|
||||
import re
|
||||
for k, v in state_dict.items():
|
||||
k = k.replace('downsample_layers.0.', 'stem.')
|
||||
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
|
||||
k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
|
||||
k = k.replace('dwconv', 'conv_dw')
|
||||
k = k.replace('pwconv', 'mlp.fc')
|
||||
k = k.replace('head.', 'head.fc.')
|
||||
if k.startswith('norm.'):
|
||||
k = k.replace('norm', 'head.norm')
|
||||
if v.ndim == 2 and 'head' not in k:
|
||||
model_shape = model.state_dict()[k].shape
|
||||
v = v.reshape(model_shape)
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def _create_edgenext(variant, pretrained=False, **kwargs):
|
||||
model = build_model_with_cfg(
|
||||
EdgeNeXt, variant, pretrained,
|
||||
pretrained_filter_fn=checkpoint_filter_fn,
|
||||
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
|
||||
**kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def edgenext_xx_small(pretrained=False, **kwargs):
|
||||
# 1.33M & 260.58M @ 256 resolution
|
||||
# 71.23% Top-1 accuracy
|
||||
# No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
|
||||
# Jetson FPS=51.66 versus 47.67 for MobileViT_XXS
|
||||
# For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS
|
||||
model_kwargs = dict(depths=(2, 2, 6, 2), dims=(24, 48, 88, 168), heads=(4, 4, 4, 4), **kwargs)
|
||||
return _create_edgenext('edgenext_xx_small', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def edgenext_x_small(pretrained=False, **kwargs):
|
||||
# 2.34M & 538.0M @ 256 resolution
|
||||
# 75.00% Top-1 accuracy
|
||||
# No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
|
||||
# Jetson FPS=31.61 versus 28.49 for MobileViT_XS
|
||||
# For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS
|
||||
model_kwargs = dict(depths=(3, 3, 9, 3), dims=(32, 64, 100, 192), heads=(4, 4, 4, 4), **kwargs)
|
||||
return _create_edgenext('edgenext_x_small', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def edgenext_small(pretrained=False, **kwargs):
|
||||
# 5.59M & 1260.59M @ 256 resolution
|
||||
# 79.43% Top-1 accuracy
|
||||
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
|
||||
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
|
||||
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
|
||||
model_kwargs = dict(depths=(3, 3, 9, 3), dims=(48, 96, 160, 304), **kwargs)
|
||||
return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def edgenext_small_rw(pretrained=False, **kwargs):
|
||||
# 5.59M & 1260.59M @ 256 resolution
|
||||
# 79.43% Top-1 accuracy
|
||||
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
|
||||
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
|
||||
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
|
||||
model_kwargs = dict(
|
||||
depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
|
||||
downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs)
|
||||
return _create_edgenext('edgenext_small_rw', pretrained=pretrained, **model_kwargs)
|
||||
|
@ -25,7 +25,7 @@ from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
|
||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||
from .norm import GroupNorm, LayerNorm2d
|
||||
from .norm import GroupNorm, GroupNorm1, LayerNorm2d
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
|
||||
from .padding import get_padding, get_same_padding, pad_same
|
||||
from .patch_embed import PatchEmbed
|
||||
@ -39,4 +39,4 @@ from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .trace_utils import _assert, _float_to_int
|
||||
from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_
|
||||
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import functools
|
||||
from torch import nn as nn
|
||||
|
||||
from .create_conv2d import create_conv2d
|
||||
@ -40,12 +41,26 @@ class ConvNormAct(nn.Module):
|
||||
ConvBnAct = ConvNormAct
|
||||
|
||||
|
||||
def create_aa(aa_layer, channels, stride=2, enable=True):
|
||||
if not aa_layer or not enable:
|
||||
return nn.Identity()
|
||||
if isinstance(aa_layer, functools.partial):
|
||||
if issubclass(aa_layer.func, nn.AvgPool2d):
|
||||
return aa_layer()
|
||||
else:
|
||||
return aa_layer(channels)
|
||||
elif issubclass(aa_layer, nn.AvgPool2d):
|
||||
return aa_layer(stride)
|
||||
else:
|
||||
return aa_layer(channels=channels, stride=stride)
|
||||
|
||||
|
||||
class ConvNormActAa(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
|
||||
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None):
|
||||
super(ConvNormActAa, self).__init__()
|
||||
use_aa = aa_layer is not None
|
||||
use_aa = aa_layer is not None and stride == 2
|
||||
|
||||
self.conv = create_conv2d(
|
||||
in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
|
||||
@ -56,7 +71,7 @@ class ConvNormActAa(nn.Module):
|
||||
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
|
||||
norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
|
||||
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
|
||||
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity()
|
||||
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
|
@ -22,7 +22,7 @@ def get_attn(attn_type):
|
||||
if isinstance(attn_type, torch.nn.Module):
|
||||
return attn_type
|
||||
module_cls = None
|
||||
if attn_type is not None:
|
||||
if attn_type:
|
||||
if isinstance(attn_type, str):
|
||||
attn_type = attn_type.lower()
|
||||
# Lightweight attention modules (channel and/or coarse spatial).
|
||||
|
@ -14,11 +14,59 @@ class GroupNorm(nn.GroupNorm):
|
||||
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
class GroupNorm1(nn.GroupNorm):
|
||||
""" Group Normalization with 1 group.
|
||||
Input: tensor in shape [B, C, *]
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels, **kwargs):
|
||||
super().__init__(1, num_channels, **kwargs)
|
||||
|
||||
|
||||
class LayerNorm2d(nn.LayerNorm):
|
||||
""" LayerNorm for channels of '2D' spatial BCHW tensors """
|
||||
def __init__(self, num_channels):
|
||||
super().__init__(num_channels)
|
||||
""" LayerNorm for channels of '2D' spatial NCHW tensors """
|
||||
def __init__(self, num_channels, eps=1e-6, affine=True):
|
||||
super().__init__(num_channels, eps=eps, elementwise_affine=affine)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.layer_norm(
|
||||
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
||||
|
||||
|
||||
def _is_contiguous(tensor: torch.Tensor) -> bool:
|
||||
# jit is oh so lovely :/
|
||||
# if torch.jit.is_tracing():
|
||||
# return True
|
||||
if torch.jit.is_scripting():
|
||||
return tensor.is_contiguous()
|
||||
else:
|
||||
return tensor.is_contiguous(memory_format=torch.contiguous_format)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
|
||||
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
|
||||
x = (x - u) * torch.rsqrt(s + eps)
|
||||
x = x * weight[:, None, None] + bias[:, None, None]
|
||||
return x
|
||||
|
||||
|
||||
class LayerNormExp2d(nn.LayerNorm):
|
||||
""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
|
||||
|
||||
Experimental implementation w/ manual norm for tensors non-contiguous tensors.
|
||||
|
||||
This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last
|
||||
layout. However, benefits are not always clear and can perform worse on other GPUs.
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels, eps=1e-6):
|
||||
super().__init__(num_channels, eps=eps)
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
if _is_contiguous(x):
|
||||
x = F.layer_norm(
|
||||
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
||||
else:
|
||||
x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
|
||||
return x
|
||||
|
@ -36,7 +36,7 @@ class TestTimePoolHead(nn.Module):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
def apply_test_time_pool(model, config, use_test_size=True):
|
||||
def apply_test_time_pool(model, config, use_test_size=False):
|
||||
test_time_pool = False
|
||||
if not hasattr(model, 'default_cfg') or not model.default_cfg:
|
||||
return model, False
|
||||
|
@ -49,6 +49,11 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
|
||||
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
|
||||
applied while sampling the normal with mean/std applied, therefore a, b args
|
||||
should be adjusted to match the range of mean, std args.
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
@ -62,6 +67,35 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
|
||||
def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
# type: (Tensor, float, float, float, float) -> Tensor
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
|
||||
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
|
||||
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
|
||||
and the result is subsquently scaled and shifted by the mean and std args.
|
||||
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
std: the standard deviation of the normal distribution
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.trunc_normal_(w)
|
||||
"""
|
||||
_no_grad_trunc_normal_(tensor, 0, 1.0, a, b)
|
||||
with torch.no_grad():
|
||||
tensor.mul_(std).add_(mean)
|
||||
return tensor
|
||||
|
||||
|
||||
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
if mode == 'fan_in':
|
||||
@ -75,7 +109,7 @@ def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
||||
|
||||
if distribution == "truncated_normal":
|
||||
# constant is stddev of standard normal truncated to (-2, 2)
|
||||
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
|
||||
trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978)
|
||||
elif distribution == "normal":
|
||||
tensor.normal_(std=math.sqrt(variance))
|
||||
elif distribution == "uniform":
|
||||
|
@ -1,7 +1,8 @@
|
||||
""" MobileViT
|
||||
|
||||
Paper:
|
||||
`MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178
|
||||
V1: `MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178
|
||||
V2: `Separable Self-attention for Mobile Vision Transformers` - https://arxiv.org/abs/2206.02680
|
||||
|
||||
MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below)
|
||||
License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source)
|
||||
@ -13,7 +14,7 @@ Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022
|
||||
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
|
||||
#
|
||||
import math
|
||||
from typing import Union, Callable, Dict, Tuple, Optional
|
||||
from typing import Union, Callable, Dict, Tuple, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -21,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups
|
||||
from .fx_features import register_notrace_module
|
||||
from .layers import to_2tuple, make_divisible
|
||||
from .layers import to_2tuple, make_divisible, LayerNorm2d, GroupNorm1, ConvMlp, DropPath
|
||||
from .vision_transformer import Block as TransformerBlock
|
||||
from .helpers import build_model_with_cfg
|
||||
from .registry import register_model
|
||||
@ -48,6 +49,48 @@ default_cfgs = {
|
||||
'mobilevit_s': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevit_s-38a5a959.pth'),
|
||||
'semobilevit_s': _cfg(),
|
||||
|
||||
'mobilevitv2_050': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_050-49951ee2.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_075': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_075-b5556ef6.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_100': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_100-e464ef3b.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_125': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_125-0ae35027.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_150': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150-737c5019.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_175': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175-16462ee2.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_200': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200-b3422f67.pth',
|
||||
crop_pct=0.888),
|
||||
|
||||
'mobilevitv2_150_in22ft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_in22ft1k-0b555d7b.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_175_in22ft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_in22ft1k-4117fa1f.pth',
|
||||
crop_pct=0.888),
|
||||
'mobilevitv2_200_in22ft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_in22ft1k-1d7c8927.pth',
|
||||
crop_pct=0.888),
|
||||
|
||||
'mobilevitv2_150_384_in22ft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_150_384_in22ft1k-9e142854.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
'mobilevitv2_175_384_in22ft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_175_384_in22ft1k-059cbe56.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
'mobilevitv2_200_384_in22ft1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-mvit-weights/mobilevitv2_200_384_in22ft1k-32c87503.pth',
|
||||
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
|
||||
}
|
||||
|
||||
|
||||
@ -72,6 +115,40 @@ def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4,
|
||||
)
|
||||
|
||||
|
||||
def _mobilevitv2_block(d, c, s, transformer_depth, patch_size=2, br=2.0, transformer_br=0.5):
|
||||
# inverted residual + mobilevit blocks as per MobileViT network
|
||||
return (
|
||||
_inverted_residual_block(d=d, c=c, s=s, br=br),
|
||||
ByoBlockCfg(
|
||||
type='mobilevit2', d=1, c=c, s=1, br=transformer_br, gs=1,
|
||||
block_kwargs=dict(
|
||||
transformer_depth=transformer_depth,
|
||||
patch_size=patch_size)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _mobilevitv2_cfg(multiplier=1.0):
|
||||
chs = (64, 128, 256, 384, 512)
|
||||
if multiplier != 1.0:
|
||||
chs = tuple([int(c * multiplier) for c in chs])
|
||||
cfg = ByoModelCfg(
|
||||
blocks=(
|
||||
_inverted_residual_block(d=1, c=chs[0], s=1, br=2.0),
|
||||
_inverted_residual_block(d=2, c=chs[1], s=2, br=2.0),
|
||||
_mobilevitv2_block(d=1, c=chs[2], s=2, transformer_depth=2),
|
||||
_mobilevitv2_block(d=1, c=chs[3], s=2, transformer_depth=4),
|
||||
_mobilevitv2_block(d=1, c=chs[4], s=2, transformer_depth=3),
|
||||
),
|
||||
stem_chs=int(32 * multiplier),
|
||||
stem_type='3x3',
|
||||
stem_pool='',
|
||||
downsample='',
|
||||
act_layer='silu',
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
model_cfgs = dict(
|
||||
mobilevit_xxs=ByoModelCfg(
|
||||
blocks=(
|
||||
@ -137,11 +214,19 @@ model_cfgs = dict(
|
||||
attn_kwargs=dict(rd_ratio=1/8),
|
||||
num_features=640,
|
||||
),
|
||||
|
||||
mobilevitv2_050=_mobilevitv2_cfg(.50),
|
||||
mobilevitv2_075=_mobilevitv2_cfg(.75),
|
||||
mobilevitv2_125=_mobilevitv2_cfg(1.25),
|
||||
mobilevitv2_100=_mobilevitv2_cfg(1.0),
|
||||
mobilevitv2_150=_mobilevitv2_cfg(1.5),
|
||||
mobilevitv2_175=_mobilevitv2_cfg(1.75),
|
||||
mobilevitv2_200=_mobilevitv2_cfg(2.0),
|
||||
)
|
||||
|
||||
|
||||
@register_notrace_module
|
||||
class MobileViTBlock(nn.Module):
|
||||
class MobileVitBlock(nn.Module):
|
||||
""" MobileViT block
|
||||
Paper: https://arxiv.org/abs/2110.02178?context=cs.LG
|
||||
"""
|
||||
@ -165,9 +250,9 @@ class MobileViTBlock(nn.Module):
|
||||
drop_path_rate: float = 0.,
|
||||
layers: LayerFn = None,
|
||||
transformer_norm_layer: Callable = nn.LayerNorm,
|
||||
downsample: str = ''
|
||||
**kwargs, # eat unused args
|
||||
):
|
||||
super(MobileViTBlock, self).__init__()
|
||||
super(MobileVitBlock, self).__init__()
|
||||
|
||||
layers = layers or LayerFn()
|
||||
groups = num_groups(group_size, in_chs)
|
||||
@ -241,7 +326,270 @@ class MobileViTBlock(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
register_block('mobilevit', MobileViTBlock)
|
||||
class LinearSelfAttention(nn.Module):
|
||||
"""
|
||||
This layer applies a self-attention with linear complexity, as described in `https://arxiv.org/abs/2206.02680`
|
||||
This layer can be used for self- as well as cross-attention.
|
||||
Args:
|
||||
embed_dim (int): :math:`C` from an expected input of size :math:`(N, C, H, W)`
|
||||
attn_drop (float): Dropout value for context scores. Default: 0.0
|
||||
bias (bool): Use bias in learnable layers. Default: True
|
||||
Shape:
|
||||
- Input: :math:`(N, C, P, N)` where :math:`N` is the batch size, :math:`C` is the input channels,
|
||||
:math:`P` is the number of pixels in the patch, and :math:`N` is the number of patches
|
||||
- Output: same as the input
|
||||
.. note::
|
||||
For MobileViTv2, we unfold the feature map [B, C, H, W] into [B, C, P, N] where P is the number of pixels
|
||||
in a patch and N is the number of patches. Because channel is the first dimension in this unfolded tensor,
|
||||
we use point-wise convolution (instead of a linear layer). This avoids a transpose operation (which may be
|
||||
expensive on resource-constrained devices) that may be required to convert the unfolded tensor from
|
||||
channel-first to channel-last format in case of a linear layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.qkv_proj = nn.Conv2d(
|
||||
in_channels=embed_dim,
|
||||
out_channels=1 + (2 * embed_dim),
|
||||
bias=bias,
|
||||
kernel_size=1,
|
||||
)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.out_proj = nn.Conv2d(
|
||||
in_channels=embed_dim,
|
||||
out_channels=embed_dim,
|
||||
bias=bias,
|
||||
kernel_size=1,
|
||||
)
|
||||
self.out_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def _forward_self_attn(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# [B, C, P, N] --> [B, h + 2d, P, N]
|
||||
qkv = self.qkv_proj(x)
|
||||
|
||||
# Project x into query, key and value
|
||||
# Query --> [B, 1, P, N]
|
||||
# value, key --> [B, d, P, N]
|
||||
query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)
|
||||
|
||||
# apply softmax along N dimension
|
||||
context_scores = F.softmax(query, dim=-1)
|
||||
context_scores = self.attn_drop(context_scores)
|
||||
|
||||
# Compute context vector
|
||||
# [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1]
|
||||
context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
|
||||
|
||||
# combine context vector with values
|
||||
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
|
||||
out = F.relu(value) * context_vector.expand_as(value)
|
||||
out = self.out_proj(out)
|
||||
out = self.out_drop(out)
|
||||
return out
|
||||
|
||||
@torch.jit.ignore()
|
||||
def _forward_cross_attn(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# x --> [B, C, P, N]
|
||||
# x_prev = [B, C, P, M]
|
||||
batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape
|
||||
q_patch_area, q_num_patches = x.shape[-2:]
|
||||
|
||||
assert (
|
||||
kv_patch_area == q_patch_area
|
||||
), "The number of pixels in a patch for query and key_value should be the same"
|
||||
|
||||
# compute query, key, and value
|
||||
# [B, C, P, M] --> [B, 1 + d, P, M]
|
||||
qk = F.conv2d(
|
||||
x_prev,
|
||||
weight=self.qkv_proj.weight[:self.embed_dim + 1],
|
||||
bias=self.qkv_proj.bias[:self.embed_dim + 1],
|
||||
)
|
||||
|
||||
# [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M]
|
||||
query, key = qk.split([1, self.embed_dim], dim=1)
|
||||
# [B, C, P, N] --> [B, d, P, N]
|
||||
value = F.conv2d(
|
||||
x,
|
||||
weight=self.qkv_proj.weight[self.embed_dim + 1],
|
||||
bias=self.qkv_proj.bias[self.embed_dim + 1] if self.qkv_proj.bias is not None else None,
|
||||
)
|
||||
|
||||
# apply softmax along M dimension
|
||||
context_scores = F.softmax(query, dim=-1)
|
||||
context_scores = self.attn_drop(context_scores)
|
||||
|
||||
# compute context vector
|
||||
# [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] --> [B, d, P, 1]
|
||||
context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
|
||||
|
||||
# combine context vector with values
|
||||
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
|
||||
out = F.relu(value) * context_vector.expand_as(value)
|
||||
out = self.out_proj(out)
|
||||
out = self.out_drop(out)
|
||||
return out
|
||||
|
||||
def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if x_prev is None:
|
||||
return self._forward_self_attn(x)
|
||||
else:
|
||||
return self._forward_cross_attn(x, x_prev=x_prev)
|
||||
|
||||
|
||||
class LinearTransformerBlock(nn.Module):
|
||||
"""
|
||||
This class defines the pre-norm transformer encoder with linear self-attention in `MobileViTv2 paper <>`_
|
||||
Args:
|
||||
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, P, N)`
|
||||
mlp_ratio (float): Inner dimension ratio of the FFN relative to embed_dim
|
||||
drop (float): Dropout rate. Default: 0.0
|
||||
attn_drop (float): Dropout rate for attention in multi-head attention. Default: 0.0
|
||||
drop_path (float): Stochastic depth rate Default: 0.0
|
||||
norm_layer (Callable): Normalization layer. Default: layer_norm_2d
|
||||
Shape:
|
||||
- Input: :math:`(B, C_{in}, P, N)` where :math:`B` is batch size, :math:`C_{in}` is input embedding dim,
|
||||
:math:`P` is number of pixels in a patch, and :math:`N` is number of patches,
|
||||
- Output: same shape as the input
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
mlp_ratio: float = 2.0,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
act_layer=None,
|
||||
norm_layer=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
act_layer = act_layer or nn.SiLU
|
||||
norm_layer = norm_layer or GroupNorm1
|
||||
|
||||
self.norm1 = norm_layer(embed_dim)
|
||||
self.attn = LinearSelfAttention(embed_dim=embed_dim, attn_drop=attn_drop, proj_drop=drop)
|
||||
self.drop_path1 = DropPath(drop_path)
|
||||
|
||||
self.norm2 = norm_layer(embed_dim)
|
||||
self.mlp = ConvMlp(
|
||||
in_features=embed_dim,
|
||||
hidden_features=int(embed_dim * mlp_ratio),
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
self.drop_path2 = DropPath(drop_path)
|
||||
|
||||
def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if x_prev is None:
|
||||
# self-attention
|
||||
x = x + self.drop_path1(self.attn(self.norm1(x)))
|
||||
else:
|
||||
# cross-attention
|
||||
res = x
|
||||
x = self.norm1(x) # norm
|
||||
x = self.attn(x, x_prev) # attn
|
||||
x = self.drop_path1(x) + res # residual
|
||||
|
||||
# Feed forward network
|
||||
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
@register_notrace_module
|
||||
class MobileVitV2Block(nn.Module):
|
||||
"""
|
||||
This class defines the `MobileViTv2 block <>`_
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs: int,
|
||||
out_chs: Optional[int] = None,
|
||||
kernel_size: int = 3,
|
||||
bottle_ratio: float = 1.0,
|
||||
group_size: Optional[int] = 1,
|
||||
dilation: Tuple[int, int] = (1, 1),
|
||||
mlp_ratio: float = 2.0,
|
||||
transformer_dim: Optional[int] = None,
|
||||
transformer_depth: int = 2,
|
||||
patch_size: int = 8,
|
||||
attn_drop: float = 0.,
|
||||
drop: int = 0.,
|
||||
drop_path_rate: float = 0.,
|
||||
layers: LayerFn = None,
|
||||
transformer_norm_layer: Callable = GroupNorm1,
|
||||
**kwargs, # eat unused args
|
||||
):
|
||||
super(MobileVitV2Block, self).__init__()
|
||||
layers = layers or LayerFn()
|
||||
groups = num_groups(group_size, in_chs)
|
||||
out_chs = out_chs or in_chs
|
||||
transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs)
|
||||
|
||||
self.conv_kxk = layers.conv_norm_act(
|
||||
in_chs, in_chs, kernel_size=kernel_size,
|
||||
stride=1, groups=groups, dilation=dilation[0])
|
||||
self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False)
|
||||
|
||||
self.transformer = nn.Sequential(*[
|
||||
LinearTransformerBlock(
|
||||
transformer_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attn_drop=attn_drop,
|
||||
drop=drop,
|
||||
drop_path=drop_path_rate,
|
||||
act_layer=layers.act,
|
||||
norm_layer=transformer_norm_layer
|
||||
)
|
||||
for _ in range(transformer_depth)
|
||||
])
|
||||
self.norm = transformer_norm_layer(transformer_dim)
|
||||
|
||||
self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, apply_act=False)
|
||||
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
self.patch_area = self.patch_size[0] * self.patch_size[1]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, C, H, W = x.shape
|
||||
patch_h, patch_w = self.patch_size
|
||||
new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
|
||||
num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
|
||||
num_patches = num_patch_h * num_patch_w # N
|
||||
if new_h != H or new_w != W:
|
||||
x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=True)
|
||||
|
||||
# Local representation
|
||||
x = self.conv_kxk(x)
|
||||
x = self.conv_1x1(x)
|
||||
|
||||
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
|
||||
C = x.shape[1]
|
||||
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
|
||||
x = x.reshape(B, C, -1, num_patches)
|
||||
|
||||
# Global representations
|
||||
x = self.transformer(x)
|
||||
x = self.norm(x)
|
||||
|
||||
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
|
||||
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
|
||||
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||
|
||||
x = self.conv_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
register_block('mobilevit', MobileVitBlock)
|
||||
register_block('mobilevit2', MobileVitV2Block)
|
||||
|
||||
|
||||
def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs):
|
||||
@ -252,6 +600,14 @@ def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs):
|
||||
**kwargs)
|
||||
|
||||
|
||||
def _create_mobilevit2(variant, cfg_variant=None, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(
|
||||
ByobNet, variant, pretrained,
|
||||
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
**kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevit_xxs(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs)
|
||||
@ -269,4 +625,75 @@ def mobilevit_s(pretrained=False, **kwargs):
|
||||
|
||||
@register_model
|
||||
def semobilevit_s(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs)
|
||||
return _create_mobilevit('semobilevit_s', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_050(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('mobilevitv2_050', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_075(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('mobilevitv2_075', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_100(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('mobilevitv2_100', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_125(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('mobilevitv2_125', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_150(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('mobilevitv2_150', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_175(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('mobilevitv2_175', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_200(pretrained=False, **kwargs):
|
||||
return _create_mobilevit('mobilevitv2_200', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_150_in22ft1k(pretrained=False, **kwargs):
|
||||
return _create_mobilevit(
|
||||
'mobilevitv2_150_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_175_in22ft1k(pretrained=False, **kwargs):
|
||||
return _create_mobilevit(
|
||||
'mobilevitv2_175_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_200_in22ft1k(pretrained=False, **kwargs):
|
||||
return _create_mobilevit(
|
||||
'mobilevitv2_200_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_150_384_in22ft1k(pretrained=False, **kwargs):
|
||||
return _create_mobilevit(
|
||||
'mobilevitv2_150_384_in22ft1k', cfg_variant='mobilevitv2_150', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_175_384_in22ft1k(pretrained=False, **kwargs):
|
||||
return _create_mobilevit(
|
||||
'mobilevitv2_175_384_in22ft1k', cfg_variant='mobilevitv2_175', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def mobilevitv2_200_384_in22ft1k(pretrained=False, **kwargs):
|
||||
return _create_mobilevit(
|
||||
'mobilevitv2_200_384_in22ft1k', cfg_variant='mobilevitv2_200', pretrained=pretrained, **kwargs)
|
@ -26,7 +26,7 @@ import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, checkpoint_seq
|
||||
from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp
|
||||
from .layers import DropPath, trunc_normal_, to_2tuple, ConvMlp, GroupNorm1
|
||||
from .registry import register_model
|
||||
|
||||
|
||||
@ -80,15 +80,6 @@ class PatchEmbed(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class GroupNorm1(nn.GroupNorm):
|
||||
""" Group Normalization with 1 group.
|
||||
Input: tensor in shape [B, C, H, W]
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels, **kwargs):
|
||||
super().__init__(1, num_channels, **kwargs)
|
||||
|
||||
|
||||
class Pooling(nn.Module):
|
||||
def __init__(self, pool_size=3):
|
||||
super().__init__()
|
||||
|
@ -35,6 +35,16 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
default_cfgs = {
|
||||
# ResNet and Wide ResNet
|
||||
'resnet10t': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet10t_176_c3-f3215ab1.pth',
|
||||
input_size=(3, 176, 176), pool_size=(6, 6),
|
||||
test_crop_pct=0.95, test_input_size=(3, 224, 224),
|
||||
first_conv='conv1.0'),
|
||||
'resnet14t': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet14t_176_c3-c4ed2c37.pth',
|
||||
input_size=(3, 176, 176), pool_size=(6, 6),
|
||||
test_crop_pct=0.95, test_input_size=(3, 224, 224),
|
||||
first_conv='conv1.0'),
|
||||
'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
|
||||
'resnet18d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth',
|
||||
@ -262,6 +272,9 @@ default_cfgs = {
|
||||
'resnetblur101d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
'resnetaa50': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetaa50_a1h-4cf422b3.pth',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=1.0, interpolation='bicubic'),
|
||||
'resnetaa50d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
@ -723,6 +736,24 @@ def _create_resnet(variant, pretrained=False, **kwargs):
|
||||
return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet10t(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-10-T model.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
|
||||
return _create_resnet('resnet10t', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet14t(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-14-T model.
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
|
||||
return _create_resnet('resnet14t', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnet18(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
@ -1436,6 +1467,14 @@ def resnetblur101d(pretrained=False, **kwargs):
|
||||
return _create_resnet('resnetblur101d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetaa50(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50 model with avgpool anti-aliasing
|
||||
"""
|
||||
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, **kwargs)
|
||||
return _create_resnet('resnetaa50', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetaa50d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50-D model with avgpool anti-aliasing
|
||||
|
@ -325,8 +325,8 @@ class VisionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
|
||||
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
|
||||
class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
|
||||
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
|
||||
class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||
weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
@ -360,15 +360,17 @@ class VisionTransformer(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_tokens = 1 if class_token else 0
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.no_embed_class = no_embed_class
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02)
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
@ -428,11 +430,24 @@ class VisionTransformer(nn.Module):
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def _pos_embed(self, x):
|
||||
if self.no_embed_class:
|
||||
# deit-3, updated JAX (big vision)
|
||||
# position embedding does not overlap with class token, add then concat
|
||||
x = x + self.pos_embed
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
else:
|
||||
# original timm, JAX, and deit vit impl
|
||||
# pos_embed has entry for class token, concat then add
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
return self.pos_drop(x)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
x = self._pos_embed(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
@ -442,7 +457,7 @@ class VisionTransformer(nn.Module):
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool:
|
||||
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.fc_norm(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
@ -556,7 +571,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
||||
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
||||
if pos_embed_w.shape != model.pos_embed.shape:
|
||||
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
||||
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
||||
pos_embed_w,
|
||||
model.pos_embed,
|
||||
getattr(model, 'num_prefix_tokens', 1),
|
||||
model.patch_embed.grid_size
|
||||
)
|
||||
model.pos_embed.copy_(pos_embed_w)
|
||||
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
||||
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
||||
@ -585,16 +604,16 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
||||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
||||
|
||||
|
||||
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
||||
def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
|
||||
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
||||
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
||||
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
||||
ntok_new = posemb_new.shape[1]
|
||||
if num_tokens:
|
||||
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
||||
ntok_new -= num_tokens
|
||||
if num_prefix_tokens:
|
||||
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
|
||||
ntok_new -= num_prefix_tokens
|
||||
else:
|
||||
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
||||
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
if not len(gs_new): # backwards compatibility
|
||||
gs_new = [int(math.sqrt(ntok_new))] * 2
|
||||
@ -603,25 +622,34 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
||||
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
|
||||
return posemb
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
import re
|
||||
out_dict = {}
|
||||
if 'model' in state_dict:
|
||||
# For deit models
|
||||
state_dict = state_dict['model']
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
||||
# For old models that I trained prior to conv based patchification
|
||||
O, I, H, W = model.patch_embed.proj.weight.shape
|
||||
v = v.reshape(O, -1, H, W)
|
||||
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
||||
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
||||
# To resize pos embedding when using model at different size from pretrained weights
|
||||
v = resize_pos_embed(
|
||||
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
||||
v,
|
||||
model.pos_embed,
|
||||
getattr(model, 'num_prefix_tokens', 1),
|
||||
model.patch_embed.grid_size
|
||||
)
|
||||
elif 'gamma_' in k:
|
||||
# remap layer-scale gamma into sub-module (deit3 models)
|
||||
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
|
||||
elif 'pre_logits' in k:
|
||||
# NOTE representation layer removed as not used in latest 21k/1k pretrained weights
|
||||
continue
|
||||
|
@ -8,6 +8,7 @@ import math
|
||||
import logging
|
||||
from functools import partial
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -16,7 +17,7 @@ import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import build_model_with_cfg, named_apply
|
||||
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply
|
||||
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple
|
||||
from .registry import register_model
|
||||
|
||||
@ -47,9 +48,16 @@ default_cfgs = {
|
||||
'vit_relpos_base_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'),
|
||||
|
||||
'vit_srelpos_small_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth'),
|
||||
'vit_srelpos_medium_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth'),
|
||||
|
||||
'vit_relpos_medium_patch16_cls_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth'),
|
||||
'vit_relpos_base_patch16_cls_224': _cfg(
|
||||
url=''),
|
||||
'vit_relpos_base_patch16_gapcls_224': _cfg(
|
||||
'vit_relpos_base_patch16_clsgap_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'),
|
||||
|
||||
'vit_relpos_small_patch16_rpn_224': _cfg(url=''),
|
||||
@ -59,35 +67,43 @@ default_cfgs = {
|
||||
}
|
||||
|
||||
|
||||
def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor:
|
||||
# cut and paste w/ modifications from swin / beit codebase
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
def gen_relative_position_index(
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int] = None,
|
||||
class_token: bool = False) -> torch.Tensor:
|
||||
# Adapted with significant modifications from Swin / BeiT codebases
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
window_area = win_size[0] * win_size[1]
|
||||
coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) # 2, Wh, Ww
|
||||
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += win_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += win_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * win_size[1] - 1
|
||||
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
|
||||
if k_size is None:
|
||||
k_coords = q_coords
|
||||
k_size = q_size
|
||||
else:
|
||||
# different q vs k sizes is a WIP
|
||||
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
|
||||
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
|
||||
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
|
||||
|
||||
if class_token:
|
||||
num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3
|
||||
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
|
||||
# NOTE not intended or tested with MLP log-coords
|
||||
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
|
||||
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
|
||||
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
|
||||
relative_position_index[0, 0:] = num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = num_relative_distance - 2
|
||||
relative_position_index[0, 0] = num_relative_distance - 1
|
||||
else:
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
return relative_position_index
|
||||
|
||||
return relative_position_index.contiguous()
|
||||
|
||||
|
||||
def gen_relative_log_coords(
|
||||
win_size: Tuple[int, int],
|
||||
pretrained_win_size: Tuple[int, int] = (0, 0),
|
||||
mode='swin'
|
||||
mode='swin',
|
||||
):
|
||||
# as per official swin-v2 impl, supporting timm swin-v2-cr coords as well
|
||||
assert mode in ('swin', 'cr', 'rw')
|
||||
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
|
||||
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
|
||||
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
|
||||
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
|
||||
@ -100,12 +116,22 @@ def gen_relative_log_coords(
|
||||
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
|
||||
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
|
||||
relative_coords_table *= 8 # normalize to -8, 8
|
||||
scale = math.log2(8)
|
||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
||||
1.0 + relative_coords_table.abs()) / math.log2(8)
|
||||
else:
|
||||
# FIXME we should support a form of normalization (to -1/1) for this mode?
|
||||
scale = math.log2(math.e)
|
||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
||||
1.0 + relative_coords_table.abs()) / scale
|
||||
if mode == 'rw':
|
||||
# cr w/ window size normalization -> [-1,1] log coords
|
||||
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
|
||||
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
|
||||
relative_coords_table *= 8 # scale to -8, 8
|
||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
||||
1.0 + relative_coords_table.abs())
|
||||
relative_coords_table /= math.log2(9) # -> [-1, 1]
|
||||
else:
|
||||
# mode == 'cr'
|
||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
|
||||
1.0 + relative_coords_table.abs())
|
||||
|
||||
return relative_coords_table
|
||||
|
||||
|
||||
@ -115,19 +141,29 @@ class RelPosMlp(nn.Module):
|
||||
window_size,
|
||||
num_heads=8,
|
||||
hidden_dim=128,
|
||||
class_token=False,
|
||||
prefix_tokens=0,
|
||||
mode='cr',
|
||||
pretrained_window_size=(0, 0)
|
||||
):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.class_token = 1 if class_token else 0
|
||||
self.prefix_tokens = prefix_tokens
|
||||
self.num_heads = num_heads
|
||||
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
|
||||
self.apply_sigmoid = mode == 'swin'
|
||||
if mode == 'swin':
|
||||
self.bias_act = nn.Sigmoid()
|
||||
self.bias_gain = 16
|
||||
mlp_bias = (True, False)
|
||||
elif mode == 'rw':
|
||||
self.bias_act = nn.Tanh()
|
||||
self.bias_gain = 4
|
||||
mlp_bias = True
|
||||
else:
|
||||
self.bias_act = nn.Identity()
|
||||
self.bias_gain = None
|
||||
mlp_bias = True
|
||||
|
||||
mlp_bias = (True, False) if mode == 'swin' else True
|
||||
self.mlp = Mlp(
|
||||
2, # x, y
|
||||
hidden_features=hidden_dim,
|
||||
@ -155,10 +191,11 @@ class RelPosMlp(nn.Module):
|
||||
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.view(self.bias_shape)
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1)
|
||||
if self.apply_sigmoid:
|
||||
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
||||
if self.class_token:
|
||||
relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0])
|
||||
relative_position_bias = self.bias_act(relative_position_bias)
|
||||
if self.bias_gain is not None:
|
||||
relative_position_bias = self.bias_gain * relative_position_bias
|
||||
if self.prefix_tokens:
|
||||
relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
|
||||
return relative_position_bias.unsqueeze(0).contiguous()
|
||||
|
||||
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
|
||||
@ -167,18 +204,18 @@ class RelPosMlp(nn.Module):
|
||||
|
||||
class RelPosBias(nn.Module):
|
||||
|
||||
def __init__(self, window_size, num_heads, class_token=False):
|
||||
def __init__(self, window_size, num_heads, prefix_tokens=0):
|
||||
super().__init__()
|
||||
assert prefix_tokens <= 1
|
||||
self.window_size = window_size
|
||||
self.window_area = window_size[0] * window_size[1]
|
||||
self.class_token = 1 if class_token else 0
|
||||
self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,)
|
||||
self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
|
||||
|
||||
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token
|
||||
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
|
||||
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
|
||||
self.register_buffer(
|
||||
"relative_position_index",
|
||||
gen_relative_position_index(self.window_size, class_token=self.class_token),
|
||||
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
@ -306,11 +343,32 @@ class VisionTransformerRelPos(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
|
||||
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-6,
|
||||
class_token=False, fc_norm=False, rel_pos_type='mlp', shared_rel_pos=False, rel_pos_dim=None,
|
||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip',
|
||||
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock):
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
global_pool='avg',
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
init_values=1e-6,
|
||||
class_token=False,
|
||||
fc_norm=False,
|
||||
rel_pos_type='mlp',
|
||||
rel_pos_dim=None,
|
||||
shared_rel_pos=False,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
weight_init='skip',
|
||||
embed_layer=PatchEmbed,
|
||||
norm_layer=None,
|
||||
act_layer=None,
|
||||
block_fn=RelPosBlock
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
@ -345,19 +403,22 @@ class VisionTransformerRelPos(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_tokens = 1 if class_token else 0
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
feat_size = self.patch_embed.grid_size
|
||||
|
||||
rel_pos_args = dict(window_size=feat_size, class_token=class_token)
|
||||
rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens)
|
||||
if rel_pos_type.startswith('mlp'):
|
||||
if rel_pos_dim:
|
||||
rel_pos_args['hidden_dim'] = rel_pos_dim
|
||||
# FIXME experimenting with different relpos log coord configs
|
||||
if 'swin' in rel_pos_type:
|
||||
rel_pos_args['mode'] = 'swin'
|
||||
elif 'rw' in rel_pos_type:
|
||||
rel_pos_args['mode'] = 'rw'
|
||||
rel_pos_cls = partial(RelPosMlp, **rel_pos_args)
|
||||
else:
|
||||
rel_pos_cls = partial(RelPosBias, **rel_pos_args)
|
||||
@ -367,7 +428,7 @@ class VisionTransformerRelPos(nn.Module):
|
||||
# NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both...
|
||||
rel_pos_cls = None
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, self.num_prefix_tokens, embed_dim)) if class_token else None
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList([
|
||||
@ -434,7 +495,7 @@ class VisionTransformerRelPos(nn.Module):
|
||||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool:
|
||||
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.fc_norm(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
@ -502,6 +563,41 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_srelpos_small_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=False,
|
||||
rel_pos_dim=384, shared_rel_pos=True, **kwargs)
|
||||
model = _create_vision_transformer_relpos('vit_srelpos_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_srelpos_medium_patch16_224(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False,
|
||||
rel_pos_dim=512, shared_rel_pos=True, **kwargs)
|
||||
model = _create_vision_transformer_relpos(
|
||||
'vit_srelpos_medium_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_relpos_medium_patch16_cls_224(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-M/16) w/ relative log-coord position, class token present
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False,
|
||||
rel_pos_dim=256, class_token=True, global_pool='token', **kwargs)
|
||||
model = _create_vision_transformer_relpos(
|
||||
'vit_relpos_medium_patch16_cls_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
|
||||
@ -514,14 +610,14 @@ def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs):
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_relpos_base_patch16_gapcls_224(pretrained=False, **kwargs):
|
||||
def vit_relpos_base_patch16_clsgap_224(pretrained=False, **kwargs):
|
||||
""" ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
|
||||
NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled
|
||||
Leaving here for comparisons w/ a future re-train as it performs quite well.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs)
|
||||
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_gapcls_224', pretrained=pretrained, **model_kwargs)
|
||||
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_clsgap_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -1 +1 @@
|
||||
__version__ = '0.6.2.dev0'
|
||||
__version__ = '0.6.3.dev0'
|
||||
|
29
validate.py
29
validate.py
@ -38,6 +38,12 @@ try:
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from functorch.compile import memory_efficient_fusion
|
||||
has_functorch = True
|
||||
except ImportError as e:
|
||||
has_functorch = False
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
_logger = logging.getLogger('validate')
|
||||
|
||||
@ -61,6 +67,8 @@ parser.add_argument('--img-size', default=None, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
parser.add_argument('--input-size', default=None, nargs=3, type=int,
|
||||
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
||||
parser.add_argument('--use-train-size', action='store_true', default=False,
|
||||
help='force use of train input size, even when test size is specified in pretrained cfg')
|
||||
parser.add_argument('--crop-pct', default=None, type=float,
|
||||
metavar='N', help='Input image center crop pct')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
@ -101,8 +109,11 @@ parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
||||
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
||||
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||
help='use ema version of weights if present')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
scripting_group = parser.add_mutually_exclusive_group()
|
||||
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='torch.jit.script the full model')
|
||||
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
|
||||
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
|
||||
parser.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||
@ -155,14 +166,22 @@ def validate(args):
|
||||
param_count = sum([m.numel() for m in model.parameters()])
|
||||
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
||||
|
||||
data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
|
||||
data_config = resolve_data_config(
|
||||
vars(args),
|
||||
model=model,
|
||||
use_test_size=not args.use_train_size,
|
||||
verbose=True
|
||||
)
|
||||
test_time_pool = False
|
||||
if args.test_pool:
|
||||
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
|
||||
model, test_time_pool = apply_test_time_pool(model, data_config)
|
||||
|
||||
if args.torchscript:
|
||||
torch.jit.optimized_execution(True)
|
||||
model = torch.jit.script(model)
|
||||
model = torch.jit.trace(model, example_inputs=torch.randn((args.batch_size,) + data_config['input_size']))
|
||||
if args.aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
model = memory_efficient_fusion(model)
|
||||
|
||||
model = model.cuda()
|
||||
if args.apex_amp:
|
||||
|
Loading…
x
Reference in New Issue
Block a user