mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Some TResNet cleanup.
* allow use of global pool arg, test-time-pooling * clean checkpoints to just contain state dict, add 448 res checkpoints * support DataParallel via lazy filter creation for JIT Downsample * some minor formatting (mostly alignment) preferences
This commit is contained in:
parent
64fe37d008
commit
0004f37d25
@ -70,10 +70,11 @@ class AdaptiveCatAvgMaxPool2d(nn.Module):
|
||||
class SelectAdaptivePool2d(nn.Module):
|
||||
"""Selectable global pooling layer with dynamic input kernel size
|
||||
"""
|
||||
def __init__(self, output_size=1, pool_type='avg'):
|
||||
def __init__(self, output_size=1, pool_type='avg', flatten=False):
|
||||
super(SelectAdaptivePool2d, self).__init__()
|
||||
self.output_size = output_size
|
||||
self.pool_type = pool_type
|
||||
self.flatten = flatten
|
||||
if pool_type == 'avgmax':
|
||||
self.pool = AdaptiveAvgMaxPool2d(output_size)
|
||||
elif pool_type == 'catavgmax':
|
||||
@ -86,7 +87,10 @@ class SelectAdaptivePool2d(nn.Module):
|
||||
self.pool = nn.AdaptiveAvgPool2d(output_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.pool(x)
|
||||
x = self.pool(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(1)
|
||||
return x
|
||||
|
||||
def feat_mult(self):
|
||||
return adaptive_pool_feat_mult(self.pool_type)
|
||||
|
@ -5,13 +5,14 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class AntiAliasDownsampleLayer(nn.Module):
|
||||
def __init__(self, remove_aa_jit: bool = False, filt_size: int = 3, stride: int = 2,
|
||||
channels: int = 0):
|
||||
def __init__(self, no_jit: bool = False, filt_size: int = 3, stride: int = 2, channels: int = 0):
|
||||
super(AntiAliasDownsampleLayer, self).__init__()
|
||||
if not remove_aa_jit:
|
||||
self.op = DownsampleJIT(filt_size, stride, channels)
|
||||
else:
|
||||
if no_jit:
|
||||
self.op = Downsample(filt_size, stride, channels)
|
||||
else:
|
||||
self.op = DownsampleJIT(filt_size, stride, channels)
|
||||
|
||||
# FIXME I should probably override _apply and clear DownsampleJIT filter cache for .cuda(), .half(), etc calls
|
||||
|
||||
def forward(self, x):
|
||||
return self.op(x)
|
||||
@ -23,20 +24,21 @@ class DownsampleJIT(object):
|
||||
self.stride = stride
|
||||
self.filt_size = filt_size
|
||||
self.channels = channels
|
||||
|
||||
assert self.filt_size == 3
|
||||
assert stride == 2
|
||||
a = torch.tensor([1., 2., 1.])
|
||||
self.filt = {} # lazy init by device for DataParallel compat
|
||||
|
||||
filt = (a[:, None] * a[None, :]).clone().detach()
|
||||
def _create_filter(self, like: torch.Tensor):
|
||||
filt = torch.tensor([1., 2., 1.], dtype=like.dtype, device=like.device)
|
||||
filt = filt[:, None] * filt[None, :]
|
||||
filt = filt / torch.sum(filt)
|
||||
self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1)).cuda().half()
|
||||
filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
|
||||
return filt
|
||||
|
||||
def __call__(self, input: torch.Tensor):
|
||||
if input.dtype != self.filt.dtype:
|
||||
self.filt = self.filt.float()
|
||||
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
|
||||
return F.conv2d(input_pad, self.filt, stride=2, padding=0, groups=input.shape[1])
|
||||
filt = self.filt.get(str(input.device), self._create_filter(input))
|
||||
return F.conv2d(input_pad, filt, stride=2, padding=0, groups=input.shape[1])
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
@ -46,11 +48,9 @@ class Downsample(nn.Module):
|
||||
self.stride = stride
|
||||
self.channels = channels
|
||||
|
||||
|
||||
assert self.filt_size == 3
|
||||
a = torch.tensor([1., 2., 1.])
|
||||
|
||||
filt = (a[:, None] * a[None, :])
|
||||
filt = torch.tensor([1., 2., 1.])
|
||||
filt = filt[:, None] * filt[None, :]
|
||||
filt = filt / torch.sum(filt)
|
||||
|
||||
# self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
|
||||
@ -58,4 +58,4 @@ class Downsample(nn.Module):
|
||||
|
||||
def forward(self, input):
|
||||
input_pad = F.pad(input, (1, 1, 1, 1), 'reflect')
|
||||
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])
|
||||
return F.conv2d(input_pad, self.filt, stride=self.stride, padding=0, groups=input.shape[1])
|
||||
|
@ -28,9 +28,9 @@ class SpaceToDepthJit(object):
|
||||
|
||||
|
||||
class SpaceToDepthModule(nn.Module):
|
||||
def __init__(self, remove_model_jit=False):
|
||||
def __init__(self, no_jit=False):
|
||||
super().__init__()
|
||||
if not remove_model_jit:
|
||||
if not no_jit:
|
||||
self.op = SpaceToDepthJit()
|
||||
else:
|
||||
self.op = SpaceToDepth()
|
||||
|
@ -8,8 +8,9 @@ Original model: https://github.com/mrT23/TResNet
|
||||
from functools import partial
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from collections import OrderedDict
|
||||
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer
|
||||
from .layers import SpaceToDepthModule, AntiAliasDownsampleLayer, SelectAdaptivePool2d
|
||||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
|
||||
@ -27,18 +28,27 @@ def _cfg(url='', **kwargs):
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||
'mean': (0, 0, 0), 'std': (1, 1, 1),
|
||||
'first_conv': 'layer0.conv1', 'classifier': 'head',
|
||||
'first_conv': 'layer0.conv1', 'classifier': 'head.fc',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'tresnet_m':
|
||||
_cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_m_80_8.pth'),
|
||||
'tresnet_l':
|
||||
_cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_l_81_5.pth'),
|
||||
'tresnet_xl':
|
||||
_cfg(url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/tresnet/tresnet_xl_82_0.pth')
|
||||
'tresnet_m': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_80_8-dbc13962.pth'),
|
||||
'tresnet_l': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_81_5-235b486c.pth'),
|
||||
'tresnet_xl': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_82_0-a2d51b00.pth'),
|
||||
'tresnet_m_448': _cfg(
|
||||
input_size=(3, 448, 448),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_m_448-bc359d10.pth'),
|
||||
'tresnet_l_448': _cfg(
|
||||
input_size=(3, 448, 448),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_l_448-940d0cd1.pth'),
|
||||
'tresnet_xl_448': _cfg(
|
||||
input_size=(3, 448, 448),
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/tresnet_xl_448-8c1815de.pth')
|
||||
}
|
||||
|
||||
|
||||
@ -54,6 +64,9 @@ class FastGlobalAvgPool2d(nn.Module):
|
||||
else:
|
||||
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
|
||||
|
||||
def feat_mult(self):
|
||||
return 1
|
||||
|
||||
|
||||
class FastSEModule(nn.Module):
|
||||
|
||||
@ -78,14 +91,15 @@ def IABN2Float(module: nn.Module) -> nn.Module:
|
||||
"If `module` is IABN don't use half precision."
|
||||
if isinstance(module, InPlaceABN):
|
||||
module.float()
|
||||
for child in module.children(): IABN2Float(child)
|
||||
for child in module.children():
|
||||
IABN2Float(child)
|
||||
return module
|
||||
|
||||
|
||||
def conv2d_ABN(ni, nf, stride, activation="leaky_relu", kernel_size=3, activation_param=1e-2, groups=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups,
|
||||
bias=False),
|
||||
nn.Conv2d(
|
||||
ni, nf, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, groups=groups, bias=False),
|
||||
InPlaceABN(num_features=nf, activation=activation, activation_param=activation_param)
|
||||
)
|
||||
|
||||
@ -101,8 +115,9 @@ class BasicBlock(nn.Module):
|
||||
if anti_alias_layer is None:
|
||||
self.conv1 = conv2d_ABN(inplanes, planes, stride=2, activation_param=1e-3)
|
||||
else:
|
||||
self.conv1 = nn.Sequential(conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3),
|
||||
anti_alias_layer(channels=planes, filt_size=3, stride=2))
|
||||
self.conv1 = nn.Sequential(
|
||||
conv2d_ABN(inplanes, planes, stride=1, activation_param=1e-3),
|
||||
anti_alias_layer(channels=planes, filt_size=3, stride=2))
|
||||
|
||||
self.conv2 = conv2d_ABN(planes, planes, stride=1, activation="identity")
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
@ -120,12 +135,11 @@ class BasicBlock(nn.Module):
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
|
||||
if self.se is not None: out = self.se(out)
|
||||
if self.se is not None:
|
||||
out = self.se(out)
|
||||
|
||||
out += residual
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@ -134,22 +148,22 @@ class Bottleneck(nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True, anti_alias_layer=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = conv2d_ABN(inplanes, planes, kernel_size=1, stride=1, activation="leaky_relu",
|
||||
activation_param=1e-3)
|
||||
self.conv1 = conv2d_ABN(
|
||||
inplanes, planes, kernel_size=1, stride=1, activation="leaky_relu", activation_param=1e-3)
|
||||
if stride == 1:
|
||||
self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=1, activation="leaky_relu",
|
||||
activation_param=1e-3)
|
||||
self.conv2 = conv2d_ABN(
|
||||
planes, planes, kernel_size=3, stride=1, activation="leaky_relu", activation_param=1e-3)
|
||||
else:
|
||||
if anti_alias_layer is None:
|
||||
self.conv2 = conv2d_ABN(planes, planes, kernel_size=3, stride=2, activation="leaky_relu",
|
||||
activation_param=1e-3)
|
||||
self.conv2 = conv2d_ABN(
|
||||
planes, planes, kernel_size=3, stride=2, activation="leaky_relu", activation_param=1e-3)
|
||||
else:
|
||||
self.conv2 = nn.Sequential(conv2d_ABN(planes, planes, kernel_size=3, stride=1,
|
||||
activation="leaky_relu", activation_param=1e-3),
|
||||
anti_alias_layer(channels=planes, filt_size=3, stride=2))
|
||||
self.conv2 = nn.Sequential(
|
||||
conv2d_ABN(planes, planes, kernel_size=3, stride=1, activation="leaky_relu", activation_param=1e-3),
|
||||
anti_alias_layer(channels=planes, filt_size=3, stride=2))
|
||||
|
||||
self.conv3 = conv2d_ABN(planes, planes * self.expansion, kernel_size=1, stride=1,
|
||||
activation="identity")
|
||||
self.conv3 = conv2d_ABN(
|
||||
planes, planes * self.expansion, kernel_size=1, stride=1, activation="identity")
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
@ -166,7 +180,8 @@ class Bottleneck(nn.Module):
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.conv2(out)
|
||||
if self.se is not None: out = self.se(out)
|
||||
if self.se is not None:
|
||||
out = self.se(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = out + residual # no inplace
|
||||
@ -176,29 +191,32 @@ class Bottleneck(nn.Module):
|
||||
|
||||
|
||||
class TResNet(nn.Module):
|
||||
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, remove_aa_jit=False):
|
||||
def __init__(self, layers, in_chans=3, num_classes=1000, width_factor=1.0, no_aa_jit=False,
|
||||
global_pool='avg', drop_rate=0.):
|
||||
if not has_iabn:
|
||||
raise " For TResNet models, please install InplaceABN: 'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11' "
|
||||
|
||||
raise ImportError(
|
||||
"For TResNet models, please install InplaceABN: "
|
||||
"'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.11'")
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
super(TResNet, self).__init__()
|
||||
|
||||
# JIT layers
|
||||
space_to_depth = SpaceToDepthModule()
|
||||
anti_alias_layer = partial(AntiAliasDownsampleLayer, remove_aa_jit=remove_aa_jit)
|
||||
global_pool_layer = FastGlobalAvgPool2d(flatten=True)
|
||||
anti_alias_layer = partial(AntiAliasDownsampleLayer, no_jit=no_aa_jit)
|
||||
|
||||
# TResnet stages
|
||||
self.inplanes = int(64 * width_factor)
|
||||
self.planes = int(64 * width_factor)
|
||||
conv1 = conv2d_ABN(in_chans * 16, self.planes, stride=1, kernel_size=3)
|
||||
layer1 = self._make_layer(BasicBlock, self.planes, layers[0], stride=1, use_se=True,
|
||||
anti_alias_layer=anti_alias_layer) # 56x56
|
||||
layer2 = self._make_layer(BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True,
|
||||
anti_alias_layer=anti_alias_layer) # 28x28
|
||||
layer3 = self._make_layer(Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True,
|
||||
anti_alias_layer=anti_alias_layer) # 14x14
|
||||
layer4 = self._make_layer(Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False,
|
||||
anti_alias_layer=anti_alias_layer) # 7x7
|
||||
layer1 = self._make_layer(
|
||||
BasicBlock, self.planes, layers[0], stride=1, use_se=True, anti_alias_layer=anti_alias_layer) # 56x56
|
||||
layer2 = self._make_layer(
|
||||
BasicBlock, self.planes * 2, layers[1], stride=2, use_se=True, anti_alias_layer=anti_alias_layer) # 28x28
|
||||
layer3 = self._make_layer(
|
||||
Bottleneck, self.planes * 4, layers[2], stride=2, use_se=True, anti_alias_layer=anti_alias_layer) # 14x14
|
||||
layer4 = self._make_layer(
|
||||
Bottleneck, self.planes * 8, layers[3], stride=2, use_se=False, anti_alias_layer=anti_alias_layer) # 7x7
|
||||
|
||||
# body
|
||||
self.body = nn.Sequential(OrderedDict([
|
||||
@ -210,11 +228,10 @@ class TResNet(nn.Module):
|
||||
('layer4', layer4)]))
|
||||
|
||||
# head
|
||||
self.embeddings = []
|
||||
self.global_pool = nn.Sequential(OrderedDict([('global_pool_layer', global_pool_layer)]))
|
||||
self.num_features = (self.planes * 8) * Bottleneck.expansion
|
||||
fc = nn.Linear(self.num_features, num_classes)
|
||||
self.head = nn.Sequential(OrderedDict([('fc', fc)]))
|
||||
self.global_pool = None
|
||||
self.head = None
|
||||
self.reset_classifier(num_classes, global_pool)
|
||||
|
||||
# model initilization
|
||||
for m in self.modules():
|
||||
@ -239,54 +256,104 @@ class TResNet(nn.Module):
|
||||
if stride == 2:
|
||||
# avg pooling before 1x1 conv
|
||||
layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
|
||||
layers += [conv2d_ABN(self.inplanes, planes * block.expansion, kernel_size=1, stride=1,
|
||||
activation="identity")]
|
||||
layers += [conv2d_ABN(
|
||||
self.inplanes, planes * block.expansion, kernel_size=1, stride=1, activation="identity")]
|
||||
downsample = nn.Sequential(*layers)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, use_se=use_se,
|
||||
anti_alias_layer=anti_alias_layer))
|
||||
layers.append(block(
|
||||
self.inplanes, planes, stride, downsample, use_se=use_se, anti_alias_layer=anti_alias_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks): layers.append(
|
||||
block(self.inplanes, planes, use_se=use_se, anti_alias_layer=anti_alias_layer))
|
||||
for i in range(1, blocks):
|
||||
layers.append(
|
||||
block(self.inplanes, planes, use_se=use_se, anti_alias_layer=anti_alias_layer))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head.fc
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
if global_pool == 'avg':
|
||||
self.global_pool = FastGlobalAvgPool2d(flatten=True)
|
||||
else:
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
|
||||
self.head = None
|
||||
if num_classes:
|
||||
self.head = nn.Sequential(OrderedDict([
|
||||
('fc', nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes))]))
|
||||
|
||||
def forward_features(self, x):
|
||||
return self.body(x)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.body(x)
|
||||
self.embeddings = self.global_pool(x)
|
||||
logits = self.head(self.embeddings)
|
||||
return logits
|
||||
|
||||
|
||||
def filter_fn(input):
|
||||
return input['model']
|
||||
x = self.forward_features(x)
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate:
|
||||
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
@register_model
|
||||
def tresnet_m(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['tresnet_m']
|
||||
model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans)
|
||||
model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=filter_fn)
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tresnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['tresnet_l']
|
||||
model = TResNet(layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2)
|
||||
model = TResNet(
|
||||
layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=filter_fn)
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tresnet_xl(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['tresnet_xl']
|
||||
model = TResNet(layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3)
|
||||
model = TResNet(
|
||||
layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=filter_fn)
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tresnet_m_448(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['tresnet_m_448']
|
||||
model = TResNet(layers=[3, 4, 11, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tresnet_l_448(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['tresnet_l_448']
|
||||
model = TResNet(
|
||||
layers=[4, 5, 18, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.2, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tresnet_xl_448(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
default_cfg = default_cfgs['tresnet_xl_448']
|
||||
model = TResNet(
|
||||
layers=[4, 5, 24, 3], num_classes=num_classes, in_chans=in_chans, width_factor=1.3, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
Loading…
x
Reference in New Issue
Block a user