DenseNet, DPN, VoVNet, Aligned Xception weights on HF hub. DenseNet grad_checkpointing using timm API

pull/1789/head
Ross Wightman 2023-04-21 16:56:44 -07:00
parent 864bfd43d0
commit 7ad7ddb7ad
4 changed files with 65 additions and 69 deletions

View File

@ -28,7 +28,7 @@ class DenseLayer(nn.Module):
bn_size,
norm_layer=BatchNormAct2d,
drop_rate=0.,
memory_efficient=False,
grad_checkpointing=False,
):
super(DenseLayer, self).__init__()
self.add_module('norm1', norm_layer(num_input_features)),
@ -38,7 +38,7 @@ class DenseLayer(nn.Module):
self.add_module('conv2', nn.Conv2d(
bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient
self.grad_checkpointing = grad_checkpointing
def bottleneck_fn(self, xs):
# type: (List[torch.Tensor]) -> torch.Tensor
@ -80,7 +80,7 @@ class DenseLayer(nn.Module):
else:
prev_features = x
if self.memory_efficient and self.any_requires_grad(prev_features):
if self.grad_checkpointing and self.any_requires_grad(prev_features):
if torch.jit.is_scripting():
raise Exception("Memory Efficient not supported in JIT")
bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
@ -104,7 +104,7 @@ class DenseBlock(nn.ModuleDict):
growth_rate,
norm_layer=BatchNormAct2d,
drop_rate=0.,
memory_efficient=False,
grad_checkpointing=False,
):
super(DenseBlock, self).__init__()
for i in range(num_layers):
@ -114,7 +114,7 @@ class DenseBlock(nn.ModuleDict):
bn_size=bn_size,
norm_layer=norm_layer,
drop_rate=drop_rate,
memory_efficient=memory_efficient,
grad_checkpointing=grad_checkpointing,
)
self.add_module('denselayer%d' % (i + 1), layer)
@ -153,7 +153,8 @@ class DenseNet(nn.Module):
block_config (list of 4 ints) - how many layers in each pooling block
bn_size (int) - multiplicative factor for number of bottle neck layers
(i.e. bn_size * k features in the bottleneck layer)
drop_rate (float) - dropout rate after each dense layer
drop_rate (float) - dropout rate before classifier layer
proj_drop_rate (float) - dropout rate after each dense layer
num_classes (int) - number of classification classes
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
@ -171,12 +172,12 @@ class DenseNet(nn.Module):
act_layer='relu',
norm_layer='batchnorm2d',
aa_layer=None,
drop_rate=0,
drop_rate=0.,
proj_drop_rate=0.,
memory_efficient=False,
aa_stem_only=True,
):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(DenseNet, self).__init__()
norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
@ -222,8 +223,8 @@ class DenseNet(nn.Module):
bn_size=bn_size,
growth_rate=growth_rate,
norm_layer=norm_layer,
drop_rate=drop_rate,
memory_efficient=memory_efficient
drop_rate=proj_drop_rate,
grad_checkpointing=memory_efficient,
)
module_name = f'denseblock{(i + 1)}'
self.features.add_module(module_name, block)
@ -249,8 +250,14 @@ class DenseNet(nn.Module):
self.num_features = num_features
# Linear layer
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
global_pool, classifier = create_classifier(
self.num_features,
self.num_classes,
pool_type=global_pool,
)
self.global_pool = global_pool
self.head_drop = nn.Dropout(drop_rate)
self.classifier = classifier
# Official init from torch repo.
for m in self.modules():
@ -273,6 +280,12 @@ class DenseNet(nn.Module):
)
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for b in self.features.modules():
if isinstance(b, DenseLayer):
b.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.classifier
@ -288,9 +301,7 @@ class DenseNet(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x)
# both classifier and block drop?
# if self.drop_rate > 0.:
# x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.head_drop(x)
x = self.classifier(x)
return x
@ -315,32 +326,34 @@ def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs):
DenseNet,
variant,
pretrained,
feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained,
feature_cfg=dict(flatten_sequential=True),
pretrained_filter_fn=_filter_torchvision_pretrained,
**kwargs,
)
def _cfg(url=''):
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'features.conv0', 'classifier': 'classifier',
'first_conv': 'features.conv0', 'classifier': 'classifier', **kwargs,
}
default_cfgs = {
default_cfgs = generate_default_cfgs({
'densenet121.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenet121_ra-50efcf5c.pth'),
'densenet121d': _cfg(url=''),
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'densenetblur121d.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenetblur121d_ra-100dcfbc.pth'),
'densenet169.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'),
'densenet201.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'),
'densenet161.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'),
'densenet264.untrained': _cfg(url=''),
'densenet121.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet121-a639ec97.pth'),
}
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'densenet264d.untrained': _cfg(),
'densenet121.tv_in1k': _cfg(hf_hub_id='timm/'),
'densenet169.tv_in1k': _cfg(hf_hub_id='timm/'),
'densenet201.tv_in1k': _cfg(hf_hub_id='timm/'),
'densenet161.tv_in1k': _cfg(hf_hub_id='timm/'),
})
@register_model
@ -355,7 +368,7 @@ def densenet121(pretrained=False, **kwargs):
@register_model
def densenetblur121d(pretrained=False, **kwargs):
r"""Densenet-121 model from
r"""Densenet-121 w/ blur-pooling & 3-layer 3x3 stem
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _create_densenet(
@ -364,17 +377,6 @@ def densenetblur121d(pretrained=False, **kwargs):
return model
@register_model
def densenet121d(pretrained=False, **kwargs):
r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _create_densenet(
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
pretrained=pretrained, **kwargs)
return model
@register_model
def densenet169(pretrained=False, **kwargs):
r"""Densenet-169 model from
@ -406,11 +408,11 @@ def densenet161(pretrained=False, **kwargs):
@register_model
def densenet264(pretrained=False, **kwargs):
def densenet264d(pretrained=False, **kwargs):
r"""Densenet-264 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _create_densenet(
'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs)
'densenet264d', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep', pretrained=pretrained, **kwargs)
return model

View File

@ -302,21 +302,16 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
'dpn48b.untrained': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'dpn68.mx_in1k': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
'dpn68.mx_in1k': _cfg(hf_hub_id='timm/'),
'dpn68b.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dpn68b_ra-a31ca160.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'dpn68b.mx_in1k': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth'),
'dpn92.mx_in1k': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'),
'dpn98.mx_in1k': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'),
'dpn131.mx_in1k': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn131-71dfe43e0.pth'),
'dpn107.mx_in1k': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth')
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'dpn68b.mx_in1k': _cfg(hf_hub_id='timm/'),
'dpn92.mx_in1k': _cfg(hf_hub_id='timm/'),
'dpn98.mx_in1k': _cfg(hf_hub_id='timm/'),
'dpn131.mx_in1k': _cfg(hf_hub_id='timm/'),
'dpn107.mx_in1k': _cfg(hf_hub_id='timm/')
})

View File

@ -389,12 +389,12 @@ def _create_vovnet(variant, pretrained=False, **kwargs):
)
def _cfg(url=''):
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
'first_conv': 'stem.0.conv', 'classifier': 'head.fc', **kwargs,
}
@ -403,10 +403,12 @@ default_cfgs = generate_default_cfgs({
'vovnet57a.untrained': _cfg(url=''),
'ese_vovnet19b_slim_dw.untrained': _cfg(url=''),
'ese_vovnet19b_dw.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet19b_dw-a8741004.pth'),
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'ese_vovnet19b_slim.untrained': _cfg(url=''),
'ese_vovnet39b.ra_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet39b-f912fe73.pth'),
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'ese_vovnet57b.untrained': _cfg(url=''),
'ese_vovnet99b.untrained': _cfg(url=''),
'eca_vovnet39b.untrained': _cfg(url=''),

View File

@ -300,23 +300,20 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
'xception65.ra3_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65_ra3-1447db8d.pth',
hf_hub_id='timm/',
crop_pct=0.94,
),
'xception41.tf_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
'xception65.tf_in1k': _cfg(
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
'xception71.tf_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
'xception41.tf_in1k': _cfg(hf_hub_id='timm/'),
'xception65.tf_in1k': _cfg(hf_hub_id='timm/'),
'xception71.tf_in1k': _cfg(hf_hub_id='timm/'),
'xception41p.ra3_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception41p_ra3-33195bc8.pth',
hf_hub_id='timm/',
crop_pct=0.94,
),
'xception65p.ra3_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65p_ra3-3c6114e4.pth',
hf_hub_id='timm/',
crop_pct=0.94,
),
})