mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Big re-org, working towards making pip/module as 'timm'
This commit is contained in:
parent
871f4c1b0c
commit
aa4354f466
@ -11,9 +11,9 @@ import argparse
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from models import create_model, apply_test_time_pool
|
from timm.models import create_model, apply_test_time_pool
|
||||||
from data import Dataset, create_loader, resolve_data_config
|
from timm.data import Dataset, create_loader, resolve_data_config
|
||||||
from utils import AverageMeter
|
from timm.utils import AverageMeter
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
@ -55,6 +55,9 @@ parser.add_argument('--topk', default=5, type=int,
|
|||||||
def main():
|
def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# might as well try to do something useful...
|
||||||
|
args.pretrained = args.pretrained or not args.checkpoint
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
model = create_model(
|
model = create_model(
|
||||||
args.model,
|
args.model,
|
||||||
|
2
requirements.txt
Normal file
2
requirements.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
torch~=1.0
|
||||||
|
torchvision
|
55
setup.py
Normal file
55
setup.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
""" Setup
|
||||||
|
"""
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
from codecs import open
|
||||||
|
from os import path
|
||||||
|
|
||||||
|
here = path.abspath(path.dirname(__file__))
|
||||||
|
|
||||||
|
# Get the long description from the README file
|
||||||
|
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
|
||||||
|
long_description = f.read()
|
||||||
|
|
||||||
|
exec(open('timm/version.py').read())
|
||||||
|
setup(
|
||||||
|
name='timm',
|
||||||
|
version=__version__,
|
||||||
|
description='(Unofficial) PyTorch Image Models',
|
||||||
|
long_description=long_description,
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models',
|
||||||
|
author='Ross Wightman',
|
||||||
|
author_email='hello@rwightman.com',
|
||||||
|
classifiers=[ # Optional
|
||||||
|
# How mature is this project? Common values are
|
||||||
|
# 3 - Alpha
|
||||||
|
# 4 - Beta
|
||||||
|
# 5 - Production/Stable
|
||||||
|
'Development Status :: 3 - Alpha',
|
||||||
|
'Intended Audience :: Developers',
|
||||||
|
'Topic :: Software Development :: Build Tools',
|
||||||
|
'License :: OSI Approved :: Apache License',
|
||||||
|
'Programming Language :: Python :: 3.6',
|
||||||
|
],
|
||||||
|
|
||||||
|
# Note that this is a string of words separated by whitespace, not a list.
|
||||||
|
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet',
|
||||||
|
|
||||||
|
# You can just specify package directories manually here if your project is
|
||||||
|
# simple. Or you can use find_packages().
|
||||||
|
#
|
||||||
|
# Alternatively, if you just want to distribute a single Python file, use
|
||||||
|
# the `py_modules` argument instead as follows, which will expect a file
|
||||||
|
# called `my_module.py` to exist:
|
||||||
|
#
|
||||||
|
# py_modules=["my_module"],
|
||||||
|
#
|
||||||
|
packages=find_packages(exclude=['convert']),
|
||||||
|
|
||||||
|
# This field lists other packages that your project depends on to run.
|
||||||
|
# Any package you put here will be installed by pip when your project is
|
||||||
|
# installed, so they must be valid existing projects.
|
||||||
|
#
|
||||||
|
# For an analysis of "install_requires" vs pip's requirements files see:
|
||||||
|
# https://packaging.python.org/en/latest/requirements.html
|
||||||
|
install_requires=['torch', 'torchvision'],
|
||||||
|
)
|
2
timm/__init__.py
Normal file
2
timm/__init__.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from .version import __version__
|
||||||
|
from .models import create_model
|
@ -1,4 +1,4 @@
|
|||||||
from data.constants import *
|
from .constants import *
|
||||||
|
|
||||||
|
|
||||||
def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
@ -1,7 +1,7 @@
|
|||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
from data.transforms import *
|
from .transforms import *
|
||||||
from data.distributed_sampler import OrderedDistributedSampler
|
from .distributed_sampler import OrderedDistributedSampler
|
||||||
from data.mixup import FastCollateMixup
|
from .mixup import FastCollateMixup
|
||||||
|
|
||||||
|
|
||||||
def fast_collate(batch):
|
def fast_collate(batch):
|
||||||
@ -101,7 +101,7 @@ def create_loader(
|
|||||||
img_size = input_size
|
img_size = input_size
|
||||||
|
|
||||||
if tf_preprocessing and use_prefetcher:
|
if tf_preprocessing and use_prefetcher:
|
||||||
from data.tf_preprocessing import TfPreprocessTransform
|
from timm.data.tf_preprocessing import TfPreprocessTransform
|
||||||
transform = TfPreprocessTransform(is_training=is_training, size=img_size)
|
transform = TfPreprocessTransform(is_training=is_training, size=img_size)
|
||||||
else:
|
else:
|
||||||
if is_training:
|
if is_training:
|
@ -1,5 +1,3 @@
|
|||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
@ -7,8 +7,8 @@ import math
|
|||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from data import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from .constants import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
from data.random_erasing import RandomErasing
|
from .random_erasing import RandomErasing
|
||||||
|
|
||||||
|
|
||||||
class ToNumpy:
|
class ToNumpy:
|
@ -2,14 +2,11 @@
|
|||||||
This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with
|
This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with
|
||||||
fixed kwargs passthrough and addition of dynamic global avg/max pool.
|
fixed kwargs passthrough and addition of dynamic global avg/max pool.
|
||||||
"""
|
"""
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from models.helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from models.adaptive_avgmax_pool import *
|
from .adaptive_avgmax_pool import *
|
||||||
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
import re
|
import re
|
||||||
|
|
||||||
_models = ['densenet121', 'densenet169', 'densenet201', 'densenet161']
|
_models = ['densenet121', 'densenet169', 'densenet201', 'densenet161']
|
@ -9,15 +9,13 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from models.helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from models.adaptive_avgmax_pool import select_adaptive_pool2d
|
from .adaptive_avgmax_pool import select_adaptive_pool2d
|
||||||
from data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
|
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
|
||||||
|
|
||||||
_models = ['dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
|
_models = ['dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
|
||||||
__all__ = ['DPN'] + _models
|
__all__ = ['DPN'] + _models
|
@ -22,10 +22,10 @@ from copy import deepcopy
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from models.helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||||
from models.conv2d_same import sconv2d
|
from .conv2d_same import sconv2d
|
||||||
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
_models = [
|
_models = [
|
||||||
'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075',
|
'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', 'semnasnet_050', 'semnasnet_075',
|
@ -3,13 +3,12 @@ This file evolved from https://github.com/pytorch/vision 'resnet.py' with (SE)-R
|
|||||||
and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py)
|
and ports of Gluon variations (https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/resnet.py)
|
||||||
by Ross Wightman
|
by Ross Wightman
|
||||||
"""
|
"""
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
from models.helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||||
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
_models = [
|
_models = [
|
||||||
'gluon_resnet18_v1b', 'gluon_resnet34_v1b', 'gluon_resnet50_v1b', 'gluon_resnet101_v1b', 'gluon_resnet152_v1b',
|
'gluon_resnet18_v1b', 'gluon_resnet34_v1b', 'gluon_resnet50_v1b', 'gluon_resnet101_v1b', 'gluon_resnet152_v1b',
|
@ -2,12 +2,9 @@
|
|||||||
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
|
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
|
||||||
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
||||||
"""
|
"""
|
||||||
import torch
|
from .helpers import load_pretrained
|
||||||
import torch.nn as nn
|
from .adaptive_avgmax_pool import *
|
||||||
import torch.nn.functional as F
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from models.helpers import load_pretrained
|
|
||||||
from models.adaptive_avgmax_pool import *
|
|
||||||
from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
|
||||||
|
|
||||||
_models = ['inception_resnet_v2']
|
_models = ['inception_resnet_v2']
|
||||||
__all__ = ['InceptionResnetV2'] + _models
|
__all__ = ['InceptionResnetV2'] + _models
|
@ -1,6 +1,6 @@
|
|||||||
from torchvision.models import Inception3
|
from torchvision.models import Inception3
|
||||||
from models.helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
|
|
||||||
_models = ['inception_v3', 'tf_inception_v3', 'adv_inception_v3', 'gluon_inception_v3']
|
_models = ['inception_v3', 'tf_inception_v3', 'adv_inception_v3', 'gluon_inception_v3']
|
||||||
__all__ = _models
|
__all__ = _models
|
@ -2,12 +2,9 @@
|
|||||||
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
|
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
|
||||||
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
||||||
"""
|
"""
|
||||||
import torch
|
from .helpers import load_pretrained
|
||||||
import torch.nn as nn
|
from .adaptive_avgmax_pool import *
|
||||||
import torch.nn.functional as F
|
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
from models.helpers import load_pretrained
|
|
||||||
from models.adaptive_avgmax_pool import *
|
|
||||||
from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
|
||||||
|
|
||||||
_models = ['inception_v4']
|
_models = ['inception_v4']
|
||||||
__all__ = ['InceptionV4'] + _models
|
__all__ = ['InceptionV4'] + _models
|
@ -1,21 +1,21 @@
|
|||||||
from models.inception_v4 import *
|
from .inception_v4 import *
|
||||||
from models.inception_resnet_v2 import *
|
from .inception_resnet_v2 import *
|
||||||
from models.densenet import *
|
from .densenet import *
|
||||||
from models.resnet import *
|
from .resnet import *
|
||||||
from models.dpn import *
|
from .dpn import *
|
||||||
from models.senet import *
|
from .senet import *
|
||||||
from models.xception import *
|
from .xception import *
|
||||||
from models.pnasnet import *
|
from .pnasnet import *
|
||||||
from models.gen_efficientnet import *
|
from .gen_efficientnet import *
|
||||||
from models.inception_v3 import *
|
from .inception_v3 import *
|
||||||
from models.gluon_resnet import *
|
from .gluon_resnet import *
|
||||||
|
|
||||||
from models.helpers import load_checkpoint
|
from .helpers import load_checkpoint
|
||||||
|
|
||||||
|
|
||||||
def create_model(
|
def create_model(
|
||||||
model_name='resnet50',
|
model_name,
|
||||||
pretrained=None,
|
pretrained=False,
|
||||||
num_classes=1000,
|
num_classes=1000,
|
||||||
in_chans=3,
|
in_chans=3,
|
||||||
checkpoint_path='',
|
checkpoint_path='',
|
@ -12,8 +12,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from models.helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||||
|
|
||||||
_models = ['pnasnet5large']
|
_models = ['pnasnet5large']
|
||||||
__all__ = ['PNASNet5Large'] + _models
|
__all__ = ['PNASNet5Large'] + _models
|
@ -4,13 +4,12 @@ additional dropout and dynamic global avg/max pool.
|
|||||||
|
|
||||||
ResNext additions added by Ross Wightman
|
ResNext additions added by Ross Wightman
|
||||||
"""
|
"""
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
from models.helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||||
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
_models = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
|
_models = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
|
||||||
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d']
|
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d']
|
@ -15,9 +15,9 @@ import math
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from models.helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||||
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||||
|
|
||||||
_models = ['seresnet18', 'seresnet34', 'seresnet50', 'seresnet101', 'seresnet152', 'senet154',
|
_models = ['seresnet18', 'seresnet34', 'seresnet50', 'seresnet101', 'seresnet152', 'senet154',
|
||||||
'seresnext26_32x4d', 'seresnext50_32x4d', 'seresnext101_32x4d']
|
'seresnext26_32x4d', 'seresnext50_32x4d', 'seresnext101_32x4d']
|
@ -1,6 +1,6 @@
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from models.adaptive_avgmax_pool import adaptive_avgmax_pool2d
|
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
|
||||||
|
|
||||||
|
|
||||||
class TestTimePoolHead(nn.Module):
|
class TestTimePoolHead(nn.Module):
|
@ -27,8 +27,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from models.helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from models.adaptive_avgmax_pool import select_adaptive_pool2d
|
from .adaptive_avgmax_pool import select_adaptive_pool2d
|
||||||
|
|
||||||
_models = ['xception']
|
_models = ['xception']
|
||||||
__all__ = ['Xception'] + _models
|
__all__ = ['Xception'] + _models
|
@ -1,5 +1,5 @@
|
|||||||
from torch import optim as optim
|
from torch import optim as optim
|
||||||
from optim import Nadam, RMSpropTF
|
from timm.optim import Nadam, RMSpropTF
|
||||||
|
|
||||||
|
|
||||||
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
|
@ -1,7 +1,6 @@
|
|||||||
from scheduler.cosine_lr import CosineLRScheduler
|
from .cosine_lr import CosineLRScheduler
|
||||||
from scheduler.plateau_lr import PlateauLRScheduler
|
from .tanh_lr import TanhLRScheduler
|
||||||
from scheduler.tanh_lr import TanhLRScheduler
|
from .step_lr import StepLRScheduler
|
||||||
from scheduler.step_lr import StepLRScheduler
|
|
||||||
|
|
||||||
|
|
||||||
def create_scheduler(args, optimizer):
|
def create_scheduler(args, optimizer):
|
1
timm/version.py
Normal file
1
timm/version.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.1.1'
|
13
train.py
13
train.py
@ -11,16 +11,15 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
has_apex = False
|
has_apex = False
|
||||||
|
|
||||||
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
|
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
|
||||||
from models import create_model, resume_checkpoint, load_checkpoint
|
from timm.models import create_model, resume_checkpoint
|
||||||
from utils import *
|
from timm.utils import *
|
||||||
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||||
from optim import create_optimizer
|
from timm.optim import create_optimizer
|
||||||
from scheduler import create_scheduler
|
from timm.scheduler import create_scheduler
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.distributed as dist
|
|
||||||
import torchvision.utils
|
import torchvision.utils
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
@ -12,9 +12,9 @@ import torch.nn as nn
|
|||||||
import torch.nn.parallel
|
import torch.nn.parallel
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from models import create_model, apply_test_time_pool, load_checkpoint
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint
|
||||||
from data import Dataset, create_loader, resolve_data_config
|
from timm.data import Dataset, create_loader, resolve_data_config
|
||||||
from utils import accuracy, AverageMeter, natural_key
|
from timm.utils import accuracy, AverageMeter, natural_key
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user