Merge remote-tracking branch 'origin/main' into focalnet_and_swin_refactor

This commit is contained in:
Ross Wightman 2023-03-15 15:58:39 -07:00
commit c30a160d3e
18 changed files with 549 additions and 248 deletions

View File

@ -19,6 +19,7 @@ jobs:
python: ['3.10']
torch: ['1.13.0']
torchvision: ['0.14.0']
testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward']
runs-on: ${{ matrix.os }}
steps:
@ -30,7 +31,7 @@ jobs:
- name: Install testing dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-timeout pytest-xdist pytest-forked expecttest
pip install -r requirements-dev.txt
- name: Install torch on mac
if: startsWith(matrix.os, 'macOS')
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
@ -54,10 +55,10 @@ jobs:
PYTHONDONTWRITEBYTECODE: 1
run: |
pytest -vv tests
- name: Run tests on Linux / Mac
- name: Run '${{ matrix.testmarker }}' tests on Linux / Mac
if: ${{ !startsWith(matrix.os, 'windows') }}
env:
LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
PYTHONDONTWRITEBYTECODE: 1
run: |
pytest -vv --forked --durations=0 tests
pytest -vv --forked --durations=0 ${{ matrix.testmarker }} tests

107
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,107 @@
*This guideline is very much a work-in-progress.*
Contriubtions to `timm` for code, documentation, tests are more than welcome!
There haven't been any formal guidelines to date so please bear with me, and feel free to add to this guide.
# Coding style
Code linting and auto-format (black) are not currently in place but open to consideration. In the meantime, the style to follow is (mostly) aligned with Google's guide: https://google.github.io/styleguide/pyguide.html.
A few specific differences from Google style (or black)
1. Line length is 120 char. Going over is okay in some cases (e.g. I prefer not to break URL across lines).
2. Hanging indents are always prefered, please avoid aligning arguments with closing brackets or braces.
Example, from Google guide, but this is a NO here:
```
# Aligned with opening delimiter.
foo = long_function_name(var_one, var_two,
var_three, var_four)
meal = (spam,
beans)
# Aligned with opening delimiter in a dictionary.
foo = {
'long_dictionary_key': value1 +
value2,
...
}
```
This is YES:
```
# 4-space hanging indent; nothing on first line,
# closing parenthesis on a new line.
foo = long_function_name(
var_one, var_two, var_three,
var_four
)
meal = (
spam,
beans,
)
# 4-space hanging indent in a dictionary.
foo = {
'long_dictionary_key':
long_dictionary_value,
...
}
```
When there is descrepancy in a given source file (there are many origins for various bits of code and not all have been updated to what I consider current goal), please follow the style in a given file.
In general, if you add new code, formatting it with black using the following options should result in a style that is compatible with the rest of the code base:
```
black --skip-string-normalization --line-length 120 <path-to-file>
```
Avoid formatting code that is unrelated to your PR though.
PR with pure formatting / style fixes will be accepted but only in isolation from functional changes, best to ask before starting such a change.
# Documentation
As with code style, docstrings style based on the Google guide: guide: https://google.github.io/styleguide/pyguide.html
The goal for the code is to eventually move to have all major functions and `__init__` methods use PEP484 type annotations.
When type annotations are used for a function, as per the Google pyguide, they should **NOT** be duplicated in the docstrings, please leave annotations as the one source of truth re typing.
There are a LOT of gaps in current documentation relative to the functionality in timm, please, document away!
# Installation
Create a Python virtual environment using Python 3.10. Inside the environment, install torch` and `torchvision` using the instructions matching your system as listed on the [PyTorch website](https://pytorch.org/).
Then install the remaining dependencies:
```
python -m pip install -r requirements.txt
python -m pip install -r requirements-dev.txt # for testing
python -m pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git
python -m pip install -e .
```
## Unit tests
Run the tests using:
```
pytest tests/
```
Since the whole test suite takes a lot of time to run locally (a few hours), you may want to select a subset of tests relating to the changes you made by using the `-k` option of [`pytest`](https://docs.pytest.org/en/7.1.x/example/markers.html#using-k-expr-to-select-tests-based-on-their-name). Moreover, running tests in parallel (in this example 4 processes) with the `-n` option may help:
```
pytest -k "substring-to-match" -n 4 tests/
```
## Building documentation
Please refer to [this document](https://github.com/huggingface/pytorch-image-models/tree/main/hfdocs).
# Questions
If you have any questions about contribution, where / how to contribute, please ask in the [Discussions](https://github.com/huggingface/pytorch-image-models/discussions/categories/contributing) (there is a `Contributing` topic).

View File

@ -24,6 +24,15 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
### Feb 26, 2023
* Add ConvNeXt-XXLarge CLIP pretrained image tower weights for fine-tune & features (fine-tuning TBD) -- see [model card](https://huggingface.co/laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup)
* Update `convnext_xxlarge` default LayerNorm eps to 1e-5 (for CLIP weights, improved stability)
* 0.8.15dev0
### Feb 20, 2023
* Add 320x320 `convnext_large_mlp.clip_laion2b_ft_320` and `convnext_lage_mlp.clip_laion2b_ft_soup_320` CLIP image tower weights for features & fine-tune
* 0.8.13dev0 pypi release for latest changes w/ move to huggingface org
### Feb 16, 2023
* `safetensor` checkpoint support added
* Add ideas from 'Scaling Vision Transformers to 22 B. Params' (https://arxiv.org/abs/2302.05442) -- qk norm, RmsNorm, parallel block
@ -112,7 +121,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* Finally got around to adding `--model-kwargs` and `--opt-kwargs` to scripts to pass through rare args directly to model classes from cmd line
* `train.py /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu`
* `train.py /imagenet --model vit_base_patch16_clip_224 --img-size 240 --amp --model-kwargs img_size=240 patch_size=12`
* Cleanup some popular models to better support arg passthrough / merge with model configs, more to go.
* Cleanup some popular models to better support arg passthrough / merge with model configs, more to go.
### Jan 5, 2023
* ConvNeXt-V2 models and weights added to existing `convnext.py`
@ -142,7 +151,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
| eva_large_patch14_196.in22k_ft_in1k | 87.9 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) |
### Dec 6, 2022
* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`.
* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`.
* original source: https://github.com/baaivision/EVA
* paper: https://arxiv.org/abs/2211.07636
@ -237,7 +246,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* `maxxvit_rmlp_small_rw_256` - 84.6 @ 256, 84.9 @ 288 (G) -- could be trained better, hparams need tuning (uses ConvNeXt block, no BN)
* `coatnet_rmlp_2_rw_224` - 84.6 @ 224, 85 @ 320 (T)
* NOTE: official MaxVit weights (in1k) have been released at https://github.com/google-research/maxvit -- some extra work is needed to port and adapt since my impl was created independently of theirs and has a few small differences + the whole TF same padding fun.
### Sept 23, 2022
* LAION-2B CLIP image towers supported as pretrained backbones for fine-tune or features (no classifier)
* vit_base_patch32_224_clip_laion2b
@ -268,7 +277,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* `coatnet_bn_0_rw_224` - 82.4 (T)
* `maxvit_nano_rw_256` - 82.9 @ 256 (T)
* `coatnet_rmlp_1_rw_224` - 83.4 @ 224, 84 @ 320 (T)
* `coatnet_1_rw_224` - 83.6 @ 224 (G)
* `coatnet_1_rw_224` - 83.6 @ 224 (G)
* (T) = TPU trained with `bits_and_tpu` branch training code, (G) = GPU trained
* GCVit (weights adapted from https://github.com/NVlabs/GCVit, code 100% `timm` re-write for license purposes)
* MViT-V2 (multi-scale vit, adapted from https://github.com/facebookresearch/mvit)
@ -283,7 +292,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* `convnext_atto_ols` - 75.9 @ 224, 77.2 @ 288
### Aug 5, 2022
* More custom ConvNeXt smaller model defs with weights
* More custom ConvNeXt smaller model defs with weights
* `convnext_femto` - 77.5 @ 224, 78.7 @ 288
* `convnext_femto_ols` - 77.9 @ 224, 78.9 @ 288
* `convnext_pico` - 79.5 @ 224, 80.4 @ 288
@ -304,7 +313,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* `cs3sedarknet_x` - 82.2 @ 256, 82.7 @ 288
* `cs3edgenet_x` - 82.2 @ 256, 82.7 @ 288
* `cs3se_edgenet_x` - 82.8 @ 256, 83.5 @ 320
* `cs3*` weights above all trained on TPU w/ `bits_and_tpu` branch. Thanks to TRC program!
* `cs3*` weights above all trained on TPU w/ `bits_and_tpu` branch. Thanks to TRC program!
* Add output_stride=8 and 16 support to ConvNeXt (dilation)
* deit3 models not being able to resize pos_emb fixed
* Version 0.6.7 PyPi release (/w above bug fixes and new weighs since 0.6.5)
@ -337,8 +346,8 @@ More models, more fixes
* Hugging Face Hub support fixes verified, demo notebook TBA
* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation.
* Add support to change image extensions scanned by `timm` datasets/readers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases.
* a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges.
* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases.
* a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges.
* previous impl exists as `LayerNormExp2d` in `models/layers/norm.py`
* Numerous bug fixes
* Currently testing for imminent PyPi 0.6.x release
@ -435,9 +444,7 @@ The work of many others is present here. I've tried to make sure all source mate
## Models
All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated. Here are some example [training hparams](https://rwightman.github.io/pytorch-image-models/training_hparam_examples) to get you started.
A full version of the list below with source links can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/models/).
All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated.
* Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723
* BEiT - https://arxiv.org/abs/2106.08254
@ -538,15 +545,15 @@ Several (less common) features that I often utilize in my projects are included.
* All models have a common default configuration interface and API for
* accessing/changing the classifier - `get_classifier` and `reset_classifier`
* doing a forward pass on just the features - `forward_features` (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/))
* doing a forward pass on just the features - `forward_features` (see [documentation](https://huggingface.co/docs/timm/feature_extraction))
* these makes it easy to write consistent network wrappers that work with any of the models
* All models support multi-scale feature map extraction (feature pyramids) via create_model (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/))
* All models support multi-scale feature map extraction (feature pyramids) via create_model (see [documentation](https://huggingface.co/docs/timm/feature_extraction))
* `create_model(name, features_only=True, out_indices=..., output_stride=...)`
* `out_indices` creation arg specifies which feature maps to return, these indices are 0 based and generally correspond to the `C(i + 1)` feature level.
* `output_stride` creation arg controls output stride of the network by using dilated convolutions. Most networks are stride 32 by default. Not all networks support this.
* feature map channel counts, reduction level (stride) can be queried AFTER model creation via the `.feature_info` member
* All models have a consistent pretrained weight loader that adapts last linear if necessary, and from 3 to 1 channel input if desired
* High performance [reference training, validation, and inference scripts](https://rwightman.github.io/pytorch-image-models/scripts/) that work in several process/GPU modes:
* High performance [reference training, validation, and inference scripts](https://huggingface.co/docs/timm/training_script) that work in several process/GPU modes:
* NVIDIA DDP w/ a single GPU per process, multiple processes with APEX present (AMP mixed-precision optional)
* PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled)
* PyTorch w/ single GPU single process (AMP optional)
@ -573,7 +580,7 @@ Several (less common) features that I often utilize in my projects are included.
* AutoAugment (https://arxiv.org/abs/1805.09501) and RandAugment (https://arxiv.org/abs/1909.13719) ImageNet configurations modeled after impl for EfficientNet training (https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py)
* AugMix w/ JSD loss (https://arxiv.org/abs/1912.02781), JSD w/ clean + augmented mixing support works with AutoAugment and RandAugment as well
* SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data
* DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382)
* DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382)
* DropBlock (https://arxiv.org/abs/1810.12890)
* Blur Pooling (https://arxiv.org/abs/1904.11486)
* Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
@ -600,19 +607,17 @@ Model validation results can be found in the [results tables](results/README.md)
## Getting Started (Documentation)
My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics.
Hugging Face [`timm` docs](https://huggingface.co/docs/hub/timm) will be the documentation focus going forward and will eventually replace the `github.io` docs above.
The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.
[Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
[timmdocs](http://timm.fast.ai/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
[timmdocs](http://timm.fast.ai/) is an alternate set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
[paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`.
## Train, Validation, Inference Scripts
The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See [documentation](https://rwightman.github.io/pytorch-image-models/scripts/) for some basics and [training hparams](https://rwightman.github.io/pytorch-image-models/training_hparam_examples) for some train examples that produce SOTA ImageNet results.
The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See [documentation](https://huggingface.co/docs/timm/training_script).
## Awesome PyTorch Resources

View File

@ -17,10 +17,14 @@ import os
import glob
import hashlib
from timm.models import load_state_dict
import safetensors.torch
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
DEFAULT_OUTPUT = "./average.pth"
DEFAULT_SAFE_OUTPUT = "./average.safetensors"
DEFAULT_OUTPUT = "./averaged.pth"
DEFAULT_SAFE_OUTPUT = "./averaged.safetensors"
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH',
@ -38,6 +42,7 @@ parser.add_argument('-n', type=int, default=10, metavar='N',
parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).')
def checkpoint_metric(checkpoint_path):
if not checkpoint_path or not os.path.isfile(checkpoint_path):
return {}
@ -63,14 +68,20 @@ def main():
if args.safetensors and args.output == DEFAULT_OUTPUT:
# Default path changes if using safetensors
args.output = DEFAULT_SAFE_OUTPUT
if args.safetensors and not args.output.endswith(".safetensors"):
output, output_ext = os.path.splitext(args.output)
if not output_ext:
output_ext = ('.safetensors' if args.safetensors else '.pth')
output = output + output_ext
if args.safetensors and not output_ext == ".safetensors":
print(
"Warning: saving weights as safetensors but output file extension is not "
f"set to '.safetensors': {args.output}"
)
if os.path.exists(args.output):
print("Error: Output filename ({}) already exists.".format(args.output))
if os.path.exists(output):
print("Error: Output filename ({}) already exists.".format(output))
exit(1)
pattern = args.input
@ -87,22 +98,27 @@ def main():
checkpoint_metrics.append((metric, c))
checkpoint_metrics = list(sorted(checkpoint_metrics))
checkpoint_metrics = checkpoint_metrics[-args.n:]
print("Selected checkpoints:")
[print(m, c) for m, c in checkpoint_metrics]
if checkpoint_metrics:
print("Selected checkpoints:")
[print(m, c) for m, c in checkpoint_metrics]
avg_checkpoints = [c for m, c in checkpoint_metrics]
else:
avg_checkpoints = checkpoints
print("Selected checkpoints:")
[print(c) for c in checkpoints]
if avg_checkpoints:
print("Selected checkpoints:")
[print(c) for c in checkpoints]
if not avg_checkpoints:
print('Error: No checkpoints found to average.')
exit(1)
avg_state_dict = {}
avg_counts = {}
for c in avg_checkpoints:
new_state_dict = load_state_dict(c, args.use_ema)
if not new_state_dict:
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
print(f"Error: Checkpoint ({c}) doesn't exist")
continue
for k, v in new_state_dict.items():
if k not in avg_state_dict:
avg_state_dict[k] = v.clone().to(dtype=torch.float64)
@ -122,16 +138,14 @@ def main():
final_state_dict[k] = v.to(dtype=torch.float32)
if args.safetensors:
safetensors.torch.save_file(final_state_dict, args.output)
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(final_state_dict, output)
else:
try:
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
except:
torch.save(final_state_dict, args.output)
torch.save(final_state_dict, output)
with open(args.output, 'rb') as f:
with open(output, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")
if __name__ == '__main__':

View File

@ -11,9 +11,14 @@ import torch
import argparse
import os
import hashlib
import safetensors.torch
import shutil
import tempfile
from timm.models import load_state_dict
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -22,13 +27,13 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
help='output path')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--no-hash', dest='no_hash', action='store_true',
help='no hash in output filename')
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).')
_TEMP_NAME = './_checkpoint.pth'
def main():
args = parser.parse_args()
@ -37,10 +42,24 @@ def main():
print("Error: Output filename ({}) already exists.".format(args.output))
exit(1)
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors)
clean_checkpoint(
args.checkpoint,
args.output,
not args.no_use_ema,
args.no_hash,
args.clean_aux_bn,
safe_serialization=args.safetensors,
)
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False):
def clean_checkpoint(
checkpoint,
output,
use_ema=True,
no_hash=False,
clean_aux_bn=False,
safe_serialization: bool=False,
):
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
if checkpoint and os.path.isfile(checkpoint):
print("=> Loading checkpoint '{}'".format(checkpoint))
@ -55,25 +74,36 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, sa
new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(checkpoint))
if safe_serialization:
safetensors.torch.save_file(new_state_dict, _TEMP_NAME)
else:
try:
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
except:
torch.save(new_state_dict, _TEMP_NAME)
with open(_TEMP_NAME, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
ext = ''
if output:
checkpoint_root, checkpoint_base = os.path.split(output)
checkpoint_base = os.path.splitext(checkpoint_base)[0]
checkpoint_base, ext = os.path.splitext(checkpoint_base)
else:
checkpoint_root = ''
checkpoint_base = os.path.splitext(checkpoint)[0]
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth')
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
checkpoint_base = os.path.split(checkpoint)[1]
checkpoint_base = os.path.splitext(checkpoint_base)[0]
temp_filename = '__' + checkpoint_base
if safe_serialization:
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(new_state_dict, temp_filename)
else:
torch.save(new_state_dict, temp_filename)
with open(temp_filename, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
if ext:
final_ext = ext
else:
final_ext = ('.safetensors' if safe_serialization else '.pth')
if no_hash:
final_filename = checkpoint_base + final_ext
else:
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext
shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename))
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
return final_filename
else:

14
pyproject.toml Normal file
View File

@ -0,0 +1,14 @@
[tool.pytest.ini_options]
markers = [
"base: marker for model tests using the basic setup",
"cfg: marker for model tests checking the config",
"torchscript: marker for model tests using torchscript",
"features: marker for model tests checking feature extraction",
"fxforward: marker for model tests using torch fx (only forward)",
"fxbackward: marker for model tests using torch fx (only backward)",
]
[tool.black]
line-length = 120
target-version = ['py37', 'py38', 'py39', 'py310', 'py311']
skip-string-normalization = true

5
requirements-dev.txt Normal file
View File

@ -0,0 +1,5 @@
pytest
pytest-timeout
pytest-xdist
pytest-forked
expecttest

View File

@ -14,12 +14,12 @@ exec(open('timm/version.py').read())
setup(
name='timm',
version=__version__,
description='(Unofficial) PyTorch Image Models',
description='PyTorch Image Models',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/rwightman/pytorch-image-models',
url='https://github.com/huggingface/pytorch-image-models',
author='Ross Wightman',
author_email='hello@rwightman.com',
author_email='ross@huggingface.co',
classifiers=[
# How mature is this project? Common values are
# 3 - Alpha
@ -29,11 +29,11 @@ setup(
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
@ -45,7 +45,7 @@ setup(
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
packages=find_packages(exclude=['convert', 'tests', 'results']),
include_package_data=True,
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'],
python_requires='>=3.6',
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub', 'safetensors'],
python_requires='>=3.7',
)

View File

@ -1,3 +1,16 @@
"""Run tests for all models
Tests that run on CI should have a specific marker, e.g. @pytest.mark.base. This
marker is used to parallelize the CI runs, with one runner for each marker.
If new tests are added, ensure that they use one of the existing markers
(documented in pyproject.toml > pytest > markers) or that a new marker is added
for this set of tests. If using a new marker, adjust the test matrix in
.github/workflows/tests.yml to run tests with this new marker, otherwise the
tests will be skipped on CI.
"""
import pytest
import torch
import platform
@ -83,6 +96,7 @@ def _get_input_size(model=None, model_name='', target=None):
return input_size
@pytest.mark.base
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
@ -101,6 +115,7 @@ def test_model_forward(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.base
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
@ -128,6 +143,7 @@ def test_model_backward(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.cfg
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1])
@ -190,6 +206,7 @@ def test_model_default_cfgs(model_name, batch_size):
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
@pytest.mark.cfg
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1])
@ -274,6 +291,7 @@ EXCLUDE_JIT_FILTERS = [
]
@pytest.mark.torchscript
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
@ -303,6 +321,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d']
@pytest.mark.features
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1])
@ -379,6 +398,7 @@ if not _IS_MAC:
]
@pytest.mark.fxforward
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
@ -412,6 +432,7 @@ if not _IS_MAC:
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.fxbackward
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))

View File

@ -54,8 +54,7 @@ def _interpolation(kwargs):
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation)
else:
return interpolation
return interpolation
def _check_args_tf(kwargs):
@ -100,7 +99,7 @@ def rotate(img, degrees, **kwargs):
_check_args_tf(kwargs)
if _PIL_VER >= (5, 2):
return img.rotate(degrees, **kwargs)
elif _PIL_VER >= (5, 0):
if _PIL_VER >= (5, 0):
w, h = img.size
post_trans = (0, 0)
rotn_center = (w / 2.0, h / 2.0)
@ -124,8 +123,7 @@ def rotate(img, degrees, **kwargs):
matrix[2] += rotn_center[0]
matrix[5] += rotn_center[1]
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
else:
return img.rotate(degrees, resample=kwargs['resample'])
return img.rotate(degrees, resample=kwargs['resample'])
def auto_contrast(img, **__):
@ -151,12 +149,13 @@ def solarize_add(img, add, thresh=128, **__):
lut.append(min(255, i + add))
else:
lut.append(i)
if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut
return img.point(lut)
else:
return img
return img
def posterize(img, bits_to_keep, **__):
@ -226,7 +225,7 @@ def _enhance_increasing_level_to_arg(level, _hparams):
def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
level = (level / _LEVEL_DENOM)
min_val + (max_val - min_val) * level
level = min_val + (max_val - min_val) * level
if clamp:
level = max(min_val, min(max_val, level))
return level,
@ -552,16 +551,15 @@ def auto_augment_policy(name='v0', hparams=None):
hparams = hparams or _HPARAMS_DEFAULT
if name == 'original':
return auto_augment_policy_original(hparams)
elif name == 'originalr':
if name == 'originalr':
return auto_augment_policy_originalr(hparams)
elif name == 'v0':
if name == 'v0':
return auto_augment_policy_v0(hparams)
elif name == 'v0r':
if name == 'v0r':
return auto_augment_policy_v0r(hparams)
elif name == '3a':
if name == '3a':
return auto_augment_policy_3a(hparams)
else:
assert False, 'Unknown AA policy (%s)' % name
assert False, f'Unknown AA policy {name}'
class AutoAugment:
@ -576,7 +574,7 @@ class AutoAugment:
return img
def __repr__(self):
fs = self.__class__.__name__ + f'(policy='
fs = self.__class__.__name__ + '(policy='
for p in self.policy:
fs += '\n\t['
fs += ', '.join([str(op) for op in p])
@ -636,7 +634,7 @@ _RAND_TRANSFORMS = [
'ShearY',
'TranslateXRel',
'TranslateYRel',
#'Cutout' # NOTE I've implement this as random erasing separately
# 'Cutout' # NOTE I've implement this as random erasing separately
]
@ -656,7 +654,7 @@ _RAND_INCREASING_TRANSFORMS = [
'ShearY',
'TranslateXRel',
'TranslateYRel',
#'Cutout' # NOTE I've implement this as random erasing separately
# 'Cutout' # NOTE I've implement this as random erasing separately
]
@ -667,7 +665,7 @@ _RAND_3A = [
]
_RAND_CHOICE_3A = {
_RAND_WEIGHTED_3A = {
'SolarizeIncreasing': 6,
'Desaturate': 6,
'GaussianBlur': 6,
@ -687,7 +685,7 @@ _RAND_CHOICE_3A = {
# These experimental weights are based loosely on the relative improvements mentioned in paper.
# They may not result in increased performance, but could likely be tuned to so.
_RAND_CHOICE_WEIGHTS_0 = {
_RAND_WEIGHTED_0 = {
'Rotate': 3,
'ShearX': 2,
'ShearY': 2,
@ -715,13 +713,12 @@ def _get_weighted_transforms(transforms: Dict):
def rand_augment_choices(name: str, increasing=True):
if name == 'weights':
return _RAND_CHOICE_WEIGHTS_0
elif name == '3aw':
return _RAND_CHOICE_3A
elif name == '3a':
return _RAND_WEIGHTED_0
if name == '3aw':
return _RAND_WEIGHTED_3A
if name == '3a':
return _RAND_3A
else:
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
def rand_augment_ops(

View File

@ -7,7 +7,11 @@ import os
from collections import OrderedDict
import torch
import safetensors.torch
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
import timm.models._builder
@ -29,6 +33,7 @@ def load_state_dict(checkpoint_path, use_ema=True):
if checkpoint_path and os.path.isfile(checkpoint_path):
# Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"):
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
else:
checkpoint = torch.load(checkpoint_path, map_location='cpu')

View File

@ -7,15 +7,21 @@ from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Iterable, Optional, Union
import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
import safetensors.torch
try:
from torch.hub import get_dir
except ImportError:
from torch.hub import _get_torch_home as get_dir
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
if sys.version_info >= (3, 8):
from typing import Literal
else:
@ -45,6 +51,7 @@ __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'l
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
def get_cache_dir(child_dir=''):
"""
Returns the location of the directory where models are cached (and creates it if necessary).
@ -164,21 +171,28 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
hf_model_id, hf_revision = hf_split(model_id)
# Look for .safetensors alternatives and load from it if it exists
for safe_filename in _get_safe_alternatives(filename):
try:
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
_logger.info(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
return safetensors.torch.load_file(cached_safe_file, device="cpu")
except EntryNotFoundError:
pass
if _has_safetensors:
for safe_filename in _get_safe_alternatives(filename):
try:
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
_logger.info(
f"[{model_id}] Safe alternative available for '{filename}' "
f"(as '{safe_filename}'). Loading weights using safetensors.")
return safetensors.torch.load_file(cached_safe_file, device="cpu")
except EntryNotFoundError:
pass
# Otherwise, load using pytorch.load
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
_logger.info(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
return torch.load(cached_file, map_location='cpu')
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
def save_config_for_hf(
model,
config_path: str,
model_config: Optional[dict] = None
):
model_config = model_config or {}
hf_config = {}
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
@ -220,8 +234,8 @@ def save_for_hf(
model,
save_directory: str,
model_config: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False
):
safe_serialization: Union[bool, Literal["both"]] = False,
):
assert has_hf_hub(True)
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
@ -229,6 +243,7 @@ def save_for_hf(
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
tensors = model.state_dict()
if safe_serialization is True or safe_serialization == "both":
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
if safe_serialization is False or safe_serialization == "both":
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
@ -238,16 +253,16 @@ def save_for_hf(
def push_to_hf_hub(
model,
repo_id: str,
commit_message: str = 'Add model',
token: Optional[str] = None,
revision: Optional[str] = None,
private: bool = False,
create_pr: bool = False,
model_config: Optional[dict] = None,
model_card: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False
model,
repo_id: str,
commit_message: str = 'Add model',
token: Optional[str] = None,
revision: Optional[str] = None,
private: bool = False,
create_pr: bool = False,
model_config: Optional[dict] = None,
model_card: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False,
):
"""
Arguments:
@ -341,6 +356,7 @@ def generate_readme(model_card: dict, model_name: str):
readme_text += f"```bibtex\n{c}\n```\n"
return readme_text
def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""Returns potential safetensors alternatives for a given filename.
@ -350,5 +366,5 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""
if filename == HF_WEIGHTS_NAME:
yield HF_SAFE_WEIGHTS_NAME
if filename.endswith(".bin"):
yield filename[:-4] + ".safetensors"
if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"):
return filename[:-4] + ".safetensors"

View File

@ -93,7 +93,7 @@ class DefaultCfg:
return tag, self.cfgs[tag]
def split_model_name_tag(model_name: str, no_tag=''):
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag
return model_name, tag

View File

@ -8,7 +8,7 @@ import sys
from collections import defaultdict, deque
from copy import deepcopy
from dataclasses import replace
from typing import List, Optional, Union, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
@ -16,20 +16,20 @@ __all__ = [
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module: Dict[str, str] = {} # mapping of model names to module names
_model_entrypoints: Dict[str, Callable[..., Any]] = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained: Set[str] = set() # set of model names that have pretrained weight url present
_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]:
def get_arch_name(model_name: str) -> str:
return split_model_name_tag(model_name)[0]
def register_model(fn):
def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
# lookup containing module
mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.')
@ -40,7 +40,7 @@ def register_model(fn):
if hasattr(mod, '__all__'):
mod.__all__.append(model_name)
else:
mod.__all__ = [model_name]
mod.__all__ = [model_name] # type: ignore
# add entries to registry dict/sets
_model_entrypoints[model_name] = fn
@ -87,28 +87,33 @@ def register_model(fn):
return fn
def _natural_key(string_):
def _natural_key(string_: str) -> List[Union[int, str]]:
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(
filter: Union[str, List[str]] = '',
module: str = '',
pretrained=False,
exclude_filters: str = '',
pretrained: bool = False,
exclude_filters: Union[str, List[str]] = '',
name_matches_cfg: bool = False,
include_tags: Optional[bool] = None,
):
) -> List[str]:
""" Return list of available model names, sorted alphabetically
Args:
filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
pretrained (bool) - Include only models with valid pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
filter - Wildcard filter string that works with fnmatch
module - Limit model selection to a specific submodule (ie 'vision_transformer')
pretrained - Include only models with valid pretrained weights if True
exclude_filters - Wildcard filters to exclude models after including them with filter
name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
include_tags - Include pretrained tags in model names (model.tag). If None, defaults
set to True when pretrained=True else False (default: None)
Returns:
models - The sorted list of models
Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
@ -118,7 +123,7 @@ def list_models(
include_tags = pretrained
if module:
all_models = list(_module_to_models[module])
all_models: Iterable[str] = list(_module_to_models[module])
else:
all_models = _model_entrypoints.keys()
@ -130,14 +135,14 @@ def list_models(
all_models = models_with_tags
if filter:
models = []
models: Set[str] = set()
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters:
include_models = fnmatch.filter(all_models, f) # include these models
if len(include_models):
models = set(models).union(include_models)
models = models.union(include_models)
else:
models = all_models
models = set(all_models)
if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)):
@ -145,7 +150,7 @@ def list_models(
for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models):
models = set(models).difference(exclude_models)
models = models.difference(exclude_models)
if pretrained:
models = _model_has_pretrained.intersection(models)
@ -153,13 +158,13 @@ def list_models(
if name_matches_cfg:
models = set(_model_pretrained_cfgs).intersection(models)
return list(sorted(models, key=_natural_key))
return sorted(models, key=_natural_key)
def list_pretrained(
filter: Union[str, List[str]] = '',
exclude_filters: str = '',
):
) -> List[str]:
return list_models(
filter=filter,
pretrained=True,
@ -168,14 +173,14 @@ def list_pretrained(
)
def is_model(model_name):
def is_model(model_name: str) -> bool:
""" Check if a model name exists
"""
arch_name = get_arch_name(model_name)
return arch_name in _model_entrypoints
def model_entrypoint(model_name, module_filter: Optional[str] = None):
def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
"""Fetch a model entrypoint for specified model name
"""
arch_name = get_arch_name(model_name)
@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None):
return _model_entrypoints[arch_name]
def list_modules():
def list_modules() -> List[str]:
""" Return list of module names that contain models / model entrypoints
"""
modules = _module_to_models.keys()
return list(sorted(modules))
return sorted(modules)
def is_model_in_modules(model_name, module_names):
def is_model_in_modules(
model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
) -> bool:
"""Check if a model exists within a subset of modules
Args:
model_name (str) - name of model to check
module_names (tuple, list, set) - names of modules to search in
model_name - name of model to check
module_names - names of modules to search in
"""
arch_name = get_arch_name(model_name)
assert isinstance(module_names, (tuple, list, set))
return any(arch_name in _module_to_models[n] for n in module_names)
def is_model_pretrained(model_name):
def is_model_pretrained(model_name: str) -> bool:
return model_name in _model_has_pretrained
def get_pretrained_cfg(model_name, allow_unregistered=True):
def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
if model_name in _model_pretrained_cfgs:
return deepcopy(_model_pretrained_cfgs[model_name])
arch_name, tag = split_model_name_tag(model_name)
@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True):
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
def get_pretrained_cfg_value(model_name, cfg_key):
def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
""" Get a specific model default_cfg value by key. None if key doesn't exist.
"""
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)

View File

@ -45,7 +45,7 @@ import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
@ -56,6 +56,28 @@ from ._registry import register_model
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
class Downsample(nn.Module):
def __init__(self, in_chs, out_chs, stride=1, dilation=1):
super().__init__()
avg_stride = stride if dilation == 1 else 1
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
else:
self.pool = nn.Identity()
if in_chs != out_chs:
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
else:
self.conv = nn.Identity()
def forward(self, x):
x = self.pool(x)
x = self.conv(x)
return x
class ConvNeXtBlock(nn.Module):
""" ConvNeXt Block
There are two equivalent implementations:
@ -65,41 +87,65 @@ class ConvNeXtBlock(nn.Module):
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
Args:
in_chs (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(
self,
in_chs,
out_chs=None,
kernel_size=7,
stride=1,
dilation=1,
mlp_ratio=4,
conv_mlp=False,
conv_bias=True,
use_grn=False,
ls_init_value=1e-6,
act_layer='gelu',
norm_layer=None,
drop_path=0.,
in_chs: int,
out_chs: Optional[int] = None,
kernel_size: int = 7,
stride: int = 1,
dilation: Union[int, Tuple[int, int]] = (1, 1),
mlp_ratio: float = 4,
conv_mlp: bool = False,
conv_bias: bool = True,
use_grn: bool = False,
ls_init_value: Optional[float] = 1e-6,
act_layer: Union[str, Callable] = 'gelu',
norm_layer: Optional[Callable] = None,
drop_path: float = 0.,
):
"""
Args:
in_chs: Block input channels.
out_chs: Block output channels (same as in_chs if None).
kernel_size: Depthwise convolution kernel size.
stride: Stride of depthwise convolution.
dilation: Tuple specifying input and output dilation of block.
mlp_ratio: MLP expansion ratio.
conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
conv_bias: Apply bias for all convolution (linear) layers.
use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
ls_init_value: Layer-scale init values, layer-scale applied if not None.
act_layer: Activation layer.
norm_layer: Normalization layer (defaults to LN if not specified).
drop_path: Stochastic depth probability.
"""
super().__init__()
out_chs = out_chs or in_chs
dilation = to_ntuple(2)(dilation)
act_layer = get_act_layer(act_layer)
if not norm_layer:
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
self.use_conv_mlp = conv_mlp
self.conv_dw = create_conv2d(
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
in_chs,
out_chs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation[0],
depthwise=True,
bias=conv_bias,
)
self.norm = norm_layer(out_chs)
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
else:
self.shortcut = nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
@ -116,7 +162,7 @@ class ConvNeXtBlock(nn.Module):
if self.gamma is not None:
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
x = self.drop_path(x) + shortcut
x = self.drop_path(x) + self.shortcut(shortcut)
return x
@ -148,8 +194,14 @@ class ConvNeXtStage(nn.Module):
self.downsample = nn.Sequential(
norm_layer(in_chs),
create_conv2d(
in_chs, out_chs, kernel_size=ds_ks, stride=stride,
dilation=dilation[0], padding=pad, bias=conv_bias),
in_chs,
out_chs,
kernel_size=ds_ks,
stride=stride,
dilation=dilation[0],
padding=pad,
bias=conv_bias,
),
)
in_chs = out_chs
else:
@ -773,136 +825,147 @@ default_cfgs = generate_default_cfgs({
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
'convnext_xxlarge.clip_laion2b_soup': _cfg(
hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
'convnext_xxlarge.clip_laion2b_rewind': _cfg(
hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
})
@register_model
def convnext_atto(pretrained=False, **kwargs):
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
model_args = dict(
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, **kwargs)
model = _create_convnext('convnext_atto', pretrained=pretrained, **model_args)
model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True)
model = _create_convnext('convnext_atto', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_atto_ols(pretrained=False, **kwargs):
# timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
model_args = dict(
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **model_args)
model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered')
model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_femto(pretrained=False, **kwargs):
# timm femto variant
model_args = dict(
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, **kwargs)
model = _create_convnext('convnext_femto', pretrained=pretrained, **model_args)
model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True)
model = _create_convnext('convnext_femto', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_femto_ols(pretrained=False, **kwargs):
# timm femto variant
model_args = dict(
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **model_args)
model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered')
model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_pico(pretrained=False, **kwargs):
# timm pico variant
model_args = dict(
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, **kwargs)
model = _create_convnext('convnext_pico', pretrained=pretrained, **model_args)
model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True)
model = _create_convnext('convnext_pico', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_pico_ols(pretrained=False, **kwargs):
# timm nano variant with overlapping 3x3 conv stem
model_args = dict(
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **model_args)
model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered')
model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_nano(pretrained=False, **kwargs):
# timm nano variant with standard stem and head
model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs)
model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args)
model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True)
model = _create_convnext('convnext_nano', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_nano_ols(pretrained=False, **kwargs):
# experimental nano variant with overlapping conv stem
model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap', **kwargs)
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap')
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_tiny_hnf(pretrained=False, **kwargs):
# experimental tiny variant with norm before pooling in head (head norm first)
model_args = dict(
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True)
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_tiny(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args)
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768))
model = _create_convnext('convnext_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_small(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
model = _create_convnext('convnext_small', pretrained=pretrained, **model_args)
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
model = _create_convnext('convnext_small', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_base(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_base', pretrained=pretrained, **model_args)
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
model = _create_convnext('convnext_base', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_large(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
model = _create_convnext('convnext_large', pretrained=pretrained, **model_args)
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
model = _create_convnext('convnext_large', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_large_mlp(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536, **kwargs)
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **model_args)
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536)
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_xlarge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **model_args)
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnext_xxlarge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], **kwargs)
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **model_args)
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -910,8 +973,8 @@ def convnext_xxlarge(pretrained=False, **kwargs):
def convnextv2_atto(pretrained=False, **kwargs):
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
model_args = dict(
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
model = _create_convnext('convnextv2_atto', pretrained=pretrained, **model_args)
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True)
model = _create_convnext('convnextv2_atto', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -919,8 +982,8 @@ def convnextv2_atto(pretrained=False, **kwargs):
def convnextv2_femto(pretrained=False, **kwargs):
# timm femto variant
model_args = dict(
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
model = _create_convnext('convnextv2_femto', pretrained=pretrained, **model_args)
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True)
model = _create_convnext('convnextv2_femto', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -928,8 +991,8 @@ def convnextv2_femto(pretrained=False, **kwargs):
def convnextv2_pico(pretrained=False, **kwargs):
# timm pico variant
model_args = dict(
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
model = _create_convnext('convnextv2_pico', pretrained=pretrained, **model_args)
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True)
model = _create_convnext('convnextv2_pico', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -937,42 +1000,41 @@ def convnextv2_pico(pretrained=False, **kwargs):
def convnextv2_nano(pretrained=False, **kwargs):
# timm nano variant with standard stem and head
model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
model = _create_convnext('convnextv2_nano', pretrained=pretrained, **model_args)
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True)
model = _create_convnext('convnextv2_nano', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnextv2_tiny(pretrained=False, **kwargs):
model_args = dict(
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **model_args)
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None)
model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnextv2_small(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_small', pretrained=pretrained, **model_args)
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None)
model = _create_convnext('convnextv2_small', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnextv2_base(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_base', pretrained=pretrained, **model_args)
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None)
model = _create_convnext('convnextv2_base', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnextv2_large(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_large', pretrained=pretrained, **model_args)
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None)
model = _create_convnext('convnextv2_large', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def convnextv2_huge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **model_args)
return model
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
return model

View File

@ -20,7 +20,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_module
from ._registry import register_model
@ -564,6 +564,7 @@ class MobileVitV2Block(nn.Module):
self.patch_size = to_2tuple(patch_size)
self.patch_area = self.patch_size[0] * self.patch_size[1]
self.coreml_exportable = is_exportable()
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
@ -580,7 +581,10 @@ class MobileVitV2Block(nn.Module):
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
C = x.shape[1]
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
if self.coreml_exportable:
x = F.unfold(x, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
else:
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
x = x.reshape(B, C, -1, num_patches)
# Global representations
@ -588,8 +592,14 @@ class MobileVitV2Block(nn.Module):
x = self.norm(x)
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
if self.coreml_exportable:
# adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624
x = x.reshape(B, C * patch_h * patch_w, num_patch_h, num_patch_w)
x = F.pixel_shuffle(x, upscale_factor=patch_h)
else:
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
x = self.conv_proj(x)
return x

View File

@ -1 +1 @@
__version__ = '0.8.12dev0'
__version__ = '0.8.15dev0'

View File

@ -514,8 +514,14 @@ def main():
if utils.is_primary(args):
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
elif use_amp == 'native':
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
if device.type == 'cuda':
try:
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
except (AttributeError, TypeError):
# fallback to CUDA only AMP for PyTorch < 1.10
assert device.type == 'cuda'
amp_autocast = torch.cuda.amp.autocast
if device.type == 'cuda' and amp_dtype == torch.float16:
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
loss_scaler = NativeScaler()
if utils.is_primary(args):
_logger.info('Using native Torch AMP. Training in mixed precision.')