mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge remote-tracking branch 'origin/main' into focalnet_and_swin_refactor
This commit is contained in:
commit
c30a160d3e
7
.github/workflows/tests.yml
vendored
7
.github/workflows/tests.yml
vendored
@ -19,6 +19,7 @@ jobs:
|
||||
python: ['3.10']
|
||||
torch: ['1.13.0']
|
||||
torchvision: ['0.14.0']
|
||||
testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward']
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
@ -30,7 +31,7 @@ jobs:
|
||||
- name: Install testing dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest pytest-timeout pytest-xdist pytest-forked expecttest
|
||||
pip install -r requirements-dev.txt
|
||||
- name: Install torch on mac
|
||||
if: startsWith(matrix.os, 'macOS')
|
||||
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
|
||||
@ -54,10 +55,10 @@ jobs:
|
||||
PYTHONDONTWRITEBYTECODE: 1
|
||||
run: |
|
||||
pytest -vv tests
|
||||
- name: Run tests on Linux / Mac
|
||||
- name: Run '${{ matrix.testmarker }}' tests on Linux / Mac
|
||||
if: ${{ !startsWith(matrix.os, 'windows') }}
|
||||
env:
|
||||
LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
|
||||
PYTHONDONTWRITEBYTECODE: 1
|
||||
run: |
|
||||
pytest -vv --forked --durations=0 tests
|
||||
pytest -vv --forked --durations=0 ${{ matrix.testmarker }} tests
|
||||
|
107
CONTRIBUTING.md
Normal file
107
CONTRIBUTING.md
Normal file
@ -0,0 +1,107 @@
|
||||
*This guideline is very much a work-in-progress.*
|
||||
|
||||
Contriubtions to `timm` for code, documentation, tests are more than welcome!
|
||||
|
||||
There haven't been any formal guidelines to date so please bear with me, and feel free to add to this guide.
|
||||
|
||||
# Coding style
|
||||
|
||||
Code linting and auto-format (black) are not currently in place but open to consideration. In the meantime, the style to follow is (mostly) aligned with Google's guide: https://google.github.io/styleguide/pyguide.html.
|
||||
|
||||
A few specific differences from Google style (or black)
|
||||
1. Line length is 120 char. Going over is okay in some cases (e.g. I prefer not to break URL across lines).
|
||||
2. Hanging indents are always prefered, please avoid aligning arguments with closing brackets or braces.
|
||||
|
||||
Example, from Google guide, but this is a NO here:
|
||||
```
|
||||
# Aligned with opening delimiter.
|
||||
foo = long_function_name(var_one, var_two,
|
||||
var_three, var_four)
|
||||
meal = (spam,
|
||||
beans)
|
||||
|
||||
# Aligned with opening delimiter in a dictionary.
|
||||
foo = {
|
||||
'long_dictionary_key': value1 +
|
||||
value2,
|
||||
...
|
||||
}
|
||||
```
|
||||
This is YES:
|
||||
|
||||
```
|
||||
# 4-space hanging indent; nothing on first line,
|
||||
# closing parenthesis on a new line.
|
||||
foo = long_function_name(
|
||||
var_one, var_two, var_three,
|
||||
var_four
|
||||
)
|
||||
meal = (
|
||||
spam,
|
||||
beans,
|
||||
)
|
||||
|
||||
# 4-space hanging indent in a dictionary.
|
||||
foo = {
|
||||
'long_dictionary_key':
|
||||
long_dictionary_value,
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
When there is descrepancy in a given source file (there are many origins for various bits of code and not all have been updated to what I consider current goal), please follow the style in a given file.
|
||||
|
||||
In general, if you add new code, formatting it with black using the following options should result in a style that is compatible with the rest of the code base:
|
||||
|
||||
```
|
||||
black --skip-string-normalization --line-length 120 <path-to-file>
|
||||
```
|
||||
|
||||
Avoid formatting code that is unrelated to your PR though.
|
||||
|
||||
PR with pure formatting / style fixes will be accepted but only in isolation from functional changes, best to ask before starting such a change.
|
||||
|
||||
# Documentation
|
||||
|
||||
As with code style, docstrings style based on the Google guide: guide: https://google.github.io/styleguide/pyguide.html
|
||||
|
||||
The goal for the code is to eventually move to have all major functions and `__init__` methods use PEP484 type annotations.
|
||||
|
||||
When type annotations are used for a function, as per the Google pyguide, they should **NOT** be duplicated in the docstrings, please leave annotations as the one source of truth re typing.
|
||||
|
||||
There are a LOT of gaps in current documentation relative to the functionality in timm, please, document away!
|
||||
|
||||
# Installation
|
||||
|
||||
Create a Python virtual environment using Python 3.10. Inside the environment, install torch` and `torchvision` using the instructions matching your system as listed on the [PyTorch website](https://pytorch.org/).
|
||||
|
||||
Then install the remaining dependencies:
|
||||
|
||||
```
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install -r requirements-dev.txt # for testing
|
||||
python -m pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git
|
||||
python -m pip install -e .
|
||||
```
|
||||
|
||||
## Unit tests
|
||||
|
||||
Run the tests using:
|
||||
|
||||
```
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
Since the whole test suite takes a lot of time to run locally (a few hours), you may want to select a subset of tests relating to the changes you made by using the `-k` option of [`pytest`](https://docs.pytest.org/en/7.1.x/example/markers.html#using-k-expr-to-select-tests-based-on-their-name). Moreover, running tests in parallel (in this example 4 processes) with the `-n` option may help:
|
||||
|
||||
```
|
||||
pytest -k "substring-to-match" -n 4 tests/
|
||||
```
|
||||
|
||||
## Building documentation
|
||||
|
||||
Please refer to [this document](https://github.com/huggingface/pytorch-image-models/tree/main/hfdocs).
|
||||
|
||||
# Questions
|
||||
|
||||
If you have any questions about contribution, where / how to contribute, please ask in the [Discussions](https://github.com/huggingface/pytorch-image-models/discussions/categories/contributing) (there is a `Contributing` topic).
|
45
README.md
45
README.md
@ -24,6 +24,15 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
|
||||
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
|
||||
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
|
||||
|
||||
### Feb 26, 2023
|
||||
* Add ConvNeXt-XXLarge CLIP pretrained image tower weights for fine-tune & features (fine-tuning TBD) -- see [model card](https://huggingface.co/laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup)
|
||||
* Update `convnext_xxlarge` default LayerNorm eps to 1e-5 (for CLIP weights, improved stability)
|
||||
* 0.8.15dev0
|
||||
|
||||
### Feb 20, 2023
|
||||
* Add 320x320 `convnext_large_mlp.clip_laion2b_ft_320` and `convnext_lage_mlp.clip_laion2b_ft_soup_320` CLIP image tower weights for features & fine-tune
|
||||
* 0.8.13dev0 pypi release for latest changes w/ move to huggingface org
|
||||
|
||||
### Feb 16, 2023
|
||||
* `safetensor` checkpoint support added
|
||||
* Add ideas from 'Scaling Vision Transformers to 22 B. Params' (https://arxiv.org/abs/2302.05442) -- qk norm, RmsNorm, parallel block
|
||||
@ -112,7 +121,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
|
||||
* Finally got around to adding `--model-kwargs` and `--opt-kwargs` to scripts to pass through rare args directly to model classes from cmd line
|
||||
* `train.py /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu`
|
||||
* `train.py /imagenet --model vit_base_patch16_clip_224 --img-size 240 --amp --model-kwargs img_size=240 patch_size=12`
|
||||
* Cleanup some popular models to better support arg passthrough / merge with model configs, more to go.
|
||||
* Cleanup some popular models to better support arg passthrough / merge with model configs, more to go.
|
||||
|
||||
### Jan 5, 2023
|
||||
* ConvNeXt-V2 models and weights added to existing `convnext.py`
|
||||
@ -142,7 +151,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
|
||||
| eva_large_patch14_196.in22k_ft_in1k | 87.9 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) |
|
||||
|
||||
### Dec 6, 2022
|
||||
* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`.
|
||||
* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`.
|
||||
* original source: https://github.com/baaivision/EVA
|
||||
* paper: https://arxiv.org/abs/2211.07636
|
||||
|
||||
@ -237,7 +246,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
|
||||
* `maxxvit_rmlp_small_rw_256` - 84.6 @ 256, 84.9 @ 288 (G) -- could be trained better, hparams need tuning (uses ConvNeXt block, no BN)
|
||||
* `coatnet_rmlp_2_rw_224` - 84.6 @ 224, 85 @ 320 (T)
|
||||
* NOTE: official MaxVit weights (in1k) have been released at https://github.com/google-research/maxvit -- some extra work is needed to port and adapt since my impl was created independently of theirs and has a few small differences + the whole TF same padding fun.
|
||||
|
||||
|
||||
### Sept 23, 2022
|
||||
* LAION-2B CLIP image towers supported as pretrained backbones for fine-tune or features (no classifier)
|
||||
* vit_base_patch32_224_clip_laion2b
|
||||
@ -268,7 +277,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
|
||||
* `coatnet_bn_0_rw_224` - 82.4 (T)
|
||||
* `maxvit_nano_rw_256` - 82.9 @ 256 (T)
|
||||
* `coatnet_rmlp_1_rw_224` - 83.4 @ 224, 84 @ 320 (T)
|
||||
* `coatnet_1_rw_224` - 83.6 @ 224 (G)
|
||||
* `coatnet_1_rw_224` - 83.6 @ 224 (G)
|
||||
* (T) = TPU trained with `bits_and_tpu` branch training code, (G) = GPU trained
|
||||
* GCVit (weights adapted from https://github.com/NVlabs/GCVit, code 100% `timm` re-write for license purposes)
|
||||
* MViT-V2 (multi-scale vit, adapted from https://github.com/facebookresearch/mvit)
|
||||
@ -283,7 +292,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
|
||||
* `convnext_atto_ols` - 75.9 @ 224, 77.2 @ 288
|
||||
|
||||
### Aug 5, 2022
|
||||
* More custom ConvNeXt smaller model defs with weights
|
||||
* More custom ConvNeXt smaller model defs with weights
|
||||
* `convnext_femto` - 77.5 @ 224, 78.7 @ 288
|
||||
* `convnext_femto_ols` - 77.9 @ 224, 78.9 @ 288
|
||||
* `convnext_pico` - 79.5 @ 224, 80.4 @ 288
|
||||
@ -304,7 +313,7 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
|
||||
* `cs3sedarknet_x` - 82.2 @ 256, 82.7 @ 288
|
||||
* `cs3edgenet_x` - 82.2 @ 256, 82.7 @ 288
|
||||
* `cs3se_edgenet_x` - 82.8 @ 256, 83.5 @ 320
|
||||
* `cs3*` weights above all trained on TPU w/ `bits_and_tpu` branch. Thanks to TRC program!
|
||||
* `cs3*` weights above all trained on TPU w/ `bits_and_tpu` branch. Thanks to TRC program!
|
||||
* Add output_stride=8 and 16 support to ConvNeXt (dilation)
|
||||
* deit3 models not being able to resize pos_emb fixed
|
||||
* Version 0.6.7 PyPi release (/w above bug fixes and new weighs since 0.6.5)
|
||||
@ -337,8 +346,8 @@ More models, more fixes
|
||||
* Hugging Face Hub support fixes verified, demo notebook TBA
|
||||
* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation.
|
||||
* Add support to change image extensions scanned by `timm` datasets/readers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
|
||||
* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases.
|
||||
* a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges.
|
||||
* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases.
|
||||
* a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges.
|
||||
* previous impl exists as `LayerNormExp2d` in `models/layers/norm.py`
|
||||
* Numerous bug fixes
|
||||
* Currently testing for imminent PyPi 0.6.x release
|
||||
@ -435,9 +444,7 @@ The work of many others is present here. I've tried to make sure all source mate
|
||||
|
||||
## Models
|
||||
|
||||
All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated. Here are some example [training hparams](https://rwightman.github.io/pytorch-image-models/training_hparam_examples) to get you started.
|
||||
|
||||
A full version of the list below with source links can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/models/).
|
||||
All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated.
|
||||
|
||||
* Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723
|
||||
* BEiT - https://arxiv.org/abs/2106.08254
|
||||
@ -538,15 +545,15 @@ Several (less common) features that I often utilize in my projects are included.
|
||||
|
||||
* All models have a common default configuration interface and API for
|
||||
* accessing/changing the classifier - `get_classifier` and `reset_classifier`
|
||||
* doing a forward pass on just the features - `forward_features` (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/))
|
||||
* doing a forward pass on just the features - `forward_features` (see [documentation](https://huggingface.co/docs/timm/feature_extraction))
|
||||
* these makes it easy to write consistent network wrappers that work with any of the models
|
||||
* All models support multi-scale feature map extraction (feature pyramids) via create_model (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/))
|
||||
* All models support multi-scale feature map extraction (feature pyramids) via create_model (see [documentation](https://huggingface.co/docs/timm/feature_extraction))
|
||||
* `create_model(name, features_only=True, out_indices=..., output_stride=...)`
|
||||
* `out_indices` creation arg specifies which feature maps to return, these indices are 0 based and generally correspond to the `C(i + 1)` feature level.
|
||||
* `output_stride` creation arg controls output stride of the network by using dilated convolutions. Most networks are stride 32 by default. Not all networks support this.
|
||||
* feature map channel counts, reduction level (stride) can be queried AFTER model creation via the `.feature_info` member
|
||||
* All models have a consistent pretrained weight loader that adapts last linear if necessary, and from 3 to 1 channel input if desired
|
||||
* High performance [reference training, validation, and inference scripts](https://rwightman.github.io/pytorch-image-models/scripts/) that work in several process/GPU modes:
|
||||
* High performance [reference training, validation, and inference scripts](https://huggingface.co/docs/timm/training_script) that work in several process/GPU modes:
|
||||
* NVIDIA DDP w/ a single GPU per process, multiple processes with APEX present (AMP mixed-precision optional)
|
||||
* PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled)
|
||||
* PyTorch w/ single GPU single process (AMP optional)
|
||||
@ -573,7 +580,7 @@ Several (less common) features that I often utilize in my projects are included.
|
||||
* AutoAugment (https://arxiv.org/abs/1805.09501) and RandAugment (https://arxiv.org/abs/1909.13719) ImageNet configurations modeled after impl for EfficientNet training (https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py)
|
||||
* AugMix w/ JSD loss (https://arxiv.org/abs/1912.02781), JSD w/ clean + augmented mixing support works with AutoAugment and RandAugment as well
|
||||
* SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data
|
||||
* DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382)
|
||||
* DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382)
|
||||
* DropBlock (https://arxiv.org/abs/1810.12890)
|
||||
* Blur Pooling (https://arxiv.org/abs/1904.11486)
|
||||
* Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
|
||||
@ -600,19 +607,17 @@ Model validation results can be found in the [results tables](results/README.md)
|
||||
|
||||
## Getting Started (Documentation)
|
||||
|
||||
My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics.
|
||||
|
||||
Hugging Face [`timm` docs](https://huggingface.co/docs/hub/timm) will be the documentation focus going forward and will eventually replace the `github.io` docs above.
|
||||
The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.
|
||||
|
||||
[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
|
||||
|
||||
[timmdocs](http://timm.fast.ai/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
|
||||
[timmdocs](http://timm.fast.ai/) is an alternate set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
|
||||
|
||||
[paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`.
|
||||
|
||||
## Train, Validation, Inference Scripts
|
||||
|
||||
The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See [documentation](https://rwightman.github.io/pytorch-image-models/scripts/) for some basics and [training hparams](https://rwightman.github.io/pytorch-image-models/training_hparam_examples) for some train examples that produce SOTA ImageNet results.
|
||||
The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See [documentation](https://huggingface.co/docs/timm/training_script).
|
||||
|
||||
## Awesome PyTorch Resources
|
||||
|
||||
|
@ -17,10 +17,14 @@ import os
|
||||
import glob
|
||||
import hashlib
|
||||
from timm.models import load_state_dict
|
||||
import safetensors.torch
|
||||
try:
|
||||
import safetensors.torch
|
||||
_has_safetensors = True
|
||||
except ImportError:
|
||||
_has_safetensors = False
|
||||
|
||||
DEFAULT_OUTPUT = "./average.pth"
|
||||
DEFAULT_SAFE_OUTPUT = "./average.safetensors"
|
||||
DEFAULT_OUTPUT = "./averaged.pth"
|
||||
DEFAULT_SAFE_OUTPUT = "./averaged.safetensors"
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
|
||||
parser.add_argument('--input', default='', type=str, metavar='PATH',
|
||||
@ -38,6 +42,7 @@ parser.add_argument('-n', type=int, default=10, metavar='N',
|
||||
parser.add_argument('--safetensors', action='store_true',
|
||||
help='Save weights using safetensors instead of the default torch way (pickle).')
|
||||
|
||||
|
||||
def checkpoint_metric(checkpoint_path):
|
||||
if not checkpoint_path or not os.path.isfile(checkpoint_path):
|
||||
return {}
|
||||
@ -63,14 +68,20 @@ def main():
|
||||
if args.safetensors and args.output == DEFAULT_OUTPUT:
|
||||
# Default path changes if using safetensors
|
||||
args.output = DEFAULT_SAFE_OUTPUT
|
||||
if args.safetensors and not args.output.endswith(".safetensors"):
|
||||
|
||||
output, output_ext = os.path.splitext(args.output)
|
||||
if not output_ext:
|
||||
output_ext = ('.safetensors' if args.safetensors else '.pth')
|
||||
output = output + output_ext
|
||||
|
||||
if args.safetensors and not output_ext == ".safetensors":
|
||||
print(
|
||||
"Warning: saving weights as safetensors but output file extension is not "
|
||||
f"set to '.safetensors': {args.output}"
|
||||
)
|
||||
|
||||
if os.path.exists(args.output):
|
||||
print("Error: Output filename ({}) already exists.".format(args.output))
|
||||
if os.path.exists(output):
|
||||
print("Error: Output filename ({}) already exists.".format(output))
|
||||
exit(1)
|
||||
|
||||
pattern = args.input
|
||||
@ -87,22 +98,27 @@ def main():
|
||||
checkpoint_metrics.append((metric, c))
|
||||
checkpoint_metrics = list(sorted(checkpoint_metrics))
|
||||
checkpoint_metrics = checkpoint_metrics[-args.n:]
|
||||
print("Selected checkpoints:")
|
||||
[print(m, c) for m, c in checkpoint_metrics]
|
||||
if checkpoint_metrics:
|
||||
print("Selected checkpoints:")
|
||||
[print(m, c) for m, c in checkpoint_metrics]
|
||||
avg_checkpoints = [c for m, c in checkpoint_metrics]
|
||||
else:
|
||||
avg_checkpoints = checkpoints
|
||||
print("Selected checkpoints:")
|
||||
[print(c) for c in checkpoints]
|
||||
if avg_checkpoints:
|
||||
print("Selected checkpoints:")
|
||||
[print(c) for c in checkpoints]
|
||||
|
||||
if not avg_checkpoints:
|
||||
print('Error: No checkpoints found to average.')
|
||||
exit(1)
|
||||
|
||||
avg_state_dict = {}
|
||||
avg_counts = {}
|
||||
for c in avg_checkpoints:
|
||||
new_state_dict = load_state_dict(c, args.use_ema)
|
||||
if not new_state_dict:
|
||||
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
|
||||
print(f"Error: Checkpoint ({c}) doesn't exist")
|
||||
continue
|
||||
|
||||
for k, v in new_state_dict.items():
|
||||
if k not in avg_state_dict:
|
||||
avg_state_dict[k] = v.clone().to(dtype=torch.float64)
|
||||
@ -122,16 +138,14 @@ def main():
|
||||
final_state_dict[k] = v.to(dtype=torch.float32)
|
||||
|
||||
if args.safetensors:
|
||||
safetensors.torch.save_file(final_state_dict, args.output)
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
safetensors.torch.save_file(final_state_dict, output)
|
||||
else:
|
||||
try:
|
||||
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
|
||||
except:
|
||||
torch.save(final_state_dict, args.output)
|
||||
torch.save(final_state_dict, output)
|
||||
|
||||
with open(args.output, 'rb') as f:
|
||||
with open(output, 'rb') as f:
|
||||
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
||||
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
|
||||
print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -11,9 +11,14 @@ import torch
|
||||
import argparse
|
||||
import os
|
||||
import hashlib
|
||||
import safetensors.torch
|
||||
import shutil
|
||||
import tempfile
|
||||
from timm.models import load_state_dict
|
||||
try:
|
||||
import safetensors.torch
|
||||
_has_safetensors = True
|
||||
except ImportError:
|
||||
_has_safetensors = False
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
@ -22,13 +27,13 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
|
||||
help='output path')
|
||||
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
|
||||
help='use ema version of weights if present')
|
||||
parser.add_argument('--no-hash', dest='no_hash', action='store_true',
|
||||
help='no hash in output filename')
|
||||
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
|
||||
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
|
||||
parser.add_argument('--safetensors', action='store_true',
|
||||
help='Save weights using safetensors instead of the default torch way (pickle).')
|
||||
|
||||
_TEMP_NAME = './_checkpoint.pth'
|
||||
|
||||
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
@ -37,10 +42,24 @@ def main():
|
||||
print("Error: Output filename ({}) already exists.".format(args.output))
|
||||
exit(1)
|
||||
|
||||
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors)
|
||||
clean_checkpoint(
|
||||
args.checkpoint,
|
||||
args.output,
|
||||
not args.no_use_ema,
|
||||
args.no_hash,
|
||||
args.clean_aux_bn,
|
||||
safe_serialization=args.safetensors,
|
||||
)
|
||||
|
||||
|
||||
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False):
|
||||
def clean_checkpoint(
|
||||
checkpoint,
|
||||
output,
|
||||
use_ema=True,
|
||||
no_hash=False,
|
||||
clean_aux_bn=False,
|
||||
safe_serialization: bool=False,
|
||||
):
|
||||
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
|
||||
if checkpoint and os.path.isfile(checkpoint):
|
||||
print("=> Loading checkpoint '{}'".format(checkpoint))
|
||||
@ -55,25 +74,36 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, sa
|
||||
new_state_dict[name] = v
|
||||
print("=> Loaded state_dict from '{}'".format(checkpoint))
|
||||
|
||||
if safe_serialization:
|
||||
safetensors.torch.save_file(new_state_dict, _TEMP_NAME)
|
||||
else:
|
||||
try:
|
||||
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
|
||||
except:
|
||||
torch.save(new_state_dict, _TEMP_NAME)
|
||||
|
||||
with open(_TEMP_NAME, 'rb') as f:
|
||||
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
||||
|
||||
ext = ''
|
||||
if output:
|
||||
checkpoint_root, checkpoint_base = os.path.split(output)
|
||||
checkpoint_base = os.path.splitext(checkpoint_base)[0]
|
||||
checkpoint_base, ext = os.path.splitext(checkpoint_base)
|
||||
else:
|
||||
checkpoint_root = ''
|
||||
checkpoint_base = os.path.splitext(checkpoint)[0]
|
||||
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth')
|
||||
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
|
||||
checkpoint_base = os.path.split(checkpoint)[1]
|
||||
checkpoint_base = os.path.splitext(checkpoint_base)[0]
|
||||
|
||||
temp_filename = '__' + checkpoint_base
|
||||
if safe_serialization:
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
safetensors.torch.save_file(new_state_dict, temp_filename)
|
||||
else:
|
||||
torch.save(new_state_dict, temp_filename)
|
||||
|
||||
with open(temp_filename, 'rb') as f:
|
||||
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
||||
|
||||
if ext:
|
||||
final_ext = ext
|
||||
else:
|
||||
final_ext = ('.safetensors' if safe_serialization else '.pth')
|
||||
|
||||
if no_hash:
|
||||
final_filename = checkpoint_base + final_ext
|
||||
else:
|
||||
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext
|
||||
|
||||
shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename))
|
||||
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
|
||||
return final_filename
|
||||
else:
|
||||
|
14
pyproject.toml
Normal file
14
pyproject.toml
Normal file
@ -0,0 +1,14 @@
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"base: marker for model tests using the basic setup",
|
||||
"cfg: marker for model tests checking the config",
|
||||
"torchscript: marker for model tests using torchscript",
|
||||
"features: marker for model tests checking feature extraction",
|
||||
"fxforward: marker for model tests using torch fx (only forward)",
|
||||
"fxbackward: marker for model tests using torch fx (only backward)",
|
||||
]
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ['py37', 'py38', 'py39', 'py310', 'py311']
|
||||
skip-string-normalization = true
|
5
requirements-dev.txt
Normal file
5
requirements-dev.txt
Normal file
@ -0,0 +1,5 @@
|
||||
pytest
|
||||
pytest-timeout
|
||||
pytest-xdist
|
||||
pytest-forked
|
||||
expecttest
|
12
setup.py
12
setup.py
@ -14,12 +14,12 @@ exec(open('timm/version.py').read())
|
||||
setup(
|
||||
name='timm',
|
||||
version=__version__,
|
||||
description='(Unofficial) PyTorch Image Models',
|
||||
description='PyTorch Image Models',
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
url='https://github.com/rwightman/pytorch-image-models',
|
||||
url='https://github.com/huggingface/pytorch-image-models',
|
||||
author='Ross Wightman',
|
||||
author_email='hello@rwightman.com',
|
||||
author_email='ross@huggingface.co',
|
||||
classifiers=[
|
||||
# How mature is this project? Common values are
|
||||
# 3 - Alpha
|
||||
@ -29,11 +29,11 @@ setup(
|
||||
'Intended Audience :: Education',
|
||||
'Intended Audience :: Science/Research',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: Python :: 3.9',
|
||||
'Programming Language :: Python :: 3.10',
|
||||
'Programming Language :: Python :: 3.11',
|
||||
'Topic :: Scientific/Engineering',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development',
|
||||
@ -45,7 +45,7 @@ setup(
|
||||
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
|
||||
packages=find_packages(exclude=['convert', 'tests', 'results']),
|
||||
include_package_data=True,
|
||||
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'],
|
||||
python_requires='>=3.6',
|
||||
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub', 'safetensors'],
|
||||
python_requires='>=3.7',
|
||||
)
|
||||
|
||||
|
@ -1,3 +1,16 @@
|
||||
"""Run tests for all models
|
||||
|
||||
Tests that run on CI should have a specific marker, e.g. @pytest.mark.base. This
|
||||
marker is used to parallelize the CI runs, with one runner for each marker.
|
||||
|
||||
If new tests are added, ensure that they use one of the existing markers
|
||||
(documented in pyproject.toml > pytest > markers) or that a new marker is added
|
||||
for this set of tests. If using a new marker, adjust the test matrix in
|
||||
.github/workflows/tests.yml to run tests with this new marker, otherwise the
|
||||
tests will be skipped on CI.
|
||||
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import platform
|
||||
@ -83,6 +96,7 @@ def _get_input_size(model=None, model_name='', target=None):
|
||||
return input_size
|
||||
|
||||
|
||||
@pytest.mark.base
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@ -101,6 +115,7 @@ def test_model_forward(model_name, batch_size):
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
|
||||
@pytest.mark.base
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
|
||||
@pytest.mark.parametrize('batch_size', [2])
|
||||
@ -128,6 +143,7 @@ def test_model_backward(model_name, batch_size):
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
|
||||
@pytest.mark.cfg
|
||||
@pytest.mark.timeout(300)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS, include_tags=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@ -190,6 +206,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
||||
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
|
||||
|
||||
|
||||
@pytest.mark.cfg
|
||||
@pytest.mark.timeout(300)
|
||||
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@ -274,6 +291,7 @@ EXCLUDE_JIT_FILTERS = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.torchscript
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize(
|
||||
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
|
||||
@ -303,6 +321,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
|
||||
EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d']
|
||||
|
||||
|
||||
@pytest.mark.features
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@ -379,6 +398,7 @@ if not _IS_MAC:
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.fxforward
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
@ -412,6 +432,7 @@ if not _IS_MAC:
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||
|
||||
|
||||
@pytest.mark.fxbackward
|
||||
@pytest.mark.timeout(120)
|
||||
@pytest.mark.parametrize('model_name', list_models(
|
||||
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
|
||||
|
@ -54,8 +54,7 @@ def _interpolation(kwargs):
|
||||
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
|
||||
if isinstance(interpolation, (list, tuple)):
|
||||
return random.choice(interpolation)
|
||||
else:
|
||||
return interpolation
|
||||
return interpolation
|
||||
|
||||
|
||||
def _check_args_tf(kwargs):
|
||||
@ -100,7 +99,7 @@ def rotate(img, degrees, **kwargs):
|
||||
_check_args_tf(kwargs)
|
||||
if _PIL_VER >= (5, 2):
|
||||
return img.rotate(degrees, **kwargs)
|
||||
elif _PIL_VER >= (5, 0):
|
||||
if _PIL_VER >= (5, 0):
|
||||
w, h = img.size
|
||||
post_trans = (0, 0)
|
||||
rotn_center = (w / 2.0, h / 2.0)
|
||||
@ -124,8 +123,7 @@ def rotate(img, degrees, **kwargs):
|
||||
matrix[2] += rotn_center[0]
|
||||
matrix[5] += rotn_center[1]
|
||||
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
||||
else:
|
||||
return img.rotate(degrees, resample=kwargs['resample'])
|
||||
return img.rotate(degrees, resample=kwargs['resample'])
|
||||
|
||||
|
||||
def auto_contrast(img, **__):
|
||||
@ -151,12 +149,13 @@ def solarize_add(img, add, thresh=128, **__):
|
||||
lut.append(min(255, i + add))
|
||||
else:
|
||||
lut.append(i)
|
||||
|
||||
if img.mode in ("L", "RGB"):
|
||||
if img.mode == "RGB" and len(lut) == 256:
|
||||
lut = lut + lut + lut
|
||||
return img.point(lut)
|
||||
else:
|
||||
return img
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def posterize(img, bits_to_keep, **__):
|
||||
@ -226,7 +225,7 @@ def _enhance_increasing_level_to_arg(level, _hparams):
|
||||
|
||||
def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
|
||||
level = (level / _LEVEL_DENOM)
|
||||
min_val + (max_val - min_val) * level
|
||||
level = min_val + (max_val - min_val) * level
|
||||
if clamp:
|
||||
level = max(min_val, min(max_val, level))
|
||||
return level,
|
||||
@ -552,16 +551,15 @@ def auto_augment_policy(name='v0', hparams=None):
|
||||
hparams = hparams or _HPARAMS_DEFAULT
|
||||
if name == 'original':
|
||||
return auto_augment_policy_original(hparams)
|
||||
elif name == 'originalr':
|
||||
if name == 'originalr':
|
||||
return auto_augment_policy_originalr(hparams)
|
||||
elif name == 'v0':
|
||||
if name == 'v0':
|
||||
return auto_augment_policy_v0(hparams)
|
||||
elif name == 'v0r':
|
||||
if name == 'v0r':
|
||||
return auto_augment_policy_v0r(hparams)
|
||||
elif name == '3a':
|
||||
if name == '3a':
|
||||
return auto_augment_policy_3a(hparams)
|
||||
else:
|
||||
assert False, 'Unknown AA policy (%s)' % name
|
||||
assert False, f'Unknown AA policy {name}'
|
||||
|
||||
|
||||
class AutoAugment:
|
||||
@ -576,7 +574,7 @@ class AutoAugment:
|
||||
return img
|
||||
|
||||
def __repr__(self):
|
||||
fs = self.__class__.__name__ + f'(policy='
|
||||
fs = self.__class__.__name__ + '(policy='
|
||||
for p in self.policy:
|
||||
fs += '\n\t['
|
||||
fs += ', '.join([str(op) for op in p])
|
||||
@ -636,7 +634,7 @@ _RAND_TRANSFORMS = [
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
#'Cutout' # NOTE I've implement this as random erasing separately
|
||||
# 'Cutout' # NOTE I've implement this as random erasing separately
|
||||
]
|
||||
|
||||
|
||||
@ -656,7 +654,7 @@ _RAND_INCREASING_TRANSFORMS = [
|
||||
'ShearY',
|
||||
'TranslateXRel',
|
||||
'TranslateYRel',
|
||||
#'Cutout' # NOTE I've implement this as random erasing separately
|
||||
# 'Cutout' # NOTE I've implement this as random erasing separately
|
||||
]
|
||||
|
||||
|
||||
@ -667,7 +665,7 @@ _RAND_3A = [
|
||||
]
|
||||
|
||||
|
||||
_RAND_CHOICE_3A = {
|
||||
_RAND_WEIGHTED_3A = {
|
||||
'SolarizeIncreasing': 6,
|
||||
'Desaturate': 6,
|
||||
'GaussianBlur': 6,
|
||||
@ -687,7 +685,7 @@ _RAND_CHOICE_3A = {
|
||||
|
||||
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
||||
# They may not result in increased performance, but could likely be tuned to so.
|
||||
_RAND_CHOICE_WEIGHTS_0 = {
|
||||
_RAND_WEIGHTED_0 = {
|
||||
'Rotate': 3,
|
||||
'ShearX': 2,
|
||||
'ShearY': 2,
|
||||
@ -715,13 +713,12 @@ def _get_weighted_transforms(transforms: Dict):
|
||||
|
||||
def rand_augment_choices(name: str, increasing=True):
|
||||
if name == 'weights':
|
||||
return _RAND_CHOICE_WEIGHTS_0
|
||||
elif name == '3aw':
|
||||
return _RAND_CHOICE_3A
|
||||
elif name == '3a':
|
||||
return _RAND_WEIGHTED_0
|
||||
if name == '3aw':
|
||||
return _RAND_WEIGHTED_3A
|
||||
if name == '3a':
|
||||
return _RAND_3A
|
||||
else:
|
||||
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
|
||||
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
|
||||
|
||||
|
||||
def rand_augment_ops(
|
||||
|
@ -7,7 +7,11 @@ import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import safetensors.torch
|
||||
try:
|
||||
import safetensors.torch
|
||||
_has_safetensors = True
|
||||
except ImportError:
|
||||
_has_safetensors = False
|
||||
|
||||
import timm.models._builder
|
||||
|
||||
@ -29,6 +33,7 @@ def load_state_dict(checkpoint_path, use_ema=True):
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
# Check if safetensors or not and load weights accordingly
|
||||
if str(checkpoint_path).endswith(".safetensors"):
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
@ -7,15 +7,21 @@ from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
||||
import safetensors.torch
|
||||
|
||||
try:
|
||||
from torch.hub import get_dir
|
||||
except ImportError:
|
||||
from torch.hub import _get_torch_home as get_dir
|
||||
|
||||
try:
|
||||
import safetensors.torch
|
||||
_has_safetensors = True
|
||||
except ImportError:
|
||||
_has_safetensors = False
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
else:
|
||||
@ -45,6 +51,7 @@ __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'l
|
||||
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
|
||||
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
|
||||
|
||||
|
||||
def get_cache_dir(child_dir=''):
|
||||
"""
|
||||
Returns the location of the directory where models are cached (and creates it if necessary).
|
||||
@ -164,21 +171,28 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
|
||||
# Look for .safetensors alternatives and load from it if it exists
|
||||
for safe_filename in _get_safe_alternatives(filename):
|
||||
try:
|
||||
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
|
||||
_logger.info(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
|
||||
return safetensors.torch.load_file(cached_safe_file, device="cpu")
|
||||
except EntryNotFoundError:
|
||||
pass
|
||||
if _has_safetensors:
|
||||
for safe_filename in _get_safe_alternatives(filename):
|
||||
try:
|
||||
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
|
||||
_logger.info(
|
||||
f"[{model_id}] Safe alternative available for '{filename}' "
|
||||
f"(as '{safe_filename}'). Loading weights using safetensors.")
|
||||
return safetensors.torch.load_file(cached_safe_file, device="cpu")
|
||||
except EntryNotFoundError:
|
||||
pass
|
||||
|
||||
# Otherwise, load using pytorch.load
|
||||
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
||||
_logger.info(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
||||
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
||||
return torch.load(cached_file, map_location='cpu')
|
||||
|
||||
|
||||
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
|
||||
def save_config_for_hf(
|
||||
model,
|
||||
config_path: str,
|
||||
model_config: Optional[dict] = None
|
||||
):
|
||||
model_config = model_config or {}
|
||||
hf_config = {}
|
||||
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
||||
@ -220,8 +234,8 @@ def save_for_hf(
|
||||
model,
|
||||
save_directory: str,
|
||||
model_config: Optional[dict] = None,
|
||||
safe_serialization: Union[bool, Literal["both"]] = False
|
||||
):
|
||||
safe_serialization: Union[bool, Literal["both"]] = False,
|
||||
):
|
||||
assert has_hf_hub(True)
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(exist_ok=True, parents=True)
|
||||
@ -229,6 +243,7 @@ def save_for_hf(
|
||||
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
|
||||
tensors = model.state_dict()
|
||||
if safe_serialization is True or safe_serialization == "both":
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
|
||||
if safe_serialization is False or safe_serialization == "both":
|
||||
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
|
||||
@ -238,16 +253,16 @@ def save_for_hf(
|
||||
|
||||
|
||||
def push_to_hf_hub(
|
||||
model,
|
||||
repo_id: str,
|
||||
commit_message: str = 'Add model',
|
||||
token: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
private: bool = False,
|
||||
create_pr: bool = False,
|
||||
model_config: Optional[dict] = None,
|
||||
model_card: Optional[dict] = None,
|
||||
safe_serialization: Union[bool, Literal["both"]] = False
|
||||
model,
|
||||
repo_id: str,
|
||||
commit_message: str = 'Add model',
|
||||
token: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
private: bool = False,
|
||||
create_pr: bool = False,
|
||||
model_config: Optional[dict] = None,
|
||||
model_card: Optional[dict] = None,
|
||||
safe_serialization: Union[bool, Literal["both"]] = False,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
@ -341,6 +356,7 @@ def generate_readme(model_card: dict, model_name: str):
|
||||
readme_text += f"```bibtex\n{c}\n```\n"
|
||||
return readme_text
|
||||
|
||||
|
||||
def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
||||
"""Returns potential safetensors alternatives for a given filename.
|
||||
|
||||
@ -350,5 +366,5 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
||||
"""
|
||||
if filename == HF_WEIGHTS_NAME:
|
||||
yield HF_SAFE_WEIGHTS_NAME
|
||||
if filename.endswith(".bin"):
|
||||
yield filename[:-4] + ".safetensors"
|
||||
if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"):
|
||||
return filename[:-4] + ".safetensors"
|
||||
|
@ -93,7 +93,7 @@ class DefaultCfg:
|
||||
return tag, self.cfgs[tag]
|
||||
|
||||
|
||||
def split_model_name_tag(model_name: str, no_tag=''):
|
||||
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
|
||||
model_name, *tag_list = model_name.split('.', 1)
|
||||
tag = tag_list[0] if tag_list else no_tag
|
||||
return model_name, tag
|
||||
|
@ -8,7 +8,7 @@ import sys
|
||||
from collections import defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from dataclasses import replace
|
||||
from typing import List, Optional, Union, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
|
||||
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||
|
||||
@ -16,20 +16,20 @@ __all__ = [
|
||||
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
||||
|
||||
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
|
||||
_model_to_module = {} # mapping of model names to module names
|
||||
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
|
||||
_model_has_pretrained = set() # set of model names that have pretrained weight url present
|
||||
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
|
||||
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
|
||||
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
||||
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
|
||||
_model_to_module: Dict[str, str] = {} # mapping of model names to module names
|
||||
_model_entrypoints: Dict[str, Callable[..., Any]] = {} # mapping of model names to architecture entrypoint fns
|
||||
_model_has_pretrained: Set[str] = set() # set of model names that have pretrained weight url present
|
||||
_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
|
||||
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
|
||||
_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
||||
|
||||
|
||||
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]:
|
||||
def get_arch_name(model_name: str) -> str:
|
||||
return split_model_name_tag(model_name)[0]
|
||||
|
||||
|
||||
def register_model(fn):
|
||||
def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
# lookup containing module
|
||||
mod = sys.modules[fn.__module__]
|
||||
module_name_split = fn.__module__.split('.')
|
||||
@ -40,7 +40,7 @@ def register_model(fn):
|
||||
if hasattr(mod, '__all__'):
|
||||
mod.__all__.append(model_name)
|
||||
else:
|
||||
mod.__all__ = [model_name]
|
||||
mod.__all__ = [model_name] # type: ignore
|
||||
|
||||
# add entries to registry dict/sets
|
||||
_model_entrypoints[model_name] = fn
|
||||
@ -87,28 +87,33 @@ def register_model(fn):
|
||||
return fn
|
||||
|
||||
|
||||
def _natural_key(string_):
|
||||
def _natural_key(string_: str) -> List[Union[int, str]]:
|
||||
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
||||
|
||||
def list_models(
|
||||
filter: Union[str, List[str]] = '',
|
||||
module: str = '',
|
||||
pretrained=False,
|
||||
exclude_filters: str = '',
|
||||
pretrained: bool = False,
|
||||
exclude_filters: Union[str, List[str]] = '',
|
||||
name_matches_cfg: bool = False,
|
||||
include_tags: Optional[bool] = None,
|
||||
):
|
||||
) -> List[str]:
|
||||
""" Return list of available model names, sorted alphabetically
|
||||
|
||||
Args:
|
||||
filter (str) - Wildcard filter string that works with fnmatch
|
||||
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
|
||||
pretrained (bool) - Include only models with valid pretrained weights if True
|
||||
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
|
||||
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
||||
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
|
||||
filter - Wildcard filter string that works with fnmatch
|
||||
module - Limit model selection to a specific submodule (ie 'vision_transformer')
|
||||
pretrained - Include only models with valid pretrained weights if True
|
||||
exclude_filters - Wildcard filters to exclude models after including them with filter
|
||||
name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
||||
include_tags - Include pretrained tags in model names (model.tag). If None, defaults
|
||||
set to True when pretrained=True else False (default: None)
|
||||
|
||||
Returns:
|
||||
models - The sorted list of models
|
||||
|
||||
Example:
|
||||
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
||||
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
||||
@ -118,7 +123,7 @@ def list_models(
|
||||
include_tags = pretrained
|
||||
|
||||
if module:
|
||||
all_models = list(_module_to_models[module])
|
||||
all_models: Iterable[str] = list(_module_to_models[module])
|
||||
else:
|
||||
all_models = _model_entrypoints.keys()
|
||||
|
||||
@ -130,14 +135,14 @@ def list_models(
|
||||
all_models = models_with_tags
|
||||
|
||||
if filter:
|
||||
models = []
|
||||
models: Set[str] = set()
|
||||
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
|
||||
for f in include_filters:
|
||||
include_models = fnmatch.filter(all_models, f) # include these models
|
||||
if len(include_models):
|
||||
models = set(models).union(include_models)
|
||||
models = models.union(include_models)
|
||||
else:
|
||||
models = all_models
|
||||
models = set(all_models)
|
||||
|
||||
if exclude_filters:
|
||||
if not isinstance(exclude_filters, (tuple, list)):
|
||||
@ -145,7 +150,7 @@ def list_models(
|
||||
for xf in exclude_filters:
|
||||
exclude_models = fnmatch.filter(models, xf) # exclude these models
|
||||
if len(exclude_models):
|
||||
models = set(models).difference(exclude_models)
|
||||
models = models.difference(exclude_models)
|
||||
|
||||
if pretrained:
|
||||
models = _model_has_pretrained.intersection(models)
|
||||
@ -153,13 +158,13 @@ def list_models(
|
||||
if name_matches_cfg:
|
||||
models = set(_model_pretrained_cfgs).intersection(models)
|
||||
|
||||
return list(sorted(models, key=_natural_key))
|
||||
return sorted(models, key=_natural_key)
|
||||
|
||||
|
||||
def list_pretrained(
|
||||
filter: Union[str, List[str]] = '',
|
||||
exclude_filters: str = '',
|
||||
):
|
||||
) -> List[str]:
|
||||
return list_models(
|
||||
filter=filter,
|
||||
pretrained=True,
|
||||
@ -168,14 +173,14 @@ def list_pretrained(
|
||||
)
|
||||
|
||||
|
||||
def is_model(model_name):
|
||||
def is_model(model_name: str) -> bool:
|
||||
""" Check if a model name exists
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
return arch_name in _model_entrypoints
|
||||
|
||||
|
||||
def model_entrypoint(model_name, module_filter: Optional[str] = None):
|
||||
def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
|
||||
"""Fetch a model entrypoint for specified model name
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None):
|
||||
return _model_entrypoints[arch_name]
|
||||
|
||||
|
||||
def list_modules():
|
||||
def list_modules() -> List[str]:
|
||||
""" Return list of module names that contain models / model entrypoints
|
||||
"""
|
||||
modules = _module_to_models.keys()
|
||||
return list(sorted(modules))
|
||||
return sorted(modules)
|
||||
|
||||
|
||||
def is_model_in_modules(model_name, module_names):
|
||||
def is_model_in_modules(
|
||||
model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
|
||||
) -> bool:
|
||||
"""Check if a model exists within a subset of modules
|
||||
|
||||
Args:
|
||||
model_name (str) - name of model to check
|
||||
module_names (tuple, list, set) - names of modules to search in
|
||||
model_name - name of model to check
|
||||
module_names - names of modules to search in
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
assert isinstance(module_names, (tuple, list, set))
|
||||
return any(arch_name in _module_to_models[n] for n in module_names)
|
||||
|
||||
|
||||
def is_model_pretrained(model_name):
|
||||
def is_model_pretrained(model_name: str) -> bool:
|
||||
return model_name in _model_has_pretrained
|
||||
|
||||
|
||||
def get_pretrained_cfg(model_name, allow_unregistered=True):
|
||||
def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
|
||||
if model_name in _model_pretrained_cfgs:
|
||||
return deepcopy(_model_pretrained_cfgs[model_name])
|
||||
arch_name, tag = split_model_name_tag(model_name)
|
||||
@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True):
|
||||
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
|
||||
|
||||
|
||||
def get_pretrained_cfg_value(model_name, cfg_key):
|
||||
def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
|
||||
""" Get a specific model default_cfg value by key. None if key doesn't exist.
|
||||
"""
|
||||
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
|
||||
|
@ -45,7 +45,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
|
||||
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
|
||||
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
|
||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||
from ._builder import build_model_with_cfg
|
||||
@ -56,6 +56,28 @@ from ._registry import register_model
|
||||
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_chs, out_chs, stride=1, dilation=1):
|
||||
super().__init__()
|
||||
avg_stride = stride if dilation == 1 else 1
|
||||
if stride > 1 or dilation > 1:
|
||||
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
||||
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
||||
else:
|
||||
self.pool = nn.Identity()
|
||||
|
||||
if in_chs != out_chs:
|
||||
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
|
||||
else:
|
||||
self.conv = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool(x)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvNeXtBlock(nn.Module):
|
||||
""" ConvNeXt Block
|
||||
There are two equivalent implementations:
|
||||
@ -65,41 +87,65 @@ class ConvNeXtBlock(nn.Module):
|
||||
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
|
||||
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
|
||||
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
|
||||
|
||||
Args:
|
||||
in_chs (int): Number of input channels.
|
||||
drop_path (float): Stochastic depth rate. Default: 0.0
|
||||
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chs,
|
||||
out_chs=None,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
mlp_ratio=4,
|
||||
conv_mlp=False,
|
||||
conv_bias=True,
|
||||
use_grn=False,
|
||||
ls_init_value=1e-6,
|
||||
act_layer='gelu',
|
||||
norm_layer=None,
|
||||
drop_path=0.,
|
||||
in_chs: int,
|
||||
out_chs: Optional[int] = None,
|
||||
kernel_size: int = 7,
|
||||
stride: int = 1,
|
||||
dilation: Union[int, Tuple[int, int]] = (1, 1),
|
||||
mlp_ratio: float = 4,
|
||||
conv_mlp: bool = False,
|
||||
conv_bias: bool = True,
|
||||
use_grn: bool = False,
|
||||
ls_init_value: Optional[float] = 1e-6,
|
||||
act_layer: Union[str, Callable] = 'gelu',
|
||||
norm_layer: Optional[Callable] = None,
|
||||
drop_path: float = 0.,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
in_chs: Block input channels.
|
||||
out_chs: Block output channels (same as in_chs if None).
|
||||
kernel_size: Depthwise convolution kernel size.
|
||||
stride: Stride of depthwise convolution.
|
||||
dilation: Tuple specifying input and output dilation of block.
|
||||
mlp_ratio: MLP expansion ratio.
|
||||
conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
|
||||
conv_bias: Apply bias for all convolution (linear) layers.
|
||||
use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
|
||||
ls_init_value: Layer-scale init values, layer-scale applied if not None.
|
||||
act_layer: Activation layer.
|
||||
norm_layer: Normalization layer (defaults to LN if not specified).
|
||||
drop_path: Stochastic depth probability.
|
||||
"""
|
||||
super().__init__()
|
||||
out_chs = out_chs or in_chs
|
||||
dilation = to_ntuple(2)(dilation)
|
||||
act_layer = get_act_layer(act_layer)
|
||||
if not norm_layer:
|
||||
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
|
||||
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
|
||||
self.use_conv_mlp = conv_mlp
|
||||
self.conv_dw = create_conv2d(
|
||||
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation[0],
|
||||
depthwise=True,
|
||||
bias=conv_bias,
|
||||
)
|
||||
self.norm = norm_layer(out_chs)
|
||||
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
|
||||
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
|
||||
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
||||
self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
|
||||
else:
|
||||
self.shortcut = nn.Identity()
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
@ -116,7 +162,7 @@ class ConvNeXtBlock(nn.Module):
|
||||
if self.gamma is not None:
|
||||
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
||||
|
||||
x = self.drop_path(x) + shortcut
|
||||
x = self.drop_path(x) + self.shortcut(shortcut)
|
||||
return x
|
||||
|
||||
|
||||
@ -148,8 +194,14 @@ class ConvNeXtStage(nn.Module):
|
||||
self.downsample = nn.Sequential(
|
||||
norm_layer(in_chs),
|
||||
create_conv2d(
|
||||
in_chs, out_chs, kernel_size=ds_ks, stride=stride,
|
||||
dilation=dilation[0], padding=pad, bias=conv_bias),
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=ds_ks,
|
||||
stride=stride,
|
||||
dilation=dilation[0],
|
||||
padding=pad,
|
||||
bias=conv_bias,
|
||||
),
|
||||
)
|
||||
in_chs = out_chs
|
||||
else:
|
||||
@ -773,136 +825,147 @@ default_cfgs = generate_default_cfgs({
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
|
||||
'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
|
||||
'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
|
||||
'convnext_xxlarge.clip_laion2b_soup': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
|
||||
'convnext_xxlarge.clip_laion2b_rewind': _cfg(
|
||||
hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind',
|
||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_atto(pretrained=False, **kwargs):
|
||||
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
||||
model_args = dict(
|
||||
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, **kwargs)
|
||||
model = _create_convnext('convnext_atto', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True)
|
||||
model = _create_convnext('convnext_atto', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_atto_ols(pretrained=False, **kwargs):
|
||||
# timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
|
||||
model_args = dict(
|
||||
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
|
||||
model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered')
|
||||
model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_femto(pretrained=False, **kwargs):
|
||||
# timm femto variant
|
||||
model_args = dict(
|
||||
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, **kwargs)
|
||||
model = _create_convnext('convnext_femto', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True)
|
||||
model = _create_convnext('convnext_femto', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_femto_ols(pretrained=False, **kwargs):
|
||||
# timm femto variant
|
||||
model_args = dict(
|
||||
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
|
||||
model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered')
|
||||
model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_pico(pretrained=False, **kwargs):
|
||||
# timm pico variant
|
||||
model_args = dict(
|
||||
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, **kwargs)
|
||||
model = _create_convnext('convnext_pico', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True)
|
||||
model = _create_convnext('convnext_pico', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_pico_ols(pretrained=False, **kwargs):
|
||||
# timm nano variant with overlapping 3x3 conv stem
|
||||
model_args = dict(
|
||||
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
|
||||
model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered')
|
||||
model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_nano(pretrained=False, **kwargs):
|
||||
# timm nano variant with standard stem and head
|
||||
model_args = dict(
|
||||
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs)
|
||||
model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True)
|
||||
model = _create_convnext('convnext_nano', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_nano_ols(pretrained=False, **kwargs):
|
||||
# experimental nano variant with overlapping conv stem
|
||||
model_args = dict(
|
||||
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap', **kwargs)
|
||||
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap')
|
||||
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_tiny_hnf(pretrained=False, **kwargs):
|
||||
# experimental tiny variant with norm before pooling in head (head norm first)
|
||||
model_args = dict(
|
||||
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
|
||||
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True)
|
||||
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_tiny(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
|
||||
model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768))
|
||||
model = _create_convnext('convnext_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_small(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
|
||||
model = _create_convnext('convnext_small', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
|
||||
model = _create_convnext('convnext_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_base(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
|
||||
model = _create_convnext('convnext_base', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
|
||||
model = _create_convnext('convnext_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_large(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
|
||||
model = _create_convnext('convnext_large', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
|
||||
model = _create_convnext('convnext_large', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_large_mlp(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536, **kwargs)
|
||||
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536)
|
||||
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_xlarge(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
|
||||
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])
|
||||
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnext_xxlarge(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], **kwargs)
|
||||
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
|
||||
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@ -910,8 +973,8 @@ def convnext_xxlarge(pretrained=False, **kwargs):
|
||||
def convnextv2_atto(pretrained=False, **kwargs):
|
||||
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
||||
model_args = dict(
|
||||
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
|
||||
model = _create_convnext('convnextv2_atto', pretrained=pretrained, **model_args)
|
||||
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True)
|
||||
model = _create_convnext('convnextv2_atto', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@ -919,8 +982,8 @@ def convnextv2_atto(pretrained=False, **kwargs):
|
||||
def convnextv2_femto(pretrained=False, **kwargs):
|
||||
# timm femto variant
|
||||
model_args = dict(
|
||||
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
|
||||
model = _create_convnext('convnextv2_femto', pretrained=pretrained, **model_args)
|
||||
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True)
|
||||
model = _create_convnext('convnextv2_femto', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@ -928,8 +991,8 @@ def convnextv2_femto(pretrained=False, **kwargs):
|
||||
def convnextv2_pico(pretrained=False, **kwargs):
|
||||
# timm pico variant
|
||||
model_args = dict(
|
||||
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
|
||||
model = _create_convnext('convnextv2_pico', pretrained=pretrained, **model_args)
|
||||
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True)
|
||||
model = _create_convnext('convnextv2_pico', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@ -937,42 +1000,41 @@ def convnextv2_pico(pretrained=False, **kwargs):
|
||||
def convnextv2_nano(pretrained=False, **kwargs):
|
||||
# timm nano variant with standard stem and head
|
||||
model_args = dict(
|
||||
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
|
||||
model = _create_convnext('convnextv2_nano', pretrained=pretrained, **model_args)
|
||||
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True)
|
||||
model = _create_convnext('convnextv2_nano', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnextv2_tiny(pretrained=False, **kwargs):
|
||||
model_args = dict(
|
||||
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None, **kwargs)
|
||||
model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None)
|
||||
model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnextv2_small(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None, **kwargs)
|
||||
model = _create_convnext('convnextv2_small', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None)
|
||||
model = _create_convnext('convnextv2_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnextv2_base(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None, **kwargs)
|
||||
model = _create_convnext('convnextv2_base', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None)
|
||||
model = _create_convnext('convnextv2_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnextv2_large(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None, **kwargs)
|
||||
model = _create_convnext('convnextv2_large', pretrained=pretrained, **model_args)
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None)
|
||||
model = _create_convnext('convnextv2_large', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def convnextv2_huge(pretrained=False, **kwargs):
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None, **kwargs)
|
||||
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **model_args)
|
||||
return model
|
||||
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
|
||||
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
@ -20,7 +20,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath
|
||||
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features_fx import register_notrace_module
|
||||
from ._registry import register_model
|
||||
@ -564,6 +564,7 @@ class MobileVitV2Block(nn.Module):
|
||||
|
||||
self.patch_size = to_2tuple(patch_size)
|
||||
self.patch_area = self.patch_size[0] * self.patch_size[1]
|
||||
self.coreml_exportable = is_exportable()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, C, H, W = x.shape
|
||||
@ -580,7 +581,10 @@ class MobileVitV2Block(nn.Module):
|
||||
|
||||
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
|
||||
C = x.shape[1]
|
||||
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
|
||||
if self.coreml_exportable:
|
||||
x = F.unfold(x, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
|
||||
else:
|
||||
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
|
||||
x = x.reshape(B, C, -1, num_patches)
|
||||
|
||||
# Global representations
|
||||
@ -588,8 +592,14 @@ class MobileVitV2Block(nn.Module):
|
||||
x = self.norm(x)
|
||||
|
||||
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
|
||||
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
|
||||
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||
if self.coreml_exportable:
|
||||
# adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624
|
||||
x = x.reshape(B, C * patch_h * patch_w, num_patch_h, num_patch_w)
|
||||
x = F.pixel_shuffle(x, upscale_factor=patch_h)
|
||||
else:
|
||||
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
|
||||
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||
|
||||
|
||||
x = self.conv_proj(x)
|
||||
return x
|
||||
|
@ -1 +1 @@
|
||||
__version__ = '0.8.12dev0'
|
||||
__version__ = '0.8.15dev0'
|
||||
|
10
train.py
10
train.py
@ -514,8 +514,14 @@ def main():
|
||||
if utils.is_primary(args):
|
||||
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
|
||||
elif use_amp == 'native':
|
||||
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
||||
if device.type == 'cuda':
|
||||
try:
|
||||
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
||||
except (AttributeError, TypeError):
|
||||
# fallback to CUDA only AMP for PyTorch < 1.10
|
||||
assert device.type == 'cuda'
|
||||
amp_autocast = torch.cuda.amp.autocast
|
||||
if device.type == 'cuda' and amp_dtype == torch.float16:
|
||||
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
|
||||
loss_scaler = NativeScaler()
|
||||
if utils.is_primary(args):
|
||||
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
||||
|
Loading…
x
Reference in New Issue
Block a user