DenseNet, DPN, VoVNet, Aligned Xception weights on HF hub. DenseNet grad_checkpointing using timm API
parent
864bfd43d0
commit
7ad7ddb7ad
|
@ -28,7 +28,7 @@ class DenseLayer(nn.Module):
|
|||
bn_size,
|
||||
norm_layer=BatchNormAct2d,
|
||||
drop_rate=0.,
|
||||
memory_efficient=False,
|
||||
grad_checkpointing=False,
|
||||
):
|
||||
super(DenseLayer, self).__init__()
|
||||
self.add_module('norm1', norm_layer(num_input_features)),
|
||||
|
@ -38,7 +38,7 @@ class DenseLayer(nn.Module):
|
|||
self.add_module('conv2', nn.Conv2d(
|
||||
bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
|
||||
self.drop_rate = float(drop_rate)
|
||||
self.memory_efficient = memory_efficient
|
||||
self.grad_checkpointing = grad_checkpointing
|
||||
|
||||
def bottleneck_fn(self, xs):
|
||||
# type: (List[torch.Tensor]) -> torch.Tensor
|
||||
|
@ -80,7 +80,7 @@ class DenseLayer(nn.Module):
|
|||
else:
|
||||
prev_features = x
|
||||
|
||||
if self.memory_efficient and self.any_requires_grad(prev_features):
|
||||
if self.grad_checkpointing and self.any_requires_grad(prev_features):
|
||||
if torch.jit.is_scripting():
|
||||
raise Exception("Memory Efficient not supported in JIT")
|
||||
bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
|
||||
|
@ -104,7 +104,7 @@ class DenseBlock(nn.ModuleDict):
|
|||
growth_rate,
|
||||
norm_layer=BatchNormAct2d,
|
||||
drop_rate=0.,
|
||||
memory_efficient=False,
|
||||
grad_checkpointing=False,
|
||||
):
|
||||
super(DenseBlock, self).__init__()
|
||||
for i in range(num_layers):
|
||||
|
@ -114,7 +114,7 @@ class DenseBlock(nn.ModuleDict):
|
|||
bn_size=bn_size,
|
||||
norm_layer=norm_layer,
|
||||
drop_rate=drop_rate,
|
||||
memory_efficient=memory_efficient,
|
||||
grad_checkpointing=grad_checkpointing,
|
||||
)
|
||||
self.add_module('denselayer%d' % (i + 1), layer)
|
||||
|
||||
|
@ -153,7 +153,8 @@ class DenseNet(nn.Module):
|
|||
block_config (list of 4 ints) - how many layers in each pooling block
|
||||
bn_size (int) - multiplicative factor for number of bottle neck layers
|
||||
(i.e. bn_size * k features in the bottleneck layer)
|
||||
drop_rate (float) - dropout rate after each dense layer
|
||||
drop_rate (float) - dropout rate before classifier layer
|
||||
proj_drop_rate (float) - dropout rate after each dense layer
|
||||
num_classes (int) - number of classification classes
|
||||
memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient,
|
||||
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
|
||||
|
@ -171,12 +172,12 @@ class DenseNet(nn.Module):
|
|||
act_layer='relu',
|
||||
norm_layer='batchnorm2d',
|
||||
aa_layer=None,
|
||||
drop_rate=0,
|
||||
drop_rate=0.,
|
||||
proj_drop_rate=0.,
|
||||
memory_efficient=False,
|
||||
aa_stem_only=True,
|
||||
):
|
||||
self.num_classes = num_classes
|
||||
self.drop_rate = drop_rate
|
||||
super(DenseNet, self).__init__()
|
||||
norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
|
||||
|
||||
|
@ -222,8 +223,8 @@ class DenseNet(nn.Module):
|
|||
bn_size=bn_size,
|
||||
growth_rate=growth_rate,
|
||||
norm_layer=norm_layer,
|
||||
drop_rate=drop_rate,
|
||||
memory_efficient=memory_efficient
|
||||
drop_rate=proj_drop_rate,
|
||||
grad_checkpointing=memory_efficient,
|
||||
)
|
||||
module_name = f'denseblock{(i + 1)}'
|
||||
self.features.add_module(module_name, block)
|
||||
|
@ -249,8 +250,14 @@ class DenseNet(nn.Module):
|
|||
self.num_features = num_features
|
||||
|
||||
# Linear layer
|
||||
self.global_pool, self.classifier = create_classifier(
|
||||
self.num_features, self.num_classes, pool_type=global_pool)
|
||||
global_pool, classifier = create_classifier(
|
||||
self.num_features,
|
||||
self.num_classes,
|
||||
pool_type=global_pool,
|
||||
)
|
||||
self.global_pool = global_pool
|
||||
self.head_drop = nn.Dropout(drop_rate)
|
||||
self.classifier = classifier
|
||||
|
||||
# Official init from torch repo.
|
||||
for m in self.modules():
|
||||
|
@ -273,6 +280,12 @@ class DenseNet(nn.Module):
|
|||
)
|
||||
return matcher
|
||||
|
||||
@torch.jit.ignore
|
||||
def set_grad_checkpointing(self, enable=True):
|
||||
for b in self.features.modules():
|
||||
if isinstance(b, DenseLayer):
|
||||
b.grad_checkpointing = enable
|
||||
|
||||
@torch.jit.ignore
|
||||
def get_classifier(self):
|
||||
return self.classifier
|
||||
|
@ -288,9 +301,7 @@ class DenseNet(nn.Module):
|
|||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
x = self.global_pool(x)
|
||||
# both classifier and block drop?
|
||||
# if self.drop_rate > 0.:
|
||||
# x = F.dropout(x, p=self.drop_rate, training=self.training)
|
||||
x = self.head_drop(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
|
@ -315,32 +326,34 @@ def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs):
|
|||
DenseNet,
|
||||
variant,
|
||||
pretrained,
|
||||
feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained,
|
||||
feature_cfg=dict(flatten_sequential=True),
|
||||
pretrained_filter_fn=_filter_torchvision_pretrained,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _cfg(url=''):
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'features.conv0', 'classifier': 'classifier',
|
||||
'first_conv': 'features.conv0', 'classifier': 'classifier', **kwargs,
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'densenet121.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenet121_ra-50efcf5c.pth'),
|
||||
'densenet121d': _cfg(url=''),
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||
'densenetblur121d.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/densenetblur121d_ra-100dcfbc.pth'),
|
||||
'densenet169.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet169-b2777c0a.pth'),
|
||||
'densenet201.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet201-c1103571.pth'),
|
||||
'densenet161.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet161-8d451a50.pth'),
|
||||
'densenet264.untrained': _cfg(url=''),
|
||||
'densenet121.tv_in1k': _cfg(url='https://download.pytorch.org/models/densenet121-a639ec97.pth'),
|
||||
}
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||
'densenet264d.untrained': _cfg(),
|
||||
'densenet121.tv_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'densenet169.tv_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'densenet201.tv_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'densenet161.tv_in1k': _cfg(hf_hub_id='timm/'),
|
||||
})
|
||||
|
||||
|
||||
@register_model
|
||||
|
@ -355,7 +368,7 @@ def densenet121(pretrained=False, **kwargs):
|
|||
|
||||
@register_model
|
||||
def densenetblur121d(pretrained=False, **kwargs):
|
||||
r"""Densenet-121 model from
|
||||
r"""Densenet-121 w/ blur-pooling & 3-layer 3x3 stem
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
model = _create_densenet(
|
||||
|
@ -364,17 +377,6 @@ def densenetblur121d(pretrained=False, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def densenet121d(pretrained=False, **kwargs):
|
||||
r"""Densenet-121 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
model = _create_densenet(
|
||||
'densenet121d', growth_rate=32, block_config=(6, 12, 24, 16), stem_type='deep',
|
||||
pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def densenet169(pretrained=False, **kwargs):
|
||||
r"""Densenet-169 model from
|
||||
|
@ -406,11 +408,11 @@ def densenet161(pretrained=False, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def densenet264(pretrained=False, **kwargs):
|
||||
def densenet264d(pretrained=False, **kwargs):
|
||||
r"""Densenet-264 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
model = _create_densenet(
|
||||
'densenet264', growth_rate=48, block_config=(6, 12, 64, 48), pretrained=pretrained, **kwargs)
|
||||
'densenet264d', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep', pretrained=pretrained, **kwargs)
|
||||
return model
|
||||
|
||||
|
|
|
@ -302,21 +302,16 @@ def _cfg(url='', **kwargs):
|
|||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'dpn48b.untrained': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'dpn68.mx_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
|
||||
'dpn68.mx_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dpn68b.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/dpn68b_ra-a31ca160.pth',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
|
||||
'dpn68b.mx_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth'),
|
||||
'dpn92.mx_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'),
|
||||
'dpn98.mx_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'),
|
||||
'dpn131.mx_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn131-71dfe43e0.pth'),
|
||||
'dpn107.mx_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth')
|
||||
hf_hub_id='timm/',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
||||
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
||||
'dpn68b.mx_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dpn92.mx_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dpn98.mx_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dpn131.mx_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'dpn107.mx_in1k': _cfg(hf_hub_id='timm/')
|
||||
})
|
||||
|
||||
|
||||
|
|
|
@ -389,12 +389,12 @@ def _create_vovnet(variant, pretrained=False, **kwargs):
|
|||
)
|
||||
|
||||
|
||||
def _cfg(url=''):
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'stem.0.conv', 'classifier': 'head.fc',
|
||||
'first_conv': 'stem.0.conv', 'classifier': 'head.fc', **kwargs,
|
||||
}
|
||||
|
||||
|
||||
|
@ -403,10 +403,12 @@ default_cfgs = generate_default_cfgs({
|
|||
'vovnet57a.untrained': _cfg(url=''),
|
||||
'ese_vovnet19b_slim_dw.untrained': _cfg(url=''),
|
||||
'ese_vovnet19b_dw.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet19b_dw-a8741004.pth'),
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||
'ese_vovnet19b_slim.untrained': _cfg(url=''),
|
||||
'ese_vovnet39b.ra_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ese_vovnet39b-f912fe73.pth'),
|
||||
hf_hub_id='timm/',
|
||||
test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
||||
'ese_vovnet57b.untrained': _cfg(url=''),
|
||||
'ese_vovnet99b.untrained': _cfg(url=''),
|
||||
'eca_vovnet39b.untrained': _cfg(url=''),
|
||||
|
|
|
@ -300,23 +300,20 @@ def _cfg(url='', **kwargs):
|
|||
|
||||
default_cfgs = generate_default_cfgs({
|
||||
'xception65.ra3_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65_ra3-1447db8d.pth',
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.94,
|
||||
),
|
||||
|
||||
'xception41.tf_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
|
||||
'xception65.tf_in1k': _cfg(
|
||||
url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
|
||||
'xception71.tf_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
|
||||
'xception41.tf_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'xception65.tf_in1k': _cfg(hf_hub_id='timm/'),
|
||||
'xception71.tf_in1k': _cfg(hf_hub_id='timm/'),
|
||||
|
||||
'xception41p.ra3_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception41p_ra3-33195bc8.pth',
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.94,
|
||||
),
|
||||
'xception65p.ra3_in1k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/xception65p_ra3-3c6114e4.pth',
|
||||
hf_hub_id='timm/',
|
||||
crop_pct=0.94,
|
||||
),
|
||||
})
|
||||
|
|
Loading…
Reference in New Issue