mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge remote-tracking branch 'origin/main' into adafactor_bv
This commit is contained in:
commit
326e5dcfd4
@ -7,7 +7,7 @@ Benchmark all 'vit*' models:
|
|||||||
python bulk_runner.py --model-list 'vit*' --results-file vit_bench.csv benchmark.py --amp -b 512
|
python bulk_runner.py --model-list 'vit*' --results-file vit_bench.csv benchmark.py --amp -b 512
|
||||||
|
|
||||||
Validate all models:
|
Validate all models:
|
||||||
python bulk_runner.py --model-list all --results-file val.csv --pretrained validate.py /imagenet/validation/ --amp -b 512 --retry
|
python bulk_runner.py --model-list all --results-file val.csv --pretrained validate.py --data-dir /imagenet/validation/ --amp -b 512 --retry
|
||||||
|
|
||||||
Hacked together by Ross Wightman (https://github.com/rwightman)
|
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||||
"""
|
"""
|
||||||
|
@ -228,7 +228,7 @@ Datasets & transform refactoring
|
|||||||
* Enabling either dynamic mode will break FX tracing unless PatchEmbed module added as leaf.
|
* Enabling either dynamic mode will break FX tracing unless PatchEmbed module added as leaf.
|
||||||
* Existing method of resizing position embedding by passing different `img_size` (interpolate pretrained embed weights once) on creation still works.
|
* Existing method of resizing position embedding by passing different `img_size` (interpolate pretrained embed weights once) on creation still works.
|
||||||
* Existing method of changing `patch_size` (resize pretrained patch_embed weights once) on creation still works.
|
* Existing method of changing `patch_size` (resize pretrained patch_embed weights once) on creation still works.
|
||||||
* Example validation cmd `python validate.py /imagenet --model vit_base_patch16_224 --amp --amp-dtype bfloat16 --img-size 255 --crop-pct 1.0 --model-kwargs dynamic_img_size=True dyamic_img_pad=True`
|
* Example validation cmd `python validate.py --data-dir /imagenet --model vit_base_patch16_224 --amp --amp-dtype bfloat16 --img-size 255 --crop-pct 1.0 --model-kwargs dynamic_img_size=True dyamic_img_pad=True`
|
||||||
|
|
||||||
### Aug 25, 2023
|
### Aug 25, 2023
|
||||||
* Many new models since last release
|
* Many new models since last release
|
||||||
@ -245,8 +245,8 @@ Datasets & transform refactoring
|
|||||||
|
|
||||||
### Aug 11, 2023
|
### Aug 11, 2023
|
||||||
* Swin, MaxViT, CoAtNet, and BEiT models support resizing of image/window size on creation with adaptation of pretrained weights
|
* Swin, MaxViT, CoAtNet, and BEiT models support resizing of image/window size on creation with adaptation of pretrained weights
|
||||||
* Example validation cmd to test w/ non-square resize `python validate.py /imagenet --model swin_base_patch4_window7_224.ms_in22k_ft_in1k --amp --amp-dtype bfloat16 --input-size 3 256 320 --model-kwargs window_size=8,10 img_size=256,320`
|
* Example validation cmd to test w/ non-square resize `python validate.py --data-dir /imagenet --model swin_base_patch4_window7_224.ms_in22k_ft_in1k --amp --amp-dtype bfloat16 --input-size 3 256 320 --model-kwargs window_size=8,10 img_size=256,320`
|
||||||
|
|
||||||
### Aug 3, 2023
|
### Aug 3, 2023
|
||||||
* Add GluonCV weights for HRNet w18_small and w18_small_v2. Converted by [SeeFun](https://github.com/seefun)
|
* Add GluonCV weights for HRNet w18_small and w18_small_v2. Converted by [SeeFun](https://github.com/seefun)
|
||||||
* Fix `selecsls*` model naming regression
|
* Fix `selecsls*` model naming regression
|
||||||
@ -385,7 +385,7 @@ Datasets & transform refactoring
|
|||||||
* Refactor LeViT models to stages, add `features_only=True` support to new `conv` variants, weight remap required.
|
* Refactor LeViT models to stages, add `features_only=True` support to new `conv` variants, weight remap required.
|
||||||
* Move ImageNet meta-data (synsets, indices) from `/results` to [`timm/data/_info`](timm/data/_info/).
|
* Move ImageNet meta-data (synsets, indices) from `/results` to [`timm/data/_info`](timm/data/_info/).
|
||||||
* Add ImageNetInfo / DatasetInfo classes to provide labelling for various ImageNet classifier layouts in `timm`
|
* Add ImageNetInfo / DatasetInfo classes to provide labelling for various ImageNet classifier layouts in `timm`
|
||||||
* Update `inference.py` to use, try: `python inference.py /folder/to/images --model convnext_small.in12k --label-type detail --topk 5`
|
* Update `inference.py` to use, try: `python inference.py --data-dir /folder/to/images --model convnext_small.in12k --label-type detail --topk 5`
|
||||||
* Ready for 0.8.10 pypi pre-release (final testing).
|
* Ready for 0.8.10 pypi pre-release (final testing).
|
||||||
|
|
||||||
### Jan 20, 2023
|
### Jan 20, 2023
|
||||||
@ -449,8 +449,8 @@ Datasets & transform refactoring
|
|||||||
|
|
||||||
### Jan 6, 2023
|
### Jan 6, 2023
|
||||||
* Finally got around to adding `--model-kwargs` and `--opt-kwargs` to scripts to pass through rare args directly to model classes from cmd line
|
* Finally got around to adding `--model-kwargs` and `--opt-kwargs` to scripts to pass through rare args directly to model classes from cmd line
|
||||||
* `train.py /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu`
|
* `train.py --data-dir /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu`
|
||||||
* `train.py /imagenet --model vit_base_patch16_clip_224 --img-size 240 --amp --model-kwargs img_size=240 patch_size=12`
|
* `train.py --data-dir /imagenet --model vit_base_patch16_clip_224 --img-size 240 --amp --model-kwargs img_size=240 patch_size=12`
|
||||||
* Cleanup some popular models to better support arg passthrough / merge with model configs, more to go.
|
* Cleanup some popular models to better support arg passthrough / merge with model configs, more to go.
|
||||||
|
|
||||||
### Jan 5, 2023
|
### Jan 5, 2023
|
||||||
|
@ -12,7 +12,7 @@ The variety of training args is large and not all combinations of options (or ev
|
|||||||
To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value:
|
To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4
|
./distributed_train.sh 4 --data-dir /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4
|
||||||
```
|
```
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
@ -27,13 +27,13 @@ Validation and inference scripts are similar in usage. One outputs metrics on a
|
|||||||
To validate with the model's pretrained weights (if they exist):
|
To validate with the model's pretrained weights (if they exist):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python validate.py /imagenet/validation/ --model seresnext26_32x4d --pretrained
|
python validate.py --data-dir /imagenet/validation/ --model seresnext26_32x4d --pretrained
|
||||||
```
|
```
|
||||||
|
|
||||||
To run inference from a checkpoint:
|
To run inference from a checkpoint:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar
|
python inference.py --data-dir /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar
|
||||||
```
|
```
|
||||||
|
|
||||||
## Training Examples
|
## Training Examples
|
||||||
@ -43,7 +43,7 @@ python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkp
|
|||||||
These params are for dual Titan RTX cards with NVIDIA Apex installed:
|
These params are for dual Titan RTX cards with NVIDIA Apex installed:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./distributed_train.sh 2 /imagenet/ --model efficientnet_b2 -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .016
|
./distributed_train.sh 2 --data-dir /imagenet/ --model efficientnet_b2 -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .016
|
||||||
```
|
```
|
||||||
|
|
||||||
### MixNet-XL with RandAugment - 80.5 top-1, 94.9 top-5
|
### MixNet-XL with RandAugment - 80.5 top-1, 94.9 top-5
|
||||||
@ -51,7 +51,7 @@ These params are for dual Titan RTX cards with NVIDIA Apex installed:
|
|||||||
This params are for dual Titan RTX cards with NVIDIA Apex installed:
|
This params are for dual Titan RTX cards with NVIDIA Apex installed:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./distributed_train.sh 2 /imagenet/ --model mixnet_xl -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .969 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.3 --amp --lr .016 --dist-bn reduce
|
./distributed_train.sh 2 --data-dir /imagenet/ --model mixnet_xl -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .969 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.3 --amp --lr .016 --dist-bn reduce
|
||||||
```
|
```
|
||||||
|
|
||||||
### SE-ResNeXt-26-D and SE-ResNeXt-26-T
|
### SE-ResNeXt-26-D and SE-ResNeXt-26-T
|
||||||
@ -59,7 +59,7 @@ This params are for dual Titan RTX cards with NVIDIA Apex installed:
|
|||||||
These hparams (or similar) work well for a wide range of ResNet architecture, generally a good idea to increase the epoch # as the model size increases... ie approx 180-200 for ResNe(X)t50, and 220+ for larger. Increase batch size and LR proportionally for better GPUs or with AMP enabled. These params were for 2 1080Ti cards:
|
These hparams (or similar) work well for a wide range of ResNet architecture, generally a good idea to increase the epoch # as the model size increases... ie approx 180-200 for ResNe(X)t50, and 220+ for larger. Increase batch size and LR proportionally for better GPUs or with AMP enabled. These params were for 2 1080Ti cards:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./distributed_train.sh 2 /imagenet/ --model seresnext26t_32x4d --lr 0.1 --warmup-epochs 5 --epochs 160 --weight-decay 1e-4 --sched cosine --reprob 0.4 --remode pixel -b 112
|
./distributed_train.sh 2 --data-dir /imagenet/ --model seresnext26t_32x4d --lr 0.1 --warmup-epochs 5 --epochs 160 --weight-decay 1e-4 --sched cosine --reprob 0.4 --remode pixel -b 112
|
||||||
```
|
```
|
||||||
### EfficientNet-B3 with RandAugment - 81.5 top-1, 95.7 top-5
|
### EfficientNet-B3 with RandAugment - 81.5 top-1, 95.7 top-5
|
||||||
|
|
||||||
@ -70,26 +70,26 @@ The training of this model started with the same command line as EfficientNet-B2
|
|||||||
[Michael Klachko](https://github.com/michaelklachko) achieved these results with the command line for B2 adapted for larger batch size, with the recommended B0 dropout rate of 0.2.
|
[Michael Klachko](https://github.com/michaelklachko) achieved these results with the command line for B2 adapted for larger batch size, with the recommended B0 dropout rate of 0.2.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./distributed_train.sh 2 /imagenet/ --model efficientnet_b0 -b 384 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048
|
./distributed_train.sh 2 --data-dir /imagenet/ --model efficientnet_b0 -b 384 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048
|
||||||
```
|
```
|
||||||
### ResNet50 with JSD loss and RandAugment (clean + 2x RA augs) - 79.04 top-1, 94.39 top-5
|
### ResNet50 with JSD loss and RandAugment (clean + 2x RA augs) - 79.04 top-1, 94.39 top-5
|
||||||
|
|
||||||
Trained on two older 1080Ti cards, this took a while. Only slightly, non statistically better ImageNet validation result than my first good AugMix training of 78.99. However, these weights are more robust on tests with ImageNetV2, ImageNet-Sketch, etc. Unlike my first AugMix runs, I've enabled SplitBatchNorm, disabled random erasing on the clean split, and cranked up random erasing prob on the 2 augmented paths.
|
Trained on two older 1080Ti cards, this took a while. Only slightly, non statistically better ImageNet validation result than my first good AugMix training of 78.99. However, these weights are more robust on tests with ImageNetV2, ImageNet-Sketch, etc. Unlike my first AugMix runs, I've enabled SplitBatchNorm, disabled random erasing on the clean split, and cranked up random erasing prob on the 2 augmented paths.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./distributed_train.sh 2 /imagenet -b 64 --model resnet50 --sched cosine --epochs 200 --lr 0.05 --amp --remode pixel --reprob 0.6 --aug-splits 3 --aa rand-m9-mstd0.5-inc1 --resplit --split-bn --jsd --dist-bn reduce
|
./distributed_train.sh 2 --data-dir /imagenet -b 64 --model resnet50 --sched cosine --epochs 200 --lr 0.05 --amp --remode pixel --reprob 0.6 --aug-splits 3 --aa rand-m9-mstd0.5-inc1 --resplit --split-bn --jsd --dist-bn reduce
|
||||||
```
|
```
|
||||||
### EfficientNet-ES (EdgeTPU-Small) with RandAugment - 78.066 top-1, 93.926 top-5
|
### EfficientNet-ES (EdgeTPU-Small) with RandAugment - 78.066 top-1, 93.926 top-5
|
||||||
|
|
||||||
Trained by [Andrew Lavin](https://github.com/andravin) with 8 V100 cards. Model EMA was not used, final checkpoint is the average of 8 best checkpoints during training.
|
Trained by [Andrew Lavin](https://github.com/andravin) with 8 V100 cards. Model EMA was not used, final checkpoint is the average of 8 best checkpoints during training.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064
|
./distributed_train.sh 8 --data-dir /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064
|
||||||
```
|
```
|
||||||
### MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5
|
### MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./distributed_train.sh 2 /imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9
|
./distributed_train.sh 2 /--data-dir imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9
|
||||||
```
|
```
|
||||||
|
|
||||||
### ResNeXt-50 32x4d w/ RandAugment - 79.762 top-1, 94.60 top-5
|
### ResNeXt-50 32x4d w/ RandAugment - 79.762 top-1, 94.60 top-5
|
||||||
@ -97,5 +97,5 @@ These params will also work well for SE-ResNeXt-50 and SK-ResNeXt-50 and likely
|
|||||||
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./distributed_train.sh 8 /imagenet --model resnext50_32x4d --lr 0.6 --warmup-epochs 5 --epochs 240 --weight-decay 1e-4 --sched cosine --reprob 0.4 --recount 3 --remode pixel --aa rand-m7-mstd0.5-inc1 -b 192 -j 6 --amp --dist-bn reduce
|
./distributed_train.sh 8 --data-dir /imagenet --model resnext50_32x4d --lr 0.6 --warmup-epochs 5 --epochs 240 --weight-decay 1e-4 --sched cosine --reprob 0.4 --recount 3 --remode pixel --aa rand-m7-mstd0.5-inc1 -b 192 -j 6 --amp --dist-bn reduce
|
||||||
```
|
```
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.layers import create_act_layer, set_layer_config
|
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
@ -76,3 +76,46 @@ def test_hard_swish_grad():
|
|||||||
def test_hard_mish_grad():
|
def test_hard_mish_grad():
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
_run_act_layer_grad('hard_mish')
|
_run_act_layer_grad('hard_mish')
|
||||||
|
|
||||||
|
def test_get_act_layer_empty_string():
|
||||||
|
# Empty string should return None
|
||||||
|
assert get_act_layer('') is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_act_layer_inplace_error():
|
||||||
|
class NoInplaceAct(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Should recover when inplace arg causes TypeError
|
||||||
|
layer = create_act_layer(NoInplaceAct, inplace=True)
|
||||||
|
assert isinstance(layer, NoInplaceAct)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_act_layer_edge_cases():
|
||||||
|
# Test None input
|
||||||
|
assert create_act_layer(None) is None
|
||||||
|
|
||||||
|
# Test TypeError handling for inplace
|
||||||
|
class CustomAct(nn.Module):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
result = create_act_layer(CustomAct, inplace=True)
|
||||||
|
assert isinstance(result, CustomAct)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_act_fn_callable():
|
||||||
|
def custom_act(x):
|
||||||
|
return x
|
||||||
|
assert get_act_fn(custom_act) is custom_act
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_act_fn_none():
|
||||||
|
assert get_act_fn(None) is None
|
||||||
|
assert get_act_fn('') is None
|
||||||
|
|
||||||
|
@ -11,6 +11,8 @@ from copy import deepcopy
|
|||||||
import torch
|
import torch
|
||||||
from torch.testing._internal.common_utils import TestCase
|
from torch.testing._internal.common_utils import TestCase
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from timm.optim.optim_factory import param_groups_layer_decay, param_groups_weight_decay
|
||||||
from timm.scheduler import PlateauLRScheduler
|
from timm.scheduler import PlateauLRScheduler
|
||||||
|
|
||||||
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class
|
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class
|
||||||
@ -495,3 +497,82 @@ def test_lookahead_radam(optimizer):
|
|||||||
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
|
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_param_groups_layer_decay_with_end_decay():
|
||||||
|
model = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(10, 5),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(5, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
param_groups = param_groups_layer_decay(
|
||||||
|
model,
|
||||||
|
weight_decay=0.05,
|
||||||
|
layer_decay=0.75,
|
||||||
|
end_layer_decay=0.5,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(param_groups) > 0
|
||||||
|
# Verify layer scaling is applied with end decay
|
||||||
|
for group in param_groups:
|
||||||
|
assert 'lr_scale' in group
|
||||||
|
assert group['lr_scale'] <= 1.0
|
||||||
|
assert group['lr_scale'] >= 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_param_groups_layer_decay_with_matcher():
|
||||||
|
class ModelWithMatcher(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.layer1 = torch.nn.Linear(10, 5)
|
||||||
|
self.layer2 = torch.nn.Linear(5, 2)
|
||||||
|
|
||||||
|
def group_matcher(self, coarse=False):
|
||||||
|
return lambda name: int(name.split('.')[0][-1])
|
||||||
|
|
||||||
|
model = ModelWithMatcher()
|
||||||
|
param_groups = param_groups_layer_decay(
|
||||||
|
model,
|
||||||
|
weight_decay=0.05,
|
||||||
|
layer_decay=0.75,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(param_groups) > 0
|
||||||
|
# Verify layer scaling is applied
|
||||||
|
for group in param_groups:
|
||||||
|
assert 'lr_scale' in group
|
||||||
|
assert 'weight_decay' in group
|
||||||
|
assert len(group['params']) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_param_groups_weight_decay():
|
||||||
|
model = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(10, 5),
|
||||||
|
torch.nn.ReLU(),
|
||||||
|
torch.nn.Linear(5, 2)
|
||||||
|
)
|
||||||
|
weight_decay = 0.01
|
||||||
|
no_weight_decay_list = ['1.weight']
|
||||||
|
|
||||||
|
param_groups = param_groups_weight_decay(
|
||||||
|
model,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
no_weight_decay_list=no_weight_decay_list
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(param_groups) == 2
|
||||||
|
assert param_groups[0]['weight_decay'] == 0.0
|
||||||
|
assert param_groups[1]['weight_decay'] == weight_decay
|
||||||
|
|
||||||
|
# Verify parameters are correctly grouped
|
||||||
|
no_decay_params = set(param_groups[0]['params'])
|
||||||
|
decay_params = set(param_groups[1]['params'])
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
|
||||||
|
assert param in no_decay_params
|
||||||
|
else:
|
||||||
|
assert param in decay_params
|
||||||
|
|
||||||
|
@ -2,8 +2,15 @@ from torch.nn.modules.batchnorm import BatchNorm2d
|
|||||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||||
|
|
||||||
import timm
|
import timm
|
||||||
|
import pytest
|
||||||
from timm.utils.model import freeze, unfreeze
|
from timm.utils.model import freeze, unfreeze
|
||||||
|
from timm.utils.model import ActivationStatsHook
|
||||||
|
from timm.utils.model import extract_spp_stats
|
||||||
|
|
||||||
|
from timm.utils.model import _freeze_unfreeze
|
||||||
|
from timm.utils.model import avg_sq_ch_mean, avg_ch_var, avg_ch_var_residual
|
||||||
|
from timm.utils.model import reparameterize_model
|
||||||
|
from timm.utils.model import get_state_dict
|
||||||
|
|
||||||
def test_freeze_unfreeze():
|
def test_freeze_unfreeze():
|
||||||
model = timm.create_model('resnet18')
|
model = timm.create_model('resnet18')
|
||||||
@ -54,4 +61,132 @@ def test_freeze_unfreeze():
|
|||||||
freeze(model.layer1[0], ['bn1'])
|
freeze(model.layer1[0], ['bn1'])
|
||||||
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
|
||||||
unfreeze(model.layer1[0], ['bn1'])
|
unfreeze(model.layer1[0], ['bn1'])
|
||||||
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
|
||||||
|
|
||||||
|
def test_activation_stats_hook_validation():
|
||||||
|
model = timm.create_model('resnet18')
|
||||||
|
|
||||||
|
def test_hook(model, input, output):
|
||||||
|
return output.mean().item()
|
||||||
|
|
||||||
|
# Test error case with mismatched lengths
|
||||||
|
with pytest.raises(ValueError, match="Please provide `hook_fns` for each `hook_fn_locs`"):
|
||||||
|
ActivationStatsHook(
|
||||||
|
model,
|
||||||
|
hook_fn_locs=['layer1.0.conv1', 'layer1.0.conv2'],
|
||||||
|
hook_fns=[test_hook]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_spp_stats():
|
||||||
|
model = timm.create_model('resnet18')
|
||||||
|
|
||||||
|
def test_hook(model, input, output):
|
||||||
|
return output.mean().item()
|
||||||
|
|
||||||
|
stats = extract_spp_stats(
|
||||||
|
model,
|
||||||
|
hook_fn_locs=['layer1.0.conv1'],
|
||||||
|
hook_fns=[test_hook],
|
||||||
|
input_shape=[2, 3, 32, 32]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(stats, dict)
|
||||||
|
assert test_hook.__name__ in stats
|
||||||
|
assert isinstance(stats[test_hook.__name__], list)
|
||||||
|
assert len(stats[test_hook.__name__]) > 0
|
||||||
|
|
||||||
|
def test_freeze_unfreeze_bn_root():
|
||||||
|
import torch.nn as nn
|
||||||
|
from timm.layers import BatchNormAct2d
|
||||||
|
|
||||||
|
# Create batch norm layers
|
||||||
|
bn = nn.BatchNorm2d(10)
|
||||||
|
bn_act = BatchNormAct2d(10)
|
||||||
|
|
||||||
|
# Test with BatchNorm2d as root
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
_freeze_unfreeze(bn, mode="freeze")
|
||||||
|
|
||||||
|
# Test with BatchNormAct2d as root
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
_freeze_unfreeze(bn_act, mode="freeze")
|
||||||
|
|
||||||
|
|
||||||
|
def test_activation_stats_functions():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Create sample input tensor [batch, channels, height, width]
|
||||||
|
x = torch.randn(2, 3, 4, 4)
|
||||||
|
|
||||||
|
# Test avg_sq_ch_mean
|
||||||
|
result1 = avg_sq_ch_mean(None, None, x)
|
||||||
|
assert isinstance(result1, float)
|
||||||
|
|
||||||
|
# Test avg_ch_var
|
||||||
|
result2 = avg_ch_var(None, None, x)
|
||||||
|
assert isinstance(result2, float)
|
||||||
|
|
||||||
|
# Test avg_ch_var_residual
|
||||||
|
result3 = avg_ch_var_residual(None, None, x)
|
||||||
|
assert isinstance(result3, float)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reparameterize_model():
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class FusableModule(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(3, 3, 1)
|
||||||
|
|
||||||
|
def fuse(self):
|
||||||
|
return nn.Identity()
|
||||||
|
|
||||||
|
class ModelWithFusable(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.fusable = FusableModule()
|
||||||
|
self.normal = nn.Linear(10, 10)
|
||||||
|
|
||||||
|
model = ModelWithFusable()
|
||||||
|
|
||||||
|
# Test with inplace=False (should create a copy)
|
||||||
|
new_model = reparameterize_model(model, inplace=False)
|
||||||
|
assert isinstance(new_model.fusable, nn.Identity)
|
||||||
|
assert isinstance(model.fusable, FusableModule) # Original unchanged
|
||||||
|
|
||||||
|
# Test with inplace=True
|
||||||
|
reparameterize_model(model, inplace=True)
|
||||||
|
assert isinstance(model.fusable, nn.Identity)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_state_dict_custom_unwrap():
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class CustomModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(10, 10)
|
||||||
|
|
||||||
|
model = CustomModel()
|
||||||
|
|
||||||
|
def custom_unwrap(m):
|
||||||
|
return m
|
||||||
|
|
||||||
|
state_dict = get_state_dict(model, unwrap_fn=custom_unwrap)
|
||||||
|
assert 'linear.weight' in state_dict
|
||||||
|
assert 'linear.bias' in state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def test_freeze_unfreeze_string_input():
|
||||||
|
model = timm.create_model('resnet18')
|
||||||
|
|
||||||
|
# Test with string input
|
||||||
|
_freeze_unfreeze(model, 'layer1', mode='freeze')
|
||||||
|
assert model.layer1[0].conv1.weight.requires_grad == False
|
||||||
|
|
||||||
|
# Test unfreezing with string input
|
||||||
|
_freeze_unfreeze(model, 'layer1', mode='unfreeze')
|
||||||
|
assert model.layer1[0].conv1.weight.requires_grad == True
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ class ReaderHfds(Reader):
|
|||||||
input_key: str = 'image',
|
input_key: str = 'image',
|
||||||
target_key: str = 'label',
|
target_key: str = 'label',
|
||||||
download: bool = False,
|
download: bool = False,
|
||||||
|
trust_remote_code: bool = False
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
@ -48,6 +49,7 @@ class ReaderHfds(Reader):
|
|||||||
name, # 'name' maps to path arg in hf datasets
|
name, # 'name' maps to path arg in hf datasets
|
||||||
split=split,
|
split=split,
|
||||||
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
|
||||||
|
trust_remote_code=trust_remote_code
|
||||||
)
|
)
|
||||||
# leave decode for caller, plus we want easy access to original path names...
|
# leave decode for caller, plus we want easy access to original path names...
|
||||||
self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False))
|
self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False))
|
||||||
|
@ -75,9 +75,11 @@ class VisionTransformerDistilled(VisionTransformer):
|
|||||||
def _pos_embed(self, x):
|
def _pos_embed(self, x):
|
||||||
if self.dynamic_img_size:
|
if self.dynamic_img_size:
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
|
prev_grid_size = self.patch_embed.grid_size
|
||||||
pos_embed = resample_abs_pos_embed(
|
pos_embed = resample_abs_pos_embed(
|
||||||
self.pos_embed,
|
self.pos_embed,
|
||||||
(H, W),
|
new_size=(H, W),
|
||||||
|
old_size=prev_grid_size,
|
||||||
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
||||||
)
|
)
|
||||||
x = x.view(B, -1, C)
|
x = x.view(B, -1, C)
|
||||||
|
@ -560,9 +560,11 @@ class Eva(nn.Module):
|
|||||||
if self.dynamic_img_size:
|
if self.dynamic_img_size:
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
if self.pos_embed is not None:
|
if self.pos_embed is not None:
|
||||||
|
prev_grid_size = self.patch_embed.grid_size
|
||||||
pos_embed = resample_abs_pos_embed(
|
pos_embed = resample_abs_pos_embed(
|
||||||
self.pos_embed,
|
self.pos_embed,
|
||||||
(H, W),
|
new_size=(H, W),
|
||||||
|
old_size=prev_grid_size,
|
||||||
num_prefix_tokens=self.num_prefix_tokens,
|
num_prefix_tokens=self.num_prefix_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -669,9 +669,11 @@ class VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
if self.dynamic_img_size:
|
if self.dynamic_img_size:
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
|
prev_grid_size = self.patch_embed.grid_size
|
||||||
pos_embed = resample_abs_pos_embed(
|
pos_embed = resample_abs_pos_embed(
|
||||||
self.pos_embed,
|
self.pos_embed,
|
||||||
(H, W),
|
new_size=(H, W),
|
||||||
|
old_size=prev_grid_size,
|
||||||
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
||||||
)
|
)
|
||||||
x = x.view(B, -1, C)
|
x = x.view(B, -1, C)
|
||||||
|
@ -111,6 +111,7 @@ class CosineLRScheduler(Scheduler):
|
|||||||
def get_cycle_length(self, cycles=0):
|
def get_cycle_length(self, cycles=0):
|
||||||
cycles = max(1, cycles or self.cycle_limit)
|
cycles = max(1, cycles or self.cycle_limit)
|
||||||
if self.cycle_mul == 1.0:
|
if self.cycle_mul == 1.0:
|
||||||
return self.t_initial * cycles
|
t = self.t_initial * cycles
|
||||||
else:
|
else:
|
||||||
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||||
|
return t + self.warmup_t if self.warmup_prefix else t
|
@ -107,6 +107,7 @@ class PolyLRScheduler(Scheduler):
|
|||||||
def get_cycle_length(self, cycles=0):
|
def get_cycle_length(self, cycles=0):
|
||||||
cycles = max(1, cycles or self.cycle_limit)
|
cycles = max(1, cycles or self.cycle_limit)
|
||||||
if self.cycle_mul == 1.0:
|
if self.cycle_mul == 1.0:
|
||||||
return self.t_initial * cycles
|
t = self.t_initial * cycles
|
||||||
else:
|
else:
|
||||||
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||||
|
return t + self.warmup_t if self.warmup_prefix else t
|
||||||
|
@ -196,11 +196,15 @@ def create_scheduler_v2(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if hasattr(lr_scheduler, 'get_cycle_length'):
|
if hasattr(lr_scheduler, 'get_cycle_length'):
|
||||||
# for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
|
# For cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
|
||||||
|
# NOTE: Warmup prefix added in get_cycle_lengths() if enabled
|
||||||
t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
|
t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
|
||||||
if step_on_epochs:
|
if step_on_epochs:
|
||||||
num_epochs = t_with_cycles_and_cooldown
|
num_epochs = t_with_cycles_and_cooldown
|
||||||
else:
|
else:
|
||||||
num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
|
num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
|
||||||
|
else:
|
||||||
|
if warmup_prefix:
|
||||||
|
num_epochs += warmup_epochs
|
||||||
|
|
||||||
return lr_scheduler, num_epochs
|
return lr_scheduler, num_epochs
|
||||||
|
@ -108,6 +108,7 @@ class TanhLRScheduler(Scheduler):
|
|||||||
def get_cycle_length(self, cycles=0):
|
def get_cycle_length(self, cycles=0):
|
||||||
cycles = max(1, cycles or self.cycle_limit)
|
cycles = max(1, cycles or self.cycle_limit)
|
||||||
if self.cycle_mul == 1.0:
|
if self.cycle_mul == 1.0:
|
||||||
return self.t_initial * cycles
|
t = self.t_initial * cycles
|
||||||
else:
|
else:
|
||||||
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
|
||||||
|
return t + self.warmup_t if self.warmup_prefix else t
|
||||||
|
9
train.py
9
train.py
@ -369,7 +369,7 @@ group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
|
|||||||
group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
||||||
help='how many training processes to use (default: 4)')
|
help='how many training processes to use (default: 4)')
|
||||||
group.add_argument('--save-images', action='store_true', default=False,
|
group.add_argument('--save-images', action='store_true', default=False,
|
||||||
help='save images of input bathes every log interval for debugging')
|
help='save images of input batches every log interval for debugging')
|
||||||
group.add_argument('--pin-mem', action='store_true', default=False,
|
group.add_argument('--pin-mem', action='store_true', default=False,
|
||||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||||
group.add_argument('--no-prefetcher', action='store_true', default=False,
|
group.add_argument('--no-prefetcher', action='store_true', default=False,
|
||||||
@ -840,8 +840,13 @@ def main():
|
|||||||
lr_scheduler.step(start_epoch)
|
lr_scheduler.step(start_epoch)
|
||||||
|
|
||||||
if utils.is_primary(args):
|
if utils.is_primary(args):
|
||||||
|
if args.warmup_prefix:
|
||||||
|
sched_explain = '(warmup_epochs + epochs + cooldown_epochs). Warmup added to total when warmup_prefix=True'
|
||||||
|
else:
|
||||||
|
sched_explain = '(epochs + cooldown_epochs). Warmup within epochs when warmup_prefix=False'
|
||||||
_logger.info(
|
_logger.info(
|
||||||
f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
|
f'Scheduled epochs: {num_epochs} {sched_explain}. '
|
||||||
|
f'LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
try:
|
try:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user