[Feature] Support TIMMBackbone (#998)

* add TIMMBackbone and unittests

* add timm to tests requirements

* deprecate pt1.3.1

* reduce the unittests input of timm backbone

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* remove unittests of large models of timm backbone

* generate coverage report for all unittests env

* reduce the unittests input of timm backbone

* reduce the unittests input of timm backbone
This commit is contained in:
Junjun2016 2021-11-02 12:51:11 +08:00 committed by GitHub
parent ddce375977
commit 54435fb149
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 215 additions and 1 deletions

View File

@ -71,9 +71,17 @@ jobs:
run: rm -rf .eggs && pip install -e .
- name: Run unittests and generate coverage report
run: |
pip install timm
coverage run --branch --source mmseg -m pytest tests/
coverage xml
coverage report -m
if: ${{matrix.torch >= '1.5.0'}}
- name: Skip timm unittests and generate coverage report
run: |
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
coverage xml
coverage report -m
if: ${{matrix.torch < '1.5.0'}}
build_cuda101:
runs-on: ubuntu-18.04
@ -142,9 +150,17 @@ jobs:
TORCH_CUDA_ARCH_LIST=7.0 pip install .
- name: Run unittests and generate coverage report
run: |
python -m pip install timm
coverage run --branch --source mmseg -m pytest tests/
coverage xml
coverage report -m
if: ${{matrix.torch >= '1.5.0'}}
- name: Skip timm unittests and generate coverage report
run: |
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
coverage xml
coverage report -m
if: ${{matrix.torch < '1.5.0'}}
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1.0.10
with:
@ -198,6 +214,7 @@ jobs:
TORCH_CUDA_ARCH_LIST=7.0 pip install .
- name: Run unittests and generate coverage report
run: |
python -m pip install timm
coverage run --branch --source mmseg -m pytest tests/
coverage xml
coverage report -m

View File

@ -12,6 +12,7 @@ from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnext import ResNeXt
from .swin import SwinTransformer
from .timm_backbone import TIMMBackbone
from .unet import UNet
from .vit import VisionTransformer
@ -19,5 +20,5 @@ __all__ = [
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet'
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone'
]

View File

@ -0,0 +1,63 @@
# Copyright (c) OpenMMLab. All rights reserved.
try:
import timm
except ImportError:
timm = None
from mmcv.cnn.bricks.registry import NORM_LAYERS
from mmcv.runner import BaseModule
from ..builder import BACKBONES
@BACKBONES.register_module()
class TIMMBackbone(BaseModule):
"""Wrapper to use backbones from timm library. More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_ .
Args:
model_name (str): Name of timm model to instantiate.
pretrained (bool): Load pretrained weights if True.
checkpoint_path (str): Path of checkpoint to load after
model is initialized.
in_channels (int): Number of input image channels. Default: 3.
init_cfg (dict, optional): Initialization config dict
**kwargs: Other timm & model specific arguments.
"""
def __init__(
self,
model_name,
features_only=True,
pretrained=True,
checkpoint_path='',
in_channels=3,
init_cfg=None,
**kwargs,
):
if timm is None:
raise RuntimeError('timm is not installed')
super(TIMMBackbone, self).__init__(init_cfg)
if 'norm_layer' in kwargs:
kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer'])
self.timm_model = timm.create_model(
model_name=model_name,
features_only=features_only,
pretrained=pretrained,
in_chans=in_channels,
checkpoint_path=checkpoint_path,
**kwargs,
)
# Make unused parameters None
self.timm_model.global_pool = None
self.timm_model.fc = None
self.timm_model.classifier = None
# Hack to use pretrained weights from timm
if pretrained or checkpoint_path:
self._is_init = True
def forward(self, x):
features = self.timm_model(x)
return features

View File

@ -0,0 +1,133 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.backbones import TIMMBackbone
from .utils import check_norm_state
def test_timm_backbone():
with pytest.raises(TypeError):
# pretrained must be a string path
model = TIMMBackbone()
model.init_weights(pretrained=0)
# Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN'
# Test resnet18 from timm, norm_layer='BN2d'
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32,
norm_layer='BN2d')
# Test resnet18 from timm, norm_layer='SyncBN'
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32,
norm_layer='SyncBN')
# Test resnet18 from timm, features_only=True, output_stride=32
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 112, 112))
assert feats[1] == torch.Size((1, 64, 56, 56))
assert feats[2] == torch.Size((1, 128, 28, 28))
assert feats[3] == torch.Size((1, 256, 14, 14))
assert feats[4] == torch.Size((1, 512, 7, 7))
# Test resnet18 from timm, features_only=True, output_stride=16
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=16)
imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 112, 112))
assert feats[1] == torch.Size((1, 64, 56, 56))
assert feats[2] == torch.Size((1, 128, 28, 28))
assert feats[3] == torch.Size((1, 256, 14, 14))
assert feats[4] == torch.Size((1, 512, 14, 14))
# Test resnet18 from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 112, 112))
assert feats[1] == torch.Size((1, 64, 56, 56))
assert feats[2] == torch.Size((1, 128, 28, 28))
assert feats[3] == torch.Size((1, 256, 28, 28))
assert feats[4] == torch.Size((1, 512, 28, 28))
# Test efficientnet_b1 with pretrained weights
model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True)
# Test resnetv2_50x1_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_50x1_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 4, 4))
assert feats[1] == torch.Size((1, 256, 2, 2))
assert feats[2] == torch.Size((1, 512, 1, 1))
assert feats[3] == torch.Size((1, 1024, 1, 1))
assert feats[4] == torch.Size((1, 2048, 1, 1))
# Test resnetv2_50x3_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_50x3_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 192, 4, 4))
assert feats[1] == torch.Size((1, 768, 2, 2))
assert feats[2] == torch.Size((1, 1536, 1, 1))
assert feats[3] == torch.Size((1, 3072, 1, 1))
assert feats[4] == torch.Size((1, 6144, 1, 1))
# Test resnetv2_101x1_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_101x1_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 4, 4))
assert feats[1] == torch.Size((1, 256, 2, 2))
assert feats[2] == torch.Size((1, 512, 1, 1))
assert feats[3] == torch.Size((1, 1024, 1, 1))
assert feats[4] == torch.Size((1, 2048, 1, 1))