mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fixup a few comments, add PyTorch version aware Flatten and finish as_sequential for GenEfficientNet
This commit is contained in:
parent
7ac6db4543
commit
35e8f0c5e7
@ -8,6 +8,7 @@ import numpy as np
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
# Tuple helpers ripped from PyTorch
|
||||||
def _ntuple(n):
|
def _ntuple(n):
|
||||||
def parse(x):
|
def parse(x):
|
||||||
if isinstance(x, container_abcs.Iterable):
|
if isinstance(x, container_abcs.Iterable):
|
||||||
@ -77,7 +78,7 @@ def get_padding_value(padding, kernel_size, **kwargs):
|
|||||||
# static case, no extra overhead
|
# static case, no extra overhead
|
||||||
padding = _get_padding(kernel_size, **kwargs)
|
padding = _get_padding(kernel_size, **kwargs)
|
||||||
else:
|
else:
|
||||||
# dynamic padding
|
# dynamic 'SAME' padding, has runtime/GPU memory overhead
|
||||||
padding = 0
|
padding = 0
|
||||||
dynamic = True
|
dynamic = True
|
||||||
elif padding == 'valid':
|
elif padding == 'valid':
|
||||||
@ -101,6 +102,7 @@ def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
|
|||||||
|
|
||||||
class MixedConv2d(nn.Module):
|
class MixedConv2d(nn.Module):
|
||||||
""" Mixed Grouped Convolution
|
""" Mixed Grouped Convolution
|
||||||
|
|
||||||
Based on MDConv and GroupedConv in MixNet impl:
|
Based on MDConv and GroupedConv in MixNet impl:
|
||||||
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
|
||||||
"""
|
"""
|
||||||
@ -152,7 +154,11 @@ def get_condconv_initializer(initializer, num_experts, expert_shape):
|
|||||||
|
|
||||||
class CondConv2d(nn.Module):
|
class CondConv2d(nn.Module):
|
||||||
""" Conditional Convolution
|
""" Conditional Convolution
|
||||||
|
|
||||||
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
|
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
|
||||||
|
|
||||||
|
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
|
||||||
|
https://github.com/pytorch/pytorch/issues/17983
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3,
|
def __init__(self, in_channels, out_channels, kernel_size=3,
|
||||||
@ -211,6 +217,7 @@ class CondConv2d(nn.Module):
|
|||||||
if self._use_groups:
|
if self._use_groups:
|
||||||
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size
|
||||||
weight = weight.view(new_weight_shape)
|
weight = weight.view(new_weight_shape)
|
||||||
|
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel
|
||||||
x = x.view(1, B * C, H, W)
|
x = x.view(1, B * C, H, W)
|
||||||
out = self.conv_fn(
|
out = self.conv_fn(
|
||||||
x, weight, bias, stride=self.stride, padding=self.padding,
|
x, weight, bias, stride=self.stride, padding=self.padding,
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
A generic class with building blocks to support a variety of models with efficient architectures:
|
A generic class with building blocks to support a variety of models with efficient architectures:
|
||||||
* EfficientNet (B0-B7)
|
* EfficientNet (B0-B7)
|
||||||
|
* EfficientNet-EdgeTPU
|
||||||
|
* EfficientNet-CondConv
|
||||||
* MixNet (Small, Medium, and Large)
|
* MixNet (Small, Medium, and Large)
|
||||||
* MnasNet B1, A1 (SE), Small
|
* MnasNet B1, A1 (SE), Small
|
||||||
* MobileNet V1, V2, and V3
|
* MobileNet V1, V2, and V3
|
||||||
@ -31,6 +33,7 @@ from .registry import register_model, model_entrypoint
|
|||||||
from .helpers import load_pretrained
|
from .helpers import load_pretrained
|
||||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||||
from .conv2d_layers import select_conv2d
|
from .conv2d_layers import select_conv2d
|
||||||
|
from .layers import Flatten
|
||||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||||
|
|
||||||
|
|
||||||
@ -1050,16 +1053,14 @@ class GenEfficientNet(_GenEfficientNet):
|
|||||||
layers = [self.conv_stem, self.bn1, self.act1]
|
layers = [self.conv_stem, self.bn1, self.act1]
|
||||||
layers.extend(self.blocks)
|
layers.extend(self.blocks)
|
||||||
if self.head_conv == 'efficient':
|
if self.head_conv == 'efficient':
|
||||||
layers.extend([self.global_pool, self.bn2, self.act2])
|
layers.extend([self.global_pool, self.conv_head, self.act2])
|
||||||
else:
|
else:
|
||||||
layers.extend([self.conv_head, self.bn2, self.act2])
|
layers.extend([self.conv_head, self.bn2, self.act2])
|
||||||
if self.global_pool is not None:
|
if self.global_pool is not None:
|
||||||
layers.append(self.global_pool)
|
layers.append(self.global_pool)
|
||||||
#append flatten layer
|
layers.extend([Flatten(), nn.Dropout(self.drop_rate), self.classifier])
|
||||||
layers.append(self.classifier)
|
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
|
||||||
def get_classifier(self):
|
def get_classifier(self):
|
||||||
return self.classifier
|
return self.classifier
|
||||||
|
|
||||||
@ -1106,7 +1107,8 @@ class GenEfficientNetFeatures(_GenEfficientNet):
|
|||||||
#assert len(block_args) >= num_stages - 1
|
#assert len(block_args) >= num_stages - 1
|
||||||
#block_args = block_args[:num_stages - 1]
|
#block_args = block_args[:num_stages - 1]
|
||||||
|
|
||||||
super(GenEfficientNetFeatures, self).__init__( # FIXME it would be nice if Python made this nicer
|
# FIXME it would be nice if Python made this nicer without using kwargs and erasing IDE hints, etc
|
||||||
|
super(GenEfficientNetFeatures, self).__init__(
|
||||||
block_args, in_chans=in_chans, stem_size=stem_size,
|
block_args, in_chans=in_chans, stem_size=stem_size,
|
||||||
output_stride=output_stride, pad_type=pad_type, act_layer=act_layer,
|
output_stride=output_stride, pad_type=pad_type, act_layer=act_layer,
|
||||||
drop_rate=drop_rate, drop_connect_rate=drop_connect_rate, feature_location=feature_location,
|
drop_rate=drop_rate, drop_connect_rate=drop_connect_rate, feature_location=feature_location,
|
||||||
@ -1548,6 +1550,11 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
|
|||||||
|
|
||||||
|
|
||||||
def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
|
||||||
|
""" Creates an EfficientNet-EdgeTPU model
|
||||||
|
|
||||||
|
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
|
||||||
|
"""
|
||||||
|
|
||||||
arch_def = [
|
arch_def = [
|
||||||
# NOTE `fc` is present to override a mismatch between stem channels and in chs not
|
# NOTE `fc` is present to override a mismatch between stem channels and in chs not
|
||||||
# present in other models
|
# present in other models
|
||||||
@ -1573,8 +1580,10 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
|
|||||||
|
|
||||||
def _gen_efficientnet_condconv(
|
def _gen_efficientnet_condconv(
|
||||||
variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
|
variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
|
||||||
|
"""Creates an EfficientNet-CondConv model.
|
||||||
|
|
||||||
"""Creates an efficientnet-condconv model."""
|
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
|
||||||
|
"""
|
||||||
arch_def = [
|
arch_def = [
|
||||||
['ds_r1_k3_s1_e1_c16_se0.25'],
|
['ds_r1_k3_s1_e1_c16_se0.25'],
|
||||||
['ir_r2_k3_s2_e6_c24_se0.25'],
|
['ir_r2_k3_s2_e6_c24_se0.25'],
|
||||||
@ -1584,6 +1593,8 @@ def _gen_efficientnet_condconv(
|
|||||||
['ir_r4_k5_s2_e6_c192_se0.25_cc4'],
|
['ir_r4_k5_s2_e6_c192_se0.25_cc4'],
|
||||||
['ir_r1_k3_s1_e6_c320_se0.25_cc4'],
|
['ir_r1_k3_s1_e6_c320_se0.25_cc4'],
|
||||||
]
|
]
|
||||||
|
# NOTE unlike official impl, this one uses `cc<x>` option where x is the base number of experts for each stage and
|
||||||
|
# the expert_multiplier increases that on a per-model basis as with depth/channel multipliers
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
block_args=_decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
|
block_args=_decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
|
||||||
num_features=_round_channels(1280, channel_multiplier, 8, None),
|
num_features=_round_channels(1280, channel_multiplier, 8, None),
|
||||||
@ -2056,7 +2067,7 @@ def tf_efficientnet_el(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
|
def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
|
||||||
""" EfficientNet-B0 """
|
""" EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """
|
||||||
# NOTE for train, drop_rate should be 0.2
|
# NOTE for train, drop_rate should be 0.2
|
||||||
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
||||||
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
||||||
@ -2068,7 +2079,7 @@ def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
|
def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
|
||||||
""" EfficientNet-B0 """
|
""" EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """
|
||||||
# NOTE for train, drop_rate should be 0.2
|
# NOTE for train, drop_rate should be 0.2
|
||||||
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
||||||
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
||||||
@ -2080,7 +2091,7 @@ def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs):
|
|||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs):
|
||||||
""" EfficientNet-B0 """
|
""" EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """
|
||||||
# NOTE for train, drop_rate should be 0.2
|
# NOTE for train, drop_rate should be 0.2
|
||||||
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
#kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
|
||||||
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
kwargs['bn_eps'] = _BN_EPS_TF_DEFAULT
|
||||||
|
31
timm/models/layers.py
Normal file
31
timm/models/layers.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def versiontuple(v):
|
||||||
|
return tuple(map(int, (v.split("."))))[:3]
|
||||||
|
|
||||||
|
|
||||||
|
if versiontuple(torch.__version__) >= versiontuple('1.2.0'):
|
||||||
|
Flatten = nn.Flatten
|
||||||
|
else:
|
||||||
|
class Flatten(nn.Module):
|
||||||
|
r"""
|
||||||
|
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`.
|
||||||
|
Args:
|
||||||
|
start_dim: first dim to flatten (default = 1).
|
||||||
|
end_dim: last dim to flatten (default = -1).
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, *dims)`
|
||||||
|
- Output: :math:`(N, \prod *dims)` (for the default case).
|
||||||
|
"""
|
||||||
|
__constants__ = ['start_dim', 'end_dim']
|
||||||
|
|
||||||
|
def __init__(self, start_dim=1, end_dim=-1):
|
||||||
|
super(Flatten, self).__init__()
|
||||||
|
self.start_dim = start_dim
|
||||||
|
self.end_dim = end_dim
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return input.flatten(self.start_dim, self.end_dim)
|
Loading…
x
Reference in New Issue
Block a user