[Feature] Support TIMMBackbone (#998)

* add TIMMBackbone and unittests

* add timm to tests requirements

* deprecate pt1.3.1

* reduce the unittests input of timm backbone

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* fix ci

* remove unittests of large models of timm backbone

* generate coverage report for all unittests env

* reduce the unittests input of timm backbone

* reduce the unittests input of timm backbone
This commit is contained in:
Junjun2016 2021-11-02 12:51:11 +08:00 committed by GitHub
parent ddce375977
commit 54435fb149
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 215 additions and 1 deletions

View File

@ -71,9 +71,17 @@ jobs:
run: rm -rf .eggs && pip install -e . run: rm -rf .eggs && pip install -e .
- name: Run unittests and generate coverage report - name: Run unittests and generate coverage report
run: | run: |
pip install timm
coverage run --branch --source mmseg -m pytest tests/ coverage run --branch --source mmseg -m pytest tests/
coverage xml coverage xml
coverage report -m coverage report -m
if: ${{matrix.torch >= '1.5.0'}}
- name: Skip timm unittests and generate coverage report
run: |
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
coverage xml
coverage report -m
if: ${{matrix.torch < '1.5.0'}}
build_cuda101: build_cuda101:
runs-on: ubuntu-18.04 runs-on: ubuntu-18.04
@ -142,9 +150,17 @@ jobs:
TORCH_CUDA_ARCH_LIST=7.0 pip install . TORCH_CUDA_ARCH_LIST=7.0 pip install .
- name: Run unittests and generate coverage report - name: Run unittests and generate coverage report
run: | run: |
python -m pip install timm
coverage run --branch --source mmseg -m pytest tests/ coverage run --branch --source mmseg -m pytest tests/
coverage xml coverage xml
coverage report -m coverage report -m
if: ${{matrix.torch >= '1.5.0'}}
- name: Skip timm unittests and generate coverage report
run: |
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
coverage xml
coverage report -m
if: ${{matrix.torch < '1.5.0'}}
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
uses: codecov/codecov-action@v1.0.10 uses: codecov/codecov-action@v1.0.10
with: with:
@ -198,6 +214,7 @@ jobs:
TORCH_CUDA_ARCH_LIST=7.0 pip install . TORCH_CUDA_ARCH_LIST=7.0 pip install .
- name: Run unittests and generate coverage report - name: Run unittests and generate coverage report
run: | run: |
python -m pip install timm
coverage run --branch --source mmseg -m pytest tests/ coverage run --branch --source mmseg -m pytest tests/
coverage xml coverage xml
coverage report -m coverage report -m

View File

@ -12,6 +12,7 @@ from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnext import ResNeXt from .resnext import ResNeXt
from .swin import SwinTransformer from .swin import SwinTransformer
from .timm_backbone import TIMMBackbone
from .unet import UNet from .unet import UNet
from .vit import VisionTransformer from .vit import VisionTransformer
@ -19,5 +20,5 @@ __all__ = [
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet' 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone'
] ]

View File

@ -0,0 +1,63 @@
# Copyright (c) OpenMMLab. All rights reserved.
try:
import timm
except ImportError:
timm = None
from mmcv.cnn.bricks.registry import NORM_LAYERS
from mmcv.runner import BaseModule
from ..builder import BACKBONES
@BACKBONES.register_module()
class TIMMBackbone(BaseModule):
"""Wrapper to use backbones from timm library. More details can be found in
`timm <https://github.com/rwightman/pytorch-image-models>`_ .
Args:
model_name (str): Name of timm model to instantiate.
pretrained (bool): Load pretrained weights if True.
checkpoint_path (str): Path of checkpoint to load after
model is initialized.
in_channels (int): Number of input image channels. Default: 3.
init_cfg (dict, optional): Initialization config dict
**kwargs: Other timm & model specific arguments.
"""
def __init__(
self,
model_name,
features_only=True,
pretrained=True,
checkpoint_path='',
in_channels=3,
init_cfg=None,
**kwargs,
):
if timm is None:
raise RuntimeError('timm is not installed')
super(TIMMBackbone, self).__init__(init_cfg)
if 'norm_layer' in kwargs:
kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer'])
self.timm_model = timm.create_model(
model_name=model_name,
features_only=features_only,
pretrained=pretrained,
in_chans=in_channels,
checkpoint_path=checkpoint_path,
**kwargs,
)
# Make unused parameters None
self.timm_model.global_pool = None
self.timm_model.fc = None
self.timm_model.classifier = None
# Hack to use pretrained weights from timm
if pretrained or checkpoint_path:
self._is_init = True
def forward(self, x):
features = self.timm_model(x)
return features

View File

@ -0,0 +1,133 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
from mmseg.models.backbones import TIMMBackbone
from .utils import check_norm_state
def test_timm_backbone():
with pytest.raises(TypeError):
# pretrained must be a string path
model = TIMMBackbone()
model.init_weights(pretrained=0)
# Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN'
# Test resnet18 from timm, norm_layer='BN2d'
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32,
norm_layer='BN2d')
# Test resnet18 from timm, norm_layer='SyncBN'
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32,
norm_layer='SyncBN')
# Test resnet18 from timm, features_only=True, output_stride=32
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=32)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 112, 112))
assert feats[1] == torch.Size((1, 64, 56, 56))
assert feats[2] == torch.Size((1, 128, 28, 28))
assert feats[3] == torch.Size((1, 256, 14, 14))
assert feats[4] == torch.Size((1, 512, 7, 7))
# Test resnet18 from timm, features_only=True, output_stride=16
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=16)
imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 112, 112))
assert feats[1] == torch.Size((1, 64, 56, 56))
assert feats[2] == torch.Size((1, 128, 28, 28))
assert feats[3] == torch.Size((1, 256, 14, 14))
assert feats[4] == torch.Size((1, 512, 14, 14))
# Test resnet18 from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnet18',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 224, 224)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 112, 112))
assert feats[1] == torch.Size((1, 64, 56, 56))
assert feats[2] == torch.Size((1, 128, 28, 28))
assert feats[3] == torch.Size((1, 256, 28, 28))
assert feats[4] == torch.Size((1, 512, 28, 28))
# Test efficientnet_b1 with pretrained weights
model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True)
# Test resnetv2_50x1_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_50x1_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 4, 4))
assert feats[1] == torch.Size((1, 256, 2, 2))
assert feats[2] == torch.Size((1, 512, 1, 1))
assert feats[3] == torch.Size((1, 1024, 1, 1))
assert feats[4] == torch.Size((1, 2048, 1, 1))
# Test resnetv2_50x3_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_50x3_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 192, 4, 4))
assert feats[1] == torch.Size((1, 768, 2, 2))
assert feats[2] == torch.Size((1, 1536, 1, 1))
assert feats[3] == torch.Size((1, 3072, 1, 1))
assert feats[4] == torch.Size((1, 6144, 1, 1))
# Test resnetv2_101x1_bitm from timm, features_only=True, output_stride=8
model = TIMMBackbone(
model_name='resnetv2_101x1_bitm',
features_only=True,
pretrained=False,
output_stride=8)
imgs = torch.randn(1, 3, 8, 8)
feats = model(imgs)
feats = [feat.shape for feat in feats]
assert len(feats) == 5
assert feats[0] == torch.Size((1, 64, 4, 4))
assert feats[1] == torch.Size((1, 256, 2, 2))
assert feats[2] == torch.Size((1, 512, 1, 1))
assert feats[3] == torch.Size((1, 1024, 1, 1))
assert feats[4] == torch.Size((1, 2048, 1, 1))