mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add AvgPool2d anti-aliasing support to ResNet arch (as per OpenAI CLIP models), add a few blur aa models as well
This commit is contained in:
parent
f0f9eccda8
commit
1aa617cb3b
@ -251,6 +251,21 @@ default_cfgs = {
|
||||
'resnetblur50': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth',
|
||||
interpolation='bicubic'),
|
||||
'resnetblur50d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
'resnetblur101d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
'resnetaa50d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
'resnetaa101d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
'seresnetaa50d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic', first_conv='conv1.0'),
|
||||
|
||||
# ResNet-RS models
|
||||
'resnetrs50': _cfg(
|
||||
@ -289,6 +304,12 @@ def get_padding(kernel_size, stride, dilation=1):
|
||||
return padding
|
||||
|
||||
|
||||
def create_aa(aa_layer, channels, stride=2, enable=True):
|
||||
if not aa_layer or not enable:
|
||||
return None
|
||||
return aa_layer(stride) if issubclass(aa_layer, nn.AvgPool2d) else aa_layer(channels=channels, stride=stride)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
@ -309,7 +330,7 @@ class BasicBlock(nn.Module):
|
||||
dilation=first_dilation, bias=False)
|
||||
self.bn1 = norm_layer(first_planes)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None
|
||||
self.aa = create_aa(aa_layer, channels=first_planes, stride=stride, enable=use_aa)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
|
||||
@ -380,7 +401,7 @@ class Bottleneck(nn.Module):
|
||||
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.aa = aa_layer(channels=width, stride=stride) if use_aa else None
|
||||
self.aa = create_aa(aa_layer, channels=width, stride=stride, enable=use_aa)
|
||||
|
||||
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
||||
self.bn3 = norm_layer(outplanes)
|
||||
@ -617,16 +638,19 @@ class ResNet(nn.Module):
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
|
||||
|
||||
# Stem Pooling
|
||||
# Stem pooling. The name 'maxpool' remains for weight compatibility.
|
||||
if replace_stem_pool:
|
||||
self.maxpool = nn.Sequential(*filter(None, [
|
||||
nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False),
|
||||
aa_layer(channels=inplanes, stride=2) if aa_layer else None,
|
||||
create_aa(aa_layer, channels=inplanes, stride=2),
|
||||
norm_layer(inplanes),
|
||||
act_layer(inplace=True)
|
||||
]))
|
||||
else:
|
||||
if aa_layer is not None:
|
||||
if issubclass(aa_layer, nn.AvgPool2d):
|
||||
self.maxpool = aa_layer(2)
|
||||
else:
|
||||
self.maxpool = nn.Sequential(*[
|
||||
nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
|
||||
aa_layer(channels=inplanes, stride=2)])
|
||||
@ -1342,6 +1366,56 @@ def resnetblur50(pretrained=False, **kwargs):
|
||||
return _create_resnet('resnetblur50', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetblur50d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50-D model with blur anti-aliasing
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d,
|
||||
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
||||
return _create_resnet('resnetblur50d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetblur101d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-101-D model with blur anti-aliasing
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d,
|
||||
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
||||
return _create_resnet('resnetblur101d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetaa50d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50-D model with avgpool anti-aliasing
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
|
||||
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
||||
return _create_resnet('resnetaa50d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def resnetaa101d(pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-101-D model with avgpool anti-aliasing
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d,
|
||||
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
|
||||
return _create_resnet('resnetaa101d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnetaa50d(pretrained=False, **kwargs):
|
||||
"""Constructs a SE=ResNet-50-D model with avgpool anti-aliasing
|
||||
"""
|
||||
model_args = dict(
|
||||
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
|
||||
stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
|
||||
return _create_resnet('seresnetaa50d', pretrained, **model_args)
|
||||
|
||||
|
||||
@register_model
|
||||
def seresnet18(pretrained=False, **kwargs):
|
||||
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user