Josua Rieder
8d81fdf3d9
Fix typos
2025-01-19 13:39:40 -08:00
Ross Wightman
deb9895600
Update checkpoint save to fix old hard-link + fuse issue I ran into again... fix #340
2025-01-08 15:36:58 -08:00
Ross Wightman
3bef09f831
Tweak a few docstrings
2024-11-13 10:12:31 -08:00
Ross Wightman
015fbe457a
Merge branch 'MengqingCao-npu_support' into device_amp_cleanup
2024-10-18 14:50:44 -07:00
Ross Wightman
1766a01f96
Cleanup some amp related behaviour to better support different (non-cuda) devices
2024-10-18 13:54:16 -07:00
MengqingCao
234f975787
add npu support
2024-10-16 07:13:45 +00:00
Ross Wightman
0cbf4fa586
_orig_mod still causing issues even though I thought it was fixed in pytorch, add unwrap / clean helpers
2024-07-19 11:03:45 -07:00
Ross Wightman
e748805be3
Add regex matching support to AttentionExtract. Add return_dict support to graph extractors and use returned output in AttentionExtractor
2024-05-22 14:33:39 -07:00
Ross Wightman
44f72c04b3
Change node/module name matching for AttentionExtract so it keeps outputs in order. #1232
2024-05-22 13:45:25 -07:00
Ross Wightman
e57625e814
Tweak dist_backend to use device_type (before possible :)
2024-05-15 08:49:25 -07:00
Setepenre
8848dad362
Update distributed.py
2024-05-13 16:55:42 -04:00
Ross Wightman
07535f408a
Add AttentionExtract helper module
2024-05-04 14:10:00 -07:00
Ross Wightman
24f6d4f7f8
Fix #2127 move to ema device
2024-04-10 21:29:09 -07:00
Ross Wightman
ba641e07ae
Add support for dynamo based onnx export
2024-03-13 12:05:26 -07:00
Ross Wightman
47c9bc4dc6
Fix device idx split
2024-02-10 21:41:14 -08:00
Ross Wightman
a08b57e801
Fix distributed flag bug w/ flex device handling
2024-02-03 16:26:15 -08:00
Ross Wightman
bee0471f91
forward() pass through for ema model, flag for ema warmup, comment about warmup
2024-02-03 16:24:45 -08:00
Ross Wightman
5e4a4b2adc
Merge branch 'device_flex' into mesa_ema
2024-02-02 09:45:30 -08:00
Ross Wightman
dd84ef2cd5
ModelEmaV3 and MESA experiments
2024-02-02 09:45:04 -08:00
Ross Wightman
d0ff315eed
Merge remote-tracking branch 'emav3/faster_ema' into mesa_ema
2024-01-27 14:52:10 -08:00
Ross Wightman
a48ab818f5
Improving device flexibility in train. Fix #2081
2024-01-20 15:10:20 -08:00
Ross Wightman
c50004db79
Allow training w/o validation split set
2024-01-08 09:38:42 -08:00
Ross Wightman
5242ba6edc
MobileOne and FastViT weights on HF hub, more code cleanup and tweaks, features_only working. Add reparam flag to validate and benchmark, support reparm of all models with fuse(), reparameterize() or switch_to_deploy() methods on modules
2023-08-23 22:50:37 -07:00
Ross Wightman
4cd7fb88b2
clip gradients with update
2023-04-19 23:36:20 -07:00
Ross Wightman
df81d8d85b
Cleanup gradient accumulation, fix a few issues, a few other small cleanups in related code.
2023-04-19 23:11:00 -07:00
Ross Wightman
ab7ca62a6e
Merge branch 'main' of github.com:rwightman/pytorch-image-models into wip-voidbag-accumulate-grad
2023-04-19 11:08:12 -07:00
Ross Wightman
aef6e562e4
Add onnx utils and export code, tweak padding and conv2d_same for better dynamic export with recent PyTorch
2023-04-11 17:03:57 -07:00
Taeksang Kim
7f29a46d44
Add gradient accumulation option to train.py
...
option: iters-to-accum(iterations to accmulate)
Gradient accumulation improves training performance(samples/s).
It can reduce the number of parameter sharing between each node.
This option can be helpful when network is bottleneck.
Signed-off-by: Taeksang Kim <voidbag@puzzle-ai.com>
2023-02-06 09:24:48 +09:00
Fredo Guan
81ca323751
Davit update formatting and fix grad checkpointing ( #7 )
...
fixed head to gap->norm->fc as per convnext, along with option for norm->gap->fc
failed tests due to clip convnext models, davit tests passed
2023-01-15 14:34:56 -08:00
Jerome Rony
3491506fec
Add foreach option for faster EMA
2022-11-30 14:06:58 -05:00
Jerome Rony
6ec5cd6a99
Use in-place operations for EMA
2022-11-17 11:53:29 -05:00
Ross Wightman
b1b024dfed
Scheduler update, add v2 factory method, support scheduling on updates instead of just epochs. Add LR to summary csv. Add lr_base scaling calculations to train script. Fix #1168
2022-10-07 10:43:04 -07:00
Ross Wightman
87939e6fab
Refactor device handling in scripts, distributed init to be less 'cuda' centric. More device args passed through where needed.
2022-09-23 16:08:59 -07:00
Ross Wightman
0dbd9352ce
Add bulk_runner script and updates to benchmark.py and validate.py for better error handling in bulk runs (used for benchmark and validation result runs). Improved batch size decay stepping on retry...
2022-07-18 18:04:54 -07:00
Ross Wightman
324a4e58b6
disable nvfuser for jit te/legacy modes (for PT 1.12+)
2022-07-13 10:34:34 -07:00
Ross Wightman
2f2b22d8c7
Disable nvfuser fma / opt level overrides per #1244
2022-05-13 09:27:13 -07:00
jjsjann123
f88c606fcf
fixing channels_last on cond_conv2d; update nvfuser debug env variable
2022-04-25 12:41:46 -07:00
Ross Wightman
f0f9eccda8
Add --fuser arg to train/validate/benchmark scripts to select jit fuser type
2022-01-17 13:54:25 -08:00
Ross Wightman
57992509f9
Fix some formatting in utils/model.py
2021-10-23 20:35:36 -07:00
Ross Wightman
e5da481073
Small post-merge tweak for freeze/unfreeze, add to __init__ for utils
2021-10-06 17:00:27 -07:00
Alexander Soare
431e60c83f
Add acknowledgements for freeze_batch_norm inspiration
2021-10-06 14:28:49 +01:00
Alexander Soare
65c3d78b96
Freeze unfreeze functionality finalized. Tests added
2021-10-02 15:55:08 +01:00
Alexander Soare
0cb8ea432c
wip
2021-10-02 15:55:08 +01:00
Ross Wightman
d667351eac
Tweak accuracy topk safety. Fix #807
2021-08-19 14:18:53 -07:00
Yohann Lereclus
35c9740826
Fix accuracy when topk > num_classes
2021-08-19 11:58:59 +02:00
Ross Wightman
e685618f45
Merge pull request #550 from amaarora/wandb
...
Wandb Support
2021-04-15 09:26:35 -07:00
Ross Wightman
7c97e66f7c
Remove commented code, add more consistent seed fn
2021-04-12 09:51:36 -07:00
Aman Arora
5772c55c57
Make wandb optional
2021-04-10 01:34:20 -04:00
Aman Arora
f54897cc0b
make wandb not required but rather optional as huggingface_hub
2021-04-10 01:27:23 -04:00
Aman Arora
3f028ebc0f
import wandb in summary.py
2021-04-08 03:48:51 -04:00