mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge branch 'master' into densenet_update_and_more
This commit is contained in:
commit
7df83258c9
42
.github/workflows/tests.yml
vendored
Normal file
42
.github/workflows/tests.yml
vendored
Normal file
@ -0,0 +1,42 @@
|
||||
name: Python tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ master ]
|
||||
pull_request:
|
||||
branches: [ master ]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Run tests on ${{ matrix.os }} with Python ${{ matrix.python }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macOS-latest]
|
||||
python: ['3.8']
|
||||
torch: ['1.5.0']
|
||||
torchvision: ['0.6.0']
|
||||
runs-on: ${{ matrix.os }}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python }}
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: ${{ matrix.python }}
|
||||
- name: Install testing dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install pytest pytest-timeout
|
||||
- name: Install torch on mac
|
||||
if: startsWith(matrix.os, 'macOS')
|
||||
run: pip install torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
|
||||
- name: Install torch on ubuntu
|
||||
if: startsWith(matrix.os, 'ubuntu')
|
||||
run: pip install torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
- name: Install requirements
|
||||
run: |
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -vv --durations=0 ./tests
|
39
README.md
39
README.md
@ -2,6 +2,9 @@
|
||||
|
||||
## What's New
|
||||
|
||||
### May 12, 2020
|
||||
* Add ResNeSt models (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955))
|
||||
|
||||
### May 3, 2020
|
||||
* Pruned EfficientNet B1, B2, and B3 (https://arxiv.org/abs/2002.08258) contributed by [Yonathan Aflalo](https://github.com/yoniaflalo)
|
||||
|
||||
@ -70,41 +73,6 @@
|
||||
* Add RandAugment trained EfficientNet-B0 weight with 77.7 top-1. Trained by [Michael Klachko](https://github.com/michaelklachko) with this code and recent hparams (see Training section)
|
||||
* Add `avg_checkpoints.py` script for post training weight averaging and update all scripts with header docstrings and shebangs.
|
||||
|
||||
### Dec 30, 2019
|
||||
* Merge [Dushyant Mehta's](https://github.com/mehtadushy) PR for SelecSLS (Selective Short and Long Range Skip Connections) networks. Good GPU memory consumption and throughput. Original: https://github.com/mehtadushy/SelecSLS-Pytorch
|
||||
|
||||
### Dec 28, 2019
|
||||
* Add new model weights and training hparams (see Training Hparams section)
|
||||
* `efficientnet_b3` - 81.5 top-1, 95.7 top-5 at default res/crop, 81.9, 95.8 at 320x320 1.0 crop-pct
|
||||
* trained with RandAugment, ended up with an interesting but less than perfect result (see training section)
|
||||
* `seresnext26d_32x4d`- 77.6 top-1, 93.6 top-5
|
||||
* deep stem (32, 32, 64), avgpool downsample
|
||||
* stem/dowsample from bag-of-tricks paper
|
||||
* `seresnext26t_32x4d`- 78.0 top-1, 93.7 top-5
|
||||
* deep tiered stem (24, 48, 64), avgpool downsample (a modified 'D' variant)
|
||||
* stem sizing mods from Jeremy Howard and fastai devs discussing ResNet architecture experiments
|
||||
|
||||
### Dec 23, 2019
|
||||
* Add RandAugment trained MixNet-XL weights with 80.48 top-1.
|
||||
* `--dist-bn` argument added to train.py, will distribute BN stats between nodes after each train epoch, before eval
|
||||
|
||||
### Dec 4, 2019
|
||||
* Added weights from the first training from scratch of an EfficientNet (B2) with my new RandAugment implementation. Much better than my previous B2 and very close to the official AdvProp ones (80.4 top-1, 95.08 top-5).
|
||||
|
||||
### Nov 29, 2019
|
||||
* Brought EfficientNet and MobileNetV3 up to date with my https://github.com/rwightman/gen-efficientnet-pytorch code. Torchscript and ONNX export compat excluded.
|
||||
* AdvProp weights added
|
||||
* Official TF MobileNetv3 weights added
|
||||
* EfficientNet and MobileNetV3 hook based 'feature extraction' classes added. Will serve as basis for using models as backbones in obj detection/segmentation tasks. Lots more to be done here...
|
||||
* HRNet classification models and weights added from https://github.com/HRNet/HRNet-Image-Classification
|
||||
* Consistency in global pooling, `reset_classifer`, and `forward_features` across models
|
||||
* `forward_features` always returns unpooled feature maps now
|
||||
* Reasonable chance I broke something... let me know
|
||||
|
||||
### Nov 22, 2019
|
||||
* Add ImageNet training RandAugment implementation alongside AutoAugment. PyTorch Transform compatible format, using PIL. Currently training two EfficientNet models from scratch with promising results... will update.
|
||||
* `drop-connect` cmd line arg finally added to `train.py`, no need to hack model fns. Works for efficientnet/mobilenetv3 based models, ignored otherwise.
|
||||
|
||||
## Introduction
|
||||
|
||||
For each competition, personal, or freelance project involving images + Convolution Neural Networks, I build on top of an evolving collection of code and models. This repo contains a (somewhat) cleaned up and paired down iteration of that code. Hopefully it'll be of use to others.
|
||||
@ -130,6 +98,7 @@ Included models:
|
||||
* Instagram trained / ImageNet tuned ResNeXt101-32x8d to 32x48d from from [facebookresearch](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/)
|
||||
* Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169)
|
||||
* Selective Kernel (SK) Nets (https://arxiv.org/abs/1903.06586)
|
||||
* ResNeSt (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955)
|
||||
* DLA
|
||||
* Original (https://github.com/ucbdrive/dla, https://arxiv.org/abs/1707.06484)
|
||||
* Res2Net (https://github.com/gasvn/Res2Net, https://arxiv.org/abs/1904.01169)
|
||||
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
19
tests/test_inference.py
Normal file
19
tests/test_inference.py
Normal file
@ -0,0 +1,19 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from timm import list_models, create_model
|
||||
|
||||
|
||||
@pytest.mark.timeout(300)
|
||||
@pytest.mark.parametrize('model_name', list_models(exclude_filters='*efficientnet_l2*'))
|
||||
@pytest.mark.parametrize('batch_size', [1])
|
||||
def test_model_forward(model_name, batch_size):
|
||||
"""Run a single forward pass with each model"""
|
||||
model = create_model(model_name, pretrained=False)
|
||||
model.eval()
|
||||
|
||||
inputs = torch.randn((batch_size, *model.default_cfg['input_size']))
|
||||
outputs = model(inputs)
|
||||
|
||||
assert outputs.shape[0] == batch_size
|
||||
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|
@ -18,6 +18,7 @@ from .dla import *
|
||||
from .hrnet import *
|
||||
from .sknet import *
|
||||
from .tresnet import *
|
||||
from .resnest import *
|
||||
|
||||
from .registry import *
|
||||
from .factory import create_model
|
||||
|
@ -1,121 +1,561 @@
|
||||
from torchvision.models import Inception3
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
from .helpers import load_pretrained
|
||||
from .registry import register_model
|
||||
from .layers import trunc_normal_, SelectAdaptivePool2d
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'conv1', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
__all__ = []
|
||||
|
||||
default_cfgs = {
|
||||
# original PyTorch weights, ported from Tensorflow but modified
|
||||
'inception_v3': {
|
||||
'url': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
|
||||
'input_size': (3, 299, 299),
|
||||
'crop_pct': 0.875,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN, # also works well enough with resnet defaults
|
||||
'std': IMAGENET_INCEPTION_STD, # also works well enough with resnet defaults
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv0',
|
||||
'classifier': 'fc'
|
||||
},
|
||||
'inception_v3': _cfg(
|
||||
url='https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
|
||||
has_aux=True), # checkpoint has aux logit layer weights
|
||||
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
||||
'tf_inception_v3': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
|
||||
'input_size': (3, 299, 299),
|
||||
'crop_pct': 0.875,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN,
|
||||
'std': IMAGENET_INCEPTION_STD,
|
||||
'num_classes': 1001,
|
||||
'first_conv': 'conv0',
|
||||
'classifier': 'fc'
|
||||
},
|
||||
'tf_inception_v3': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
|
||||
num_classes=1001, has_aux=False),
|
||||
# my port of Tensorflow adversarially trained Inception V3 from
|
||||
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
||||
'adv_inception_v3': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
|
||||
'input_size': (3, 299, 299),
|
||||
'crop_pct': 0.875,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_INCEPTION_MEAN,
|
||||
'std': IMAGENET_INCEPTION_STD,
|
||||
'num_classes': 1001,
|
||||
'first_conv': 'conv0',
|
||||
'classifier': 'fc'
|
||||
},
|
||||
'adv_inception_v3': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
|
||||
num_classes=1001, has_aux=False),
|
||||
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
||||
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
||||
'gluon_inception_v3': {
|
||||
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth',
|
||||
'input_size': (3, 299, 299),
|
||||
'crop_pct': 0.875,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
|
||||
'std': IMAGENET_DEFAULT_STD, # also works well with inception defaults
|
||||
'num_classes': 1000,
|
||||
'first_conv': 'conv0',
|
||||
'classifier': 'fc'
|
||||
}
|
||||
'gluon_inception_v3': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_inception_v3-9f746940.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, # also works well with inception defaults
|
||||
std=IMAGENET_DEFAULT_STD, # also works well with inception defaults
|
||||
has_aux=False,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def _assert_default_kwargs(kwargs):
|
||||
# for imported models (ie torchvision) without capability to change these params,
|
||||
# make sure they aren't being set to non-defaults
|
||||
assert kwargs.pop('global_pool', 'avg') == 'avg'
|
||||
assert kwargs.pop('drop_rate', 0.) == 0.
|
||||
class InceptionV3Aux(nn.Module):
|
||||
"""InceptionV3 with AuxLogits
|
||||
"""
|
||||
|
||||
def __init__(self, inception_blocks=None, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'):
|
||||
super(InceptionV3Aux, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
if inception_blocks is None:
|
||||
inception_blocks = [
|
||||
BasicConv2d, InceptionA, InceptionB, InceptionC,
|
||||
InceptionD, InceptionE, InceptionAux
|
||||
]
|
||||
assert len(inception_blocks) == 7
|
||||
conv_block = inception_blocks[0]
|
||||
inception_a = inception_blocks[1]
|
||||
inception_b = inception_blocks[2]
|
||||
inception_c = inception_blocks[3]
|
||||
inception_d = inception_blocks[4]
|
||||
inception_e = inception_blocks[5]
|
||||
inception_aux = inception_blocks[6]
|
||||
|
||||
self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2)
|
||||
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
|
||||
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
|
||||
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
|
||||
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
|
||||
self.Mixed_5b = inception_a(192, pool_features=32)
|
||||
self.Mixed_5c = inception_a(256, pool_features=64)
|
||||
self.Mixed_5d = inception_a(288, pool_features=64)
|
||||
self.Mixed_6a = inception_b(288)
|
||||
self.Mixed_6b = inception_c(768, channels_7x7=128)
|
||||
self.Mixed_6c = inception_c(768, channels_7x7=160)
|
||||
self.Mixed_6d = inception_c(768, channels_7x7=160)
|
||||
self.Mixed_6e = inception_c(768, channels_7x7=192)
|
||||
self.AuxLogits = inception_aux(768, num_classes)
|
||||
self.Mixed_7a = inception_d(768)
|
||||
self.Mixed_7b = inception_e(1280)
|
||||
self.Mixed_7c = inception_e(2048)
|
||||
|
||||
self.num_features = 2048
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
|
||||
trunc_normal_(m.weight, std=stddev)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward_features(self, x):
|
||||
# N x 3 x 299 x 299
|
||||
x = self.Conv2d_1a_3x3(x)
|
||||
# N x 32 x 149 x 149
|
||||
x = self.Conv2d_2a_3x3(x)
|
||||
# N x 32 x 147 x 147
|
||||
x = self.Conv2d_2b_3x3(x)
|
||||
# N x 64 x 147 x 147
|
||||
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||
# N x 64 x 73 x 73
|
||||
x = self.Conv2d_3b_1x1(x)
|
||||
# N x 80 x 73 x 73
|
||||
x = self.Conv2d_4a_3x3(x)
|
||||
# N x 192 x 71 x 71
|
||||
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||
# N x 192 x 35 x 35
|
||||
x = self.Mixed_5b(x)
|
||||
# N x 256 x 35 x 35
|
||||
x = self.Mixed_5c(x)
|
||||
# N x 288 x 35 x 35
|
||||
x = self.Mixed_5d(x)
|
||||
# N x 288 x 35 x 35
|
||||
x = self.Mixed_6a(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6b(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6c(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6d(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6e(x)
|
||||
# N x 768 x 17 x 17
|
||||
aux = self.AuxLogits(x) if self.training else None
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_7a(x)
|
||||
# N x 1280 x 8 x 8
|
||||
x = self.Mixed_7b(x)
|
||||
# N x 2048 x 8 x 8
|
||||
x = self.Mixed_7c(x)
|
||||
# N x 2048 x 8 x 8
|
||||
return x, aux
|
||||
|
||||
def get_classifier(self):
|
||||
return self.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.num_classes = num_classes
|
||||
if self.num_classes > 0:
|
||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
else:
|
||||
self.fc = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x, aux = self.forward_features(x)
|
||||
x = self.global_pool(x).flatten(1)
|
||||
if self.drop_rate > 0:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.fc(x)
|
||||
return x, aux
|
||||
|
||||
|
||||
class InceptionV3(nn.Module):
|
||||
"""Inception-V3 with no AuxLogits
|
||||
FIXME two class defs are redundant, but less screwing around with torchsript fussyness and inconsistent returns
|
||||
"""
|
||||
|
||||
def __init__(self, inception_blocks=None, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'):
|
||||
super(InceptionV3, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
if inception_blocks is None:
|
||||
inception_blocks = [
|
||||
BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE]
|
||||
assert len(inception_blocks) >= 6
|
||||
conv_block = inception_blocks[0]
|
||||
inception_a = inception_blocks[1]
|
||||
inception_b = inception_blocks[2]
|
||||
inception_c = inception_blocks[3]
|
||||
inception_d = inception_blocks[4]
|
||||
inception_e = inception_blocks[5]
|
||||
|
||||
self.Conv2d_1a_3x3 = conv_block(in_chans, 32, kernel_size=3, stride=2)
|
||||
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
|
||||
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
|
||||
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
|
||||
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
|
||||
self.Mixed_5b = inception_a(192, pool_features=32)
|
||||
self.Mixed_5c = inception_a(256, pool_features=64)
|
||||
self.Mixed_5d = inception_a(288, pool_features=64)
|
||||
self.Mixed_6a = inception_b(288)
|
||||
self.Mixed_6b = inception_c(768, channels_7x7=128)
|
||||
self.Mixed_6c = inception_c(768, channels_7x7=160)
|
||||
self.Mixed_6d = inception_c(768, channels_7x7=160)
|
||||
self.Mixed_6e = inception_c(768, channels_7x7=192)
|
||||
self.Mixed_7a = inception_d(768)
|
||||
self.Mixed_7b = inception_e(1280)
|
||||
self.Mixed_7c = inception_e(2048)
|
||||
|
||||
self.num_features = 2048
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.fc = nn.Linear(2048, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
||||
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
|
||||
trunc_normal_(m.weight, std=stddev)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward_features(self, x):
|
||||
# N x 3 x 299 x 299
|
||||
x = self.Conv2d_1a_3x3(x)
|
||||
# N x 32 x 149 x 149
|
||||
x = self.Conv2d_2a_3x3(x)
|
||||
# N x 32 x 147 x 147
|
||||
x = self.Conv2d_2b_3x3(x)
|
||||
# N x 64 x 147 x 147
|
||||
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||
# N x 64 x 73 x 73
|
||||
x = self.Conv2d_3b_1x1(x)
|
||||
# N x 80 x 73 x 73
|
||||
x = self.Conv2d_4a_3x3(x)
|
||||
# N x 192 x 71 x 71
|
||||
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||
# N x 192 x 35 x 35
|
||||
x = self.Mixed_5b(x)
|
||||
# N x 256 x 35 x 35
|
||||
x = self.Mixed_5c(x)
|
||||
# N x 288 x 35 x 35
|
||||
x = self.Mixed_5d(x)
|
||||
# N x 288 x 35 x 35
|
||||
x = self.Mixed_6a(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6b(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6c(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6d(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_6e(x)
|
||||
# N x 768 x 17 x 17
|
||||
x = self.Mixed_7a(x)
|
||||
# N x 1280 x 8 x 8
|
||||
x = self.Mixed_7b(x)
|
||||
# N x 2048 x 8 x 8
|
||||
x = self.Mixed_7c(x)
|
||||
# N x 2048 x 8 x 8
|
||||
return x
|
||||
|
||||
def get_classifier(self):
|
||||
return self.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.num_classes = num_classes
|
||||
if self.num_classes > 0:
|
||||
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
else:
|
||||
self.fc = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.global_pool(x).flatten(1)
|
||||
if self.drop_rate > 0:
|
||||
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
class InceptionA(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, pool_features, conv_block=None):
|
||||
super(InceptionA, self).__init__()
|
||||
if conv_block is None:
|
||||
conv_block = BasicConv2d
|
||||
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
|
||||
|
||||
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
|
||||
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
|
||||
|
||||
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
||||
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
||||
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
|
||||
|
||||
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
|
||||
|
||||
def _forward(self, x):
|
||||
branch1x1 = self.branch1x1(x)
|
||||
|
||||
branch5x5 = self.branch5x5_1(x)
|
||||
branch5x5 = self.branch5x5_2(branch5x5)
|
||||
|
||||
branch3x3dbl = self.branch3x3dbl_1(x)
|
||||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
||||
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
||||
|
||||
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
||||
branch_pool = self.branch_pool(branch_pool)
|
||||
|
||||
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
||||
return outputs
|
||||
|
||||
def forward(self, x):
|
||||
outputs = self._forward(x)
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
class InceptionB(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, conv_block=None):
|
||||
super(InceptionB, self).__init__()
|
||||
if conv_block is None:
|
||||
conv_block = BasicConv2d
|
||||
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
|
||||
|
||||
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
||||
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
||||
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
|
||||
|
||||
def _forward(self, x):
|
||||
branch3x3 = self.branch3x3(x)
|
||||
|
||||
branch3x3dbl = self.branch3x3dbl_1(x)
|
||||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
||||
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
||||
|
||||
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||
|
||||
outputs = [branch3x3, branch3x3dbl, branch_pool]
|
||||
return outputs
|
||||
|
||||
def forward(self, x):
|
||||
outputs = self._forward(x)
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
class InceptionC(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, channels_7x7, conv_block=None):
|
||||
super(InceptionC, self).__init__()
|
||||
if conv_block is None:
|
||||
conv_block = BasicConv2d
|
||||
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
|
||||
|
||||
c7 = channels_7x7
|
||||
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
|
||||
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
||||
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
|
||||
|
||||
self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
|
||||
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
||||
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
||||
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
||||
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
||||
|
||||
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
||||
|
||||
def _forward(self, x):
|
||||
branch1x1 = self.branch1x1(x)
|
||||
|
||||
branch7x7 = self.branch7x7_1(x)
|
||||
branch7x7 = self.branch7x7_2(branch7x7)
|
||||
branch7x7 = self.branch7x7_3(branch7x7)
|
||||
|
||||
branch7x7dbl = self.branch7x7dbl_1(x)
|
||||
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
||||
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
||||
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
||||
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
||||
|
||||
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
||||
branch_pool = self.branch_pool(branch_pool)
|
||||
|
||||
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
||||
return outputs
|
||||
|
||||
def forward(self, x):
|
||||
outputs = self._forward(x)
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
class InceptionD(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, conv_block=None):
|
||||
super(InceptionD, self).__init__()
|
||||
if conv_block is None:
|
||||
conv_block = BasicConv2d
|
||||
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
||||
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
|
||||
|
||||
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
||||
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
|
||||
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
|
||||
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
|
||||
|
||||
def _forward(self, x):
|
||||
branch3x3 = self.branch3x3_1(x)
|
||||
branch3x3 = self.branch3x3_2(branch3x3)
|
||||
|
||||
branch7x7x3 = self.branch7x7x3_1(x)
|
||||
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
|
||||
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
|
||||
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
|
||||
|
||||
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
||||
outputs = [branch3x3, branch7x7x3, branch_pool]
|
||||
return outputs
|
||||
|
||||
def forward(self, x):
|
||||
outputs = self._forward(x)
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
class InceptionE(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, conv_block=None):
|
||||
super(InceptionE, self).__init__()
|
||||
if conv_block is None:
|
||||
conv_block = BasicConv2d
|
||||
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
|
||||
|
||||
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
|
||||
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
||||
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
||||
|
||||
self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
|
||||
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
|
||||
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
||||
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
||||
|
||||
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
||||
|
||||
def _forward(self, x):
|
||||
branch1x1 = self.branch1x1(x)
|
||||
|
||||
branch3x3 = self.branch3x3_1(x)
|
||||
branch3x3 = [
|
||||
self.branch3x3_2a(branch3x3),
|
||||
self.branch3x3_2b(branch3x3),
|
||||
]
|
||||
branch3x3 = torch.cat(branch3x3, 1)
|
||||
|
||||
branch3x3dbl = self.branch3x3dbl_1(x)
|
||||
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
||||
branch3x3dbl = [
|
||||
self.branch3x3dbl_3a(branch3x3dbl),
|
||||
self.branch3x3dbl_3b(branch3x3dbl),
|
||||
]
|
||||
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
||||
|
||||
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
||||
branch_pool = self.branch_pool(branch_pool)
|
||||
|
||||
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
||||
return outputs
|
||||
|
||||
def forward(self, x):
|
||||
outputs = self._forward(x)
|
||||
return torch.cat(outputs, 1)
|
||||
|
||||
|
||||
class InceptionAux(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, num_classes, conv_block=None):
|
||||
super(InceptionAux, self).__init__()
|
||||
if conv_block is None:
|
||||
conv_block = BasicConv2d
|
||||
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
|
||||
self.conv1 = conv_block(128, 768, kernel_size=5)
|
||||
self.conv1.stddev = 0.01
|
||||
self.fc = nn.Linear(768, num_classes)
|
||||
self.fc.stddev = 0.001
|
||||
|
||||
def forward(self, x):
|
||||
# N x 768 x 17 x 17
|
||||
x = F.avg_pool2d(x, kernel_size=5, stride=3)
|
||||
# N x 768 x 5 x 5
|
||||
x = self.conv0(x)
|
||||
# N x 128 x 5 x 5
|
||||
x = self.conv1(x)
|
||||
# N x 768 x 1 x 1
|
||||
# Adaptive average pooling
|
||||
x = F.adaptive_avg_pool2d(x, (1, 1))
|
||||
# N x 768 x 1 x 1
|
||||
x = torch.flatten(x, 1)
|
||||
# N x 768
|
||||
x = self.fc(x)
|
||||
# N x 1000
|
||||
return x
|
||||
|
||||
|
||||
class BasicConv2d(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, **kwargs):
|
||||
super(BasicConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
||||
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return F.relu(x, inplace=True)
|
||||
|
||||
|
||||
def _inception_v3(variant, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs[variant]
|
||||
if kwargs.pop('features_only', False):
|
||||
assert False, 'Not Implemented' # TODO
|
||||
load_strict = False
|
||||
model_kwargs.pop('num_classes', 0)
|
||||
model_class = InceptionV3
|
||||
else:
|
||||
aux_logits = kwargs.pop('aux_logits', False)
|
||||
if aux_logits:
|
||||
model_class = InceptionV3Aux
|
||||
load_strict = default_cfg['has_aux']
|
||||
else:
|
||||
model_class = InceptionV3
|
||||
load_strict = not default_cfg['has_aux']
|
||||
|
||||
model = model_class(**kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
num_classes=kwargs.get('num_classes', 0),
|
||||
in_chans=kwargs.get('in_chans', 3),
|
||||
strict=load_strict)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def inception_v3(pretrained=False, **kwargs):
|
||||
# original PyTorch weights, ported from Tensorflow but modified
|
||||
default_cfg = default_cfgs['inception_v3']
|
||||
assert in_chans == 3
|
||||
_assert_default_kwargs(kwargs)
|
||||
model = Inception3(num_classes=num_classes, aux_logits=True, transform_input=False)
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
model = _inception_v3('inception_v3', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tf_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def tf_inception_v3(pretrained=False, **kwargs):
|
||||
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
|
||||
default_cfg = default_cfgs['tf_inception_v3']
|
||||
assert in_chans == 3
|
||||
_assert_default_kwargs(kwargs)
|
||||
model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False)
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
model = _inception_v3('tf_inception_v3', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def adv_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def adv_inception_v3(pretrained=False, **kwargs):
|
||||
# my port of Tensorflow adversarially trained Inception V3 from
|
||||
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
|
||||
default_cfg = default_cfgs['adv_inception_v3']
|
||||
assert in_chans == 3
|
||||
_assert_default_kwargs(kwargs)
|
||||
model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False)
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
model = _inception_v3('adv_inception_v3', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_inception_v3(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
def gluon_inception_v3(pretrained=False, **kwargs):
|
||||
# from gluon pretrained models, best performing in terms of accuracy/loss metrics
|
||||
# https://gluon-cv.mxnet.io/model_zoo/classification.html
|
||||
default_cfg = default_cfgs['gluon_inception_v3']
|
||||
assert in_chans == 3
|
||||
_assert_default_kwargs(kwargs)
|
||||
model = Inception3(num_classes=num_classes, aux_logits=False, transform_input=False)
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
model = _inception_v3('gluon_inception_v3', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
@ -22,3 +22,4 @@ from .blur_pool import BlurPool2d
|
||||
from .norm_act import BatchNormAct2d
|
||||
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
|
||||
from .create_norm_act import create_norm_act
|
||||
from .weight_init import trunc_normal_
|
||||
|
@ -22,43 +22,88 @@ import math
|
||||
|
||||
|
||||
def drop_block_2d(
|
||||
x, drop_prob: float = 0.1, training: bool = False, block_size: int = 7,
|
||||
gamma_scale: float = 1.0, drop_with_noise: bool = False):
|
||||
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
|
||||
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
|
||||
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
|
||||
runs with success, but needs further validation and possibly optimization for lower runtime impact.
|
||||
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
_, _, height, width = x.shape
|
||||
total_size = width * height
|
||||
clipped_block_size = min(block_size, min(width, height))
|
||||
B, C, H, W = x.shape
|
||||
total_size = W * H
|
||||
clipped_block_size = min(block_size, min(W, H))
|
||||
# seed_drop_rate, the gamma parameter
|
||||
seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||
(width - block_size + 1) *
|
||||
(height - block_size + 1))
|
||||
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||
(W - block_size + 1) * (H - block_size + 1))
|
||||
|
||||
# Forces the block to be inside the feature map.
|
||||
w_i, h_i = torch.meshgrid(torch.arange(width).to(x.device), torch.arange(height).to(x.device))
|
||||
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < width - (clipped_block_size - 1) // 2)) & \
|
||||
((h_i >= clipped_block_size // 2) & (h_i < height - (clipped_block_size - 1) // 2))
|
||||
valid_block = torch.reshape(valid_block, (1, 1, height, width)).float()
|
||||
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
|
||||
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
|
||||
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
|
||||
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
|
||||
|
||||
uniform_noise = torch.rand_like(x, dtype=torch.float32)
|
||||
block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise) >= 1).float()
|
||||
if batchwise:
|
||||
# one mask for whole batch, quite a bit faster
|
||||
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
uniform_noise = torch.rand_like(x)
|
||||
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
|
||||
block_mask = -F.max_pool2d(
|
||||
-block_mask,
|
||||
kernel_size=clipped_block_size, # block_size, ???
|
||||
kernel_size=clipped_block_size, # block_size,
|
||||
stride=1,
|
||||
padding=clipped_block_size // 2)
|
||||
|
||||
if drop_with_noise:
|
||||
normal_noise = torch.randn_like(x)
|
||||
if with_noise:
|
||||
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
||||
if inplace:
|
||||
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
|
||||
else:
|
||||
x = x * block_mask + normal_noise * (1 - block_mask)
|
||||
else:
|
||||
normalize_scale = block_mask.numel() / (torch.sum(block_mask) + 1e-7)
|
||||
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
|
||||
if inplace:
|
||||
x.mul_(block_mask * normalize_scale)
|
||||
else:
|
||||
x = x * block_mask * normalize_scale
|
||||
return x
|
||||
|
||||
|
||||
def drop_block_fast_2d(
|
||||
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
|
||||
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
|
||||
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
|
||||
|
||||
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
|
||||
block mask at edges.
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
total_size = W * H
|
||||
clipped_block_size = min(block_size, min(W, H))
|
||||
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
|
||||
(W - block_size + 1) * (H - block_size + 1))
|
||||
|
||||
if batchwise:
|
||||
# one mask for whole batch, quite a bit faster
|
||||
block_mask = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device) < gamma
|
||||
else:
|
||||
# mask per batch element
|
||||
block_mask = torch.rand_like(x) < gamma
|
||||
block_mask = F.max_pool2d(
|
||||
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
|
||||
|
||||
if with_noise:
|
||||
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
|
||||
if inplace:
|
||||
x.mul_(1. - block_mask).add_(normal_noise * block_mask)
|
||||
else:
|
||||
x = x * (1. - block_mask) + normal_noise * block_mask
|
||||
else:
|
||||
block_mask = 1 - block_mask
|
||||
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(dtype=x.dtype)
|
||||
if inplace:
|
||||
x.mul_(block_mask * normalize_scale)
|
||||
else:
|
||||
x = x * block_mask * normalize_scale
|
||||
return x
|
||||
|
||||
@ -70,15 +115,28 @@ class DropBlock2d(nn.Module):
|
||||
drop_prob=0.1,
|
||||
block_size=7,
|
||||
gamma_scale=1.0,
|
||||
with_noise=False):
|
||||
with_noise=False,
|
||||
inplace=False,
|
||||
batchwise=False,
|
||||
fast=True):
|
||||
super(DropBlock2d, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.gamma_scale = gamma_scale
|
||||
self.block_size = block_size
|
||||
self.with_noise = with_noise
|
||||
self.inplace = inplace
|
||||
self.batchwise = batchwise
|
||||
self.fast = fast # FIXME finish comparisons of fast vs not
|
||||
|
||||
def forward(self, x):
|
||||
return drop_block_2d(x, self.drop_prob, self.training, self.block_size, self.gamma_scale, self.with_noise)
|
||||
if not self.training or not self.drop_prob:
|
||||
return x
|
||||
if self.fast:
|
||||
return drop_block_fast_2d(
|
||||
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
|
||||
else:
|
||||
return drop_block_2d(
|
||||
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
||||
|
80
timm/models/layers/split_attn.py
Normal file
80
timm/models/layers/split_attn.py
Normal file
@ -0,0 +1,80 @@
|
||||
""" Split Attention Conv2d (for ResNeSt Models)
|
||||
|
||||
Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
|
||||
|
||||
Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
|
||||
|
||||
Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class RadixSoftmax(nn.Module):
|
||||
def __init__(self, radix, cardinality):
|
||||
super(RadixSoftmax, self).__init__()
|
||||
self.radix = radix
|
||||
self.cardinality = cardinality
|
||||
|
||||
def forward(self, x):
|
||||
batch = x.size(0)
|
||||
if self.radix > 1:
|
||||
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
|
||||
x = F.softmax(x, dim=1)
|
||||
x = x.reshape(batch, -1)
|
||||
else:
|
||||
x = torch.sigmoid(x)
|
||||
return x
|
||||
|
||||
|
||||
class SplitAttnConv2d(nn.Module):
|
||||
"""Split-Attention Conv2d
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
|
||||
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4,
|
||||
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
|
||||
super(SplitAttnConv2d, self).__init__()
|
||||
self.radix = radix
|
||||
self.drop_block = drop_block
|
||||
mid_chs = out_channels * radix
|
||||
attn_chs = max(in_channels * radix // reduction_factor, 32)
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels, mid_chs, kernel_size, stride, padding, dilation,
|
||||
groups=groups * radix, bias=bias, **kwargs)
|
||||
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None
|
||||
self.act0 = act_layer(inplace=True)
|
||||
self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
|
||||
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
|
||||
self.rsoftmax = RadixSoftmax(radix, groups)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if self.bn0 is not None:
|
||||
x = self.bn0(x)
|
||||
if self.drop_block is not None:
|
||||
x = self.drop_block(x)
|
||||
x = self.act0(x)
|
||||
|
||||
B, RC, H, W = x.shape
|
||||
if self.radix > 1:
|
||||
x = x.reshape((B, self.radix, RC // self.radix, H, W))
|
||||
x_gap = x.sum(dim=1)
|
||||
else:
|
||||
x_gap = x
|
||||
x_gap = F.adaptive_avg_pool2d(x_gap, 1)
|
||||
x_gap = self.fc1(x_gap)
|
||||
if self.bn1 is not None:
|
||||
x_gap = self.bn1(x_gap)
|
||||
x_gap = self.act1(x_gap)
|
||||
x_attn = self.fc2(x_gap)
|
||||
|
||||
x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
|
||||
if self.radix > 1:
|
||||
out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
|
||||
else:
|
||||
out = x * x_attn
|
||||
return out.contiguous()
|
60
timm/models/layers/weight_init.py
Normal file
60
timm/models/layers/weight_init.py
Normal file
@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import math
|
||||
import warnings
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
||||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect.",
|
||||
stacklevel=2)
|
||||
|
||||
with torch.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
l = norm_cdf((a - mean) / std)
|
||||
u = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to
|
||||
# [2l-1, 2u-1].
|
||||
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tensor.erfinv_()
|
||||
|
||||
# Transform to proper mean, std
|
||||
tensor.mul_(std * math.sqrt(2.))
|
||||
tensor.add_(mean)
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tensor.clamp_(min=a, max=b)
|
||||
return tensor
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
||||
# type: (Tensor, float, float, float, float) -> Tensor
|
||||
r"""Fills the input Tensor with values drawn from a truncated
|
||||
normal distribution. The values are effectively drawn from the
|
||||
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
||||
with values outside :math:`[a, b]` redrawn until they are within
|
||||
the bounds. The method used for generating the random values works
|
||||
best when :math:`a \leq \text{mean} \leq b`.
|
||||
Args:
|
||||
tensor: an n-dimensional `torch.Tensor`
|
||||
mean: the mean of the normal distribution
|
||||
std: the standard deviation of the normal distribution
|
||||
a: the minimum cutoff value
|
||||
b: the maximum cutoff value
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.trunc_normal_(w)
|
||||
"""
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
@ -42,12 +42,14 @@ def _natural_key(string_):
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
||||
|
||||
def list_models(filter='', module='', pretrained=False):
|
||||
def list_models(filter='', module='', pretrained=False, exclude_filters=''):
|
||||
""" Return list of available model names, sorted alphabetically
|
||||
|
||||
Args:
|
||||
filter (str) - Wildcard filter string that works with fnmatch
|
||||
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
|
||||
pretrained (bool) - Include only models with pretrained weights if True
|
||||
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
|
||||
|
||||
Example:
|
||||
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
||||
@ -58,7 +60,14 @@ def list_models(filter='', module='', pretrained=False):
|
||||
else:
|
||||
models = _model_entrypoints.keys()
|
||||
if filter:
|
||||
models = fnmatch.filter(models, filter)
|
||||
models = fnmatch.filter(models, filter) # include these models
|
||||
if exclude_filters:
|
||||
if not isinstance(exclude_filters, list):
|
||||
exclude_filters = [exclude_filters]
|
||||
for xf in exclude_filters:
|
||||
exclude_models = fnmatch.filter(models, xf) # exclude these models
|
||||
if len(exclude_models):
|
||||
models = set(models).difference(exclude_models)
|
||||
if pretrained:
|
||||
models = _model_has_pretrained.intersection(models)
|
||||
return list(sorted(models, key=_natural_key))
|
||||
|
264
timm/models/resnest.py
Normal file
264
timm/models/resnest.py
Normal file
@ -0,0 +1,264 @@
|
||||
""" ResNeSt Models
|
||||
|
||||
Paper: `ResNeSt: Split-Attention Networks` - https://arxiv.org/abs/2004.08955
|
||||
|
||||
Adapted from original PyTorch impl w/ weights at https://github.com/zhanghang1989/ResNeSt by Hang Zhang
|
||||
|
||||
Modified for torchscript compat, and consistency with timm by Ross Wightman
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.layers import DropBlock2d
|
||||
from .helpers import load_pretrained
|
||||
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
|
||||
from .layers.split_attn import SplitAttnConv2d
|
||||
from .registry import register_model
|
||||
from .resnet import ResNet
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'conv1', 'classifier': 'fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
default_cfgs = {
|
||||
'resnest14d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'),
|
||||
'resnest26d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'),
|
||||
'resnest50d': _cfg(
|
||||
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50-528c19ca.pth'),
|
||||
'resnest101e': _cfg(
|
||||
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest101-22405ba7.pth', input_size=(3, 256, 256)),
|
||||
'resnest200e': _cfg(
|
||||
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest200-75117900.pth', input_size=(3, 320, 320)),
|
||||
'resnest269e': _cfg(
|
||||
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest269-0cc87c48.pth', input_size=(3, 416, 416)),
|
||||
'resnest50d_4s2x40d': _cfg(
|
||||
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_4s2x40d-41d14ed0.pth',
|
||||
interpolation='bicubic'),
|
||||
'resnest50d_1s4x24d': _cfg(
|
||||
url='https://hangzh.s3.amazonaws.com/encoding/models/resnest50_fast_1s4x24d-d4a4f76f.pth',
|
||||
interpolation='bicubic')
|
||||
}
|
||||
|
||||
|
||||
class ResNestBottleneck(nn.Module):
|
||||
"""ResNet Bottleneck
|
||||
"""
|
||||
# pylint: disable=unused-argument
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False,
|
||||
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
|
||||
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
|
||||
super(ResNestBottleneck, self).__init__()
|
||||
assert reduce_first == 1 # not supported
|
||||
assert attn_layer is None # not supported
|
||||
assert aa_layer is None # TODO not yet supported
|
||||
assert drop_path is None # TODO not yet supported
|
||||
|
||||
group_width = int(planes * (base_width / 64.)) * cardinality
|
||||
first_dilation = first_dilation or dilation
|
||||
if avd and (stride > 1 or is_first):
|
||||
avd_stride = stride
|
||||
stride = 1
|
||||
else:
|
||||
avd_stride = 0
|
||||
self.radix = radix
|
||||
self.drop_block = drop_block
|
||||
|
||||
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
|
||||
self.bn1 = norm_layer(group_width)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None
|
||||
|
||||
if self.radix >= 1:
|
||||
self.conv2 = SplitAttnConv2d(
|
||||
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
|
||||
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block)
|
||||
self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness
|
||||
self.act2 = None
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(
|
||||
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
|
||||
dilation=first_dilation, groups=cardinality, bias=False)
|
||||
self.bn2 = norm_layer(group_width)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None
|
||||
|
||||
self.conv3 = nn.Conv2d(group_width, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = norm_layer(planes*4)
|
||||
self.act3 = act_layer(inplace=True)
|
||||
self.downsample = downsample
|
||||
|
||||
def zero_init_last_bn(self):
|
||||
nn.init.zeros_(self.bn3.weight)
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
if self.drop_block is not None:
|
||||
out = self.drop_block(out)
|
||||
out = self.act1(out)
|
||||
|
||||
if self.avd_first is not None:
|
||||
out = self.avd_first(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
if self.bn2 is not None:
|
||||
out = self.bn2(out)
|
||||
if self.drop_block is not None:
|
||||
out = self.drop_block(out)
|
||||
out = self.act2(out)
|
||||
|
||||
if self.avd_last is not None:
|
||||
out = self.avd_last(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
if self.drop_block is not None:
|
||||
out = self.drop_block(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.act3(out)
|
||||
return out
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest14d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" ResNeSt-14d model. Weights ported from GluonCV.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest14d']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [1, 1, 1, 1], num_classes=num_classes, in_chans=in_chans,
|
||||
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
|
||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" ResNeSt-26d model. Weights ported from GluonCV.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest26d']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans,
|
||||
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
|
||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" ResNeSt-50d model. Matches paper ResNeSt-50 model, https://arxiv.org/abs/2004.08955
|
||||
Since this codebase supports all possible variations, 'd' for deep stem, stem_width 32, avg in downsample.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest50d']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
|
||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest101e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" ResNeSt-101e model. Matches paper ResNeSt-101 model, https://arxiv.org/abs/2004.08955
|
||||
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest101e']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
|
||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest200e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" ResNeSt-200e model. Matches paper ResNeSt-200 model, https://arxiv.org/abs/2004.08955
|
||||
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest200e']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [3, 24, 36, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
|
||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest269e(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
""" ResNeSt-269e model. Matches paper ResNeSt-269 model, https://arxiv.org/abs/2004.08955
|
||||
Since this codebase supports all possible variations, 'e' for deep stem, stem_width 64, avg in downsample.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest269e']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [3, 30, 48, 8], num_classes=num_classes, in_chans=in_chans,
|
||||
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
|
||||
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest50d_4s2x40d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""ResNeSt-50 4s2x40d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest50d_4s2x40d']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2,
|
||||
block_args=dict(radix=4, avd=True, avd_first=True), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnest50d_1s4x24d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""ResNeSt-50 1s4x24d from https://github.com/zhanghang1989/ResNeSt/blob/master/ablation.md
|
||||
"""
|
||||
default_cfg = default_cfgs['resnest50d_1s4x24d']
|
||||
model = ResNet(
|
||||
ResNestBottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans,
|
||||
stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
|
||||
block_args=dict(radix=1, avd=True, avd_first=True), **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
Loading…
x
Reference in New Issue
Block a user