Commit Graph

160 Commits (1825b5e314b702d603e3e7a2d02616a8ffd49ea2)

Author SHA1 Message Date
Ross Wightman e861b74cf8 Pass through --model-kwargs (and --opt-kwargs for train) from command line through to model __init__. Update some models to improve arg overlay. Cleanup along the way. 2023-01-06 16:12:33 -08:00
Ross Wightman d5e7d6b27e Merge remote-tracking branch 'origin/main' into refactor-imports 2022-12-09 14:49:44 -08:00
Lorenzo Baraldi 3d6bc42aa1 Put validation loss under amp_autocast
Secured the loss evaluation under the amp, avoiding function to operate on float16
2022-12-09 12:03:23 +01:00
Ross Wightman 927f031293 Major module / path restructure, timm.models.layers -> timm.layers, add _ prefix to all non model modules in timm.models 2022-12-06 15:00:06 -08:00
Ross Wightman dbe7531aa3 Update scripts to support torch.compile(). Make --results_file arg more consistent across benchmark/validate/inference. Fix #1570 2022-12-05 10:21:34 -08:00
Ross Wightman 9da7e3a799 Add crop_mode for pretraind config / image transforms. Add support for dynamo compilation to benchmark/train/validate 2022-12-05 10:21:34 -08:00
Ross Wightman 4714a4910e
Merge pull request #1525 from TianyiFranklinWang/main
✏️ fix typo
2022-11-03 20:55:43 -07:00
klae01 ddd6361904
Update train.py
fix typo args.in_chanes
2022-11-01 16:55:05 +09:00
NPU-Franklin 9152b10478
✏️ fix typo 2022-10-30 08:49:40 +08:00
hova88 29baf32327 fix typo : miss back quote 2022-10-28 09:30:51 +08:00
Simon Schrodi aceb79e002 Fix typo 2022-10-17 22:06:17 +02:00
Ross Wightman 285771972e Change --amp flags, no more --apex-amp and --native-amp, add --amp-impl to select apex, and --amp-dtype to allow bfloat16 AMP dtype 2022-10-07 15:27:25 -07:00
Ross Wightman b1b024dfed Scheduler update, add v2 factory method, support scheduling on updates instead of just epochs. Add LR to summary csv. Add lr_base scaling calculations to train script. Fix #1168 2022-10-07 10:43:04 -07:00
Ross Wightman b8c8550841 Data improvements. Improve train support for in_chans != 3. Add wds dataset support from bits_and_tpu branch w/ fixes and tweaks. TFDS tweaks. 2022-09-29 16:42:58 -07:00
Ross Wightman 87939e6fab Refactor device handling in scripts, distributed init to be less 'cuda' centric. More device args passed through where needed. 2022-09-23 16:08:59 -07:00
Ross Wightman ff6a919cf5 Add --fast-norm arg to benchmark.py, train.py, validate.py 2022-08-25 17:20:46 -07:00
Xiao Wang 11060f84c5 make train.py compatible with torchrun 2022-07-07 14:44:55 -07:00
Ross Wightman a29fba307d disable dist_bn when sync_bn active 2022-06-24 21:30:17 -07:00
Ross Wightman 879df47c0a Support BatchNormAct2d for sync-bn use. Fix #1254 2022-06-24 14:51:26 -07:00
Ross Wightman 037e5e6c09 Fix #1309, move wandb init after distributed init, only init on rank == 0 process 2022-06-21 12:32:40 -07:00
Jakub Kaczmarzyk 9e12530433 use utils namespace instead of function/classnames
This fixes buggy behavior introduced by
https://github.com/rwightman/pytorch-image-models/pull/1266.

Related to https://github.com/rwightman/pytorch-image-models/pull/1273.
2022-06-12 22:39:41 -07:00
Xiao Wang ca991c1fa5 add --aot-autograd 2022-06-07 18:01:52 -07:00
Ross Wightman fd360ac951
Merge pull request #1266 from kaczmarj/enh/no-star-imports
ENH: replace star imports with imported names in train.py
2022-05-20 08:55:07 -07:00
Jakub Kaczmarzyk ce5578bc3a replace star imports with imported names 2022-05-18 11:04:10 -04:00
Jakub Kaczmarzyk dcad288fd6 use argparse groups to group arguments 2022-05-18 10:27:33 -04:00
Jakub Kaczmarzyk e1e4c9bbae rm whitespace 2022-05-18 10:17:02 -04:00
han a16171335b fix: change milestones to decay-milestones
- change argparser option `milestone` to `decay-milestone`
2022-05-10 07:57:19 +09:00
han 57a988df30 fix: multistep lr decay epoch bugs
- add milestones arguments
- change decay_epochs to milestones variable
2022-05-06 13:14:43 +09:00
Ross Wightman b049a5c5c6 Merge remote-tracking branch 'origin/master' into norm_norm_norm 2022-03-21 13:41:43 -07:00
Ross Wightman 04db5833eb
Merge pull request #986 from hankyul2/master
fix: typo of argment parser desc in train.py
2022-03-21 12:13:51 -07:00
Ross Wightman 0557c8257d Fix bug introduced in non layer_decay weight_decay application. Remove debug print, fix arg desc. 2022-02-28 17:06:32 -08:00
Ross Wightman 372ad5fa0d Significant model refactor and additions:
* All models updated with revised foward_features / forward_head interface
* Vision transformer and MLP based models consistently output sequence from forward_features (pooling or token selection considered part of 'head')
* WIP param grouping interface to allow consistent grouping of parameters for layer-wise decay across all model types
* Add gradient checkpointing support to a significant % of models, especially popular architectures
* Formatting and interface consistency improvements across models
* layer-wise LR decay impl part of optimizer factory w/ scale support in scheduler
* Poolformer and Volo architectures added
2022-02-28 13:56:23 -08:00
Ross Wightman 95cfc9b3e8 Merge remote-tracking branch 'origin/master' into norm_norm_norm 2022-01-25 22:20:45 -08:00
Ross Wightman abc9ba2544 Transitioning default_cfg -> pretrained_cfg. Improving handling of pretrained_cfg source (HF-Hub, files, timm config, etc). Checkpoint handling tweaks. 2022-01-25 21:54:13 -08:00
Ross Wightman f0f9eccda8 Add --fuser arg to train/validate/benchmark scripts to select jit fuser type 2022-01-17 13:54:25 -08:00
Ross Wightman 5ccf682a8f Remove deprecated bn-tf train arg and create_model handler. Add evos/evob models back into fx test filter until norm_norm_norm branch merged. 2022-01-06 18:08:39 -08:00
han ab5ae32f75
fix: typo of argment parser desc in train.py
- Remove duplicated `of`
2021-11-24 09:32:05 +09:00
Ross Wightman ba65dfe2c6 Dataset work
* support some torchvision datasets
* improvements to TFDS wrapper for subsplit handling (fix #942), shuffle seed
* add class-map support to train (fix #957)
2021-11-09 22:34:15 -08:00
Ross Wightman cd638d50a5
Merge pull request #880 from rwightman/fixes_bce_regnet
A collection of fixes, model experiments, etc
2021-10-03 19:37:01 -07:00
Ross Wightman d9abfa48df Make broadcast_buffers disable its own flag for now (needs more testing on interaction with dist_bn) 2021-10-01 13:43:55 -07:00
Ross Wightman 80075b0b8a Add worker_seeding arg to allow selecting old vs updated data loader worker seed for (old) experiment repeatability 2021-09-28 16:37:45 -07:00
Shoufa Chen 908563d060
fix `use_amp`
Fix https://github.com/rwightman/pytorch-image-models/issues/881
2021-09-26 12:32:22 +08:00
Ross Wightman 0387e6057e Update binary cross ent impl to use thresholding as an option (convert soft targets from mixup/cutmix to 0, 1) 2021-09-23 15:45:39 -07:00
Ross Wightman 0639d9a591 Fix updated validation_batch_size fallback 2021-09-02 14:44:53 -07:00
Ross Wightman 5db057dca0 Fix misnamed arg, tweak other train script args for better defaults. 2021-09-02 14:15:49 -07:00
Ross Wightman fb94350896 Update training script and loader factory to allow use of scheduler updates, repeat augment, and bce loss 2021-09-01 17:46:40 -07:00
SamuelGabriel 7c19c35d9f
Global instead of local rank. 2021-06-09 19:11:58 +02:00
Ross Wightman e15e68d881 Fix #566, summary.csv writing to pwd on local_rank != 0. Tweak benchmark mem handling to see if it reduces likelihood of 'bad' exceptions on OOM. 2021-04-15 23:03:56 -07:00
Ross Wightman e685618f45
Merge pull request #550 from amaarora/wandb
Wandb Support
2021-04-15 09:26:35 -07:00
Ross Wightman 7c97e66f7c Remove commented code, add more consistent seed fn 2021-04-12 09:51:36 -07:00