Commit Graph

497 Commits (709d7c07e8e7d0c0d4b2077304642cbd73e00240)

Author SHA1 Message Date
Ross Wightman de97be9146 Spell out diff between my small and deit small vit models. 2021-02-23 16:22:55 -08:00
Ross Wightman f0ffdf89b3 Add numerous experimental ViT Hybrid models w/ ResNetV2 base. Update the ViT naming for hybrids. Fix #426 for pretrained vit resizing. 2021-02-23 15:54:55 -08:00
Ross Wightman 0e16d4e9fb Add benchmark.py script, and update optimizer factory to be more friendly to use outside of argparse interface. 2021-02-23 15:38:12 -08:00
Ross Wightman 4bc103f504 Fix CUDA crash w/ channels-last + CSP models. Remove use of chunk() 2021-02-23 13:15:52 -08:00
Ross Wightman 8563609b28 Update notes in ScaledStdConv impl 2021-02-18 12:44:08 -08:00
Ross Wightman 678ba4e0a2 Add NFNet-F model weights ported from DeepMind Haiku impl and new set of models w/ compatible config. 2021-02-18 12:28:46 -08:00
Ross Wightman 9de2ec5e44 Update README for AGC and bump version to 0.4.4 2021-02-16 09:13:03 -08:00
Ross Wightman 4f49b94311 Initial AGC impl. Still testing. 2021-02-15 23:22:44 -08:00
Ross Wightman 5f9aff395c Fix stem width in NFNet-F models, add some more comments, add some 'light' NFNet models for testing. 2021-02-13 16:58:51 -08:00
Ross Wightman d86dbe45c2 Update README.md and few more comments 2021-02-12 22:07:18 -08:00
Ross Wightman 0d253e2c5e Fix issue with nfnet tests, bit more cleanup. 2021-02-12 21:05:41 -08:00
Ross Wightman cb06c7a910 Add NFNet-F models and tweak existing NF models. 2021-02-12 18:28:56 -08:00
Ross Wightman e4de077021 Add first 'Normalizer Free' models. nf_regnet_b1 79.3 @ 288x288 test, and nf_resnet50 80.3 @ 256x256 test (80.68 @ 288x288). 2021-02-11 13:20:11 -08:00
Ross Wightman d8e69206be
Merge pull request #419 from rwightman/byob_vgg_models
More models, GPU-Efficient Nets, RepVGG, classic VGG, and flexible Byob backbone.
2021-02-10 15:44:09 -08:00
Ross Wightman ca9b078ac7 Update README.md and docs. Version bumped to 0.4.3 2021-02-10 14:46:07 -08:00
Ross Wightman 6853b07bbd Improve RegVGG block identity/vs non for clariy and fix attn usage. Add comments. 2021-02-10 14:40:29 -08:00
Ross Wightman 0356e773f5 Default to native PyTorch AMP instead of APEX amp. Too many APEX issues cropping up lately. 2021-02-10 14:31:18 -08:00
Reuben 94ca140b67 update collections.abc import 2021-02-10 23:54:35 +11:00
Ross Wightman b4e216e377 Fix a few small things. 2021-02-09 17:33:43 -08:00
Ross Wightman dc85e5a237 Add ByobNet w/ GPU-EfficientNets and RepVGG. Also add classic vgg models. 2021-02-09 16:22:52 -08:00
Ross Wightman 1bcc69e0ad Use in_channels for depthwise groups, allows using `out_channels=N * in_channels` (does not impact existing models). Fix #354. 2021-02-09 16:22:52 -08:00
Ross Wightman 9811e229f7 Fix regression in models with 1001 class pretrained weights. Improve batchnorm arg and BatchNormAct layer handling in several models. 2021-02-09 16:22:52 -08:00
Ross Wightman a39c3ee216
Merge branch 'master' into eca-weights 2021-02-08 11:52:31 -08:00
Ross Wightman e9d6fe293c Update README for new weights. Version 0.4.2 2021-02-08 11:51:16 -08:00
Ross Wightman 666de85cf1 Move stride in EdgeResidual block to 3x3 expansion conv. Fix #414 2021-02-07 22:10:18 -08:00
Ross Wightman 3b57490a63 Fix some half removed resnet model defs, pooling for ecaresnet269d 2021-02-07 22:09:25 -08:00
Ross Wightman 68a4144882 Add new weights for ecaresnet26t/50t/269d models. Remove distinction between 't' and 'tn' (tiered models), tn is now t. Add test time img size spec to default cfg. 2021-02-06 16:30:02 -08:00
Ross Wightman b9843f954b
Merge pull request #282 from tigert1998/patch-1
Add symbolic for SwishJitAutoFn to support onnx
2021-02-04 12:18:40 -08:00
hwangdeyu 7a4be5c035 add operator HardSwishJitAutoFn export to onnx 2021-02-03 09:06:53 +08:00
Ross Wightman 4203efa36d Fix #387 so that checkpoint saver works with max history of 1. Add checkpoint-hist arg to train.py. 2021-01-31 20:14:51 -08:00
Ross Wightman f0e65e37b7 Fix NF-ResNet101 model defs 2021-01-30 23:26:19 -08:00
Ross Wightman 2c988c3b6e Update README.md for NF-nets, bump version to 0.4.1 for merge 2021-01-30 23:19:45 -08:00
Ross Wightman 2de54d174a Fix pool size defs for NFNet models, add a comment. 2021-01-30 18:02:33 -08:00
Ross Wightman 90980de4a9 Fix up a few details in NFResNet models, managed stable training. Add support for gamma gain to be applied in activation or ScaleStdConv. Some tweaks to ScaledStdConv. 2021-01-30 16:32:07 -08:00
Ross Wightman 5a8e1e643e Initial Normalizer-Free Reg/ResNet impl. A bit of related layer refactoring. 2021-01-27 22:06:57 -08:00
Ross Wightman 38d8f67570 Fix potential issue with change to num_classes arg in train/validate.py defaulting to None (rely on model def / default_cfg) 2021-01-25 11:53:34 -08:00
Ross Wightman 587780e56b Update README.md and bump version to 0.4.0 2021-01-25 11:22:11 -08:00
Ross Wightman bb50ac4708 Add DeiT distilled weights and distilled model def. Remove some redudant ViT model args. 2021-01-25 11:05:23 -08:00
Ross Wightman c16e965037 Add some ViT comments and fix a few minor issues. 2021-01-24 23:18:35 -08:00
Ross Wightman 22748f1a2d Convert samples/targets in ParserImageInTar to numpy arrays, slightly less mem usage for massive datasets. Add a few more se/eca model defs to resnet.py 2021-01-22 16:54:33 -08:00
Ross Wightman 5d4c3d0af3 Add enhanced ParserImageInTar that can read images from tars within tars, folders with multiple tars, etc. Additional comment cleanup. 2021-01-22 10:52:04 -08:00
Ross Wightman 55f7dfa9ea Refactor vision_transformer entrpy fns, add pos embedding resize support for fine tuning, add some deit models for testing 2021-01-18 16:11:02 -08:00
Ross Wightman d55bcc0fee Finishing adding stochastic depth support to BiT ResNetV2 models 2021-01-16 16:32:03 -08:00
Ross Wightman 855d6cc217 More dataset work including factories and a tensorflow datasets (TFDS) wrapper
* Add parser/dataset factory methods for more flexible dataset & parser creation
* Add dataset parser that wraps TFDS image classification datasets
* Tweak num_classes handling bug for 21k models
* Add initial deit models so they can be benchmarked in next csv results runs
2021-01-15 17:26:20 -08:00
Ross Wightman 20516abc18 Fix some broken tests for ResNetV2 BiT models 2021-01-04 23:21:39 -08:00
Ross Wightman 59ec7e6a53 Merge branch 'master' into imagenet21k_datasets_more 2021-01-04 12:11:05 -08:00
Ross Wightman e7a9ddf982
Merge pull request #334 from kecsap/links
Follow symbolic links during dataset scanning
2021-01-04 10:30:58 -08:00
Csaba Kertesz 7cae7e7035 Follow links during dataset scanning 2021-01-04 00:16:45 +02:00
Ross Wightman c96e9f99a0 Update version to 0.3.3 2021-01-03 12:43:44 -08:00
Ross Wightman 4e2533db77 Add 320x320 model default cfgs for 101D and 152D ResNets. Add SEResNet-152D weights and 320x320 cfg. 2021-01-03 12:10:25 -08:00
Ross Wightman 0167f749d3 Remove some old __future__ imports 2021-01-03 11:24:16 -08:00
Ross Wightman e35e9760a6 More work on dataset / parser split and imagenet21k (tar) support 2020-12-28 16:59:15 -08:00
Ross Wightman ce69de70d3 Add 21k weight urls to vision_transformer. Cleanup feature_info for preact ResNetV2 (BiT) models 2020-12-28 16:59:15 -08:00
Ross Wightman 231d04e91a ResNetV2 pre-act and non-preact model, w/ BiT pretrained weights and support for ViT R50 model. Tweaks for in21k num_classes passing. More to do... tests failing. 2020-12-28 16:59:15 -08:00
Ross Wightman de6046e213 Initial commit for dataset / parser reorg to support additional datasets / types 2020-12-28 16:59:15 -08:00
Ross Wightman 392595c7eb Add pool_size to default cfgs for new models to prevent tests from failing. Add explicit 200D_320 model entrypoint for next benchmark run. 2020-12-18 21:28:47 -08:00
Ross Wightman b1f1228a41 Add ResNet101D, 152D, and 200D weights, remove meh 66d model 2020-12-18 17:13:37 -08:00
Jasha 7c56c718f3 Configure create_optimizer with args.opt_args
Closes #301
2020-12-08 00:03:09 -06:00
Ross Wightman 9a25fdf3ad
Merge pull request #297 from rwightman/ema_simplify
Simplified JIT compatible Ema module. Fixes for SiLU export and torchscript training w/ Linear layer.
2020-12-05 11:42:45 -08:00
Tymoteusz Wiśniewski de15b43865 Fix a bug with accuracy retrieving from RealLabels 2020-12-04 16:12:50 +01:00
Ross Wightman cd72e66eff Bug in last mod for features_only default_cfg 2020-12-03 12:33:01 -08:00
Ross Wightman 867a0e5a04 Add default_cfg back to models wrapped in feature extraction module as per discussion in #294. 2020-12-03 10:24:35 -08:00
Ross Wightman 4ca52d73d8 Add separate set and update method to ModelEmaV2 2020-12-03 10:05:09 -08:00
Ross Wightman 2ed8f24715 A few more changes for 0.3.2 maint release. Linear layer change for mobilenetv3 and inception_v3, support no bias for linear wrapper. 2020-11-30 16:19:52 -08:00
Ross Wightman 6504a42832 Version 0.3.2 2020-11-30 13:39:08 -08:00
Ross Wightman 460eba7f24 Work around casting issue with combination of native torch AMP and torchscript for Linear layers 2020-11-30 13:30:51 -08:00
Ross Wightman 5f4b6076d8 Fix inplace arg compat for GELU and PreLU via activation factory 2020-11-30 13:27:40 -08:00
Ross Wightman fd962c4b4a Native SiLU (Swish) op doesn't export to ONNX 2020-11-29 21:56:55 -08:00
Ross Wightman 27bbc70d71 Add back old ModelEma and rename new one to ModelEmaV2 to avoid compat breaks in dependant code. Shuffle train script, add a few comments, remove DataParallel support, support experimental torchscript training. 2020-11-29 16:22:19 -08:00
tigertang 43f2500c26
Add symbolic for SwishJitAutoFn to support onnx 2020-11-18 14:36:12 +08:00
Ross Wightman 9214ca0716 Simplifying EMA... 2020-11-16 12:51:52 -08:00
Ross Wightman 53aeed3499 ver 0.3.1 2020-10-31 18:14:58 -07:00
Ross Wightman 30ab4a1494 Fix issue in optim factory with sgd / eps flag. Bump version to 0.3.1 2020-10-31 18:05:30 -07:00
Ross Wightman 741572dc9d Bump version to 0.3.0 for pending PyPi push 2020-10-29 17:31:39 -07:00
Ross Wightman b401952caf Add newly added vision transformer large/base 224x224 weights ported from JAX official repo 2020-10-29 17:31:01 -07:00
Ross Wightman 61200db0ab in_chans=1 working w/ pretrained weights for vision_transformer 2020-10-29 15:49:36 -07:00
Ross Wightman e90edce438 Support native silu activation (aka swish). An optimized ver is available in PyTorch 1.7. 2020-10-29 15:45:17 -07:00
Ross Wightman da6cd2cc1f Fix regression for pretrained classifier loading when using entrypt functions directly 2020-10-29 15:43:39 -07:00
Ross Wightman f591e90b0d Make sure num_features attr is present in vit models as with others 2020-10-29 15:33:47 -07:00
Ross Wightman 4a3df7842a Fix topn metric view regression on PyTorch 1.7 2020-10-29 14:04:15 -07:00
Ross Wightman f944242cb0 Fix #262, num_classes arg mixup. Make vision_transformers a bit closer to other models wrt get/reset classfier/forward_features. Fix torchscript for ViT. 2020-10-29 13:58:28 -07:00
Ross Wightman 736f209e7d Update vision transformers to be compatible with official code. Port official ViT weights from jax impl. 2020-10-26 18:42:11 -07:00
Ross Wightman 477a78ed81 Fix optimizer factory regressin for optimizers like sgd/momentum that don't have an eps arg 2020-10-22 15:59:47 -07:00
Ross Wightman 27a93e9de7 Improve test crop for ViT models. Small now 77.85, added base weights at 79.35 top-1. 2020-10-21 23:35:25 -07:00
Ross Wightman d4db9e7977 Add small vision transformer weights. 77.42 top-1. 2020-10-21 12:14:12 -07:00
talrid 27fadaa922 asymmetric_loss 2020-10-16 17:12:28 +03:00
Ross Wightman f31933cb37 Initial Vision Transformer impl w/ patch and hybrid variants. Refactor tuple helpers. 2020-10-13 13:33:44 -07:00
Ross Wightman a4d8fea61e Add model based wd skip support. Improve cross version compat of optimizer factory. Fix #247 2020-10-13 12:49:47 -07:00
Ross Wightman 80078c47bb Add Adafactor and Adahessian optimizers, cleanup optimizer arg passing, add gradient clipping support. 2020-10-09 17:24:43 -07:00
Ross Wightman fcb6258877 Add missing leaky_relu layer factory defn, update Apex/Native loss scaler interfaces to support unscaled grad clipping. Bump ver to 0.2.2 for pending release. 2020-10-02 16:19:39 -07:00
Ross Wightman e8e2d9cabf Add DropPath (stochastic depth) to ReXNet and VoVNet. RegNet DropPath impl tweak and dedupe se args. 2020-09-24 18:20:36 -07:00
Ross Wightman e8ca45854c More models in sotabench, more control over sotabench run, dataset filename extraction consistency 2020-09-24 15:56:57 -07:00
Ross Wightman 9c406532bd Add EfficientNet-EdgeTPU-M (efficientnet_em) model trained natively in PyTorch. More sotabench fiddling. 2020-09-23 17:12:07 -07:00
Ross Wightman c40384f5bd Add ResNet weights. 80.5 (top-1) ResNet-50-D, 77.1 ResNet-34-D, 72.7 ResNet-18-D. 2020-09-18 12:05:37 -07:00
Ross Wightman 47a7b3b5b1 More flexible mixup mode, add 'half' mode. 2020-09-07 20:03:06 -07:00
Ross Wightman 532e3b417d Reorg of utils into separate modules 2020-09-07 13:58:09 -07:00
Ross Wightman 33f8a1bf36 Updated README, add wide_resnet50_2 and seresnext50_32x4d weights 2020-09-03 10:45:17 -07:00
Ross Wightman 751b0bba98 Add global_pool (--gp) arg changes to allow passing 'fast' easily for train/validate to avoid channels_last issue with AdaptiveAvgPool 2020-09-02 16:13:47 -07:00
Ross Wightman 9c297ec67d Cleanup Apex vs native AMP scaler state save/load. Cleanup CheckpointSaver a bit. 2020-09-02 15:12:59 -07:00
Ross Wightman 80c9d9cc72 Add 'fast' global pool option, remove redundant SEModule from tresnet, normal one is now 'fast' 2020-09-02 09:11:48 -07:00
Ross Wightman 90a01f47d1 hrnet features_only pretrained weight loading issue. Fix #232. 2020-09-01 17:37:55 -07:00
Ross Wightman 110a7c4982 AdaptiveAvgPool2d -> mean((2,3)) for all SE/attn layers to avoid NaN with AMP + channels_last layout. See https://github.com/pytorch/pytorch/issues/43992 2020-09-01 16:05:32 -07:00
Ross Wightman c2cd1a332e Improve torch amp support and add channels_last support for train/validate scripts 2020-08-31 17:58:16 -07:00
Ross Wightman 470220b1f4 Fix MobileNetV3 crash with global_pool='', output consistent with other models but not equivalent due to efficient head. 2020-08-18 14:11:30 -07:00
Ross Wightman fc8b8afb6f Fix a silly bug in Sample version of EvoNorm missing x* part of swish, update EvoNormBatch to accumulated unbiased variance. 2020-08-13 18:25:01 -07:00
Ross Wightman f614df7921 Bump version to 0.2.1 and update README 2020-08-12 18:05:07 -07:00
Ross Wightman b423bc8362
Merge pull request #218 from rwightman/cutmix
CutMix + MixUp overhaul
2020-08-12 17:17:41 -07:00
Ross Wightman 8c9814e3f5 Final cleanup of mixup/cutmix. Element/batch modes working with both collate (prefetcher active) and without prefetcher. 2020-08-12 17:01:32 -07:00
Ross Wightman 0f5d9d8166 Add CSPResNet50 weights, 79.6 top-1 at 256x256 2020-08-12 11:20:11 -07:00
Ross Wightman b1b6e7c361 Fix a few more issues related to #216 w/ TResNet (space2depth) and FP16 weights in wide resnets. Also don't completely dump pretrained weights in in_chans != 1 or 3 cases. 2020-08-11 18:57:47 -07:00
Ross Wightman 512b2dd645 Add new EfficientNet-B3 and RegNetY-3.2GF weights, both just over 82 top-1 2020-08-11 14:18:51 -07:00
Ross Wightman 6890300877 Add DropPath (stochastic depth) to RegNet 2020-08-11 14:08:53 -07:00
Ross Wightman cd23f55397 Fix mixed prec issues with new mixup code 2020-08-11 12:17:43 -07:00
Yusuke Uchida f6b56602f9 fix test_model_default_cfgs 2020-08-11 23:23:57 +09:00
Ross Wightman f471c17c9d More cutmix/mixup overhaul, ready to kick-off some trials. 2020-08-11 00:10:33 -07:00
Ross Wightman d5145fa4d5 Change default_cfg names for senet to include the legacy and match model names 2020-08-08 11:12:58 -07:00
Ross Wightman 92f2d0d65d Merge branch 'master' into cutmix. Fixup a few issues. 2020-08-07 15:59:52 -07:00
Ross Wightman 1696499ce5 Bump version to 0.2.0, ready to roll (I think) 2020-08-05 16:55:18 -07:00
Ross Wightman e62758cf4f More documentation updates, fix a typo 2020-08-05 15:59:31 -07:00
Ross Wightman dfe80414a6 Add bool arg helper 2020-08-05 13:17:23 -07:00
Ross Wightman fa28067704 Add more augmentation arguments, including a no_aug disable flag. Fix #209 2020-08-05 13:16:44 -07:00
Ross Wightman b1f1a54de9 More uniform treatment of classifiers across all models, reduce code duplication. 2020-08-03 22:18:24 -07:00
Ross Wightman d72ddafe56 Fix some checkpoint / model str regressions 2020-07-29 19:43:01 -07:00
Ross Wightman ac18adb9c3 Remove debug print from RexNet 2020-07-29 11:15:19 -07:00
Ross Wightman c53ec33ae0 Add synset/label indices for results generation. Add 'valid labels' to validation script to support imagenet-a/r label subsets properly. 2020-07-29 00:58:57 -07:00
Ross Wightman ec4976fdba Add EfficientNet-Lite0 weights trained with this code by @hal-314, 75.484 top-1 2020-07-29 00:32:08 -07:00
Ross Wightman 9ecd16bd7b Add new seresnet50 (non-legacy) model weights, 80.274 top-1 2020-07-29 00:17:42 -07:00
Ross Wightman 7995295968 Merge branch 'logger' into features. Change 'logger' to '_logger'. 2020-07-27 18:00:46 -07:00
Ross Wightman 1998bd3180 Merge branch 'feature/AB/logger' of https://github.com/antoinebrl/pytorch-image-models into logger 2020-07-27 16:06:01 -07:00
Ross Wightman 6c17d57a2c Fix some attributions, add copyrights to some file docstrings 2020-07-27 13:44:56 -07:00
Ross Wightman a69c0e04f0 Fix pool size in cspnet 2020-07-27 13:44:02 -07:00
Ross Wightman 14ef7a0dd6 Rename csp.py -> cspnet.py 2020-07-27 11:15:07 -07:00
Ross Wightman ec37008432 Add pretrained weight links to CSPNet for cspdarknet53, cspresnext50 2020-07-27 11:13:21 -07:00
Sangdoo Yun e93e571f7a Add `adamp` and 'sgdp' optimizers.
Update requirements.txt

Update optim_factory.py

Add `adamp` optimizer

Update __init__.py

copy files of adamp & sgdp

Create adamp.py

Update __init__.py

Create sgdp.py

Update optim_factory.py

Update optim_factory.py

Update requirements.txt

Update adamp.py

Update sgdp.py

Update sgdp.py

Update adamp.py
2020-07-25 15:33:20 -07:00
Ross Wightman 08016e839d Cleanup FeatureInfo getters, add TF models sourced Xception41/65/71 weights 2020-07-24 17:59:21 -07:00
Ross Wightman 7ba5a384d3 Add ReXNet w/ remapped weights, feature support 2020-07-23 10:28:57 -07:00
Ross Wightman c9d54bc1c3 Add HRNet feature extraction, fix senet type, lower feature testing res to 96x96 2020-07-21 17:39:29 -07:00
Ross Wightman 2ac663f340 Add feature support to legacy senets, add 32x32 resnext models to exclude list for feature testing. 2020-07-21 11:15:30 -07:00
Ross Wightman c146b54abc Cleanup EfficientNet/MobileNetV3 feature extraction a bit, only two tap locations now, small mobilenetv3 models work 2020-07-21 01:21:38 -07:00
Ross Wightman 68fd8a267b Merge branch 'master' into features 2020-07-20 16:11:38 -07:00
Ross Wightman 4e61c6a12d Cleanup, refactoring of Feature extraction code, add tests, fix tests, non hook feature extraction working with torchscript 2020-07-20 16:10:31 -07:00
Ross Wightman 6eec3fb4a4 Move FeatureHooks into features.py, switch EfficientNet, MobileNetV3 to use build model helper 2020-07-19 15:00:43 -07:00
Ross Wightman 9eba134d79 More models supporting feature extraction, xception, gluon_xception, inception_v3, inception_v4, pnasnet, nasnet, dla. Fix DLA unused projection params. 2020-07-19 14:02:02 -07:00
Ross Wightman 298fba09ac Back out some activation hacks trialing upcoming pytorch changes 2020-07-17 18:41:37 -07:00
Ross Wightman 3b9004bef9 Lots of changes to model creation helpers, close to finalizing feature extraction / interfaces 2020-07-17 17:54:26 -07:00
Ross Wightman e2cc481310 Update CSP ResNets for cross expansion without activation. Fix VovNet IABN compatibility with fixed activation arg. 2020-07-13 16:24:55 -07:00
Ross Wightman 3b6cce4c95 Add initial impl of CrossStagePartial networks, yet to be trained, not quite the same as darknet cfgs. 2020-07-13 15:01:06 -07:00
Ross Wightman 3aebc2f06c Switch DPN to use BnAct layer, train a new DPN 68b model with RA to 79.21 2020-07-12 11:13:06 -07:00
Ross Wightman f122f0274b Significant ResNet refactor:
* stage creation + make_layer moved to separate fn with more sensible dilation/output_stride calc
* drop path rate decay easy to impl with refactored block creation loops
* fix dilation + blur pool combo
2020-07-05 00:48:12 -07:00
Ross Wightman a66df5fb91 More model feature extraction support, start to deprecate senet.py, dilations added to regnet, add proper aligned xception 2020-07-03 00:41:30 -07:00