mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #533 from rwightman/pit_and_vit_update
Addition of PiT models and update/cleanup of ViT, new NFNet weight, TFDS wrapper fix, few misc fixes/updates
This commit is contained in:
commit
d5ed58d623
18
README.md
18
README.md
@ -23,6 +23,22 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
|
|||||||
|
|
||||||
## What's New
|
## What's New
|
||||||
|
|
||||||
|
### April 1, 2021
|
||||||
|
* Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference
|
||||||
|
* Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit)
|
||||||
|
* Merged distilled variant into main for torchscript compatibility
|
||||||
|
* Some `timm` cleanup/style tweaks and weights have hub download support
|
||||||
|
* Cleanup Vision Transformer (ViT) models
|
||||||
|
* Merge distilled (DeiT) model into main so that torchscript can work
|
||||||
|
* Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch)
|
||||||
|
* Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids
|
||||||
|
* Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants
|
||||||
|
* nn.Sequential for block stack (does not break downstream compat)
|
||||||
|
* TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT)
|
||||||
|
* Add RegNetY-160 weights from DeiT teacher model
|
||||||
|
* Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288
|
||||||
|
* Some fixes/improvements for TFDS dataset wrapper
|
||||||
|
|
||||||
### March 17, 2021
|
### March 17, 2021
|
||||||
* Add new ECA-NFNet-L0 (rename `nfnet_l0c`->`eca_nfnet_l0`) weights trained by myself.
|
* Add new ECA-NFNet-L0 (rename `nfnet_l0c`->`eca_nfnet_l0`) weights trained by myself.
|
||||||
* 82.6 top-1 @ 288x288, 82.8 @ 320x320, trained at 224x224
|
* 82.6 top-1 @ 288x288, 82.8 @ 320x320, trained at 224x224
|
||||||
@ -189,6 +205,7 @@ A full version of the list below with source links can be found in the [document
|
|||||||
* NFNet-F - https://arxiv.org/abs/2102.06171
|
* NFNet-F - https://arxiv.org/abs/2102.06171
|
||||||
* NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
|
* NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
|
||||||
* PNasNet - https://arxiv.org/abs/1712.00559
|
* PNasNet - https://arxiv.org/abs/1712.00559
|
||||||
|
* Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
|
||||||
* RegNet - https://arxiv.org/abs/2003.13678
|
* RegNet - https://arxiv.org/abs/2003.13678
|
||||||
* RepVGG - https://arxiv.org/abs/2101.03697
|
* RepVGG - https://arxiv.org/abs/2101.03697
|
||||||
* ResNet/ResNeXt
|
* ResNet/ResNeXt
|
||||||
@ -204,6 +221,7 @@ A full version of the list below with source links can be found in the [document
|
|||||||
* ReXNet - https://arxiv.org/abs/2007.00992
|
* ReXNet - https://arxiv.org/abs/2007.00992
|
||||||
* SelecSLS - https://arxiv.org/abs/1907.00837
|
* SelecSLS - https://arxiv.org/abs/1907.00837
|
||||||
* Selective Kernel Networks - https://arxiv.org/abs/1903.06586
|
* Selective Kernel Networks - https://arxiv.org/abs/1903.06586
|
||||||
|
* Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112
|
||||||
* TResNet - https://arxiv.org/abs/2003.13630
|
* TResNet - https://arxiv.org/abs/2003.13630
|
||||||
* Vision Transformer - https://arxiv.org/abs/2010.11929
|
* Vision Transformer - https://arxiv.org/abs/2010.11929
|
||||||
* VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
|
* VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
|
||||||
|
481
benchmark.py
Executable file
481
benchmark.py
Executable file
@ -0,0 +1,481 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
""" Model Benchmark Script
|
||||||
|
|
||||||
|
An inference and train step benchmark script for timm models.
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.parallel
|
||||||
|
from collections import OrderedDict
|
||||||
|
from contextlib import suppress
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from timm.models import create_model, is_model, list_models
|
||||||
|
from timm.optim import create_optimizer_v2
|
||||||
|
from timm.data import resolve_data_config
|
||||||
|
from timm.utils import AverageMeter, setup_default_logging
|
||||||
|
|
||||||
|
|
||||||
|
has_apex = False
|
||||||
|
try:
|
||||||
|
from apex import amp
|
||||||
|
has_apex = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
has_native_amp = False
|
||||||
|
try:
|
||||||
|
if getattr(torch.cuda.amp, 'autocast') is not None:
|
||||||
|
has_native_amp = True
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
_logger = logging.getLogger('validate')
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='PyTorch Benchmark')
|
||||||
|
|
||||||
|
# benchmark specific args
|
||||||
|
parser.add_argument('--model-list', metavar='NAME', default='',
|
||||||
|
help='txt file based list of model names to benchmark')
|
||||||
|
parser.add_argument('--bench', default='both', type=str,
|
||||||
|
help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'inference'")
|
||||||
|
parser.add_argument('--detail', action='store_true', default=False,
|
||||||
|
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
|
||||||
|
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||||
|
help='Output csv file for validation results (summary)')
|
||||||
|
parser.add_argument('--num-warm-iter', default=10, type=int,
|
||||||
|
metavar='N', help='Number of warmup iterations (default: 10)')
|
||||||
|
parser.add_argument('--num-bench-iter', default=40, type=int,
|
||||||
|
metavar='N', help='Number of benchmark iterations (default: 40)')
|
||||||
|
|
||||||
|
# common inference / train args
|
||||||
|
parser.add_argument('--model', '-m', metavar='NAME', default='resnet50',
|
||||||
|
help='model architecture (default: resnet50)')
|
||||||
|
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||||
|
metavar='N', help='mini-batch size (default: 256)')
|
||||||
|
parser.add_argument('--img-size', default=None, type=int,
|
||||||
|
metavar='N', help='Input image dimension, uses model default if empty')
|
||||||
|
parser.add_argument('--input-size', default=None, nargs=3, type=int,
|
||||||
|
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
||||||
|
parser.add_argument('--num-classes', type=int, default=None,
|
||||||
|
help='Number classes in dataset')
|
||||||
|
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
||||||
|
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
||||||
|
parser.add_argument('--channels-last', action='store_true', default=False,
|
||||||
|
help='Use channels_last memory layout')
|
||||||
|
parser.add_argument('--amp', action='store_true', default=False,
|
||||||
|
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
|
||||||
|
parser.add_argument('--precision', default='float32', type=str,
|
||||||
|
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
|
||||||
|
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||||
|
help='convert model torchscript for inference')
|
||||||
|
|
||||||
|
|
||||||
|
# train optimizer parameters
|
||||||
|
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||||
|
help='Optimizer (default: "sgd"')
|
||||||
|
parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
|
||||||
|
help='Optimizer Epsilon (default: None, use opt default)')
|
||||||
|
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
||||||
|
help='Optimizer Betas (default: None, use opt default)')
|
||||||
|
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
||||||
|
help='Optimizer momentum (default: 0.9)')
|
||||||
|
parser.add_argument('--weight-decay', type=float, default=0.0001,
|
||||||
|
help='weight decay (default: 0.0001)')
|
||||||
|
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
|
||||||
|
help='Clip gradient norm (default: None, no clipping)')
|
||||||
|
parser.add_argument('--clip-mode', type=str, default='norm',
|
||||||
|
help='Gradient clipping mode. One of ("norm", "value", "agc")')
|
||||||
|
|
||||||
|
|
||||||
|
# model regularization / loss params that impact model or loss fn
|
||||||
|
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||||
|
help='Label smoothing (default: 0.1)')
|
||||||
|
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
||||||
|
help='Dropout rate (default: 0.)')
|
||||||
|
parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
|
||||||
|
help='Drop path rate (default: None)')
|
||||||
|
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
|
||||||
|
help='Drop block rate (default: None)')
|
||||||
|
|
||||||
|
|
||||||
|
def timestamp(sync=False):
|
||||||
|
return time.perf_counter()
|
||||||
|
|
||||||
|
|
||||||
|
def cuda_timestamp(sync=False, device=None):
|
||||||
|
if sync:
|
||||||
|
torch.cuda.synchronize(device=device)
|
||||||
|
return time.perf_counter()
|
||||||
|
|
||||||
|
|
||||||
|
def count_params(model: nn.Module):
|
||||||
|
return sum([m.numel() for m in model.parameters()])
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_precision(precision: str):
|
||||||
|
assert precision in ('amp', 'float16', 'bfloat16', 'float32')
|
||||||
|
use_amp = False
|
||||||
|
model_dtype = torch.float32
|
||||||
|
data_dtype = torch.float32
|
||||||
|
if precision == 'amp':
|
||||||
|
use_amp = True
|
||||||
|
elif precision == 'float16':
|
||||||
|
model_dtype = torch.float16
|
||||||
|
data_dtype = torch.float16
|
||||||
|
elif precision == 'bfloat16':
|
||||||
|
model_dtype = torch.bfloat16
|
||||||
|
data_dtype = torch.bfloat16
|
||||||
|
return use_amp, model_dtype, data_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkRunner:
|
||||||
|
def __init__(
|
||||||
|
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
|
||||||
|
num_warm_iter=10, num_bench_iter=50, **kwargs):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.detail = detail
|
||||||
|
self.device = device
|
||||||
|
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
|
||||||
|
self.channels_last = kwargs.pop('channels_last', False)
|
||||||
|
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
|
||||||
|
|
||||||
|
self.model = create_model(
|
||||||
|
model_name,
|
||||||
|
num_classes=kwargs.pop('num_classes', None),
|
||||||
|
in_chans=3,
|
||||||
|
global_pool=kwargs.pop('gp', 'fast'),
|
||||||
|
scriptable=torchscript)
|
||||||
|
self.model.to(
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.model_dtype,
|
||||||
|
memory_format=torch.channels_last if self.channels_last else None)
|
||||||
|
self.num_classes = self.model.num_classes
|
||||||
|
self.param_count = count_params(self.model)
|
||||||
|
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
|
||||||
|
if torchscript:
|
||||||
|
self.model = torch.jit.script(self.model)
|
||||||
|
|
||||||
|
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=True)
|
||||||
|
self.input_size = data_config['input_size']
|
||||||
|
self.batch_size = kwargs.pop('batch_size', 256)
|
||||||
|
|
||||||
|
self.example_inputs = None
|
||||||
|
self.num_warm_iter = num_warm_iter
|
||||||
|
self.num_bench_iter = num_bench_iter
|
||||||
|
self.log_freq = num_bench_iter // 5
|
||||||
|
if 'cuda' in self.device:
|
||||||
|
self.time_fn = partial(cuda_timestamp, device=self.device)
|
||||||
|
else:
|
||||||
|
self.time_fn = timestamp
|
||||||
|
|
||||||
|
def _init_input(self):
|
||||||
|
self.example_inputs = torch.randn(
|
||||||
|
(self.batch_size,) + self.input_size, device=self.device, dtype=self.data_dtype)
|
||||||
|
if self.channels_last:
|
||||||
|
self.example_inputs = self.example_inputs.contiguous(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceBenchmarkRunner(BenchmarkRunner):
|
||||||
|
|
||||||
|
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
|
||||||
|
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
def _step():
|
||||||
|
t_step_start = self.time_fn()
|
||||||
|
with self.amp_autocast():
|
||||||
|
output = self.model(self.example_inputs)
|
||||||
|
t_step_end = self.time_fn(True)
|
||||||
|
return t_step_end - t_step_start
|
||||||
|
|
||||||
|
_logger.info(
|
||||||
|
f'Running inference benchmark on {self.model_name} for {self.num_bench_iter} steps w/ '
|
||||||
|
f'input size {self.input_size} and batch size {self.batch_size}.')
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
self._init_input()
|
||||||
|
|
||||||
|
for _ in range(self.num_warm_iter):
|
||||||
|
_step()
|
||||||
|
|
||||||
|
total_step = 0.
|
||||||
|
num_samples = 0
|
||||||
|
t_run_start = self.time_fn()
|
||||||
|
for i in range(self.num_bench_iter):
|
||||||
|
delta_fwd = _step()
|
||||||
|
total_step += delta_fwd
|
||||||
|
num_samples += self.batch_size
|
||||||
|
num_steps = i + 1
|
||||||
|
if num_steps % self.log_freq == 0:
|
||||||
|
_logger.info(
|
||||||
|
f"Infer [{num_steps}/{self.num_bench_iter}]."
|
||||||
|
f" {num_samples / total_step:0.2f} samples/sec."
|
||||||
|
f" {1000 * total_step / num_steps:0.3f} ms/step.")
|
||||||
|
t_run_end = self.time_fn(True)
|
||||||
|
t_run_elapsed = t_run_end - t_run_start
|
||||||
|
|
||||||
|
results = dict(
|
||||||
|
samples_per_sec=round(num_samples / t_run_elapsed, 2),
|
||||||
|
step_time=round(1000 * total_step / self.num_bench_iter, 3),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
img_size=self.input_size[-1],
|
||||||
|
param_count=round(self.param_count / 1e6, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
_logger.info(
|
||||||
|
f"Inference benchmark of {self.model_name} done. "
|
||||||
|
f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/step")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
class TrainBenchmarkRunner(BenchmarkRunner):
|
||||||
|
|
||||||
|
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
|
||||||
|
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
if kwargs.pop('smoothing', 0) > 0:
|
||||||
|
self.loss = nn.CrossEntropyLoss().to(self.device)
|
||||||
|
else:
|
||||||
|
self.loss = nn.CrossEntropyLoss().to(self.device)
|
||||||
|
self.target_shape = tuple()
|
||||||
|
|
||||||
|
self.optimizer = create_optimizer_v2(
|
||||||
|
self.model,
|
||||||
|
optimizer_name=kwargs.pop('opt', 'sgd'),
|
||||||
|
learning_rate=kwargs.pop('lr', 1e-4))
|
||||||
|
|
||||||
|
def _gen_target(self, batch_size):
|
||||||
|
return torch.empty(
|
||||||
|
(batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
def _step(detail=False):
|
||||||
|
self.optimizer.zero_grad() # can this be ignored?
|
||||||
|
t_start = self.time_fn()
|
||||||
|
t_fwd_end = t_start
|
||||||
|
t_bwd_end = t_start
|
||||||
|
with self.amp_autocast():
|
||||||
|
output = self.model(self.example_inputs)
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = output[0]
|
||||||
|
if detail:
|
||||||
|
t_fwd_end = self.time_fn(True)
|
||||||
|
target = self._gen_target(output.shape[0])
|
||||||
|
self.loss(output, target).backward()
|
||||||
|
if detail:
|
||||||
|
t_bwd_end = self.time_fn(True)
|
||||||
|
self.optimizer.step()
|
||||||
|
t_end = self.time_fn(True)
|
||||||
|
if detail:
|
||||||
|
delta_fwd = t_fwd_end - t_start
|
||||||
|
delta_bwd = t_bwd_end - t_fwd_end
|
||||||
|
delta_opt = t_end - t_bwd_end
|
||||||
|
return delta_fwd, delta_bwd, delta_opt
|
||||||
|
else:
|
||||||
|
delta_step = t_end - t_start
|
||||||
|
return delta_step
|
||||||
|
|
||||||
|
_logger.info(
|
||||||
|
f'Running train benchmark on {self.model_name} for {self.num_bench_iter} steps w/ '
|
||||||
|
f'input size {self.input_size} and batch size {self.batch_size}.')
|
||||||
|
|
||||||
|
self._init_input()
|
||||||
|
|
||||||
|
for _ in range(self.num_warm_iter):
|
||||||
|
_step()
|
||||||
|
|
||||||
|
t_run_start = self.time_fn()
|
||||||
|
if self.detail:
|
||||||
|
total_fwd = 0.
|
||||||
|
total_bwd = 0.
|
||||||
|
total_opt = 0.
|
||||||
|
num_samples = 0
|
||||||
|
for i in range(self.num_bench_iter):
|
||||||
|
delta_fwd, delta_bwd, delta_opt = _step(True)
|
||||||
|
num_samples += self.batch_size
|
||||||
|
total_fwd += delta_fwd
|
||||||
|
total_bwd += delta_bwd
|
||||||
|
total_opt += delta_opt
|
||||||
|
num_steps = (i + 1)
|
||||||
|
if num_steps % self.log_freq == 0:
|
||||||
|
total_step = total_fwd + total_bwd + total_opt
|
||||||
|
_logger.info(
|
||||||
|
f"Train [{num_steps}/{self.num_bench_iter}]."
|
||||||
|
f" {num_samples / total_step:0.2f} samples/sec."
|
||||||
|
f" {1000 * total_fwd / num_steps:0.3f} ms/step fwd,"
|
||||||
|
f" {1000 * total_bwd / num_steps:0.3f} ms/step bwd,"
|
||||||
|
f" {1000 * total_opt / num_steps:0.3f} ms/step opt."
|
||||||
|
)
|
||||||
|
total_step = total_fwd + total_bwd + total_opt
|
||||||
|
t_run_elapsed = self.time_fn() - t_run_start
|
||||||
|
results = dict(
|
||||||
|
samples_per_sec=round(num_samples / t_run_elapsed, 2),
|
||||||
|
step_time=round(1000 * total_step / self.num_bench_iter, 3),
|
||||||
|
fwd_time=round(1000 * total_fwd / self.num_bench_iter, 3),
|
||||||
|
bwd_time=round(1000 * total_bwd / self.num_bench_iter, 3),
|
||||||
|
opt_time=round(1000 * total_opt / self.num_bench_iter, 3),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
img_size=self.input_size[-1],
|
||||||
|
param_count=round(self.param_count / 1e6, 2),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
total_step = 0.
|
||||||
|
num_samples = 0
|
||||||
|
for i in range(self.num_bench_iter):
|
||||||
|
delta_step = _step(False)
|
||||||
|
num_samples += self.batch_size
|
||||||
|
total_step += delta_step
|
||||||
|
num_steps = (i + 1)
|
||||||
|
if num_steps % self.log_freq == 0:
|
||||||
|
_logger.info(
|
||||||
|
f"Train [{num_steps}/{self.num_bench_iter}]."
|
||||||
|
f" {num_samples / total_step:0.2f} samples/sec."
|
||||||
|
f" {1000 * total_step / num_steps:0.3f} ms/step.")
|
||||||
|
t_run_elapsed = self.time_fn() - t_run_start
|
||||||
|
results = dict(
|
||||||
|
samples_per_sec=round(num_samples / t_run_elapsed, 2),
|
||||||
|
step_time=round(1000 * total_step / self.num_bench_iter, 3),
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
img_size=self.input_size[-1],
|
||||||
|
param_count=round(self.param_count / 1e6, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
_logger.info(
|
||||||
|
f"Train benchmark of {self.model_name} done. "
|
||||||
|
f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/sample")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def decay_batch_exp(batch_size, factor=0.5, divisor=16):
|
||||||
|
out_batch_size = batch_size * factor
|
||||||
|
if out_batch_size > divisor:
|
||||||
|
out_batch_size = (out_batch_size + 1) // divisor * divisor
|
||||||
|
else:
|
||||||
|
out_batch_size = batch_size - 1
|
||||||
|
return max(0, int(out_batch_size))
|
||||||
|
|
||||||
|
|
||||||
|
def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
|
||||||
|
batch_size = initial_batch_size
|
||||||
|
results = dict()
|
||||||
|
while batch_size >= 1:
|
||||||
|
try:
|
||||||
|
bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs)
|
||||||
|
results = bench.run()
|
||||||
|
return results
|
||||||
|
except RuntimeError as e:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
batch_size = decay_batch_exp(batch_size)
|
||||||
|
print(f'Error: {str(e)} while running benchmark. Reducing batch size to {batch_size} for retry.')
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark(args):
|
||||||
|
if args.amp:
|
||||||
|
_logger.warning("Overriding precision to 'amp' since --amp flag set.")
|
||||||
|
args.precision = 'amp'
|
||||||
|
_logger.info(f'Benchmarking in {args.precision} precision. '
|
||||||
|
f'{"NHWC" if args.channels_last else "NCHW"} layout. '
|
||||||
|
f'torchscript {"enabled" if args.torchscript else "disabled"}')
|
||||||
|
|
||||||
|
bench_kwargs = vars(args).copy()
|
||||||
|
bench_kwargs.pop('amp')
|
||||||
|
model = bench_kwargs.pop('model')
|
||||||
|
batch_size = bench_kwargs.pop('batch_size')
|
||||||
|
|
||||||
|
bench_fns = (InferenceBenchmarkRunner,)
|
||||||
|
prefixes = ('infer',)
|
||||||
|
if args.bench == 'both':
|
||||||
|
bench_fns = (
|
||||||
|
InferenceBenchmarkRunner,
|
||||||
|
TrainBenchmarkRunner
|
||||||
|
)
|
||||||
|
prefixes = ('infer', 'train')
|
||||||
|
elif args.bench == 'train':
|
||||||
|
bench_fns = TrainBenchmarkRunner,
|
||||||
|
prefixes = 'train',
|
||||||
|
|
||||||
|
model_results = OrderedDict(model=model)
|
||||||
|
for prefix, bench_fn in zip(prefixes, bench_fns):
|
||||||
|
run_results = _try_run(model, bench_fn, initial_batch_size=batch_size, bench_kwargs=bench_kwargs)
|
||||||
|
if prefix:
|
||||||
|
run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()}
|
||||||
|
model_results.update(run_results)
|
||||||
|
param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0))
|
||||||
|
model_results.setdefault('param_count', param_count)
|
||||||
|
model_results.pop('train_param_count', 0)
|
||||||
|
return model_results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
setup_default_logging()
|
||||||
|
args = parser.parse_args()
|
||||||
|
model_cfgs = []
|
||||||
|
model_names = []
|
||||||
|
|
||||||
|
if args.model_list:
|
||||||
|
args.model = ''
|
||||||
|
with open(args.model_list) as f:
|
||||||
|
model_names = [line.rstrip() for line in f]
|
||||||
|
model_cfgs = [(n, None) for n in model_names]
|
||||||
|
elif args.model == 'all':
|
||||||
|
# validate all models in a list of names with pretrained checkpoints
|
||||||
|
args.pretrained = True
|
||||||
|
model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
|
||||||
|
model_cfgs = [(n, None) for n in model_names]
|
||||||
|
elif not is_model(args.model):
|
||||||
|
# model name doesn't exist, try as wildcard filter
|
||||||
|
model_names = list_models(args.model)
|
||||||
|
model_cfgs = [(n, None) for n in model_names]
|
||||||
|
|
||||||
|
if len(model_cfgs):
|
||||||
|
results_file = args.results_file or './benchmark.csv'
|
||||||
|
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
||||||
|
results = []
|
||||||
|
try:
|
||||||
|
for m, _ in model_cfgs:
|
||||||
|
if not m:
|
||||||
|
continue
|
||||||
|
args.model = m
|
||||||
|
r = benchmark(args)
|
||||||
|
results.append(r)
|
||||||
|
except KeyboardInterrupt as e:
|
||||||
|
pass
|
||||||
|
sort_key = 'train_samples_per_sec' if 'train' in args.bench else 'infer_samples_per_sec'
|
||||||
|
results = sorted(results, key=lambda x: x[sort_key], reverse=True)
|
||||||
|
if len(results):
|
||||||
|
write_results(results_file, results)
|
||||||
|
|
||||||
|
import json
|
||||||
|
json_str = json.dumps(results, indent=4)
|
||||||
|
print(json_str)
|
||||||
|
else:
|
||||||
|
benchmark(args)
|
||||||
|
|
||||||
|
|
||||||
|
def write_results(results_file, results):
|
||||||
|
with open(results_file, mode='w') as cf:
|
||||||
|
dw = csv.DictWriter(cf, fieldnames=results[0].keys())
|
||||||
|
dw.writeheader()
|
||||||
|
for r in results:
|
||||||
|
dw.writerow(r)
|
||||||
|
cf.flush()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -14,7 +14,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
|
|||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
|
||||||
# transformer models don't support many of the spatial / feature based model functionalities
|
# transformer models don't support many of the spatial / feature based model functionalities
|
||||||
NON_STD_FILTERS = ['vit_*', 'tnt_*']
|
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*']
|
||||||
NUM_NON_STD = len(NON_STD_FILTERS)
|
NUM_NON_STD = len(NON_STD_FILTERS)
|
||||||
|
|
||||||
# exclude models that cause specific test failures
|
# exclude models that cause specific test failures
|
||||||
|
@ -5,7 +5,7 @@ from .constants import *
|
|||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=True):
|
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
|
||||||
new_config = {}
|
new_config = {}
|
||||||
default_cfg = default_cfg
|
default_cfg = default_cfg
|
||||||
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
|
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
|
||||||
|
@ -73,12 +73,13 @@ class IterableImageDataset(data.IterableDataset):
|
|||||||
batch_size=None,
|
batch_size=None,
|
||||||
class_map='',
|
class_map='',
|
||||||
load_bytes=False,
|
load_bytes=False,
|
||||||
|
repeats=0,
|
||||||
transform=None,
|
transform=None,
|
||||||
):
|
):
|
||||||
assert parser is not None
|
assert parser is not None
|
||||||
if isinstance(parser, str):
|
if isinstance(parser, str):
|
||||||
self.parser = create_parser(
|
self.parser = create_parser(
|
||||||
parser, root=root, split=split, is_training=is_training, batch_size=batch_size)
|
parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
|
||||||
else:
|
else:
|
||||||
self.parser = parser
|
self.parser = parser
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
@ -23,6 +23,7 @@ def create_dataset(name, root, split='validation', search_split=True, is_trainin
|
|||||||
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
|
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
|
||||||
else:
|
else:
|
||||||
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
|
||||||
|
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
|
||||||
if search_split and os.path.isdir(root):
|
if search_split and os.path.isdir(root):
|
||||||
root = _search_split(root, split)
|
root = _search_split(root, split)
|
||||||
ds = ImageDataset(root, parser=name, **kwargs)
|
ds = ImageDataset(root, parser=name, **kwargs)
|
||||||
|
@ -29,6 +29,11 @@ SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
|
|||||||
PREFETCH_SIZE = 4096 # samples to prefetch
|
PREFETCH_SIZE = 4096 # samples to prefetch
|
||||||
|
|
||||||
|
|
||||||
|
def even_split_indices(split, n, num_samples):
|
||||||
|
partitions = [round(i * num_samples / n) for i in range(n + 1)]
|
||||||
|
return [f"{split}[{partitions[i]}:{partitions[i+1]}]" for i in range(n)]
|
||||||
|
|
||||||
|
|
||||||
class ParserTfds(Parser):
|
class ParserTfds(Parser):
|
||||||
""" Wrap Tensorflow Datasets for use in PyTorch
|
""" Wrap Tensorflow Datasets for use in PyTorch
|
||||||
|
|
||||||
@ -52,7 +57,7 @@ class ParserTfds(Parser):
|
|||||||
components.
|
components.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None):
|
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.root = root
|
self.root = root
|
||||||
self.split = split
|
self.split = split
|
||||||
@ -62,6 +67,8 @@ class ParserTfds(Parser):
|
|||||||
assert batch_size is not None,\
|
assert batch_size is not None,\
|
||||||
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
self.repeats = repeats
|
||||||
|
self.subsplit = None
|
||||||
|
|
||||||
self.builder = tfds.builder(name, data_dir=root)
|
self.builder = tfds.builder(name, data_dir=root)
|
||||||
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
|
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
|
||||||
@ -95,6 +102,7 @@ class ParserTfds(Parser):
|
|||||||
if worker_info is not None:
|
if worker_info is not None:
|
||||||
self.worker_info = worker_info
|
self.worker_info = worker_info
|
||||||
num_workers = worker_info.num_workers
|
num_workers = worker_info.num_workers
|
||||||
|
global_num_workers = self.dist_num_replicas * num_workers
|
||||||
worker_id = worker_info.id
|
worker_id = worker_info.id
|
||||||
|
|
||||||
# FIXME I need to spend more time figuring out the best way to distribute/split data across
|
# FIXME I need to spend more time figuring out the best way to distribute/split data across
|
||||||
@ -114,19 +122,31 @@ class ParserTfds(Parser):
|
|||||||
# split = split + '[{}:]'.format(start)
|
# split = split + '[{}:]'.format(start)
|
||||||
# else:
|
# else:
|
||||||
# split = split + '[{}:{}]'.format(start, start + split_size)
|
# split = split + '[{}:{}]'.format(start, start + split_size)
|
||||||
|
if not self.is_training and '[' not in self.split:
|
||||||
|
# If not training, and split doesn't define a subsplit, manually split the dataset
|
||||||
|
# for more even samples / worker
|
||||||
|
self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[
|
||||||
|
self.dist_rank * num_workers + worker_id]
|
||||||
|
|
||||||
|
if self.subsplit is None:
|
||||||
input_context = tf.distribute.InputContext(
|
input_context = tf.distribute.InputContext(
|
||||||
num_input_pipelines=self.dist_num_replicas * num_workers,
|
num_input_pipelines=self.dist_num_replicas * num_workers,
|
||||||
input_pipeline_id=self.dist_rank * num_workers + worker_id,
|
input_pipeline_id=self.dist_rank * num_workers + worker_id,
|
||||||
num_replicas_in_sync=self.dist_num_replicas # FIXME does this have any impact?
|
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
input_context = None
|
||||||
|
|
||||||
read_config = tfds.ReadConfig(input_context=input_context)
|
read_config = tfds.ReadConfig(
|
||||||
ds = self.builder.as_dataset(split=split, shuffle_files=self.shuffle, read_config=read_config)
|
shuffle_seed=42,
|
||||||
|
shuffle_reshuffle_each_iteration=True,
|
||||||
|
input_context=input_context)
|
||||||
|
ds = self.builder.as_dataset(
|
||||||
|
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
|
||||||
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
|
||||||
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
|
||||||
ds.options().experimental_threading.max_intra_op_parallelism = 1
|
ds.options().experimental_threading.max_intra_op_parallelism = 1
|
||||||
if self.is_training:
|
if self.is_training or self.repeats > 1:
|
||||||
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
|
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
|
||||||
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
|
||||||
ds = ds.repeat() # allow wrap around and break iteration manually
|
ds = ds.repeat() # allow wrap around and break iteration manually
|
||||||
@ -143,7 +163,7 @@ class ParserTfds(Parser):
|
|||||||
# This adds extra samples and will slightly alter validation results.
|
# This adds extra samples and will slightly alter validation results.
|
||||||
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
|
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
|
||||||
# batches are produced (underlying tfds iter wraps around)
|
# batches are produced (underlying tfds iter wraps around)
|
||||||
target_sample_count = math.ceil(self.num_samples / self._num_pipelines)
|
target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines)
|
||||||
if self.is_training:
|
if self.is_training:
|
||||||
# round up to nearest batch_size per worker-replica
|
# round up to nearest batch_size per worker-replica
|
||||||
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
|
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
|
||||||
@ -160,8 +180,8 @@ class ParserTfds(Parser):
|
|||||||
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
|
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
|
||||||
# Validation batch padding only done for distributed training where results are reduced across nodes.
|
# Validation batch padding only done for distributed training where results are reduced across nodes.
|
||||||
# For single process case, it won't matter if workers return different batch sizes.
|
# For single process case, it won't matter if workers return different batch sizes.
|
||||||
# FIXME this needs more testing, possible for sharding / split api to cause differences of > 1?
|
# FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this
|
||||||
assert target_sample_count - sample_count == 1 # should only be off by 1 or sharding is not optimal
|
# approach is not optimal
|
||||||
yield img, sample['label'] # yield prev sample again
|
yield img, sample['label'] # yield prev sample again
|
||||||
sample_count += 1
|
sample_count += 1
|
||||||
|
|
||||||
@ -176,7 +196,7 @@ class ParserTfds(Parser):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
# this is just an estimate and does not factor in extra samples added to pad batches based on
|
# this is just an estimate and does not factor in extra samples added to pad batches based on
|
||||||
# complete worker & replica info (not available until init in dataloader).
|
# complete worker & replica info (not available until init in dataloader).
|
||||||
return math.ceil(self.num_samples / self.dist_num_replicas)
|
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
|
||||||
|
|
||||||
def _filename(self, index, basename=False, absolute=False):
|
def _filename(self, index, basename=False, absolute=False):
|
||||||
assert False, "Not supported" # no random access to samples
|
assert False, "Not supported" # no random access to samples
|
||||||
|
@ -14,6 +14,7 @@ from .inception_v4 import *
|
|||||||
from .mobilenetv3 import *
|
from .mobilenetv3 import *
|
||||||
from .nasnet import *
|
from .nasnet import *
|
||||||
from .nfnet import *
|
from .nfnet import *
|
||||||
|
from .pit import *
|
||||||
from .pnasnet import *
|
from .pnasnet import *
|
||||||
from .regnet import *
|
from .regnet import *
|
||||||
from .res2net import *
|
from .res2net import *
|
||||||
@ -28,6 +29,7 @@ from .tnt import *
|
|||||||
from .tresnet import *
|
from .tresnet import *
|
||||||
from .vgg import *
|
from .vgg import *
|
||||||
from .vision_transformer import *
|
from .vision_transformer import *
|
||||||
|
from .vision_transformer_hybrid import *
|
||||||
from .vovnet import *
|
from .vovnet import *
|
||||||
from .xception import *
|
from .xception import *
|
||||||
from .xception_aligned import *
|
from .xception_aligned import *
|
||||||
|
@ -198,15 +198,19 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte
|
|||||||
_logger.warning(
|
_logger.warning(
|
||||||
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
|
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
|
||||||
|
|
||||||
classifier_name = default_cfg.get('classifier', None)
|
classifiers = default_cfg.get('classifier', None)
|
||||||
label_offset = default_cfg.get('label_offset', 0)
|
label_offset = default_cfg.get('label_offset', 0)
|
||||||
if classifier_name is not None:
|
if classifiers is not None:
|
||||||
|
if isinstance(classifiers, str):
|
||||||
|
classifiers = (classifiers,)
|
||||||
if num_classes != default_cfg['num_classes']:
|
if num_classes != default_cfg['num_classes']:
|
||||||
|
for classifier_name in classifiers:
|
||||||
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
||||||
del state_dict[classifier_name + '.weight']
|
del state_dict[classifier_name + '.weight']
|
||||||
del state_dict[classifier_name + '.bias']
|
del state_dict[classifier_name + '.bias']
|
||||||
strict = False
|
strict = False
|
||||||
elif label_offset > 0:
|
elif label_offset > 0:
|
||||||
|
for classifier_name in classifiers:
|
||||||
# special case for pretrained weights with an extra background class in pretrained weights
|
# special case for pretrained weights with an extra background class in pretrained weights
|
||||||
classifier_weight = state_dict[classifier_name + '.weight']
|
classifier_weight = state_dict[classifier_name + '.weight']
|
||||||
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
||||||
|
@ -31,4 +31,4 @@ from .split_attn import SplitAttnConv2d
|
|||||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||||
from .weight_init import trunc_normal_
|
from .weight_init import trunc_normal_, variance_scaling_, lecun_normal_
|
||||||
|
@ -2,6 +2,8 @@ import torch
|
|||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
from torch.nn.init import _calculate_fan_in_and_fan_out
|
||||||
|
|
||||||
|
|
||||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||||
@ -58,3 +60,30 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
|||||||
>>> nn.init.trunc_normal_(w)
|
>>> nn.init.trunc_normal_(w)
|
||||||
"""
|
"""
|
||||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
|
||||||
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||||
|
if mode == 'fan_in':
|
||||||
|
denom = fan_in
|
||||||
|
elif mode == 'fan_out':
|
||||||
|
denom = fan_out
|
||||||
|
elif mode == 'fan_avg':
|
||||||
|
denom = (fan_in + fan_out) / 2
|
||||||
|
|
||||||
|
variance = scale / denom
|
||||||
|
|
||||||
|
if distribution == "truncated_normal":
|
||||||
|
# constant is stddev of standard normal truncated to (-2, 2)
|
||||||
|
trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
|
||||||
|
elif distribution == "normal":
|
||||||
|
tensor.normal_(std=math.sqrt(variance))
|
||||||
|
elif distribution == "uniform":
|
||||||
|
bound = math.sqrt(3 * variance)
|
||||||
|
tensor.uniform_(-bound, bound)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid distribution {distribution}")
|
||||||
|
|
||||||
|
|
||||||
|
def lecun_normal_(tensor):
|
||||||
|
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')
|
||||||
|
@ -100,14 +100,16 @@ default_cfgs = dict(
|
|||||||
nfnet_f7s=_dcfg(
|
nfnet_f7s=_dcfg(
|
||||||
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
|
url='', pool_size=(15, 15), input_size=(3, 480, 480), test_input_size=(3, 608, 608)),
|
||||||
|
|
||||||
nfnet_l0a=_dcfg(
|
nfnet_l0=_dcfg(
|
||||||
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)),
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nfnet_l0_ra2-45c6688d.pth',
|
||||||
nfnet_l0b=_dcfg(
|
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0),
|
||||||
url='', pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288)),
|
|
||||||
eca_nfnet_l0=_dcfg(
|
eca_nfnet_l0=_dcfg(
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l0_ra2-e3e9ac50.pth',
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l0_ra2-e3e9ac50.pth',
|
||||||
hf_hub='timm/eca_nfnet_l0',
|
hf_hub='timm/eca_nfnet_l0',
|
||||||
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0),
|
pool_size=(7, 7), input_size=(3, 224, 224), test_input_size=(3, 288, 288), crop_pct=1.0),
|
||||||
|
eca_nfnet_l1=_dcfg(
|
||||||
|
url='',
|
||||||
|
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0),
|
||||||
|
|
||||||
nf_regnet_b0=_dcfg(
|
nf_regnet_b0=_dcfg(
|
||||||
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
|
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
|
||||||
@ -232,15 +234,15 @@ model_cfgs = dict(
|
|||||||
nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'),
|
nfnet_f6s=_nfnet_cfg(depths=(7, 14, 42, 21), act_layer='silu'),
|
||||||
nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'),
|
nfnet_f7s=_nfnet_cfg(depths=(8, 16, 48, 24), act_layer='silu'),
|
||||||
|
|
||||||
# Experimental 'light' versions of nfnet-f that are little leaner
|
# Experimental 'light' versions of NFNet-F that are little leaner
|
||||||
nfnet_l0a=_nfnet_cfg(
|
nfnet_l0=_nfnet_cfg(
|
||||||
depths=(1, 2, 6, 3), channels=(256, 512, 1280, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
|
depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
|
||||||
attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'),
|
|
||||||
nfnet_l0b=_nfnet_cfg(
|
|
||||||
depths=(1, 2, 6, 3), channels=(256, 512, 1536, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
|
|
||||||
attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'),
|
attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'),
|
||||||
eca_nfnet_l0=_nfnet_cfg(
|
eca_nfnet_l0=_nfnet_cfg(
|
||||||
depths=(1, 2, 6, 3), channels=(256, 512, 1536, 1536), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
|
depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
|
||||||
|
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
|
||||||
|
eca_nfnet_l1=_nfnet_cfg(
|
||||||
|
depths=(2, 4, 12, 6), feat_mult=2, group_size=64, bottle_ratio=0.25,
|
||||||
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
|
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
|
||||||
|
|
||||||
# EffNet influenced RegNet defs.
|
# EffNet influenced RegNet defs.
|
||||||
@ -789,29 +791,29 @@ def nfnet_f7s(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def nfnet_l0a(pretrained=False, **kwargs):
|
def nfnet_l0(pretrained=False, **kwargs):
|
||||||
""" NFNet-L0a w/ SiLU
|
|
||||||
My experimental 'light' model w/ 1280 width stage 3, 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio
|
|
||||||
"""
|
|
||||||
return _create_normfreenet('nfnet_l0a', pretrained=pretrained, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def nfnet_l0b(pretrained=False, **kwargs):
|
|
||||||
""" NFNet-L0b w/ SiLU
|
""" NFNet-L0b w/ SiLU
|
||||||
My experimental 'light' model w/ 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio
|
My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & SE ratio
|
||||||
"""
|
"""
|
||||||
return _create_normfreenet('nfnet_l0b', pretrained=pretrained, **kwargs)
|
return _create_normfreenet('nfnet_l0', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def eca_nfnet_l0(pretrained=False, **kwargs):
|
def eca_nfnet_l0(pretrained=False, **kwargs):
|
||||||
""" ECA-NFNet-L0 w/ SiLU
|
""" ECA-NFNet-L0 w/ SiLU
|
||||||
My experimental 'light' model w/ 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
|
My experimental 'light' model w/ F0 repeats, 1.5x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
|
||||||
"""
|
"""
|
||||||
return _create_normfreenet('eca_nfnet_l0', pretrained=pretrained, **kwargs)
|
return _create_normfreenet('eca_nfnet_l0', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def eca_nfnet_l1(pretrained=False, **kwargs):
|
||||||
|
""" ECA-NFNet-L1 w/ SiLU
|
||||||
|
My experimental 'light' model w/ F1 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
|
||||||
|
"""
|
||||||
|
return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def nf_regnet_b0(pretrained=False, **kwargs):
|
def nf_regnet_b0(pretrained=False, **kwargs):
|
||||||
""" Normalization-Free RegNet-B0
|
""" Normalization-Free RegNet-B0
|
||||||
|
388
timm/models/pit.py
Normal file
388
timm/models/pit.py
Normal file
@ -0,0 +1,388 @@
|
|||||||
|
""" Pooling-based Vision Transformer (PiT) in PyTorch
|
||||||
|
|
||||||
|
A PyTorch implement of Pooling-based Vision Transformers as described in
|
||||||
|
'Rethinking Spatial Dimensions of Vision Transformers' - https://arxiv.org/abs/2103.16302
|
||||||
|
|
||||||
|
This code was adapted from the original version at https://github.com/naver-ai/pit, original copyright below.
|
||||||
|
|
||||||
|
Modifications for timm by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
# PiT
|
||||||
|
# Copyright 2021-present NAVER Corp.
|
||||||
|
# Apache License v2.0
|
||||||
|
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||||
|
from .layers import trunc_normal_, to_2tuple
|
||||||
|
from .registry import register_model
|
||||||
|
from .vision_transformer import Block
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
|
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
|
'first_conv': 'patch_embed.conv', 'classifier': 'head',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
# deit models (FB weights)
|
||||||
|
'pit_ti_224': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_730.pth'),
|
||||||
|
'pit_xs_224': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_781.pth'),
|
||||||
|
'pit_s_224': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_809.pth'),
|
||||||
|
'pit_b_224': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_820.pth'),
|
||||||
|
'pit_ti_distilled_224': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_ti_distill_746.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
|
'pit_xs_distilled_224': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_xs_distill_791.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
|
'pit_s_distilled_224': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_s_distill_819.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
|
'pit_b_distilled_224': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-pit-weights/pit_b_distill_840.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SequentialTuple(nn.Sequential):
|
||||||
|
""" This module exists to work around torchscript typing issues list -> list"""
|
||||||
|
def __init__(self, *args):
|
||||||
|
super(SequentialTuple, self).__init__(*args)
|
||||||
|
|
||||||
|
def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
for module in self:
|
||||||
|
x = module(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, base_dim, depth, heads, mlp_ratio, pool=None, drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None):
|
||||||
|
super(Transformer, self).__init__()
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
embed_dim = base_dim * heads
|
||||||
|
|
||||||
|
self.blocks = nn.Sequential(*[
|
||||||
|
Block(
|
||||||
|
dim=embed_dim,
|
||||||
|
num_heads=heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=True,
|
||||||
|
drop=drop_rate,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
drop_path=drop_path_prob[i],
|
||||||
|
norm_layer=partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
)
|
||||||
|
for i in range(depth)])
|
||||||
|
|
||||||
|
self.pool = pool
|
||||||
|
|
||||||
|
def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
x, cls_tokens = x
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
token_length = cls_tokens.shape[1]
|
||||||
|
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
|
||||||
|
x = self.blocks(x)
|
||||||
|
|
||||||
|
cls_tokens = x[:, :token_length]
|
||||||
|
x = x[:, token_length:]
|
||||||
|
x = x.transpose(1, 2).reshape(B, C, H, W)
|
||||||
|
|
||||||
|
if self.pool is not None:
|
||||||
|
x, cls_tokens = self.pool(x, cls_tokens)
|
||||||
|
return x, cls_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class ConvHeadPooling(nn.Module):
|
||||||
|
def __init__(self, in_feature, out_feature, stride, padding_mode='zeros'):
|
||||||
|
super(ConvHeadPooling, self).__init__()
|
||||||
|
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_feature, out_feature, kernel_size=stride + 1, padding=stride // 2, stride=stride,
|
||||||
|
padding_mode=padding_mode, groups=in_feature)
|
||||||
|
self.fc = nn.Linear(in_feature, out_feature)
|
||||||
|
|
||||||
|
def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
x = self.conv(x)
|
||||||
|
cls_token = self.fc(cls_token)
|
||||||
|
|
||||||
|
return x, cls_token
|
||||||
|
|
||||||
|
|
||||||
|
class ConvEmbedding(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, patch_size, stride, padding):
|
||||||
|
super(ConvEmbedding, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_channels, out_channels, kernel_size=patch_size, stride=stride, padding=padding, bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PoolingVisionTransformer(nn.Module):
|
||||||
|
""" Pooling-based Vision Transformer
|
||||||
|
|
||||||
|
A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers'
|
||||||
|
- https://arxiv.org/abs/2103.16302
|
||||||
|
"""
|
||||||
|
def __init__(self, img_size, patch_size, stride, base_dims, depth, heads,
|
||||||
|
mlp_ratio, num_classes=1000, in_chans=3, distilled=False,
|
||||||
|
attn_drop_rate=.0, drop_rate=.0, drop_path_rate=.0):
|
||||||
|
super(PoolingVisionTransformer, self).__init__()
|
||||||
|
|
||||||
|
padding = 0
|
||||||
|
img_size = to_2tuple(img_size)
|
||||||
|
patch_size = to_2tuple(patch_size)
|
||||||
|
height = math.floor((img_size[0] + 2 * padding - patch_size[0]) / stride + 1)
|
||||||
|
width = math.floor((img_size[1] + 2 * padding - patch_size[1]) / stride + 1)
|
||||||
|
|
||||||
|
self.base_dims = base_dims
|
||||||
|
self.heads = heads
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.num_tokens = 2 if distilled else 1
|
||||||
|
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.pos_embed = nn.Parameter(torch.randn(1, base_dims[0] * heads[0], height, width))
|
||||||
|
self.patch_embed = ConvEmbedding(in_chans, base_dims[0] * heads[0], patch_size, stride, padding)
|
||||||
|
|
||||||
|
self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, base_dims[0] * heads[0]))
|
||||||
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
|
transformers = []
|
||||||
|
# stochastic depth decay rule
|
||||||
|
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depth)).split(depth)]
|
||||||
|
for stage in range(len(depth)):
|
||||||
|
pool = None
|
||||||
|
if stage < len(heads) - 1:
|
||||||
|
pool = ConvHeadPooling(
|
||||||
|
base_dims[stage] * heads[stage], base_dims[stage + 1] * heads[stage + 1], stride=2)
|
||||||
|
transformers += [Transformer(
|
||||||
|
base_dims[stage], depth[stage], heads[stage], mlp_ratio, pool=pool,
|
||||||
|
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_prob=dpr[stage])
|
||||||
|
]
|
||||||
|
self.transformers = SequentialTuple(*transformers)
|
||||||
|
self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
|
||||||
|
self.embed_dim = base_dims[-1] * heads[-1]
|
||||||
|
|
||||||
|
# Classifier head
|
||||||
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
|
||||||
|
if num_classes > 0 and distilled else nn.Identity()
|
||||||
|
|
||||||
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
|
trunc_normal_(self.cls_token, std=.02)
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def no_weight_decay(self):
|
||||||
|
return {'pos_embed', 'cls_token'}
|
||||||
|
|
||||||
|
def get_classifier(self):
|
||||||
|
return self.head
|
||||||
|
|
||||||
|
def reset_classifier(self, num_classes, global_pool=''):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
|
||||||
|
if num_classes > 0 and self.num_tokens == 2 else nn.Identity()
|
||||||
|
|
||||||
|
def forward_features(self, x):
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
x = self.pos_drop(x + self.pos_embed)
|
||||||
|
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
||||||
|
x, cls_tokens = self.transformers((x, cls_tokens))
|
||||||
|
cls_tokens = self.norm(cls_tokens)
|
||||||
|
return cls_tokens
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.forward_features(x)
|
||||||
|
x_cls = self.head(x[:, 0])
|
||||||
|
if self.num_tokens > 1:
|
||||||
|
x_dist = self.head_dist(x[:, 1])
|
||||||
|
if self.training and not torch.jit.is_scripting():
|
||||||
|
return x_cls, x_dist
|
||||||
|
else:
|
||||||
|
return (x_cls + x_dist) / 2
|
||||||
|
else:
|
||||||
|
return x_cls
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
|
""" preprocess checkpoints """
|
||||||
|
out_dict = {}
|
||||||
|
p_blocks = re.compile(r'pools\.(\d)\.')
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
# FIXME need to update resize for PiT impl
|
||||||
|
# if k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
||||||
|
# # To resize pos embedding when using model at different size from pretrained weights
|
||||||
|
# v = resize_pos_embed(v, model.pos_embed)
|
||||||
|
k = p_blocks.sub(lambda exp: f'transformers.{int(exp.group(1))}.pool.', k)
|
||||||
|
out_dict[k] = v
|
||||||
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _create_pit(variant, pretrained=False, **kwargs):
|
||||||
|
default_cfg = deepcopy(default_cfgs[variant])
|
||||||
|
overlay_external_default_cfg(default_cfg, kwargs)
|
||||||
|
default_num_classes = default_cfg['num_classes']
|
||||||
|
default_img_size = default_cfg['input_size'][-2:]
|
||||||
|
img_size = kwargs.pop('img_size', default_img_size)
|
||||||
|
num_classes = kwargs.pop('num_classes', default_num_classes)
|
||||||
|
|
||||||
|
if kwargs.get('features_only', None):
|
||||||
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||||
|
|
||||||
|
model = build_model_with_cfg(
|
||||||
|
PoolingVisionTransformer, variant, pretrained,
|
||||||
|
default_cfg=default_cfg,
|
||||||
|
img_size=img_size,
|
||||||
|
num_classes=num_classes,
|
||||||
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def pit_b_224(pretrained, **kwargs):
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=14,
|
||||||
|
stride=7,
|
||||||
|
base_dims=[64, 64, 64],
|
||||||
|
depth=[3, 6, 4],
|
||||||
|
heads=[4, 8, 16],
|
||||||
|
mlp_ratio=4,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_pit('pit_b_224', pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def pit_s_224(pretrained, **kwargs):
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16,
|
||||||
|
stride=8,
|
||||||
|
base_dims=[48, 48, 48],
|
||||||
|
depth=[2, 6, 4],
|
||||||
|
heads=[3, 6, 12],
|
||||||
|
mlp_ratio=4,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_pit('pit_s_224', pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def pit_xs_224(pretrained, **kwargs):
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16,
|
||||||
|
stride=8,
|
||||||
|
base_dims=[48, 48, 48],
|
||||||
|
depth=[2, 6, 4],
|
||||||
|
heads=[2, 4, 8],
|
||||||
|
mlp_ratio=4,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_pit('pit_xs_224', pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def pit_ti_224(pretrained, **kwargs):
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16,
|
||||||
|
stride=8,
|
||||||
|
base_dims=[32, 32, 32],
|
||||||
|
depth=[2, 6, 4],
|
||||||
|
heads=[2, 4, 8],
|
||||||
|
mlp_ratio=4,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_pit('pit_ti_224', pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def pit_b_distilled_224(pretrained, **kwargs):
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=14,
|
||||||
|
stride=7,
|
||||||
|
base_dims=[64, 64, 64],
|
||||||
|
depth=[3, 6, 4],
|
||||||
|
heads=[4, 8, 16],
|
||||||
|
mlp_ratio=4,
|
||||||
|
distilled=True,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_pit('pit_b_distilled_224', pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def pit_s_distilled_224(pretrained, **kwargs):
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16,
|
||||||
|
stride=8,
|
||||||
|
base_dims=[48, 48, 48],
|
||||||
|
depth=[2, 6, 4],
|
||||||
|
heads=[3, 6, 12],
|
||||||
|
mlp_ratio=4,
|
||||||
|
distilled=True,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_pit('pit_s_distilled_224', pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def pit_xs_distilled_224(pretrained, **kwargs):
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16,
|
||||||
|
stride=8,
|
||||||
|
base_dims=[48, 48, 48],
|
||||||
|
depth=[2, 6, 4],
|
||||||
|
heads=[2, 4, 8],
|
||||||
|
mlp_ratio=4,
|
||||||
|
distilled=True,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_pit('pit_xs_distilled_224', pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def pit_ti_distilled_224(pretrained, **kwargs):
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16,
|
||||||
|
stride=8,
|
||||||
|
base_dims=[32, 32, 32],
|
||||||
|
depth=[2, 6, 4],
|
||||||
|
heads=[2, 4, 8],
|
||||||
|
mlp_ratio=4,
|
||||||
|
distilled=True,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return _create_pit('pit_ti_distilled_224', pretrained, **model_kwargs)
|
@ -57,12 +57,13 @@ model_cfgs = dict(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _cfg(url=''):
|
def _cfg(url='', **kwargs):
|
||||||
return {
|
return {
|
||||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
||||||
|
**kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -84,12 +85,16 @@ default_cfgs = dict(
|
|||||||
regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'),
|
regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'),
|
||||||
regnety_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'),
|
regnety_008=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth'),
|
||||||
regnety_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'),
|
regnety_016=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth'),
|
||||||
regnety_032=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth'),
|
regnety_032=_cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth',
|
||||||
|
crop_pct=1.0, test_input_size=(3, 288, 288)),
|
||||||
regnety_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'),
|
regnety_040=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'),
|
||||||
regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'),
|
regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'),
|
||||||
regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'),
|
regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'),
|
||||||
regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'),
|
regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'),
|
||||||
regnety_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth'),
|
regnety_160=_cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository
|
||||||
|
crop_pct=1.0, test_input_size=(3, 288, 288)),
|
||||||
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
|
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -328,11 +333,20 @@ class RegNet(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_fn(state_dict):
|
||||||
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||||
|
if 'model' in state_dict:
|
||||||
|
# For DeiT trained regnety_160 pretraiend model
|
||||||
|
state_dict = state_dict['model']
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def _create_regnet(variant, pretrained, **kwargs):
|
def _create_regnet(variant, pretrained, **kwargs):
|
||||||
return build_model_with_cfg(
|
return build_model_with_cfg(
|
||||||
RegNet, variant, pretrained,
|
RegNet, variant, pretrained,
|
||||||
default_cfg=default_cfgs[variant],
|
default_cfg=default_cfgs[variant],
|
||||||
model_cfg=model_cfgs[variant],
|
model_cfg=model_cfgs[variant],
|
||||||
|
pretrained_filter_fn=_filter_fn,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -274,7 +274,9 @@ class ResNetStage(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, norm_layer=None):
|
def create_resnetv2_stem(
|
||||||
|
in_chs, out_chs=64, stem_type='', preact=True,
|
||||||
|
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
|
||||||
stem = OrderedDict()
|
stem = OrderedDict()
|
||||||
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
|
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
|
||||||
|
|
||||||
@ -322,7 +324,8 @@ class ResNetV2(nn.Module):
|
|||||||
|
|
||||||
self.feature_info = []
|
self.feature_info = []
|
||||||
stem_chs = make_div(stem_chs * wf)
|
stem_chs = make_div(stem_chs * wf)
|
||||||
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
|
self.stem = create_resnetv2_stem(
|
||||||
|
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
|
||||||
stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv') if preact else 'stem.norm'
|
stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv') if preact else 'stem.norm'
|
||||||
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
|
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
|
||||||
|
|
||||||
|
@ -5,6 +5,9 @@ A PyTorch implement of Vision Transformers as described in
|
|||||||
|
|
||||||
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
||||||
|
|
||||||
|
DeiT model defs and weights from https://github.com/facebookresearch/deit,
|
||||||
|
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
||||||
|
|
||||||
Acknowledgments:
|
Acknowledgments:
|
||||||
* The paper authors for releasing code and weights, thanks!
|
* The paper authors for releasing code and weights, thanks!
|
||||||
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
||||||
@ -12,9 +15,6 @@ for some einops/einsum fun
|
|||||||
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
||||||
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
||||||
|
|
||||||
DeiT model defs and weights from https://github.com/facebookresearch/deit,
|
|
||||||
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
|
||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
@ -29,9 +29,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
from .helpers import build_model_with_cfg, overlay_external_default_cfg
|
||||||
from .layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_
|
from .layers import DropPath, to_2tuple, trunc_normal_, lecun_normal_
|
||||||
from .resnet import resnet26d, resnet50d
|
|
||||||
from .resnetv2 import ResNetV2
|
|
||||||
from .registry import register_model
|
from .registry import register_model
|
||||||
|
|
||||||
_logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
@ -98,20 +96,6 @@ default_cfgs = {
|
|||||||
hf_hub='timm/vit_huge_patch14_224_in21k',
|
hf_hub='timm/vit_huge_patch14_224_in21k',
|
||||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
||||||
|
|
||||||
# hybrid models (weights ported from official Google JAX impl)
|
|
||||||
'vit_base_resnet50_224_in21k': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
|
||||||
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'),
|
|
||||||
'vit_base_resnet50_384': _cfg(
|
|
||||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
|
||||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'),
|
|
||||||
|
|
||||||
# hybrid models (my experiments)
|
|
||||||
'vit_small_resnet26d_224': _cfg(),
|
|
||||||
'vit_small_resnet50d_s3_224': _cfg(),
|
|
||||||
'vit_base_resnet26d_224': _cfg(),
|
|
||||||
'vit_base_resnet50d_224': _cfg(),
|
|
||||||
|
|
||||||
# deit models (FB weights)
|
# deit models (FB weights)
|
||||||
'vit_deit_tiny_patch16_224': _cfg(
|
'vit_deit_tiny_patch16_224': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
||||||
@ -123,14 +107,17 @@ default_cfgs = {
|
|||||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
||||||
input_size=(3, 384, 384), crop_pct=1.0),
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
'vit_deit_tiny_distilled_patch16_224': _cfg(
|
'vit_deit_tiny_distilled_patch16_224': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'),
|
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
'vit_deit_small_distilled_patch16_224': _cfg(
|
'vit_deit_small_distilled_patch16_224': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'),
|
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
'vit_deit_base_distilled_patch16_224': _cfg(
|
'vit_deit_base_distilled_patch16_224': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ),
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
||||||
|
classifier=('head', 'head_dist')),
|
||||||
'vit_deit_base_distilled_patch16_384': _cfg(
|
'vit_deit_base_distilled_patch16_384': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
||||||
input_size=(3, 384, 384), crop_pct=1.0),
|
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -158,7 +145,6 @@ class Attention(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
head_dim = dim // num_heads
|
head_dim = dim // num_heads
|
||||||
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
|
||||||
self.scale = qk_scale or head_dim ** -0.5
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
|
||||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||||
@ -224,56 +210,20 @@ class PatchEmbed(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class HybridEmbed(nn.Module):
|
|
||||||
""" CNN Feature Map Embedding
|
|
||||||
Extract feature map from CNN, flatten, project to embedding dim.
|
|
||||||
"""
|
|
||||||
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
|
||||||
super().__init__()
|
|
||||||
assert isinstance(backbone, nn.Module)
|
|
||||||
img_size = to_2tuple(img_size)
|
|
||||||
self.img_size = img_size
|
|
||||||
self.backbone = backbone
|
|
||||||
if feature_size is None:
|
|
||||||
with torch.no_grad():
|
|
||||||
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
|
|
||||||
# map for all networks, the feature metadata has reliable channel and stride info, but using
|
|
||||||
# stride to calc feature dim requires info about padding of each stage that isn't captured.
|
|
||||||
training = backbone.training
|
|
||||||
if training:
|
|
||||||
backbone.eval()
|
|
||||||
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
|
||||||
if isinstance(o, (list, tuple)):
|
|
||||||
o = o[-1] # last feature if backbone outputs list/tuple of features
|
|
||||||
feature_size = o.shape[-2:]
|
|
||||||
feature_dim = o.shape[1]
|
|
||||||
backbone.train(training)
|
|
||||||
else:
|
|
||||||
feature_size = to_2tuple(feature_size)
|
|
||||||
if hasattr(self.backbone, 'feature_info'):
|
|
||||||
feature_dim = self.backbone.feature_info.channels()[-1]
|
|
||||||
else:
|
|
||||||
feature_dim = self.backbone.num_features
|
|
||||||
self.num_patches = feature_size[0] * feature_size[1]
|
|
||||||
self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.backbone(x)
|
|
||||||
if isinstance(x, (list, tuple)):
|
|
||||||
x = x[-1] # last feature if backbone outputs list/tuple of features
|
|
||||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class VisionTransformer(nn.Module):
|
class VisionTransformer(nn.Module):
|
||||||
""" Vision Transformer
|
""" Vision Transformer
|
||||||
|
|
||||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||||
https://arxiv.org/abs/2010.11929
|
- https://arxiv.org/abs/2010.11929
|
||||||
|
|
||||||
|
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
||||||
|
- https://arxiv.org/abs/2012.12877
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||||
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False,
|
||||||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None):
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
|
||||||
|
act_layer=None, weight_init=''):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
img_size (int, tuple): input image size
|
img_size (int, tuple): input image size
|
||||||
@ -287,39 +237,40 @@ class VisionTransformer(nn.Module):
|
|||||||
qkv_bias (bool): enable bias for qkv if True
|
qkv_bias (bool): enable bias for qkv if True
|
||||||
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
||||||
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
||||||
|
distilled (bool): model includes a distillation token and head as in DeiT models
|
||||||
drop_rate (float): dropout rate
|
drop_rate (float): dropout rate
|
||||||
attn_drop_rate (float): attention dropout rate
|
attn_drop_rate (float): attention dropout rate
|
||||||
drop_path_rate (float): stochastic depth rate
|
drop_path_rate (float): stochastic depth rate
|
||||||
hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
|
embed_layer (nn.Module): patch embedding layer
|
||||||
norm_layer: (nn.Module): normalization layer
|
norm_layer: (nn.Module): normalization layer
|
||||||
|
weight_init: (str): weight init scheme
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
|
self.num_tokens = 2 if distilled else 1
|
||||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
act_layer = act_layer or nn.GELU
|
||||||
|
|
||||||
if hybrid_backbone is not None:
|
self.patch_embed = embed_layer(
|
||||||
self.patch_embed = HybridEmbed(
|
|
||||||
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
||||||
else:
|
|
||||||
self.patch_embed = PatchEmbed(
|
|
||||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||||
num_patches = self.patch_embed.num_patches
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
||||||
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.Sequential(*[
|
||||||
Block(
|
Block(
|
||||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
||||||
for i in range(depth)])
|
for i in range(depth)])
|
||||||
self.norm = norm_layer(embed_dim)
|
self.norm = norm_layer(embed_dim)
|
||||||
|
|
||||||
# Representation layer
|
# Representation layer
|
||||||
if representation_size:
|
if representation_size and not distilled:
|
||||||
self.num_features = representation_size
|
self.num_features = representation_size
|
||||||
self.pre_logits = nn.Sequential(OrderedDict([
|
self.pre_logits = nn.Sequential(OrderedDict([
|
||||||
('fc', nn.Linear(embed_dim, representation_size)),
|
('fc', nn.Linear(embed_dim, representation_size)),
|
||||||
@ -328,110 +279,118 @@ class VisionTransformer(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.pre_logits = nn.Identity()
|
self.pre_logits = nn.Identity()
|
||||||
|
|
||||||
# Classifier head
|
# Classifier head(s)
|
||||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
self.head_dist = None
|
||||||
|
if distilled:
|
||||||
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
# Weight init
|
||||||
|
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
|
||||||
|
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
|
||||||
trunc_normal_(self.pos_embed, std=.02)
|
trunc_normal_(self.pos_embed, std=.02)
|
||||||
|
if self.dist_token is not None:
|
||||||
|
trunc_normal_(self.dist_token, std=.02)
|
||||||
|
if weight_init.startswith('jax'):
|
||||||
|
# leave cls token as zeros to match jax impl
|
||||||
|
for n, m in self.named_modules():
|
||||||
|
_init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
|
||||||
|
else:
|
||||||
trunc_normal_(self.cls_token, std=.02)
|
trunc_normal_(self.cls_token, std=.02)
|
||||||
self.apply(self._init_weights)
|
self.apply(_init_vit_weights)
|
||||||
|
|
||||||
def _init_weights(self, m):
|
def _init_weights(self, m):
|
||||||
if isinstance(m, nn.Linear):
|
# this fn left here for compat with downstream users
|
||||||
trunc_normal_(m.weight, std=.02)
|
_init_vit_weights(m)
|
||||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def no_weight_decay(self):
|
def no_weight_decay(self):
|
||||||
return {'pos_embed', 'cls_token'}
|
return {'pos_embed', 'cls_token', 'dist_token'}
|
||||||
|
|
||||||
def get_classifier(self):
|
def get_classifier(self):
|
||||||
|
if self.dist_token is None:
|
||||||
return self.head
|
return self.head
|
||||||
|
else:
|
||||||
|
return self.head, self.head_dist
|
||||||
|
|
||||||
def reset_classifier(self, num_classes, global_pool=''):
|
def reset_classifier(self, num_classes, global_pool=''):
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
if self.num_tokens == 2:
|
||||||
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
B = x.shape[0]
|
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
|
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||||
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
if self.dist_token is None:
|
||||||
x = torch.cat((cls_tokens, x), dim=1)
|
x = torch.cat((cls_token, x), dim=1)
|
||||||
x = x + self.pos_embed
|
else:
|
||||||
x = self.pos_drop(x)
|
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
|
x = self.pos_drop(x + self.pos_embed)
|
||||||
for blk in self.blocks:
|
x = self.blocks(x)
|
||||||
x = blk(x)
|
|
||||||
|
|
||||||
x = self.norm(x)[:, 0]
|
|
||||||
x = self.pre_logits(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.forward_features(x)
|
|
||||||
x = self.head(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DistilledVisionTransformer(VisionTransformer):
|
|
||||||
""" Vision Transformer with distillation token.
|
|
||||||
|
|
||||||
Paper: `Training data-efficient image transformers & distillation through attention` -
|
|
||||||
https://arxiv.org/abs/2012.12877
|
|
||||||
|
|
||||||
This impl of distilled ViT is taken from https://github.com/facebookresearch/deit
|
|
||||||
"""
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
|
||||||
num_patches = self.patch_embed.num_patches
|
|
||||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
|
|
||||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
|
||||||
|
|
||||||
trunc_normal_(self.dist_token, std=.02)
|
|
||||||
trunc_normal_(self.pos_embed, std=.02)
|
|
||||||
self.head_dist.apply(self._init_weights)
|
|
||||||
|
|
||||||
def forward_features(self, x):
|
|
||||||
B = x.shape[0]
|
|
||||||
x = self.patch_embed(x)
|
|
||||||
|
|
||||||
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
|
||||||
dist_token = self.dist_token.expand(B, -1, -1)
|
|
||||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
|
||||||
|
|
||||||
x = x + self.pos_embed
|
|
||||||
x = self.pos_drop(x)
|
|
||||||
|
|
||||||
for blk in self.blocks:
|
|
||||||
x = blk(x)
|
|
||||||
|
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
if self.dist_token is None:
|
||||||
|
return self.pre_logits(x[:, 0])
|
||||||
|
else:
|
||||||
return x[:, 0], x[:, 1]
|
return x[:, 0], x[:, 1]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x, x_dist = self.forward_features(x)
|
x = self.forward_features(x)
|
||||||
x = self.head(x)
|
if self.head_dist is not None:
|
||||||
x_dist = self.head_dist(x_dist)
|
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
|
||||||
if self.training:
|
if self.training and not torch.jit.is_scripting():
|
||||||
|
# during inference, return the average of both classifier predictions
|
||||||
return x, x_dist
|
return x, x_dist
|
||||||
else:
|
else:
|
||||||
# during inference, return the average of both classifier predictions
|
|
||||||
return (x + x_dist) / 2
|
return (x + x_dist) / 2
|
||||||
|
else:
|
||||||
|
x = self.head(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def resize_pos_embed(posemb, posemb_new):
|
def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
|
||||||
|
""" ViT weight initialization
|
||||||
|
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
||||||
|
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
||||||
|
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
||||||
|
"""
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
if n.startswith('head'):
|
||||||
|
nn.init.zeros_(m.weight)
|
||||||
|
nn.init.constant_(m.bias, head_bias)
|
||||||
|
elif n.startswith('pre_logits'):
|
||||||
|
lecun_normal_(m.weight)
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
else:
|
||||||
|
if jax_impl:
|
||||||
|
nn.init.xavier_uniform_(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
if 'mlp' in n:
|
||||||
|
nn.init.normal_(m.bias, std=1e-6)
|
||||||
|
else:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
else:
|
||||||
|
trunc_normal_(m.weight, std=.02)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif jax_impl and isinstance(m, nn.Conv2d):
|
||||||
|
# NOTE conv was left to pytorch default in my original init
|
||||||
|
lecun_normal_(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
nn.init.zeros_(m.bias)
|
||||||
|
nn.init.ones_(m.weight)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_pos_embed(posemb, posemb_new, num_tokens=1):
|
||||||
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
||||||
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
||||||
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
||||||
ntok_new = posemb_new.shape[1]
|
ntok_new = posemb_new.shape[1]
|
||||||
if True:
|
if num_tokens:
|
||||||
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
|
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
||||||
ntok_new -= 1
|
ntok_new -= num_tokens
|
||||||
else:
|
else:
|
||||||
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
||||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||||
@ -457,12 +416,13 @@ def checkpoint_filter_fn(state_dict, model):
|
|||||||
v = v.reshape(O, -1, H, W)
|
v = v.reshape(O, -1, H, W)
|
||||||
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
||||||
# To resize pos embedding when using model at different size from pretrained weights
|
# To resize pos embedding when using model at different size from pretrained weights
|
||||||
v = resize_pos_embed(v, model.pos_embed)
|
v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1))
|
||||||
out_dict[k] = v
|
out_dict[k] = v
|
||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
|
def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
|
||||||
|
if default_cfg is None:
|
||||||
default_cfg = deepcopy(default_cfgs[variant])
|
default_cfg = deepcopy(default_cfgs[variant])
|
||||||
overlay_external_default_cfg(default_cfg, kwargs)
|
overlay_external_default_cfg(default_cfg, kwargs)
|
||||||
default_num_classes = default_cfg['num_classes']
|
default_num_classes = default_cfg['num_classes']
|
||||||
@ -480,9 +440,8 @@ def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwa
|
|||||||
if kwargs.get('features_only', None):
|
if kwargs.get('features_only', None):
|
||||||
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
||||||
|
|
||||||
model_cls = DistilledVisionTransformer if distilled else VisionTransformer
|
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
model_cls, variant, pretrained,
|
VisionTransformer, variant, pretrained,
|
||||||
default_cfg=default_cfg,
|
default_cfg=default_cfg,
|
||||||
img_size=img_size,
|
img_size=img_size,
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
@ -495,7 +454,11 @@ def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwa
|
|||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_small_patch16_224(pretrained=False, **kwargs):
|
def vit_small_patch16_224(pretrained=False, **kwargs):
|
||||||
""" My custom 'small' ViT model. Depth=8, heads=8= mlp_ratio=3."""
|
""" My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.
|
||||||
|
NOTE:
|
||||||
|
* this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6
|
||||||
|
* this model does not have a bias for QKV (unlike the official ViT and DeiT models)
|
||||||
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
|
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
|
||||||
qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
|
qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
|
||||||
@ -640,76 +603,6 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
|
|
||||||
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
|
|
||||||
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
||||||
"""
|
|
||||||
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
|
|
||||||
backbone = ResNetV2(
|
|
||||||
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
|
||||||
preact=False, stem_type='same', conv_layer=StdConv2dSame)
|
|
||||||
model_kwargs = dict(
|
|
||||||
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone,
|
|
||||||
representation_size=768, **kwargs)
|
|
||||||
model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def vit_base_resnet50_384(pretrained=False, **kwargs):
|
|
||||||
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
|
||||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
||||||
"""
|
|
||||||
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
|
|
||||||
backbone = ResNetV2(
|
|
||||||
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
|
||||||
preact=False, stem_type='same', conv_layer=StdConv2dSame)
|
|
||||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
||||||
model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def vit_small_resnet26d_224(pretrained=False, **kwargs):
|
|
||||||
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
|
|
||||||
"""
|
|
||||||
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
|
||||||
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
|
||||||
model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
|
|
||||||
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
|
|
||||||
"""
|
|
||||||
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
|
|
||||||
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
|
||||||
model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def vit_base_resnet26d_224(pretrained=False, **kwargs):
|
|
||||||
""" Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
|
|
||||||
"""
|
|
||||||
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
|
||||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
||||||
model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def vit_base_resnet50d_224(pretrained=False, **kwargs):
|
|
||||||
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
|
|
||||||
"""
|
|
||||||
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
|
||||||
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
||||||
model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
|
def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
|
||||||
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
||||||
|
313
timm/models/vision_transformer_hybrid.py
Normal file
313
timm/models/vision_transformer_hybrid.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
""" Hybrid Vision Transformer (ViT) in PyTorch
|
||||||
|
|
||||||
|
A PyTorch implement of the Hybrid Vision Transformers as described in
|
||||||
|
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
|
||||||
|
- https://arxiv.org/abs/2010.11929
|
||||||
|
|
||||||
|
NOTE This relies on code in vision_transformer.py. The hybrid model definitions were moved here to
|
||||||
|
keep file sizes sane.
|
||||||
|
|
||||||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
|
"""
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
from .layers import StdConv2dSame, StdConv2d, to_2tuple
|
||||||
|
from .resnet import resnet26d, resnet50d
|
||||||
|
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
||||||
|
from .registry import register_model
|
||||||
|
from timm.models.vision_transformer import _create_vision_transformer
|
||||||
|
|
||||||
|
|
||||||
|
def _cfg(url='', **kwargs):
|
||||||
|
return {
|
||||||
|
'url': url,
|
||||||
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||||
|
'crop_pct': .9, 'interpolation': 'bicubic',
|
||||||
|
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
||||||
|
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
default_cfgs = {
|
||||||
|
# hybrid in-21k models (weights ported from official Google JAX impl where they exist)
|
||||||
|
'vit_base_r50_s16_224_in21k': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
||||||
|
num_classes=21843, crop_pct=0.9),
|
||||||
|
|
||||||
|
# hybrid in-1k models (weights ported from official JAX impl)
|
||||||
|
'vit_base_r50_s16_384': _cfg(
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
|
||||||
|
# hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones)
|
||||||
|
'vit_tiny_r_s16_p8_224': _cfg(),
|
||||||
|
'vit_small_r_s16_p8_224': _cfg(),
|
||||||
|
'vit_small_r20_s16_p2_224': _cfg(),
|
||||||
|
'vit_small_r20_s16_224': _cfg(),
|
||||||
|
'vit_small_r26_s32_224': _cfg(),
|
||||||
|
'vit_base_r20_s16_224': _cfg(),
|
||||||
|
'vit_base_r26_s32_224': _cfg(),
|
||||||
|
'vit_base_r50_s16_224': _cfg(),
|
||||||
|
'vit_large_r50_s32_224': _cfg(),
|
||||||
|
|
||||||
|
# hybrid models (using timm resnet backbones)
|
||||||
|
'vit_small_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||||
|
'vit_small_resnet50d_s16_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||||
|
'vit_base_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||||
|
'vit_base_resnet50d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class HybridEmbed(nn.Module):
|
||||||
|
""" CNN Feature Map Embedding
|
||||||
|
Extract feature map from CNN, flatten, project to embedding dim.
|
||||||
|
"""
|
||||||
|
def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(backbone, nn.Module)
|
||||||
|
img_size = to_2tuple(img_size)
|
||||||
|
patch_size = to_2tuple(patch_size)
|
||||||
|
self.img_size = img_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.backbone = backbone
|
||||||
|
if feature_size is None:
|
||||||
|
with torch.no_grad():
|
||||||
|
# NOTE Most reliable way of determining output dims is to run forward pass
|
||||||
|
training = backbone.training
|
||||||
|
if training:
|
||||||
|
backbone.eval()
|
||||||
|
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
|
||||||
|
if isinstance(o, (list, tuple)):
|
||||||
|
o = o[-1] # last feature if backbone outputs list/tuple of features
|
||||||
|
feature_size = o.shape[-2:]
|
||||||
|
feature_dim = o.shape[1]
|
||||||
|
backbone.train(training)
|
||||||
|
else:
|
||||||
|
feature_size = to_2tuple(feature_size)
|
||||||
|
if hasattr(self.backbone, 'feature_info'):
|
||||||
|
feature_dim = self.backbone.feature_info.channels()[-1]
|
||||||
|
else:
|
||||||
|
feature_dim = self.backbone.num_features
|
||||||
|
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
||||||
|
self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1]
|
||||||
|
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.backbone(x)
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
x = x[-1] # last feature if backbone outputs list/tuple of features
|
||||||
|
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
|
||||||
|
default_cfg = deepcopy(default_cfgs[variant])
|
||||||
|
embed_layer = partial(HybridEmbed, backbone=backbone)
|
||||||
|
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
|
||||||
|
return _create_vision_transformer(
|
||||||
|
variant, pretrained=pretrained, default_cfg=default_cfg, embed_layer=embed_layer, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _resnetv2(layers=(3, 4, 9), **kwargs):
|
||||||
|
""" ResNet-V2 backbone helper"""
|
||||||
|
padding_same = kwargs.get('padding_same', True)
|
||||||
|
if padding_same:
|
||||||
|
stem_type = 'same'
|
||||||
|
conv_layer = StdConv2dSame
|
||||||
|
else:
|
||||||
|
stem_type = ''
|
||||||
|
conv_layer = StdConv2d
|
||||||
|
if len(layers):
|
||||||
|
backbone = ResNetV2(
|
||||||
|
layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
||||||
|
preact=False, stem_type=stem_type, conv_layer=conv_layer)
|
||||||
|
else:
|
||||||
|
backbone = create_resnetv2_stem(
|
||||||
|
kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer)
|
||||||
|
return backbone
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
|
||||||
|
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||||
|
"""
|
||||||
|
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
|
||||||
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
|
||||||
|
# NOTE this is forwarding to model def above for backwards compatibility
|
||||||
|
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_base_r50_s16_384(pretrained=False, **kwargs):
|
||||||
|
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||||
|
"""
|
||||||
|
backbone = _resnetv2((3, 4, 9), **kwargs)
|
||||||
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_base_resnet50_384(pretrained=False, **kwargs):
|
||||||
|
# NOTE this is forwarding to model def above for backwards compatibility
|
||||||
|
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
||||||
|
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
||||||
|
"""
|
||||||
|
backbone = _resnetv2(layers=(), **kwargs)
|
||||||
|
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_small_r_s16_p8_224(pretrained=False, **kwargs):
|
||||||
|
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
||||||
|
"""
|
||||||
|
backbone = _resnetv2(layers=(), **kwargs)
|
||||||
|
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_small_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
|
||||||
|
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
|
||||||
|
"""
|
||||||
|
backbone = _resnetv2((2, 4), **kwargs)
|
||||||
|
model_kwargs = dict(patch_size=2, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_small_r20_s16_p2_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_small_r20_s16_224(pretrained=False, **kwargs):
|
||||||
|
""" R20+ViT-S/S16 hybrid.
|
||||||
|
"""
|
||||||
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
||||||
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_small_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_small_r26_s32_224(pretrained=False, **kwargs):
|
||||||
|
""" R26+ViT-S/S32 hybrid.
|
||||||
|
"""
|
||||||
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
||||||
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_small_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_base_r20_s16_224(pretrained=False, **kwargs):
|
||||||
|
""" R20+ViT-B/S16 hybrid.
|
||||||
|
"""
|
||||||
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
||||||
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_base_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_base_r26_s32_224(pretrained=False, **kwargs):
|
||||||
|
""" R26+ViT-B/S32 hybrid.
|
||||||
|
"""
|
||||||
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
||||||
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_base_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_base_r50_s16_224(pretrained=False, **kwargs):
|
||||||
|
""" R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
|
"""
|
||||||
|
backbone = _resnetv2((3, 4, 9), **kwargs)
|
||||||
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_base_r50_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_large_r50_s32_224(pretrained=False, **kwargs):
|
||||||
|
""" R50+ViT-L/S32 hybrid.
|
||||||
|
"""
|
||||||
|
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
||||||
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_small_resnet26d_224(pretrained=False, **kwargs):
|
||||||
|
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
|
||||||
|
"""
|
||||||
|
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
||||||
|
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_small_resnet50d_s16_224(pretrained=False, **kwargs):
|
||||||
|
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
|
||||||
|
"""
|
||||||
|
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
|
||||||
|
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_base_resnet26d_224(pretrained=False, **kwargs):
|
||||||
|
""" Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
|
||||||
|
"""
|
||||||
|
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
||||||
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_base_resnet50d_224(pretrained=False, **kwargs):
|
||||||
|
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
|
||||||
|
"""
|
||||||
|
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
||||||
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
||||||
|
model = _create_vision_transformer_hybrid(
|
||||||
|
'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
@ -10,4 +10,4 @@ from .radam import RAdam
|
|||||||
from .rmsprop_tf import RMSpropTF
|
from .rmsprop_tf import RMSpropTF
|
||||||
from .sgdp import SGDP
|
from .sgdp import SGDP
|
||||||
|
|
||||||
from .optim_factory import create_optimizer
|
from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
|
@ -1,8 +1,11 @@
|
|||||||
""" Optimizer Factory w/ Custom Weight Decay
|
""" Optimizer Factory w/ Custom Weight Decay
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import optim as optim
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
|
||||||
from .adafactor import Adafactor
|
from .adafactor import Adafactor
|
||||||
from .adahessian import Adahessian
|
from .adahessian import Adahessian
|
||||||
@ -37,9 +40,63 @@ def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
|||||||
{'params': decay, 'weight_decay': weight_decay}]
|
{'params': decay, 'weight_decay': weight_decay}]
|
||||||
|
|
||||||
|
|
||||||
|
def optimizer_kwargs(cfg):
|
||||||
|
""" cfg/argparse to kwargs helper
|
||||||
|
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
|
||||||
|
"""
|
||||||
|
kwargs = dict(
|
||||||
|
optimizer_name=cfg.opt,
|
||||||
|
learning_rate=cfg.lr,
|
||||||
|
weight_decay=cfg.weight_decay,
|
||||||
|
momentum=cfg.momentum)
|
||||||
|
if getattr(cfg, 'opt_eps', None) is not None:
|
||||||
|
kwargs['eps'] = cfg.opt_eps
|
||||||
|
if getattr(cfg, 'opt_betas', None) is not None:
|
||||||
|
kwargs['betas'] = cfg.opt_betas
|
||||||
|
if getattr(cfg, 'opt_args', None) is not None:
|
||||||
|
kwargs.update(cfg.opt_args)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(args, model, filter_bias_and_bn=True):
|
def create_optimizer(args, model, filter_bias_and_bn=True):
|
||||||
opt_lower = args.opt.lower()
|
""" Legacy optimizer factory for backwards compatibility.
|
||||||
weight_decay = args.weight_decay
|
NOTE: Use create_optimizer_v2 for new code.
|
||||||
|
"""
|
||||||
|
return create_optimizer_v2(
|
||||||
|
model,
|
||||||
|
**optimizer_kwargs(cfg=args),
|
||||||
|
filter_bias_and_bn=filter_bias_and_bn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_optimizer_v2(
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer_name: str = 'sgd',
|
||||||
|
learning_rate: Optional[float] = None,
|
||||||
|
weight_decay: float = 0.,
|
||||||
|
momentum: float = 0.9,
|
||||||
|
filter_bias_and_bn: bool = True,
|
||||||
|
**kwargs):
|
||||||
|
""" Create an optimizer.
|
||||||
|
|
||||||
|
TODO currently the model is passed in and all parameters are selected for optimization.
|
||||||
|
For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
|
||||||
|
* a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
|
||||||
|
* expose the parameters interface and leave it up to caller
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): model containing parameters to optimize
|
||||||
|
optimizer_name: name of optimizer to create
|
||||||
|
learning_rate: initial learning rate
|
||||||
|
weight_decay: weight decay to apply in optimizer
|
||||||
|
momentum: momentum for momentum based optimizers (others may use betas via kwargs)
|
||||||
|
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
|
||||||
|
**kwargs: extra optimizer specific kwargs to pass through
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optimizer
|
||||||
|
"""
|
||||||
|
opt_lower = optimizer_name.lower()
|
||||||
if weight_decay and filter_bias_and_bn:
|
if weight_decay and filter_bias_and_bn:
|
||||||
skip = {}
|
skip = {}
|
||||||
if hasattr(model, 'no_weight_decay'):
|
if hasattr(model, 'no_weight_decay'):
|
||||||
@ -48,26 +105,18 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
|||||||
weight_decay = 0.
|
weight_decay = 0.
|
||||||
else:
|
else:
|
||||||
parameters = model.parameters()
|
parameters = model.parameters()
|
||||||
|
|
||||||
if 'fused' in opt_lower:
|
if 'fused' in opt_lower:
|
||||||
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
||||||
|
|
||||||
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
|
opt_args = dict(lr=learning_rate, weight_decay=weight_decay, **kwargs)
|
||||||
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
|
|
||||||
opt_args['eps'] = args.opt_eps
|
|
||||||
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
|
|
||||||
opt_args['betas'] = args.opt_betas
|
|
||||||
if hasattr(args, 'opt_args') and args.opt_args is not None:
|
|
||||||
opt_args.update(args.opt_args)
|
|
||||||
|
|
||||||
opt_split = opt_lower.split('_')
|
opt_split = opt_lower.split('_')
|
||||||
opt_lower = opt_split[-1]
|
opt_lower = opt_split[-1]
|
||||||
if opt_lower == 'sgd' or opt_lower == 'nesterov':
|
if opt_lower == 'sgd' or opt_lower == 'nesterov':
|
||||||
opt_args.pop('eps', None)
|
opt_args.pop('eps', None)
|
||||||
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||||
elif opt_lower == 'momentum':
|
elif opt_lower == 'momentum':
|
||||||
opt_args.pop('eps', None)
|
opt_args.pop('eps', None)
|
||||||
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
|
||||||
elif opt_lower == 'adam':
|
elif opt_lower == 'adam':
|
||||||
optimizer = optim.Adam(parameters, **opt_args)
|
optimizer = optim.Adam(parameters, **opt_args)
|
||||||
elif opt_lower == 'adamw':
|
elif opt_lower == 'adamw':
|
||||||
@ -79,29 +128,29 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
|||||||
elif opt_lower == 'adamp':
|
elif opt_lower == 'adamp':
|
||||||
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
|
optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
|
||||||
elif opt_lower == 'sgdp':
|
elif opt_lower == 'sgdp':
|
||||||
optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||||
elif opt_lower == 'adadelta':
|
elif opt_lower == 'adadelta':
|
||||||
optimizer = optim.Adadelta(parameters, **opt_args)
|
optimizer = optim.Adadelta(parameters, **opt_args)
|
||||||
elif opt_lower == 'adafactor':
|
elif opt_lower == 'adafactor':
|
||||||
if not args.lr:
|
if not learning_rate:
|
||||||
opt_args['lr'] = None
|
opt_args['lr'] = None
|
||||||
optimizer = Adafactor(parameters, **opt_args)
|
optimizer = Adafactor(parameters, **opt_args)
|
||||||
elif opt_lower == 'adahessian':
|
elif opt_lower == 'adahessian':
|
||||||
optimizer = Adahessian(parameters, **opt_args)
|
optimizer = Adahessian(parameters, **opt_args)
|
||||||
elif opt_lower == 'rmsprop':
|
elif opt_lower == 'rmsprop':
|
||||||
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
|
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
|
||||||
elif opt_lower == 'rmsproptf':
|
elif opt_lower == 'rmsproptf':
|
||||||
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
|
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
|
||||||
elif opt_lower == 'novograd':
|
elif opt_lower == 'novograd':
|
||||||
optimizer = NovoGrad(parameters, **opt_args)
|
optimizer = NovoGrad(parameters, **opt_args)
|
||||||
elif opt_lower == 'nvnovograd':
|
elif opt_lower == 'nvnovograd':
|
||||||
optimizer = NvNovoGrad(parameters, **opt_args)
|
optimizer = NvNovoGrad(parameters, **opt_args)
|
||||||
elif opt_lower == 'fusedsgd':
|
elif opt_lower == 'fusedsgd':
|
||||||
opt_args.pop('eps', None)
|
opt_args.pop('eps', None)
|
||||||
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
|
||||||
elif opt_lower == 'fusedmomentum':
|
elif opt_lower == 'fusedmomentum':
|
||||||
opt_args.pop('eps', None)
|
opt_args.pop('eps', None)
|
||||||
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
|
||||||
elif opt_lower == 'fusedadam':
|
elif opt_lower == 'fusedadam':
|
||||||
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
|
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
|
||||||
elif opt_lower == 'fusedadamw':
|
elif opt_lower == 'fusedadamw':
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = '0.4.6'
|
__version__ = '0.4.7'
|
||||||
|
18
train.py
18
train.py
@ -33,7 +33,7 @@ from timm.models import create_model, safe_model_name, resume_checkpoint, load_c
|
|||||||
convert_splitbn_model, model_parameters
|
convert_splitbn_model, model_parameters
|
||||||
from timm.utils import *
|
from timm.utils import *
|
||||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||||
from timm.optim import create_optimizer
|
from timm.optim import create_optimizer_v2, optimizer_kwargs
|
||||||
from timm.scheduler import create_scheduler
|
from timm.scheduler import create_scheduler
|
||||||
from timm.utils import ApexScaler, NativeScaler
|
from timm.utils import ApexScaler, NativeScaler
|
||||||
|
|
||||||
@ -142,6 +142,8 @@ parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
|||||||
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
||||||
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
||||||
help='number of epochs to train (default: 2)')
|
help='number of epochs to train (default: 2)')
|
||||||
|
parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
|
||||||
|
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
|
||||||
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
|
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
|
||||||
help='manual epoch number (useful on restarts)')
|
help='manual epoch number (useful on restarts)')
|
||||||
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
||||||
@ -258,6 +260,8 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
|||||||
help='disable fast prefetcher')
|
help='disable fast prefetcher')
|
||||||
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
||||||
help='path to output folder (default: none, current dir)')
|
help='path to output folder (default: none, current dir)')
|
||||||
|
parser.add_argument('--experiment', default='', type=str, metavar='NAME',
|
||||||
|
help='name of train experiment, name of sub-folder for output')
|
||||||
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
|
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
|
||||||
help='Best metric (default: "top1"')
|
help='Best metric (default: "top1"')
|
||||||
parser.add_argument('--tta', type=int, default=0, metavar='N',
|
parser.add_argument('--tta', type=int, default=0, metavar='N',
|
||||||
@ -385,7 +389,7 @@ def main():
|
|||||||
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
|
|
||||||
optimizer = create_optimizer(args, model)
|
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
|
||||||
|
|
||||||
# setup automatic mixed-precision (AMP) loss scaling and op casting
|
# setup automatic mixed-precision (AMP) loss scaling and op casting
|
||||||
amp_autocast = suppress # do nothing
|
amp_autocast = suppress # do nothing
|
||||||
@ -451,7 +455,9 @@ def main():
|
|||||||
|
|
||||||
# create the train and eval datasets
|
# create the train and eval datasets
|
||||||
dataset_train = create_dataset(
|
dataset_train = create_dataset(
|
||||||
args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size)
|
args.dataset,
|
||||||
|
root=args.data_dir, split=args.train_split, is_training=True,
|
||||||
|
batch_size=args.batch_size, repeats=args.epoch_repeats)
|
||||||
dataset_eval = create_dataset(
|
dataset_eval = create_dataset(
|
||||||
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
|
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
|
||||||
|
|
||||||
@ -541,13 +547,15 @@ def main():
|
|||||||
saver = None
|
saver = None
|
||||||
output_dir = ''
|
output_dir = ''
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
output_base = args.output if args.output else './output'
|
if args.experiment:
|
||||||
|
exp_name = args.experiment
|
||||||
|
else:
|
||||||
exp_name = '-'.join([
|
exp_name = '-'.join([
|
||||||
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
||||||
safe_model_name(args.model),
|
safe_model_name(args.model),
|
||||||
str(data_config['input_size'][-1])
|
str(data_config['input_size'][-1])
|
||||||
])
|
])
|
||||||
output_dir = get_outdir(output_base, 'train', exp_name)
|
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
|
||||||
decreasing = True if eval_metric == 'loss' else False
|
decreasing = True if eval_metric == 'loss' else False
|
||||||
saver = CheckpointSaver(
|
saver = CheckpointSaver(
|
||||||
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
|
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
|
||||||
|
@ -152,7 +152,7 @@ def validate(args):
|
|||||||
param_count = sum([m.numel() for m in model.parameters()])
|
param_count = sum([m.numel() for m in model.parameters()])
|
||||||
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
||||||
|
|
||||||
data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
|
data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
|
||||||
test_time_pool = False
|
test_time_pool = False
|
||||||
if not args.no_test_pool:
|
if not args.no_test_pool:
|
||||||
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
|
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user