mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Pass through --model-kwargs (and --opt-kwargs for train) from command line through to model __init__. Update some models to improve arg overlay. Cleanup along the way.
This commit is contained in:
parent
add3fb864e
commit
e861b74cf8
26
benchmark.py
26
benchmark.py
@ -22,7 +22,7 @@ from timm.data import resolve_data_config
|
|||||||
from timm.layers import set_fast_norm
|
from timm.layers import set_fast_norm
|
||||||
from timm.models import create_model, is_model, list_models
|
from timm.models import create_model, is_model, list_models
|
||||||
from timm.optim import create_optimizer_v2
|
from timm.optim import create_optimizer_v2
|
||||||
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
|
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs
|
||||||
|
|
||||||
has_apex = False
|
has_apex = False
|
||||||
try:
|
try:
|
||||||
@ -108,12 +108,15 @@ parser.add_argument('--grad-checkpointing', action='store_true', default=False,
|
|||||||
help='Enable gradient checkpointing through model blocks/stages')
|
help='Enable gradient checkpointing through model blocks/stages')
|
||||||
parser.add_argument('--amp', action='store_true', default=False,
|
parser.add_argument('--amp', action='store_true', default=False,
|
||||||
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
|
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
|
||||||
|
parser.add_argument('--amp-dtype', default='float16', type=str,
|
||||||
|
help='lower precision AMP dtype (default: float16). Overrides --precision arg if args.amp True.')
|
||||||
parser.add_argument('--precision', default='float32', type=str,
|
parser.add_argument('--precision', default='float32', type=str,
|
||||||
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
|
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
|
||||||
parser.add_argument('--fuser', default='', type=str,
|
parser.add_argument('--fuser', default='', type=str,
|
||||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||||
parser.add_argument('--fast-norm', default=False, action='store_true',
|
parser.add_argument('--fast-norm', default=False, action='store_true',
|
||||||
help='enable experimental fast-norm')
|
help='enable experimental fast-norm')
|
||||||
|
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
||||||
|
|
||||||
# codegen (model compilation) options
|
# codegen (model compilation) options
|
||||||
scripting_group = parser.add_mutually_exclusive_group()
|
scripting_group = parser.add_mutually_exclusive_group()
|
||||||
@ -124,7 +127,6 @@ scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None
|
|||||||
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
|
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
|
||||||
help="Enable AOT Autograd optimization.")
|
help="Enable AOT Autograd optimization.")
|
||||||
|
|
||||||
|
|
||||||
# train optimizer parameters
|
# train optimizer parameters
|
||||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||||
help='Optimizer (default: "sgd"')
|
help='Optimizer (default: "sgd"')
|
||||||
@ -168,19 +170,21 @@ def count_params(model: nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def resolve_precision(precision: str):
|
def resolve_precision(precision: str):
|
||||||
assert precision in ('amp', 'float16', 'bfloat16', 'float32')
|
assert precision in ('amp', 'amp_bfloat16', 'float16', 'bfloat16', 'float32')
|
||||||
use_amp = False
|
amp_dtype = None # amp disabled
|
||||||
model_dtype = torch.float32
|
model_dtype = torch.float32
|
||||||
data_dtype = torch.float32
|
data_dtype = torch.float32
|
||||||
if precision == 'amp':
|
if precision == 'amp':
|
||||||
use_amp = True
|
amp_dtype = torch.float16
|
||||||
|
elif precision == 'amp_bfloat16':
|
||||||
|
amp_dtype = torch.bfloat16
|
||||||
elif precision == 'float16':
|
elif precision == 'float16':
|
||||||
model_dtype = torch.float16
|
model_dtype = torch.float16
|
||||||
data_dtype = torch.float16
|
data_dtype = torch.float16
|
||||||
elif precision == 'bfloat16':
|
elif precision == 'bfloat16':
|
||||||
model_dtype = torch.bfloat16
|
model_dtype = torch.bfloat16
|
||||||
data_dtype = torch.bfloat16
|
data_dtype = torch.bfloat16
|
||||||
return use_amp, model_dtype, data_dtype
|
return amp_dtype, model_dtype, data_dtype
|
||||||
|
|
||||||
|
|
||||||
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
|
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
|
||||||
@ -228,9 +232,12 @@ class BenchmarkRunner:
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.detail = detail
|
self.detail = detail
|
||||||
self.device = device
|
self.device = device
|
||||||
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
|
self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
|
||||||
self.channels_last = kwargs.pop('channels_last', False)
|
self.channels_last = kwargs.pop('channels_last', False)
|
||||||
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress
|
if self.amp_dtype is not None:
|
||||||
|
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype)
|
||||||
|
else:
|
||||||
|
self.amp_autocast = suppress
|
||||||
|
|
||||||
if fuser:
|
if fuser:
|
||||||
set_jit_fuser(fuser)
|
set_jit_fuser(fuser)
|
||||||
@ -243,6 +250,7 @@ class BenchmarkRunner:
|
|||||||
drop_rate=kwargs.pop('drop', 0.),
|
drop_rate=kwargs.pop('drop', 0.),
|
||||||
drop_path_rate=kwargs.pop('drop_path', None),
|
drop_path_rate=kwargs.pop('drop_path', None),
|
||||||
drop_block_rate=kwargs.pop('drop_block', None),
|
drop_block_rate=kwargs.pop('drop_block', None),
|
||||||
|
**kwargs.pop('model_kwargs', {}),
|
||||||
)
|
)
|
||||||
self.model.to(
|
self.model.to(
|
||||||
device=self.device,
|
device=self.device,
|
||||||
@ -560,7 +568,7 @@ def _try_run(
|
|||||||
def benchmark(args):
|
def benchmark(args):
|
||||||
if args.amp:
|
if args.amp:
|
||||||
_logger.warning("Overriding precision to 'amp' since --amp flag set.")
|
_logger.warning("Overriding precision to 'amp' since --amp flag set.")
|
||||||
args.precision = 'amp'
|
args.precision = 'amp' if args.amp_dtype == 'float16' else '_'.join(['amp', args.amp_dtype])
|
||||||
_logger.info(f'Benchmarking in {args.precision} precision. '
|
_logger.info(f'Benchmarking in {args.precision} precision. '
|
||||||
f'{"NHWC" if args.channels_last else "NCHW"} layout. '
|
f'{"NHWC" if args.channels_last else "NCHW"} layout. '
|
||||||
f'torchscript {"enabled" if args.torchscript else "disabled"}')
|
f'torchscript {"enabled" if args.torchscript else "disabled"}')
|
||||||
|
14
inference.py
14
inference.py
@ -20,7 +20,7 @@ import torch
|
|||||||
from timm.data import create_dataset, create_loader, resolve_data_config
|
from timm.data import create_dataset, create_loader, resolve_data_config
|
||||||
from timm.layers import apply_test_time_pool
|
from timm.layers import apply_test_time_pool
|
||||||
from timm.models import create_model
|
from timm.models import create_model
|
||||||
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser
|
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
@ -72,6 +72,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
|
|||||||
metavar='N', help='mini-batch size (default: 256)')
|
metavar='N', help='mini-batch size (default: 256)')
|
||||||
parser.add_argument('--img-size', default=None, type=int,
|
parser.add_argument('--img-size', default=None, type=int,
|
||||||
metavar='N', help='Input image dimension, uses model default if empty')
|
metavar='N', help='Input image dimension, uses model default if empty')
|
||||||
|
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
|
||||||
|
help='Image input channels (default: None => 3)')
|
||||||
parser.add_argument('--input-size', default=None, nargs=3, type=int,
|
parser.add_argument('--input-size', default=None, nargs=3, type=int,
|
||||||
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
||||||
parser.add_argument('--use-train-size', action='store_true', default=False,
|
parser.add_argument('--use-train-size', action='store_true', default=False,
|
||||||
@ -110,6 +112,7 @@ parser.add_argument('--amp-dtype', default='float16', type=str,
|
|||||||
help='lower precision AMP dtype (default: float16)')
|
help='lower precision AMP dtype (default: float16)')
|
||||||
parser.add_argument('--fuser', default='', type=str,
|
parser.add_argument('--fuser', default='', type=str,
|
||||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||||
|
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
||||||
|
|
||||||
scripting_group = parser.add_mutually_exclusive_group()
|
scripting_group = parser.add_mutually_exclusive_group()
|
||||||
scripting_group.add_argument('--torchscript', default=False, action='store_true',
|
scripting_group.add_argument('--torchscript', default=False, action='store_true',
|
||||||
@ -170,12 +173,19 @@ def main():
|
|||||||
set_jit_fuser(args.fuser)
|
set_jit_fuser(args.fuser)
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
|
in_chans = 3
|
||||||
|
if args.in_chans is not None:
|
||||||
|
in_chans = args.in_chans
|
||||||
|
elif args.input_size is not None:
|
||||||
|
in_chans = args.input_size[0]
|
||||||
|
|
||||||
model = create_model(
|
model = create_model(
|
||||||
args.model,
|
args.model,
|
||||||
num_classes=args.num_classes,
|
num_classes=args.num_classes,
|
||||||
in_chans=3,
|
in_chans=in_chans,
|
||||||
pretrained=args.pretrained,
|
pretrained=args.pretrained,
|
||||||
checkpoint_path=args.checkpoint,
|
checkpoint_path=args.checkpoint,
|
||||||
|
**args.model_kwargs,
|
||||||
)
|
)
|
||||||
if args.num_classes is None:
|
if args.num_classes is None:
|
||||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||||
|
@ -218,7 +218,10 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
|
|||||||
|
|
||||||
|
|
||||||
def interleave_blocks(
|
def interleave_blocks(
|
||||||
types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs
|
types: Tuple[str, str], d,
|
||||||
|
every: Union[int, List[int]] = 1,
|
||||||
|
first: bool = False,
|
||||||
|
**kwargs,
|
||||||
) -> Tuple[ByoBlockCfg]:
|
) -> Tuple[ByoBlockCfg]:
|
||||||
""" interleave 2 block types in stack
|
""" interleave 2 block types in stack
|
||||||
"""
|
"""
|
||||||
@ -1587,15 +1590,32 @@ class ByobNet(nn.Module):
|
|||||||
in_chans=3,
|
in_chans=3,
|
||||||
global_pool='avg',
|
global_pool='avg',
|
||||||
output_stride=32,
|
output_stride=32,
|
||||||
zero_init_last=True,
|
|
||||||
img_size=None,
|
img_size=None,
|
||||||
drop_rate=0.,
|
drop_rate=0.,
|
||||||
drop_path_rate=0.,
|
drop_path_rate=0.,
|
||||||
|
zero_init_last=True,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (ByoModelCfg): Model architecture configuration
|
||||||
|
num_classes (int): Number of classifier classes (default: 1000)
|
||||||
|
in_chans (int): Number of input channels (default: 3)
|
||||||
|
global_pool (str): Global pooling type (default: 'avg')
|
||||||
|
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
|
||||||
|
img_size (Union[int, Tuple[int]): Image size for fixed image size models (i.e. self-attn)
|
||||||
|
drop_rate (float): Dropout rate (default: 0.)
|
||||||
|
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
|
||||||
|
zero_init_last (bool): Zero-init last weight of residual path
|
||||||
|
kwargs (dict): Extra kwargs overlayed onto cfg
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
|
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
|
||||||
layers = get_layer_fns(cfg)
|
layers = get_layer_fns(cfg)
|
||||||
if cfg.fixed_input_size:
|
if cfg.fixed_input_size:
|
||||||
assert img_size is not None, 'img_size argument is required for fixed input size model'
|
assert img_size is not None, 'img_size argument is required for fixed input size model'
|
||||||
|
@ -167,7 +167,7 @@ class ConvNeXtStage(nn.Module):
|
|||||||
conv_bias=conv_bias,
|
conv_bias=conv_bias,
|
||||||
use_grn=use_grn,
|
use_grn=use_grn,
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
norm_layer=norm_layer if conv_mlp else norm_layer_cl
|
norm_layer=norm_layer if conv_mlp else norm_layer_cl,
|
||||||
))
|
))
|
||||||
in_chs = out_chs
|
in_chs = out_chs
|
||||||
self.blocks = nn.Sequential(*stage_blocks)
|
self.blocks = nn.Sequential(*stage_blocks)
|
||||||
@ -184,16 +184,6 @@ class ConvNeXtStage(nn.Module):
|
|||||||
class ConvNeXt(nn.Module):
|
class ConvNeXt(nn.Module):
|
||||||
r""" ConvNeXt
|
r""" ConvNeXt
|
||||||
A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
|
A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
|
||||||
|
|
||||||
Args:
|
|
||||||
in_chans (int): Number of input image channels. Default: 3
|
|
||||||
num_classes (int): Number of classes for classification head. Default: 1000
|
|
||||||
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
|
||||||
dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
|
||||||
drop_rate (float): Head dropout rate
|
|
||||||
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
|
||||||
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
|
||||||
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -218,6 +208,28 @@ class ConvNeXt(nn.Module):
|
|||||||
drop_rate=0.,
|
drop_rate=0.,
|
||||||
drop_path_rate=0.,
|
drop_path_rate=0.,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
in_chans (int): Number of input image channels (default: 3)
|
||||||
|
num_classes (int): Number of classes for classification head (default: 1000)
|
||||||
|
global_pool (str): Global pooling type (default: 'avg')
|
||||||
|
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
|
||||||
|
depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3])
|
||||||
|
dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768])
|
||||||
|
kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7)
|
||||||
|
ls_init_value (float): Init value for Layer Scale (default: 1e-6)
|
||||||
|
stem_type (str): Type of stem (default: 'patch')
|
||||||
|
patch_size (int): Stem patch size for patch stem (default: 4)
|
||||||
|
head_init_scale (float): Init scaling value for classifier weights and biases (default: 1)
|
||||||
|
head_norm_first (bool): Apply normalization before global pool + head (default: False)
|
||||||
|
conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False)
|
||||||
|
conv_bias (bool): Use bias layers w/ all convolutions (default: True)
|
||||||
|
use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False)
|
||||||
|
act_layer (Union[str, nn.Module]): Activation Layer
|
||||||
|
norm_layer (Union[str, nn.Module]): Normalization Layer
|
||||||
|
drop_rate (float): Head dropout rate (default: 0.)
|
||||||
|
drop_path_rate (float): Stochastic depth rate (default: 0.)
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert output_stride in (8, 16, 32)
|
assert output_stride in (8, 16, 32)
|
||||||
kernel_sizes = to_ntuple(4)(kernel_sizes)
|
kernel_sizes = to_ntuple(4)(kernel_sizes)
|
||||||
@ -279,7 +291,7 @@ class ConvNeXt(nn.Module):
|
|||||||
use_grn=use_grn,
|
use_grn=use_grn,
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
norm_layer=norm_layer,
|
norm_layer=norm_layer,
|
||||||
norm_layer_cl=norm_layer_cl
|
norm_layer_cl=norm_layer_cl,
|
||||||
))
|
))
|
||||||
prev_chs = out_chs
|
prev_chs = out_chs
|
||||||
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
|
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
|
||||||
|
@ -12,7 +12,7 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage
|
|||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict, replace
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -518,7 +518,7 @@ class CrossStage(nn.Module):
|
|||||||
cross_linear=False,
|
cross_linear=False,
|
||||||
block_dpr=None,
|
block_dpr=None,
|
||||||
block_fn=BottleneckBlock,
|
block_fn=BottleneckBlock,
|
||||||
**block_kwargs
|
**block_kwargs,
|
||||||
):
|
):
|
||||||
super(CrossStage, self).__init__()
|
super(CrossStage, self).__init__()
|
||||||
first_dilation = first_dilation or dilation
|
first_dilation = first_dilation or dilation
|
||||||
@ -558,7 +558,7 @@ class CrossStage(nn.Module):
|
|||||||
bottle_ratio=bottle_ratio,
|
bottle_ratio=bottle_ratio,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
drop_path=block_dpr[i] if block_dpr is not None else 0.,
|
drop_path=block_dpr[i] if block_dpr is not None else 0.,
|
||||||
**block_kwargs
|
**block_kwargs,
|
||||||
))
|
))
|
||||||
prev_chs = block_out_chs
|
prev_chs = block_out_chs
|
||||||
|
|
||||||
@ -597,7 +597,7 @@ class CrossStage3(nn.Module):
|
|||||||
cross_linear=False,
|
cross_linear=False,
|
||||||
block_dpr=None,
|
block_dpr=None,
|
||||||
block_fn=BottleneckBlock,
|
block_fn=BottleneckBlock,
|
||||||
**block_kwargs
|
**block_kwargs,
|
||||||
):
|
):
|
||||||
super(CrossStage3, self).__init__()
|
super(CrossStage3, self).__init__()
|
||||||
first_dilation = first_dilation or dilation
|
first_dilation = first_dilation or dilation
|
||||||
@ -635,7 +635,7 @@ class CrossStage3(nn.Module):
|
|||||||
bottle_ratio=bottle_ratio,
|
bottle_ratio=bottle_ratio,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
drop_path=block_dpr[i] if block_dpr is not None else 0.,
|
drop_path=block_dpr[i] if block_dpr is not None else 0.,
|
||||||
**block_kwargs
|
**block_kwargs,
|
||||||
))
|
))
|
||||||
prev_chs = block_out_chs
|
prev_chs = block_out_chs
|
||||||
|
|
||||||
@ -668,7 +668,7 @@ class DarkStage(nn.Module):
|
|||||||
avg_down=False,
|
avg_down=False,
|
||||||
block_fn=BottleneckBlock,
|
block_fn=BottleneckBlock,
|
||||||
block_dpr=None,
|
block_dpr=None,
|
||||||
**block_kwargs
|
**block_kwargs,
|
||||||
):
|
):
|
||||||
super(DarkStage, self).__init__()
|
super(DarkStage, self).__init__()
|
||||||
first_dilation = first_dilation or dilation
|
first_dilation = first_dilation or dilation
|
||||||
@ -715,7 +715,7 @@ def create_csp_stem(
|
|||||||
padding='',
|
padding='',
|
||||||
act_layer=nn.ReLU,
|
act_layer=nn.ReLU,
|
||||||
norm_layer=nn.BatchNorm2d,
|
norm_layer=nn.BatchNorm2d,
|
||||||
aa_layer=None
|
aa_layer=None,
|
||||||
):
|
):
|
||||||
stem = nn.Sequential()
|
stem = nn.Sequential()
|
||||||
feature_info = []
|
feature_info = []
|
||||||
@ -738,7 +738,7 @@ def create_csp_stem(
|
|||||||
stride=conv_stride,
|
stride=conv_stride,
|
||||||
padding=padding if i == 0 else '',
|
padding=padding if i == 0 else '',
|
||||||
act_layer=act_layer,
|
act_layer=act_layer,
|
||||||
norm_layer=norm_layer
|
norm_layer=norm_layer,
|
||||||
))
|
))
|
||||||
stem_stride *= conv_stride
|
stem_stride *= conv_stride
|
||||||
prev_chs = chs
|
prev_chs = chs
|
||||||
@ -800,7 +800,7 @@ def create_csp_stages(
|
|||||||
cfg: CspModelCfg,
|
cfg: CspModelCfg,
|
||||||
drop_path_rate: float,
|
drop_path_rate: float,
|
||||||
output_stride: int,
|
output_stride: int,
|
||||||
stem_feat: Dict[str, Any]
|
stem_feat: Dict[str, Any],
|
||||||
):
|
):
|
||||||
cfg_dict = asdict(cfg.stages)
|
cfg_dict = asdict(cfg.stages)
|
||||||
num_stages = len(cfg.stages.depth)
|
num_stages = len(cfg.stages.depth)
|
||||||
@ -868,12 +868,27 @@ class CspNet(nn.Module):
|
|||||||
global_pool='avg',
|
global_pool='avg',
|
||||||
drop_rate=0.,
|
drop_rate=0.,
|
||||||
drop_path_rate=0.,
|
drop_path_rate=0.,
|
||||||
zero_init_last=True
|
zero_init_last=True,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cfg (CspModelCfg): Model architecture configuration
|
||||||
|
in_chans (int): Number of input channels (default: 3)
|
||||||
|
num_classes (int): Number of classifier classes (default: 1000)
|
||||||
|
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
|
||||||
|
global_pool (str): Global pooling type (default: 'avg')
|
||||||
|
drop_rate (float): Dropout rate (default: 0.)
|
||||||
|
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
|
||||||
|
zero_init_last (bool): Zero-init last weight of residual path
|
||||||
|
kwargs (dict): Extra kwargs overlayed onto cfg
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
assert output_stride in (8, 16, 32)
|
assert output_stride in (8, 16, 32)
|
||||||
|
|
||||||
|
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
|
||||||
layer_args = dict(
|
layer_args = dict(
|
||||||
act_layer=cfg.act_layer,
|
act_layer=cfg.act_layer,
|
||||||
norm_layer=cfg.norm_layer,
|
norm_layer=cfg.norm_layer,
|
||||||
|
@ -17,7 +17,7 @@ Status:
|
|||||||
Hacked together by / copyright Ross Wightman, 2021.
|
Hacked together by / copyright Ross Wightman, 2021.
|
||||||
"""
|
"""
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, replace
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Tuple, Optional
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
@ -159,11 +159,25 @@ class NfCfg:
|
|||||||
|
|
||||||
|
|
||||||
def _nfres_cfg(
|
def _nfres_cfg(
|
||||||
depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None):
|
depths,
|
||||||
|
channels=(256, 512, 1024, 2048),
|
||||||
|
group_size=None,
|
||||||
|
act_layer='relu',
|
||||||
|
attn_layer=None,
|
||||||
|
attn_kwargs=None,
|
||||||
|
):
|
||||||
attn_kwargs = attn_kwargs or {}
|
attn_kwargs = attn_kwargs or {}
|
||||||
cfg = NfCfg(
|
cfg = NfCfg(
|
||||||
depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25,
|
depths=depths,
|
||||||
group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
|
channels=channels,
|
||||||
|
stem_type='7x7_pool',
|
||||||
|
stem_chs=64,
|
||||||
|
bottle_ratio=0.25,
|
||||||
|
group_size=group_size,
|
||||||
|
act_layer=act_layer,
|
||||||
|
attn_layer=attn_layer,
|
||||||
|
attn_kwargs=attn_kwargs,
|
||||||
|
)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
@ -171,28 +185,70 @@ def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
|
|||||||
num_features = 1280 * channels[-1] // 440
|
num_features = 1280 * channels[-1] // 440
|
||||||
attn_kwargs = dict(rd_ratio=0.5)
|
attn_kwargs = dict(rd_ratio=0.5)
|
||||||
cfg = NfCfg(
|
cfg = NfCfg(
|
||||||
depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25,
|
depths=depths,
|
||||||
num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs)
|
channels=channels,
|
||||||
|
stem_type='3x3',
|
||||||
|
group_size=8,
|
||||||
|
width_factor=0.75,
|
||||||
|
bottle_ratio=2.25,
|
||||||
|
num_features=num_features,
|
||||||
|
reg=True,
|
||||||
|
attn_layer='se',
|
||||||
|
attn_kwargs=attn_kwargs,
|
||||||
|
)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
def _nfnet_cfg(
|
def _nfnet_cfg(
|
||||||
depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2.,
|
depths,
|
||||||
act_layer='gelu', attn_layer='se', attn_kwargs=None):
|
channels=(256, 512, 1536, 1536),
|
||||||
|
group_size=128,
|
||||||
|
bottle_ratio=0.5,
|
||||||
|
feat_mult=2.,
|
||||||
|
act_layer='gelu',
|
||||||
|
attn_layer='se',
|
||||||
|
attn_kwargs=None,
|
||||||
|
):
|
||||||
num_features = int(channels[-1] * feat_mult)
|
num_features = int(channels[-1] * feat_mult)
|
||||||
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5)
|
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5)
|
||||||
cfg = NfCfg(
|
cfg = NfCfg(
|
||||||
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size,
|
depths=depths,
|
||||||
bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer,
|
channels=channels,
|
||||||
attn_layer=attn_layer, attn_kwargs=attn_kwargs)
|
stem_type='deep_quad',
|
||||||
|
stem_chs=128,
|
||||||
|
group_size=group_size,
|
||||||
|
bottle_ratio=bottle_ratio,
|
||||||
|
extra_conv=True,
|
||||||
|
num_features=num_features,
|
||||||
|
act_layer=act_layer,
|
||||||
|
attn_layer=attn_layer,
|
||||||
|
attn_kwargs=attn_kwargs,
|
||||||
|
)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True):
|
def _dm_nfnet_cfg(
|
||||||
|
depths,
|
||||||
|
channels=(256, 512, 1536, 1536),
|
||||||
|
act_layer='gelu',
|
||||||
|
skipinit=True,
|
||||||
|
):
|
||||||
cfg = NfCfg(
|
cfg = NfCfg(
|
||||||
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128,
|
depths=depths,
|
||||||
bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit,
|
channels=channels,
|
||||||
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5))
|
stem_type='deep_quad',
|
||||||
|
stem_chs=128,
|
||||||
|
group_size=128,
|
||||||
|
bottle_ratio=0.5,
|
||||||
|
extra_conv=True,
|
||||||
|
gamma_in_act=True,
|
||||||
|
same_padding=True,
|
||||||
|
skipinit=skipinit,
|
||||||
|
num_features=int(channels[-1] * 2.0),
|
||||||
|
act_layer=act_layer,
|
||||||
|
attn_layer='se',
|
||||||
|
attn_kwargs=dict(rd_ratio=0.5),
|
||||||
|
)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
@ -278,7 +334,14 @@ def act_with_gamma(act_type, gamma: float = 1.):
|
|||||||
|
|
||||||
class DownsampleAvg(nn.Module):
|
class DownsampleAvg(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, conv_layer=ScaledStdConv2d):
|
self,
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
conv_layer=ScaledStdConv2d,
|
||||||
|
):
|
||||||
""" AvgPool Downsampling as in 'D' ResNet variants. Support for dilation."""
|
""" AvgPool Downsampling as in 'D' ResNet variants. Support for dilation."""
|
||||||
super(DownsampleAvg, self).__init__()
|
super(DownsampleAvg, self).__init__()
|
||||||
avg_stride = stride if dilation == 1 else 1
|
avg_stride = stride if dilation == 1 else 1
|
||||||
@ -299,9 +362,26 @@ class NormFreeBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,
|
self,
|
||||||
alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False,
|
in_chs,
|
||||||
skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.):
|
out_chs=None,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
alpha=1.0,
|
||||||
|
beta=1.0,
|
||||||
|
bottle_ratio=0.25,
|
||||||
|
group_size=None,
|
||||||
|
ch_div=1,
|
||||||
|
reg=True,
|
||||||
|
extra_conv=False,
|
||||||
|
skipinit=False,
|
||||||
|
attn_layer=None,
|
||||||
|
attn_gain=2.0,
|
||||||
|
act_layer=None,
|
||||||
|
conv_layer=None,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
first_dilation = first_dilation or dilation
|
first_dilation = first_dilation or dilation
|
||||||
out_chs = out_chs or in_chs
|
out_chs = out_chs or in_chs
|
||||||
@ -316,7 +396,13 @@ class NormFreeBlock(nn.Module):
|
|||||||
|
|
||||||
if in_chs != out_chs or stride != 1 or dilation != first_dilation:
|
if in_chs != out_chs or stride != 1 or dilation != first_dilation:
|
||||||
self.downsample = DownsampleAvg(
|
self.downsample = DownsampleAvg(
|
||||||
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer)
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
first_dilation=first_dilation,
|
||||||
|
conv_layer=conv_layer,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.downsample = None
|
self.downsample = None
|
||||||
|
|
||||||
@ -452,14 +538,33 @@ class NormFreeNet(nn.Module):
|
|||||||
for what it is/does. Approx 8-10% throughput loss.
|
for what it is/does. Approx 8-10% throughput loss.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
|
self,
|
||||||
drop_rate=0., drop_path_rate=0.
|
cfg: NfCfg,
|
||||||
|
num_classes=1000,
|
||||||
|
in_chans=3,
|
||||||
|
global_pool='avg',
|
||||||
|
output_stride=32,
|
||||||
|
drop_rate=0.,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cfg (NfCfg): Model architecture configuration
|
||||||
|
num_classes (int): Number of classifier classes (default: 1000)
|
||||||
|
in_chans (int): Number of input channels (default: 3)
|
||||||
|
global_pool (str): Global pooling type (default: 'avg')
|
||||||
|
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
|
||||||
|
drop_rate (float): Dropout rate (default: 0.)
|
||||||
|
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
|
||||||
|
kwargs (dict): Extra kwargs overlayed onto cfg
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
|
cfg = replace(cfg, **kwargs)
|
||||||
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
|
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
|
||||||
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
|
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
|
||||||
if cfg.gamma_in_act:
|
if cfg.gamma_in_act:
|
||||||
@ -472,7 +577,12 @@ class NormFreeNet(nn.Module):
|
|||||||
|
|
||||||
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
|
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
|
||||||
self.stem, stem_stride, stem_feat = create_stem(
|
self.stem, stem_stride, stem_feat = create_stem(
|
||||||
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
|
in_chans,
|
||||||
|
stem_chs,
|
||||||
|
cfg.stem_type,
|
||||||
|
conv_layer=conv_layer,
|
||||||
|
act_layer=act_layer,
|
||||||
|
)
|
||||||
|
|
||||||
self.feature_info = [stem_feat]
|
self.feature_info = [stem_feat]
|
||||||
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
|
||||||
|
@ -14,7 +14,7 @@ Weights from original impl have been modified
|
|||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, replace
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional, Union, Callable
|
from typing import Optional, Union, Callable
|
||||||
|
|
||||||
@ -237,7 +237,15 @@ def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_la
|
|||||||
|
|
||||||
|
|
||||||
def create_shortcut(
|
def create_shortcut(
|
||||||
downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None, preact=False):
|
downsample_type,
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
dilation=(1, 1),
|
||||||
|
norm_layer=None,
|
||||||
|
preact=False,
|
||||||
|
):
|
||||||
assert downsample_type in ('avg', 'conv1x1', '', None)
|
assert downsample_type in ('avg', 'conv1x1', '', None)
|
||||||
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
||||||
dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact)
|
dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact)
|
||||||
@ -259,9 +267,21 @@ class Bottleneck(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
|
self,
|
||||||
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
in_chs,
|
||||||
drop_block=None, drop_path_rate=0.):
|
out_chs,
|
||||||
|
stride=1,
|
||||||
|
dilation=(1, 1),
|
||||||
|
bottle_ratio=1,
|
||||||
|
group_size=1,
|
||||||
|
se_ratio=0.25,
|
||||||
|
downsample='conv1x1',
|
||||||
|
linear_out=False,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
norm_layer=nn.BatchNorm2d,
|
||||||
|
drop_block=None,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
super(Bottleneck, self).__init__()
|
super(Bottleneck, self).__init__()
|
||||||
act_layer = get_act_layer(act_layer)
|
act_layer = get_act_layer(act_layer)
|
||||||
bottleneck_chs = int(round(out_chs * bottle_ratio))
|
bottleneck_chs = int(round(out_chs * bottle_ratio))
|
||||||
@ -307,9 +327,21 @@ class PreBottleneck(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
|
self,
|
||||||
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
in_chs,
|
||||||
drop_block=None, drop_path_rate=0.):
|
out_chs,
|
||||||
|
stride=1,
|
||||||
|
dilation=(1, 1),
|
||||||
|
bottle_ratio=1,
|
||||||
|
group_size=1,
|
||||||
|
se_ratio=0.25,
|
||||||
|
downsample='conv1x1',
|
||||||
|
linear_out=False,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
norm_layer=nn.BatchNorm2d,
|
||||||
|
drop_block=None,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
):
|
||||||
super(PreBottleneck, self).__init__()
|
super(PreBottleneck, self).__init__()
|
||||||
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
|
||||||
bottleneck_chs = int(round(out_chs * bottle_ratio))
|
bottleneck_chs = int(round(out_chs * bottle_ratio))
|
||||||
@ -353,8 +385,16 @@ class RegStage(nn.Module):
|
|||||||
"""Stage (sequence of blocks w/ the same output shape)."""
|
"""Stage (sequence of blocks w/ the same output shape)."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, depth, in_chs, out_chs, stride, dilation,
|
self,
|
||||||
drop_path_rates=None, block_fn=Bottleneck, **block_kwargs):
|
depth,
|
||||||
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
drop_path_rates=None,
|
||||||
|
block_fn=Bottleneck,
|
||||||
|
**block_kwargs,
|
||||||
|
):
|
||||||
super(RegStage, self).__init__()
|
super(RegStage, self).__init__()
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
@ -367,8 +407,13 @@ class RegStage(nn.Module):
|
|||||||
name = "b{}".format(i + 1)
|
name = "b{}".format(i + 1)
|
||||||
self.add_module(
|
self.add_module(
|
||||||
name, block_fn(
|
name, block_fn(
|
||||||
block_in_chs, out_chs, stride=block_stride, dilation=block_dilation,
|
block_in_chs,
|
||||||
drop_path_rate=dpr, **block_kwargs)
|
out_chs,
|
||||||
|
stride=block_stride,
|
||||||
|
dilation=block_dilation,
|
||||||
|
drop_path_rate=dpr,
|
||||||
|
**block_kwargs,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
first_dilation = dilation
|
first_dilation = dilation
|
||||||
|
|
||||||
@ -389,12 +434,35 @@ class RegNet(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, cfg: RegNetCfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg',
|
self,
|
||||||
drop_rate=0., drop_path_rate=0., zero_init_last=True):
|
cfg: RegNetCfg,
|
||||||
|
in_chans=3,
|
||||||
|
num_classes=1000,
|
||||||
|
output_stride=32,
|
||||||
|
global_pool='avg',
|
||||||
|
drop_rate=0.,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
zero_init_last=True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (RegNetCfg): Model architecture configuration
|
||||||
|
in_chans (int): Number of input channels (default: 3)
|
||||||
|
num_classes (int): Number of classifier classes (default: 1000)
|
||||||
|
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
|
||||||
|
global_pool (str): Global pooling type (default: 'avg')
|
||||||
|
drop_rate (float): Dropout rate (default: 0.)
|
||||||
|
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
|
||||||
|
zero_init_last (bool): Zero-init last weight of residual path
|
||||||
|
kwargs (dict): Extra kwargs overlayed onto cfg
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
assert output_stride in (8, 16, 32)
|
assert output_stride in (8, 16, 32)
|
||||||
|
cfg = replace(cfg, **kwargs) # update cfg with extra passed kwargs
|
||||||
|
|
||||||
# Construct the stem
|
# Construct the stem
|
||||||
stem_width = cfg.stem_width
|
stem_width = cfg.stem_width
|
||||||
@ -461,8 +529,12 @@ class RegNet(nn.Module):
|
|||||||
dict(zip(arg_names, params)) for params in
|
dict(zip(arg_names, params)) for params in
|
||||||
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)]
|
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)]
|
||||||
common_args = dict(
|
common_args = dict(
|
||||||
downsample=cfg.downsample, se_ratio=cfg.se_ratio, linear_out=cfg.linear_out,
|
downsample=cfg.downsample,
|
||||||
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
|
se_ratio=cfg.se_ratio,
|
||||||
|
linear_out=cfg.linear_out,
|
||||||
|
act_layer=cfg.act_layer,
|
||||||
|
norm_layer=cfg.norm_layer,
|
||||||
|
)
|
||||||
return per_stage_args, common_args
|
return per_stage_args, common_args
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
@ -518,7 +590,6 @@ def _init_weights(module, name='', zero_init_last=False):
|
|||||||
|
|
||||||
|
|
||||||
def _filter_fn(state_dict):
|
def _filter_fn(state_dict):
|
||||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
||||||
if 'classy_state_dict' in state_dict:
|
if 'classy_state_dict' in state_dict:
|
||||||
import re
|
import re
|
||||||
state_dict = state_dict['classy_state_dict']['base_model']['model']
|
state_dict = state_dict['classy_state_dict']['base_model']['model']
|
||||||
|
@ -16,7 +16,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \
|
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \
|
||||||
create_classifier
|
get_act_layer, get_norm_layer, create_classifier
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint_seq
|
from ._manipulate import checkpoint_seq
|
||||||
from ._registry import register_model, model_entrypoint
|
from ._registry import register_model, model_entrypoint
|
||||||
@ -500,7 +500,14 @@ class Bottleneck(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def downsample_conv(
|
def downsample_conv(
|
||||||
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
norm_layer=None,
|
||||||
|
):
|
||||||
norm_layer = norm_layer or nn.BatchNorm2d
|
norm_layer = norm_layer or nn.BatchNorm2d
|
||||||
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
|
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
|
||||||
first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
|
first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
|
||||||
@ -514,7 +521,14 @@ def downsample_conv(
|
|||||||
|
|
||||||
|
|
||||||
def downsample_avg(
|
def downsample_avg(
|
||||||
in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None):
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
norm_layer=None,
|
||||||
|
):
|
||||||
norm_layer = norm_layer or nn.BatchNorm2d
|
norm_layer = norm_layer or nn.BatchNorm2d
|
||||||
avg_stride = stride if dilation == 1 else 1
|
avg_stride = stride if dilation == 1 else 1
|
||||||
if stride == 1 and dilation == 1:
|
if stride == 1 and dilation == 1:
|
||||||
@ -627,31 +641,6 @@ class ResNet(nn.Module):
|
|||||||
|
|
||||||
SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64,
|
SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64,
|
||||||
reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block
|
reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
block : Block, class for the residual block. Options are BasicBlockGl, BottleneckGl.
|
|
||||||
layers : list of int, number of layers in each block
|
|
||||||
num_classes : int, default 1000, number of classification classes.
|
|
||||||
in_chans : int, default 3, number of input (color) channels.
|
|
||||||
output_stride : int, default 32, output stride of the network, 32, 16, or 8.
|
|
||||||
global_pool : str, Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
|
|
||||||
cardinality : int, default 1, number of convolution groups for 3x3 conv in Bottleneck.
|
|
||||||
base_width : int, default 64, factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
|
|
||||||
stem_width : int, default 64, number of channels in stem convolutions
|
|
||||||
stem_type : str, default ''
|
|
||||||
The type of stem:
|
|
||||||
* '', default - a single 7x7 conv with a width of stem_width
|
|
||||||
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
|
|
||||||
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
|
|
||||||
block_reduce_first : int, default 1
|
|
||||||
Reduction factor for first convolution output width of residual blocks, 1 for all archs except senets, where 2
|
|
||||||
down_kernel_size : int, default 1, kernel size of residual block downsample path, 1x1 for most, 3x3 for senets
|
|
||||||
avg_down : bool, default False, use average pooling for projection skip connection between stages/downsample.
|
|
||||||
act_layer : nn.Module, activation layer
|
|
||||||
norm_layer : nn.Module, normalization layer
|
|
||||||
aa_layer : nn.Module, anti-aliasing layer
|
|
||||||
drop_rate : float, default 0. Dropout probability before classifier, for training
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -679,6 +668,36 @@ class ResNet(nn.Module):
|
|||||||
zero_init_last=True,
|
zero_init_last=True,
|
||||||
block_args=None,
|
block_args=None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck.
|
||||||
|
layers (List[int]) : number of layers in each block
|
||||||
|
num_classes (int): number of classification classes (default 1000)
|
||||||
|
in_chans (int): number of input (color) channels. (default 3)
|
||||||
|
output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
|
||||||
|
global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg')
|
||||||
|
cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1)
|
||||||
|
base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64)
|
||||||
|
stem_width (int): number of channels in stem convolutions (default 64)
|
||||||
|
stem_type (str): The type of stem (default ''):
|
||||||
|
* '', default - a single 7x7 conv with a width of stem_width
|
||||||
|
* 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
|
||||||
|
* 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
|
||||||
|
replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution
|
||||||
|
block_reduce_first (int): Reduction factor for first convolution output width of residual blocks,
|
||||||
|
1 for all archs except senets, where 2 (default 1)
|
||||||
|
down_kernel_size (int): kernel size of residual block downsample path,
|
||||||
|
1x1 for most, 3x3 for senets (default: 1)
|
||||||
|
avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False)
|
||||||
|
act_layer (str, nn.Module): activation layer
|
||||||
|
norm_layer (str, nn.Module): normalization layer
|
||||||
|
aa_layer (nn.Module): anti-aliasing layer
|
||||||
|
drop_rate (float): Dropout probability before classifier, for training (default 0.)
|
||||||
|
drop_path_rate (float): Stochastic depth drop-path rate (default 0.)
|
||||||
|
drop_block_rate (float): Drop block rate (default 0.)
|
||||||
|
zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight)
|
||||||
|
block_args (dict): Extra kwargs to pass through to block module
|
||||||
|
"""
|
||||||
super(ResNet, self).__init__()
|
super(ResNet, self).__init__()
|
||||||
block_args = block_args or dict()
|
block_args = block_args or dict()
|
||||||
assert output_stride in (8, 16, 32)
|
assert output_stride in (8, 16, 32)
|
||||||
@ -686,6 +705,9 @@ class ResNet(nn.Module):
|
|||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
|
act_layer = get_act_layer(act_layer)
|
||||||
|
norm_layer = get_norm_layer(norm_layer)
|
||||||
|
|
||||||
# Stem
|
# Stem
|
||||||
deep_stem = 'deep' in stem_type
|
deep_stem = 'deep' in stem_type
|
||||||
inplanes = stem_width * 2 if deep_stem else 64
|
inplanes = stem_width * 2 if deep_stem else 64
|
||||||
|
@ -37,7 +37,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \
|
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \
|
||||||
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
|
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
|
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
|
||||||
from ._registry import register_model
|
from ._registry import register_model
|
||||||
@ -276,8 +276,16 @@ class Bottleneck(nn.Module):
|
|||||||
|
|
||||||
class DownsampleConv(nn.Module):
|
class DownsampleConv(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True,
|
self,
|
||||||
conv_layer=None, norm_layer=None):
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
preact=True,
|
||||||
|
conv_layer=None,
|
||||||
|
norm_layer=None,
|
||||||
|
):
|
||||||
super(DownsampleConv, self).__init__()
|
super(DownsampleConv, self).__init__()
|
||||||
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
|
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
|
||||||
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
|
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
|
||||||
@ -288,8 +296,16 @@ class DownsampleConv(nn.Module):
|
|||||||
|
|
||||||
class DownsampleAvg(nn.Module):
|
class DownsampleAvg(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None,
|
self,
|
||||||
preact=True, conv_layer=None, norm_layer=None):
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
stride=1,
|
||||||
|
dilation=1,
|
||||||
|
first_dilation=None,
|
||||||
|
preact=True,
|
||||||
|
conv_layer=None,
|
||||||
|
norm_layer=None,
|
||||||
|
):
|
||||||
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
|
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
|
||||||
super(DownsampleAvg, self).__init__()
|
super(DownsampleAvg, self).__init__()
|
||||||
avg_stride = stride if dilation == 1 else 1
|
avg_stride = stride if dilation == 1 else 1
|
||||||
@ -334,9 +350,18 @@ class ResNetStage(nn.Module):
|
|||||||
drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
|
drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
|
||||||
stride = stride if block_idx == 0 else 1
|
stride = stride if block_idx == 0 else 1
|
||||||
self.blocks.add_module(str(block_idx), block_fn(
|
self.blocks.add_module(str(block_idx), block_fn(
|
||||||
prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups,
|
prev_chs,
|
||||||
first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate,
|
out_chs,
|
||||||
**layer_kwargs, **block_kwargs))
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
bottle_ratio=bottle_ratio,
|
||||||
|
groups=groups,
|
||||||
|
first_dilation=first_dilation,
|
||||||
|
proj_layer=proj_layer,
|
||||||
|
drop_path_rate=drop_path_rate,
|
||||||
|
**layer_kwargs,
|
||||||
|
**block_kwargs,
|
||||||
|
))
|
||||||
prev_chs = out_chs
|
prev_chs = out_chs
|
||||||
first_dilation = dilation
|
first_dilation = dilation
|
||||||
proj_layer = None
|
proj_layer = None
|
||||||
@ -413,21 +438,49 @@ class ResNetV2(nn.Module):
|
|||||||
avg_down=False,
|
avg_down=False,
|
||||||
preact=True,
|
preact=True,
|
||||||
act_layer=nn.ReLU,
|
act_layer=nn.ReLU,
|
||||||
conv_layer=StdConv2d,
|
|
||||||
norm_layer=partial(GroupNormAct, num_groups=32),
|
norm_layer=partial(GroupNormAct, num_groups=32),
|
||||||
|
conv_layer=StdConv2d,
|
||||||
drop_rate=0.,
|
drop_rate=0.,
|
||||||
drop_path_rate=0.,
|
drop_path_rate=0.,
|
||||||
zero_init_last=False,
|
zero_init_last=False,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
layers (List[int]) : number of layers in each block
|
||||||
|
channels (List[int]) : number of channels in each block:
|
||||||
|
num_classes (int): number of classification classes (default 1000)
|
||||||
|
in_chans (int): number of input (color) channels. (default 3)
|
||||||
|
global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg')
|
||||||
|
output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
|
||||||
|
width_factor (int): channel (width) multiplication factor
|
||||||
|
stem_chs (int): stem width (default: 64)
|
||||||
|
stem_type (str): stem type (default: '' == 7x7)
|
||||||
|
avg_down (bool): average pooling in residual downsampling (default: False)
|
||||||
|
preact (bool): pre-activiation (default: True)
|
||||||
|
act_layer (Union[str, nn.Module]): activation layer
|
||||||
|
norm_layer (Union[str, nn.Module]): normalization layer
|
||||||
|
conv_layer (nn.Module): convolution module
|
||||||
|
drop_rate: classifier dropout rate (default: 0.)
|
||||||
|
drop_path_rate: stochastic depth rate (default: 0.)
|
||||||
|
zero_init_last: zero-init last weight in residual path (default: False)
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
wf = width_factor
|
wf = width_factor
|
||||||
|
norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
|
||||||
|
act_layer = get_act_layer(act_layer)
|
||||||
|
|
||||||
self.feature_info = []
|
self.feature_info = []
|
||||||
stem_chs = make_div(stem_chs * wf)
|
stem_chs = make_div(stem_chs * wf)
|
||||||
self.stem = create_resnetv2_stem(
|
self.stem = create_resnetv2_stem(
|
||||||
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
|
in_chans,
|
||||||
|
stem_chs,
|
||||||
|
stem_type,
|
||||||
|
preact,
|
||||||
|
conv_layer=conv_layer,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
)
|
||||||
stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm'
|
stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm'
|
||||||
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
|
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
|
||||||
|
|
||||||
|
@ -1152,8 +1152,8 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
|||||||
def vit_tiny_patch16_224(pretrained=False, **kwargs):
|
def vit_tiny_patch16_224(pretrained=False, **kwargs):
|
||||||
""" ViT-Tiny (Vit-Ti/16)
|
""" ViT-Tiny (Vit-Ti/16)
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
|
||||||
model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1161,8 +1161,8 @@ def vit_tiny_patch16_224(pretrained=False, **kwargs):
|
|||||||
def vit_tiny_patch16_384(pretrained=False, **kwargs):
|
def vit_tiny_patch16_384(pretrained=False, **kwargs):
|
||||||
""" ViT-Tiny (Vit-Ti/16) @ 384x384.
|
""" ViT-Tiny (Vit-Ti/16) @ 384x384.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
|
||||||
model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1170,8 +1170,8 @@ def vit_tiny_patch16_384(pretrained=False, **kwargs):
|
|||||||
def vit_small_patch32_224(pretrained=False, **kwargs):
|
def vit_small_patch32_224(pretrained=False, **kwargs):
|
||||||
""" ViT-Small (ViT-S/32)
|
""" ViT-Small (ViT-S/32)
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
|
||||||
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1179,8 +1179,8 @@ def vit_small_patch32_224(pretrained=False, **kwargs):
|
|||||||
def vit_small_patch32_384(pretrained=False, **kwargs):
|
def vit_small_patch32_384(pretrained=False, **kwargs):
|
||||||
""" ViT-Small (ViT-S/32) at 384x384.
|
""" ViT-Small (ViT-S/32) at 384x384.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
|
||||||
model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1188,8 +1188,8 @@ def vit_small_patch32_384(pretrained=False, **kwargs):
|
|||||||
def vit_small_patch16_224(pretrained=False, **kwargs):
|
def vit_small_patch16_224(pretrained=False, **kwargs):
|
||||||
""" ViT-Small (ViT-S/16)
|
""" ViT-Small (ViT-S/16)
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
|
||||||
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1197,8 +1197,8 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
|
|||||||
def vit_small_patch16_384(pretrained=False, **kwargs):
|
def vit_small_patch16_384(pretrained=False, **kwargs):
|
||||||
""" ViT-Small (ViT-S/16)
|
""" ViT-Small (ViT-S/16)
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
|
||||||
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1206,8 +1206,8 @@ def vit_small_patch16_384(pretrained=False, **kwargs):
|
|||||||
def vit_small_patch8_224(pretrained=False, **kwargs):
|
def vit_small_patch8_224(pretrained=False, **kwargs):
|
||||||
""" ViT-Small (ViT-S/8)
|
""" ViT-Small (ViT-S/8)
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6)
|
||||||
model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1216,8 +1216,8 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
|
|||||||
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
|
ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
|
||||||
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1226,8 +1226,8 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
|
|||||||
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
|
||||||
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1236,8 +1236,8 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
|
|||||||
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
|
||||||
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1246,8 +1246,8 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
|
|||||||
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
|
||||||
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1256,8 +1256,8 @@ def vit_base_patch8_224(pretrained=False, **kwargs):
|
|||||||
""" ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
|
""" ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12)
|
||||||
model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1265,8 +1265,8 @@ def vit_base_patch8_224(pretrained=False, **kwargs):
|
|||||||
def vit_large_patch32_224(pretrained=False, **kwargs):
|
def vit_large_patch32_224(pretrained=False, **kwargs):
|
||||||
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
|
||||||
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1275,8 +1275,8 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
|
|||||||
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
|
||||||
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1285,8 +1285,8 @@ def vit_large_patch16_224(pretrained=False, **kwargs):
|
|||||||
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
|
||||||
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1295,8 +1295,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
|
|||||||
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
|
||||||
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1304,8 +1304,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
|
|||||||
def vit_large_patch14_224(pretrained=False, **kwargs):
|
def vit_large_patch14_224(pretrained=False, **kwargs):
|
||||||
""" ViT-Large model (ViT-L/14)
|
""" ViT-Large model (ViT-L/14)
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16)
|
||||||
model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1313,8 +1313,8 @@ def vit_large_patch14_224(pretrained=False, **kwargs):
|
|||||||
def vit_huge_patch14_224(pretrained=False, **kwargs):
|
def vit_huge_patch14_224(pretrained=False, **kwargs):
|
||||||
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
|
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16)
|
||||||
model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1322,8 +1322,8 @@ def vit_huge_patch14_224(pretrained=False, **kwargs):
|
|||||||
def vit_giant_patch14_224(pretrained=False, **kwargs):
|
def vit_giant_patch14_224(pretrained=False, **kwargs):
|
||||||
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
|
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16)
|
||||||
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1331,8 +1331,9 @@ def vit_giant_patch14_224(pretrained=False, **kwargs):
|
|||||||
def vit_gigantic_patch14_224(pretrained=False, **kwargs):
|
def vit_gigantic_patch14_224(pretrained=False, **kwargs):
|
||||||
""" ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
""" ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
|
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
|
||||||
model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer(
|
||||||
|
'vit_gigantic_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1341,8 +1342,9 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
|
|||||||
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
||||||
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False)
|
||||||
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer(
|
||||||
|
'vit_base_patch16_224_miil', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1352,8 +1354,9 @@ def vit_medium_patch16_gap_240(pretrained=False, **kwargs):
|
|||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
|
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
|
||||||
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
|
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
|
||||||
model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer(
|
||||||
|
'vit_medium_patch16_gap_240', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1363,8 +1366,9 @@ def vit_medium_patch16_gap_256(pretrained=False, **kwargs):
|
|||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
|
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
|
||||||
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
|
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
|
||||||
model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer(
|
||||||
|
'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1374,8 +1378,9 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs):
|
|||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
|
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
|
||||||
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
|
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
|
||||||
model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer(
|
||||||
|
'vit_medium_patch16_gap_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1384,9 +1389,9 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs):
|
|||||||
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
|
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False,
|
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
|
||||||
global_pool=kwargs.get('global_pool', 'avg'), fc_norm=False, **kwargs)
|
model = _create_vision_transformer(
|
||||||
model = _create_vision_transformer('vit_base_patch16_gap_224', pretrained=pretrained, **model_kwargs)
|
'vit_base_patch16_gap_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1395,8 +1400,9 @@ def vit_base_patch32_clip_224(pretrained=False, **kwargs):
|
|||||||
""" ViT-B/32 CLIP image tower @ 224x224
|
""" ViT-B/32 CLIP image tower @ 224x224
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
|
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
|
||||||
model = _create_vision_transformer('vit_base_patch32_clip_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer(
|
||||||
|
'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1405,8 +1411,9 @@ def vit_base_patch32_clip_384(pretrained=False, **kwargs):
|
|||||||
""" ViT-B/32 CLIP image tower @ 384x384
|
""" ViT-B/32 CLIP image tower @ 384x384
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
|
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
|
||||||
model = _create_vision_transformer('vit_base_patch32_clip_384', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer(
|
||||||
|
'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1415,8 +1422,9 @@ def vit_base_patch32_clip_448(pretrained=False, **kwargs):
|
|||||||
""" ViT-B/32 CLIP image tower @ 448x448
|
""" ViT-B/32 CLIP image tower @ 448x448
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
|
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
|
||||||
model = _create_vision_transformer('vit_base_patch32_clip_448', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer(
|
||||||
|
'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1424,9 +1432,9 @@ def vit_base_patch32_clip_448(pretrained=False, **kwargs):
|
|||||||
def vit_base_patch16_clip_224(pretrained=False, **kwargs):
|
def vit_base_patch16_clip_224(pretrained=False, **kwargs):
|
||||||
""" ViT-B/16 CLIP image tower
|
""" ViT-B/16 CLIP image tower
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
|
||||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
|
model = _create_vision_transformer(
|
||||||
model = _create_vision_transformer('vit_base_patch16_clip_224', pretrained=pretrained, **model_kwargs)
|
'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1434,9 +1442,9 @@ def vit_base_patch16_clip_224(pretrained=False, **kwargs):
|
|||||||
def vit_base_patch16_clip_384(pretrained=False, **kwargs):
|
def vit_base_patch16_clip_384(pretrained=False, **kwargs):
|
||||||
""" ViT-B/16 CLIP image tower @ 384x384
|
""" ViT-B/16 CLIP image tower @ 384x384
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
|
||||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
|
model = _create_vision_transformer(
|
||||||
model = _create_vision_transformer('vit_base_patch16_clip_384', pretrained=pretrained, **model_kwargs)
|
'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1444,9 +1452,9 @@ def vit_base_patch16_clip_384(pretrained=False, **kwargs):
|
|||||||
def vit_large_patch14_clip_224(pretrained=False, **kwargs):
|
def vit_large_patch14_clip_224(pretrained=False, **kwargs):
|
||||||
""" ViT-Large model (ViT-L/14) CLIP image tower
|
""" ViT-Large model (ViT-L/14) CLIP image tower
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
|
||||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
|
model = _create_vision_transformer(
|
||||||
model = _create_vision_transformer('vit_large_patch14_clip_224', pretrained=pretrained, **model_kwargs)
|
'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1454,9 +1462,9 @@ def vit_large_patch14_clip_224(pretrained=False, **kwargs):
|
|||||||
def vit_large_patch14_clip_336(pretrained=False, **kwargs):
|
def vit_large_patch14_clip_336(pretrained=False, **kwargs):
|
||||||
""" ViT-Large model (ViT-L/14) CLIP image tower @ 336x336
|
""" ViT-Large model (ViT-L/14) CLIP image tower @ 336x336
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
|
||||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
|
model = _create_vision_transformer(
|
||||||
model = _create_vision_transformer('vit_large_patch14_clip_336', pretrained=pretrained, **model_kwargs)
|
'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1464,9 +1472,9 @@ def vit_large_patch14_clip_336(pretrained=False, **kwargs):
|
|||||||
def vit_huge_patch14_clip_224(pretrained=False, **kwargs):
|
def vit_huge_patch14_clip_224(pretrained=False, **kwargs):
|
||||||
""" ViT-Huge model (ViT-H/14) CLIP image tower.
|
""" ViT-Huge model (ViT-H/14) CLIP image tower.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
|
||||||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
|
model = _create_vision_transformer(
|
||||||
model = _create_vision_transformer('vit_huge_patch14_clip_224', pretrained=pretrained, **model_kwargs)
|
'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1474,9 +1482,9 @@ def vit_huge_patch14_clip_224(pretrained=False, **kwargs):
|
|||||||
def vit_huge_patch14_clip_336(pretrained=False, **kwargs):
|
def vit_huge_patch14_clip_336(pretrained=False, **kwargs):
|
||||||
""" ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336
|
""" ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
|
||||||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
|
model = _create_vision_transformer(
|
||||||
model = _create_vision_transformer('vit_huge_patch14_clip_336', pretrained=pretrained, **model_kwargs)
|
'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1486,9 +1494,9 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs):
|
|||||||
Pretrained weights from CLIP image tower.
|
Pretrained weights from CLIP image tower.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16,
|
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
|
||||||
pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
|
model = _create_vision_transformer(
|
||||||
model = _create_vision_transformer('vit_giant_patch14_clip_224', pretrained=pretrained, **model_kwargs)
|
'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1498,8 +1506,9 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs):
|
|||||||
def vit_base_patch32_plus_256(pretrained=False, **kwargs):
|
def vit_base_patch32_plus_256(pretrained=False, **kwargs):
|
||||||
""" ViT-Base (ViT-B/32+)
|
""" ViT-Base (ViT-B/32+)
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
|
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
|
||||||
model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer(
|
||||||
|
'vit_base_patch32_plus_256', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1507,8 +1516,9 @@ def vit_base_patch32_plus_256(pretrained=False, **kwargs):
|
|||||||
def vit_base_patch16_plus_240(pretrained=False, **kwargs):
|
def vit_base_patch16_plus_240(pretrained=False, **kwargs):
|
||||||
""" ViT-Base (ViT-B/16+)
|
""" ViT-Base (ViT-B/16+)
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
|
||||||
model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer(
|
||||||
|
'vit_base_patch16_plus_240', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1517,9 +1527,10 @@ def vit_base_patch16_rpn_224(pretrained=False, **kwargs):
|
|||||||
""" ViT-Base (ViT-B/16) w/ residual post-norm
|
""" ViT-Base (ViT-B/16) w/ residual post-norm
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False,
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
|
||||||
block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs)
|
class_token=False, block_fn=ResPostBlock, global_pool='avg')
|
||||||
model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer(
|
||||||
|
'vit_base_patch16_rpn_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1529,8 +1540,9 @@ def vit_small_patch16_36x1_224(pretrained=False, **kwargs):
|
|||||||
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
|
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
|
||||||
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
|
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5)
|
||||||
model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer(
|
||||||
|
'vit_small_patch16_36x1_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1541,8 +1553,9 @@ def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
|
|||||||
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
|
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
|
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock)
|
||||||
model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer(
|
||||||
|
'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1551,27 +1564,26 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
|
|||||||
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
|
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
|
||||||
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
|
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock)
|
||||||
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
|
model = _create_vision_transformer(
|
||||||
model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
|
'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def eva_large_patch14_196(pretrained=False, **kwargs):
|
def eva_large_patch14_196(pretrained=False, **kwargs):
|
||||||
""" EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain"""
|
""" EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
|
||||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs)
|
model = _create_vision_transformer(
|
||||||
model = _create_vision_transformer('eva_large_patch14_196', pretrained=pretrained, **model_kwargs)
|
'eva_large_patch14_196', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def eva_large_patch14_336(pretrained=False, **kwargs):
|
def eva_large_patch14_336(pretrained=False, **kwargs):
|
||||||
""" EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain"""
|
""" EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain"""
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
|
||||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs)
|
model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1579,8 +1591,8 @@ def eva_large_patch14_336(pretrained=False, **kwargs):
|
|||||||
def flexivit_small(pretrained=False, **kwargs):
|
def flexivit_small(pretrained=False, **kwargs):
|
||||||
""" FlexiViT-Small
|
""" FlexiViT-Small
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True)
|
||||||
model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1588,8 +1600,8 @@ def flexivit_small(pretrained=False, **kwargs):
|
|||||||
def flexivit_base(pretrained=False, **kwargs):
|
def flexivit_base(pretrained=False, **kwargs):
|
||||||
""" FlexiViT-Base
|
""" FlexiViT-Base
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
|
||||||
model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -1597,6 +1609,6 @@ def flexivit_base(pretrained=False, **kwargs):
|
|||||||
def flexivit_large(pretrained=False, **kwargs):
|
def flexivit_large(pretrained=False, **kwargs):
|
||||||
""" FlexiViT-Large
|
""" FlexiViT-Large
|
||||||
"""
|
"""
|
||||||
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, **kwargs)
|
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
|
||||||
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **model_kwargs)
|
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
@ -181,8 +181,18 @@ class SequentialAppendList(nn.Sequential):
|
|||||||
class OsaBlock(nn.Module):
|
class OsaBlock(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
|
self,
|
||||||
depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None):
|
in_chs,
|
||||||
|
mid_chs,
|
||||||
|
out_chs,
|
||||||
|
layer_per_block,
|
||||||
|
residual=False,
|
||||||
|
depthwise=False,
|
||||||
|
attn='',
|
||||||
|
norm_layer=BatchNormAct2d,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
drop_path=None,
|
||||||
|
):
|
||||||
super(OsaBlock, self).__init__()
|
super(OsaBlock, self).__init__()
|
||||||
|
|
||||||
self.residual = residual
|
self.residual = residual
|
||||||
@ -232,9 +242,20 @@ class OsaBlock(nn.Module):
|
|||||||
class OsaStage(nn.Module):
|
class OsaStage(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True,
|
self,
|
||||||
residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU,
|
in_chs,
|
||||||
drop_path_rates=None):
|
mid_chs,
|
||||||
|
out_chs,
|
||||||
|
block_per_stage,
|
||||||
|
layer_per_block,
|
||||||
|
downsample=True,
|
||||||
|
residual=True,
|
||||||
|
depthwise=False,
|
||||||
|
attn='ese',
|
||||||
|
norm_layer=BatchNormAct2d,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
drop_path_rates=None,
|
||||||
|
):
|
||||||
super(OsaStage, self).__init__()
|
super(OsaStage, self).__init__()
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
@ -270,16 +291,38 @@ class OsaStage(nn.Module):
|
|||||||
class VovNet(nn.Module):
|
class VovNet(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
|
self,
|
||||||
output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.):
|
cfg,
|
||||||
""" VovNet (v2)
|
in_chans=3,
|
||||||
|
num_classes=1000,
|
||||||
|
global_pool='avg',
|
||||||
|
output_stride=32,
|
||||||
|
norm_layer=BatchNormAct2d,
|
||||||
|
act_layer=nn.ReLU,
|
||||||
|
drop_rate=0.,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cfg (dict): Model architecture configuration
|
||||||
|
in_chans (int): Number of input channels (default: 3)
|
||||||
|
num_classes (int): Number of classifier classes (default: 1000)
|
||||||
|
global_pool (str): Global pooling type (default: 'avg')
|
||||||
|
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
|
||||||
|
norm_layer (Union[str, nn.Module]): normalization layer
|
||||||
|
act_layer (Union[str, nn.Module]): activation layer
|
||||||
|
drop_rate (float): Dropout rate (default: 0.)
|
||||||
|
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
|
||||||
|
kwargs (dict): Extra kwargs overlayed onto cfg
|
||||||
"""
|
"""
|
||||||
super(VovNet, self).__init__()
|
super(VovNet, self).__init__()
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.drop_rate = drop_rate
|
self.drop_rate = drop_rate
|
||||||
assert stem_stride in (4, 2)
|
|
||||||
assert output_stride == 32 # FIXME support dilation
|
assert output_stride == 32 # FIXME support dilation
|
||||||
|
|
||||||
|
cfg = dict(cfg, **kwargs)
|
||||||
|
stem_stride = cfg.get("stem_stride", 4)
|
||||||
stem_chs = cfg["stem_chs"]
|
stem_chs = cfg["stem_chs"]
|
||||||
stage_conv_chs = cfg["stage_conv_chs"]
|
stage_conv_chs = cfg["stage_conv_chs"]
|
||||||
stage_out_chs = cfg["stage_out_chs"]
|
stage_out_chs = cfg["stage_out_chs"]
|
||||||
@ -307,9 +350,15 @@ class VovNet(nn.Module):
|
|||||||
for i in range(4): # num_stages
|
for i in range(4): # num_stages
|
||||||
downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
|
downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
|
||||||
stages += [OsaStage(
|
stages += [OsaStage(
|
||||||
in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block,
|
in_ch_list[i],
|
||||||
downsample=downsample, drop_path_rates=stage_dpr[i], **stage_args)
|
stage_conv_chs[i],
|
||||||
]
|
stage_out_chs[i],
|
||||||
|
block_per_stage[i],
|
||||||
|
layer_per_block,
|
||||||
|
downsample=downsample,
|
||||||
|
drop_path_rates=stage_dpr[i],
|
||||||
|
**stage_args,
|
||||||
|
)]
|
||||||
self.num_features = stage_out_chs[i]
|
self.num_features = stage_out_chs[i]
|
||||||
current_stride *= 2 if downsample else 1
|
current_stride *= 2 if downsample else 1
|
||||||
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
|
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
|
||||||
@ -324,7 +373,6 @@ class VovNet(nn.Module):
|
|||||||
elif isinstance(m, nn.Linear):
|
elif isinstance(m, nn.Linear):
|
||||||
nn.init.zeros_(m.bias)
|
nn.init.zeros_(m.bias)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.ignore
|
@torch.jit.ignore
|
||||||
def group_matcher(self, coarse=False):
|
def group_matcher(self, coarse=False):
|
||||||
return dict(
|
return dict(
|
||||||
|
@ -8,7 +8,7 @@ from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\
|
|||||||
from .jit import set_jit_legacy, set_jit_fuser
|
from .jit import set_jit_legacy, set_jit_fuser
|
||||||
from .log import setup_default_logging, FormatterNoInfo
|
from .log import setup_default_logging, FormatterNoInfo
|
||||||
from .metrics import AverageMeter, accuracy
|
from .metrics import AverageMeter, accuracy
|
||||||
from .misc import natural_key, add_bool_arg
|
from .misc import natural_key, add_bool_arg, ParseKwargs
|
||||||
from .model import unwrap_model, get_state_dict, freeze, unfreeze
|
from .model import unwrap_model, get_state_dict, freeze, unfreeze
|
||||||
from .model_ema import ModelEma, ModelEmaV2
|
from .model_ema import ModelEma, ModelEmaV2
|
||||||
from .random import random_seed
|
from .random import random_seed
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
@ -16,3 +18,15 @@ def add_bool_arg(parser, name, default=False, help=''):
|
|||||||
group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
|
group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
|
||||||
group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
|
group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
|
||||||
parser.set_defaults(**{dest_name: default})
|
parser.set_defaults(**{dest_name: default})
|
||||||
|
|
||||||
|
|
||||||
|
class ParseKwargs(argparse.Action):
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
kw = {}
|
||||||
|
for value in values:
|
||||||
|
key, value = value.split('=')
|
||||||
|
try:
|
||||||
|
kw[key] = ast.literal_eval(value)
|
||||||
|
except ValueError:
|
||||||
|
kw[key] = str(value) # fallback to string (avoid need to escape on command line)
|
||||||
|
setattr(namespace, self.dest, kw)
|
||||||
|
32
train.py
32
train.py
@ -118,7 +118,8 @@ group.add_argument('--img-size', type=int, default=None, metavar='N',
|
|||||||
group.add_argument('--in-chans', type=int, default=None, metavar='N',
|
group.add_argument('--in-chans', type=int, default=None, metavar='N',
|
||||||
help='Image input channels (default: None => 3)')
|
help='Image input channels (default: None => 3)')
|
||||||
group.add_argument('--input-size', default=None, nargs=3, type=int,
|
group.add_argument('--input-size', default=None, nargs=3, type=int,
|
||||||
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
metavar='N N N',
|
||||||
|
help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
||||||
group.add_argument('--crop-pct', default=None, type=float,
|
group.add_argument('--crop-pct', default=None, type=float,
|
||||||
metavar='N', help='Input image center crop percent (for validation only)')
|
metavar='N', help='Input image center crop percent (for validation only)')
|
||||||
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||||
@ -139,6 +140,7 @@ group.add_argument('--grad-checkpointing', action='store_true', default=False,
|
|||||||
help='Enable gradient checkpointing through model blocks/stages')
|
help='Enable gradient checkpointing through model blocks/stages')
|
||||||
group.add_argument('--fast-norm', default=False, action='store_true',
|
group.add_argument('--fast-norm', default=False, action='store_true',
|
||||||
help='enable experimental fast-norm')
|
help='enable experimental fast-norm')
|
||||||
|
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
|
||||||
|
|
||||||
scripting_group = group.add_mutually_exclusive_group()
|
scripting_group = group.add_mutually_exclusive_group()
|
||||||
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
|
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||||
@ -166,6 +168,7 @@ group.add_argument('--clip-mode', type=str, default='norm',
|
|||||||
help='Gradient clipping mode. One of ("norm", "value", "agc")')
|
help='Gradient clipping mode. One of ("norm", "value", "agc")')
|
||||||
group.add_argument('--layer-decay', type=float, default=None,
|
group.add_argument('--layer-decay', type=float, default=None,
|
||||||
help='layer-wise learning rate decay (default: None)')
|
help='layer-wise learning rate decay (default: None)')
|
||||||
|
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
|
||||||
|
|
||||||
# Learning rate schedule parameters
|
# Learning rate schedule parameters
|
||||||
group = parser.add_argument_group('Learning rate schedule parameters')
|
group = parser.add_argument_group('Learning rate schedule parameters')
|
||||||
@ -371,8 +374,6 @@ def main():
|
|||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
if args.data and not args.data_dir:
|
|
||||||
args.data_dir = args.data
|
|
||||||
args.prefetcher = not args.no_prefetcher
|
args.prefetcher = not args.no_prefetcher
|
||||||
device = utils.init_distributed_device(args)
|
device = utils.init_distributed_device(args)
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
@ -383,14 +384,6 @@ def main():
|
|||||||
_logger.info(f'Training with a single process on 1 device ({args.device}).')
|
_logger.info(f'Training with a single process on 1 device ({args.device}).')
|
||||||
assert args.rank >= 0
|
assert args.rank >= 0
|
||||||
|
|
||||||
if utils.is_primary(args) and args.log_wandb:
|
|
||||||
if has_wandb:
|
|
||||||
wandb.init(project=args.experiment, config=args)
|
|
||||||
else:
|
|
||||||
_logger.warning(
|
|
||||||
"You've requested to log metrics to wandb but package not found. "
|
|
||||||
"Metrics not being logged to wandb, try `pip install wandb`")
|
|
||||||
|
|
||||||
# resolve AMP arguments based on PyTorch / Apex availability
|
# resolve AMP arguments based on PyTorch / Apex availability
|
||||||
use_amp = None
|
use_amp = None
|
||||||
amp_dtype = torch.float16
|
amp_dtype = torch.float16
|
||||||
@ -432,6 +425,7 @@ def main():
|
|||||||
bn_eps=args.bn_eps,
|
bn_eps=args.bn_eps,
|
||||||
scriptable=args.torchscript,
|
scriptable=args.torchscript,
|
||||||
checkpoint_path=args.initial_checkpoint,
|
checkpoint_path=args.initial_checkpoint,
|
||||||
|
**args.model_kwargs,
|
||||||
)
|
)
|
||||||
if args.num_classes is None:
|
if args.num_classes is None:
|
||||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||||
@ -504,7 +498,11 @@ def main():
|
|||||||
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
|
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
|
||||||
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
|
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
|
||||||
|
|
||||||
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
|
optimizer = create_optimizer_v2(
|
||||||
|
model,
|
||||||
|
**optimizer_kwargs(cfg=args),
|
||||||
|
**args.opt_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# setup automatic mixed-precision (AMP) loss scaling and op casting
|
# setup automatic mixed-precision (AMP) loss scaling and op casting
|
||||||
amp_autocast = suppress # do nothing
|
amp_autocast = suppress # do nothing
|
||||||
@ -559,6 +557,8 @@ def main():
|
|||||||
# NOTE: EMA model does not need to be wrapped by DDP
|
# NOTE: EMA model does not need to be wrapped by DDP
|
||||||
|
|
||||||
# create the train and eval datasets
|
# create the train and eval datasets
|
||||||
|
if args.data and not args.data_dir:
|
||||||
|
args.data_dir = args.data
|
||||||
dataset_train = create_dataset(
|
dataset_train = create_dataset(
|
||||||
args.dataset,
|
args.dataset,
|
||||||
root=args.data_dir,
|
root=args.data_dir,
|
||||||
@ -712,6 +712,14 @@ def main():
|
|||||||
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
||||||
f.write(args_text)
|
f.write(args_text)
|
||||||
|
|
||||||
|
if utils.is_primary(args) and args.log_wandb:
|
||||||
|
if has_wandb:
|
||||||
|
wandb.init(project=args.experiment, config=args)
|
||||||
|
else:
|
||||||
|
_logger.warning(
|
||||||
|
"You've requested to log metrics to wandb but package not found. "
|
||||||
|
"Metrics not being logged to wandb, try `pip install wandb`")
|
||||||
|
|
||||||
# setup learning rate schedule and starting epoch
|
# setup learning rate schedule and starting epoch
|
||||||
updates_per_epoch = len(loader_train)
|
updates_per_epoch = len(loader_train)
|
||||||
lr_scheduler, num_epochs = create_scheduler_v2(
|
lr_scheduler, num_epochs = create_scheduler_v2(
|
||||||
|
20
validate.py
20
validate.py
@ -26,7 +26,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa
|
|||||||
from timm.layers import apply_test_time_pool, set_fast_norm
|
from timm.layers import apply_test_time_pool, set_fast_norm
|
||||||
from timm.models import create_model, load_checkpoint, is_model, list_models
|
from timm.models import create_model, load_checkpoint, is_model, list_models
|
||||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
|
||||||
decay_batch_step, check_batch_size_retry
|
decay_batch_step, check_batch_size_retry, ParseKwargs
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
@ -71,6 +71,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
|
|||||||
metavar='N', help='mini-batch size (default: 256)')
|
metavar='N', help='mini-batch size (default: 256)')
|
||||||
parser.add_argument('--img-size', default=None, type=int,
|
parser.add_argument('--img-size', default=None, type=int,
|
||||||
metavar='N', help='Input image dimension, uses model default if empty')
|
metavar='N', help='Input image dimension, uses model default if empty')
|
||||||
|
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
|
||||||
|
help='Image input channels (default: None => 3)')
|
||||||
parser.add_argument('--input-size', default=None, nargs=3, type=int,
|
parser.add_argument('--input-size', default=None, nargs=3, type=int,
|
||||||
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
||||||
parser.add_argument('--use-train-size', action='store_true', default=False,
|
parser.add_argument('--use-train-size', action='store_true', default=False,
|
||||||
@ -123,6 +125,8 @@ parser.add_argument('--fuser', default='', type=str,
|
|||||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||||
parser.add_argument('--fast-norm', default=False, action='store_true',
|
parser.add_argument('--fast-norm', default=False, action='store_true',
|
||||||
help='enable experimental fast-norm')
|
help='enable experimental fast-norm')
|
||||||
|
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
||||||
|
|
||||||
|
|
||||||
scripting_group = parser.add_mutually_exclusive_group()
|
scripting_group = parser.add_mutually_exclusive_group()
|
||||||
scripting_group.add_argument('--torchscript', default=False, action='store_true',
|
scripting_group.add_argument('--torchscript', default=False, action='store_true',
|
||||||
@ -181,13 +185,20 @@ def validate(args):
|
|||||||
set_fast_norm()
|
set_fast_norm()
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
|
in_chans = 3
|
||||||
|
if args.in_chans is not None:
|
||||||
|
in_chans = args.in_chans
|
||||||
|
elif args.input_size is not None:
|
||||||
|
in_chans = args.input_size[0]
|
||||||
|
|
||||||
model = create_model(
|
model = create_model(
|
||||||
args.model,
|
args.model,
|
||||||
pretrained=args.pretrained,
|
pretrained=args.pretrained,
|
||||||
num_classes=args.num_classes,
|
num_classes=args.num_classes,
|
||||||
in_chans=3,
|
in_chans=in_chans,
|
||||||
global_pool=args.gp,
|
global_pool=args.gp,
|
||||||
scriptable=args.torchscript,
|
scriptable=args.torchscript,
|
||||||
|
**args.model_kwargs,
|
||||||
)
|
)
|
||||||
if args.num_classes is None:
|
if args.num_classes is None:
|
||||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||||
@ -232,8 +243,9 @@ def validate(args):
|
|||||||
|
|
||||||
criterion = nn.CrossEntropyLoss().to(device)
|
criterion = nn.CrossEntropyLoss().to(device)
|
||||||
|
|
||||||
|
root_dir = args.data or args.data_dir
|
||||||
dataset = create_dataset(
|
dataset = create_dataset(
|
||||||
root=args.data,
|
root=root_dir,
|
||||||
name=args.dataset,
|
name=args.dataset,
|
||||||
split=args.split,
|
split=args.split,
|
||||||
download=args.dataset_download,
|
download=args.dataset_download,
|
||||||
@ -389,7 +401,7 @@ def main():
|
|||||||
if args.model == 'all':
|
if args.model == 'all':
|
||||||
# validate all models in a list of names with pretrained checkpoints
|
# validate all models in a list of names with pretrained checkpoints
|
||||||
args.pretrained = True
|
args.pretrained = True
|
||||||
model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*_dino'])
|
model_names = list_models('convnext*', pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae'])
|
||||||
model_cfgs = [(n, '') for n in model_names]
|
model_cfgs = [(n, '') for n in model_names]
|
||||||
elif not is_model(args.model):
|
elif not is_model(args.model):
|
||||||
# model name doesn't exist, try as wildcard filter
|
# model name doesn't exist, try as wildcard filter
|
||||||
|
Loading…
x
Reference in New Issue
Block a user