From 7e5477acf5a58a9a74c73d0ee7528ef7ce1da766 Mon Sep 17 00:00:00 2001 From: Josua Rieder Date: Mon, 4 Nov 2024 12:22:06 +0100 Subject: [PATCH 01/10] Replace deprecated positional argument with --data-dir --- bulk_runner.py | 2 +- hfdocs/source/changes.mdx | 12 ++++++------ hfdocs/source/training_script.mdx | 22 +++++++++++----------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/bulk_runner.py b/bulk_runner.py index 286059c2..cb8bc7d1 100755 --- a/bulk_runner.py +++ b/bulk_runner.py @@ -7,7 +7,7 @@ Benchmark all 'vit*' models: python bulk_runner.py --model-list 'vit*' --results-file vit_bench.csv benchmark.py --amp -b 512 Validate all models: -python bulk_runner.py --model-list all --results-file val.csv --pretrained validate.py /imagenet/validation/ --amp -b 512 --retry +python bulk_runner.py --model-list all --results-file val.csv --pretrained validate.py --data-dir /imagenet/validation/ --amp -b 512 --retry Hacked together by Ross Wightman (https://github.com/rwightman) """ diff --git a/hfdocs/source/changes.mdx b/hfdocs/source/changes.mdx index ddcf2dc5..014fa0da 100644 --- a/hfdocs/source/changes.mdx +++ b/hfdocs/source/changes.mdx @@ -228,7 +228,7 @@ Datasets & transform refactoring * Enabling either dynamic mode will break FX tracing unless PatchEmbed module added as leaf. * Existing method of resizing position embedding by passing different `img_size` (interpolate pretrained embed weights once) on creation still works. * Existing method of changing `patch_size` (resize pretrained patch_embed weights once) on creation still works. - * Example validation cmd `python validate.py /imagenet --model vit_base_patch16_224 --amp --amp-dtype bfloat16 --img-size 255 --crop-pct 1.0 --model-kwargs dynamic_img_size=True dyamic_img_pad=True` + * Example validation cmd `python validate.py --data-dir /imagenet --model vit_base_patch16_224 --amp --amp-dtype bfloat16 --img-size 255 --crop-pct 1.0 --model-kwargs dynamic_img_size=True dyamic_img_pad=True` ### Aug 25, 2023 * Many new models since last release @@ -245,8 +245,8 @@ Datasets & transform refactoring ### Aug 11, 2023 * Swin, MaxViT, CoAtNet, and BEiT models support resizing of image/window size on creation with adaptation of pretrained weights -* Example validation cmd to test w/ non-square resize `python validate.py /imagenet --model swin_base_patch4_window7_224.ms_in22k_ft_in1k --amp --amp-dtype bfloat16 --input-size 3 256 320 --model-kwargs window_size=8,10 img_size=256,320` - +* Example validation cmd to test w/ non-square resize `python validate.py --data-dir /imagenet --model swin_base_patch4_window7_224.ms_in22k_ft_in1k --amp --amp-dtype bfloat16 --input-size 3 256 320 --model-kwargs window_size=8,10 img_size=256,320` + ### Aug 3, 2023 * Add GluonCV weights for HRNet w18_small and w18_small_v2. Converted by [SeeFun](https://github.com/seefun) * Fix `selecsls*` model naming regression @@ -385,7 +385,7 @@ Datasets & transform refactoring * Refactor LeViT models to stages, add `features_only=True` support to new `conv` variants, weight remap required. * Move ImageNet meta-data (synsets, indices) from `/results` to [`timm/data/_info`](timm/data/_info/). * Add ImageNetInfo / DatasetInfo classes to provide labelling for various ImageNet classifier layouts in `timm` - * Update `inference.py` to use, try: `python inference.py /folder/to/images --model convnext_small.in12k --label-type detail --topk 5` + * Update `inference.py` to use, try: `python inference.py --data-dir /folder/to/images --model convnext_small.in12k --label-type detail --topk 5` * Ready for 0.8.10 pypi pre-release (final testing). ### Jan 20, 2023 @@ -449,8 +449,8 @@ Datasets & transform refactoring ### Jan 6, 2023 * Finally got around to adding `--model-kwargs` and `--opt-kwargs` to scripts to pass through rare args directly to model classes from cmd line - * `train.py /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu` - * `train.py /imagenet --model vit_base_patch16_clip_224 --img-size 240 --amp --model-kwargs img_size=240 patch_size=12` + * `train.py --data-dir /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu` + * `train.py --data-dir /imagenet --model vit_base_patch16_clip_224 --img-size 240 --amp --model-kwargs img_size=240 patch_size=12` * Cleanup some popular models to better support arg passthrough / merge with model configs, more to go. ### Jan 5, 2023 diff --git a/hfdocs/source/training_script.mdx b/hfdocs/source/training_script.mdx index 3eb772a3..a3641033 100644 --- a/hfdocs/source/training_script.mdx +++ b/hfdocs/source/training_script.mdx @@ -12,7 +12,7 @@ The variety of training args is large and not all combinations of options (or ev To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value: ```bash -./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4 +./distributed_train.sh 4 --data-dir /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4 ``` @@ -27,13 +27,13 @@ Validation and inference scripts are similar in usage. One outputs metrics on a To validate with the model's pretrained weights (if they exist): ```bash -python validate.py /imagenet/validation/ --model seresnext26_32x4d --pretrained +python validate.py --data-dir /imagenet/validation/ --model seresnext26_32x4d --pretrained ``` To run inference from a checkpoint: ```bash -python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar +python inference.py --data-dir /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar ``` ## Training Examples @@ -43,7 +43,7 @@ python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkp These params are for dual Titan RTX cards with NVIDIA Apex installed: ```bash -./distributed_train.sh 2 /imagenet/ --model efficientnet_b2 -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .016 +./distributed_train.sh 2 --data-dir /imagenet/ --model efficientnet_b2 -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .016 ``` ### MixNet-XL with RandAugment - 80.5 top-1, 94.9 top-5 @@ -51,7 +51,7 @@ These params are for dual Titan RTX cards with NVIDIA Apex installed: This params are for dual Titan RTX cards with NVIDIA Apex installed: ```bash -./distributed_train.sh 2 /imagenet/ --model mixnet_xl -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .969 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.3 --amp --lr .016 --dist-bn reduce +./distributed_train.sh 2 --data-dir /imagenet/ --model mixnet_xl -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .969 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.3 --amp --lr .016 --dist-bn reduce ``` ### SE-ResNeXt-26-D and SE-ResNeXt-26-T @@ -59,7 +59,7 @@ This params are for dual Titan RTX cards with NVIDIA Apex installed: These hparams (or similar) work well for a wide range of ResNet architecture, generally a good idea to increase the epoch # as the model size increases... ie approx 180-200 for ResNe(X)t50, and 220+ for larger. Increase batch size and LR proportionally for better GPUs or with AMP enabled. These params were for 2 1080Ti cards: ```bash -./distributed_train.sh 2 /imagenet/ --model seresnext26t_32x4d --lr 0.1 --warmup-epochs 5 --epochs 160 --weight-decay 1e-4 --sched cosine --reprob 0.4 --remode pixel -b 112 +./distributed_train.sh 2 --data-dir /imagenet/ --model seresnext26t_32x4d --lr 0.1 --warmup-epochs 5 --epochs 160 --weight-decay 1e-4 --sched cosine --reprob 0.4 --remode pixel -b 112 ``` ### EfficientNet-B3 with RandAugment - 81.5 top-1, 95.7 top-5 @@ -70,26 +70,26 @@ The training of this model started with the same command line as EfficientNet-B2 [Michael Klachko](https://github.com/michaelklachko) achieved these results with the command line for B2 adapted for larger batch size, with the recommended B0 dropout rate of 0.2. ```bash -./distributed_train.sh 2 /imagenet/ --model efficientnet_b0 -b 384 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048 +./distributed_train.sh 2 --data-dir /imagenet/ --model efficientnet_b0 -b 384 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048 ``` ### ResNet50 with JSD loss and RandAugment (clean + 2x RA augs) - 79.04 top-1, 94.39 top-5 Trained on two older 1080Ti cards, this took a while. Only slightly, non statistically better ImageNet validation result than my first good AugMix training of 78.99. However, these weights are more robust on tests with ImageNetV2, ImageNet-Sketch, etc. Unlike my first AugMix runs, I've enabled SplitBatchNorm, disabled random erasing on the clean split, and cranked up random erasing prob on the 2 augmented paths. ```bash -./distributed_train.sh 2 /imagenet -b 64 --model resnet50 --sched cosine --epochs 200 --lr 0.05 --amp --remode pixel --reprob 0.6 --aug-splits 3 --aa rand-m9-mstd0.5-inc1 --resplit --split-bn --jsd --dist-bn reduce +./distributed_train.sh 2 --data-dir /imagenet -b 64 --model resnet50 --sched cosine --epochs 200 --lr 0.05 --amp --remode pixel --reprob 0.6 --aug-splits 3 --aa rand-m9-mstd0.5-inc1 --resplit --split-bn --jsd --dist-bn reduce ``` ### EfficientNet-ES (EdgeTPU-Small) with RandAugment - 78.066 top-1, 93.926 top-5 Trained by [Andrew Lavin](https://github.com/andravin) with 8 V100 cards. Model EMA was not used, final checkpoint is the average of 8 best checkpoints during training. ```bash -./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 +./distributed_train.sh 8 --data-dir /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 ``` ### MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5 ```bash -./distributed_train.sh 2 /imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9 +./distributed_train.sh 2 /--data-dir imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9 ``` ### ResNeXt-50 32x4d w/ RandAugment - 79.762 top-1, 94.60 top-5 @@ -97,5 +97,5 @@ These params will also work well for SE-ResNeXt-50 and SK-ResNeXt-50 and likely ```bash -./distributed_train.sh 8 /imagenet --model resnext50_32x4d --lr 0.6 --warmup-epochs 5 --epochs 240 --weight-decay 1e-4 --sched cosine --reprob 0.4 --recount 3 --remode pixel --aa rand-m7-mstd0.5-inc1 -b 192 -j 6 --amp --dist-bn reduce +./distributed_train.sh 8 --data-dir /imagenet --model resnext50_32x4d --lr 0.6 --warmup-epochs 5 --epochs 240 --weight-decay 1e-4 --sched cosine --reprob 0.4 --recount 3 --remode pixel --aa rand-m7-mstd0.5-inc1 -b 192 -j 6 --amp --dist-bn reduce ``` From 51ac8d2efb926c6b7c34eeb1dc52bcf57999e2de Mon Sep 17 00:00:00 2001 From: Josua Rieder Date: Mon, 4 Nov 2024 11:44:33 +0100 Subject: [PATCH 02/10] fix typo in train.py: bathes > batches --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index a82fa0a8..7aab20a7 100755 --- a/train.py +++ b/train.py @@ -368,7 +368,7 @@ group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', group.add_argument('-j', '--workers', type=int, default=4, metavar='N', help='how many training processes to use (default: 4)') group.add_argument('--save-images', action='store_true', default=False, - help='save images of input bathes every log interval for debugging') + help='save images of input batches every log interval for debugging') group.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') group.add_argument('--no-prefetcher', action='store_true', default=False, From 3ae3f44288a45a79e68476cd60b90d7fa271925d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojtek=20Jasi=C5=84ski?= Date: Fri, 1 Nov 2024 19:46:46 +0100 Subject: [PATCH 03/10] Fix positional embedding resampling for non-square inputs in ViT --- timm/models/vision_transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 8bc09e94..b3b0ddca 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -669,9 +669,11 @@ class VisionTransformer(nn.Module): if self.dynamic_img_size: B, H, W, C = x.shape + prev_grid_size = self.patch_embed.grid_size pos_embed = resample_abs_pos_embed( self.pos_embed, - (H, W), + new_size=(H, W), + old_size=prev_grid_size, num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, ) x = x.view(B, -1, C) From 3c7822c621b3cb765935441eb949ee6aa2e3ae72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojtek=20Jasi=C5=84ski?= Date: Fri, 1 Nov 2024 23:24:13 +0100 Subject: [PATCH 04/10] fix pos embed dynamic resampling for deit --- timm/models/deit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/deit.py b/timm/models/deit.py index 63662c02..0072013b 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -75,9 +75,11 @@ class VisionTransformerDistilled(VisionTransformer): def _pos_embed(self, x): if self.dynamic_img_size: B, H, W, C = x.shape + prev_grid_size = self.patch_embed.grid_size pos_embed = resample_abs_pos_embed( self.pos_embed, - (H, W), + new_size=(H, W), + old_size=prev_grid_size, num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, ) x = x.view(B, -1, C) From eb94efb21817eb68629cfab63083bfa34ddddd70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wojtek=20Jasi=C5=84ski?= Date: Fri, 1 Nov 2024 23:27:06 +0100 Subject: [PATCH 05/10] fix pos embed dynamic resampling for eva --- timm/models/eva.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/eva.py b/timm/models/eva.py index 62e986ba..fe871540 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -560,9 +560,11 @@ class Eva(nn.Module): if self.dynamic_img_size: B, H, W, C = x.shape if self.pos_embed is not None: + prev_grid_size = self.patch_embed.grid_size pos_embed = resample_abs_pos_embed( self.pos_embed, - (H, W), + new_size=(H, W), + old_size=prev_grid_size, num_prefix_tokens=self.num_prefix_tokens, ) else: From 7f0c1b1f301e7005d46b02707e688b0c4368282b Mon Sep 17 00:00:00 2001 From: Augustin Godinot Date: Fri, 8 Nov 2024 15:50:55 +0100 Subject: [PATCH 06/10] Add trust_remote_code argument to ReaderHfds --- timm/data/readers/reader_hfds.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/timm/data/readers/reader_hfds.py b/timm/data/readers/reader_hfds.py index 77846606..b2054472 100644 --- a/timm/data/readers/reader_hfds.py +++ b/timm/data/readers/reader_hfds.py @@ -38,6 +38,7 @@ class ReaderHfds(Reader): input_key: str = 'image', target_key: str = 'label', download: bool = False, + trust_remote_code: bool = False ): """ """ @@ -48,6 +49,7 @@ class ReaderHfds(Reader): name, # 'name' maps to path arg in hf datasets split=split, cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path + trust_remote_code=trust_remote_code ) # leave decode for caller, plus we want easy access to original path names... self.dataset = self.dataset.cast_column(input_key, datasets.Image(decode=False)) From 363b043c13a66628f86ddcc53ad84cbd65dc2c15 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Nov 2024 10:24:06 -0800 Subject: [PATCH 07/10] Extend train epoch schedule by warmup_epochs if warmup_prefix enable, allows schedule to reach end w/ prefix enabledy --- timm/scheduler/cosine_lr.py | 5 +++-- timm/scheduler/poly_lr.py | 5 +++-- timm/scheduler/scheduler_factory.py | 6 +++++- timm/scheduler/tanh_lr.py | 5 +++-- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/timm/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py index 4eaaa86a..00dd9357 100644 --- a/timm/scheduler/cosine_lr.py +++ b/timm/scheduler/cosine_lr.py @@ -111,6 +111,7 @@ class CosineLRScheduler(Scheduler): def get_cycle_length(self, cycles=0): cycles = max(1, cycles or self.cycle_limit) if self.cycle_mul == 1.0: - return self.t_initial * cycles + t = self.t_initial * cycles else: - return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) + t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) + return t + self.warmup_t if self.warmup_prefix else t \ No newline at end of file diff --git a/timm/scheduler/poly_lr.py b/timm/scheduler/poly_lr.py index 8875e15b..f7971302 100644 --- a/timm/scheduler/poly_lr.py +++ b/timm/scheduler/poly_lr.py @@ -107,6 +107,7 @@ class PolyLRScheduler(Scheduler): def get_cycle_length(self, cycles=0): cycles = max(1, cycles or self.cycle_limit) if self.cycle_mul == 1.0: - return self.t_initial * cycles + t = self.t_initial * cycles else: - return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) + t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) + return t + self.warmup_t if self.warmup_prefix else t diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index caf68fad..08c5e180 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -196,11 +196,15 @@ def create_scheduler_v2( ) if hasattr(lr_scheduler, 'get_cycle_length'): - # for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown + # For cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown + # NOTE: Warmup prefix added in get_cycle_lengths() if enabled t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t if step_on_epochs: num_epochs = t_with_cycles_and_cooldown else: num_epochs = t_with_cycles_and_cooldown // updates_per_epoch + else: + if warmup_prefix: + num_epochs += warmup_epochs return lr_scheduler, num_epochs diff --git a/timm/scheduler/tanh_lr.py b/timm/scheduler/tanh_lr.py index 94455302..93222926 100644 --- a/timm/scheduler/tanh_lr.py +++ b/timm/scheduler/tanh_lr.py @@ -108,6 +108,7 @@ class TanhLRScheduler(Scheduler): def get_cycle_length(self, cycles=0): cycles = max(1, cycles or self.cycle_limit) if self.cycle_mul == 1.0: - return self.t_initial * cycles + t = self.t_initial * cycles else: - return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) + t = int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) + return t + self.warmup_t if self.warmup_prefix else t From 9f5c279bad296447a0d896ac564846da677f93b2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 8 Nov 2024 10:57:29 -0800 Subject: [PATCH 08/10] Update log to describe scheduling behaviour diff w/ warmup_prefix --- train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 7aab20a7..5179e31d 100755 --- a/train.py +++ b/train.py @@ -832,8 +832,13 @@ def main(): lr_scheduler.step(start_epoch) if utils.is_primary(args): + if args.warmup_prefix: + sched_explain = '(warmup_epochs + epochs + cooldown_epochs). Warmup added to total when warmup_prefix=True' + else: + sched_explain = '(epochs + cooldown_epochs). Warmup within epochs when warmup_prefix=False' _logger.info( - f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.') + f'Scheduled epochs: {num_epochs} {sched_explain}. ' + f'LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.') results = [] try: From 68d5a64e45f3304bf87231fc0f924f6c44bcfecb Mon Sep 17 00:00:00 2001 From: Tal Date: Sun, 10 Nov 2024 06:57:39 +0000 Subject: [PATCH 09/10] extend existing unittests --- tests/test_layers.py | 45 +++++++++++++- tests/test_optim.py | 81 ++++++++++++++++++++++++- tests/test_utils.py | 137 ++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 260 insertions(+), 3 deletions(-) diff --git a/tests/test_layers.py b/tests/test_layers.py index 92f6b683..2cc8420a 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from timm.layers import create_act_layer, set_layer_config +from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn import importlib import os @@ -76,3 +76,46 @@ def test_hard_swish_grad(): def test_hard_mish_grad(): for _ in range(100): _run_act_layer_grad('hard_mish') + +def test_get_act_layer_empty_string(): + # Empty string should return None + assert get_act_layer('') is None + + +def test_create_act_layer_inplace_error(): + class NoInplaceAct(nn.Module): + def __init__(self): + super().__init__() + def forward(self, x): + return x + + # Should recover when inplace arg causes TypeError + layer = create_act_layer(NoInplaceAct, inplace=True) + assert isinstance(layer, NoInplaceAct) + + +def test_create_act_layer_edge_cases(): + # Test None input + assert create_act_layer(None) is None + + # Test TypeError handling for inplace + class CustomAct(nn.Module): + def __init__(self, **kwargs): + super().__init__() + def forward(self, x): + return x + + result = create_act_layer(CustomAct, inplace=True) + assert isinstance(result, CustomAct) + + +def test_get_act_fn_callable(): + def custom_act(x): + return x + assert get_act_fn(custom_act) is custom_act + + +def test_get_act_fn_none(): + assert get_act_fn(None) is None + assert get_act_fn('') is None + diff --git a/tests/test_optim.py b/tests/test_optim.py index 38f625fb..1ad8baea 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import TestCase from torch.nn import Parameter from timm.scheduler import PlateauLRScheduler -from timm.optim import create_optimizer_v2 +from timm.optim import create_optimizer_v2, param_groups_layer_decay, param_groups_weight_decay import importlib import os @@ -741,3 +741,82 @@ def test_lookahead_radam(optimizer): lambda params: create_optimizer_v2(params, optimizer, lr=1e-4) ) + +def test_param_groups_layer_decay_with_end_decay(): + model = torch.nn.Sequential( + torch.nn.Linear(10, 5), + torch.nn.ReLU(), + torch.nn.Linear(5, 2) + ) + + param_groups = param_groups_layer_decay( + model, + weight_decay=0.05, + layer_decay=0.75, + end_layer_decay=0.5, + verbose=True + ) + + assert len(param_groups) > 0 + # Verify layer scaling is applied with end decay + for group in param_groups: + assert 'lr_scale' in group + assert group['lr_scale'] <= 1.0 + assert group['lr_scale'] >= 0.5 + + +def test_param_groups_layer_decay_with_matcher(): + class ModelWithMatcher(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer1 = torch.nn.Linear(10, 5) + self.layer2 = torch.nn.Linear(5, 2) + + def group_matcher(self, coarse=False): + return lambda name: int(name.split('.')[0][-1]) + + model = ModelWithMatcher() + param_groups = param_groups_layer_decay( + model, + weight_decay=0.05, + layer_decay=0.75, + verbose=True + ) + + assert len(param_groups) > 0 + # Verify layer scaling is applied + for group in param_groups: + assert 'lr_scale' in group + assert 'weight_decay' in group + assert len(group['params']) > 0 + + +def test_param_groups_weight_decay(): + model = torch.nn.Sequential( + torch.nn.Linear(10, 5), + torch.nn.ReLU(), + torch.nn.Linear(5, 2) + ) + weight_decay = 0.01 + no_weight_decay_list = ['1.weight'] + + param_groups = param_groups_weight_decay( + model, + weight_decay=weight_decay, + no_weight_decay_list=no_weight_decay_list + ) + + assert len(param_groups) == 2 + assert param_groups[0]['weight_decay'] == 0.0 + assert param_groups[1]['weight_decay'] == weight_decay + + # Verify parameters are correctly grouped + no_decay_params = set(param_groups[0]['params']) + decay_params = set(param_groups[1]['params']) + + for name, param in model.named_parameters(): + if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: + assert param in no_decay_params + else: + assert param in decay_params + diff --git a/tests/test_utils.py b/tests/test_utils.py index b0f890d2..1e2126ee 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,8 +2,15 @@ from torch.nn.modules.batchnorm import BatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d import timm +import pytest from timm.utils.model import freeze, unfreeze +from timm.utils.model import ActivationStatsHook +from timm.utils.model import extract_spp_stats +from timm.utils.model import _freeze_unfreeze +from timm.utils.model import avg_sq_ch_mean, avg_ch_var, avg_ch_var_residual +from timm.utils.model import reparameterize_model +from timm.utils.model import get_state_dict def test_freeze_unfreeze(): model = timm.create_model('resnet18') @@ -54,4 +61,132 @@ def test_freeze_unfreeze(): freeze(model.layer1[0], ['bn1']) assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) unfreeze(model.layer1[0], ['bn1']) - assert isinstance(model.layer1[0].bn1, BatchNorm2d) \ No newline at end of file + assert isinstance(model.layer1[0].bn1, BatchNorm2d) + +def test_activation_stats_hook_validation(): + model = timm.create_model('resnet18') + + def test_hook(model, input, output): + return output.mean().item() + + # Test error case with mismatched lengths + with pytest.raises(ValueError, match="Please provide `hook_fns` for each `hook_fn_locs`"): + ActivationStatsHook( + model, + hook_fn_locs=['layer1.0.conv1', 'layer1.0.conv2'], + hook_fns=[test_hook] + ) + + +def test_extract_spp_stats(): + model = timm.create_model('resnet18') + + def test_hook(model, input, output): + return output.mean().item() + + stats = extract_spp_stats( + model, + hook_fn_locs=['layer1.0.conv1'], + hook_fns=[test_hook], + input_shape=[2, 3, 32, 32] + ) + + assert isinstance(stats, dict) + assert test_hook.__name__ in stats + assert isinstance(stats[test_hook.__name__], list) + assert len(stats[test_hook.__name__]) > 0 + +def test_freeze_unfreeze_bn_root(): + import torch.nn as nn + from timm.layers import BatchNormAct2d + + # Create batch norm layers + bn = nn.BatchNorm2d(10) + bn_act = BatchNormAct2d(10) + + # Test with BatchNorm2d as root + with pytest.raises(AssertionError): + _freeze_unfreeze(bn, mode="freeze") + + # Test with BatchNormAct2d as root + with pytest.raises(AssertionError): + _freeze_unfreeze(bn_act, mode="freeze") + + +def test_activation_stats_functions(): + import torch + + # Create sample input tensor [batch, channels, height, width] + x = torch.randn(2, 3, 4, 4) + + # Test avg_sq_ch_mean + result1 = avg_sq_ch_mean(None, None, x) + assert isinstance(result1, float) + + # Test avg_ch_var + result2 = avg_ch_var(None, None, x) + assert isinstance(result2, float) + + # Test avg_ch_var_residual + result3 = avg_ch_var_residual(None, None, x) + assert isinstance(result3, float) + + +def test_reparameterize_model(): + import torch.nn as nn + + class FusableModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + + def fuse(self): + return nn.Identity() + + class ModelWithFusable(nn.Module): + def __init__(self): + super().__init__() + self.fusable = FusableModule() + self.normal = nn.Linear(10, 10) + + model = ModelWithFusable() + + # Test with inplace=False (should create a copy) + new_model = reparameterize_model(model, inplace=False) + assert isinstance(new_model.fusable, nn.Identity) + assert isinstance(model.fusable, FusableModule) # Original unchanged + + # Test with inplace=True + reparameterize_model(model, inplace=True) + assert isinstance(model.fusable, nn.Identity) + + +def test_get_state_dict_custom_unwrap(): + import torch.nn as nn + + class CustomModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 10) + + model = CustomModel() + + def custom_unwrap(m): + return m + + state_dict = get_state_dict(model, unwrap_fn=custom_unwrap) + assert 'linear.weight' in state_dict + assert 'linear.bias' in state_dict + + +def test_freeze_unfreeze_string_input(): + model = timm.create_model('resnet18') + + # Test with string input + _freeze_unfreeze(model, 'layer1', mode='freeze') + assert model.layer1[0].conv1.weight.requires_grad == False + + # Test unfreezing with string input + _freeze_unfreeze(model, 'layer1', mode='unfreeze') + assert model.layer1[0].conv1.weight.requires_grad == True + From e31e5d2d6430d8cd14664d88f680aa8d62ad4726 Mon Sep 17 00:00:00 2001 From: mrT23 Date: Mon, 11 Nov 2024 08:00:05 +0200 Subject: [PATCH 10/10] imports --- tests/test_optim.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_optim.py b/tests/test_optim.py index 1ad8baea..66aaadbf 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -11,9 +11,11 @@ from copy import deepcopy import torch from torch.testing._internal.common_utils import TestCase from torch.nn import Parameter + +from timm.optim.optim_factory import param_groups_layer_decay, param_groups_weight_decay from timm.scheduler import PlateauLRScheduler -from timm.optim import create_optimizer_v2, param_groups_layer_decay, param_groups_weight_decay +from timm.optim import create_optimizer_v2 import importlib import os