mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge branch 'rwightman:main' into main
This commit is contained in:
commit
84178fca60
22
README.md
22
README.md
@ -21,7 +21,25 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
|
||||
|
||||
## What's New
|
||||
|
||||
# Dec 6, 2022
|
||||
### 🤗 Survey: Feedback Appreciated 🤗
|
||||
|
||||
For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly, we survey users of our tools to see what we could do better, what we need to continue doing, or what we need to stop doing.
|
||||
|
||||
If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts:
|
||||
[**hf.co/oss-survey**](https://hf.co/oss-survey) 🙏
|
||||
|
||||
### Dec 8, 2022
|
||||
* Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some)
|
||||
* original source: https://github.com/baaivision/EVA
|
||||
|
||||
| model | top1 | param_count | gmac | macts | hub |
|
||||
|:------------------------------------------|-----:|------------:|------:|------:|:----------------------------------------|
|
||||
| eva_large_patch14_336.in22k_ft_in22k_in1k | 89.2 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) |
|
||||
| eva_large_patch14_336.in22k_ft_in1k | 88.7 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) |
|
||||
| eva_large_patch14_196.in22k_ft_in22k_in1k | 88.6 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) |
|
||||
| eva_large_patch14_196.in22k_ft_in1k | 87.9 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) |
|
||||
|
||||
### Dec 6, 2022
|
||||
* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`.
|
||||
* original source: https://github.com/baaivision/EVA
|
||||
* paper: https://arxiv.org/abs/2211.07636
|
||||
@ -33,7 +51,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
|
||||
| eva_giant_patch14_336.clip_ft_in1k | 89.4 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) |
|
||||
| eva_giant_patch14_224.clip_ft_in1k | 89.1 | 1012.6 | 267.2 | 192.6 | [link](https://huggingface.co/BAAI/EVA) |
|
||||
|
||||
# Dec 5, 2022
|
||||
### Dec 5, 2022
|
||||
|
||||
* Pre-release (`0.8.0dev0`) of multi-weight support (`model_arch.pretrained_tag`). Install with `pip install --pre timm`
|
||||
* vision_transformer, maxvit, convnext are the first three model impl w/ support
|
||||
|
@ -16,7 +16,7 @@ import argparse
|
||||
import os
|
||||
import glob
|
||||
import hashlib
|
||||
from timm.models.helpers import load_state_dict
|
||||
from timm.models import load_state_dict
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
|
||||
parser.add_argument('--input', default='', type=str, metavar='PATH',
|
||||
|
@ -19,7 +19,8 @@ import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
|
||||
from timm.data import resolve_data_config
|
||||
from timm.models import create_model, is_model, list_models, set_fast_norm
|
||||
from timm.layers import set_fast_norm
|
||||
from timm.models import create_model, is_model, list_models
|
||||
from timm.optim import create_optimizer_v2
|
||||
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
|
||||
|
||||
|
@ -13,7 +13,7 @@ import os
|
||||
import hashlib
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from timm.models.helpers import load_state_dict
|
||||
from timm.models import load_state_dict
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
|
@ -1,4 +1,3 @@
|
||||
dependencies = ['torch']
|
||||
from timm.models import registry
|
||||
|
||||
globals().update(registry._model_entrypoints)
|
||||
import timm
|
||||
globals().update(timm.models._registry._model_entrypoints)
|
||||
|
@ -5,11 +5,11 @@ An example inference script that outputs top-k class ids for images in a folder
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from functools import partial
|
||||
|
||||
@ -17,12 +17,11 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config
|
||||
from timm.layers import apply_test_time_pool
|
||||
from timm.models import create_model
|
||||
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser
|
||||
|
||||
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
has_apex = True
|
||||
|
@ -1,10 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import platform
|
||||
import os
|
||||
|
||||
from timm.models.layers import create_act_layer, get_act_layer, set_layer_config
|
||||
from timm.layers import create_act_layer, set_layer_config
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
@ -14,7 +14,7 @@ except ImportError:
|
||||
|
||||
import timm
|
||||
from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value
|
||||
from timm.models.fx_features import _leaf_modules, _autowrap_functions
|
||||
from timm.models._features_fx import _leaf_modules, _autowrap_functions
|
||||
|
||||
if hasattr(torch._C, '_jit_set_profiling_executor'):
|
||||
# legacy executor is too slow to compile large models for unit tests
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .version import __version__
|
||||
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable
|
||||
from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
|
||||
is_scriptable, is_exportable, set_scriptable, set_exportable, \
|
||||
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
||||
|
@ -1,4 +1,4 @@
|
||||
""" AutoAugment, RandAugment, and AugMix for PyTorch
|
||||
""" AutoAugment, RandAugment, AugMix, and 3-Augment for PyTorch
|
||||
|
||||
This code implements the searched ImageNet policies with various tweaks and improvements and
|
||||
does not include any of the search code.
|
||||
@ -9,18 +9,24 @@ AA and RA Implementation adapted from:
|
||||
AugMix adapted from:
|
||||
https://github.com/google-research/augmix
|
||||
|
||||
3-Augment based on: https://github.com/facebookresearch/deit/blob/main/README_revenge.md
|
||||
|
||||
Papers:
|
||||
AutoAugment: Learning Augmentation Policies from Data - https://arxiv.org/abs/1805.09501
|
||||
Learning Data Augmentation Strategies for Object Detection - https://arxiv.org/abs/1906.11172
|
||||
RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719
|
||||
AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty - https://arxiv.org/abs/1912.02781
|
||||
3-Augment: DeiT III: Revenge of the ViT - https://arxiv.org/abs/2204.07118
|
||||
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
import random
|
||||
import math
|
||||
import re
|
||||
from PIL import Image, ImageOps, ImageEnhance, ImageChops
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter
|
||||
import PIL
|
||||
import numpy as np
|
||||
|
||||
@ -175,6 +181,24 @@ def sharpness(img, factor, **__):
|
||||
return ImageEnhance.Sharpness(img).enhance(factor)
|
||||
|
||||
|
||||
def gaussian_blur(img, factor, **__):
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
|
||||
return img
|
||||
|
||||
|
||||
def gaussian_blur_rand(img, factor, **__):
|
||||
radius_min = 0.1
|
||||
radius_max = 2.0
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor)))
|
||||
return img
|
||||
|
||||
|
||||
def desaturate(img, factor, **_):
|
||||
factor = min(1., max(0., 1. - factor))
|
||||
# enhance factor 0 = grayscale, 1.0 = no-change
|
||||
return ImageEnhance.Color(img).enhance(factor)
|
||||
|
||||
|
||||
def _randomly_negate(v):
|
||||
"""With 50% prob, negate the value"""
|
||||
return -v if random.random() > 0.5 else v
|
||||
@ -200,6 +224,14 @@ def _enhance_increasing_level_to_arg(level, _hparams):
|
||||
return level,
|
||||
|
||||
|
||||
def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
|
||||
level = (level / _LEVEL_DENOM)
|
||||
min_val + (max_val - min_val) * level
|
||||
if clamp:
|
||||
level = max(min_val, min(max_val, level))
|
||||
return level,
|
||||
|
||||
|
||||
def _shear_level_to_arg(level, _hparams):
|
||||
# range [-0.3, 0.3]
|
||||
level = (level / _LEVEL_DENOM) * 0.3
|
||||
@ -246,7 +278,7 @@ def _posterize_original_level_to_arg(level, _hparams):
|
||||
def _solarize_level_to_arg(level, _hparams):
|
||||
# range [0, 256]
|
||||
# intensity/severity of augmentation decreases with level
|
||||
return int((level / _LEVEL_DENOM) * 256),
|
||||
return min(256, int((level / _LEVEL_DENOM) * 256)),
|
||||
|
||||
|
||||
def _solarize_increasing_level_to_arg(level, _hparams):
|
||||
@ -257,7 +289,7 @@ def _solarize_increasing_level_to_arg(level, _hparams):
|
||||
|
||||
def _solarize_add_level_to_arg(level, _hparams):
|
||||
# range [0, 110]
|
||||
return int((level / _LEVEL_DENOM) * 110),
|
||||
return min(128, int((level / _LEVEL_DENOM) * 110)),
|
||||
|
||||
|
||||
LEVEL_TO_ARG = {
|
||||
@ -286,6 +318,9 @@ LEVEL_TO_ARG = {
|
||||
'TranslateY': _translate_abs_level_to_arg,
|
||||
'TranslateXRel': _translate_rel_level_to_arg,
|
||||
'TranslateYRel': _translate_rel_level_to_arg,
|
||||
'Desaturate': partial(_minmax_level_to_arg, min_val=0.5, max_val=1.0),
|
||||
'GaussianBlur': partial(_minmax_level_to_arg, min_val=0.1, max_val=2.0),
|
||||
'GaussianBlurRand': _minmax_level_to_arg,
|
||||
}
|
||||
|
||||
|
||||
@ -314,6 +349,9 @@ NAME_TO_OP = {
|
||||
'TranslateY': translate_y_abs,
|
||||
'TranslateXRel': translate_x_rel,
|
||||
'TranslateYRel': translate_y_rel,
|
||||
'Desaturate': desaturate,
|
||||
'GaussianBlur': gaussian_blur,
|
||||
'GaussianBlurRand': gaussian_blur_rand,
|
||||
}
|
||||
|
||||
|
||||
@ -347,6 +385,7 @@ class AugmentOp:
|
||||
if self.magnitude_std > 0:
|
||||
# magnitude randomization enabled
|
||||
if self.magnitude_std == float('inf'):
|
||||
# inf == uniform sampling
|
||||
magnitude = random.uniform(0, magnitude)
|
||||
elif self.magnitude_std > 0:
|
||||
magnitude = random.gauss(magnitude, self.magnitude_std)
|
||||
@ -499,6 +538,16 @@ def auto_augment_policy_originalr(hparams):
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy_3a(hparams):
|
||||
policy = [
|
||||
[('Solarize', 1.0, 5)], # 128 solarize threshold @ 5 magnitude
|
||||
[('Desaturate', 1.0, 10)], # grayscale at 10 magnitude
|
||||
[('GaussianBlurRand', 1.0, 10)],
|
||||
]
|
||||
pc = [[AugmentOp(*a, hparams=hparams) for a in sp] for sp in policy]
|
||||
return pc
|
||||
|
||||
|
||||
def auto_augment_policy(name='v0', hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
if name == 'original':
|
||||
@ -509,6 +558,8 @@ def auto_augment_policy(name='v0', hparams=None):
|
||||
return auto_augment_policy_v0(hparams)
|
||||
elif name == 'v0r':
|
||||
return auto_augment_policy_v0r(hparams)
|
||||
elif name == '3a':
|
||||
return auto_augment_policy_3a(hparams)
|
||||
else:
|
||||
assert False, 'Unknown AA policy (%s)' % name
|
||||
|
||||
@ -534,19 +585,23 @@ class AutoAugment:
|
||||
return fs
|
||||
|
||||
|
||||
def auto_augment_transform(config_str, hparams):
|
||||
def auto_augment_transform(config_str: str, hparams: Optional[Dict] = None):
|
||||
"""
|
||||
Create a AutoAugment transform
|
||||
|
||||
:param config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
|
||||
The remaining sections, not order sepecific determine
|
||||
'mstd' - float std deviation of magnitude noise applied
|
||||
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
|
||||
Args:
|
||||
config_str: String defining configuration of auto augmentation. Consists of multiple sections separated by
|
||||
dashes ('-').
|
||||
The first section defines the AutoAugment policy (one of 'v0', 'v0r', 'original', 'originalr').
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the AutoAugmentation scheme
|
||||
The remaining sections:
|
||||
'mstd' - float std deviation of magnitude noise applied
|
||||
Ex 'original-mstd0.5' results in AutoAugment with original policy, magnitude_std 0.5
|
||||
|
||||
:return: A PyTorch compatible Transform
|
||||
hparams: Other hparams (kwargs) for the AutoAugmentation scheme
|
||||
|
||||
Returns:
|
||||
A PyTorch compatible Transform
|
||||
"""
|
||||
config = config_str.split('-')
|
||||
policy_name = config[0]
|
||||
@ -605,42 +660,80 @@ _RAND_INCREASING_TRANSFORMS = [
|
||||
]
|
||||
|
||||
|
||||
_RAND_3A = [
|
||||
'SolarizeIncreasing',
|
||||
'Desaturate',
|
||||
'GaussianBlur',
|
||||
]
|
||||
|
||||
|
||||
_RAND_CHOICE_3A = {
|
||||
'SolarizeIncreasing': 6,
|
||||
'Desaturate': 6,
|
||||
'GaussianBlur': 6,
|
||||
'Rotate': 3,
|
||||
'ShearX': 2,
|
||||
'ShearY': 2,
|
||||
'PosterizeIncreasing': 1,
|
||||
'AutoContrast': 1,
|
||||
'ColorIncreasing': 1,
|
||||
'SharpnessIncreasing': 1,
|
||||
'ContrastIncreasing': 1,
|
||||
'BrightnessIncreasing': 1,
|
||||
'Equalize': 1,
|
||||
'Invert': 1,
|
||||
}
|
||||
|
||||
|
||||
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
||||
# They may not result in increased performance, but could likely be tuned to so.
|
||||
_RAND_CHOICE_WEIGHTS_0 = {
|
||||
'Rotate': 0.3,
|
||||
'ShearX': 0.2,
|
||||
'ShearY': 0.2,
|
||||
'TranslateXRel': 0.1,
|
||||
'TranslateYRel': 0.1,
|
||||
'Color': .025,
|
||||
'Sharpness': 0.025,
|
||||
'AutoContrast': 0.025,
|
||||
'Solarize': .005,
|
||||
'SolarizeAdd': .005,
|
||||
'Contrast': .005,
|
||||
'Brightness': .005,
|
||||
'Equalize': .005,
|
||||
'Posterize': 0,
|
||||
'Invert': 0,
|
||||
'Rotate': 3,
|
||||
'ShearX': 2,
|
||||
'ShearY': 2,
|
||||
'TranslateXRel': 1,
|
||||
'TranslateYRel': 1,
|
||||
'ColorIncreasing': .25,
|
||||
'SharpnessIncreasing': 0.25,
|
||||
'AutoContrast': 0.25,
|
||||
'SolarizeIncreasing': .05,
|
||||
'SolarizeAdd': .05,
|
||||
'ContrastIncreasing': .05,
|
||||
'BrightnessIncreasing': .05,
|
||||
'Equalize': .05,
|
||||
'PosterizeIncreasing': 0.05,
|
||||
'Invert': 0.05,
|
||||
}
|
||||
|
||||
|
||||
def _select_rand_weights(weight_idx=0, transforms=None):
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
assert weight_idx == 0 # only one set of weights currently
|
||||
rand_weights = _RAND_CHOICE_WEIGHTS_0
|
||||
probs = [rand_weights[k] for k in transforms]
|
||||
probs /= np.sum(probs)
|
||||
return probs
|
||||
def _get_weighted_transforms(transforms: Dict):
|
||||
transforms, probs = list(zip(*transforms.items()))
|
||||
probs = np.array(probs)
|
||||
probs = probs / np.sum(probs)
|
||||
return transforms, probs
|
||||
|
||||
|
||||
def rand_augment_ops(magnitude=10, hparams=None, transforms=None):
|
||||
def rand_augment_choices(name: str, increasing=True):
|
||||
if name == 'weights':
|
||||
return _RAND_CHOICE_WEIGHTS_0
|
||||
elif name == '3aw':
|
||||
return _RAND_CHOICE_3A
|
||||
elif name == '3a':
|
||||
return _RAND_3A
|
||||
else:
|
||||
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
|
||||
|
||||
|
||||
def rand_augment_ops(
|
||||
magnitude: Union[int, float] = 10,
|
||||
prob: float = 0.5,
|
||||
hparams: Optional[Dict] = None,
|
||||
transforms: Optional[Union[Dict, List]] = None,
|
||||
):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _RAND_TRANSFORMS
|
||||
return [AugmentOp(
|
||||
name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
|
||||
|
||||
class RandAugment:
|
||||
@ -648,11 +741,16 @@ class RandAugment:
|
||||
self.ops = ops
|
||||
self.num_layers = num_layers
|
||||
self.choice_weights = choice_weights
|
||||
print(self.ops, self.choice_weights)
|
||||
|
||||
def __call__(self, img):
|
||||
# no replacement when using weighted choice
|
||||
ops = np.random.choice(
|
||||
self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights)
|
||||
self.ops,
|
||||
self.num_layers,
|
||||
replace=self.choice_weights is None,
|
||||
p=self.choice_weights,
|
||||
)
|
||||
for op in ops:
|
||||
img = op(img)
|
||||
return img
|
||||
@ -665,61 +763,84 @@ class RandAugment:
|
||||
return fs
|
||||
|
||||
|
||||
def rand_augment_transform(config_str, hparams):
|
||||
def rand_augment_transform(
|
||||
config_str: str,
|
||||
hparams: Optional[Dict] = None,
|
||||
transforms: Optional[Union[str, Dict, List]] = None,
|
||||
):
|
||||
"""
|
||||
Create a RandAugment transform
|
||||
|
||||
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
||||
sections, not order sepecific determine
|
||||
'm' - integer magnitude of rand augment
|
||||
'n' - integer num layers (number of transform ops selected per image)
|
||||
'w' - integer probabiliy weight index (index of a set of weights to influence choice of op)
|
||||
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
|
||||
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
|
||||
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
||||
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
||||
'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2
|
||||
Args:
|
||||
config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
|
||||
by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
|
||||
The remaining sections, not order sepecific determine
|
||||
'm' - integer magnitude of rand augment
|
||||
'n' - integer num layers (number of transform ops selected per image)
|
||||
'p' - float probability of applying each layer (default 0.5)
|
||||
'mstd' - float std deviation of magnitude noise applied, or uniform sampling if infinity (or > 100)
|
||||
'mmax' - set upper bound for magnitude to something other than default of _LEVEL_DENOM (10)
|
||||
'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0)
|
||||
't' - str name of transform set to use
|
||||
Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5
|
||||
'rand-mstd1-tweights' results in mag std 1.0, weighted transforms, default mag of 10 and num_layers 2
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the RandAugmentation scheme
|
||||
hparams (dict): Other hparams (kwargs) for the RandAugmentation scheme
|
||||
|
||||
:return: A PyTorch compatible Transform
|
||||
Returns:
|
||||
A PyTorch compatible Transform
|
||||
"""
|
||||
magnitude = _LEVEL_DENOM # default to _LEVEL_DENOM for magnitude (currently 10)
|
||||
num_layers = 2 # default to 2 ops per image
|
||||
weight_idx = None # default to no probability weights for op choice
|
||||
transforms = _RAND_TRANSFORMS
|
||||
increasing = False
|
||||
prob = 0.5
|
||||
config = config_str.split('-')
|
||||
assert config[0] == 'rand'
|
||||
config = config[1:]
|
||||
for c in config:
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param / randomization of magnitude values
|
||||
mstd = float(val)
|
||||
if mstd > 100:
|
||||
# use uniform sampling in 0 to magnitude if mstd is > 100
|
||||
mstd = float('inf')
|
||||
hparams.setdefault('magnitude_std', mstd)
|
||||
elif key == 'mmax':
|
||||
# clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
|
||||
hparams.setdefault('magnitude_max', int(val))
|
||||
elif key == 'inc':
|
||||
if bool(val):
|
||||
transforms = _RAND_INCREASING_TRANSFORMS
|
||||
elif key == 'm':
|
||||
magnitude = int(val)
|
||||
elif key == 'n':
|
||||
num_layers = int(val)
|
||||
elif key == 'w':
|
||||
weight_idx = int(val)
|
||||
if c.startswith('t'):
|
||||
# NOTE old 'w' key was removed, 'w0' is not equivalent to 'tweights'
|
||||
val = str(c[1:])
|
||||
if transforms is None:
|
||||
transforms = val
|
||||
else:
|
||||
assert False, 'Unknown RandAugment config section'
|
||||
ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams, transforms=transforms)
|
||||
choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx)
|
||||
# numeric options
|
||||
cs = re.split(r'(\d.*)', c)
|
||||
if len(cs) < 2:
|
||||
continue
|
||||
key, val = cs[:2]
|
||||
if key == 'mstd':
|
||||
# noise param / randomization of magnitude values
|
||||
mstd = float(val)
|
||||
if mstd > 100:
|
||||
# use uniform sampling in 0 to magnitude if mstd is > 100
|
||||
mstd = float('inf')
|
||||
hparams.setdefault('magnitude_std', mstd)
|
||||
elif key == 'mmax':
|
||||
# clip magnitude between [0, mmax] instead of default [0, _LEVEL_DENOM]
|
||||
hparams.setdefault('magnitude_max', int(val))
|
||||
elif key == 'inc':
|
||||
if bool(val):
|
||||
increasing = True
|
||||
elif key == 'm':
|
||||
magnitude = int(val)
|
||||
elif key == 'n':
|
||||
num_layers = int(val)
|
||||
elif key == 'p':
|
||||
prob = float(val)
|
||||
else:
|
||||
assert False, 'Unknown RandAugment config section'
|
||||
|
||||
if isinstance(transforms, str):
|
||||
transforms = rand_augment_choices(transforms, increasing=increasing)
|
||||
elif transforms is None:
|
||||
transforms = _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
|
||||
|
||||
choice_weights = None
|
||||
if isinstance(transforms, Dict):
|
||||
transforms, choice_weights = _get_weighted_transforms(transforms)
|
||||
|
||||
ra_ops = rand_augment_ops(magnitude=magnitude, prob=prob, hparams=hparams, transforms=transforms)
|
||||
return RandAugment(ra_ops, num_layers, choice_weights=choice_weights)
|
||||
|
||||
|
||||
@ -740,11 +861,19 @@ _AUGMIX_TRANSFORMS = [
|
||||
]
|
||||
|
||||
|
||||
def augmix_ops(magnitude=10, hparams=None, transforms=None):
|
||||
def augmix_ops(
|
||||
magnitude: Union[int, float] = 10,
|
||||
hparams: Optional[Dict] = None,
|
||||
transforms: Optional[Union[str, Dict, List]] = None,
|
||||
):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
transforms = transforms or _AUGMIX_TRANSFORMS
|
||||
return [AugmentOp(
|
||||
name, prob=1.0, magnitude=magnitude, hparams=hparams) for name in transforms]
|
||||
name,
|
||||
prob=1.0,
|
||||
magnitude=magnitude,
|
||||
hparams=hparams
|
||||
) for name in transforms]
|
||||
|
||||
|
||||
class AugMixAugment:
|
||||
@ -820,22 +949,24 @@ class AugMixAugment:
|
||||
return fs
|
||||
|
||||
|
||||
def augment_and_mix_transform(config_str, hparams):
|
||||
def augment_and_mix_transform(config_str: str, hparams: Optional[Dict] = None):
|
||||
""" Create AugMix PyTorch transform
|
||||
|
||||
:param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by
|
||||
dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining
|
||||
sections, not order sepecific determine
|
||||
'm' - integer magnitude (severity) of augmentation mix (default: 3)
|
||||
'w' - integer width of augmentation chain (default: 3)
|
||||
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
|
||||
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
|
||||
'mstd' - float std deviation of magnitude noise applied (default: 0)
|
||||
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
|
||||
Args:
|
||||
config_str (str): String defining configuration of random augmentation. Consists of multiple sections separated
|
||||
by dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand').
|
||||
The remaining sections, not order sepecific determine
|
||||
'm' - integer magnitude (severity) of augmentation mix (default: 3)
|
||||
'w' - integer width of augmentation chain (default: 3)
|
||||
'd' - integer depth of augmentation chain (-1 is random [1, 3], default: -1)
|
||||
'b' - integer (bool), blend each branch of chain into end result without a final blend, less CPU (default: 0)
|
||||
'mstd' - float std deviation of magnitude noise applied (default: 0)
|
||||
Ex 'augmix-m5-w4-d2' results in AugMix with severity 5, chain width 4, chain depth 2
|
||||
|
||||
:param hparams: Other hparams (kwargs) for the Augmentation transforms
|
||||
hparams: Other hparams (kwargs) for the Augmentation transforms
|
||||
|
||||
:return: A PyTorch compatible Transform
|
||||
Returns:
|
||||
A PyTorch compatible Transform
|
||||
"""
|
||||
magnitude = 3
|
||||
width = 3
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
def load_class_map(map_or_filename, root=''):
|
||||
if isinstance(map_or_filename, dict):
|
||||
assert dict, 'class_map dict must be non-empty'
|
||||
@ -14,7 +15,7 @@ def load_class_map(map_or_filename, root=''):
|
||||
with open(class_map_path) as f:
|
||||
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
|
||||
elif class_map_ext == '.pkl':
|
||||
with open(class_map_path,'rb') as f:
|
||||
with open(class_map_path, 'rb') as f:
|
||||
class_to_idx = pickle.load(f)
|
||||
else:
|
||||
assert False, f'Unsupported class map file extension ({class_map_ext}).'
|
||||
|
@ -59,6 +59,7 @@ def transforms_imagenet_train(
|
||||
re_count=1,
|
||||
re_num_splits=0,
|
||||
separate=False,
|
||||
force_color_jitter=False,
|
||||
):
|
||||
"""
|
||||
If separate==True, the transforms are returned as a tuple of 3 separate transforms
|
||||
@ -77,8 +78,12 @@ def transforms_imagenet_train(
|
||||
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
|
||||
|
||||
secondary_tfl = []
|
||||
disable_color_jitter = False
|
||||
if auto_augment:
|
||||
assert isinstance(auto_augment, str)
|
||||
# color jitter is typically disabled if AA/RA on,
|
||||
# this allows override without breaking old hparm cfgs
|
||||
disable_color_jitter = not (force_color_jitter or '3a' in auto_augment)
|
||||
if isinstance(img_size, (tuple, list)):
|
||||
img_size_min = min(img_size)
|
||||
else:
|
||||
@ -96,8 +101,9 @@ def transforms_imagenet_train(
|
||||
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
|
||||
else:
|
||||
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
|
||||
elif color_jitter is not None:
|
||||
# color jitter is enabled when not using AA
|
||||
|
||||
if color_jitter is not None and not disable_color_jitter:
|
||||
# color jitter is enabled when not using AA or when forced
|
||||
if isinstance(color_jitter, (list, tuple)):
|
||||
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
|
||||
# or 4 if also augmenting hue
|
||||
|
44
timm/layers/__init__.py
Normal file
44
timm/layers/__init__.py
Normal file
@ -0,0 +1,44 @@
|
||||
from .activations import *
|
||||
from .adaptive_avgmax_pool import \
|
||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||
from .blur_pool import BlurPool2d
|
||||
from .classifier import ClassifierHead, create_classifier
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
|
||||
set_layer_config
|
||||
from .conv2d_same import Conv2dSame, conv2d_same
|
||||
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
|
||||
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||
from .create_attn import get_attn, create_attn
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_norm import get_norm_layer, create_norm_layer
|
||||
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
|
||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
|
||||
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
|
||||
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
|
||||
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
|
||||
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
|
||||
from .gather_excite import GatherExcite
|
||||
from .global_context import GlobalContext
|
||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
|
||||
from .inplace_abn import InplaceAbn
|
||||
from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
|
||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
|
||||
from .padding import get_padding, get_same_padding, pad_same
|
||||
from .patch_embed import PatchEmbed
|
||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||
from .selective_kernel import SelectiveKernel
|
||||
from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
||||
from .space_to_depth import SpaceToDepthModule
|
||||
from .split_attn import SplitAttn
|
||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .trace_utils import _assert, _float_to_int
|
||||
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
|
@ -65,12 +65,18 @@ from .xception import *
|
||||
from .xception_aligned import *
|
||||
from .xcit import *
|
||||
|
||||
from .factory import create_model, parse_model_name, safe_model_name
|
||||
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
|
||||
from .layers import TestTimePoolHead, apply_test_time_pool
|
||||
from .layers import convert_splitbn_model, convert_sync_batchnorm
|
||||
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
|
||||
from .layers import set_fast_norm
|
||||
from .pretrained import PretrainedCfg, filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
|
||||
from .registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules,\
|
||||
from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \
|
||||
set_pretrained_download_progress, set_pretrained_check_hash
|
||||
from ._factory import create_model, parse_model_name, safe_model_name
|
||||
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
|
||||
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
|
||||
register_notrace_module, register_notrace_function
|
||||
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_checkpoint, resume_checkpoint
|
||||
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
|
||||
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
|
||||
group_modules, group_parameters, checkpoint_seq, adapt_input_conv
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, \
|
||||
filter_pretrained_cfg, generate_default_cfgs, split_model_name_tag
|
||||
from ._prune import adapt_model_from_string
|
||||
from ._registry import register_model, model_entrypoint, list_models, list_pretrained, is_model, list_modules, \
|
||||
is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
||||
|
399
timm/models/_builder.py
Normal file
399
timm/models/_builder.py
Normal file
@ -0,0 +1,399 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Dict, Callable, Any, Tuple
|
||||
|
||||
from torch import nn as nn
|
||||
from torch.hub import load_state_dict_from_url
|
||||
|
||||
from timm.models._features import FeatureListNet, FeatureHookNet
|
||||
from timm.models._features_fx import FeatureGraphNet
|
||||
from timm.models._helpers import load_state_dict
|
||||
from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
|
||||
from timm.models._manipulate import adapt_input_conv
|
||||
from timm.models._pretrained import PretrainedCfg
|
||||
from timm.models._prune import adapt_model_from_file
|
||||
from timm.models._registry import get_pretrained_cfg
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# Global variables for rarely used pretrained checkpoint download progress and hash check.
|
||||
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
|
||||
_DOWNLOAD_PROGRESS = False
|
||||
_CHECK_HASH = False
|
||||
|
||||
|
||||
__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
|
||||
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
|
||||
|
||||
|
||||
def _resolve_pretrained_source(pretrained_cfg):
|
||||
cfg_source = pretrained_cfg.get('source', '')
|
||||
pretrained_url = pretrained_cfg.get('url', None)
|
||||
pretrained_file = pretrained_cfg.get('file', None)
|
||||
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
|
||||
# resolve where to load pretrained weights from
|
||||
load_from = ''
|
||||
pretrained_loc = ''
|
||||
if cfg_source == 'hf-hub' and has_hf_hub(necessary=True):
|
||||
# hf-hub specified as source via model identifier
|
||||
load_from = 'hf-hub'
|
||||
assert hf_hub_id
|
||||
pretrained_loc = hf_hub_id
|
||||
else:
|
||||
# default source == timm or unspecified
|
||||
if pretrained_file:
|
||||
load_from = 'file'
|
||||
pretrained_loc = pretrained_file
|
||||
elif pretrained_url:
|
||||
load_from = 'url'
|
||||
pretrained_loc = pretrained_url
|
||||
elif hf_hub_id and has_hf_hub(necessary=True):
|
||||
# hf-hub available as alternate weight source in default_cfg
|
||||
load_from = 'hf-hub'
|
||||
pretrained_loc = hf_hub_id
|
||||
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
|
||||
# if a filename override is set, return tuple for location w/ (hub_id, filename)
|
||||
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
|
||||
return load_from, pretrained_loc
|
||||
|
||||
|
||||
def set_pretrained_download_progress(enable=True):
|
||||
""" Set download progress for pretrained weights on/off (globally). """
|
||||
global _DOWNLOAD_PROGRESS
|
||||
_DOWNLOAD_PROGRESS = enable
|
||||
|
||||
|
||||
def set_pretrained_check_hash(enable=True):
|
||||
""" Set hash checking for pretrained weights on/off (globally). """
|
||||
global _CHECK_HASH
|
||||
_CHECK_HASH = enable
|
||||
|
||||
|
||||
def load_custom_pretrained(
|
||||
model: nn.Module,
|
||||
pretrained_cfg: Optional[Dict] = None,
|
||||
load_fn: Optional[Callable] = None,
|
||||
):
|
||||
r"""Loads a custom (read non .pth) weight file
|
||||
|
||||
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
|
||||
a passed in custom load fun, or the `load_pretrained` model member fn.
|
||||
|
||||
If the object is already present in `model_dir`, it's deserialized and returned.
|
||||
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
|
||||
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
|
||||
|
||||
Args:
|
||||
model: The instantiated model to load weights into
|
||||
pretrained_cfg (dict): Default pretrained model cfg
|
||||
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
|
||||
'laod_pretrained' on the model will be called if it exists
|
||||
"""
|
||||
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
||||
if not pretrained_cfg:
|
||||
_logger.warning("Invalid pretrained config, cannot load weights.")
|
||||
return
|
||||
|
||||
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
|
||||
if not load_from:
|
||||
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
||||
return
|
||||
if load_from == 'hf-hub': # FIXME
|
||||
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
|
||||
elif load_from == 'url':
|
||||
pretrained_loc = download_cached_file(
|
||||
pretrained_loc,
|
||||
check_hash=_CHECK_HASH,
|
||||
progress=_DOWNLOAD_PROGRESS
|
||||
)
|
||||
|
||||
if load_fn is not None:
|
||||
load_fn(model, pretrained_loc)
|
||||
elif hasattr(model, 'load_pretrained'):
|
||||
model.load_pretrained(pretrained_loc)
|
||||
else:
|
||||
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
|
||||
|
||||
|
||||
def load_pretrained(
|
||||
model: nn.Module,
|
||||
pretrained_cfg: Optional[Dict] = None,
|
||||
num_classes: int = 1000,
|
||||
in_chans: int = 3,
|
||||
filter_fn: Optional[Callable] = None,
|
||||
strict: bool = True,
|
||||
):
|
||||
""" Load pretrained checkpoint
|
||||
|
||||
Args:
|
||||
model (nn.Module) : PyTorch model module
|
||||
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
|
||||
num_classes (int): num_classes for target model
|
||||
in_chans (int): in_chans for target model
|
||||
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
|
||||
strict (bool): strict load of checkpoint
|
||||
|
||||
"""
|
||||
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
||||
if not pretrained_cfg:
|
||||
_logger.warning("Invalid pretrained config, cannot load weights.")
|
||||
return
|
||||
|
||||
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
|
||||
if load_from == 'file':
|
||||
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
|
||||
state_dict = load_state_dict(pretrained_loc)
|
||||
elif load_from == 'url':
|
||||
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
|
||||
state_dict = load_state_dict_from_url(
|
||||
pretrained_loc,
|
||||
map_location='cpu',
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
check_hash=_CHECK_HASH,
|
||||
)
|
||||
elif load_from == 'hf-hub':
|
||||
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
|
||||
if isinstance(pretrained_loc, (list, tuple)):
|
||||
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
||||
else:
|
||||
state_dict = load_state_dict_from_hf(pretrained_loc)
|
||||
else:
|
||||
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
|
||||
return
|
||||
|
||||
if filter_fn is not None:
|
||||
# for backwards compat with filter fn that take one arg, try one first, the two
|
||||
try:
|
||||
state_dict = filter_fn(state_dict)
|
||||
except TypeError:
|
||||
state_dict = filter_fn(state_dict, model)
|
||||
|
||||
input_convs = pretrained_cfg.get('first_conv', None)
|
||||
if input_convs is not None and in_chans != 3:
|
||||
if isinstance(input_convs, str):
|
||||
input_convs = (input_convs,)
|
||||
for input_conv_name in input_convs:
|
||||
weight_name = input_conv_name + '.weight'
|
||||
try:
|
||||
state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
|
||||
_logger.info(
|
||||
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
|
||||
except NotImplementedError as e:
|
||||
del state_dict[weight_name]
|
||||
strict = False
|
||||
_logger.warning(
|
||||
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
|
||||
|
||||
classifiers = pretrained_cfg.get('classifier', None)
|
||||
label_offset = pretrained_cfg.get('label_offset', 0)
|
||||
if classifiers is not None:
|
||||
if isinstance(classifiers, str):
|
||||
classifiers = (classifiers,)
|
||||
if num_classes != pretrained_cfg['num_classes']:
|
||||
for classifier_name in classifiers:
|
||||
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
||||
state_dict.pop(classifier_name + '.weight', None)
|
||||
state_dict.pop(classifier_name + '.bias', None)
|
||||
strict = False
|
||||
elif label_offset > 0:
|
||||
for classifier_name in classifiers:
|
||||
# special case for pretrained weights with an extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
||||
classifier_bias = state_dict[classifier_name + '.bias']
|
||||
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
||||
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def pretrained_cfg_for_features(pretrained_cfg):
|
||||
pretrained_cfg = deepcopy(pretrained_cfg)
|
||||
# remove default pretrained cfg fields that don't have much relevance for feature backbone
|
||||
to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size?
|
||||
for tr in to_remove:
|
||||
pretrained_cfg.pop(tr, None)
|
||||
return pretrained_cfg
|
||||
|
||||
|
||||
def _filter_kwargs(kwargs, names):
|
||||
if not kwargs or not names:
|
||||
return
|
||||
for n in names:
|
||||
kwargs.pop(n, None)
|
||||
|
||||
|
||||
def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
|
||||
""" Update the default_cfg and kwargs before passing to model
|
||||
|
||||
Args:
|
||||
pretrained_cfg: input pretrained cfg (updated in-place)
|
||||
kwargs: keyword args passed to model build fn (updated in-place)
|
||||
kwargs_filter: keyword arg keys that must be removed before model __init__
|
||||
"""
|
||||
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
|
||||
default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
|
||||
if pretrained_cfg.get('fixed_input_size', False):
|
||||
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
|
||||
default_kwarg_names += ('img_size',)
|
||||
|
||||
for n in default_kwarg_names:
|
||||
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
|
||||
# pretrained_cfg has one input_size=(C, H ,W) entry
|
||||
if n == 'img_size':
|
||||
input_size = pretrained_cfg.get('input_size', None)
|
||||
if input_size is not None:
|
||||
assert len(input_size) == 3
|
||||
kwargs.setdefault(n, input_size[-2:])
|
||||
elif n == 'in_chans':
|
||||
input_size = pretrained_cfg.get('input_size', None)
|
||||
if input_size is not None:
|
||||
assert len(input_size) == 3
|
||||
kwargs.setdefault(n, input_size[0])
|
||||
else:
|
||||
default_val = pretrained_cfg.get(n, None)
|
||||
if default_val is not None:
|
||||
kwargs.setdefault(n, pretrained_cfg[n])
|
||||
|
||||
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
|
||||
_filter_kwargs(kwargs, names=kwargs_filter)
|
||||
|
||||
|
||||
def resolve_pretrained_cfg(
|
||||
variant: str,
|
||||
pretrained_cfg=None,
|
||||
pretrained_cfg_overlay=None,
|
||||
) -> PretrainedCfg:
|
||||
model_with_tag = variant
|
||||
pretrained_tag = None
|
||||
if pretrained_cfg:
|
||||
if isinstance(pretrained_cfg, dict):
|
||||
# pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg
|
||||
pretrained_cfg = PretrainedCfg(**pretrained_cfg)
|
||||
elif isinstance(pretrained_cfg, str):
|
||||
pretrained_tag = pretrained_cfg
|
||||
pretrained_cfg = None
|
||||
|
||||
# fallback to looking up pretrained cfg in model registry by variant identifier
|
||||
if not pretrained_cfg:
|
||||
if pretrained_tag:
|
||||
model_with_tag = '.'.join([variant, pretrained_tag])
|
||||
pretrained_cfg = get_pretrained_cfg(model_with_tag)
|
||||
|
||||
if not pretrained_cfg:
|
||||
_logger.warning(
|
||||
f"No pretrained configuration specified for {model_with_tag} model. Using a default."
|
||||
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
|
||||
pretrained_cfg = PretrainedCfg() # instance with defaults
|
||||
|
||||
pretrained_cfg_overlay = pretrained_cfg_overlay or {}
|
||||
if not pretrained_cfg.architecture:
|
||||
pretrained_cfg_overlay.setdefault('architecture', variant)
|
||||
pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay)
|
||||
|
||||
return pretrained_cfg
|
||||
|
||||
|
||||
def build_model_with_cfg(
|
||||
model_cls: Callable,
|
||||
variant: str,
|
||||
pretrained: bool,
|
||||
pretrained_cfg: Optional[Dict] = None,
|
||||
pretrained_cfg_overlay: Optional[Dict] = None,
|
||||
model_cfg: Optional[Any] = None,
|
||||
feature_cfg: Optional[Dict] = None,
|
||||
pretrained_strict: bool = True,
|
||||
pretrained_filter_fn: Optional[Callable] = None,
|
||||
kwargs_filter: Optional[Tuple[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
""" Build model with specified default_cfg and optional model_cfg
|
||||
|
||||
This helper fn aids in the construction of a model including:
|
||||
* handling default_cfg and associated pretrained weight loading
|
||||
* passing through optional model_cfg for models with config based arch spec
|
||||
* features_only model adaptation
|
||||
* pruning config / model adaptation
|
||||
|
||||
Args:
|
||||
model_cls (nn.Module): model class
|
||||
variant (str): model variant name
|
||||
pretrained (bool): load pretrained weights
|
||||
pretrained_cfg (dict): model's pretrained weight/task config
|
||||
model_cfg (Optional[Dict]): model's architecture config
|
||||
feature_cfg (Optional[Dict]: feature extraction adapter config
|
||||
pretrained_strict (bool): load pretrained weights strictly
|
||||
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
|
||||
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
|
||||
**kwargs: model args passed through to model __init__
|
||||
"""
|
||||
pruned = kwargs.pop('pruned', False)
|
||||
features = False
|
||||
feature_cfg = feature_cfg or {}
|
||||
|
||||
# resolve and update model pretrained config and model kwargs
|
||||
pretrained_cfg = resolve_pretrained_cfg(
|
||||
variant,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
pretrained_cfg_overlay=pretrained_cfg_overlay
|
||||
)
|
||||
|
||||
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
|
||||
pretrained_cfg = pretrained_cfg.to_dict()
|
||||
|
||||
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)
|
||||
|
||||
# Setup for feature extraction wrapper done at end of this fn
|
||||
if kwargs.pop('features_only', False):
|
||||
features = True
|
||||
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
||||
if 'out_indices' in kwargs:
|
||||
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
||||
|
||||
# Instantiate the model
|
||||
if model_cfg is None:
|
||||
model = model_cls(**kwargs)
|
||||
else:
|
||||
model = model_cls(cfg=model_cfg, **kwargs)
|
||||
model.pretrained_cfg = pretrained_cfg
|
||||
model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
||||
|
||||
if pruned:
|
||||
model = adapt_model_from_file(model, variant)
|
||||
|
||||
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
||||
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
||||
if pretrained:
|
||||
if pretrained_cfg.get('custom_load', False):
|
||||
load_custom_pretrained(
|
||||
model,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
)
|
||||
else:
|
||||
load_pretrained(
|
||||
model,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
num_classes=num_classes_pretrained,
|
||||
in_chans=kwargs.get('in_chans', 3),
|
||||
filter_fn=pretrained_filter_fn,
|
||||
strict=pretrained_strict,
|
||||
)
|
||||
|
||||
# Wrap the model in a feature extraction module if enabled
|
||||
if features:
|
||||
feature_cls = FeatureListNet
|
||||
if 'feature_cls' in feature_cfg:
|
||||
feature_cls = feature_cfg.pop('feature_cls')
|
||||
if isinstance(feature_cls, str):
|
||||
feature_cls = feature_cls.lower()
|
||||
if 'hook' in feature_cls:
|
||||
feature_cls = FeatureHookNet
|
||||
elif feature_cls == 'fx':
|
||||
feature_cls = FeatureGraphNet
|
||||
else:
|
||||
assert False, f'Unknown feature class {feature_cls}'
|
||||
model = feature_cls(model, **feature_cfg)
|
||||
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg
|
||||
model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
||||
|
||||
return model
|
@ -2,13 +2,12 @@
|
||||
|
||||
Hacked together by / Copyright 2019, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer
|
||||
from timm.layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer
|
||||
|
||||
__all__ = [
|
||||
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual']
|
@ -14,8 +14,8 @@ from functools import partial
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .efficientnet_blocks import *
|
||||
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
|
||||
from ._efficientnet_blocks import *
|
||||
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
|
||||
|
||||
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
|
||||
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
|
103
timm/models/_factory.py
Normal file
103
timm/models/_factory.py
Normal file
@ -0,0 +1,103 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from timm.layers import set_layer_config
|
||||
from ._pretrained import PretrainedCfg, split_model_name_tag
|
||||
from ._helpers import load_checkpoint
|
||||
from ._hub import load_model_config_from_hf
|
||||
from ._registry import is_model, model_entrypoint
|
||||
|
||||
|
||||
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
|
||||
|
||||
|
||||
def parse_model_name(model_name):
|
||||
if model_name.startswith('hf_hub'):
|
||||
# NOTE for backwards compat, deprecate hf_hub use
|
||||
model_name = model_name.replace('hf_hub', 'hf-hub')
|
||||
parsed = urlsplit(model_name)
|
||||
assert parsed.scheme in ('', 'timm', 'hf-hub')
|
||||
if parsed.scheme == 'hf-hub':
|
||||
# FIXME may use fragment as revision, currently `@` in URI path
|
||||
return parsed.scheme, parsed.path
|
||||
else:
|
||||
model_name = os.path.split(parsed.path)[-1]
|
||||
return 'timm', model_name
|
||||
|
||||
|
||||
def safe_model_name(model_name, remove_source=True):
|
||||
# return a filename / path safe model name
|
||||
def make_safe(name):
|
||||
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
|
||||
if remove_source:
|
||||
model_name = parse_model_name(model_name)[-1]
|
||||
return make_safe(model_name)
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name: str,
|
||||
pretrained: bool = False,
|
||||
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
|
||||
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
|
||||
checkpoint_path: str = '',
|
||||
scriptable: Optional[bool] = None,
|
||||
exportable: Optional[bool] = None,
|
||||
no_jit: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a model
|
||||
|
||||
Lookup model's entrypoint function and pass relevant args to create a new model.
|
||||
|
||||
**kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg()
|
||||
and then the model class __init__(). kwargs values set to None are pruned before passing.
|
||||
|
||||
Args:
|
||||
model_name (str): name of model to instantiate
|
||||
pretrained (bool): load pretrained ImageNet-1k weights if true
|
||||
pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model
|
||||
pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these
|
||||
checkpoint_path (str): path of checkpoint to load _after_ the model is initialized
|
||||
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
|
||||
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
|
||||
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
|
||||
|
||||
Keyword Args:
|
||||
drop_rate (float): dropout rate for training (default: 0.0)
|
||||
global_pool (str): global pool type (default: 'avg')
|
||||
**: other kwargs are consumed by builder or model __init__()
|
||||
"""
|
||||
# Parameters that aren't supported by all models or are intended to only override model defaults if set
|
||||
# should default to None in command line args/cfg. Remove them if they are present and not set so that
|
||||
# non-supporting models don't break and default args remain in effect.
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
model_source, model_name = parse_model_name(model_name)
|
||||
if model_source == 'hf-hub':
|
||||
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
|
||||
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
|
||||
# load model weights + pretrained_cfg from Hugging Face hub.
|
||||
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
|
||||
else:
|
||||
model_name, pretrained_tag = split_model_name_tag(model_name)
|
||||
if not pretrained_cfg:
|
||||
# a valid pretrained_cfg argument takes priority over tag in model name
|
||||
pretrained_cfg = pretrained_tag
|
||||
|
||||
if not is_model(model_name):
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
create_fn = model_entrypoint(model_name)
|
||||
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
|
||||
model = create_fn(
|
||||
pretrained=pretrained,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
pretrained_cfg_overlay=pretrained_cfg_overlay,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
||||
return model
|
287
timm/models/_features.py
Normal file
287
timm/models/_features.py
Normal file
@ -0,0 +1,287 @@
|
||||
""" PyTorch Feature Extraction Helpers
|
||||
|
||||
A collection of classes, functions, modules to help extract features from models
|
||||
and provide a common interface for describing them.
|
||||
|
||||
The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
|
||||
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
||||
|
||||
|
||||
class FeatureInfo:
|
||||
|
||||
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
||||
prev_reduction = 1
|
||||
for fi in feature_info:
|
||||
# sanity check the mandatory fields, there may be additional fields depending on the model
|
||||
assert 'num_chs' in fi and fi['num_chs'] > 0
|
||||
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
|
||||
prev_reduction = fi['reduction']
|
||||
assert 'module' in fi
|
||||
self.out_indices = out_indices
|
||||
self.info = feature_info
|
||||
|
||||
def from_other(self, out_indices: Tuple[int]):
|
||||
return FeatureInfo(deepcopy(self.info), out_indices)
|
||||
|
||||
def get(self, key, idx=None):
|
||||
""" Get value by key at specified index (indices)
|
||||
if idx == None, returns value for key at each output index
|
||||
if idx is an integer, return value for that feature module index (ignoring output indices)
|
||||
if idx is a list/tupple, return value for each module index (ignoring output indices)
|
||||
"""
|
||||
if idx is None:
|
||||
return [self.info[i][key] for i in self.out_indices]
|
||||
if isinstance(idx, (tuple, list)):
|
||||
return [self.info[i][key] for i in idx]
|
||||
else:
|
||||
return self.info[idx][key]
|
||||
|
||||
def get_dicts(self, keys=None, idx=None):
|
||||
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
|
||||
"""
|
||||
if idx is None:
|
||||
if keys is None:
|
||||
return [self.info[i] for i in self.out_indices]
|
||||
else:
|
||||
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
|
||||
if isinstance(idx, (tuple, list)):
|
||||
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
|
||||
else:
|
||||
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
||||
|
||||
def channels(self, idx=None):
|
||||
""" feature channels accessor
|
||||
"""
|
||||
return self.get('num_chs', idx)
|
||||
|
||||
def reduction(self, idx=None):
|
||||
""" feature reduction (output stride) accessor
|
||||
"""
|
||||
return self.get('reduction', idx)
|
||||
|
||||
def module_name(self, idx=None):
|
||||
""" feature module name accessor
|
||||
"""
|
||||
return self.get('module', idx)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.info[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.info)
|
||||
|
||||
|
||||
class FeatureHooks:
|
||||
""" Feature Hook Helper
|
||||
|
||||
This module helps with the setup and extraction of hooks for extracting features from
|
||||
internal nodes in a model by node name. This works quite well in eager Python but needs
|
||||
redesign for torchscript.
|
||||
"""
|
||||
|
||||
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
|
||||
# setup feature hooks
|
||||
modules = {k: v for k, v in named_modules}
|
||||
for i, h in enumerate(hooks):
|
||||
hook_name = h['module']
|
||||
m = modules[hook_name]
|
||||
hook_id = out_map[i] if out_map else hook_name
|
||||
hook_fn = partial(self._collect_output_hook, hook_id)
|
||||
hook_type = h.get('hook_type', default_hook_type)
|
||||
if hook_type == 'forward_pre':
|
||||
m.register_forward_pre_hook(hook_fn)
|
||||
elif hook_type == 'forward':
|
||||
m.register_forward_hook(hook_fn)
|
||||
else:
|
||||
assert False, "Unsupported hook type"
|
||||
self._feature_outputs = defaultdict(OrderedDict)
|
||||
|
||||
def _collect_output_hook(self, hook_id, *args):
|
||||
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
||||
if isinstance(x, tuple):
|
||||
x = x[0] # unwrap input tuple
|
||||
self._feature_outputs[x.device][hook_id] = x
|
||||
|
||||
def get_output(self, device) -> Dict[str, torch.tensor]:
|
||||
output = self._feature_outputs[device]
|
||||
self._feature_outputs[device] = OrderedDict() # clear after reading
|
||||
return output
|
||||
|
||||
|
||||
def _module_list(module, flatten_sequential=False):
|
||||
# a yield/iter would be better for this but wouldn't be compatible with torchscript
|
||||
ml = []
|
||||
for name, module in module.named_children():
|
||||
if flatten_sequential and isinstance(module, nn.Sequential):
|
||||
# first level of Sequential containers is flattened into containing model
|
||||
for child_name, child_module in module.named_children():
|
||||
combined = [name, child_name]
|
||||
ml.append(('_'.join(combined), '.'.join(combined), child_module))
|
||||
else:
|
||||
ml.append((name, name, module))
|
||||
return ml
|
||||
|
||||
|
||||
def _get_feature_info(net, out_indices):
|
||||
feature_info = getattr(net, 'feature_info')
|
||||
if isinstance(feature_info, FeatureInfo):
|
||||
return feature_info.from_other(out_indices)
|
||||
elif isinstance(feature_info, (list, tuple)):
|
||||
return FeatureInfo(net.feature_info, out_indices)
|
||||
else:
|
||||
assert False, "Provided feature_info is not valid"
|
||||
|
||||
|
||||
def _get_return_layers(feature_info, out_map):
|
||||
module_names = feature_info.module_name()
|
||||
return_layers = {}
|
||||
for i, name in enumerate(module_names):
|
||||
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
|
||||
return return_layers
|
||||
|
||||
|
||||
class FeatureDictNet(nn.ModuleDict):
|
||||
""" Feature extractor with OrderedDict return
|
||||
|
||||
Wrap a model and extract features as specified by the out indices, the network is
|
||||
partially re-built from contained modules.
|
||||
|
||||
There is a strong assumption that the modules have been registered into the model in the same
|
||||
order as they are used. There should be no reuse of the same nn.Module more than once, including
|
||||
trivial modules like `self.relu = nn.ReLU`.
|
||||
|
||||
Only submodules that are directly assigned to the model class (`model.feature1`) or at most
|
||||
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
|
||||
All Sequential containers that are directly assigned to the original model will have their
|
||||
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
||||
|
||||
Arguments:
|
||||
model (nn.Module): model from which we will extract the features
|
||||
out_indices (tuple[int]): model output indices to extract features for
|
||||
out_map (sequence): list or tuple specifying desired return id for each out index,
|
||||
otherwise str(index) is used
|
||||
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
|
||||
vs select element [0]
|
||||
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
||||
"""
|
||||
def __init__(
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
||||
super(FeatureDictNet, self).__init__()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.concat = feature_concat
|
||||
self.return_layers = {}
|
||||
return_layers = _get_return_layers(self.feature_info, out_map)
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = set(return_layers.keys())
|
||||
layers = OrderedDict()
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
if old_name in remaining:
|
||||
# return id has to be consistently str type for torchscript
|
||||
self.return_layers[new_name] = str(return_layers[old_name])
|
||||
remaining.remove(old_name)
|
||||
if not remaining:
|
||||
break
|
||||
assert not remaining and len(self.return_layers) == len(return_layers), \
|
||||
f'Return layers ({remaining}) are not present in model'
|
||||
self.update(layers)
|
||||
|
||||
def _collect(self, x) -> (Dict[str, torch.Tensor]):
|
||||
out = OrderedDict()
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
if name in self.return_layers:
|
||||
out_id = self.return_layers[name]
|
||||
if isinstance(x, (tuple, list)):
|
||||
# If model tap is a tuple or list, concat or select first element
|
||||
# FIXME this may need to be more generic / flexible for some nets
|
||||
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
|
||||
else:
|
||||
out[out_id] = x
|
||||
return out
|
||||
|
||||
def forward(self, x) -> Dict[str, torch.Tensor]:
|
||||
return self._collect(x)
|
||||
|
||||
|
||||
class FeatureListNet(FeatureDictNet):
|
||||
""" Feature extractor with list return
|
||||
|
||||
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
|
||||
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
|
||||
"""
|
||||
def __init__(
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
||||
super(FeatureListNet, self).__init__(
|
||||
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
|
||||
flatten_sequential=flatten_sequential)
|
||||
|
||||
def forward(self, x) -> (List[torch.Tensor]):
|
||||
return list(self._collect(x).values())
|
||||
|
||||
|
||||
class FeatureHookNet(nn.ModuleDict):
|
||||
""" FeatureHookNet
|
||||
|
||||
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
|
||||
|
||||
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
|
||||
network in any way.
|
||||
|
||||
If `no_rewrite` is False, the model will be re-written as in the
|
||||
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
|
||||
|
||||
FIXME this does not currently work with Torchscript, see FeatureHooks class
|
||||
"""
|
||||
def __init__(
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
|
||||
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
|
||||
super(FeatureHookNet, self).__init__()
|
||||
assert not torch.jit.is_scripting()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.out_as_dict = out_as_dict
|
||||
layers = OrderedDict()
|
||||
hooks = []
|
||||
if no_rewrite:
|
||||
assert not flatten_sequential
|
||||
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
|
||||
model.reset_classifier(0)
|
||||
layers['body'] = model
|
||||
hooks.extend(self.feature_info.get_dicts())
|
||||
else:
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
||||
for f in self.feature_info.get_dicts()}
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
for fn, fm in module.named_modules(prefix=old_name):
|
||||
if fn in remaining:
|
||||
hooks.append(dict(module=fn, hook_type=remaining[fn]))
|
||||
del remaining[fn]
|
||||
if not remaining:
|
||||
break
|
||||
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
||||
self.update(layers)
|
||||
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
||||
|
||||
def forward(self, x):
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
out = self.hooks.get_output(x.device)
|
||||
return out if self.out_as_dict else list(out.values())
|
110
timm/models/_features_fx.py
Normal file
110
timm/models/_features_fx.py
Normal file
@ -0,0 +1,110 @@
|
||||
""" PyTorch FX Based Feature Extraction Helpers
|
||||
Using https://pytorch.org/vision/stable/feature_extraction.html
|
||||
"""
|
||||
from typing import Callable, List, Dict, Union, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ._features import _get_feature_info
|
||||
|
||||
try:
|
||||
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
||||
has_fx_feature_extraction = True
|
||||
except ImportError:
|
||||
has_fx_feature_extraction = False
|
||||
|
||||
# Layers we went to treat as leaf modules
|
||||
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
|
||||
from timm.layers.non_local_attn import BilinearAttnTransform
|
||||
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||
|
||||
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
||||
# BUT modules from timm.models should use the registration mechanism below
|
||||
_leaf_modules = {
|
||||
BilinearAttnTransform, # reason: flow control t <= 1
|
||||
# Reason: get_same_padding has a max which raises a control flow error
|
||||
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
|
||||
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
|
||||
}
|
||||
|
||||
try:
|
||||
from timm.layers import InplaceAbn
|
||||
_leaf_modules.add(InplaceAbn)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor',
|
||||
'FeatureGraphNet', 'GraphExtractNet']
|
||||
|
||||
|
||||
def register_notrace_module(module: Type[nn.Module]):
|
||||
"""
|
||||
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
||||
"""
|
||||
_leaf_modules.add(module)
|
||||
return module
|
||||
|
||||
|
||||
# Functions we want to autowrap (treat them as leaves)
|
||||
_autowrap_functions = set()
|
||||
|
||||
|
||||
def register_notrace_function(func: Callable):
|
||||
"""
|
||||
Decorator for functions which ought not to be traced through
|
||||
"""
|
||||
_autowrap_functions.add(func)
|
||||
return func
|
||||
|
||||
|
||||
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
return _create_feature_extractor(
|
||||
model, return_nodes,
|
||||
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
|
||||
)
|
||||
|
||||
|
||||
class FeatureGraphNet(nn.Module):
|
||||
""" A FX Graph based feature extractor that works with the model feature_info metadata
|
||||
"""
|
||||
def __init__(self, model, out_indices, out_map=None):
|
||||
super().__init__()
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
if out_map is not None:
|
||||
assert len(out_map) == len(out_indices)
|
||||
return_nodes = {
|
||||
info['module']: out_map[i] if out_map is not None else info['module']
|
||||
for i, info in enumerate(self.feature_info) if i in out_indices}
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
|
||||
def forward(self, x):
|
||||
return list(self.graph_module(x).values())
|
||||
|
||||
|
||||
class GraphExtractNet(nn.Module):
|
||||
""" A standalone feature extraction wrapper that maps dict -> list or single tensor
|
||||
NOTE:
|
||||
* one can use feature_extractor directly if dictionary output is desired
|
||||
* unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
|
||||
metadata for builtin feature extraction mode
|
||||
* create_feature_extractor can be used directly if dictionary output is acceptable
|
||||
|
||||
Args:
|
||||
model: model to extract features from
|
||||
return_nodes: node names to return features from (dict or list)
|
||||
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
||||
"""
|
||||
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
|
||||
super().__init__()
|
||||
self.squeeze_out = squeeze_out
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
|
||||
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
out = list(self.graph_module(x).values())
|
||||
if self.squeeze_out and len(out) == 1:
|
||||
return out[0]
|
||||
return out
|
115
timm/models/_helpers.py
Normal file
115
timm/models/_helpers.py
Normal file
@ -0,0 +1,115 @@
|
||||
""" Model creation / weight loading / state_dict helpers
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
import timm.models._builder
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
|
||||
|
||||
|
||||
def clean_state_dict(state_dict):
|
||||
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
||||
cleaned_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
name = k[7:] if k.startswith('module.') else k
|
||||
cleaned_state_dict[name] = v
|
||||
return cleaned_state_dict
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_path, use_ema=True):
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
state_dict_key = ''
|
||||
if isinstance(checkpoint, dict):
|
||||
if use_ema and checkpoint.get('state_dict_ema', None) is not None:
|
||||
state_dict_key = 'state_dict_ema'
|
||||
elif use_ema and checkpoint.get('model_ema', None) is not None:
|
||||
state_dict_key = 'model_ema'
|
||||
elif 'state_dict' in checkpoint:
|
||||
state_dict_key = 'state_dict'
|
||||
elif 'model' in checkpoint:
|
||||
state_dict_key = 'model'
|
||||
state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint)
|
||||
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
|
||||
return state_dict
|
||||
else:
|
||||
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
|
||||
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
||||
# numpy checkpoint, try to load via model specific load_pretrained fn
|
||||
if hasattr(model, 'load_pretrained'):
|
||||
timm.models._model_builder.load_pretrained(checkpoint_path)
|
||||
else:
|
||||
raise NotImplementedError('Model cannot load numpy checkpoint')
|
||||
return
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema)
|
||||
if remap:
|
||||
state_dict = remap_checkpoint(model, state_dict)
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
||||
return incompatible_keys
|
||||
|
||||
|
||||
def remap_checkpoint(model, state_dict, allow_reshape=True):
|
||||
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
|
||||
This assumes models (and originating state dict) were created with params registered in same order.
|
||||
"""
|
||||
out_dict = {}
|
||||
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
|
||||
assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
||||
if va.shape != vb.shape:
|
||||
if allow_reshape:
|
||||
vb = vb.reshape(va.shape)
|
||||
else:
|
||||
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
||||
out_dict[ka] = vb
|
||||
return out_dict
|
||||
|
||||
|
||||
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
||||
resume_epoch = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring model state from checkpoint...')
|
||||
state_dict = clean_state_dict(checkpoint['state_dict'])
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if optimizer is not None and 'optimizer' in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring optimizer state from checkpoint...')
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring AMP loss scaler state from checkpoint...')
|
||||
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
|
||||
|
||||
if 'epoch' in checkpoint:
|
||||
resume_epoch = checkpoint['epoch']
|
||||
if 'version' in checkpoint and checkpoint['version'] > 1:
|
||||
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
||||
|
||||
if log_info:
|
||||
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
if log_info:
|
||||
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
return resume_epoch
|
||||
else:
|
||||
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
220
timm/models/_hub.py
Normal file
220
timm/models/_hub.py
Normal file
@ -0,0 +1,220 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
||||
|
||||
try:
|
||||
from torch.hub import get_dir
|
||||
except ImportError:
|
||||
from torch.hub import _get_torch_home as get_dir
|
||||
|
||||
from timm import __version__
|
||||
from timm.models._pretrained import filter_pretrained_cfg
|
||||
|
||||
try:
|
||||
from huggingface_hub import (
|
||||
create_repo, get_hf_file_metadata,
|
||||
hf_hub_download, hf_hub_url,
|
||||
repo_type_and_id_from_hf_id, upload_folder)
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
|
||||
_has_hf_hub = True
|
||||
except ImportError:
|
||||
hf_hub_download = None
|
||||
_has_hf_hub = False
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
|
||||
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
|
||||
|
||||
|
||||
def get_cache_dir(child_dir=''):
|
||||
"""
|
||||
Returns the location of the directory where models are cached (and creates it if necessary).
|
||||
"""
|
||||
# Issue warning to move data if old env is set
|
||||
if os.getenv('TORCH_MODEL_ZOO'):
|
||||
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
||||
|
||||
hub_dir = get_dir()
|
||||
child_dir = () if not child_dir else (child_dir,)
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
return model_dir
|
||||
|
||||
|
||||
def download_cached_file(url, check_hash=True, progress=False):
|
||||
if isinstance(url, (list, tuple)):
|
||||
url, filename = url
|
||||
else:
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(get_cache_dir(), filename)
|
||||
if not os.path.exists(cached_file):
|
||||
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
if check_hash:
|
||||
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
||||
hash_prefix = r.group(1) if r else None
|
||||
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
||||
return cached_file
|
||||
|
||||
|
||||
def has_hf_hub(necessary=False):
|
||||
if not _has_hf_hub and necessary:
|
||||
# if no HF Hub module installed, and it is necessary to continue, raise error
|
||||
raise RuntimeError(
|
||||
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
||||
return _has_hf_hub
|
||||
|
||||
|
||||
def hf_split(hf_id):
|
||||
# FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
|
||||
rev_split = hf_id.split('@')
|
||||
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
|
||||
hf_model_id = rev_split[0]
|
||||
hf_revision = rev_split[-1] if len(rev_split) > 1 else None
|
||||
return hf_model_id, hf_revision
|
||||
|
||||
|
||||
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
def _download_from_hf(model_id: str, filename: str):
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
|
||||
|
||||
|
||||
def load_model_config_from_hf(model_id: str):
|
||||
assert has_hf_hub(True)
|
||||
cached_file = _download_from_hf(model_id, 'config.json')
|
||||
|
||||
hf_config = load_cfg_from_json(cached_file)
|
||||
if 'pretrained_cfg' not in hf_config:
|
||||
# old form, pull pretrain_cfg out of the base dict
|
||||
pretrained_cfg = hf_config
|
||||
hf_config = {}
|
||||
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
||||
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
|
||||
if 'labels' in pretrained_cfg:
|
||||
hf_config['label_name'] = pretrained_cfg.pop('labels')
|
||||
hf_config['pretrained_cfg'] = pretrained_cfg
|
||||
|
||||
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
|
||||
pretrained_cfg = hf_config['pretrained_cfg']
|
||||
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
|
||||
pretrained_cfg['source'] = 'hf-hub'
|
||||
if 'num_classes' in hf_config:
|
||||
# model should be created with parent num_classes if they exist
|
||||
pretrained_cfg['num_classes'] = hf_config['num_classes']
|
||||
model_name = hf_config['architecture']
|
||||
|
||||
return pretrained_cfg, model_name
|
||||
|
||||
|
||||
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
|
||||
assert has_hf_hub(True)
|
||||
cached_file = _download_from_hf(model_id, filename)
|
||||
state_dict = torch.load(cached_file, map_location='cpu')
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_for_hf(model, save_directory, model_config=None):
|
||||
assert has_hf_hub(True)
|
||||
model_config = model_config or {}
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
weights_path = save_directory / 'pytorch_model.bin'
|
||||
torch.save(model.state_dict(), weights_path)
|
||||
|
||||
config_path = save_directory / 'config.json'
|
||||
hf_config = {}
|
||||
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
||||
# set some values at root config level
|
||||
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
||||
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
|
||||
hf_config['num_features'] = model_config.get('num_features', model.num_features)
|
||||
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None))
|
||||
|
||||
if 'label' in model_config:
|
||||
_logger.warning(
|
||||
"'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
|
||||
"Using provided 'label' field as 'label_name'.")
|
||||
model_config['label_name'] = model_config.pop('label')
|
||||
|
||||
label_name = model_config.pop('label_name', None)
|
||||
if label_name:
|
||||
assert isinstance(label_name, (dict, list, tuple))
|
||||
# map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
|
||||
# can be a dict id: name if there are id gaps, or tuple/list if no gaps.
|
||||
hf_config['label_name'] = model_config['label_name']
|
||||
|
||||
display_name = model_config.pop('display_name', None)
|
||||
if display_name:
|
||||
assert isinstance(display_name, dict)
|
||||
# map label_name -> user interface display name
|
||||
hf_config['display_name'] = model_config['display_name']
|
||||
|
||||
hf_config['pretrained_cfg'] = pretrained_cfg
|
||||
hf_config.update(model_config)
|
||||
|
||||
with config_path.open('w') as f:
|
||||
json.dump(hf_config, f, indent=2)
|
||||
|
||||
|
||||
def push_to_hf_hub(
|
||||
model,
|
||||
repo_id: str,
|
||||
commit_message: str = 'Add model',
|
||||
token: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
private: bool = False,
|
||||
create_pr: bool = False,
|
||||
model_config: Optional[dict] = None,
|
||||
):
|
||||
# Create repo if it doesn't exist yet
|
||||
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
||||
|
||||
# Infer complete repo_id from repo_url
|
||||
# Can be different from the input `repo_id` if repo_owner was implicit
|
||||
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
||||
repo_id = f"{repo_owner}/{repo_name}"
|
||||
|
||||
# Check if README file already exist in repo
|
||||
try:
|
||||
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
||||
has_readme = True
|
||||
except EntryNotFoundError:
|
||||
has_readme = False
|
||||
|
||||
# Dump model and push to Hub
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Save model weights and config.
|
||||
save_for_hf(model, tmpdir, model_config=model_config)
|
||||
|
||||
# Add readme if it does not exist
|
||||
if not has_readme:
|
||||
model_name = repo_id.split('/')[-1]
|
||||
readme_path = Path(tmpdir) / "README.md"
|
||||
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}'
|
||||
readme_path.write_text(readme_text)
|
||||
|
||||
# Upload model and return
|
||||
return upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=tmpdir,
|
||||
revision=revision,
|
||||
create_pr=create_pr,
|
||||
commit_message=commit_message,
|
||||
)
|
258
timm/models/_manipulate.py
Normal file
258
timm/models/_manipulate.py
Normal file
@ -0,0 +1,258 @@
|
||||
import collections.abc
|
||||
import math
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Callable, Union, Dict
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
|
||||
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
|
||||
|
||||
|
||||
def model_parameters(model, exclude_head=False):
|
||||
if exclude_head:
|
||||
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
|
||||
return [p for p in model.parameters()][:-2]
|
||||
else:
|
||||
return model.parameters()
|
||||
|
||||
|
||||
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
|
||||
if not depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
return module
|
||||
|
||||
|
||||
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
|
||||
if not depth_first and include_root:
|
||||
yield name, module
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
yield from named_modules(
|
||||
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
yield name, module
|
||||
|
||||
|
||||
def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False):
|
||||
if module._parameters and not depth_first and include_root:
|
||||
yield name, module
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
yield from named_modules_with_params(
|
||||
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if module._parameters and depth_first and include_root:
|
||||
yield name, module
|
||||
|
||||
|
||||
MATCH_PREV_GROUP = (99999,)
|
||||
|
||||
|
||||
def group_with_matcher(
|
||||
named_objects,
|
||||
group_matcher: Union[Dict, Callable],
|
||||
output_values: bool = False,
|
||||
reverse: bool = False
|
||||
):
|
||||
if isinstance(group_matcher, dict):
|
||||
# dictionary matcher contains a dict of raw-string regex expr that must be compiled
|
||||
compiled = []
|
||||
for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
|
||||
if mspec is None:
|
||||
continue
|
||||
# map all matching specifications into 3-tuple (compiled re, prefix, suffix)
|
||||
if isinstance(mspec, (tuple, list)):
|
||||
# multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
|
||||
for sspec in mspec:
|
||||
compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
|
||||
else:
|
||||
compiled += [(re.compile(mspec), (group_ordinal,), None)]
|
||||
group_matcher = compiled
|
||||
|
||||
def _get_grouping(name):
|
||||
if isinstance(group_matcher, (list, tuple)):
|
||||
for match_fn, prefix, suffix in group_matcher:
|
||||
r = match_fn.match(name)
|
||||
if r:
|
||||
parts = (prefix, r.groups(), suffix)
|
||||
# map all tuple elem to int for numeric sort, filter out None entries
|
||||
return tuple(map(float, chain.from_iterable(filter(None, parts))))
|
||||
return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal
|
||||
else:
|
||||
ord = group_matcher(name)
|
||||
if not isinstance(ord, collections.abc.Iterable):
|
||||
return ord,
|
||||
return tuple(ord)
|
||||
|
||||
# map layers into groups via ordinals (ints or tuples of ints) from matcher
|
||||
grouping = defaultdict(list)
|
||||
for k, v in named_objects:
|
||||
grouping[_get_grouping(k)].append(v if output_values else k)
|
||||
|
||||
# remap to integers
|
||||
layer_id_to_param = defaultdict(list)
|
||||
lid = -1
|
||||
for k in sorted(filter(lambda x: x is not None, grouping.keys())):
|
||||
if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
|
||||
lid += 1
|
||||
layer_id_to_param[lid].extend(grouping[k])
|
||||
|
||||
if reverse:
|
||||
assert not output_values, "reverse mapping only sensible for name output"
|
||||
# output reverse mapping
|
||||
param_to_layer_id = {}
|
||||
for lid, lm in layer_id_to_param.items():
|
||||
for n in lm:
|
||||
param_to_layer_id[n] = lid
|
||||
return param_to_layer_id
|
||||
|
||||
return layer_id_to_param
|
||||
|
||||
|
||||
def group_parameters(
|
||||
module: nn.Module,
|
||||
group_matcher,
|
||||
output_values=False,
|
||||
reverse=False,
|
||||
):
|
||||
return group_with_matcher(
|
||||
module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)
|
||||
|
||||
|
||||
def group_modules(
|
||||
module: nn.Module,
|
||||
group_matcher,
|
||||
output_values=False,
|
||||
reverse=False,
|
||||
):
|
||||
return group_with_matcher(
|
||||
named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)
|
||||
|
||||
|
||||
def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'):
|
||||
prefix_is_tuple = isinstance(prefix, tuple)
|
||||
if isinstance(module_types, str):
|
||||
if module_types == 'container':
|
||||
module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict)
|
||||
else:
|
||||
module_types = (nn.Sequential,)
|
||||
for name, module in named_modules:
|
||||
if depth and isinstance(module, module_types):
|
||||
yield from flatten_modules(
|
||||
module.named_children(),
|
||||
depth - 1,
|
||||
prefix=(name,) if prefix_is_tuple else name,
|
||||
module_types=module_types,
|
||||
)
|
||||
else:
|
||||
if prefix_is_tuple:
|
||||
name = prefix + (name,)
|
||||
yield name, module
|
||||
else:
|
||||
if prefix:
|
||||
name = '.'.join([prefix, name])
|
||||
yield name, module
|
||||
|
||||
|
||||
def checkpoint_seq(
|
||||
functions,
|
||||
x,
|
||||
every=1,
|
||||
flatten=False,
|
||||
skip_last=False,
|
||||
preserve_rng_state=True
|
||||
):
|
||||
r"""A helper function for checkpointing sequential models.
|
||||
|
||||
Sequential models execute a list of modules/functions in order
|
||||
(sequentially). Therefore, we can divide such a sequence into segments
|
||||
and checkpoint each segment. All segments except run in :func:`torch.no_grad`
|
||||
manner, i.e., not storing the intermediate activations. The inputs of each
|
||||
checkpointed segment will be saved for re-running the segment in the backward pass.
|
||||
|
||||
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
|
||||
|
||||
.. warning::
|
||||
Checkpointing currently only supports :func:`torch.autograd.backward`
|
||||
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
|
||||
is not supported.
|
||||
|
||||
.. warning:
|
||||
At least one of the inputs needs to have :code:`requires_grad=True` if
|
||||
grads are needed for model inputs, otherwise the checkpointed part of the
|
||||
model won't have gradients.
|
||||
|
||||
Args:
|
||||
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
|
||||
x: A Tensor that is input to :attr:`functions`
|
||||
every: checkpoint every-n functions (default: 1)
|
||||
flatten (bool): flatten nn.Sequential of nn.Sequentials
|
||||
skip_last (bool): skip checkpointing the last function in the sequence if True
|
||||
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
|
||||
the RNG state during each checkpoint.
|
||||
|
||||
Returns:
|
||||
Output of running :attr:`functions` sequentially on :attr:`*inputs`
|
||||
|
||||
Example:
|
||||
>>> model = nn.Sequential(...)
|
||||
>>> input_var = checkpoint_seq(model, input_var, every=2)
|
||||
"""
|
||||
def run_function(start, end, functions):
|
||||
def forward(_x):
|
||||
for j in range(start, end + 1):
|
||||
_x = functions[j](_x)
|
||||
return _x
|
||||
return forward
|
||||
|
||||
if isinstance(functions, torch.nn.Sequential):
|
||||
functions = functions.children()
|
||||
if flatten:
|
||||
functions = chain.from_iterable(functions)
|
||||
if not isinstance(functions, (tuple, list)):
|
||||
functions = tuple(functions)
|
||||
|
||||
num_checkpointed = len(functions)
|
||||
if skip_last:
|
||||
num_checkpointed -= 1
|
||||
end = -1
|
||||
for start in range(0, num_checkpointed, every):
|
||||
end = min(start + every - 1, num_checkpointed - 1)
|
||||
x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
|
||||
if skip_last:
|
||||
return run_function(end + 1, len(functions) - 1, functions)(x)
|
||||
return x
|
||||
|
||||
|
||||
def adapt_input_conv(in_chans, conv_weight):
|
||||
conv_type = conv_weight.dtype
|
||||
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
|
||||
O, I, J, K = conv_weight.shape
|
||||
if in_chans == 1:
|
||||
if I > 3:
|
||||
assert conv_weight.shape[1] % 3 == 0
|
||||
# For models with space2depth stems
|
||||
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
|
||||
conv_weight = conv_weight.sum(dim=2, keepdim=False)
|
||||
else:
|
||||
conv_weight = conv_weight.sum(dim=1, keepdim=True)
|
||||
elif in_chans != 3:
|
||||
if I != 3:
|
||||
raise NotImplementedError('Weight format not supported by conversion.')
|
||||
else:
|
||||
# NOTE this strategy should be better than random init, but there could be other combinations of
|
||||
# the original RGB input layer weights that'd work better for specific cases.
|
||||
repeat = int(math.ceil(in_chans / 3))
|
||||
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
||||
conv_weight *= (3 / float(in_chans))
|
||||
conv_weight = conv_weight.to(conv_type)
|
||||
return conv_weight
|
@ -4,6 +4,9 @@ from dataclasses import dataclass, field, replace, asdict
|
||||
from typing import Any, Deque, Dict, Tuple, Optional, Union
|
||||
|
||||
|
||||
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs']
|
||||
|
||||
|
||||
@dataclass
|
||||
class PretrainedCfg:
|
||||
"""
|
113
timm/models/_prune.py
Normal file
113
timm/models/_prune.py
Normal file
@ -0,0 +1,113 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
from torch import nn as nn
|
||||
|
||||
from timm.layers import Conv2dSame, BatchNormAct2d, Linear
|
||||
|
||||
__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
|
||||
|
||||
|
||||
def extract_layer(model, layer):
|
||||
layer = layer.split('.')
|
||||
module = model
|
||||
if hasattr(model, 'module') and layer[0] != 'module':
|
||||
module = model.module
|
||||
if not hasattr(model, 'module') and layer[0] == 'module':
|
||||
layer = layer[1:]
|
||||
for l in layer:
|
||||
if hasattr(module, l):
|
||||
if not l.isdigit():
|
||||
module = getattr(module, l)
|
||||
else:
|
||||
module = module[int(l)]
|
||||
else:
|
||||
return module
|
||||
return module
|
||||
|
||||
|
||||
def set_layer(model, layer, val):
|
||||
layer = layer.split('.')
|
||||
module = model
|
||||
if hasattr(model, 'module') and layer[0] != 'module':
|
||||
module = model.module
|
||||
lst_index = 0
|
||||
module2 = module
|
||||
for l in layer:
|
||||
if hasattr(module2, l):
|
||||
if not l.isdigit():
|
||||
module2 = getattr(module2, l)
|
||||
else:
|
||||
module2 = module2[int(l)]
|
||||
lst_index += 1
|
||||
lst_index -= 1
|
||||
for l in layer[:lst_index]:
|
||||
if not l.isdigit():
|
||||
module = getattr(module, l)
|
||||
else:
|
||||
module = module[int(l)]
|
||||
l = layer[lst_index]
|
||||
setattr(module, l, val)
|
||||
|
||||
|
||||
def adapt_model_from_string(parent_module, model_string):
|
||||
separator = '***'
|
||||
state_dict = {}
|
||||
lst_shape = model_string.split(separator)
|
||||
for k in lst_shape:
|
||||
k = k.split(':')
|
||||
key = k[0]
|
||||
shape = k[1][1:-1].split(',')
|
||||
if shape[0] != '':
|
||||
state_dict[key] = [int(i) for i in shape]
|
||||
|
||||
new_module = deepcopy(parent_module)
|
||||
for n, m in parent_module.named_modules():
|
||||
old_module = extract_layer(parent_module, n)
|
||||
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
|
||||
if isinstance(old_module, Conv2dSame):
|
||||
conv = Conv2dSame
|
||||
else:
|
||||
conv = nn.Conv2d
|
||||
s = state_dict[n + '.weight']
|
||||
in_channels = s[1]
|
||||
out_channels = s[0]
|
||||
g = 1
|
||||
if old_module.groups > 1:
|
||||
in_channels = out_channels
|
||||
g = in_channels
|
||||
new_conv = conv(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
|
||||
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
||||
groups=g, stride=old_module.stride)
|
||||
set_layer(new_module, n, new_conv)
|
||||
elif isinstance(old_module, BatchNormAct2d):
|
||||
new_bn = BatchNormAct2d(
|
||||
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
||||
affine=old_module.affine, track_running_stats=True)
|
||||
new_bn.drop = old_module.drop
|
||||
new_bn.act = old_module.act
|
||||
set_layer(new_module, n, new_bn)
|
||||
elif isinstance(old_module, nn.BatchNorm2d):
|
||||
new_bn = nn.BatchNorm2d(
|
||||
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
||||
affine=old_module.affine, track_running_stats=True)
|
||||
set_layer(new_module, n, new_bn)
|
||||
elif isinstance(old_module, nn.Linear):
|
||||
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
|
||||
num_features = state_dict[n + '.weight'][1]
|
||||
new_fc = Linear(
|
||||
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
|
||||
set_layer(new_module, n, new_fc)
|
||||
if hasattr(new_module, 'num_features'):
|
||||
new_module.num_features = num_features
|
||||
new_module.eval()
|
||||
parent_module.eval()
|
||||
|
||||
return new_module
|
||||
|
||||
|
||||
def adapt_model_from_file(parent_module, model_variant):
|
||||
adapt_file = os.path.join(os.path.dirname(__file__), '_pruned', model_variant + '.txt')
|
||||
with open(adapt_file, 'r') as f:
|
||||
return adapt_model_from_string(parent_module, f.read().strip())
|
212
timm/models/_registry.py
Normal file
212
timm/models/_registry.py
Normal file
@ -0,0 +1,212 @@
|
||||
""" Model Registry
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
import re
|
||||
import sys
|
||||
from collections import defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Union, Tuple
|
||||
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||
|
||||
__all__ = [
|
||||
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
||||
|
||||
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
|
||||
_model_to_module = {} # mapping of model names to module names
|
||||
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
|
||||
_model_has_pretrained = set() # set of model names that have pretrained weight url present
|
||||
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
|
||||
_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs
|
||||
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
||||
|
||||
|
||||
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]:
|
||||
return split_model_name_tag(model_name)[0]
|
||||
|
||||
|
||||
def register_model(fn):
|
||||
# lookup containing module
|
||||
mod = sys.modules[fn.__module__]
|
||||
module_name_split = fn.__module__.split('.')
|
||||
module_name = module_name_split[-1] if len(module_name_split) else ''
|
||||
|
||||
# add model to __all__ in module
|
||||
model_name = fn.__name__
|
||||
if hasattr(mod, '__all__'):
|
||||
mod.__all__.append(model_name)
|
||||
else:
|
||||
mod.__all__ = [model_name]
|
||||
|
||||
# add entries to registry dict/sets
|
||||
_model_entrypoints[model_name] = fn
|
||||
_model_to_module[model_name] = module_name
|
||||
_module_to_models[module_name].add(model_name)
|
||||
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
|
||||
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
|
||||
# entrypoints or non-matching combos
|
||||
cfg = mod.default_cfgs[model_name]
|
||||
if not isinstance(cfg, DefaultCfg):
|
||||
# new style default cfg dataclass w/ multiple entries per model-arch
|
||||
assert isinstance(cfg, dict)
|
||||
# old style cfg dict per model-arch
|
||||
cfg = PretrainedCfg(**cfg)
|
||||
cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg})
|
||||
|
||||
for tag_idx, tag in enumerate(cfg.tags):
|
||||
is_default = tag_idx == 0
|
||||
pretrained_cfg = cfg.cfgs[tag]
|
||||
if is_default:
|
||||
_model_pretrained_cfgs[model_name] = pretrained_cfg
|
||||
if pretrained_cfg.has_weights:
|
||||
# add tagless entry if it's default and has weights
|
||||
_model_has_pretrained.add(model_name)
|
||||
if tag:
|
||||
model_name_tag = '.'.join([model_name, tag])
|
||||
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg
|
||||
if pretrained_cfg.has_weights:
|
||||
# add model w/ tag if tag is valid
|
||||
_model_has_pretrained.add(model_name_tag)
|
||||
_model_with_tags[model_name].append(model_name_tag)
|
||||
else:
|
||||
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
|
||||
|
||||
_model_default_cfgs[model_name] = cfg
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def _natural_key(string_):
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
||||
|
||||
def list_models(
|
||||
filter: Union[str, List[str]] = '',
|
||||
module: str = '',
|
||||
pretrained=False,
|
||||
exclude_filters: str = '',
|
||||
name_matches_cfg: bool = False,
|
||||
include_tags: Optional[bool] = None,
|
||||
):
|
||||
""" Return list of available model names, sorted alphabetically
|
||||
|
||||
Args:
|
||||
filter (str) - Wildcard filter string that works with fnmatch
|
||||
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
|
||||
pretrained (bool) - Include only models with valid pretrained weights if True
|
||||
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
|
||||
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
||||
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
|
||||
set to True when pretrained=True else False (default: None)
|
||||
Example:
|
||||
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
||||
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
||||
"""
|
||||
if include_tags is None:
|
||||
# FIXME should this be default behaviour? or default to include_tags=True?
|
||||
include_tags = pretrained
|
||||
|
||||
if module:
|
||||
all_models = list(_module_to_models[module])
|
||||
else:
|
||||
all_models = _model_entrypoints.keys()
|
||||
|
||||
if include_tags:
|
||||
# expand model names to include names w/ pretrained tags
|
||||
models_with_tags = []
|
||||
for m in all_models:
|
||||
models_with_tags.extend(_model_with_tags[m])
|
||||
all_models = models_with_tags
|
||||
|
||||
if filter:
|
||||
models = []
|
||||
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
|
||||
for f in include_filters:
|
||||
include_models = fnmatch.filter(all_models, f) # include these models
|
||||
if len(include_models):
|
||||
models = set(models).union(include_models)
|
||||
else:
|
||||
models = all_models
|
||||
|
||||
if exclude_filters:
|
||||
if not isinstance(exclude_filters, (tuple, list)):
|
||||
exclude_filters = [exclude_filters]
|
||||
for xf in exclude_filters:
|
||||
exclude_models = fnmatch.filter(models, xf) # exclude these models
|
||||
if len(exclude_models):
|
||||
models = set(models).difference(exclude_models)
|
||||
|
||||
if pretrained:
|
||||
models = _model_has_pretrained.intersection(models)
|
||||
|
||||
if name_matches_cfg:
|
||||
models = set(_model_pretrained_cfgs).intersection(models)
|
||||
|
||||
return list(sorted(models, key=_natural_key))
|
||||
|
||||
|
||||
def list_pretrained(
|
||||
filter: Union[str, List[str]] = '',
|
||||
exclude_filters: str = '',
|
||||
):
|
||||
return list_models(
|
||||
filter=filter,
|
||||
pretrained=True,
|
||||
exclude_filters=exclude_filters,
|
||||
include_tags=True,
|
||||
)
|
||||
|
||||
|
||||
def is_model(model_name):
|
||||
""" Check if a model name exists
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
return arch_name in _model_entrypoints
|
||||
|
||||
|
||||
def model_entrypoint(model_name, module_filter: Optional[str] = None):
|
||||
"""Fetch a model entrypoint for specified model name
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
if module_filter and arch_name not in _module_to_models.get(module_filter, {}):
|
||||
raise RuntimeError(f'Model ({model_name} not found in module {module_filter}.')
|
||||
return _model_entrypoints[arch_name]
|
||||
|
||||
|
||||
def list_modules():
|
||||
""" Return list of module names that contain models / model entrypoints
|
||||
"""
|
||||
modules = _module_to_models.keys()
|
||||
return list(sorted(modules))
|
||||
|
||||
|
||||
def is_model_in_modules(model_name, module_names):
|
||||
"""Check if a model exists within a subset of modules
|
||||
Args:
|
||||
model_name (str) - name of model to check
|
||||
module_names (tuple, list, set) - names of modules to search in
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
assert isinstance(module_names, (tuple, list, set))
|
||||
return any(arch_name in _module_to_models[n] for n in module_names)
|
||||
|
||||
|
||||
def is_model_pretrained(model_name):
|
||||
return model_name in _model_has_pretrained
|
||||
|
||||
|
||||
def get_pretrained_cfg(model_name):
|
||||
if model_name in _model_pretrained_cfgs:
|
||||
return deepcopy(_model_pretrained_cfgs[model_name])
|
||||
raise RuntimeError(f'No pretrained config exists for model {model_name}.')
|
||||
|
||||
|
||||
def get_pretrained_cfg_value(model_name, cfg_key):
|
||||
""" Get a specific model default_cfg value by key. None if key doesn't exist.
|
||||
"""
|
||||
if model_name in _model_pretrained_cfgs:
|
||||
return getattr(_model_pretrained_cfgs[model_name], cfg_key, None)
|
||||
raise RuntimeError(f'No pretrained config exist for model {model_name}.')
|
@ -61,12 +61,14 @@ import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
|
||||
from .pretrained import generate_default_cfgs
|
||||
from .registry import register_model
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
from .vision_transformer import checkpoint_filter_fn
|
||||
|
||||
__all__ = ['Beit']
|
||||
|
||||
|
||||
def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
|
||||
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
||||
|
@ -13,9 +13,9 @@ Consider all of the models definitions here as experimental WIP and likely to ch
|
||||
Hacked together by / copyright Ross Wightman, 2021.
|
||||
"""
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
|
||||
from .helpers import build_model_with_cfg
|
||||
from .registry import register_model
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
@ -26,18 +26,18 @@ Hacked together by / copyright Ross Wightman, 2021.
|
||||
"""
|
||||
import math
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, named_apply, checkpoint_seq
|
||||
from .layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
||||
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\
|
||||
EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d
|
||||
from .registry import register_model
|
||||
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
||||
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._registry import register_model
|
||||
|
||||
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
|
||||
|
||||
|
@ -8,17 +8,16 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
||||
"""
|
||||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, checkpoint_seq
|
||||
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
|
||||
from .registry import register_model
|
||||
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model
|
||||
|
||||
__all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn']
|
||||
|
||||
|
@ -7,7 +7,6 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
|
||||
|
||||
Modified from timm/models/vision_transformer.py
|
||||
"""
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Tuple, List, Union
|
||||
|
||||
@ -16,19 +15,11 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
||||
from .registry import register_model
|
||||
from .layers import _assert
|
||||
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
"coat_tiny",
|
||||
"coat_mini",
|
||||
"coat_lite_tiny",
|
||||
"coat_lite_mini",
|
||||
"coat_lite_small"
|
||||
]
|
||||
__all__ = ['CoaT']
|
||||
|
||||
|
||||
def _cfg_coat(url='', **kwargs):
|
||||
|
@ -22,20 +22,20 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
||||
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
'''
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
|
||||
from .registry import register_model
|
||||
from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._registry import register_model
|
||||
from .vision_transformer_hybrid import HybridEmbed
|
||||
from .fx_features import register_notrace_module
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['ConViT']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
|
@ -5,9 +5,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.registry import register_model
|
||||
from .helpers import build_model_with_cfg, checkpoint_seq
|
||||
from .layers import SelectAdaptivePool2d
|
||||
from timm.layers import SelectAdaptivePool2d
|
||||
from ._registry import register_model
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
|
||||
__all__ = ['ConvMixer']
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
|
@ -18,12 +18,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
|
||||
from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
|
||||
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
|
||||
create_conv2d, get_act_layer, make_divisible, to_ntuple
|
||||
from .pretrained import generate_default_cfgs
|
||||
from .registry import register_model
|
||||
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._pretrained import generate_default_cfgs
|
||||
from ._registry import register_model
|
||||
|
||||
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
@ -24,21 +24,22 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W
|
||||
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.hub
|
||||
from functools import partial
|
||||
from typing import List
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_notrace_function
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import DropPath, to_2tuple, trunc_normal_, _assert
|
||||
from .registry import register_model
|
||||
from .vision_transformer import Mlp, Block
|
||||
from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_function
|
||||
from ._registry import register_model
|
||||
from .vision_transformer import Block
|
||||
|
||||
__all__ = ['CrossViT'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
|
@ -12,20 +12,18 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import collections.abc
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from dataclasses import dataclass, asdict
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, named_apply, MATCH_PREV_GROUP
|
||||
from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
|
||||
from .registry import register_model
|
||||
|
||||
from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import named_apply, MATCH_PREV_GROUP
|
||||
from ._registry import register_model
|
||||
|
||||
__all__ = ['CspNet'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
@ -17,9 +17,11 @@ from torch import nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model
|
||||
|
||||
from .helpers import build_model_with_cfg, checkpoint_seq
|
||||
from .registry import register_model
|
||||
__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
|
@ -4,7 +4,6 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool.
|
||||
"""
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -13,9 +12,10 @@ import torch.utils.checkpoint as cp
|
||||
from torch.jit.annotations import List
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg, MATCH_PREV_GROUP
|
||||
from .layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier
|
||||
from .registry import register_model
|
||||
from timm.layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._manipulate import MATCH_PREV_GROUP
|
||||
from ._registry import register_model
|
||||
|
||||
__all__ = ['DenseNet']
|
||||
|
||||
|
@ -13,9 +13,9 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import create_classifier
|
||||
from .registry import register_model
|
||||
from timm.layers import create_classifier
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
|
||||
__all__ = ['DLA']
|
||||
|
||||
|
@ -15,9 +15,9 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier
|
||||
from .registry import register_model
|
||||
from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
|
||||
__all__ = ['DPN']
|
||||
|
||||
|
@ -8,20 +8,20 @@ Original code and weights from https://github.com/mmaaz60/EdgeNeXt
|
||||
Modifications and additions for timm by / Copyright 2022, Ross Wightman
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Tuple
|
||||
|
||||
from torch import nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .fx_features import register_notrace_module
|
||||
from .layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
|
||||
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
|
||||
from .registry import register_model
|
||||
|
||||
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._manipulate import named_apply, checkpoint_seq
|
||||
from ._registry import register_model
|
||||
|
||||
__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
@ -18,9 +18,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import build_model_with_cfg
|
||||
from .layers import DropPath, trunc_normal_, to_2tuple, Mlp
|
||||
from .registry import register_model
|
||||
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._registry import register_model
|
||||
|
||||
__all__ = ['EfficientFormer'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
|
@ -42,15 +42,15 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .efficientnet_blocks import SqueezeExcite
|
||||
from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights,\
|
||||
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct
|
||||
from ._builder import build_model_with_cfg, pretrained_cfg_for_features
|
||||
from ._efficientnet_blocks import SqueezeExcite
|
||||
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
|
||||
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
|
||||
from .features import FeatureInfo, FeatureHooks
|
||||
from .helpers import build_model_with_cfg, pretrained_cfg_for_features, checkpoint_seq
|
||||
from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct
|
||||
from .registry import register_model
|
||||
from ._features import FeatureInfo, FeatureHooks
|
||||
from ._manipulate import checkpoint_seq
|
||||
from ._registry import register_model
|
||||
|
||||
__all__ = ['EfficientNet', 'EfficientNetFeatures']
|
||||
|
||||
|
@ -1,100 +1,4 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import urlsplit
|
||||
from ._factory import *
|
||||
|
||||
from .pretrained import PretrainedCfg, split_model_name_tag
|
||||
from .helpers import load_checkpoint
|
||||
from .hub import load_model_config_from_hf
|
||||
from .layers import set_layer_config
|
||||
from .registry import is_model, model_entrypoint
|
||||
|
||||
|
||||
def parse_model_name(model_name):
|
||||
if model_name.startswith('hf_hub'):
|
||||
# NOTE for backwards compat, deprecate hf_hub use
|
||||
model_name = model_name.replace('hf_hub', 'hf-hub')
|
||||
parsed = urlsplit(model_name)
|
||||
assert parsed.scheme in ('', 'timm', 'hf-hub')
|
||||
if parsed.scheme == 'hf-hub':
|
||||
# FIXME may use fragment as revision, currently `@` in URI path
|
||||
return parsed.scheme, parsed.path
|
||||
else:
|
||||
model_name = os.path.split(parsed.path)[-1]
|
||||
return 'timm', model_name
|
||||
|
||||
|
||||
def safe_model_name(model_name, remove_source=True):
|
||||
# return a filename / path safe model name
|
||||
def make_safe(name):
|
||||
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
|
||||
if remove_source:
|
||||
model_name = parse_model_name(model_name)[-1]
|
||||
return make_safe(model_name)
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name: str,
|
||||
pretrained: bool = False,
|
||||
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
|
||||
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
|
||||
checkpoint_path: str = '',
|
||||
scriptable: Optional[bool] = None,
|
||||
exportable: Optional[bool] = None,
|
||||
no_jit: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a model
|
||||
|
||||
Lookup model's entrypoint function and pass relevant args to create a new model.
|
||||
|
||||
**kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg()
|
||||
and then the model class __init__(). kwargs values set to None are pruned before passing.
|
||||
|
||||
Args:
|
||||
model_name (str): name of model to instantiate
|
||||
pretrained (bool): load pretrained ImageNet-1k weights if true
|
||||
pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model
|
||||
pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these
|
||||
checkpoint_path (str): path of checkpoint to load _after_ the model is initialized
|
||||
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
|
||||
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
|
||||
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
|
||||
|
||||
Keyword Args:
|
||||
drop_rate (float): dropout rate for training (default: 0.0)
|
||||
global_pool (str): global pool type (default: 'avg')
|
||||
**: other kwargs are consumed by builder or model __init__()
|
||||
"""
|
||||
# Parameters that aren't supported by all models or are intended to only override model defaults if set
|
||||
# should default to None in command line args/cfg. Remove them if they are present and not set so that
|
||||
# non-supporting models don't break and default args remain in effect.
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
model_source, model_name = parse_model_name(model_name)
|
||||
if model_source == 'hf-hub':
|
||||
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
|
||||
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
|
||||
# load model weights + pretrained_cfg from Hugging Face hub.
|
||||
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
|
||||
else:
|
||||
model_name, pretrained_tag = split_model_name_tag(model_name)
|
||||
if not pretrained_cfg:
|
||||
# a valid pretrained_cfg argument takes priority over tag in model name
|
||||
pretrained_cfg = pretrained_tag
|
||||
|
||||
if not is_model(model_name):
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
create_fn = model_entrypoint(model_name)
|
||||
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
|
||||
model = create_fn(
|
||||
pretrained=pretrained,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
pretrained_cfg_overlay=pretrained_cfg_overlay,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
||||
return model
|
||||
import warnings
|
||||
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user