mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Tweak some comments, add SKNet models with weights to sotabench, remove an unused branch
This commit is contained in:
parent
91e2b33d72
commit
569419b38d
18
sotabench.py
18
sotabench.py
@ -56,8 +56,7 @@ model_list = [
|
|||||||
model_desc='Trained from scratch in PyTorch w/ RandAugment'),
|
model_desc='Trained from scratch in PyTorch w/ RandAugment'),
|
||||||
_entry('efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946',
|
_entry('efficientnet_es', 'EfficientNet-EdgeTPU-S', '1905.11946',
|
||||||
model_desc='Trained from scratch in PyTorch w/ RandAugment'),
|
model_desc='Trained from scratch in PyTorch w/ RandAugment'),
|
||||||
_entry('fbnetc_100', 'FBNet-C', '1812.03443',
|
|
||||||
model_desc='Trained in PyTorch with RMSProp, exponential LR decay'),
|
|
||||||
_entry('gluon_inception_v3', 'Inception V3', '1512.00567', model_desc='Ported from GluonCV Model Zoo'),
|
_entry('gluon_inception_v3', 'Inception V3', '1512.00567', model_desc='Ported from GluonCV Model Zoo'),
|
||||||
_entry('gluon_resnet18_v1b', 'ResNet-18', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
|
_entry('gluon_resnet18_v1b', 'ResNet-18', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
|
||||||
_entry('gluon_resnet34_v1b', 'ResNet-34', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
|
_entry('gluon_resnet34_v1b', 'ResNet-34', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
|
||||||
@ -82,14 +81,22 @@ model_list = [
|
|||||||
_entry('gluon_seresnext101_64x4d', 'SE-ResNeXt-101 64x4d', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
|
_entry('gluon_seresnext101_64x4d', 'SE-ResNeXt-101 64x4d', '1812.01187', model_desc='Ported from GluonCV Model Zoo'),
|
||||||
_entry('gluon_xception65', 'Modified Aligned Xception', '1802.02611', batch_size=BATCH_SIZE//2,
|
_entry('gluon_xception65', 'Modified Aligned Xception', '1802.02611', batch_size=BATCH_SIZE//2,
|
||||||
model_desc='Ported from GluonCV Model Zoo'),
|
model_desc='Ported from GluonCV Model Zoo'),
|
||||||
|
|
||||||
_entry('mixnet_xl', 'MixNet-XL', '1907.09595', model_desc="My own scaling beyond paper's MixNet Large"),
|
_entry('mixnet_xl', 'MixNet-XL', '1907.09595', model_desc="My own scaling beyond paper's MixNet Large"),
|
||||||
_entry('mixnet_l', 'MixNet-L', '1907.09595'),
|
_entry('mixnet_l', 'MixNet-L', '1907.09595'),
|
||||||
_entry('mixnet_m', 'MixNet-M', '1907.09595'),
|
_entry('mixnet_m', 'MixNet-M', '1907.09595'),
|
||||||
_entry('mixnet_s', 'MixNet-S', '1907.09595'),
|
_entry('mixnet_s', 'MixNet-S', '1907.09595'),
|
||||||
|
|
||||||
|
_entry('fbnetc_100', 'FBNet-C', '1812.03443',
|
||||||
|
model_desc='Trained in PyTorch with RMSProp, exponential LR decay'),
|
||||||
_entry('mnasnet_100', 'MnasNet-B1', '1807.11626'),
|
_entry('mnasnet_100', 'MnasNet-B1', '1807.11626'),
|
||||||
|
_entry('semnasnet_100', 'MnasNet-A1', '1807.11626'),
|
||||||
|
_entry('spnasnet_100', 'Single-Path NAS', '1904.02877',
|
||||||
|
model_desc='Trained in PyTorch with SGD, cosine LR decay'),
|
||||||
_entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244',
|
_entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244',
|
||||||
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
|
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
|
||||||
'paper as closely as possible.'),
|
'paper as closely as possible.'),
|
||||||
|
|
||||||
_entry('resnet18', 'ResNet-18', '1812.01187'),
|
_entry('resnet18', 'ResNet-18', '1812.01187'),
|
||||||
_entry('resnet26', 'ResNet-26', '1812.01187', model_desc='Block cfg of ResNet-34 w/ Bottleneck'),
|
_entry('resnet26', 'ResNet-26', '1812.01187', model_desc='Block cfg of ResNet-34 w/ Bottleneck'),
|
||||||
_entry('resnet26d', 'ResNet-26-D', '1812.01187',
|
_entry('resnet26d', 'ResNet-26-D', '1812.01187',
|
||||||
@ -103,7 +110,7 @@ model_list = [
|
|||||||
_entry('resnext50d_32x4d', 'ResNeXt-50-D 32x4d', '1812.01187',
|
_entry('resnext50d_32x4d', 'ResNeXt-50-D 32x4d', '1812.01187',
|
||||||
model_desc="'D' variant (3x3 deep stem w/ avg-pool downscale). Trained with "
|
model_desc="'D' variant (3x3 deep stem w/ avg-pool downscale). Trained with "
|
||||||
"SGD w/ cosine LR decay, random-erasing (gaussian per-pixel noise) and label-smoothing"),
|
"SGD w/ cosine LR decay, random-erasing (gaussian per-pixel noise) and label-smoothing"),
|
||||||
_entry('semnasnet_100', 'MnasNet-A1', '1807.11626'),
|
|
||||||
_entry('seresnet18', 'SE-ResNet-18', '1709.01507'),
|
_entry('seresnet18', 'SE-ResNet-18', '1709.01507'),
|
||||||
_entry('seresnet34', 'SE-ResNet-34', '1709.01507'),
|
_entry('seresnet34', 'SE-ResNet-34', '1709.01507'),
|
||||||
_entry('seresnext26_32x4d', 'SE-ResNeXt-26 32x4d', '1709.01507',
|
_entry('seresnext26_32x4d', 'SE-ResNeXt-26 32x4d', '1709.01507',
|
||||||
@ -114,8 +121,9 @@ model_list = [
|
|||||||
model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered stem, and avg-pool in downsample layers.'),
|
model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered stem, and avg-pool in downsample layers.'),
|
||||||
_entry('seresnext26tn_32x4d', 'SE-ResNeXt-26-TN 32x4d', '1812.01187',
|
_entry('seresnext26tn_32x4d', 'SE-ResNeXt-26-TN 32x4d', '1812.01187',
|
||||||
model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered narrow stem, and avg-pool in downsample layers.'),
|
model_desc='Block cfg of SE-ResNeXt-34 w/ Bottleneck, deep tiered narrow stem, and avg-pool in downsample layers.'),
|
||||||
_entry('spnasnet_100', 'Single-Path NAS', '1904.02877',
|
|
||||||
model_desc='Trained in PyTorch with SGD, cosine LR decay'),
|
_entry('skresnet18', 'SK-ResNet-18', '1903.06586'),
|
||||||
|
_entry('skresnext50_32x4d', 'SKNet-50', '1903.06586'),
|
||||||
|
|
||||||
_entry('tf_efficientnet_b0', 'EfficientNet-B0 (AutoAugment)', '1905.11946',
|
_entry('tf_efficientnet_b0', 'EfficientNet-B0 (AutoAugment)', '1905.11946',
|
||||||
model_desc='Ported from official Google AI Tensorflow weights'),
|
model_desc='Ported from official Google AI Tensorflow weights'),
|
||||||
|
@ -2,6 +2,9 @@
|
|||||||
|
|
||||||
Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
|
Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
|
||||||
|
|
||||||
|
WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on
|
||||||
|
some tasks, especially fine-grained it seems. I may end up removing this impl.
|
||||||
|
|
||||||
Hacked together by Ross Wightman
|
Hacked together by Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
""" Conditional Convolution
|
""" PyTorch Conditionally Parameterized Convolution (CondConv)
|
||||||
|
|
||||||
|
Paper: CondConv: Conditionally Parameterized Convolutions for Efficient Inference
|
||||||
|
(https://arxiv.org/abs/1904.04971)
|
||||||
|
|
||||||
Hacked together by Ross Wightman
|
Hacked together by Ross Wightman
|
||||||
"""
|
"""
|
||||||
@ -28,7 +31,7 @@ def get_condconv_initializer(initializer, num_experts, expert_shape):
|
|||||||
|
|
||||||
|
|
||||||
class CondConv2d(nn.Module):
|
class CondConv2d(nn.Module):
|
||||||
""" Conditional Convolution
|
""" Conditionally Parameterized Convolution
|
||||||
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
|
Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py
|
||||||
|
|
||||||
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
|
Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion:
|
||||||
|
@ -42,7 +42,7 @@ class EcaModule(nn.Module):
|
|||||||
"""Constructs an ECA module.
|
"""Constructs an ECA module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
channel: Number of channels of the input feature map for use in adaptive kernel sizes
|
channels: Number of channels of the input feature map for use in adaptive kernel sizes
|
||||||
for actual calculations according to channel.
|
for actual calculations according to channel.
|
||||||
gamma, beta: when channel is given parameters of mapping function
|
gamma, beta: when channel is given parameters of mapping function
|
||||||
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
|
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
""" Conditional Convolution
|
""" PyTorch Mixed Convolution
|
||||||
|
|
||||||
|
Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
|
||||||
|
|
||||||
Hacked together by Ross Wightman
|
Hacked together by Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
""" Selective Kernel Convolution Attention
|
""" Selective Kernel Convolution/Attention
|
||||||
|
|
||||||
|
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
|
||||||
|
|
||||||
Hacked together by Ross Wightman
|
Hacked together by Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
@ -1,3 +1,9 @@
|
|||||||
|
""" Selective Kernel Networks (ResNet base)
|
||||||
|
|
||||||
|
Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
@ -47,19 +53,11 @@ class SelectiveKernelBasic(nn.Module):
|
|||||||
outplanes = planes * self.expansion
|
outplanes = planes * self.expansion
|
||||||
first_dilation = first_dilation or dilation
|
first_dilation = first_dilation or dilation
|
||||||
|
|
||||||
_selective_first = True # FIXME temporary, for experiments
|
self.conv1 = SelectiveKernelConv(
|
||||||
if _selective_first:
|
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
|
||||||
self.conv1 = SelectiveKernelConv(
|
conv_kwargs['act_layer'] = None
|
||||||
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
|
self.conv2 = ConvBnAct(
|
||||||
conv_kwargs['act_layer'] = None
|
first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs)
|
||||||
self.conv2 = ConvBnAct(
|
|
||||||
first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs)
|
|
||||||
else:
|
|
||||||
self.conv1 = ConvBnAct(
|
|
||||||
inplanes, first_planes, kernel_size=3, stride=stride, dilation=first_dilation, **conv_kwargs)
|
|
||||||
conv_kwargs['act_layer'] = None
|
|
||||||
self.conv2 = SelectiveKernelConv(
|
|
||||||
first_planes, outplanes, dilation=dilation, **conv_kwargs, **sk_kwargs)
|
|
||||||
self.se = create_attn(attn_layer, outplanes)
|
self.se = create_attn(attn_layer, outplanes)
|
||||||
self.act = act_layer(inplace=True)
|
self.act = act_layer(inplace=True)
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
@ -222,7 +220,7 @@ def skresnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||||||
@register_model
|
@register_model
|
||||||
def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
def skresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||||
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
|
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
|
||||||
the SKNet50 model in the Select Kernel Paper
|
the SKNet-50 model in the Select Kernel Paper
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['skresnext50_32x4d']
|
default_cfg = default_cfgs['skresnext50_32x4d']
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user