Commit Graph

202 Commits (228e080e39ce5d7599ba91a311b59bbf6fd3f93a)

Author SHA1 Message Date
Josua Rieder cb4cea561a add arguments to the respective argument groups 2025-01-20 10:54:35 -08:00
Josua Rieder 634b68ae50 Fix metavar for --input-size 2025-01-20 10:53:46 -08:00
Ross Wightman eeee38e972 Avoid unecessary compat break btw train script and nearby timm versions w/ dtype addition. 2025-01-08 21:10:15 -08:00
Ross Wightman c173886e75 Merge branch 'main' into caojiaolong-main 2025-01-08 09:11:50 -08:00
Ross Wightman 1969528296 Fix dtype log when default (None) is used w/o AMP 2025-01-07 11:47:22 -08:00
Ross Wightman 92f610c982 Add half-precision (bfloat16, float16) support to train & validate scripts. Should push dtype handling into model factory / pretrained load at some point... 2025-01-07 10:25:14 -08:00
Jiao-Long Cao 40c19f3939
Add wandb project name argument and allow change wandb run name 2025-01-07 16:43:34 +08:00
Ross Wightman 95d903fd87 Merge branch 'main' of github.com:grodino/pytorch-image-models into grodino-dataset_trust_remote 2024-12-06 11:14:26 -08:00
Sina Hajimiri 3a6cc4fb17 Improve wandb logging 2024-11-20 21:04:07 -08:00
Ross Wightman 620cb4f3cb Improve the parsable results dump at end of train, stop excessive output, only display top-10. 2024-11-20 16:47:06 -08:00
Ross Wightman 36b5d1adaa In dist training, update loss running avg every step, only sync on log updates / final. 2024-11-20 16:47:06 -08:00
Ross Wightman ee5f6e76bb A bit of an optimizer overhaul, added an improved factory, list_optimizers, class helper and add info classes with descriptions, arg configs 2024-11-12 20:49:01 -08:00
Ross Wightman 9f5c279bad Update log to describe scheduling behaviour diff w/ warmup_prefix 2024-11-08 11:01:11 -08:00
Augustin Godinot 2dff16fa58 Add --dataset-trust-remote-code to the train.py and validate.py scripts 2024-11-08 18:15:10 +01:00
Josua Rieder 51ac8d2efb fix typo in train.py: bathes > batches 2024-11-05 08:53:55 -08:00
Ross Wightman 015fbe457a Merge branch 'MengqingCao-npu_support' into device_amp_cleanup 2024-10-18 14:50:44 -07:00
Ross Wightman 1766a01f96 Cleanup some amp related behaviour to better support different (non-cuda) devices 2024-10-18 13:54:16 -07:00
MengqingCao 234f975787 add npu support 2024-10-16 07:13:45 +00:00
Ross Wightman 4d4bdd64a9 Add --torchcompile-mode args to train, validation, inference, benchmark scripts 2024-10-02 15:17:53 -07:00
Zirunis 4ed93fce93
Fix LR scheduler help in train.py
The default is and always has been the cosine scheduler, yet the help states that the default would be the step scheduler. Whatever the intended one was, for backwards compatibility the default should definitely remain cosine, which is why I changed the help comment to reflect that.
2024-07-22 23:04:00 +02:00
Tianyi Wang d3ce5a8665
Avoid zero division error 2024-07-15 12:45:46 +10:00
Ross Wightman e25bbfceec Fix #2097 a small typo in train.py 2024-04-10 09:40:14 -07:00
Ross Wightman 5a58f4d3dc Remove test MESA support, no signal that it's helpful so far 2024-02-10 14:38:01 -08:00
Ross Wightman c7ac37693d Add device arg to validate() calls in train.py 2024-02-04 10:14:57 -08:00
Ross Wightman bee0471f91 forward() pass through for ema model, flag for ema warmup, comment about warmup 2024-02-03 16:24:45 -08:00
Ross Wightman 5e4a4b2adc Merge branch 'device_flex' into mesa_ema 2024-02-02 09:45:30 -08:00
Ross Wightman dd84ef2cd5 ModelEmaV3 and MESA experiments 2024-02-02 09:45:04 -08:00
Ross Wightman 809a9e14e2 Pass train-crop-mode to create_loader/transforms from train.py args 2024-01-24 16:19:02 -08:00
Ross Wightman a48ab818f5 Improving device flexibility in train. Fix #2081 2024-01-20 15:10:20 -08:00
lorenzbaraldi 8c663c4b86 Fixed index out of range in case of resume 2024-01-12 23:33:32 -08:00
Ross Wightman c50004db79 Allow training w/o validation split set 2024-01-08 09:38:42 -08:00
Ross Wightman be0944edae Significant transforms, dataset, dataloading enhancements. 2024-01-08 09:38:42 -08:00
Ross Wightman b5a4fa9c3b Add pos_weight and support for summing over classes to BCE impl in train scripts 2023-12-30 12:13:06 -08:00
Ross Wightman f2fdd97e9f Add parsable json results output for train.py, tweak --pretrained-path to force head adaptation 2023-12-22 11:18:25 -08:00
Ross Wightman 60b170b200 Add --pretrained-path arg to train script to allow passing local checkpoint as pretrained. Add missing/unexpected keys log. 2023-12-11 12:10:29 -08:00
Ross Wightman a83e9f2d3b forward & backward in same no_sync context, slightly easier to read that splitting 2023-04-20 08:14:05 -07:00
Ross Wightman 4cd7fb88b2 clip gradients with update 2023-04-19 23:36:20 -07:00
Ross Wightman df81d8d85b Cleanup gradient accumulation, fix a few issues, a few other small cleanups in related code. 2023-04-19 23:11:00 -07:00
Ross Wightman ab7ca62a6e Merge branch 'main' of github.com:rwightman/pytorch-image-models into wip-voidbag-accumulate-grad 2023-04-19 11:08:12 -07:00
Ross Wightman ec6cca4b37 Add head-init-scale and head-init-bias args that works for all models, fix #1718 2023-04-14 17:59:23 -07:00
Ross Wightman 43e6143bef Fix #1712 broken support for AMP w/ PyTorch < 1.10. Disable loss scaler for bfloat16 2023-03-11 15:26:09 -08:00
Taeksang Kim 7f29a46d44 Add gradient accumulation option to train.py
option: iters-to-accum(iterations to accmulate)

Gradient accumulation improves training performance(samples/s).
It can reduce the number of parameter sharing between each node.
This option can be helpful when network is bottleneck.

Signed-off-by: Taeksang Kim <voidbag@puzzle-ai.com>
2023-02-06 09:24:48 +09:00
Fredo Guan 81ca323751
Davit update formatting and fix grad checkpointing (#7)
fixed head to gap->norm->fc as per convnext, along with option for norm->gap->fc
failed tests due to clip convnext models, davit tests passed
2023-01-15 14:34:56 -08:00
Ross Wightman d5e7d6b27e Merge remote-tracking branch 'origin/main' into refactor-imports 2022-12-09 14:49:44 -08:00
Lorenzo Baraldi 3d6bc42aa1 Put validation loss under amp_autocast
Secured the loss evaluation under the amp, avoiding function to operate on float16
2022-12-09 12:03:23 +01:00
Ross Wightman 927f031293 Major module / path restructure, timm.models.layers -> timm.layers, add _ prefix to all non model modules in timm.models 2022-12-06 15:00:06 -08:00
Ross Wightman dbe7531aa3 Update scripts to support torch.compile(). Make --results_file arg more consistent across benchmark/validate/inference. Fix #1570 2022-12-05 10:21:34 -08:00
Ross Wightman 9da7e3a799 Add crop_mode for pretraind config / image transforms. Add support for dynamo compilation to benchmark/train/validate 2022-12-05 10:21:34 -08:00
Ross Wightman 4714a4910e
Merge pull request #1525 from TianyiFranklinWang/main
✏️ fix typo
2022-11-03 20:55:43 -07:00
klae01 ddd6361904
Update train.py
fix typo args.in_chanes
2022-11-01 16:55:05 +09:00