mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
* factor out data related constants to own file * move data related config helpers to own file * add a variant of RandomResizeCrop that randomizes interpolation method * remove old Numpy version of RandomErasing * cleanup torch version of RandomErasing and use it in either GPU loader batch mode or single image cpu Transform
333 lines
11 KiB
Python
333 lines
11 KiB
Python
""" Pytorch Inception-Resnet-V2 implementation
|
|
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
|
|
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from models.helpers import load_pretrained
|
|
from models.adaptive_avgmax_pool import *
|
|
from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
|
|
|
default_cfgs = {
|
|
'inception_resnet_v2': {
|
|
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
|
|
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
|
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
|
'first_conv': 'conv2d_1a.conv', 'classifier': 'last_linear',
|
|
}
|
|
}
|
|
|
|
|
|
class BasicConv2d(nn.Module):
|
|
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
|
super(BasicConv2d, self).__init__()
|
|
self.conv = nn.Conv2d(
|
|
in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
|
|
self.bn = nn.BatchNorm2d(out_planes, eps=.001)
|
|
self.relu = nn.ReLU(inplace=False)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
|
|
class Mixed_5b(nn.Module):
|
|
def __init__(self):
|
|
super(Mixed_5b, self).__init__()
|
|
|
|
self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
|
|
|
|
self.branch1 = nn.Sequential(
|
|
BasicConv2d(192, 48, kernel_size=1, stride=1),
|
|
BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
|
|
)
|
|
|
|
self.branch2 = nn.Sequential(
|
|
BasicConv2d(192, 64, kernel_size=1, stride=1),
|
|
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
|
|
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
|
)
|
|
|
|
self.branch3 = nn.Sequential(
|
|
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
|
BasicConv2d(192, 64, kernel_size=1, stride=1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x0 = self.branch0(x)
|
|
x1 = self.branch1(x)
|
|
x2 = self.branch2(x)
|
|
x3 = self.branch3(x)
|
|
out = torch.cat((x0, x1, x2, x3), 1)
|
|
return out
|
|
|
|
|
|
class Block35(nn.Module):
|
|
def __init__(self, scale=1.0):
|
|
super(Block35, self).__init__()
|
|
|
|
self.scale = scale
|
|
|
|
self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
|
|
|
|
self.branch1 = nn.Sequential(
|
|
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
|
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
|
|
)
|
|
|
|
self.branch2 = nn.Sequential(
|
|
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
|
BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
|
|
BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
|
|
)
|
|
|
|
self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
|
|
self.relu = nn.ReLU(inplace=False)
|
|
|
|
def forward(self, x):
|
|
x0 = self.branch0(x)
|
|
x1 = self.branch1(x)
|
|
x2 = self.branch2(x)
|
|
out = torch.cat((x0, x1, x2), 1)
|
|
out = self.conv2d(out)
|
|
out = out * self.scale + x
|
|
out = self.relu(out)
|
|
return out
|
|
|
|
|
|
class Mixed_6a(nn.Module):
|
|
def __init__(self):
|
|
super(Mixed_6a, self).__init__()
|
|
|
|
self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
|
|
|
|
self.branch1 = nn.Sequential(
|
|
BasicConv2d(320, 256, kernel_size=1, stride=1),
|
|
BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
|
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
|
)
|
|
|
|
self.branch2 = nn.MaxPool2d(3, stride=2)
|
|
|
|
def forward(self, x):
|
|
x0 = self.branch0(x)
|
|
x1 = self.branch1(x)
|
|
x2 = self.branch2(x)
|
|
out = torch.cat((x0, x1, x2), 1)
|
|
return out
|
|
|
|
|
|
class Block17(nn.Module):
|
|
def __init__(self, scale=1.0):
|
|
super(Block17, self).__init__()
|
|
|
|
self.scale = scale
|
|
|
|
self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
|
|
|
|
self.branch1 = nn.Sequential(
|
|
BasicConv2d(1088, 128, kernel_size=1, stride=1),
|
|
BasicConv2d(128, 160, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
|
BasicConv2d(160, 192, kernel_size=(7, 1), stride=1, padding=(3, 0))
|
|
)
|
|
|
|
self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
|
|
self.relu = nn.ReLU(inplace=False)
|
|
|
|
def forward(self, x):
|
|
x0 = self.branch0(x)
|
|
x1 = self.branch1(x)
|
|
out = torch.cat((x0, x1), 1)
|
|
out = self.conv2d(out)
|
|
out = out * self.scale + x
|
|
out = self.relu(out)
|
|
return out
|
|
|
|
|
|
class Mixed_7a(nn.Module):
|
|
def __init__(self):
|
|
super(Mixed_7a, self).__init__()
|
|
|
|
self.branch0 = nn.Sequential(
|
|
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
|
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
|
)
|
|
|
|
self.branch1 = nn.Sequential(
|
|
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
|
BasicConv2d(256, 288, kernel_size=3, stride=2)
|
|
)
|
|
|
|
self.branch2 = nn.Sequential(
|
|
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
|
BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
|
|
BasicConv2d(288, 320, kernel_size=3, stride=2)
|
|
)
|
|
|
|
self.branch3 = nn.MaxPool2d(3, stride=2)
|
|
|
|
def forward(self, x):
|
|
x0 = self.branch0(x)
|
|
x1 = self.branch1(x)
|
|
x2 = self.branch2(x)
|
|
x3 = self.branch3(x)
|
|
out = torch.cat((x0, x1, x2, x3), 1)
|
|
return out
|
|
|
|
|
|
class Block8(nn.Module):
|
|
def __init__(self, scale=1.0, noReLU=False):
|
|
super(Block8, self).__init__()
|
|
|
|
self.scale = scale
|
|
self.noReLU = noReLU
|
|
|
|
self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
|
|
|
|
self.branch1 = nn.Sequential(
|
|
BasicConv2d(2080, 192, kernel_size=1, stride=1),
|
|
BasicConv2d(192, 224, kernel_size=(1, 3), stride=1, padding=(0, 1)),
|
|
BasicConv2d(224, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
|
)
|
|
|
|
self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
|
|
if not self.noReLU:
|
|
self.relu = nn.ReLU(inplace=False)
|
|
|
|
def forward(self, x):
|
|
x0 = self.branch0(x)
|
|
x1 = self.branch1(x)
|
|
out = torch.cat((x0, x1), 1)
|
|
out = self.conv2d(out)
|
|
out = out * self.scale + x
|
|
if not self.noReLU:
|
|
out = self.relu(out)
|
|
return out
|
|
|
|
|
|
class InceptionResnetV2(nn.Module):
|
|
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'):
|
|
super(InceptionResnetV2, self).__init__()
|
|
self.drop_rate = drop_rate
|
|
self.global_pool = global_pool
|
|
self.num_classes = num_classes
|
|
self.num_features = 1536
|
|
|
|
self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
|
|
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
|
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
|
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
|
|
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
|
|
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
|
|
self.maxpool_5a = nn.MaxPool2d(3, stride=2)
|
|
self.mixed_5b = Mixed_5b()
|
|
self.repeat = nn.Sequential(
|
|
Block35(scale=0.17),
|
|
Block35(scale=0.17),
|
|
Block35(scale=0.17),
|
|
Block35(scale=0.17),
|
|
Block35(scale=0.17),
|
|
Block35(scale=0.17),
|
|
Block35(scale=0.17),
|
|
Block35(scale=0.17),
|
|
Block35(scale=0.17),
|
|
Block35(scale=0.17)
|
|
)
|
|
self.mixed_6a = Mixed_6a()
|
|
self.repeat_1 = nn.Sequential(
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10),
|
|
Block17(scale=0.10)
|
|
)
|
|
self.mixed_7a = Mixed_7a()
|
|
self.repeat_2 = nn.Sequential(
|
|
Block8(scale=0.20),
|
|
Block8(scale=0.20),
|
|
Block8(scale=0.20),
|
|
Block8(scale=0.20),
|
|
Block8(scale=0.20),
|
|
Block8(scale=0.20),
|
|
Block8(scale=0.20),
|
|
Block8(scale=0.20),
|
|
Block8(scale=0.20)
|
|
)
|
|
self.block8 = Block8(noReLU=True)
|
|
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
|
|
self.last_linear = nn.Linear(self.num_features, num_classes)
|
|
|
|
def get_classifier(self):
|
|
return self.last_linear
|
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
self.global_pool = global_pool
|
|
self.num_classes = num_classes
|
|
del self.last_linear
|
|
if num_classes:
|
|
self.last_linear = torch.nn.Linear(self.num_features, num_classes)
|
|
else:
|
|
self.last_linear = None
|
|
|
|
def forward_features(self, x, pool=True):
|
|
x = self.conv2d_1a(x)
|
|
x = self.conv2d_2a(x)
|
|
x = self.conv2d_2b(x)
|
|
x = self.maxpool_3a(x)
|
|
x = self.conv2d_3b(x)
|
|
x = self.conv2d_4a(x)
|
|
x = self.maxpool_5a(x)
|
|
x = self.mixed_5b(x)
|
|
x = self.repeat(x)
|
|
x = self.mixed_6a(x)
|
|
x = self.repeat_1(x)
|
|
x = self.mixed_7a(x)
|
|
x = self.repeat_2(x)
|
|
x = self.block8(x)
|
|
x = self.conv2d_7b(x)
|
|
if pool:
|
|
x = select_adaptive_pool2d(x, self.global_pool)
|
|
#x = F.avg_pool2d(x, 8, count_include_pad=False)
|
|
x = x.view(x.size(0), -1)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
x = self.forward_features(x, pool=True)
|
|
if self.drop_rate > 0:
|
|
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
|
x = self.last_linear(x)
|
|
return x
|
|
|
|
|
|
def inception_resnet_v2(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
|
r"""InceptionResnetV2 model architecture from the
|
|
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
|
|
"""
|
|
default_cfg = default_cfgs['inception_resnet_v2']
|
|
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
model.default_cfg = default_cfg
|
|
if pretrained:
|
|
load_pretrained(model, default_cfg, num_classes, in_chans)
|
|
|
|
return model
|
|
|