Merge remote-tracking branch 'origin/main' into focalnet_and_swin_refactor

This commit is contained in:
Ross Wightman 2023-03-15 15:58:39 -07:00
commit c30a160d3e
18 changed files with 549 additions and 248 deletions

View File

@ -19,6 +19,7 @@ jobs:
python: ['3.10'] python: ['3.10']
torch: ['1.13.0'] torch: ['1.13.0']
torchvision: ['0.14.0'] torchvision: ['0.14.0']
testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward']
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
@ -30,7 +31,7 @@ jobs:
- name: Install testing dependencies - name: Install testing dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install pytest pytest-timeout pytest-xdist pytest-forked expecttest pip install -r requirements-dev.txt
- name: Install torch on mac - name: Install torch on mac
if: startsWith(matrix.os, 'macOS') if: startsWith(matrix.os, 'macOS')
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }} run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
@ -54,10 +55,10 @@ jobs:
PYTHONDONTWRITEBYTECODE: 1 PYTHONDONTWRITEBYTECODE: 1
run: | run: |
pytest -vv tests pytest -vv tests
- name: Run tests on Linux / Mac - name: Run '${{ matrix.testmarker }}' tests on Linux / Mac
if: ${{ !startsWith(matrix.os, 'windows') }} if: ${{ !startsWith(matrix.os, 'windows') }}
env: env:
LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
PYTHONDONTWRITEBYTECODE: 1 PYTHONDONTWRITEBYTECODE: 1
run: | run: |
pytest -vv --forked --durations=0 tests pytest -vv --forked --durations=0 ${{ matrix.testmarker }} tests

107
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,107 @@
*This guideline is very much a work-in-progress.*
Contriubtions to `timm` for code, documentation, tests are more than welcome!
There haven't been any formal guidelines to date so please bear with me, and feel free to add to this guide.
# Coding style
Code linting and auto-format (black) are not currently in place but open to consideration. In the meantime, the style to follow is (mostly) aligned with Google's guide: https://google.github.io/styleguide/pyguide.html.
A few specific differences from Google style (or black)
1. Line length is 120 char. Going over is okay in some cases (e.g. I prefer not to break URL across lines).
2. Hanging indents are always prefered, please avoid aligning arguments with closing brackets or braces.
Example, from Google guide, but this is a NO here:
```
# Aligned with opening delimiter.
foo = long_function_name(var_one, var_two,
var_three, var_four)
meal = (spam,
beans)
# Aligned with opening delimiter in a dictionary.
foo = {
'long_dictionary_key': value1 +
value2,
...
}
```
This is YES:
```
# 4-space hanging indent; nothing on first line,
# closing parenthesis on a new line.
foo = long_function_name(
var_one, var_two, var_three,
var_four
)
meal = (
spam,
beans,
)
# 4-space hanging indent in a dictionary.
foo = {
'long_dictionary_key':
long_dictionary_value,
...
}
```
When there is descrepancy in a given source file (there are many origins for various bits of code and not all have been updated to what I consider current goal), please follow the style in a given file.
In general, if you add new code, formatting it with black using the following options should result in a style that is compatible with the rest of the code base:
```
black --skip-string-normalization --line-length 120 <path-to-file>
```
Avoid formatting code that is unrelated to your PR though.
PR with pure formatting / style fixes will be accepted but only in isolation from functional changes, best to ask before starting such a change.
# Documentation
As with code style, docstrings style based on the Google guide: guide: https://google.github.io/styleguide/pyguide.html
The goal for the code is to eventually move to have all major functions and `__init__` methods use PEP484 type annotations.
When type annotations are used for a function, as per the Google pyguide, they should **NOT** be duplicated in the docstrings, please leave annotations as the one source of truth re typing.
There are a LOT of gaps in current documentation relative to the functionality in timm, please, document away!
# Installation
Create a Python virtual environment using Python 3.10. Inside the environment, install torch` and `torchvision` using the instructions matching your system as listed on the [PyTorch website](https://pytorch.org/).
Then install the remaining dependencies:
```
python -m pip install -r requirements.txt
python -m pip install -r requirements-dev.txt # for testing
python -m pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git
python -m pip install -e .
```
## Unit tests
Run the tests using:
```
pytest tests/
```
Since the whole test suite takes a lot of time to run locally (a few hours), you may want to select a subset of tests relating to the changes you made by using the `-k` option of [`pytest`](https://docs.pytest.org/en/7.1.x/example/markers.html#using-k-expr-to-select-tests-based-on-their-name). Moreover, running tests in parallel (in this example 4 processes) with the `-n` option may help:
```
pytest -k "substring-to-match" -n 4 tests/
```
## Building documentation
Please refer to [this document](https://github.com/huggingface/pytorch-image-models/tree/main/hfdocs).
# Questions
If you have any questions about contribution, where / how to contribute, please ask in the [Discussions](https://github.com/huggingface/pytorch-image-models/discussions/categories/contributing) (there is a `Contributing` topic).

View File

@ -24,6 +24,15 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗ * ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch. * Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
### Feb 26, 2023
* Add ConvNeXt-XXLarge CLIP pretrained image tower weights for fine-tune & features (fine-tuning TBD) -- see [model card](https://huggingface.co/laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup)
* Update `convnext_xxlarge` default LayerNorm eps to 1e-5 (for CLIP weights, improved stability)
* 0.8.15dev0
### Feb 20, 2023
* Add 320x320 `convnext_large_mlp.clip_laion2b_ft_320` and `convnext_lage_mlp.clip_laion2b_ft_soup_320` CLIP image tower weights for features & fine-tune
* 0.8.13dev0 pypi release for latest changes w/ move to huggingface org
### Feb 16, 2023 ### Feb 16, 2023
* `safetensor` checkpoint support added * `safetensor` checkpoint support added
* Add ideas from 'Scaling Vision Transformers to 22 B. Params' (https://arxiv.org/abs/2302.05442) -- qk norm, RmsNorm, parallel block * Add ideas from 'Scaling Vision Transformers to 22 B. Params' (https://arxiv.org/abs/2302.05442) -- qk norm, RmsNorm, parallel block
@ -435,9 +444,7 @@ The work of many others is present here. I've tried to make sure all source mate
## Models ## Models
All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated. Here are some example [training hparams](https://rwightman.github.io/pytorch-image-models/training_hparam_examples) to get you started. All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated.
A full version of the list below with source links can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/models/).
* Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723 * Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723
* BEiT - https://arxiv.org/abs/2106.08254 * BEiT - https://arxiv.org/abs/2106.08254
@ -538,15 +545,15 @@ Several (less common) features that I often utilize in my projects are included.
* All models have a common default configuration interface and API for * All models have a common default configuration interface and API for
* accessing/changing the classifier - `get_classifier` and `reset_classifier` * accessing/changing the classifier - `get_classifier` and `reset_classifier`
* doing a forward pass on just the features - `forward_features` (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/)) * doing a forward pass on just the features - `forward_features` (see [documentation](https://huggingface.co/docs/timm/feature_extraction))
* these makes it easy to write consistent network wrappers that work with any of the models * these makes it easy to write consistent network wrappers that work with any of the models
* All models support multi-scale feature map extraction (feature pyramids) via create_model (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/)) * All models support multi-scale feature map extraction (feature pyramids) via create_model (see [documentation](https://huggingface.co/docs/timm/feature_extraction))
* `create_model(name, features_only=True, out_indices=..., output_stride=...)` * `create_model(name, features_only=True, out_indices=..., output_stride=...)`
* `out_indices` creation arg specifies which feature maps to return, these indices are 0 based and generally correspond to the `C(i + 1)` feature level. * `out_indices` creation arg specifies which feature maps to return, these indices are 0 based and generally correspond to the `C(i + 1)` feature level.
* `output_stride` creation arg controls output stride of the network by using dilated convolutions. Most networks are stride 32 by default. Not all networks support this. * `output_stride` creation arg controls output stride of the network by using dilated convolutions. Most networks are stride 32 by default. Not all networks support this.
* feature map channel counts, reduction level (stride) can be queried AFTER model creation via the `.feature_info` member * feature map channel counts, reduction level (stride) can be queried AFTER model creation via the `.feature_info` member
* All models have a consistent pretrained weight loader that adapts last linear if necessary, and from 3 to 1 channel input if desired * All models have a consistent pretrained weight loader that adapts last linear if necessary, and from 3 to 1 channel input if desired
* High performance [reference training, validation, and inference scripts](https://rwightman.github.io/pytorch-image-models/scripts/) that work in several process/GPU modes: * High performance [reference training, validation, and inference scripts](https://huggingface.co/docs/timm/training_script) that work in several process/GPU modes:
* NVIDIA DDP w/ a single GPU per process, multiple processes with APEX present (AMP mixed-precision optional) * NVIDIA DDP w/ a single GPU per process, multiple processes with APEX present (AMP mixed-precision optional)
* PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled) * PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled)
* PyTorch w/ single GPU single process (AMP optional) * PyTorch w/ single GPU single process (AMP optional)
@ -600,19 +607,17 @@ Model validation results can be found in the [results tables](results/README.md)
## Getting Started (Documentation) ## Getting Started (Documentation)
My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics. The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.
Hugging Face [`timm` docs](https://huggingface.co/docs/hub/timm) will be the documentation focus going forward and will eventually replace the `github.io` docs above.
[Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail. [Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
[timmdocs](http://timm.fast.ai/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. [timmdocs](http://timm.fast.ai/) is an alternate set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
[paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`. [paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`.
## Train, Validation, Inference Scripts ## Train, Validation, Inference Scripts
The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See [documentation](https://rwightman.github.io/pytorch-image-models/scripts/) for some basics and [training hparams](https://rwightman.github.io/pytorch-image-models/training_hparam_examples) for some train examples that produce SOTA ImageNet results. The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See [documentation](https://huggingface.co/docs/timm/training_script).
## Awesome PyTorch Resources ## Awesome PyTorch Resources

View File

@ -17,10 +17,14 @@ import os
import glob import glob
import hashlib import hashlib
from timm.models import load_state_dict from timm.models import load_state_dict
import safetensors.torch try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
DEFAULT_OUTPUT = "./average.pth" DEFAULT_OUTPUT = "./averaged.pth"
DEFAULT_SAFE_OUTPUT = "./average.safetensors" DEFAULT_SAFE_OUTPUT = "./averaged.safetensors"
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH', parser.add_argument('--input', default='', type=str, metavar='PATH',
@ -38,6 +42,7 @@ parser.add_argument('-n', type=int, default=10, metavar='N',
parser.add_argument('--safetensors', action='store_true', parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).') help='Save weights using safetensors instead of the default torch way (pickle).')
def checkpoint_metric(checkpoint_path): def checkpoint_metric(checkpoint_path):
if not checkpoint_path or not os.path.isfile(checkpoint_path): if not checkpoint_path or not os.path.isfile(checkpoint_path):
return {} return {}
@ -63,14 +68,20 @@ def main():
if args.safetensors and args.output == DEFAULT_OUTPUT: if args.safetensors and args.output == DEFAULT_OUTPUT:
# Default path changes if using safetensors # Default path changes if using safetensors
args.output = DEFAULT_SAFE_OUTPUT args.output = DEFAULT_SAFE_OUTPUT
if args.safetensors and not args.output.endswith(".safetensors"):
output, output_ext = os.path.splitext(args.output)
if not output_ext:
output_ext = ('.safetensors' if args.safetensors else '.pth')
output = output + output_ext
if args.safetensors and not output_ext == ".safetensors":
print( print(
"Warning: saving weights as safetensors but output file extension is not " "Warning: saving weights as safetensors but output file extension is not "
f"set to '.safetensors': {args.output}" f"set to '.safetensors': {args.output}"
) )
if os.path.exists(args.output): if os.path.exists(output):
print("Error: Output filename ({}) already exists.".format(args.output)) print("Error: Output filename ({}) already exists.".format(output))
exit(1) exit(1)
pattern = args.input pattern = args.input
@ -87,22 +98,27 @@ def main():
checkpoint_metrics.append((metric, c)) checkpoint_metrics.append((metric, c))
checkpoint_metrics = list(sorted(checkpoint_metrics)) checkpoint_metrics = list(sorted(checkpoint_metrics))
checkpoint_metrics = checkpoint_metrics[-args.n:] checkpoint_metrics = checkpoint_metrics[-args.n:]
if checkpoint_metrics:
print("Selected checkpoints:") print("Selected checkpoints:")
[print(m, c) for m, c in checkpoint_metrics] [print(m, c) for m, c in checkpoint_metrics]
avg_checkpoints = [c for m, c in checkpoint_metrics] avg_checkpoints = [c for m, c in checkpoint_metrics]
else: else:
avg_checkpoints = checkpoints avg_checkpoints = checkpoints
if avg_checkpoints:
print("Selected checkpoints:") print("Selected checkpoints:")
[print(c) for c in checkpoints] [print(c) for c in checkpoints]
if not avg_checkpoints:
print('Error: No checkpoints found to average.')
exit(1)
avg_state_dict = {} avg_state_dict = {}
avg_counts = {} avg_counts = {}
for c in avg_checkpoints: for c in avg_checkpoints:
new_state_dict = load_state_dict(c, args.use_ema) new_state_dict = load_state_dict(c, args.use_ema)
if not new_state_dict: if not new_state_dict:
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) print(f"Error: Checkpoint ({c}) doesn't exist")
continue continue
for k, v in new_state_dict.items(): for k, v in new_state_dict.items():
if k not in avg_state_dict: if k not in avg_state_dict:
avg_state_dict[k] = v.clone().to(dtype=torch.float64) avg_state_dict[k] = v.clone().to(dtype=torch.float64)
@ -122,16 +138,14 @@ def main():
final_state_dict[k] = v.to(dtype=torch.float32) final_state_dict[k] = v.to(dtype=torch.float32)
if args.safetensors: if args.safetensors:
safetensors.torch.save_file(final_state_dict, args.output) assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(final_state_dict, output)
else: else:
try: torch.save(final_state_dict, output)
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
except:
torch.save(final_state_dict, args.output)
with open(args.output, 'rb') as f: with open(output, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest() sha_hash = hashlib.sha256(f.read()).hexdigest()
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash)) print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -11,9 +11,14 @@ import torch
import argparse import argparse
import os import os
import hashlib import hashlib
import safetensors.torch
import shutil import shutil
import tempfile
from timm.models import load_state_dict from timm.models import load_state_dict
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -22,13 +27,13 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
help='output path') help='output path')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='use ema version of weights if present') help='use ema version of weights if present')
parser.add_argument('--no-hash', dest='no_hash', action='store_true',
help='no hash in output filename')
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
parser.add_argument('--safetensors', action='store_true', parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).') help='Save weights using safetensors instead of the default torch way (pickle).')
_TEMP_NAME = './_checkpoint.pth'
def main(): def main():
args = parser.parse_args() args = parser.parse_args()
@ -37,10 +42,24 @@ def main():
print("Error: Output filename ({}) already exists.".format(args.output)) print("Error: Output filename ({}) already exists.".format(args.output))
exit(1) exit(1)
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors) clean_checkpoint(
args.checkpoint,
args.output,
not args.no_use_ema,
args.no_hash,
args.clean_aux_bn,
safe_serialization=args.safetensors,
)
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False): def clean_checkpoint(
checkpoint,
output,
use_ema=True,
no_hash=False,
clean_aux_bn=False,
safe_serialization: bool=False,
):
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
if checkpoint and os.path.isfile(checkpoint): if checkpoint and os.path.isfile(checkpoint):
print("=> Loading checkpoint '{}'".format(checkpoint)) print("=> Loading checkpoint '{}'".format(checkpoint))
@ -55,25 +74,36 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, sa
new_state_dict[name] = v new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(checkpoint)) print("=> Loaded state_dict from '{}'".format(checkpoint))
if safe_serialization: ext = ''
safetensors.torch.save_file(new_state_dict, _TEMP_NAME)
else:
try:
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
except:
torch.save(new_state_dict, _TEMP_NAME)
with open(_TEMP_NAME, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
if output: if output:
checkpoint_root, checkpoint_base = os.path.split(output) checkpoint_root, checkpoint_base = os.path.split(output)
checkpoint_base = os.path.splitext(checkpoint_base)[0] checkpoint_base, ext = os.path.splitext(checkpoint_base)
else: else:
checkpoint_root = '' checkpoint_root = ''
checkpoint_base = os.path.splitext(checkpoint)[0] checkpoint_base = os.path.split(checkpoint)[1]
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth') checkpoint_base = os.path.splitext(checkpoint_base)[0]
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
temp_filename = '__' + checkpoint_base
if safe_serialization:
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(new_state_dict, temp_filename)
else:
torch.save(new_state_dict, temp_filename)
with open(temp_filename, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
if ext:
final_ext = ext
else:
final_ext = ('.safetensors' if safe_serialization else '.pth')
if no_hash:
final_filename = checkpoint_base + final_ext
else:
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext
shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename))
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
return final_filename return final_filename
else: else:

14
pyproject.toml Normal file
View File

@ -0,0 +1,14 @@
[tool.pytest.ini_options]
markers = [
"base: marker for model tests using the basic setup",
"cfg: marker for model tests checking the config",
"torchscript: marker for model tests using torchscript",
"features: marker for model tests checking feature extraction",
"fxforward: marker for model tests using torch fx (only forward)",
"fxbackward: marker for model tests using torch fx (only backward)",
]
[tool.black]
line-length = 120
target-version = ['py37', 'py38', 'py39', 'py310', 'py311']
skip-string-normalization = true

5
requirements-dev.txt Normal file
View File

@ -0,0 +1,5 @@
pytest
pytest-timeout
pytest-xdist
pytest-forked
expecttest

View File

@ -14,12 +14,12 @@ exec(open('timm/version.py').read())
setup( setup(
name='timm', name='timm',
version=__version__, version=__version__,
description='(Unofficial) PyTorch Image Models', description='PyTorch Image Models',
long_description=long_description, long_description=long_description,
long_description_content_type='text/markdown', long_description_content_type='text/markdown',
url='https://github.com/rwightman/pytorch-image-models', url='https://github.com/huggingface/pytorch-image-models',
author='Ross Wightman', author='Ross Wightman',
author_email='hello@rwightman.com', author_email='ross@huggingface.co',
classifiers=[ classifiers=[
# How mature is this project? Common values are # How mature is this project? Common values are
# 3 - Alpha # 3 - Alpha
@ -29,11 +29,11 @@ setup(
'Intended Audience :: Education', 'Intended Audience :: Education',
'Intended Audience :: Science/Research', 'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License', 'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development', 'Topic :: Software Development',
@ -45,7 +45,7 @@ setup(
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit', keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
packages=find_packages(exclude=['convert', 'tests', 'results']), packages=find_packages(exclude=['convert', 'tests', 'results']),
include_package_data=True, include_package_data=True,
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'], install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub', 'safetensors'],
python_requires='>=3.6', python_requires='>=3.7',
) )

View File

@ -1,3 +1,16 @@
"""Run tests for all models
Tests that run on CI should have a specific marker, e.g. @pytest.mark.base. This
marker is used to parallelize the CI runs, with one runner for each marker.
If new tests are added, ensure that they use one of the existing markers
(documented in pyproject.toml > pytest > markers) or that a new marker is added
for this set of tests. If using a new marker, adjust the test matrix in
.github/workflows/tests.yml to run tests with this new marker, otherwise the
tests will be skipped on CI.
"""
import pytest import pytest
import torch import torch
import platform import platform
@ -83,6 +96,7 @@ def _get_input_size(model=None, model_name='', target=None):
return input_size return input_size
@pytest.mark.base
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
@ -101,6 +115,7 @@ def test_model_forward(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs' assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.base
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2]) @pytest.mark.parametrize('batch_size', [2])
@ -128,6 +143,7 @@ def test_model_backward(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs' assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.cfg
@pytest.mark.timeout(300) @pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS, include_tags=True)) @pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
@ -190,6 +206,7 @@ def test_model_default_cfgs(model_name, batch_size):
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params' assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
@pytest.mark.cfg
@pytest.mark.timeout(300) @pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True)) @pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
@ -274,6 +291,7 @@ EXCLUDE_JIT_FILTERS = [
] ]
@pytest.mark.torchscript
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True)) 'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
@ -303,6 +321,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d'] EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d']
@pytest.mark.features
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True)) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
@ -379,6 +398,7 @@ if not _IS_MAC:
] ]
@pytest.mark.fxforward
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS)) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
@ -412,6 +432,7 @@ if not _IS_MAC:
assert not torch.isnan(outputs).any(), 'Output included NaNs' assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.fxbackward
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models( @pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True)) exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))

View File

@ -54,7 +54,6 @@ def _interpolation(kwargs):
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION) interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
if isinstance(interpolation, (list, tuple)): if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation) return random.choice(interpolation)
else:
return interpolation return interpolation
@ -100,7 +99,7 @@ def rotate(img, degrees, **kwargs):
_check_args_tf(kwargs) _check_args_tf(kwargs)
if _PIL_VER >= (5, 2): if _PIL_VER >= (5, 2):
return img.rotate(degrees, **kwargs) return img.rotate(degrees, **kwargs)
elif _PIL_VER >= (5, 0): if _PIL_VER >= (5, 0):
w, h = img.size w, h = img.size
post_trans = (0, 0) post_trans = (0, 0)
rotn_center = (w / 2.0, h / 2.0) rotn_center = (w / 2.0, h / 2.0)
@ -124,7 +123,6 @@ def rotate(img, degrees, **kwargs):
matrix[2] += rotn_center[0] matrix[2] += rotn_center[0]
matrix[5] += rotn_center[1] matrix[5] += rotn_center[1]
return img.transform(img.size, Image.AFFINE, matrix, **kwargs) return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
else:
return img.rotate(degrees, resample=kwargs['resample']) return img.rotate(degrees, resample=kwargs['resample'])
@ -151,11 +149,12 @@ def solarize_add(img, add, thresh=128, **__):
lut.append(min(255, i + add)) lut.append(min(255, i + add))
else: else:
lut.append(i) lut.append(i)
if img.mode in ("L", "RGB"): if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256: if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut lut = lut + lut + lut
return img.point(lut) return img.point(lut)
else:
return img return img
@ -226,7 +225,7 @@ def _enhance_increasing_level_to_arg(level, _hparams):
def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True): def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
level = (level / _LEVEL_DENOM) level = (level / _LEVEL_DENOM)
min_val + (max_val - min_val) * level level = min_val + (max_val - min_val) * level
if clamp: if clamp:
level = max(min_val, min(max_val, level)) level = max(min_val, min(max_val, level))
return level, return level,
@ -552,16 +551,15 @@ def auto_augment_policy(name='v0', hparams=None):
hparams = hparams or _HPARAMS_DEFAULT hparams = hparams or _HPARAMS_DEFAULT
if name == 'original': if name == 'original':
return auto_augment_policy_original(hparams) return auto_augment_policy_original(hparams)
elif name == 'originalr': if name == 'originalr':
return auto_augment_policy_originalr(hparams) return auto_augment_policy_originalr(hparams)
elif name == 'v0': if name == 'v0':
return auto_augment_policy_v0(hparams) return auto_augment_policy_v0(hparams)
elif name == 'v0r': if name == 'v0r':
return auto_augment_policy_v0r(hparams) return auto_augment_policy_v0r(hparams)
elif name == '3a': if name == '3a':
return auto_augment_policy_3a(hparams) return auto_augment_policy_3a(hparams)
else: assert False, f'Unknown AA policy {name}'
assert False, 'Unknown AA policy (%s)' % name
class AutoAugment: class AutoAugment:
@ -576,7 +574,7 @@ class AutoAugment:
return img return img
def __repr__(self): def __repr__(self):
fs = self.__class__.__name__ + f'(policy=' fs = self.__class__.__name__ + '(policy='
for p in self.policy: for p in self.policy:
fs += '\n\t[' fs += '\n\t['
fs += ', '.join([str(op) for op in p]) fs += ', '.join([str(op) for op in p])
@ -636,7 +634,7 @@ _RAND_TRANSFORMS = [
'ShearY', 'ShearY',
'TranslateXRel', 'TranslateXRel',
'TranslateYRel', 'TranslateYRel',
#'Cutout' # NOTE I've implement this as random erasing separately # 'Cutout' # NOTE I've implement this as random erasing separately
] ]
@ -656,7 +654,7 @@ _RAND_INCREASING_TRANSFORMS = [
'ShearY', 'ShearY',
'TranslateXRel', 'TranslateXRel',
'TranslateYRel', 'TranslateYRel',
#'Cutout' # NOTE I've implement this as random erasing separately # 'Cutout' # NOTE I've implement this as random erasing separately
] ]
@ -667,7 +665,7 @@ _RAND_3A = [
] ]
_RAND_CHOICE_3A = { _RAND_WEIGHTED_3A = {
'SolarizeIncreasing': 6, 'SolarizeIncreasing': 6,
'Desaturate': 6, 'Desaturate': 6,
'GaussianBlur': 6, 'GaussianBlur': 6,
@ -687,7 +685,7 @@ _RAND_CHOICE_3A = {
# These experimental weights are based loosely on the relative improvements mentioned in paper. # These experimental weights are based loosely on the relative improvements mentioned in paper.
# They may not result in increased performance, but could likely be tuned to so. # They may not result in increased performance, but could likely be tuned to so.
_RAND_CHOICE_WEIGHTS_0 = { _RAND_WEIGHTED_0 = {
'Rotate': 3, 'Rotate': 3,
'ShearX': 2, 'ShearX': 2,
'ShearY': 2, 'ShearY': 2,
@ -715,12 +713,11 @@ def _get_weighted_transforms(transforms: Dict):
def rand_augment_choices(name: str, increasing=True): def rand_augment_choices(name: str, increasing=True):
if name == 'weights': if name == 'weights':
return _RAND_CHOICE_WEIGHTS_0 return _RAND_WEIGHTED_0
elif name == '3aw': if name == '3aw':
return _RAND_CHOICE_3A return _RAND_WEIGHTED_3A
elif name == '3a': if name == '3a':
return _RAND_3A return _RAND_3A
else:
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS

View File

@ -7,7 +7,11 @@ import os
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import safetensors.torch try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
import timm.models._builder import timm.models._builder
@ -29,6 +33,7 @@ def load_state_dict(checkpoint_path, use_ema=True):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
# Check if safetensors or not and load weights accordingly # Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"): if str(checkpoint_path).endswith(".safetensors"):
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu') checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
else: else:
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')

View File

@ -7,15 +7,21 @@ from functools import partial
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Iterable, Optional, Union from typing import Iterable, Optional, Union
import torch import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse from torch.hub import HASH_REGEX, download_url_to_file, urlparse
import safetensors.torch
try: try:
from torch.hub import get_dir from torch.hub import get_dir
except ImportError: except ImportError:
from torch.hub import _get_torch_home as get_dir from torch.hub import _get_torch_home as get_dir
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
from typing import Literal from typing import Literal
else: else:
@ -45,6 +51,7 @@ __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'l
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
def get_cache_dir(child_dir=''): def get_cache_dir(child_dir=''):
""" """
Returns the location of the directory where models are cached (and creates it if necessary). Returns the location of the directory where models are cached (and creates it if necessary).
@ -164,21 +171,28 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
hf_model_id, hf_revision = hf_split(model_id) hf_model_id, hf_revision = hf_split(model_id)
# Look for .safetensors alternatives and load from it if it exists # Look for .safetensors alternatives and load from it if it exists
if _has_safetensors:
for safe_filename in _get_safe_alternatives(filename): for safe_filename in _get_safe_alternatives(filename):
try: try:
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision) cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
_logger.info(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.") _logger.info(
f"[{model_id}] Safe alternative available for '{filename}' "
f"(as '{safe_filename}'). Loading weights using safetensors.")
return safetensors.torch.load_file(cached_safe_file, device="cpu") return safetensors.torch.load_file(cached_safe_file, device="cpu")
except EntryNotFoundError: except EntryNotFoundError:
pass pass
# Otherwise, load using pytorch.load # Otherwise, load using pytorch.load
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision) cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
_logger.info(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.") _logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
return torch.load(cached_file, map_location='cpu') return torch.load(cached_file, map_location='cpu')
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None): def save_config_for_hf(
model,
config_path: str,
model_config: Optional[dict] = None
):
model_config = model_config or {} model_config = model_config or {}
hf_config = {} hf_config = {}
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
@ -220,8 +234,8 @@ def save_for_hf(
model, model,
save_directory: str, save_directory: str,
model_config: Optional[dict] = None, model_config: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False safe_serialization: Union[bool, Literal["both"]] = False,
): ):
assert has_hf_hub(True) assert has_hf_hub(True)
save_directory = Path(save_directory) save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True) save_directory.mkdir(exist_ok=True, parents=True)
@ -229,6 +243,7 @@ def save_for_hf(
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both. # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
tensors = model.state_dict() tensors = model.state_dict()
if safe_serialization is True or safe_serialization == "both": if safe_serialization is True or safe_serialization == "both":
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
if safe_serialization is False or safe_serialization == "both": if safe_serialization is False or safe_serialization == "both":
torch.save(tensors, save_directory / HF_WEIGHTS_NAME) torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
@ -247,7 +262,7 @@ def push_to_hf_hub(
create_pr: bool = False, create_pr: bool = False,
model_config: Optional[dict] = None, model_config: Optional[dict] = None,
model_card: Optional[dict] = None, model_card: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False safe_serialization: Union[bool, Literal["both"]] = False,
): ):
""" """
Arguments: Arguments:
@ -341,6 +356,7 @@ def generate_readme(model_card: dict, model_name: str):
readme_text += f"```bibtex\n{c}\n```\n" readme_text += f"```bibtex\n{c}\n```\n"
return readme_text return readme_text
def _get_safe_alternatives(filename: str) -> Iterable[str]: def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""Returns potential safetensors alternatives for a given filename. """Returns potential safetensors alternatives for a given filename.
@ -350,5 +366,5 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
""" """
if filename == HF_WEIGHTS_NAME: if filename == HF_WEIGHTS_NAME:
yield HF_SAFE_WEIGHTS_NAME yield HF_SAFE_WEIGHTS_NAME
if filename.endswith(".bin"): if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"):
yield filename[:-4] + ".safetensors" return filename[:-4] + ".safetensors"

View File

@ -93,7 +93,7 @@ class DefaultCfg:
return tag, self.cfgs[tag] return tag, self.cfgs[tag]
def split_model_name_tag(model_name: str, no_tag=''): def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
model_name, *tag_list = model_name.split('.', 1) model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag tag = tag_list[0] if tag_list else no_tag
return model_name, tag return model_name, tag

View File

@ -8,7 +8,7 @@ import sys
from collections import defaultdict, deque from collections import defaultdict, deque
from copy import deepcopy from copy import deepcopy
from dataclasses import replace from dataclasses import replace
from typing import List, Optional, Union, Tuple from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
@ -16,20 +16,20 @@ __all__ = [
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module _module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names _model_to_module: Dict[str, str] = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns _model_entrypoints: Dict[str, Callable[..., Any]] = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present _model_has_pretrained: Set[str] = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects _model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs _model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names _model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]: def get_arch_name(model_name: str) -> str:
return split_model_name_tag(model_name)[0] return split_model_name_tag(model_name)[0]
def register_model(fn): def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
# lookup containing module # lookup containing module
mod = sys.modules[fn.__module__] mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.') module_name_split = fn.__module__.split('.')
@ -40,7 +40,7 @@ def register_model(fn):
if hasattr(mod, '__all__'): if hasattr(mod, '__all__'):
mod.__all__.append(model_name) mod.__all__.append(model_name)
else: else:
mod.__all__ = [model_name] mod.__all__ = [model_name] # type: ignore
# add entries to registry dict/sets # add entries to registry dict/sets
_model_entrypoints[model_name] = fn _model_entrypoints[model_name] = fn
@ -87,28 +87,33 @@ def register_model(fn):
return fn return fn
def _natural_key(string_): def _natural_key(string_: str) -> List[Union[int, str]]:
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models( def list_models(
filter: Union[str, List[str]] = '', filter: Union[str, List[str]] = '',
module: str = '', module: str = '',
pretrained=False, pretrained: bool = False,
exclude_filters: str = '', exclude_filters: Union[str, List[str]] = '',
name_matches_cfg: bool = False, name_matches_cfg: bool = False,
include_tags: Optional[bool] = None, include_tags: Optional[bool] = None,
): ) -> List[str]:
""" Return list of available model names, sorted alphabetically """ Return list of available model names, sorted alphabetically
Args: Args:
filter (str) - Wildcard filter string that works with fnmatch filter - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer') module - Limit model selection to a specific submodule (ie 'vision_transformer')
pretrained (bool) - Include only models with valid pretrained weights if True pretrained - Include only models with valid pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter exclude_filters - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults include_tags - Include pretrained tags in model names (model.tag). If None, defaults
set to True when pretrained=True else False (default: None) set to True when pretrained=True else False (default: None)
Returns:
models - The sorted list of models
Example: Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
@ -118,7 +123,7 @@ def list_models(
include_tags = pretrained include_tags = pretrained
if module: if module:
all_models = list(_module_to_models[module]) all_models: Iterable[str] = list(_module_to_models[module])
else: else:
all_models = _model_entrypoints.keys() all_models = _model_entrypoints.keys()
@ -130,14 +135,14 @@ def list_models(
all_models = models_with_tags all_models = models_with_tags
if filter: if filter:
models = [] models: Set[str] = set()
include_filters = filter if isinstance(filter, (tuple, list)) else [filter] include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters: for f in include_filters:
include_models = fnmatch.filter(all_models, f) # include these models include_models = fnmatch.filter(all_models, f) # include these models
if len(include_models): if len(include_models):
models = set(models).union(include_models) models = models.union(include_models)
else: else:
models = all_models models = set(all_models)
if exclude_filters: if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)): if not isinstance(exclude_filters, (tuple, list)):
@ -145,7 +150,7 @@ def list_models(
for xf in exclude_filters: for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models): if len(exclude_models):
models = set(models).difference(exclude_models) models = models.difference(exclude_models)
if pretrained: if pretrained:
models = _model_has_pretrained.intersection(models) models = _model_has_pretrained.intersection(models)
@ -153,13 +158,13 @@ def list_models(
if name_matches_cfg: if name_matches_cfg:
models = set(_model_pretrained_cfgs).intersection(models) models = set(_model_pretrained_cfgs).intersection(models)
return list(sorted(models, key=_natural_key)) return sorted(models, key=_natural_key)
def list_pretrained( def list_pretrained(
filter: Union[str, List[str]] = '', filter: Union[str, List[str]] = '',
exclude_filters: str = '', exclude_filters: str = '',
): ) -> List[str]:
return list_models( return list_models(
filter=filter, filter=filter,
pretrained=True, pretrained=True,
@ -168,14 +173,14 @@ def list_pretrained(
) )
def is_model(model_name): def is_model(model_name: str) -> bool:
""" Check if a model name exists """ Check if a model name exists
""" """
arch_name = get_arch_name(model_name) arch_name = get_arch_name(model_name)
return arch_name in _model_entrypoints return arch_name in _model_entrypoints
def model_entrypoint(model_name, module_filter: Optional[str] = None): def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
"""Fetch a model entrypoint for specified model name """Fetch a model entrypoint for specified model name
""" """
arch_name = get_arch_name(model_name) arch_name = get_arch_name(model_name)
@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None):
return _model_entrypoints[arch_name] return _model_entrypoints[arch_name]
def list_modules(): def list_modules() -> List[str]:
""" Return list of module names that contain models / model entrypoints """ Return list of module names that contain models / model entrypoints
""" """
modules = _module_to_models.keys() modules = _module_to_models.keys()
return list(sorted(modules)) return sorted(modules)
def is_model_in_modules(model_name, module_names): def is_model_in_modules(
model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
) -> bool:
"""Check if a model exists within a subset of modules """Check if a model exists within a subset of modules
Args: Args:
model_name (str) - name of model to check model_name - name of model to check
module_names (tuple, list, set) - names of modules to search in module_names - names of modules to search in
""" """
arch_name = get_arch_name(model_name) arch_name = get_arch_name(model_name)
assert isinstance(module_names, (tuple, list, set)) assert isinstance(module_names, (tuple, list, set))
return any(arch_name in _module_to_models[n] for n in module_names) return any(arch_name in _module_to_models[n] for n in module_names)
def is_model_pretrained(model_name): def is_model_pretrained(model_name: str) -> bool:
return model_name in _model_has_pretrained return model_name in _model_has_pretrained
def get_pretrained_cfg(model_name, allow_unregistered=True): def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
if model_name in _model_pretrained_cfgs: if model_name in _model_pretrained_cfgs:
return deepcopy(_model_pretrained_cfgs[model_name]) return deepcopy(_model_pretrained_cfgs[model_name])
arch_name, tag = split_model_name_tag(model_name) arch_name, tag = split_model_name_tag(model_name)
@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True):
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.') raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
def get_pretrained_cfg_value(model_name, cfg_key): def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
""" Get a specific model default_cfg value by key. None if key doesn't exist. """ Get a specific model default_cfg value by key. None if key doesn't exist.
""" """
cfg = get_pretrained_cfg(model_name, allow_unregistered=False) cfg = get_pretrained_cfg(model_name, allow_unregistered=False)

View File

@ -45,7 +45,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \ from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
from timm.layers import NormMlpClassifierHead, ClassifierHead from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
@ -56,6 +56,28 @@ from ._registry import register_model
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
class Downsample(nn.Module):
def __init__(self, in_chs, out_chs, stride=1, dilation=1):
super().__init__()
avg_stride = stride if dilation == 1 else 1
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
else:
self.pool = nn.Identity()
if in_chs != out_chs:
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
else:
self.conv = nn.Identity()
def forward(self, x):
x = self.pool(x)
x = self.conv(x)
return x
class ConvNeXtBlock(nn.Module): class ConvNeXtBlock(nn.Module):
""" ConvNeXt Block """ ConvNeXt Block
There are two equivalent implementations: There are two equivalent implementations:
@ -65,41 +87,65 @@ class ConvNeXtBlock(nn.Module):
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
Args:
in_chs (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
""" """
def __init__( def __init__(
self, self,
in_chs, in_chs: int,
out_chs=None, out_chs: Optional[int] = None,
kernel_size=7, kernel_size: int = 7,
stride=1, stride: int = 1,
dilation=1, dilation: Union[int, Tuple[int, int]] = (1, 1),
mlp_ratio=4, mlp_ratio: float = 4,
conv_mlp=False, conv_mlp: bool = False,
conv_bias=True, conv_bias: bool = True,
use_grn=False, use_grn: bool = False,
ls_init_value=1e-6, ls_init_value: Optional[float] = 1e-6,
act_layer='gelu', act_layer: Union[str, Callable] = 'gelu',
norm_layer=None, norm_layer: Optional[Callable] = None,
drop_path=0., drop_path: float = 0.,
): ):
"""
Args:
in_chs: Block input channels.
out_chs: Block output channels (same as in_chs if None).
kernel_size: Depthwise convolution kernel size.
stride: Stride of depthwise convolution.
dilation: Tuple specifying input and output dilation of block.
mlp_ratio: MLP expansion ratio.
conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
conv_bias: Apply bias for all convolution (linear) layers.
use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
ls_init_value: Layer-scale init values, layer-scale applied if not None.
act_layer: Activation layer.
norm_layer: Normalization layer (defaults to LN if not specified).
drop_path: Stochastic depth probability.
"""
super().__init__() super().__init__()
out_chs = out_chs or in_chs out_chs = out_chs or in_chs
dilation = to_ntuple(2)(dilation)
act_layer = get_act_layer(act_layer) act_layer = get_act_layer(act_layer)
if not norm_layer: if not norm_layer:
norm_layer = LayerNorm2d if conv_mlp else LayerNorm norm_layer = LayerNorm2d if conv_mlp else LayerNorm
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp) mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
self.use_conv_mlp = conv_mlp self.use_conv_mlp = conv_mlp
self.conv_dw = create_conv2d( self.conv_dw = create_conv2d(
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias) in_chs,
out_chs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation[0],
depthwise=True,
bias=conv_bias,
)
self.norm = norm_layer(out_chs) self.norm = norm_layer(out_chs)
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer) self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
else:
self.shortcut = nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x): def forward(self, x):
@ -116,7 +162,7 @@ class ConvNeXtBlock(nn.Module):
if self.gamma is not None: if self.gamma is not None:
x = x.mul(self.gamma.reshape(1, -1, 1, 1)) x = x.mul(self.gamma.reshape(1, -1, 1, 1))
x = self.drop_path(x) + shortcut x = self.drop_path(x) + self.shortcut(shortcut)
return x return x
@ -148,8 +194,14 @@ class ConvNeXtStage(nn.Module):
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
norm_layer(in_chs), norm_layer(in_chs),
create_conv2d( create_conv2d(
in_chs, out_chs, kernel_size=ds_ks, stride=stride, in_chs,
dilation=dilation[0], padding=pad, bias=conv_bias), out_chs,
kernel_size=ds_ks,
stride=stride,
dilation=dilation[0],
padding=pad,
bias=conv_bias,
),
) )
in_chs = out_chs in_chs = out_chs
else: else:
@ -773,136 +825,147 @@ default_cfgs = generate_default_cfgs({
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
'convnext_xxlarge.clip_laion2b_soup': _cfg(
hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
'convnext_xxlarge.clip_laion2b_rewind': _cfg(
hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
}) })
@register_model @register_model
def convnext_atto(pretrained=False, **kwargs): def convnext_atto(pretrained=False, **kwargs):
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
model_args = dict( model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True)
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, **kwargs) model = _create_convnext('convnext_atto', pretrained=pretrained, **dict(model_args, **kwargs))
model = _create_convnext('convnext_atto', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def convnext_atto_ols(pretrained=False, **kwargs): def convnext_atto_ols(pretrained=False, **kwargs):
# timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M # timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
model_args = dict( model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered')
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered', **kwargs) model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def convnext_femto(pretrained=False, **kwargs): def convnext_femto(pretrained=False, **kwargs):
# timm femto variant # timm femto variant
model_args = dict( model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True)
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, **kwargs) model = _create_convnext('convnext_femto', pretrained=pretrained, **dict(model_args, **kwargs))
model = _create_convnext('convnext_femto', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def convnext_femto_ols(pretrained=False, **kwargs): def convnext_femto_ols(pretrained=False, **kwargs):
# timm femto variant # timm femto variant
model_args = dict( model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered')
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered', **kwargs) model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def convnext_pico(pretrained=False, **kwargs): def convnext_pico(pretrained=False, **kwargs):
# timm pico variant # timm pico variant
model_args = dict( model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True)
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, **kwargs) model = _create_convnext('convnext_pico', pretrained=pretrained, **dict(model_args, **kwargs))
model = _create_convnext('convnext_pico', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def convnext_pico_ols(pretrained=False, **kwargs): def convnext_pico_ols(pretrained=False, **kwargs):
# timm nano variant with overlapping 3x3 conv stem # timm nano variant with overlapping 3x3 conv stem
model_args = dict( model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered')
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered', **kwargs) model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **dict(model_args, **kwargs))
model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def convnext_nano(pretrained=False, **kwargs): def convnext_nano(pretrained=False, **kwargs):
# timm nano variant with standard stem and head # timm nano variant with standard stem and head
model_args = dict( model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True)
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs) model = _create_convnext('convnext_nano', pretrained=pretrained, **dict(model_args, **kwargs))
model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def convnext_nano_ols(pretrained=False, **kwargs): def convnext_nano_ols(pretrained=False, **kwargs):
# experimental nano variant with overlapping conv stem # experimental nano variant with overlapping conv stem
model_args = dict( model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap')
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap', **kwargs) model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **dict(model_args, **kwargs))
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def convnext_tiny_hnf(pretrained=False, **kwargs): def convnext_tiny_hnf(pretrained=False, **kwargs):
# experimental tiny variant with norm before pooling in head (head norm first) # experimental tiny variant with norm before pooling in head (head norm first)
model_args = dict( model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True)
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **dict(model_args, **kwargs))
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def convnext_tiny(pretrained=False, **kwargs): def convnext_tiny(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768))
model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args) model = _create_convnext('convnext_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def convnext_small(pretrained=False, **kwargs): def convnext_small(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
model = _create_convnext('convnext_small', pretrained=pretrained, **model_args) model = _create_convnext('convnext_small', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def convnext_base(pretrained=False, **kwargs): def convnext_base(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
model = _create_convnext('convnext_base', pretrained=pretrained, **model_args) model = _create_convnext('convnext_base', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def convnext_large(pretrained=False, **kwargs): def convnext_large(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
model = _create_convnext('convnext_large', pretrained=pretrained, **model_args) model = _create_convnext('convnext_large', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def convnext_large_mlp(pretrained=False, **kwargs): def convnext_large_mlp(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536, **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536)
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **model_args) model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def convnext_xlarge(pretrained=False, **kwargs): def convnext_xlarge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **model_args) model = _create_convnext('convnext_xlarge', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def convnext_xxlarge(pretrained=False, **kwargs): def convnext_xxlarge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], **kwargs) model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **model_args) model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@ -910,8 +973,8 @@ def convnext_xxlarge(pretrained=False, **kwargs):
def convnextv2_atto(pretrained=False, **kwargs): def convnextv2_atto(pretrained=False, **kwargs):
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
model_args = dict( model_args = dict(
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs) depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True)
model = _create_convnext('convnextv2_atto', pretrained=pretrained, **model_args) model = _create_convnext('convnextv2_atto', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@ -919,8 +982,8 @@ def convnextv2_atto(pretrained=False, **kwargs):
def convnextv2_femto(pretrained=False, **kwargs): def convnextv2_femto(pretrained=False, **kwargs):
# timm femto variant # timm femto variant
model_args = dict( model_args = dict(
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs) depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True)
model = _create_convnext('convnextv2_femto', pretrained=pretrained, **model_args) model = _create_convnext('convnextv2_femto', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@ -928,8 +991,8 @@ def convnextv2_femto(pretrained=False, **kwargs):
def convnextv2_pico(pretrained=False, **kwargs): def convnextv2_pico(pretrained=False, **kwargs):
# timm pico variant # timm pico variant
model_args = dict( model_args = dict(
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs) depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True)
model = _create_convnext('convnextv2_pico', pretrained=pretrained, **model_args) model = _create_convnext('convnextv2_pico', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@ -937,42 +1000,41 @@ def convnextv2_pico(pretrained=False, **kwargs):
def convnextv2_nano(pretrained=False, **kwargs): def convnextv2_nano(pretrained=False, **kwargs):
# timm nano variant with standard stem and head # timm nano variant with standard stem and head
model_args = dict( model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs) depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True)
model = _create_convnext('convnextv2_nano', pretrained=pretrained, **model_args) model = _create_convnext('convnextv2_nano', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def convnextv2_tiny(pretrained=False, **kwargs): def convnextv2_tiny(pretrained=False, **kwargs):
model_args = dict( model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None)
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None, **kwargs) model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def convnextv2_small(pretrained=False, **kwargs): def convnextv2_small(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None, **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None)
model = _create_convnext('convnextv2_small', pretrained=pretrained, **model_args) model = _create_convnext('convnextv2_small', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def convnextv2_base(pretrained=False, **kwargs): def convnextv2_base(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None, **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None)
model = _create_convnext('convnextv2_base', pretrained=pretrained, **model_args) model = _create_convnext('convnextv2_base', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def convnextv2_large(pretrained=False, **kwargs): def convnextv2_large(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None, **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None)
model = _create_convnext('convnextv2_large', pretrained=pretrained, **model_args) model = _create_convnext('convnextv2_large', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model
@register_model @register_model
def convnextv2_huge(pretrained=False, **kwargs): def convnextv2_huge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None, **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **model_args) model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
return model return model

View File

@ -20,7 +20,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module from ._features_fx import register_notrace_module
from ._registry import register_model from ._registry import register_model
@ -564,6 +564,7 @@ class MobileVitV2Block(nn.Module):
self.patch_size = to_2tuple(patch_size) self.patch_size = to_2tuple(patch_size)
self.patch_area = self.patch_size[0] * self.patch_size[1] self.patch_area = self.patch_size[0] * self.patch_size[1]
self.coreml_exportable = is_exportable()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape B, C, H, W = x.shape
@ -580,6 +581,9 @@ class MobileVitV2Block(nn.Module):
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N] # Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
C = x.shape[1] C = x.shape[1]
if self.coreml_exportable:
x = F.unfold(x, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
else:
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4) x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
x = x.reshape(B, C, -1, num_patches) x = x.reshape(B, C, -1, num_patches)
@ -588,9 +592,15 @@ class MobileVitV2Block(nn.Module):
x = self.norm(x) x = self.norm(x)
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W] # Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
if self.coreml_exportable:
# adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624
x = x.reshape(B, C * patch_h * patch_w, num_patch_h, num_patch_w)
x = F.pixel_shuffle(x, upscale_factor=patch_h)
else:
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3) x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w) x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
x = self.conv_proj(x) x = self.conv_proj(x)
return x return x

View File

@ -1 +1 @@
__version__ = '0.8.12dev0' __version__ = '0.8.15dev0'

View File

@ -514,8 +514,14 @@ def main():
if utils.is_primary(args): if utils.is_primary(args):
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native': elif use_amp == 'native':
try:
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
if device.type == 'cuda': except (AttributeError, TypeError):
# fallback to CUDA only AMP for PyTorch < 1.10
assert device.type == 'cuda'
amp_autocast = torch.cuda.amp.autocast
if device.type == 'cuda' and amp_dtype == torch.float16:
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
loss_scaler = NativeScaler() loss_scaler = NativeScaler()
if utils.is_primary(args): if utils.is_primary(args):
_logger.info('Using native Torch AMP. Training in mixed precision.') _logger.info('Using native Torch AMP. Training in mixed precision.')