mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
* factor out data related constants to own file * move data related config helpers to own file * add a variant of RandomResizeCrop that randomizes interpolation method * remove old Numpy version of RandomErasing * cleanup torch version of RandomErasing and use it in either GPU loader batch mode or single image cpu Transform
211 lines
8.8 KiB
Python
211 lines
8.8 KiB
Python
"""Pytorch Densenet implementation w/ tweaks
|
|
This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with
|
|
fixed kwargs passthrough and addition of dynamic global avg/max pool.
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from collections import OrderedDict
|
|
|
|
from models.helpers import load_pretrained
|
|
from models.adaptive_avgmax_pool import *
|
|
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
import re
|
|
|
|
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
|
|
|
|
|
|
def _cfg(url=''):
|
|
return {
|
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7),
|
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
'first_conv': 'features.conv0', 'classifier': 'classifier',
|
|
}
|
|
|
|
|
|
default_cfgs = {
|
|
'densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-241335ed.pth'),
|
|
'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-6f0f7f60.pth'),
|
|
'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-4c113574.pth'),
|
|
'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-17b70270.pth'),
|
|
}
|
|
|
|
|
|
def _filter_pretrained(state_dict):
|
|
pattern = re.compile(
|
|
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
|
|
|
|
for key in list(state_dict.keys()):
|
|
res = pattern.match(key)
|
|
if res:
|
|
new_key = res.group(1) + res.group(2)
|
|
state_dict[new_key] = state_dict[key]
|
|
del state_dict[key]
|
|
return state_dict
|
|
|
|
|
|
def densenet121(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
|
r"""Densenet-121 model from
|
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
|
"""
|
|
default_cfg = default_cfgs['densenet121']
|
|
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
model.default_cfg = default_cfg
|
|
if pretrained:
|
|
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
|
return model
|
|
|
|
|
|
def densenet169(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
|
r"""Densenet-169 model from
|
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
|
"""
|
|
default_cfg = default_cfgs['densenet169']
|
|
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
model.default_cfg = default_cfg
|
|
if pretrained:
|
|
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
|
return model
|
|
|
|
|
|
def densenet201(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
|
r"""Densenet-201 model from
|
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
|
"""
|
|
default_cfg = default_cfgs['densenet201']
|
|
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
model.default_cfg = default_cfg
|
|
if pretrained:
|
|
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
|
return model
|
|
|
|
|
|
def densenet161(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
|
r"""Densenet-201 model from
|
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
|
"""
|
|
print(num_classes, in_chans, pretrained)
|
|
default_cfg = default_cfgs['densenet161']
|
|
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
|
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
|
model.default_cfg = default_cfg
|
|
if pretrained:
|
|
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
|
return model
|
|
|
|
|
|
class _DenseLayer(nn.Sequential):
|
|
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
|
|
super(_DenseLayer, self).__init__()
|
|
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
|
|
self.add_module('relu1', nn.ReLU(inplace=True)),
|
|
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
|
|
growth_rate, kernel_size=1, stride=1, bias=False)),
|
|
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
|
|
self.add_module('relu2', nn.ReLU(inplace=True)),
|
|
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
|
|
kernel_size=3, stride=1, padding=1, bias=False)),
|
|
self.drop_rate = drop_rate
|
|
|
|
def forward(self, x):
|
|
new_features = super(_DenseLayer, self).forward(x)
|
|
if self.drop_rate > 0:
|
|
new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
|
|
return torch.cat([x, new_features], 1)
|
|
|
|
|
|
class _DenseBlock(nn.Sequential):
|
|
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
|
|
super(_DenseBlock, self).__init__()
|
|
for i in range(num_layers):
|
|
layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
|
|
self.add_module('denselayer%d' % (i + 1), layer)
|
|
|
|
|
|
class _Transition(nn.Sequential):
|
|
def __init__(self, num_input_features, num_output_features):
|
|
super(_Transition, self).__init__()
|
|
self.add_module('norm', nn.BatchNorm2d(num_input_features))
|
|
self.add_module('relu', nn.ReLU(inplace=True))
|
|
self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
|
|
kernel_size=1, stride=1, bias=False))
|
|
self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
|
|
|
|
|
|
class DenseNet(nn.Module):
|
|
r"""Densenet-BC model class, based on
|
|
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
|
|
|
Args:
|
|
growth_rate (int) - how many filters to add each layer (`k` in paper)
|
|
block_config (list of 4 ints) - how many layers in each pooling block
|
|
num_init_features (int) - the number of filters to learn in the first convolution layer
|
|
bn_size (int) - multiplicative factor for number of bottle neck layers
|
|
(i.e. bn_size * k features in the bottleneck layer)
|
|
drop_rate (float) - dropout rate after each dense layer
|
|
num_classes (int) - number of classification classes
|
|
"""
|
|
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
|
|
num_init_features=64, bn_size=4, drop_rate=0,
|
|
num_classes=1000, in_chans=3, global_pool='avg'):
|
|
self.global_pool = global_pool
|
|
self.num_classes = num_classes
|
|
super(DenseNet, self).__init__()
|
|
|
|
# First convolution
|
|
self.features = nn.Sequential(OrderedDict([
|
|
('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
|
|
('norm0', nn.BatchNorm2d(num_init_features)),
|
|
('relu0', nn.ReLU(inplace=True)),
|
|
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
|
|
]))
|
|
|
|
# Each denseblock
|
|
num_features = num_init_features
|
|
for i, num_layers in enumerate(block_config):
|
|
block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
|
|
bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
|
|
self.features.add_module('denseblock%d' % (i + 1), block)
|
|
num_features = num_features + num_layers * growth_rate
|
|
if i != len(block_config) - 1:
|
|
trans = _Transition(
|
|
num_input_features=num_features, num_output_features=num_features // 2)
|
|
self.features.add_module('transition%d' % (i + 1), trans)
|
|
num_features = num_features // 2
|
|
|
|
# Final batch norm
|
|
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
|
|
|
|
# Linear layer
|
|
self.classifier = nn.Linear(num_features, num_classes)
|
|
|
|
self.num_features = num_features
|
|
|
|
def get_classifier(self):
|
|
return self.classifier
|
|
|
|
def reset_classifier(self, num_classes, global_pool='avg'):
|
|
self.global_pool = global_pool
|
|
self.num_classes = num_classes
|
|
del self.classifier
|
|
if num_classes:
|
|
self.classifier = nn.Linear(self.num_features, num_classes)
|
|
else:
|
|
self.classifier = None
|
|
|
|
def forward_features(self, x, pool=True):
|
|
x = self.features(x)
|
|
x = F.relu(x, inplace=True)
|
|
if pool:
|
|
x = select_adaptive_pool2d(x, self.global_pool)
|
|
x = x.view(x.size(0), -1)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
return self.classifier(self.forward_features(x, pool=True))
|
|
|