Merge remote-tracking branch 'origin/main' into adafactor_bv

This commit is contained in:
Ross Wightman 2024-11-12 16:52:42 -08:00
commit 326e5dcfd4
15 changed files with 311 additions and 32 deletions

View File

@ -7,7 +7,7 @@ Benchmark all 'vit*' models:
python bulk_runner.py --model-list 'vit*' --results-file vit_bench.csv benchmark.py --amp -b 512 python bulk_runner.py --model-list 'vit*' --results-file vit_bench.csv benchmark.py --amp -b 512
Validate all models: Validate all models:
python bulk_runner.py --model-list all --results-file val.csv --pretrained validate.py /imagenet/validation/ --amp -b 512 --retry python bulk_runner.py --model-list all --results-file val.csv --pretrained validate.py --data-dir /imagenet/validation/ --amp -b 512 --retry
Hacked together by Ross Wightman (https://github.com/rwightman) Hacked together by Ross Wightman (https://github.com/rwightman)
""" """

View File

@ -228,7 +228,7 @@ Datasets & transform refactoring
* Enabling either dynamic mode will break FX tracing unless PatchEmbed module added as leaf. * Enabling either dynamic mode will break FX tracing unless PatchEmbed module added as leaf.
* Existing method of resizing position embedding by passing different `img_size` (interpolate pretrained embed weights once) on creation still works. * Existing method of resizing position embedding by passing different `img_size` (interpolate pretrained embed weights once) on creation still works.
* Existing method of changing `patch_size` (resize pretrained patch_embed weights once) on creation still works. * Existing method of changing `patch_size` (resize pretrained patch_embed weights once) on creation still works.
* Example validation cmd `python validate.py /imagenet --model vit_base_patch16_224 --amp --amp-dtype bfloat16 --img-size 255 --crop-pct 1.0 --model-kwargs dynamic_img_size=True dyamic_img_pad=True` * Example validation cmd `python validate.py --data-dir /imagenet --model vit_base_patch16_224 --amp --amp-dtype bfloat16 --img-size 255 --crop-pct 1.0 --model-kwargs dynamic_img_size=True dyamic_img_pad=True`
### Aug 25, 2023 ### Aug 25, 2023
* Many new models since last release * Many new models since last release
@ -245,7 +245,7 @@ Datasets & transform refactoring
### Aug 11, 2023 ### Aug 11, 2023
* Swin, MaxViT, CoAtNet, and BEiT models support resizing of image/window size on creation with adaptation of pretrained weights * Swin, MaxViT, CoAtNet, and BEiT models support resizing of image/window size on creation with adaptation of pretrained weights
* Example validation cmd to test w/ non-square resize `python validate.py /imagenet --model swin_base_patch4_window7_224.ms_in22k_ft_in1k --amp --amp-dtype bfloat16 --input-size 3 256 320 --model-kwargs window_size=8,10 img_size=256,320` * Example validation cmd to test w/ non-square resize `python validate.py --data-dir /imagenet --model swin_base_patch4_window7_224.ms_in22k_ft_in1k --amp --amp-dtype bfloat16 --input-size 3 256 320 --model-kwargs window_size=8,10 img_size=256,320`
### Aug 3, 2023 ### Aug 3, 2023
* Add GluonCV weights for HRNet w18_small and w18_small_v2. Converted by [SeeFun](https://github.com/seefun) * Add GluonCV weights for HRNet w18_small and w18_small_v2. Converted by [SeeFun](https://github.com/seefun)
@ -385,7 +385,7 @@ Datasets & transform refactoring
* Refactor LeViT models to stages, add `features_only=True` support to new `conv` variants, weight remap required. * Refactor LeViT models to stages, add `features_only=True` support to new `conv` variants, weight remap required.
* Move ImageNet meta-data (synsets, indices) from `/results` to [`timm/data/_info`](timm/data/_info/). * Move ImageNet meta-data (synsets, indices) from `/results` to [`timm/data/_info`](timm/data/_info/).
* Add ImageNetInfo / DatasetInfo classes to provide labelling for various ImageNet classifier layouts in `timm` * Add ImageNetInfo / DatasetInfo classes to provide labelling for various ImageNet classifier layouts in `timm`
* Update `inference.py` to use, try: `python inference.py /folder/to/images --model convnext_small.in12k --label-type detail --topk 5` * Update `inference.py` to use, try: `python inference.py --data-dir /folder/to/images --model convnext_small.in12k --label-type detail --topk 5`
* Ready for 0.8.10 pypi pre-release (final testing). * Ready for 0.8.10 pypi pre-release (final testing).
### Jan 20, 2023 ### Jan 20, 2023
@ -449,8 +449,8 @@ Datasets & transform refactoring
### Jan 6, 2023 ### Jan 6, 2023
* Finally got around to adding `--model-kwargs` and `--opt-kwargs` to scripts to pass through rare args directly to model classes from cmd line * Finally got around to adding `--model-kwargs` and `--opt-kwargs` to scripts to pass through rare args directly to model classes from cmd line
* `train.py /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu` * `train.py --data-dir /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu`
* `train.py /imagenet --model vit_base_patch16_clip_224 --img-size 240 --amp --model-kwargs img_size=240 patch_size=12` * `train.py --data-dir /imagenet --model vit_base_patch16_clip_224 --img-size 240 --amp --model-kwargs img_size=240 patch_size=12`
* Cleanup some popular models to better support arg passthrough / merge with model configs, more to go. * Cleanup some popular models to better support arg passthrough / merge with model configs, more to go.
### Jan 5, 2023 ### Jan 5, 2023

View File

@ -12,7 +12,7 @@ The variety of training args is large and not all combinations of options (or ev
To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value: To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value:
```bash ```bash
./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4 ./distributed_train.sh 4 --data-dir /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4
``` ```
<Tip> <Tip>
@ -27,13 +27,13 @@ Validation and inference scripts are similar in usage. One outputs metrics on a
To validate with the model's pretrained weights (if they exist): To validate with the model's pretrained weights (if they exist):
```bash ```bash
python validate.py /imagenet/validation/ --model seresnext26_32x4d --pretrained python validate.py --data-dir /imagenet/validation/ --model seresnext26_32x4d --pretrained
``` ```
To run inference from a checkpoint: To run inference from a checkpoint:
```bash ```bash
python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar python inference.py --data-dir /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar
``` ```
## Training Examples ## Training Examples
@ -43,7 +43,7 @@ python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkp
These params are for dual Titan RTX cards with NVIDIA Apex installed: These params are for dual Titan RTX cards with NVIDIA Apex installed:
```bash ```bash
./distributed_train.sh 2 /imagenet/ --model efficientnet_b2 -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .016 ./distributed_train.sh 2 --data-dir /imagenet/ --model efficientnet_b2 -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .016
``` ```
### MixNet-XL with RandAugment - 80.5 top-1, 94.9 top-5 ### MixNet-XL with RandAugment - 80.5 top-1, 94.9 top-5
@ -51,7 +51,7 @@ These params are for dual Titan RTX cards with NVIDIA Apex installed:
This params are for dual Titan RTX cards with NVIDIA Apex installed: This params are for dual Titan RTX cards with NVIDIA Apex installed:
```bash ```bash
./distributed_train.sh 2 /imagenet/ --model mixnet_xl -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .969 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.3 --amp --lr .016 --dist-bn reduce ./distributed_train.sh 2 --data-dir /imagenet/ --model mixnet_xl -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .969 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.3 --amp --lr .016 --dist-bn reduce
``` ```
### SE-ResNeXt-26-D and SE-ResNeXt-26-T ### SE-ResNeXt-26-D and SE-ResNeXt-26-T
@ -59,7 +59,7 @@ This params are for dual Titan RTX cards with NVIDIA Apex installed:
These hparams (or similar) work well for a wide range of ResNet architecture, generally a good idea to increase the epoch # as the model size increases... ie approx 180-200 for ResNe(X)t50, and 220+ for larger. Increase batch size and LR proportionally for better GPUs or with AMP enabled. These params were for 2 1080Ti cards: These hparams (or similar) work well for a wide range of ResNet architecture, generally a good idea to increase the epoch # as the model size increases... ie approx 180-200 for ResNe(X)t50, and 220+ for larger. Increase batch size and LR proportionally for better GPUs or with AMP enabled. These params were for 2 1080Ti cards:
```bash ```bash
./distributed_train.sh 2 /imagenet/ --model seresnext26t_32x4d --lr 0.1 --warmup-epochs 5 --epochs 160 --weight-decay 1e-4 --sched cosine --reprob 0.4 --remode pixel -b 112 ./distributed_train.sh 2 --data-dir /imagenet/ --model seresnext26t_32x4d --lr 0.1 --warmup-epochs 5 --epochs 160 --weight-decay 1e-4 --sched cosine --reprob 0.4 --remode pixel -b 112
``` ```
### EfficientNet-B3 with RandAugment - 81.5 top-1, 95.7 top-5 ### EfficientNet-B3 with RandAugment - 81.5 top-1, 95.7 top-5
@ -70,26 +70,26 @@ The training of this model started with the same command line as EfficientNet-B2
[Michael Klachko](https://github.com/michaelklachko) achieved these results with the command line for B2 adapted for larger batch size, with the recommended B0 dropout rate of 0.2. [Michael Klachko](https://github.com/michaelklachko) achieved these results with the command line for B2 adapted for larger batch size, with the recommended B0 dropout rate of 0.2.
```bash ```bash
./distributed_train.sh 2 /imagenet/ --model efficientnet_b0 -b 384 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048 ./distributed_train.sh 2 --data-dir /imagenet/ --model efficientnet_b0 -b 384 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048
``` ```
### ResNet50 with JSD loss and RandAugment (clean + 2x RA augs) - 79.04 top-1, 94.39 top-5 ### ResNet50 with JSD loss and RandAugment (clean + 2x RA augs) - 79.04 top-1, 94.39 top-5
Trained on two older 1080Ti cards, this took a while. Only slightly, non statistically better ImageNet validation result than my first good AugMix training of 78.99. However, these weights are more robust on tests with ImageNetV2, ImageNet-Sketch, etc. Unlike my first AugMix runs, I've enabled SplitBatchNorm, disabled random erasing on the clean split, and cranked up random erasing prob on the 2 augmented paths. Trained on two older 1080Ti cards, this took a while. Only slightly, non statistically better ImageNet validation result than my first good AugMix training of 78.99. However, these weights are more robust on tests with ImageNetV2, ImageNet-Sketch, etc. Unlike my first AugMix runs, I've enabled SplitBatchNorm, disabled random erasing on the clean split, and cranked up random erasing prob on the 2 augmented paths.
```bash ```bash
./distributed_train.sh 2 /imagenet -b 64 --model resnet50 --sched cosine --epochs 200 --lr 0.05 --amp --remode pixel --reprob 0.6 --aug-splits 3 --aa rand-m9-mstd0.5-inc1 --resplit --split-bn --jsd --dist-bn reduce ./distributed_train.sh 2 --data-dir /imagenet -b 64 --model resnet50 --sched cosine --epochs 200 --lr 0.05 --amp --remode pixel --reprob 0.6 --aug-splits 3 --aa rand-m9-mstd0.5-inc1 --resplit --split-bn --jsd --dist-bn reduce
``` ```
### EfficientNet-ES (EdgeTPU-Small) with RandAugment - 78.066 top-1, 93.926 top-5 ### EfficientNet-ES (EdgeTPU-Small) with RandAugment - 78.066 top-1, 93.926 top-5
Trained by [Andrew Lavin](https://github.com/andravin) with 8 V100 cards. Model EMA was not used, final checkpoint is the average of 8 best checkpoints during training. Trained by [Andrew Lavin](https://github.com/andravin) with 8 V100 cards. Model EMA was not used, final checkpoint is the average of 8 best checkpoints during training.
```bash ```bash
./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 ./distributed_train.sh 8 --data-dir /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064
``` ```
### MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5 ### MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5
```bash ```bash
./distributed_train.sh 2 /imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9 ./distributed_train.sh 2 /--data-dir imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9
``` ```
### ResNeXt-50 32x4d w/ RandAugment - 79.762 top-1, 94.60 top-5 ### ResNeXt-50 32x4d w/ RandAugment - 79.762 top-1, 94.60 top-5
@ -97,5 +97,5 @@ These params will also work well for SE-ResNeXt-50 and SK-ResNeXt-50 and likely
```bash ```bash
./distributed_train.sh 8 /imagenet --model resnext50_32x4d --lr 0.6 --warmup-epochs 5 --epochs 240 --weight-decay 1e-4 --sched cosine --reprob 0.4 --recount 3 --remode pixel --aa rand-m7-mstd0.5-inc1 -b 192 -j 6 --amp --dist-bn reduce ./distributed_train.sh 8 --data-dir /imagenet --model resnext50_32x4d --lr 0.6 --warmup-epochs 5 --epochs 240 --weight-decay 1e-4 --sched cosine --reprob 0.4 --recount 3 --remode pixel --aa rand-m7-mstd0.5-inc1 -b 192 -j 6 --amp --dist-bn reduce
``` ```

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.layers import create_act_layer, set_layer_config from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn
import importlib import importlib
import os import os
@ -76,3 +76,46 @@ def test_hard_swish_grad():
def test_hard_mish_grad(): def test_hard_mish_grad():
for _ in range(100): for _ in range(100):
_run_act_layer_grad('hard_mish') _run_act_layer_grad('hard_mish')
def test_get_act_layer_empty_string():
# Empty string should return None
assert get_act_layer('') is None
def test_create_act_layer_inplace_error():
class NoInplaceAct(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
# Should recover when inplace arg causes TypeError
layer = create_act_layer(NoInplaceAct, inplace=True)
assert isinstance(layer, NoInplaceAct)
def test_create_act_layer_edge_cases():
# Test None input
assert create_act_layer(None) is None
# Test TypeError handling for inplace
class CustomAct(nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x):
return x
result = create_act_layer(CustomAct, inplace=True)
assert isinstance(result, CustomAct)
def test_get_act_fn_callable():
def custom_act(x):
return x
assert get_act_fn(custom_act) is custom_act
def test_get_act_fn_none():
assert get_act_fn(None) is None
assert get_act_fn('') is None

View File

@ -11,6 +11,8 @@ from copy import deepcopy
import torch import torch
from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_utils import TestCase
from torch.nn import Parameter from torch.nn import Parameter
from timm.optim.optim_factory import param_groups_layer_decay, param_groups_weight_decay
from timm.scheduler import PlateauLRScheduler from timm.scheduler import PlateauLRScheduler
from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class from timm.optim import create_optimizer_v2, list_optimizers, get_optimizer_class
@ -495,3 +497,82 @@ def test_lookahead_radam(optimizer):
lambda params: create_optimizer_v2(params, optimizer, lr=1e-4) lambda params: create_optimizer_v2(params, optimizer, lr=1e-4)
) )
def test_param_groups_layer_decay_with_end_decay():
model = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 2)
)
param_groups = param_groups_layer_decay(
model,
weight_decay=0.05,
layer_decay=0.75,
end_layer_decay=0.5,
verbose=True
)
assert len(param_groups) > 0
# Verify layer scaling is applied with end decay
for group in param_groups:
assert 'lr_scale' in group
assert group['lr_scale'] <= 1.0
assert group['lr_scale'] >= 0.5
def test_param_groups_layer_decay_with_matcher():
class ModelWithMatcher(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(10, 5)
self.layer2 = torch.nn.Linear(5, 2)
def group_matcher(self, coarse=False):
return lambda name: int(name.split('.')[0][-1])
model = ModelWithMatcher()
param_groups = param_groups_layer_decay(
model,
weight_decay=0.05,
layer_decay=0.75,
verbose=True
)
assert len(param_groups) > 0
# Verify layer scaling is applied
for group in param_groups:
assert 'lr_scale' in group
assert 'weight_decay' in group
assert len(group['params']) > 0
def test_param_groups_weight_decay():
model = torch.nn.Sequential(
torch.nn.Linear(10, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 2)
)
weight_decay = 0.01
no_weight_decay_list = ['1.weight']
param_groups = param_groups_weight_decay(
model,
weight_decay=weight_decay,
no_weight_decay_list=no_weight_decay_list
)
assert len(param_groups) == 2
assert param_groups[0]['weight_decay'] == 0.0
assert param_groups[1]['weight_decay'] == weight_decay
# Verify parameters are correctly grouped
no_decay_params = set(param_groups[0]['params'])
decay_params = set(param_groups[1]['params'])
for name, param in model.named_parameters():
if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
assert param in no_decay_params
else:
assert param in decay_params

View File

@ -2,8 +2,15 @@ from torch.nn.modules.batchnorm import BatchNorm2d
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
import timm import timm
import pytest
from timm.utils.model import freeze, unfreeze from timm.utils.model import freeze, unfreeze
from timm.utils.model import ActivationStatsHook
from timm.utils.model import extract_spp_stats
from timm.utils.model import _freeze_unfreeze
from timm.utils.model import avg_sq_ch_mean, avg_ch_var, avg_ch_var_residual
from timm.utils.model import reparameterize_model
from timm.utils.model import get_state_dict
def test_freeze_unfreeze(): def test_freeze_unfreeze():
model = timm.create_model('resnet18') model = timm.create_model('resnet18')
@ -55,3 +62,131 @@ def test_freeze_unfreeze():
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
unfreeze(model.layer1[0], ['bn1']) unfreeze(model.layer1[0], ['bn1'])
assert isinstance(model.layer1[0].bn1, BatchNorm2d) assert isinstance(model.layer1[0].bn1, BatchNorm2d)
def test_activation_stats_hook_validation():
model = timm.create_model('resnet18')
def test_hook(model, input, output):
return output.mean().item()
# Test error case with mismatched lengths
with pytest.raises(ValueError, match="Please provide `hook_fns` for each `hook_fn_locs`"):
ActivationStatsHook(
model,
hook_fn_locs=['layer1.0.conv1', 'layer1.0.conv2'],
hook_fns=[test_hook]
)
def test_extract_spp_stats():
model = timm.create_model('resnet18')
def test_hook(model, input, output):
return output.mean().item()
stats = extract_spp_stats(
model,
hook_fn_locs=['layer1.0.conv1'],
hook_fns=[test_hook],
input_shape=[2, 3, 32, 32]
)
assert isinstance(stats, dict)
assert test_hook.__name__ in stats
assert isinstance(stats[test_hook.__name__], list)
assert len(stats[test_hook.__name__]) > 0
def test_freeze_unfreeze_bn_root():
import torch.nn as nn
from timm.layers import BatchNormAct2d
# Create batch norm layers
bn = nn.BatchNorm2d(10)
bn_act = BatchNormAct2d(10)
# Test with BatchNorm2d as root
with pytest.raises(AssertionError):
_freeze_unfreeze(bn, mode="freeze")
# Test with BatchNormAct2d as root
with pytest.raises(AssertionError):
_freeze_unfreeze(bn_act, mode="freeze")
def test_activation_stats_functions():
import torch
# Create sample input tensor [batch, channels, height, width]
x = torch.randn(2, 3, 4, 4)
# Test avg_sq_ch_mean
result1 = avg_sq_ch_mean(None, None, x)
assert isinstance(result1, float)
# Test avg_ch_var
result2 = avg_ch_var(None, None, x)
assert isinstance(result2, float)
# Test avg_ch_var_residual
result3 = avg_ch_var_residual(None, None, x)
assert isinstance(result3, float)
def test_reparameterize_model():
import torch.nn as nn
class FusableModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
def fuse(self):
return nn.Identity()
class ModelWithFusable(nn.Module):
def __init__(self):
super().__init__()
self.fusable = FusableModule()
self.normal = nn.Linear(10, 10)
model = ModelWithFusable()
# Test with inplace=False (should create a copy)
new_model = reparameterize_model(model, inplace=False)
assert isinstance(new_model.fusable, nn.Identity)
assert isinstance(model.fusable, FusableModule) # Original unchanged
# Test with inplace=True
reparameterize_model(model, inplace=True)
assert isinstance(model.fusable, nn.Identity)
def test_get_state_dict_custom_unwrap():
import torch.nn as nn
class CustomModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
model = CustomModel()
def custom_unwrap(m):
return m
state_dict = get_state_dict(model, unwrap_fn=custom_unwrap)
assert 'linear.weight' in state_dict
assert 'linear.bias' in state_dict
def test_freeze_unfreeze_string_input():
model = timm.create_model('resnet18')
# Test with string input
_freeze_unfreeze(model, 'layer1', mode='freeze')
assert model.layer1[0].conv1.weight.requires_grad == False
# Test unfreezing with string input
_freeze_unfreeze(model, 'layer1', mode='unfreeze')
assert model.layer1[0].conv1.weight.requires_grad == True

View File

@ -38,6 +38,7 @@ class ReaderHfds(Reader):
input_key: str = 'image', input_key: str = 'image',
target_key: str = 'label', target_key: str = 'label',
download: bool = False, download: bool = False,
trust_remote_code: bool = False
): ):
""" """
""" """
@ -48,6 +49,7 @@ class ReaderHfds(Reader):
name, # 'name' maps to path arg in hf datasets name, # 'name' maps to path arg in hf datasets
split=split, split=split,
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
trust_remote_code=trust_remote_code
) )
# leave decode for caller, plus we want easy access to original path names... # leave decode for caller, plus we want easy access to original path names...
self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False)) self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False))

View File

@ -75,9 +75,11 @@ class VisionTransformerDistilled(VisionTransformer):
def _pos_embed(self, x): def _pos_embed(self, x):
if self.dynamic_img_size: if self.dynamic_img_size:
B, H, W, C = x.shape B, H, W, C = x.shape
prev_grid_size = self.patch_embed.grid_size
pos_embed = resample_abs_pos_embed( pos_embed = resample_abs_pos_embed(
self.pos_embed, self.pos_embed,
(H, W), new_size=(H, W),
old_size=prev_grid_size,
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
) )
x = x.view(B, -1, C) x = x.view(B, -1, C)

View File

@ -560,9 +560,11 @@ class Eva(nn.Module):
if self.dynamic_img_size: if self.dynamic_img_size:
B, H, W, C = x.shape B, H, W, C = x.shape
if self.pos_embed is not None: if self.pos_embed is not None:
prev_grid_size = self.patch_embed.grid_size
pos_embed = resample_abs_pos_embed( pos_embed = resample_abs_pos_embed(
self.pos_embed, self.pos_embed,
(H, W), new_size=(H, W),
old_size=prev_grid_size,
num_prefix_tokens=self.num_prefix_tokens, num_prefix_tokens=self.num_prefix_tokens,
) )
else: else:

View File

@ -669,9 +669,11 @@ class VisionTransformer(nn.Module):
if self.dynamic_img_size: if self.dynamic_img_size:
B, H, W, C = x.shape B, H, W, C = x.shape
prev_grid_size = self.patch_embed.grid_size
pos_embed = resample_abs_pos_embed( pos_embed = resample_abs_pos_embed(
self.pos_embed, self.pos_embed,
(H, W), new_size=(H, W),
old_size=prev_grid_size,
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
) )
x = x.view(B, -1, C) x = x.view(B, -1, C)

View File

@ -111,6 +111,7 @@ class CosineLRScheduler(Scheduler):
def get_cycle_length(self, cycles=0): def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit) cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0: if self.cycle_mul == 1.0:
return self.t_initial * cycles t = self.t_initial * cycles
else: else:
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
return t + self.warmup_t if self.warmup_prefix else t

View File

@ -107,6 +107,7 @@ class PolyLRScheduler(Scheduler):
def get_cycle_length(self, cycles=0): def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit) cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0: if self.cycle_mul == 1.0:
return self.t_initial * cycles t = self.t_initial * cycles
else: else:
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
return t + self.warmup_t if self.warmup_prefix else t

View File

@ -196,11 +196,15 @@ def create_scheduler_v2(
) )
if hasattr(lr_scheduler, 'get_cycle_length'): if hasattr(lr_scheduler, 'get_cycle_length'):
# for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown # For cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
# NOTE: Warmup prefix added in get_cycle_lengths() if enabled
t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
if step_on_epochs: if step_on_epochs:
num_epochs = t_with_cycles_and_cooldown num_epochs = t_with_cycles_and_cooldown
else: else:
num_epochs = t_with_cycles_and_cooldown // updates_per_epoch num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
else:
if warmup_prefix:
num_epochs += warmup_epochs
return lr_scheduler, num_epochs return lr_scheduler, num_epochs

View File

@ -108,6 +108,7 @@ class TanhLRScheduler(Scheduler):
def get_cycle_length(self, cycles=0): def get_cycle_length(self, cycles=0):
cycles = max(1, cycles or self.cycle_limit) cycles = max(1, cycles or self.cycle_limit)
if self.cycle_mul == 1.0: if self.cycle_mul == 1.0:
return self.t_initial * cycles t = self.t_initial * cycles
else: else:
return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
return t + self.warmup_t if self.warmup_prefix else t

View File

@ -369,7 +369,7 @@ group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
group.add_argument('-j', '--workers', type=int, default=4, metavar='N', group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)') help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False, group.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging') help='save images of input batches every log interval for debugging')
group.add_argument('--pin-mem', action='store_true', default=False, group.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False, group.add_argument('--no-prefetcher', action='store_true', default=False,
@ -840,8 +840,13 @@ def main():
lr_scheduler.step(start_epoch) lr_scheduler.step(start_epoch)
if utils.is_primary(args): if utils.is_primary(args):
if args.warmup_prefix:
sched_explain = '(warmup_epochs + epochs + cooldown_epochs). Warmup added to total when warmup_prefix=True'
else:
sched_explain = '(epochs + cooldown_epochs). Warmup within epochs when warmup_prefix=False'
_logger.info( _logger.info(
f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.') f'Scheduled epochs: {num_epochs} {sched_explain}. '
f'LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
results = [] results = []
try: try: