mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge remote-tracking branch 'origin/main' into focalnet_and_swin_refactor
This commit is contained in:
commit
c30a160d3e
7
.github/workflows/tests.yml
vendored
7
.github/workflows/tests.yml
vendored
@ -19,6 +19,7 @@ jobs:
|
|||||||
python: ['3.10']
|
python: ['3.10']
|
||||||
torch: ['1.13.0']
|
torch: ['1.13.0']
|
||||||
torchvision: ['0.14.0']
|
torchvision: ['0.14.0']
|
||||||
|
testmarker: ['-k "not test_models"', '-m base', '-m cfg', '-m torchscript', '-m features', '-m fxforward', '-m fxbackward']
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@ -30,7 +31,7 @@ jobs:
|
|||||||
- name: Install testing dependencies
|
- name: Install testing dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install pytest pytest-timeout pytest-xdist pytest-forked expecttest
|
pip install -r requirements-dev.txt
|
||||||
- name: Install torch on mac
|
- name: Install torch on mac
|
||||||
if: startsWith(matrix.os, 'macOS')
|
if: startsWith(matrix.os, 'macOS')
|
||||||
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
|
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
|
||||||
@ -54,10 +55,10 @@ jobs:
|
|||||||
PYTHONDONTWRITEBYTECODE: 1
|
PYTHONDONTWRITEBYTECODE: 1
|
||||||
run: |
|
run: |
|
||||||
pytest -vv tests
|
pytest -vv tests
|
||||||
- name: Run tests on Linux / Mac
|
- name: Run '${{ matrix.testmarker }}' tests on Linux / Mac
|
||||||
if: ${{ !startsWith(matrix.os, 'windows') }}
|
if: ${{ !startsWith(matrix.os, 'windows') }}
|
||||||
env:
|
env:
|
||||||
LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
|
LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
|
||||||
PYTHONDONTWRITEBYTECODE: 1
|
PYTHONDONTWRITEBYTECODE: 1
|
||||||
run: |
|
run: |
|
||||||
pytest -vv --forked --durations=0 tests
|
pytest -vv --forked --durations=0 ${{ matrix.testmarker }} tests
|
||||||
|
107
CONTRIBUTING.md
Normal file
107
CONTRIBUTING.md
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
*This guideline is very much a work-in-progress.*
|
||||||
|
|
||||||
|
Contriubtions to `timm` for code, documentation, tests are more than welcome!
|
||||||
|
|
||||||
|
There haven't been any formal guidelines to date so please bear with me, and feel free to add to this guide.
|
||||||
|
|
||||||
|
# Coding style
|
||||||
|
|
||||||
|
Code linting and auto-format (black) are not currently in place but open to consideration. In the meantime, the style to follow is (mostly) aligned with Google's guide: https://google.github.io/styleguide/pyguide.html.
|
||||||
|
|
||||||
|
A few specific differences from Google style (or black)
|
||||||
|
1. Line length is 120 char. Going over is okay in some cases (e.g. I prefer not to break URL across lines).
|
||||||
|
2. Hanging indents are always prefered, please avoid aligning arguments with closing brackets or braces.
|
||||||
|
|
||||||
|
Example, from Google guide, but this is a NO here:
|
||||||
|
```
|
||||||
|
# Aligned with opening delimiter.
|
||||||
|
foo = long_function_name(var_one, var_two,
|
||||||
|
var_three, var_four)
|
||||||
|
meal = (spam,
|
||||||
|
beans)
|
||||||
|
|
||||||
|
# Aligned with opening delimiter in a dictionary.
|
||||||
|
foo = {
|
||||||
|
'long_dictionary_key': value1 +
|
||||||
|
value2,
|
||||||
|
...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
This is YES:
|
||||||
|
|
||||||
|
```
|
||||||
|
# 4-space hanging indent; nothing on first line,
|
||||||
|
# closing parenthesis on a new line.
|
||||||
|
foo = long_function_name(
|
||||||
|
var_one, var_two, var_three,
|
||||||
|
var_four
|
||||||
|
)
|
||||||
|
meal = (
|
||||||
|
spam,
|
||||||
|
beans,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4-space hanging indent in a dictionary.
|
||||||
|
foo = {
|
||||||
|
'long_dictionary_key':
|
||||||
|
long_dictionary_value,
|
||||||
|
...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
When there is descrepancy in a given source file (there are many origins for various bits of code and not all have been updated to what I consider current goal), please follow the style in a given file.
|
||||||
|
|
||||||
|
In general, if you add new code, formatting it with black using the following options should result in a style that is compatible with the rest of the code base:
|
||||||
|
|
||||||
|
```
|
||||||
|
black --skip-string-normalization --line-length 120 <path-to-file>
|
||||||
|
```
|
||||||
|
|
||||||
|
Avoid formatting code that is unrelated to your PR though.
|
||||||
|
|
||||||
|
PR with pure formatting / style fixes will be accepted but only in isolation from functional changes, best to ask before starting such a change.
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
|
||||||
|
As with code style, docstrings style based on the Google guide: guide: https://google.github.io/styleguide/pyguide.html
|
||||||
|
|
||||||
|
The goal for the code is to eventually move to have all major functions and `__init__` methods use PEP484 type annotations.
|
||||||
|
|
||||||
|
When type annotations are used for a function, as per the Google pyguide, they should **NOT** be duplicated in the docstrings, please leave annotations as the one source of truth re typing.
|
||||||
|
|
||||||
|
There are a LOT of gaps in current documentation relative to the functionality in timm, please, document away!
|
||||||
|
|
||||||
|
# Installation
|
||||||
|
|
||||||
|
Create a Python virtual environment using Python 3.10. Inside the environment, install torch` and `torchvision` using the instructions matching your system as listed on the [PyTorch website](https://pytorch.org/).
|
||||||
|
|
||||||
|
Then install the remaining dependencies:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m pip install -r requirements.txt
|
||||||
|
python -m pip install -r requirements-dev.txt # for testing
|
||||||
|
python -m pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git
|
||||||
|
python -m pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Unit tests
|
||||||
|
|
||||||
|
Run the tests using:
|
||||||
|
|
||||||
|
```
|
||||||
|
pytest tests/
|
||||||
|
```
|
||||||
|
|
||||||
|
Since the whole test suite takes a lot of time to run locally (a few hours), you may want to select a subset of tests relating to the changes you made by using the `-k` option of [`pytest`](https://docs.pytest.org/en/7.1.x/example/markers.html#using-k-expr-to-select-tests-based-on-their-name). Moreover, running tests in parallel (in this example 4 processes) with the `-n` option may help:
|
||||||
|
|
||||||
|
```
|
||||||
|
pytest -k "substring-to-match" -n 4 tests/
|
||||||
|
```
|
||||||
|
|
||||||
|
## Building documentation
|
||||||
|
|
||||||
|
Please refer to [this document](https://github.com/huggingface/pytorch-image-models/tree/main/hfdocs).
|
||||||
|
|
||||||
|
# Questions
|
||||||
|
|
||||||
|
If you have any questions about contribution, where / how to contribute, please ask in the [Discussions](https://github.com/huggingface/pytorch-image-models/discussions/categories/contributing) (there is a `Contributing` topic).
|
27
README.md
27
README.md
@ -24,6 +24,15 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
|
|||||||
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
|
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
|
||||||
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
|
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
|
||||||
|
|
||||||
|
### Feb 26, 2023
|
||||||
|
* Add ConvNeXt-XXLarge CLIP pretrained image tower weights for fine-tune & features (fine-tuning TBD) -- see [model card](https://huggingface.co/laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup)
|
||||||
|
* Update `convnext_xxlarge` default LayerNorm eps to 1e-5 (for CLIP weights, improved stability)
|
||||||
|
* 0.8.15dev0
|
||||||
|
|
||||||
|
### Feb 20, 2023
|
||||||
|
* Add 320x320 `convnext_large_mlp.clip_laion2b_ft_320` and `convnext_lage_mlp.clip_laion2b_ft_soup_320` CLIP image tower weights for features & fine-tune
|
||||||
|
* 0.8.13dev0 pypi release for latest changes w/ move to huggingface org
|
||||||
|
|
||||||
### Feb 16, 2023
|
### Feb 16, 2023
|
||||||
* `safetensor` checkpoint support added
|
* `safetensor` checkpoint support added
|
||||||
* Add ideas from 'Scaling Vision Transformers to 22 B. Params' (https://arxiv.org/abs/2302.05442) -- qk norm, RmsNorm, parallel block
|
* Add ideas from 'Scaling Vision Transformers to 22 B. Params' (https://arxiv.org/abs/2302.05442) -- qk norm, RmsNorm, parallel block
|
||||||
@ -435,9 +444,7 @@ The work of many others is present here. I've tried to make sure all source mate
|
|||||||
|
|
||||||
## Models
|
## Models
|
||||||
|
|
||||||
All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated. Here are some example [training hparams](https://rwightman.github.io/pytorch-image-models/training_hparam_examples) to get you started.
|
All model architecture families include variants with pretrained weights. There are specific model variants without any weights, it is NOT a bug. Help training new or better weights is always appreciated.
|
||||||
|
|
||||||
A full version of the list below with source links can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/models/).
|
|
||||||
|
|
||||||
* Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723
|
* Aggregating Nested Transformers - https://arxiv.org/abs/2105.12723
|
||||||
* BEiT - https://arxiv.org/abs/2106.08254
|
* BEiT - https://arxiv.org/abs/2106.08254
|
||||||
@ -538,15 +545,15 @@ Several (less common) features that I often utilize in my projects are included.
|
|||||||
|
|
||||||
* All models have a common default configuration interface and API for
|
* All models have a common default configuration interface and API for
|
||||||
* accessing/changing the classifier - `get_classifier` and `reset_classifier`
|
* accessing/changing the classifier - `get_classifier` and `reset_classifier`
|
||||||
* doing a forward pass on just the features - `forward_features` (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/))
|
* doing a forward pass on just the features - `forward_features` (see [documentation](https://huggingface.co/docs/timm/feature_extraction))
|
||||||
* these makes it easy to write consistent network wrappers that work with any of the models
|
* these makes it easy to write consistent network wrappers that work with any of the models
|
||||||
* All models support multi-scale feature map extraction (feature pyramids) via create_model (see [documentation](https://rwightman.github.io/pytorch-image-models/feature_extraction/))
|
* All models support multi-scale feature map extraction (feature pyramids) via create_model (see [documentation](https://huggingface.co/docs/timm/feature_extraction))
|
||||||
* `create_model(name, features_only=True, out_indices=..., output_stride=...)`
|
* `create_model(name, features_only=True, out_indices=..., output_stride=...)`
|
||||||
* `out_indices` creation arg specifies which feature maps to return, these indices are 0 based and generally correspond to the `C(i + 1)` feature level.
|
* `out_indices` creation arg specifies which feature maps to return, these indices are 0 based and generally correspond to the `C(i + 1)` feature level.
|
||||||
* `output_stride` creation arg controls output stride of the network by using dilated convolutions. Most networks are stride 32 by default. Not all networks support this.
|
* `output_stride` creation arg controls output stride of the network by using dilated convolutions. Most networks are stride 32 by default. Not all networks support this.
|
||||||
* feature map channel counts, reduction level (stride) can be queried AFTER model creation via the `.feature_info` member
|
* feature map channel counts, reduction level (stride) can be queried AFTER model creation via the `.feature_info` member
|
||||||
* All models have a consistent pretrained weight loader that adapts last linear if necessary, and from 3 to 1 channel input if desired
|
* All models have a consistent pretrained weight loader that adapts last linear if necessary, and from 3 to 1 channel input if desired
|
||||||
* High performance [reference training, validation, and inference scripts](https://rwightman.github.io/pytorch-image-models/scripts/) that work in several process/GPU modes:
|
* High performance [reference training, validation, and inference scripts](https://huggingface.co/docs/timm/training_script) that work in several process/GPU modes:
|
||||||
* NVIDIA DDP w/ a single GPU per process, multiple processes with APEX present (AMP mixed-precision optional)
|
* NVIDIA DDP w/ a single GPU per process, multiple processes with APEX present (AMP mixed-precision optional)
|
||||||
* PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled)
|
* PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled)
|
||||||
* PyTorch w/ single GPU single process (AMP optional)
|
* PyTorch w/ single GPU single process (AMP optional)
|
||||||
@ -600,19 +607,17 @@ Model validation results can be found in the [results tables](results/README.md)
|
|||||||
|
|
||||||
## Getting Started (Documentation)
|
## Getting Started (Documentation)
|
||||||
|
|
||||||
My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics.
|
The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.
|
||||||
|
|
||||||
Hugging Face [`timm` docs](https://huggingface.co/docs/hub/timm) will be the documentation focus going forward and will eventually replace the `github.io` docs above.
|
|
||||||
|
|
||||||
[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
|
[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
|
||||||
|
|
||||||
[timmdocs](http://timm.fast.ai/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
|
[timmdocs](http://timm.fast.ai/) is an alternate set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
|
||||||
|
|
||||||
[paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`.
|
[paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`.
|
||||||
|
|
||||||
## Train, Validation, Inference Scripts
|
## Train, Validation, Inference Scripts
|
||||||
|
|
||||||
The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See [documentation](https://rwightman.github.io/pytorch-image-models/scripts/) for some basics and [training hparams](https://rwightman.github.io/pytorch-image-models/training_hparam_examples) for some train examples that produce SOTA ImageNet results.
|
The root folder of the repository contains reference train, validation, and inference scripts that work with the included models and other features of this repository. They are adaptable for other datasets and use cases with a little hacking. See [documentation](https://huggingface.co/docs/timm/training_script).
|
||||||
|
|
||||||
## Awesome PyTorch Resources
|
## Awesome PyTorch Resources
|
||||||
|
|
||||||
|
@ -17,10 +17,14 @@ import os
|
|||||||
import glob
|
import glob
|
||||||
import hashlib
|
import hashlib
|
||||||
from timm.models import load_state_dict
|
from timm.models import load_state_dict
|
||||||
import safetensors.torch
|
try:
|
||||||
|
import safetensors.torch
|
||||||
|
_has_safetensors = True
|
||||||
|
except ImportError:
|
||||||
|
_has_safetensors = False
|
||||||
|
|
||||||
DEFAULT_OUTPUT = "./average.pth"
|
DEFAULT_OUTPUT = "./averaged.pth"
|
||||||
DEFAULT_SAFE_OUTPUT = "./average.safetensors"
|
DEFAULT_SAFE_OUTPUT = "./averaged.safetensors"
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
|
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
|
||||||
parser.add_argument('--input', default='', type=str, metavar='PATH',
|
parser.add_argument('--input', default='', type=str, metavar='PATH',
|
||||||
@ -38,6 +42,7 @@ parser.add_argument('-n', type=int, default=10, metavar='N',
|
|||||||
parser.add_argument('--safetensors', action='store_true',
|
parser.add_argument('--safetensors', action='store_true',
|
||||||
help='Save weights using safetensors instead of the default torch way (pickle).')
|
help='Save weights using safetensors instead of the default torch way (pickle).')
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_metric(checkpoint_path):
|
def checkpoint_metric(checkpoint_path):
|
||||||
if not checkpoint_path or not os.path.isfile(checkpoint_path):
|
if not checkpoint_path or not os.path.isfile(checkpoint_path):
|
||||||
return {}
|
return {}
|
||||||
@ -63,14 +68,20 @@ def main():
|
|||||||
if args.safetensors and args.output == DEFAULT_OUTPUT:
|
if args.safetensors and args.output == DEFAULT_OUTPUT:
|
||||||
# Default path changes if using safetensors
|
# Default path changes if using safetensors
|
||||||
args.output = DEFAULT_SAFE_OUTPUT
|
args.output = DEFAULT_SAFE_OUTPUT
|
||||||
if args.safetensors and not args.output.endswith(".safetensors"):
|
|
||||||
|
output, output_ext = os.path.splitext(args.output)
|
||||||
|
if not output_ext:
|
||||||
|
output_ext = ('.safetensors' if args.safetensors else '.pth')
|
||||||
|
output = output + output_ext
|
||||||
|
|
||||||
|
if args.safetensors and not output_ext == ".safetensors":
|
||||||
print(
|
print(
|
||||||
"Warning: saving weights as safetensors but output file extension is not "
|
"Warning: saving weights as safetensors but output file extension is not "
|
||||||
f"set to '.safetensors': {args.output}"
|
f"set to '.safetensors': {args.output}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.path.exists(args.output):
|
if os.path.exists(output):
|
||||||
print("Error: Output filename ({}) already exists.".format(args.output))
|
print("Error: Output filename ({}) already exists.".format(output))
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
pattern = args.input
|
pattern = args.input
|
||||||
@ -87,22 +98,27 @@ def main():
|
|||||||
checkpoint_metrics.append((metric, c))
|
checkpoint_metrics.append((metric, c))
|
||||||
checkpoint_metrics = list(sorted(checkpoint_metrics))
|
checkpoint_metrics = list(sorted(checkpoint_metrics))
|
||||||
checkpoint_metrics = checkpoint_metrics[-args.n:]
|
checkpoint_metrics = checkpoint_metrics[-args.n:]
|
||||||
|
if checkpoint_metrics:
|
||||||
print("Selected checkpoints:")
|
print("Selected checkpoints:")
|
||||||
[print(m, c) for m, c in checkpoint_metrics]
|
[print(m, c) for m, c in checkpoint_metrics]
|
||||||
avg_checkpoints = [c for m, c in checkpoint_metrics]
|
avg_checkpoints = [c for m, c in checkpoint_metrics]
|
||||||
else:
|
else:
|
||||||
avg_checkpoints = checkpoints
|
avg_checkpoints = checkpoints
|
||||||
|
if avg_checkpoints:
|
||||||
print("Selected checkpoints:")
|
print("Selected checkpoints:")
|
||||||
[print(c) for c in checkpoints]
|
[print(c) for c in checkpoints]
|
||||||
|
|
||||||
|
if not avg_checkpoints:
|
||||||
|
print('Error: No checkpoints found to average.')
|
||||||
|
exit(1)
|
||||||
|
|
||||||
avg_state_dict = {}
|
avg_state_dict = {}
|
||||||
avg_counts = {}
|
avg_counts = {}
|
||||||
for c in avg_checkpoints:
|
for c in avg_checkpoints:
|
||||||
new_state_dict = load_state_dict(c, args.use_ema)
|
new_state_dict = load_state_dict(c, args.use_ema)
|
||||||
if not new_state_dict:
|
if not new_state_dict:
|
||||||
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
|
print(f"Error: Checkpoint ({c}) doesn't exist")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for k, v in new_state_dict.items():
|
for k, v in new_state_dict.items():
|
||||||
if k not in avg_state_dict:
|
if k not in avg_state_dict:
|
||||||
avg_state_dict[k] = v.clone().to(dtype=torch.float64)
|
avg_state_dict[k] = v.clone().to(dtype=torch.float64)
|
||||||
@ -122,16 +138,14 @@ def main():
|
|||||||
final_state_dict[k] = v.to(dtype=torch.float32)
|
final_state_dict[k] = v.to(dtype=torch.float32)
|
||||||
|
|
||||||
if args.safetensors:
|
if args.safetensors:
|
||||||
safetensors.torch.save_file(final_state_dict, args.output)
|
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||||
|
safetensors.torch.save_file(final_state_dict, output)
|
||||||
else:
|
else:
|
||||||
try:
|
torch.save(final_state_dict, output)
|
||||||
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
|
|
||||||
except:
|
|
||||||
torch.save(final_state_dict, args.output)
|
|
||||||
|
|
||||||
with open(args.output, 'rb') as f:
|
with open(output, 'rb') as f:
|
||||||
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
||||||
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
|
print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -11,9 +11,14 @@ import torch
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
import safetensors.torch
|
|
||||||
import shutil
|
import shutil
|
||||||
|
import tempfile
|
||||||
from timm.models import load_state_dict
|
from timm.models import load_state_dict
|
||||||
|
try:
|
||||||
|
import safetensors.torch
|
||||||
|
_has_safetensors = True
|
||||||
|
except ImportError:
|
||||||
|
_has_safetensors = False
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
|
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
|
||||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||||
@ -22,13 +27,13 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
|
|||||||
help='output path')
|
help='output path')
|
||||||
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
|
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
|
||||||
help='use ema version of weights if present')
|
help='use ema version of weights if present')
|
||||||
|
parser.add_argument('--no-hash', dest='no_hash', action='store_true',
|
||||||
|
help='no hash in output filename')
|
||||||
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
|
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
|
||||||
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
|
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
|
||||||
parser.add_argument('--safetensors', action='store_true',
|
parser.add_argument('--safetensors', action='store_true',
|
||||||
help='Save weights using safetensors instead of the default torch way (pickle).')
|
help='Save weights using safetensors instead of the default torch way (pickle).')
|
||||||
|
|
||||||
_TEMP_NAME = './_checkpoint.pth'
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -37,10 +42,24 @@ def main():
|
|||||||
print("Error: Output filename ({}) already exists.".format(args.output))
|
print("Error: Output filename ({}) already exists.".format(args.output))
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors)
|
clean_checkpoint(
|
||||||
|
args.checkpoint,
|
||||||
|
args.output,
|
||||||
|
not args.no_use_ema,
|
||||||
|
args.no_hash,
|
||||||
|
args.clean_aux_bn,
|
||||||
|
safe_serialization=args.safetensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False):
|
def clean_checkpoint(
|
||||||
|
checkpoint,
|
||||||
|
output,
|
||||||
|
use_ema=True,
|
||||||
|
no_hash=False,
|
||||||
|
clean_aux_bn=False,
|
||||||
|
safe_serialization: bool=False,
|
||||||
|
):
|
||||||
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
|
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
|
||||||
if checkpoint and os.path.isfile(checkpoint):
|
if checkpoint and os.path.isfile(checkpoint):
|
||||||
print("=> Loading checkpoint '{}'".format(checkpoint))
|
print("=> Loading checkpoint '{}'".format(checkpoint))
|
||||||
@ -55,25 +74,36 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, sa
|
|||||||
new_state_dict[name] = v
|
new_state_dict[name] = v
|
||||||
print("=> Loaded state_dict from '{}'".format(checkpoint))
|
print("=> Loaded state_dict from '{}'".format(checkpoint))
|
||||||
|
|
||||||
if safe_serialization:
|
ext = ''
|
||||||
safetensors.torch.save_file(new_state_dict, _TEMP_NAME)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
|
|
||||||
except:
|
|
||||||
torch.save(new_state_dict, _TEMP_NAME)
|
|
||||||
|
|
||||||
with open(_TEMP_NAME, 'rb') as f:
|
|
||||||
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
|
||||||
|
|
||||||
if output:
|
if output:
|
||||||
checkpoint_root, checkpoint_base = os.path.split(output)
|
checkpoint_root, checkpoint_base = os.path.split(output)
|
||||||
checkpoint_base = os.path.splitext(checkpoint_base)[0]
|
checkpoint_base, ext = os.path.splitext(checkpoint_base)
|
||||||
else:
|
else:
|
||||||
checkpoint_root = ''
|
checkpoint_root = ''
|
||||||
checkpoint_base = os.path.splitext(checkpoint)[0]
|
checkpoint_base = os.path.split(checkpoint)[1]
|
||||||
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth')
|
checkpoint_base = os.path.splitext(checkpoint_base)[0]
|
||||||
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
|
|
||||||
|
temp_filename = '__' + checkpoint_base
|
||||||
|
if safe_serialization:
|
||||||
|
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||||
|
safetensors.torch.save_file(new_state_dict, temp_filename)
|
||||||
|
else:
|
||||||
|
torch.save(new_state_dict, temp_filename)
|
||||||
|
|
||||||
|
with open(temp_filename, 'rb') as f:
|
||||||
|
sha_hash = hashlib.sha256(f.read()).hexdigest()
|
||||||
|
|
||||||
|
if ext:
|
||||||
|
final_ext = ext
|
||||||
|
else:
|
||||||
|
final_ext = ('.safetensors' if safe_serialization else '.pth')
|
||||||
|
|
||||||
|
if no_hash:
|
||||||
|
final_filename = checkpoint_base + final_ext
|
||||||
|
else:
|
||||||
|
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext
|
||||||
|
|
||||||
|
shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename))
|
||||||
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
|
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
|
||||||
return final_filename
|
return final_filename
|
||||||
else:
|
else:
|
||||||
|
14
pyproject.toml
Normal file
14
pyproject.toml
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
[tool.pytest.ini_options]
|
||||||
|
markers = [
|
||||||
|
"base: marker for model tests using the basic setup",
|
||||||
|
"cfg: marker for model tests checking the config",
|
||||||
|
"torchscript: marker for model tests using torchscript",
|
||||||
|
"features: marker for model tests checking feature extraction",
|
||||||
|
"fxforward: marker for model tests using torch fx (only forward)",
|
||||||
|
"fxbackward: marker for model tests using torch fx (only backward)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 120
|
||||||
|
target-version = ['py37', 'py38', 'py39', 'py310', 'py311']
|
||||||
|
skip-string-normalization = true
|
5
requirements-dev.txt
Normal file
5
requirements-dev.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
pytest
|
||||||
|
pytest-timeout
|
||||||
|
pytest-xdist
|
||||||
|
pytest-forked
|
||||||
|
expecttest
|
12
setup.py
12
setup.py
@ -14,12 +14,12 @@ exec(open('timm/version.py').read())
|
|||||||
setup(
|
setup(
|
||||||
name='timm',
|
name='timm',
|
||||||
version=__version__,
|
version=__version__,
|
||||||
description='(Unofficial) PyTorch Image Models',
|
description='PyTorch Image Models',
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type='text/markdown',
|
long_description_content_type='text/markdown',
|
||||||
url='https://github.com/rwightman/pytorch-image-models',
|
url='https://github.com/huggingface/pytorch-image-models',
|
||||||
author='Ross Wightman',
|
author='Ross Wightman',
|
||||||
author_email='hello@rwightman.com',
|
author_email='ross@huggingface.co',
|
||||||
classifiers=[
|
classifiers=[
|
||||||
# How mature is this project? Common values are
|
# How mature is this project? Common values are
|
||||||
# 3 - Alpha
|
# 3 - Alpha
|
||||||
@ -29,11 +29,11 @@ setup(
|
|||||||
'Intended Audience :: Education',
|
'Intended Audience :: Education',
|
||||||
'Intended Audience :: Science/Research',
|
'Intended Audience :: Science/Research',
|
||||||
'License :: OSI Approved :: Apache Software License',
|
'License :: OSI Approved :: Apache Software License',
|
||||||
'Programming Language :: Python :: 3.6',
|
|
||||||
'Programming Language :: Python :: 3.7',
|
'Programming Language :: Python :: 3.7',
|
||||||
'Programming Language :: Python :: 3.8',
|
'Programming Language :: Python :: 3.8',
|
||||||
'Programming Language :: Python :: 3.9',
|
'Programming Language :: Python :: 3.9',
|
||||||
'Programming Language :: Python :: 3.10',
|
'Programming Language :: Python :: 3.10',
|
||||||
|
'Programming Language :: Python :: 3.11',
|
||||||
'Topic :: Scientific/Engineering',
|
'Topic :: Scientific/Engineering',
|
||||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||||
'Topic :: Software Development',
|
'Topic :: Software Development',
|
||||||
@ -45,7 +45,7 @@ setup(
|
|||||||
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
|
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
|
||||||
packages=find_packages(exclude=['convert', 'tests', 'results']),
|
packages=find_packages(exclude=['convert', 'tests', 'results']),
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'],
|
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub', 'safetensors'],
|
||||||
python_requires='>=3.6',
|
python_requires='>=3.7',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,3 +1,16 @@
|
|||||||
|
"""Run tests for all models
|
||||||
|
|
||||||
|
Tests that run on CI should have a specific marker, e.g. @pytest.mark.base. This
|
||||||
|
marker is used to parallelize the CI runs, with one runner for each marker.
|
||||||
|
|
||||||
|
If new tests are added, ensure that they use one of the existing markers
|
||||||
|
(documented in pyproject.toml > pytest > markers) or that a new marker is added
|
||||||
|
for this set of tests. If using a new marker, adjust the test matrix in
|
||||||
|
.github/workflows/tests.yml to run tests with this new marker, otherwise the
|
||||||
|
tests will be skipped on CI.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import platform
|
import platform
|
||||||
@ -83,6 +96,7 @@ def _get_input_size(model=None, model_name='', target=None):
|
|||||||
return input_size
|
return input_size
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.base
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
@ -101,6 +115,7 @@ def test_model_forward(model_name, batch_size):
|
|||||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.base
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
|
||||||
@pytest.mark.parametrize('batch_size', [2])
|
@pytest.mark.parametrize('batch_size', [2])
|
||||||
@ -128,6 +143,7 @@ def test_model_backward(model_name, batch_size):
|
|||||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cfg
|
||||||
@pytest.mark.timeout(300)
|
@pytest.mark.timeout(300)
|
||||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS, include_tags=True))
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS, include_tags=True))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
@ -190,6 +206,7 @@ def test_model_default_cfgs(model_name, batch_size):
|
|||||||
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
|
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.cfg
|
||||||
@pytest.mark.timeout(300)
|
@pytest.mark.timeout(300)
|
||||||
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True))
|
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
@ -274,6 +291,7 @@ EXCLUDE_JIT_FILTERS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.torchscript
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
|
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
|
||||||
@ -303,6 +321,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
|
|||||||
EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d']
|
EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d']
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.features
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
@ -379,6 +398,7 @@ if not _IS_MAC:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fxforward
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
|
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS))
|
||||||
@pytest.mark.parametrize('batch_size', [1])
|
@pytest.mark.parametrize('batch_size', [1])
|
||||||
@ -412,6 +432,7 @@ if not _IS_MAC:
|
|||||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.fxbackward
|
||||||
@pytest.mark.timeout(120)
|
@pytest.mark.timeout(120)
|
||||||
@pytest.mark.parametrize('model_name', list_models(
|
@pytest.mark.parametrize('model_name', list_models(
|
||||||
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
|
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
|
||||||
|
@ -54,7 +54,6 @@ def _interpolation(kwargs):
|
|||||||
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
|
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
|
||||||
if isinstance(interpolation, (list, tuple)):
|
if isinstance(interpolation, (list, tuple)):
|
||||||
return random.choice(interpolation)
|
return random.choice(interpolation)
|
||||||
else:
|
|
||||||
return interpolation
|
return interpolation
|
||||||
|
|
||||||
|
|
||||||
@ -100,7 +99,7 @@ def rotate(img, degrees, **kwargs):
|
|||||||
_check_args_tf(kwargs)
|
_check_args_tf(kwargs)
|
||||||
if _PIL_VER >= (5, 2):
|
if _PIL_VER >= (5, 2):
|
||||||
return img.rotate(degrees, **kwargs)
|
return img.rotate(degrees, **kwargs)
|
||||||
elif _PIL_VER >= (5, 0):
|
if _PIL_VER >= (5, 0):
|
||||||
w, h = img.size
|
w, h = img.size
|
||||||
post_trans = (0, 0)
|
post_trans = (0, 0)
|
||||||
rotn_center = (w / 2.0, h / 2.0)
|
rotn_center = (w / 2.0, h / 2.0)
|
||||||
@ -124,7 +123,6 @@ def rotate(img, degrees, **kwargs):
|
|||||||
matrix[2] += rotn_center[0]
|
matrix[2] += rotn_center[0]
|
||||||
matrix[5] += rotn_center[1]
|
matrix[5] += rotn_center[1]
|
||||||
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
|
||||||
else:
|
|
||||||
return img.rotate(degrees, resample=kwargs['resample'])
|
return img.rotate(degrees, resample=kwargs['resample'])
|
||||||
|
|
||||||
|
|
||||||
@ -151,11 +149,12 @@ def solarize_add(img, add, thresh=128, **__):
|
|||||||
lut.append(min(255, i + add))
|
lut.append(min(255, i + add))
|
||||||
else:
|
else:
|
||||||
lut.append(i)
|
lut.append(i)
|
||||||
|
|
||||||
if img.mode in ("L", "RGB"):
|
if img.mode in ("L", "RGB"):
|
||||||
if img.mode == "RGB" and len(lut) == 256:
|
if img.mode == "RGB" and len(lut) == 256:
|
||||||
lut = lut + lut + lut
|
lut = lut + lut + lut
|
||||||
return img.point(lut)
|
return img.point(lut)
|
||||||
else:
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
@ -226,7 +225,7 @@ def _enhance_increasing_level_to_arg(level, _hparams):
|
|||||||
|
|
||||||
def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
|
def _minmax_level_to_arg(level, _hparams, min_val=0., max_val=1.0, clamp=True):
|
||||||
level = (level / _LEVEL_DENOM)
|
level = (level / _LEVEL_DENOM)
|
||||||
min_val + (max_val - min_val) * level
|
level = min_val + (max_val - min_val) * level
|
||||||
if clamp:
|
if clamp:
|
||||||
level = max(min_val, min(max_val, level))
|
level = max(min_val, min(max_val, level))
|
||||||
return level,
|
return level,
|
||||||
@ -552,16 +551,15 @@ def auto_augment_policy(name='v0', hparams=None):
|
|||||||
hparams = hparams or _HPARAMS_DEFAULT
|
hparams = hparams or _HPARAMS_DEFAULT
|
||||||
if name == 'original':
|
if name == 'original':
|
||||||
return auto_augment_policy_original(hparams)
|
return auto_augment_policy_original(hparams)
|
||||||
elif name == 'originalr':
|
if name == 'originalr':
|
||||||
return auto_augment_policy_originalr(hparams)
|
return auto_augment_policy_originalr(hparams)
|
||||||
elif name == 'v0':
|
if name == 'v0':
|
||||||
return auto_augment_policy_v0(hparams)
|
return auto_augment_policy_v0(hparams)
|
||||||
elif name == 'v0r':
|
if name == 'v0r':
|
||||||
return auto_augment_policy_v0r(hparams)
|
return auto_augment_policy_v0r(hparams)
|
||||||
elif name == '3a':
|
if name == '3a':
|
||||||
return auto_augment_policy_3a(hparams)
|
return auto_augment_policy_3a(hparams)
|
||||||
else:
|
assert False, f'Unknown AA policy {name}'
|
||||||
assert False, 'Unknown AA policy (%s)' % name
|
|
||||||
|
|
||||||
|
|
||||||
class AutoAugment:
|
class AutoAugment:
|
||||||
@ -576,7 +574,7 @@ class AutoAugment:
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
fs = self.__class__.__name__ + f'(policy='
|
fs = self.__class__.__name__ + '(policy='
|
||||||
for p in self.policy:
|
for p in self.policy:
|
||||||
fs += '\n\t['
|
fs += '\n\t['
|
||||||
fs += ', '.join([str(op) for op in p])
|
fs += ', '.join([str(op) for op in p])
|
||||||
@ -636,7 +634,7 @@ _RAND_TRANSFORMS = [
|
|||||||
'ShearY',
|
'ShearY',
|
||||||
'TranslateXRel',
|
'TranslateXRel',
|
||||||
'TranslateYRel',
|
'TranslateYRel',
|
||||||
#'Cutout' # NOTE I've implement this as random erasing separately
|
# 'Cutout' # NOTE I've implement this as random erasing separately
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -656,7 +654,7 @@ _RAND_INCREASING_TRANSFORMS = [
|
|||||||
'ShearY',
|
'ShearY',
|
||||||
'TranslateXRel',
|
'TranslateXRel',
|
||||||
'TranslateYRel',
|
'TranslateYRel',
|
||||||
#'Cutout' # NOTE I've implement this as random erasing separately
|
# 'Cutout' # NOTE I've implement this as random erasing separately
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -667,7 +665,7 @@ _RAND_3A = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
_RAND_CHOICE_3A = {
|
_RAND_WEIGHTED_3A = {
|
||||||
'SolarizeIncreasing': 6,
|
'SolarizeIncreasing': 6,
|
||||||
'Desaturate': 6,
|
'Desaturate': 6,
|
||||||
'GaussianBlur': 6,
|
'GaussianBlur': 6,
|
||||||
@ -687,7 +685,7 @@ _RAND_CHOICE_3A = {
|
|||||||
|
|
||||||
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
# These experimental weights are based loosely on the relative improvements mentioned in paper.
|
||||||
# They may not result in increased performance, but could likely be tuned to so.
|
# They may not result in increased performance, but could likely be tuned to so.
|
||||||
_RAND_CHOICE_WEIGHTS_0 = {
|
_RAND_WEIGHTED_0 = {
|
||||||
'Rotate': 3,
|
'Rotate': 3,
|
||||||
'ShearX': 2,
|
'ShearX': 2,
|
||||||
'ShearY': 2,
|
'ShearY': 2,
|
||||||
@ -715,12 +713,11 @@ def _get_weighted_transforms(transforms: Dict):
|
|||||||
|
|
||||||
def rand_augment_choices(name: str, increasing=True):
|
def rand_augment_choices(name: str, increasing=True):
|
||||||
if name == 'weights':
|
if name == 'weights':
|
||||||
return _RAND_CHOICE_WEIGHTS_0
|
return _RAND_WEIGHTED_0
|
||||||
elif name == '3aw':
|
if name == '3aw':
|
||||||
return _RAND_CHOICE_3A
|
return _RAND_WEIGHTED_3A
|
||||||
elif name == '3a':
|
if name == '3a':
|
||||||
return _RAND_3A
|
return _RAND_3A
|
||||||
else:
|
|
||||||
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
|
return _RAND_INCREASING_TRANSFORMS if increasing else _RAND_TRANSFORMS
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,11 @@ import os
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import safetensors.torch
|
try:
|
||||||
|
import safetensors.torch
|
||||||
|
_has_safetensors = True
|
||||||
|
except ImportError:
|
||||||
|
_has_safetensors = False
|
||||||
|
|
||||||
import timm.models._builder
|
import timm.models._builder
|
||||||
|
|
||||||
@ -29,6 +33,7 @@ def load_state_dict(checkpoint_path, use_ema=True):
|
|||||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||||
# Check if safetensors or not and load weights accordingly
|
# Check if safetensors or not and load weights accordingly
|
||||||
if str(checkpoint_path).endswith(".safetensors"):
|
if str(checkpoint_path).endswith(".safetensors"):
|
||||||
|
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||||
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
|
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||||
|
@ -7,15 +7,21 @@ from functools import partial
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import Iterable, Optional, Union
|
from typing import Iterable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
||||||
import safetensors.torch
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from torch.hub import get_dir
|
from torch.hub import get_dir
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from torch.hub import _get_torch_home as get_dir
|
from torch.hub import _get_torch_home as get_dir
|
||||||
|
|
||||||
|
try:
|
||||||
|
import safetensors.torch
|
||||||
|
_has_safetensors = True
|
||||||
|
except ImportError:
|
||||||
|
_has_safetensors = False
|
||||||
|
|
||||||
if sys.version_info >= (3, 8):
|
if sys.version_info >= (3, 8):
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
else:
|
else:
|
||||||
@ -45,6 +51,7 @@ __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'l
|
|||||||
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
|
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
|
||||||
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
|
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
|
||||||
|
|
||||||
|
|
||||||
def get_cache_dir(child_dir=''):
|
def get_cache_dir(child_dir=''):
|
||||||
"""
|
"""
|
||||||
Returns the location of the directory where models are cached (and creates it if necessary).
|
Returns the location of the directory where models are cached (and creates it if necessary).
|
||||||
@ -164,21 +171,28 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
|||||||
hf_model_id, hf_revision = hf_split(model_id)
|
hf_model_id, hf_revision = hf_split(model_id)
|
||||||
|
|
||||||
# Look for .safetensors alternatives and load from it if it exists
|
# Look for .safetensors alternatives and load from it if it exists
|
||||||
|
if _has_safetensors:
|
||||||
for safe_filename in _get_safe_alternatives(filename):
|
for safe_filename in _get_safe_alternatives(filename):
|
||||||
try:
|
try:
|
||||||
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
|
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
|
||||||
_logger.info(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
|
_logger.info(
|
||||||
|
f"[{model_id}] Safe alternative available for '{filename}' "
|
||||||
|
f"(as '{safe_filename}'). Loading weights using safetensors.")
|
||||||
return safetensors.torch.load_file(cached_safe_file, device="cpu")
|
return safetensors.torch.load_file(cached_safe_file, device="cpu")
|
||||||
except EntryNotFoundError:
|
except EntryNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Otherwise, load using pytorch.load
|
# Otherwise, load using pytorch.load
|
||||||
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
||||||
_logger.info(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
||||||
return torch.load(cached_file, map_location='cpu')
|
return torch.load(cached_file, map_location='cpu')
|
||||||
|
|
||||||
|
|
||||||
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
|
def save_config_for_hf(
|
||||||
|
model,
|
||||||
|
config_path: str,
|
||||||
|
model_config: Optional[dict] = None
|
||||||
|
):
|
||||||
model_config = model_config or {}
|
model_config = model_config or {}
|
||||||
hf_config = {}
|
hf_config = {}
|
||||||
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
||||||
@ -220,8 +234,8 @@ def save_for_hf(
|
|||||||
model,
|
model,
|
||||||
save_directory: str,
|
save_directory: str,
|
||||||
model_config: Optional[dict] = None,
|
model_config: Optional[dict] = None,
|
||||||
safe_serialization: Union[bool, Literal["both"]] = False
|
safe_serialization: Union[bool, Literal["both"]] = False,
|
||||||
):
|
):
|
||||||
assert has_hf_hub(True)
|
assert has_hf_hub(True)
|
||||||
save_directory = Path(save_directory)
|
save_directory = Path(save_directory)
|
||||||
save_directory.mkdir(exist_ok=True, parents=True)
|
save_directory.mkdir(exist_ok=True, parents=True)
|
||||||
@ -229,6 +243,7 @@ def save_for_hf(
|
|||||||
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
|
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
|
||||||
tensors = model.state_dict()
|
tensors = model.state_dict()
|
||||||
if safe_serialization is True or safe_serialization == "both":
|
if safe_serialization is True or safe_serialization == "both":
|
||||||
|
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||||
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
|
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
|
||||||
if safe_serialization is False or safe_serialization == "both":
|
if safe_serialization is False or safe_serialization == "both":
|
||||||
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
|
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
|
||||||
@ -247,7 +262,7 @@ def push_to_hf_hub(
|
|||||||
create_pr: bool = False,
|
create_pr: bool = False,
|
||||||
model_config: Optional[dict] = None,
|
model_config: Optional[dict] = None,
|
||||||
model_card: Optional[dict] = None,
|
model_card: Optional[dict] = None,
|
||||||
safe_serialization: Union[bool, Literal["both"]] = False
|
safe_serialization: Union[bool, Literal["both"]] = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
@ -341,6 +356,7 @@ def generate_readme(model_card: dict, model_name: str):
|
|||||||
readme_text += f"```bibtex\n{c}\n```\n"
|
readme_text += f"```bibtex\n{c}\n```\n"
|
||||||
return readme_text
|
return readme_text
|
||||||
|
|
||||||
|
|
||||||
def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
||||||
"""Returns potential safetensors alternatives for a given filename.
|
"""Returns potential safetensors alternatives for a given filename.
|
||||||
|
|
||||||
@ -350,5 +366,5 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
|||||||
"""
|
"""
|
||||||
if filename == HF_WEIGHTS_NAME:
|
if filename == HF_WEIGHTS_NAME:
|
||||||
yield HF_SAFE_WEIGHTS_NAME
|
yield HF_SAFE_WEIGHTS_NAME
|
||||||
if filename.endswith(".bin"):
|
if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"):
|
||||||
yield filename[:-4] + ".safetensors"
|
return filename[:-4] + ".safetensors"
|
||||||
|
@ -93,7 +93,7 @@ class DefaultCfg:
|
|||||||
return tag, self.cfgs[tag]
|
return tag, self.cfgs[tag]
|
||||||
|
|
||||||
|
|
||||||
def split_model_name_tag(model_name: str, no_tag=''):
|
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
|
||||||
model_name, *tag_list = model_name.split('.', 1)
|
model_name, *tag_list = model_name.split('.', 1)
|
||||||
tag = tag_list[0] if tag_list else no_tag
|
tag = tag_list[0] if tag_list else no_tag
|
||||||
return model_name, tag
|
return model_name, tag
|
||||||
|
@ -8,7 +8,7 @@ import sys
|
|||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
from typing import List, Optional, Union, Tuple
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
|
||||||
|
|
||||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||||
|
|
||||||
@ -16,20 +16,20 @@ __all__ = [
|
|||||||
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||||
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
||||||
|
|
||||||
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
|
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
|
||||||
_model_to_module = {} # mapping of model names to module names
|
_model_to_module: Dict[str, str] = {} # mapping of model names to module names
|
||||||
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
|
_model_entrypoints: Dict[str, Callable[..., Any]] = {} # mapping of model names to architecture entrypoint fns
|
||||||
_model_has_pretrained = set() # set of model names that have pretrained weight url present
|
_model_has_pretrained: Set[str] = set() # set of model names that have pretrained weight url present
|
||||||
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
|
_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
|
||||||
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
|
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
|
||||||
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
||||||
|
|
||||||
|
|
||||||
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]:
|
def get_arch_name(model_name: str) -> str:
|
||||||
return split_model_name_tag(model_name)[0]
|
return split_model_name_tag(model_name)[0]
|
||||||
|
|
||||||
|
|
||||||
def register_model(fn):
|
def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||||
# lookup containing module
|
# lookup containing module
|
||||||
mod = sys.modules[fn.__module__]
|
mod = sys.modules[fn.__module__]
|
||||||
module_name_split = fn.__module__.split('.')
|
module_name_split = fn.__module__.split('.')
|
||||||
@ -40,7 +40,7 @@ def register_model(fn):
|
|||||||
if hasattr(mod, '__all__'):
|
if hasattr(mod, '__all__'):
|
||||||
mod.__all__.append(model_name)
|
mod.__all__.append(model_name)
|
||||||
else:
|
else:
|
||||||
mod.__all__ = [model_name]
|
mod.__all__ = [model_name] # type: ignore
|
||||||
|
|
||||||
# add entries to registry dict/sets
|
# add entries to registry dict/sets
|
||||||
_model_entrypoints[model_name] = fn
|
_model_entrypoints[model_name] = fn
|
||||||
@ -87,28 +87,33 @@ def register_model(fn):
|
|||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
def _natural_key(string_):
|
def _natural_key(string_: str) -> List[Union[int, str]]:
|
||||||
|
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
|
||||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||||
|
|
||||||
|
|
||||||
def list_models(
|
def list_models(
|
||||||
filter: Union[str, List[str]] = '',
|
filter: Union[str, List[str]] = '',
|
||||||
module: str = '',
|
module: str = '',
|
||||||
pretrained=False,
|
pretrained: bool = False,
|
||||||
exclude_filters: str = '',
|
exclude_filters: Union[str, List[str]] = '',
|
||||||
name_matches_cfg: bool = False,
|
name_matches_cfg: bool = False,
|
||||||
include_tags: Optional[bool] = None,
|
include_tags: Optional[bool] = None,
|
||||||
):
|
) -> List[str]:
|
||||||
""" Return list of available model names, sorted alphabetically
|
""" Return list of available model names, sorted alphabetically
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filter (str) - Wildcard filter string that works with fnmatch
|
filter - Wildcard filter string that works with fnmatch
|
||||||
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
|
module - Limit model selection to a specific submodule (ie 'vision_transformer')
|
||||||
pretrained (bool) - Include only models with valid pretrained weights if True
|
pretrained - Include only models with valid pretrained weights if True
|
||||||
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
|
exclude_filters - Wildcard filters to exclude models after including them with filter
|
||||||
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
||||||
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
|
include_tags - Include pretrained tags in model names (model.tag). If None, defaults
|
||||||
set to True when pretrained=True else False (default: None)
|
set to True when pretrained=True else False (default: None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
models - The sorted list of models
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
||||||
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
||||||
@ -118,7 +123,7 @@ def list_models(
|
|||||||
include_tags = pretrained
|
include_tags = pretrained
|
||||||
|
|
||||||
if module:
|
if module:
|
||||||
all_models = list(_module_to_models[module])
|
all_models: Iterable[str] = list(_module_to_models[module])
|
||||||
else:
|
else:
|
||||||
all_models = _model_entrypoints.keys()
|
all_models = _model_entrypoints.keys()
|
||||||
|
|
||||||
@ -130,14 +135,14 @@ def list_models(
|
|||||||
all_models = models_with_tags
|
all_models = models_with_tags
|
||||||
|
|
||||||
if filter:
|
if filter:
|
||||||
models = []
|
models: Set[str] = set()
|
||||||
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
|
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
|
||||||
for f in include_filters:
|
for f in include_filters:
|
||||||
include_models = fnmatch.filter(all_models, f) # include these models
|
include_models = fnmatch.filter(all_models, f) # include these models
|
||||||
if len(include_models):
|
if len(include_models):
|
||||||
models = set(models).union(include_models)
|
models = models.union(include_models)
|
||||||
else:
|
else:
|
||||||
models = all_models
|
models = set(all_models)
|
||||||
|
|
||||||
if exclude_filters:
|
if exclude_filters:
|
||||||
if not isinstance(exclude_filters, (tuple, list)):
|
if not isinstance(exclude_filters, (tuple, list)):
|
||||||
@ -145,7 +150,7 @@ def list_models(
|
|||||||
for xf in exclude_filters:
|
for xf in exclude_filters:
|
||||||
exclude_models = fnmatch.filter(models, xf) # exclude these models
|
exclude_models = fnmatch.filter(models, xf) # exclude these models
|
||||||
if len(exclude_models):
|
if len(exclude_models):
|
||||||
models = set(models).difference(exclude_models)
|
models = models.difference(exclude_models)
|
||||||
|
|
||||||
if pretrained:
|
if pretrained:
|
||||||
models = _model_has_pretrained.intersection(models)
|
models = _model_has_pretrained.intersection(models)
|
||||||
@ -153,13 +158,13 @@ def list_models(
|
|||||||
if name_matches_cfg:
|
if name_matches_cfg:
|
||||||
models = set(_model_pretrained_cfgs).intersection(models)
|
models = set(_model_pretrained_cfgs).intersection(models)
|
||||||
|
|
||||||
return list(sorted(models, key=_natural_key))
|
return sorted(models, key=_natural_key)
|
||||||
|
|
||||||
|
|
||||||
def list_pretrained(
|
def list_pretrained(
|
||||||
filter: Union[str, List[str]] = '',
|
filter: Union[str, List[str]] = '',
|
||||||
exclude_filters: str = '',
|
exclude_filters: str = '',
|
||||||
):
|
) -> List[str]:
|
||||||
return list_models(
|
return list_models(
|
||||||
filter=filter,
|
filter=filter,
|
||||||
pretrained=True,
|
pretrained=True,
|
||||||
@ -168,14 +173,14 @@ def list_pretrained(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_model(model_name):
|
def is_model(model_name: str) -> bool:
|
||||||
""" Check if a model name exists
|
""" Check if a model name exists
|
||||||
"""
|
"""
|
||||||
arch_name = get_arch_name(model_name)
|
arch_name = get_arch_name(model_name)
|
||||||
return arch_name in _model_entrypoints
|
return arch_name in _model_entrypoints
|
||||||
|
|
||||||
|
|
||||||
def model_entrypoint(model_name, module_filter: Optional[str] = None):
|
def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
|
||||||
"""Fetch a model entrypoint for specified model name
|
"""Fetch a model entrypoint for specified model name
|
||||||
"""
|
"""
|
||||||
arch_name = get_arch_name(model_name)
|
arch_name = get_arch_name(model_name)
|
||||||
@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None):
|
|||||||
return _model_entrypoints[arch_name]
|
return _model_entrypoints[arch_name]
|
||||||
|
|
||||||
|
|
||||||
def list_modules():
|
def list_modules() -> List[str]:
|
||||||
""" Return list of module names that contain models / model entrypoints
|
""" Return list of module names that contain models / model entrypoints
|
||||||
"""
|
"""
|
||||||
modules = _module_to_models.keys()
|
modules = _module_to_models.keys()
|
||||||
return list(sorted(modules))
|
return sorted(modules)
|
||||||
|
|
||||||
|
|
||||||
def is_model_in_modules(model_name, module_names):
|
def is_model_in_modules(
|
||||||
|
model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
|
||||||
|
) -> bool:
|
||||||
"""Check if a model exists within a subset of modules
|
"""Check if a model exists within a subset of modules
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_name (str) - name of model to check
|
model_name - name of model to check
|
||||||
module_names (tuple, list, set) - names of modules to search in
|
module_names - names of modules to search in
|
||||||
"""
|
"""
|
||||||
arch_name = get_arch_name(model_name)
|
arch_name = get_arch_name(model_name)
|
||||||
assert isinstance(module_names, (tuple, list, set))
|
assert isinstance(module_names, (tuple, list, set))
|
||||||
return any(arch_name in _module_to_models[n] for n in module_names)
|
return any(arch_name in _module_to_models[n] for n in module_names)
|
||||||
|
|
||||||
|
|
||||||
def is_model_pretrained(model_name):
|
def is_model_pretrained(model_name: str) -> bool:
|
||||||
return model_name in _model_has_pretrained
|
return model_name in _model_has_pretrained
|
||||||
|
|
||||||
|
|
||||||
def get_pretrained_cfg(model_name, allow_unregistered=True):
|
def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
|
||||||
if model_name in _model_pretrained_cfgs:
|
if model_name in _model_pretrained_cfgs:
|
||||||
return deepcopy(_model_pretrained_cfgs[model_name])
|
return deepcopy(_model_pretrained_cfgs[model_name])
|
||||||
arch_name, tag = split_model_name_tag(model_name)
|
arch_name, tag = split_model_name_tag(model_name)
|
||||||
@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True):
|
|||||||
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
|
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
|
||||||
|
|
||||||
|
|
||||||
def get_pretrained_cfg_value(model_name, cfg_key):
|
def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
|
||||||
""" Get a specific model default_cfg value by key. None if key doesn't exist.
|
""" Get a specific model default_cfg value by key. None if key doesn't exist.
|
||||||
"""
|
"""
|
||||||
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
|
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
|
||||||
|
@ -45,7 +45,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
|
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
|
||||||
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
|
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
|
||||||
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
@ -56,6 +56,28 @@ from ._registry import register_model
|
|||||||
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
|
__all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_chs, out_chs, stride=1, dilation=1):
|
||||||
|
super().__init__()
|
||||||
|
avg_stride = stride if dilation == 1 else 1
|
||||||
|
if stride > 1 or dilation > 1:
|
||||||
|
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
||||||
|
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
||||||
|
else:
|
||||||
|
self.pool = nn.Identity()
|
||||||
|
|
||||||
|
if in_chs != out_chs:
|
||||||
|
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
|
||||||
|
else:
|
||||||
|
self.conv = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.pool(x)
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ConvNeXtBlock(nn.Module):
|
class ConvNeXtBlock(nn.Module):
|
||||||
""" ConvNeXt Block
|
""" ConvNeXt Block
|
||||||
There are two equivalent implementations:
|
There are two equivalent implementations:
|
||||||
@ -65,41 +87,65 @@ class ConvNeXtBlock(nn.Module):
|
|||||||
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
|
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
|
||||||
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
|
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
|
||||||
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
|
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
|
||||||
|
|
||||||
Args:
|
|
||||||
in_chs (int): Number of input channels.
|
|
||||||
drop_path (float): Stochastic depth rate. Default: 0.0
|
|
||||||
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_chs,
|
in_chs: int,
|
||||||
out_chs=None,
|
out_chs: Optional[int] = None,
|
||||||
kernel_size=7,
|
kernel_size: int = 7,
|
||||||
stride=1,
|
stride: int = 1,
|
||||||
dilation=1,
|
dilation: Union[int, Tuple[int, int]] = (1, 1),
|
||||||
mlp_ratio=4,
|
mlp_ratio: float = 4,
|
||||||
conv_mlp=False,
|
conv_mlp: bool = False,
|
||||||
conv_bias=True,
|
conv_bias: bool = True,
|
||||||
use_grn=False,
|
use_grn: bool = False,
|
||||||
ls_init_value=1e-6,
|
ls_init_value: Optional[float] = 1e-6,
|
||||||
act_layer='gelu',
|
act_layer: Union[str, Callable] = 'gelu',
|
||||||
norm_layer=None,
|
norm_layer: Optional[Callable] = None,
|
||||||
drop_path=0.,
|
drop_path: float = 0.,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_chs: Block input channels.
|
||||||
|
out_chs: Block output channels (same as in_chs if None).
|
||||||
|
kernel_size: Depthwise convolution kernel size.
|
||||||
|
stride: Stride of depthwise convolution.
|
||||||
|
dilation: Tuple specifying input and output dilation of block.
|
||||||
|
mlp_ratio: MLP expansion ratio.
|
||||||
|
conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
|
||||||
|
conv_bias: Apply bias for all convolution (linear) layers.
|
||||||
|
use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
|
||||||
|
ls_init_value: Layer-scale init values, layer-scale applied if not None.
|
||||||
|
act_layer: Activation layer.
|
||||||
|
norm_layer: Normalization layer (defaults to LN if not specified).
|
||||||
|
drop_path: Stochastic depth probability.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
out_chs = out_chs or in_chs
|
out_chs = out_chs or in_chs
|
||||||
|
dilation = to_ntuple(2)(dilation)
|
||||||
act_layer = get_act_layer(act_layer)
|
act_layer = get_act_layer(act_layer)
|
||||||
if not norm_layer:
|
if not norm_layer:
|
||||||
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
|
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
|
||||||
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
|
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
|
||||||
self.use_conv_mlp = conv_mlp
|
self.use_conv_mlp = conv_mlp
|
||||||
self.conv_dw = create_conv2d(
|
self.conv_dw = create_conv2d(
|
||||||
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
|
in_chs,
|
||||||
|
out_chs,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation[0],
|
||||||
|
depthwise=True,
|
||||||
|
bias=conv_bias,
|
||||||
|
)
|
||||||
self.norm = norm_layer(out_chs)
|
self.norm = norm_layer(out_chs)
|
||||||
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
|
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
|
||||||
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
|
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
|
||||||
|
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
||||||
|
self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0])
|
||||||
|
else:
|
||||||
|
self.shortcut = nn.Identity()
|
||||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -116,7 +162,7 @@ class ConvNeXtBlock(nn.Module):
|
|||||||
if self.gamma is not None:
|
if self.gamma is not None:
|
||||||
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
||||||
|
|
||||||
x = self.drop_path(x) + shortcut
|
x = self.drop_path(x) + self.shortcut(shortcut)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -148,8 +194,14 @@ class ConvNeXtStage(nn.Module):
|
|||||||
self.downsample = nn.Sequential(
|
self.downsample = nn.Sequential(
|
||||||
norm_layer(in_chs),
|
norm_layer(in_chs),
|
||||||
create_conv2d(
|
create_conv2d(
|
||||||
in_chs, out_chs, kernel_size=ds_ks, stride=stride,
|
in_chs,
|
||||||
dilation=dilation[0], padding=pad, bias=conv_bias),
|
out_chs,
|
||||||
|
kernel_size=ds_ks,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=pad,
|
||||||
|
bias=conv_bias,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
in_chs = out_chs
|
in_chs = out_chs
|
||||||
else:
|
else:
|
||||||
@ -773,136 +825,147 @@ default_cfgs = generate_default_cfgs({
|
|||||||
hf_hub_filename='open_clip_pytorch_model.bin',
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
|
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
|
||||||
|
'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
|
||||||
|
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft',
|
||||||
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
|
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
|
||||||
|
'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
|
||||||
|
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup',
|
||||||
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
|
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
|
||||||
|
'convnext_xxlarge.clip_laion2b_soup': _cfg(
|
||||||
|
hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup',
|
||||||
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
|
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
|
||||||
|
'convnext_xxlarge.clip_laion2b_rewind': _cfg(
|
||||||
|
hf_hub_id='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind',
|
||||||
|
hf_hub_filename='open_clip_pytorch_model.bin',
|
||||||
|
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
|
||||||
|
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_atto(pretrained=False, **kwargs):
|
def convnext_atto(pretrained=False, **kwargs):
|
||||||
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
||||||
model_args = dict(
|
model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True)
|
||||||
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, **kwargs)
|
model = _create_convnext('convnext_atto', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
model = _create_convnext('convnext_atto', pretrained=pretrained, **model_args)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_atto_ols(pretrained=False, **kwargs):
|
def convnext_atto_ols(pretrained=False, **kwargs):
|
||||||
# timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
|
# timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
|
||||||
model_args = dict(
|
model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered')
|
||||||
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
|
model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **model_args)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_femto(pretrained=False, **kwargs):
|
def convnext_femto(pretrained=False, **kwargs):
|
||||||
# timm femto variant
|
# timm femto variant
|
||||||
model_args = dict(
|
model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True)
|
||||||
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, **kwargs)
|
model = _create_convnext('convnext_femto', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
model = _create_convnext('convnext_femto', pretrained=pretrained, **model_args)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_femto_ols(pretrained=False, **kwargs):
|
def convnext_femto_ols(pretrained=False, **kwargs):
|
||||||
# timm femto variant
|
# timm femto variant
|
||||||
model_args = dict(
|
model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered')
|
||||||
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
|
model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **model_args)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_pico(pretrained=False, **kwargs):
|
def convnext_pico(pretrained=False, **kwargs):
|
||||||
# timm pico variant
|
# timm pico variant
|
||||||
model_args = dict(
|
model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True)
|
||||||
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, **kwargs)
|
model = _create_convnext('convnext_pico', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
model = _create_convnext('convnext_pico', pretrained=pretrained, **model_args)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_pico_ols(pretrained=False, **kwargs):
|
def convnext_pico_ols(pretrained=False, **kwargs):
|
||||||
# timm nano variant with overlapping 3x3 conv stem
|
# timm nano variant with overlapping 3x3 conv stem
|
||||||
model_args = dict(
|
model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered')
|
||||||
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered', **kwargs)
|
model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **model_args)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_nano(pretrained=False, **kwargs):
|
def convnext_nano(pretrained=False, **kwargs):
|
||||||
# timm nano variant with standard stem and head
|
# timm nano variant with standard stem and head
|
||||||
model_args = dict(
|
model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True)
|
||||||
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, **kwargs)
|
model = _create_convnext('convnext_nano', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
model = _create_convnext('convnext_nano', pretrained=pretrained, **model_args)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_nano_ols(pretrained=False, **kwargs):
|
def convnext_nano_ols(pretrained=False, **kwargs):
|
||||||
# experimental nano variant with overlapping conv stem
|
# experimental nano variant with overlapping conv stem
|
||||||
model_args = dict(
|
model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap')
|
||||||
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap', **kwargs)
|
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_tiny_hnf(pretrained=False, **kwargs):
|
def convnext_tiny_hnf(pretrained=False, **kwargs):
|
||||||
# experimental tiny variant with norm before pooling in head (head norm first)
|
# experimental tiny variant with norm before pooling in head (head norm first)
|
||||||
model_args = dict(
|
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True)
|
||||||
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)
|
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **model_args)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_tiny(pretrained=False, **kwargs):
|
def convnext_tiny(pretrained=False, **kwargs):
|
||||||
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
|
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768))
|
||||||
model = _create_convnext('convnext_tiny', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnext_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_small(pretrained=False, **kwargs):
|
def convnext_small(pretrained=False, **kwargs):
|
||||||
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
|
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
|
||||||
model = _create_convnext('convnext_small', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnext_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_base(pretrained=False, **kwargs):
|
def convnext_base(pretrained=False, **kwargs):
|
||||||
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
|
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
|
||||||
model = _create_convnext('convnext_base', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnext_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_large(pretrained=False, **kwargs):
|
def convnext_large(pretrained=False, **kwargs):
|
||||||
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
|
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
|
||||||
model = _create_convnext('convnext_large', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnext_large', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_large_mlp(pretrained=False, **kwargs):
|
def convnext_large_mlp(pretrained=False, **kwargs):
|
||||||
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536, **kwargs)
|
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536)
|
||||||
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_xlarge(pretrained=False, **kwargs):
|
def convnext_xlarge(pretrained=False, **kwargs):
|
||||||
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
|
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])
|
||||||
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnext_xxlarge(pretrained=False, **kwargs):
|
def convnext_xxlarge(pretrained=False, **kwargs):
|
||||||
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], **kwargs)
|
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
|
||||||
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -910,8 +973,8 @@ def convnext_xxlarge(pretrained=False, **kwargs):
|
|||||||
def convnextv2_atto(pretrained=False, **kwargs):
|
def convnextv2_atto(pretrained=False, **kwargs):
|
||||||
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
|
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True)
|
||||||
model = _create_convnext('convnextv2_atto', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnextv2_atto', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -919,8 +982,8 @@ def convnextv2_atto(pretrained=False, **kwargs):
|
|||||||
def convnextv2_femto(pretrained=False, **kwargs):
|
def convnextv2_femto(pretrained=False, **kwargs):
|
||||||
# timm femto variant
|
# timm femto variant
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
|
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True)
|
||||||
model = _create_convnext('convnextv2_femto', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnextv2_femto', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -928,8 +991,8 @@ def convnextv2_femto(pretrained=False, **kwargs):
|
|||||||
def convnextv2_pico(pretrained=False, **kwargs):
|
def convnextv2_pico(pretrained=False, **kwargs):
|
||||||
# timm pico variant
|
# timm pico variant
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
|
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True)
|
||||||
model = _create_convnext('convnextv2_pico', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnextv2_pico', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ -937,42 +1000,41 @@ def convnextv2_pico(pretrained=False, **kwargs):
|
|||||||
def convnextv2_nano(pretrained=False, **kwargs):
|
def convnextv2_nano(pretrained=False, **kwargs):
|
||||||
# timm nano variant with standard stem and head
|
# timm nano variant with standard stem and head
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
|
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True)
|
||||||
model = _create_convnext('convnextv2_nano', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnextv2_nano', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnextv2_tiny(pretrained=False, **kwargs):
|
def convnextv2_tiny(pretrained=False, **kwargs):
|
||||||
model_args = dict(
|
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None)
|
||||||
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None, **kwargs)
|
model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **model_args)
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnextv2_small(pretrained=False, **kwargs):
|
def convnextv2_small(pretrained=False, **kwargs):
|
||||||
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None, **kwargs)
|
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None)
|
||||||
model = _create_convnext('convnextv2_small', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnextv2_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnextv2_base(pretrained=False, **kwargs):
|
def convnextv2_base(pretrained=False, **kwargs):
|
||||||
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None, **kwargs)
|
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None)
|
||||||
model = _create_convnext('convnextv2_base', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnextv2_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnextv2_large(pretrained=False, **kwargs):
|
def convnextv2_large(pretrained=False, **kwargs):
|
||||||
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None, **kwargs)
|
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None)
|
||||||
model = _create_convnext('convnextv2_large', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnextv2_large', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def convnextv2_huge(pretrained=False, **kwargs):
|
def convnextv2_huge(pretrained=False, **kwargs):
|
||||||
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None, **kwargs)
|
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
|
||||||
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **model_args)
|
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
return model
|
return model
|
@ -20,7 +20,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath
|
from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features_fx import register_notrace_module
|
from ._features_fx import register_notrace_module
|
||||||
from ._registry import register_model
|
from ._registry import register_model
|
||||||
@ -564,6 +564,7 @@ class MobileVitV2Block(nn.Module):
|
|||||||
|
|
||||||
self.patch_size = to_2tuple(patch_size)
|
self.patch_size = to_2tuple(patch_size)
|
||||||
self.patch_area = self.patch_size[0] * self.patch_size[1]
|
self.patch_area = self.patch_size[0] * self.patch_size[1]
|
||||||
|
self.coreml_exportable = is_exportable()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
B, C, H, W = x.shape
|
B, C, H, W = x.shape
|
||||||
@ -580,6 +581,9 @@ class MobileVitV2Block(nn.Module):
|
|||||||
|
|
||||||
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
|
# Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
|
||||||
C = x.shape[1]
|
C = x.shape[1]
|
||||||
|
if self.coreml_exportable:
|
||||||
|
x = F.unfold(x, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
|
||||||
|
else:
|
||||||
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
|
x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
|
||||||
x = x.reshape(B, C, -1, num_patches)
|
x = x.reshape(B, C, -1, num_patches)
|
||||||
|
|
||||||
@ -588,9 +592,15 @@ class MobileVitV2Block(nn.Module):
|
|||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
|
# Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
|
||||||
|
if self.coreml_exportable:
|
||||||
|
# adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624
|
||||||
|
x = x.reshape(B, C * patch_h * patch_w, num_patch_h, num_patch_w)
|
||||||
|
x = F.pixel_shuffle(x, upscale_factor=patch_h)
|
||||||
|
else:
|
||||||
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
|
x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
|
||||||
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
|
x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
|
||||||
|
|
||||||
|
|
||||||
x = self.conv_proj(x)
|
x = self.conv_proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -1 +1 @@
|
|||||||
__version__ = '0.8.12dev0'
|
__version__ = '0.8.15dev0'
|
||||||
|
8
train.py
8
train.py
@ -514,8 +514,14 @@ def main():
|
|||||||
if utils.is_primary(args):
|
if utils.is_primary(args):
|
||||||
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
|
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
|
||||||
elif use_amp == 'native':
|
elif use_amp == 'native':
|
||||||
|
try:
|
||||||
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
|
||||||
if device.type == 'cuda':
|
except (AttributeError, TypeError):
|
||||||
|
# fallback to CUDA only AMP for PyTorch < 1.10
|
||||||
|
assert device.type == 'cuda'
|
||||||
|
amp_autocast = torch.cuda.amp.autocast
|
||||||
|
if device.type == 'cuda' and amp_dtype == torch.float16:
|
||||||
|
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
|
||||||
loss_scaler = NativeScaler()
|
loss_scaler = NativeScaler()
|
||||||
if utils.is_primary(args):
|
if utils.is_primary(args):
|
||||||
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
_logger.info('Using native Torch AMP. Training in mixed precision.')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user