Ross Wightman
eeee38e972
Avoid unecessary compat break btw train script and nearby timm versions w/ dtype addition.
2025-01-08 21:10:15 -08:00
Ross Wightman
1969528296
Fix dtype log when default (None) is used w/o AMP
2025-01-07 11:47:22 -08:00
Ross Wightman
92f610c982
Add half-precision (bfloat16, float16) support to train & validate scripts. Should push dtype handling into model factory / pretrained load at some point...
2025-01-07 10:25:14 -08:00
Augustin Godinot
2dff16fa58
Add --dataset-trust-remote-code to the train.py and validate.py scripts
2024-11-08 18:15:10 +01:00
Ross Wightman
c3992d5c4c
Remove extra space
2024-10-18 14:54:16 -07:00
Ross Wightman
015fbe457a
Merge branch 'MengqingCao-npu_support' into device_amp_cleanup
2024-10-18 14:50:44 -07:00
Ross Wightman
1766a01f96
Cleanup some amp related behaviour to better support different (non-cuda) devices
2024-10-18 13:54:16 -07:00
MengqingCao
37c731ca37
fix device check
2024-10-17 12:38:02 +00:00
MengqingCao
234f975787
add npu support
2024-10-16 07:13:45 +00:00
Ross Wightman
4d4bdd64a9
Add --torchcompile-mode args to train, validation, inference, benchmark scripts
2024-10-02 15:17:53 -07:00
Ross Wightman
be0944edae
Significant transforms, dataset, dataloading enhancements.
2024-01-08 09:38:42 -08:00
Ross Wightman
5242ba6edc
MobileOne and FastViT weights on HF hub, more code cleanup and tweaks, features_only working. Add reparam flag to validate and benchmark, support reparm of all models with fuse(), reparameterize() or switch_to_deploy() methods on modules
2023-08-23 22:50:37 -07:00
Lorenzo Baraldi
13d5b21ecd
Changed help_string of --worker
...
It seems like 4 is the correct default value
2023-06-01 17:27:51 +02:00
Ross Wightman
b3e816d6d7
Improve filtering behaviour for tag + non-tagged model wildcard consistency.
2023-03-22 10:21:22 -07:00
Ross Wightman
3448cc689c
Use gather (fancy indexing) for valid labels instead of bool mask in validate.py
2023-03-18 15:08:19 -07:00
Ross Wightman
e861b74cf8
Pass through --model-kwargs (and --opt-kwargs for train) from command line through to model __init__. Update some models to improve arg overlay. Cleanup along the way.
2023-01-06 16:12:33 -08:00
Ross Wightman
d5e7d6b27e
Merge remote-tracking branch 'origin/main' into refactor-imports
2022-12-09 14:49:44 -08:00
Lorenzo Baraldi
3d6bc42aa1
Put validation loss under amp_autocast
...
Secured the loss evaluation under the amp, avoiding function to operate on float16
2022-12-09 12:03:23 +01:00
Ross Wightman
927f031293
Major module / path restructure, timm.models.layers -> timm.layers, add _ prefix to all non model modules in timm.models
2022-12-06 15:00:06 -08:00
Ross Wightman
dbe7531aa3
Update scripts to support torch.compile(). Make --results_file arg more consistent across benchmark/validate/inference. Fix #1570
2022-12-05 10:21:34 -08:00
Ross Wightman
9da7e3a799
Add crop_mode for pretraind config / image transforms. Add support for dynamo compilation to benchmark/train/validate
2022-12-05 10:21:34 -08:00
Ross Wightman
eb333d6641
Update validate.py to use updated amp args for impl/dtype
2022-10-14 15:51:20 -07:00
Ross Wightman
87939e6fab
Refactor device handling in scripts, distributed init to be less 'cuda' centric. More device args passed through where needed.
2022-09-23 16:08:59 -07:00
Ross Wightman
ff6a919cf5
Add --fast-norm arg to benchmark.py, train.py, validate.py
2022-08-25 17:20:46 -07:00
Ross Wightman
56596e4e84
jit trace comparisons snuck into torchscript part of validate.py, fixed
2022-07-31 21:13:56 -07:00
Ross Wightman
0dbd9352ce
Add bulk_runner script and updates to benchmark.py and validate.py for better error handling in bulk runs (used for benchmark and validation result runs). Improved batch size decay stepping on retry...
2022-07-18 18:04:54 -07:00
Ross Wightman
7c7ecd2492
Add --use-train-size flag to force use of train input_size (over test input size) for validation. Default test-time pooling to use train input size (fixes issues).
2022-07-07 22:01:24 -07:00
Ross Wightman
500c190860
Add --aot-autograd (functorch efficient mem fusion) support to validate.py
2022-07-07 15:15:25 -07:00
Yonghye Kwon
57f8361a01
fix a function parameter typo(cropt_pct -> crop_pct)
2022-05-25 00:36:28 +09:00
Ross Wightman
95cfc9b3e8
Merge remote-tracking branch 'origin/master' into norm_norm_norm
2022-01-25 22:20:45 -08:00
Ross Wightman
abc9ba2544
Transitioning default_cfg -> pretrained_cfg. Improving handling of pretrained_cfg source (HF-Hub, files, timm config, etc). Checkpoint handling tweaks.
2022-01-25 21:54:13 -08:00
Ross Wightman
cf4334391e
Update benchmark and validate scripts to output results in JSON with a fixed delimiter for use in multi-process launcher
2022-01-24 14:46:47 -08:00
Ross Wightman
f0f9eccda8
Add --fuser arg to train/validate/benchmark scripts to select jit fuser type
2022-01-17 13:54:25 -08:00
Ross Wightman
b669f4a588
Add ConvNeXt 22k->1k fine-tuned and 384 22k-1k fine-tuned weights after testing
2022-01-15 15:44:36 -08:00
Ross Wightman
ba65dfe2c6
Dataset work
...
* support some torchvision datasets
* improvements to TFDS wrapper for subsplit handling (fix #942 ), shuffle seed
* add class-map support to train (fix #957 )
2021-11-09 22:34:15 -08:00
Ross Wightman
e0b3a3fab3
Make test-pooling flag for validate.py opt in
2021-10-06 16:12:20 -07:00
Ross Wightman
5d6983c462
Batch validate a list of files if model is a text file with model per line
2021-09-23 15:45:17 -07:00
Ross Wightman
a04427d8ce
Add _in22k to bulk validate filter
2021-04-17 12:18:03 -07:00
Ross Wightman
a5310a3451
Merge remote-tracking branch 'origin/benchmark-fixes-vit_hybrids' into pit_and_vit_update
2021-04-01 12:15:34 -07:00
Ross Wightman
d584e7f617
Support for huggingface hub via create_model and default_cfgs.
...
* improve consistency of model creation helper fns
* add comments to some of the model helpers
* support passing external default_cfgs so they can be sourced from hub
2021-03-16 22:48:26 -07:00
Ross Wightman
2db2d87ff7
Add epoch-repeats arg to multiply the number of dataset passes per epoch. Currently for iterable datasets (read TFDS wrapper) only.
2021-02-23 17:31:42 -08:00
Ross Wightman
d8e69206be
Merge pull request #419 from rwightman/byob_vgg_models
...
More models, GPU-Efficient Nets, RepVGG, classic VGG, and flexible Byob backbone.
2021-02-10 15:44:09 -08:00
Ross Wightman
0356e773f5
Default to native PyTorch AMP instead of APEX amp. Too many APEX issues cropping up lately.
2021-02-10 14:31:18 -08:00
Ross Wightman
b4e216e377
Fix a few small things.
2021-02-09 17:33:43 -08:00
Csaba Kertesz
5114c214fc
Change the Python interpreter to Python 3.x in the scripts
2021-02-09 21:20:28 +02:00
Ross Wightman
2a8c4dc63b
Add validation script update for using test_input_size in model default_cfgs
2021-02-07 21:35:50 -08:00
Ross Wightman
38d8f67570
Fix potential issue with change to num_classes arg in train/validate.py defaulting to None (rely on model def / default_cfg)
2021-01-25 11:53:34 -08:00
Ross Wightman
855d6cc217
More dataset work including factories and a tensorflow datasets (TFDS) wrapper
...
* Add parser/dataset factory methods for more flexible dataset & parser creation
* Add dataset parser that wraps TFDS image classification datasets
* Tweak num_classes handling bug for 21k models
* Add initial deit models so they can be benchmarked in next csv results runs
2021-01-15 17:26:20 -08:00
Ross Wightman
59ec7e6a53
Merge branch 'master' into imagenet21k_datasets_more
2021-01-04 12:11:05 -08:00
Csaba Kertesz
e42b140ade
Add --input-size option to scripts to specify full input dimensions from command-line
2021-01-04 00:25:29 +02:00