[NAS] Modify the infrasturcture for nas training

This commit is contained in:
brian1009 2023-03-12 17:52:14 +08:00
parent 0e1a79af52
commit afadc46e64
4 changed files with 357 additions and 157 deletions

View File

@ -28,6 +28,8 @@ def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
print_freq = 10 print_freq = 10
for samples, targets in metric_logger.log_every(data_loader, print_freq, header): for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
if args.nas_mode:
model.module.set_random_sample_config()
samples = samples.to(device, non_blocking=True) samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True)

62
main.py
View File

@ -27,7 +27,8 @@ from augment import new_data_aug_generator
import models import models
import models_v2 import models_v2
import model_sparse
import random
import utils import utils
from sparsity_factory.pruners import weight_pruner_loader, prune_weights_reparam, check_valid_pruner from sparsity_factory.pruners import weight_pruner_loader, prune_weights_reparam, check_valid_pruner
@ -173,7 +174,7 @@ def get_args_parser():
parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation") parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--num_workers', default=16, type=int)
parser.add_argument('--pin-mem', action='store_true', parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
@ -185,13 +186,24 @@ def get_args_parser():
help='number of distributed processes') help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
# sparsity parameters # Sparsity Training Related Flag
parser.add_argument('--pruner', type=str, help='pruning criterion') parser.add_argument('--nas-config', type=str, help='configuration for supernet training')
parser.add_argument('--sparsity', type=float, default=1.0, help = 'the sparisty level (ratio of unpruned weight)') parser.add_argument('--nas-mode', action='store_true')
parser.add_argument('--custom-config', type=str, help='customized configuration of sparsity level for each linear layer')
return parser return parser
def gen_random_config_fn(config):
if utils.get_rank() == 0 : # print whether to use non_unifrom at initialization at main process
print(f"Set up the uniform sampling function")
def _fn_uni():
def weights(ratios):
return [1 for _ in ratios]
res = []
for ratios in config['sparsity']['choices']:
res.append(random.choices(ratios, weights(ratios))[0])
return res
return _fn_uni
def main(args): def main(args):
utils.init_distributed_mode(args) utils.init_distributed_mode(args)
@ -263,6 +275,9 @@ def main(args):
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes) label_smoothing=args.smoothing, num_classes=args.nb_classes)
with open(args.nas_config) as f:
nas_config = yaml.load(f, Loader=SafeLoader)
print(f"Creating model: {args.model}") print(f"Creating model: {args.model}")
model = create_model( model = create_model(
args.model, args.model,
@ -274,29 +289,6 @@ def main(args):
img_size=args.input_size img_size=args.input_size
) )
if args.pruner == 'custom':
if args.custom_config:
with open(args.custom_config) as f:
config = yaml.load(f, Loader=SafeLoader)
else:
raise ValueError("Please provide the configuration file when using the custom mode")
mode = config['sparsity']['mode']
sparsity_config = config['sparsity']['level']
pruner = weight_pruner_loader(args.pruner)
pruner(model, mode, sparsity_config)
elif check_valid_pruner(args.pruner):
pruner = weight_pruner_loader(args.pruner)
prune_weights_reparam(model)
pruner(model, args.sparsity)
else:
raise ValueError(f"Pruner '{args.pruner}' is not supported")
if args.finetune: if args.finetune:
if args.finetune.startswith('https'): if args.finetune.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url( checkpoint = torch.hub.load_state_dict_from_url(
@ -370,6 +362,15 @@ def main(args):
if args.distributed: if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module model_without_ddp = model.module
if args.nas_mode:
smallest_config = []
for ratios in nas_config['sparsity']['choices']:
smallest_config.append(ratios[0])
model_without_ddp.set_random_config_fn(gen_random_config_fn(nas_config))
model_without_ddp.set_sample_config(smallest_config)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters) print('number of params:', n_parameters)
if not args.unscale_lr: if not args.unscale_lr:
@ -493,9 +494,6 @@ def main(args):
'epoch': epoch, 'epoch': epoch,
'n_parameters': n_parameters} 'n_parameters': n_parameters}
if args.output_dir and utils.is_main_process(): if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f: with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n") f.write(json.dumps(log_stats) + "\n")

View File

@ -313,7 +313,6 @@ class SparseVisionTransformer(nn.Module):
def set_sample_config(self, sparse_configs): def set_sample_config(self, sparse_configs):
for ratio, layer in zip(sparse_configs, filter(lambda x: isinstance(x, SparseLinearSuper), self.modules())): for ratio, layer in zip(sparse_configs, filter(lambda x: isinstance(x, SparseLinearSuper), self.modules())):
#print(ratio, layer)
layer.set_sample_config(ratio) layer.set_sample_config(ratio)
def set_random_config_fn(self, fn): def set_random_config_fn(self, fn):

429
models.py
View File

@ -1,82 +1,348 @@
# Copyright (c) 2015-present, Facebook, Inc. """ Vision Transformer (ViT) in PyTorch
# All rights reserved.
A PyTorch implement of Vision Transformers as described in:
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
- https://arxiv.org/abs/2010.11929
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- https://arxiv.org/abs/2106.10270
The official jax code is released and available at https://github.com/google-research/vision_transformer
DeiT model defs and weights from https://github.com/facebookresearch/deit,
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2020, Ross Wightman
"""
import math
import logging
from functools import partial
from collections import OrderedDict
from copy import deepcopy
from statistics import mode
import torch import torch
import torch.nn as nn import torch.nn as nn
from functools import partial import torch.nn.functional as F
from timm.models.vision_transformer import VisionTransformer, _cfg from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.models.helpers import load_pretrained
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model from timm.models.registry import register_model
from timm.models.layers import trunc_normal_
from sparse_linear import SparseLinearSuper
_logger = logging.getLogger(__name__)
__all__ = [ def _cfg(url='', **kwargs):
'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224', return {
'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224', 'url': url,
'deit_base_distilled_patch16_224', 'deit_base_patch16_384', 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'deit_base_distilled_patch16_384', 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
] 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
class DistilledVisionTransformer(VisionTransformer): default_cfgs = {
def __init__(self, *args, **kwargs): # patch models
super().__init__(*args, **kwargs) 'vit_small_patch16_224': _cfg(
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
),
'vit_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
'vit_base_patch16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
'vit_base_patch32_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
'vit_large_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_large_patch16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
'vit_large_patch32_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
'vit_huge_patch16_224': _cfg(),
'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)),
# hybrid models
'vit_small_resnet26d_224': _cfg(),
'vit_small_resnet50d_s3_224': _cfg(),
'vit_base_resnet26d_224': _cfg(),
'vit_base_resnet50d_224': _cfg(),
}
class LRMlpSuper(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.act = act_layer()
self.drop = nn.Dropout(drop)
self.fc1 = SparseLinearSuper(in_features, hidden_features)
self.fc2 = SparseLinearSuper(hidden_features, out_features)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class LRAttentionSuper(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.proj = SparseLinearSuper(dim, dim)
self.qkv = SparseLinearSuper(dim, dim * 3, bias = qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = LRAttentionSuper(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, )
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = LRMlpSuper(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
def num_params(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
class SparseVisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
- https://arxiv.org/abs/2012.12877
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None, weight_init='', ):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
weight_init: (str): weight init scheme
"""
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 2 if distilled else 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
trunc_normal_(self.dist_token, std=.02) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# Representation layer
if representation_size and not distilled:
self.num_features = representation_size
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(embed_dim, representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
# Classifier head(s)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
self.head_dist.apply(self._init_weights) trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'dist_token'}
def get_classifier(self):
if self.dist_token is None:
return self.head
else:
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.num_tokens == 2:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add the dist_token
B = x.shape[0]
x = self.patch_embed(x) x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks if self.dist_token is None:
dist_token = self.dist_token.expand(B, -1, -1) x = torch.cat((cls_token, x), dim=1)
x = torch.cat((cls_tokens, dist_token, x), dim=1) else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.pos_embed x = self.pos_drop(x + self.pos_embed)
x = self.pos_drop(x) x = self.blocks(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x) x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1] return x[:, 0], x[:, 1]
def forward(self, x): def forward(self, x):
x, x_dist = self.forward_features(x) x = self.forward_features(x)
x = self.head(x) if self.head_dist is not None:
x_dist = self.head_dist(x_dist) x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
if self.training: if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist return x, x_dist
else: else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2 return (x + x_dist) / 2
else:
x = self.head(x)
return x
@register_model @register_model
def deit_tiny_patch16_224(pretrained=False, **kwargs): def deit_base_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer( """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model = SparseVisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg() model.default_cfg = _cfg()
if pretrained: if pretrained:
checkpoint = torch.hub.load_state_dict_from_url( checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
map_location="cpu", check_hash=True map_location="cpu", check_hash=True
) )
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
return model return model
@register_model @register_model
def deit_small_patch16_224(pretrained=False, **kwargs): def deit_small_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer( """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model = SparseVisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg() model.default_cfg = _cfg()
@ -90,90 +356,25 @@ def deit_small_patch16_224(pretrained=False, **kwargs):
@register_model @register_model
def deit_base_patch16_224(pretrained=False, **kwargs): def deit_tiny_patch16_224(pretrained=False, **kwargs):
model = VisionTransformer( """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, ImageNet-1k weights from https://github.com/facebookresearch/deit.
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) """
model.default_cfg = _cfg() model = SparseVisionTransformer(
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
model = DistilledVisionTransformer(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg() model.default_cfg = _cfg()
if pretrained: if pretrained:
checkpoint = torch.hub.load_state_dict_from_url( checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth", url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
map_location="cpu", check_hash=True map_location="cpu", check_hash=True
) )
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
return model return model
@register_model
def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
model = DistilledVisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
model = DistilledVisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_patch16_384(pretrained=False, **kwargs):
model = VisionTransformer(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model
@register_model
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
model = DistilledVisionTransformer(
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
map_location="cpu", check_hash=True
)
model.load_state_dict(checkpoint["model"])
return model