mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Support TIMMBackbone (#998)
* add TIMMBackbone and unittests * add timm to tests requirements * deprecate pt1.3.1 * reduce the unittests input of timm backbone * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * fix ci * remove unittests of large models of timm backbone * generate coverage report for all unittests env * reduce the unittests input of timm backbone * reduce the unittests input of timm backbone
This commit is contained in:
parent
ddce375977
commit
54435fb149
17
.github/workflows/build.yml
vendored
17
.github/workflows/build.yml
vendored
@ -71,9 +71,17 @@ jobs:
|
||||
run: rm -rf .eggs && pip install -e .
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
pip install timm
|
||||
coverage run --branch --source mmseg -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
if: ${{matrix.torch >= '1.5.0'}}
|
||||
- name: Skip timm unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
|
||||
coverage xml
|
||||
coverage report -m
|
||||
if: ${{matrix.torch < '1.5.0'}}
|
||||
|
||||
build_cuda101:
|
||||
runs-on: ubuntu-18.04
|
||||
@ -142,9 +150,17 @@ jobs:
|
||||
TORCH_CUDA_ARCH_LIST=7.0 pip install .
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
python -m pip install timm
|
||||
coverage run --branch --source mmseg -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
if: ${{matrix.torch >= '1.5.0'}}
|
||||
- name: Skip timm unittests and generate coverage report
|
||||
run: |
|
||||
coverage run --branch --source mmseg -m pytest tests/ --ignore tests/test_models/test_backbones/test_timm_backbone.py
|
||||
coverage xml
|
||||
coverage report -m
|
||||
if: ${{matrix.torch < '1.5.0'}}
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v1.0.10
|
||||
with:
|
||||
@ -198,6 +214,7 @@ jobs:
|
||||
TORCH_CUDA_ARCH_LIST=7.0 pip install .
|
||||
- name: Run unittests and generate coverage report
|
||||
run: |
|
||||
python -m pip install timm
|
||||
coverage run --branch --source mmseg -m pytest tests/
|
||||
coverage xml
|
||||
coverage report -m
|
||||
|
@ -12,6 +12,7 @@ from .resnest import ResNeSt
|
||||
from .resnet import ResNet, ResNetV1c, ResNetV1d
|
||||
from .resnext import ResNeXt
|
||||
from .swin import SwinTransformer
|
||||
from .timm_backbone import TIMMBackbone
|
||||
from .unet import UNet
|
||||
from .vit import VisionTransformer
|
||||
|
||||
@ -19,5 +20,5 @@ __all__ = [
|
||||
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
|
||||
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
|
||||
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
|
||||
'BiSeNetV1', 'BiSeNetV2', 'ICNet'
|
||||
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone'
|
||||
]
|
||||
|
63
mmseg/models/backbones/timm_backbone.py
Normal file
63
mmseg/models/backbones/timm_backbone.py
Normal file
@ -0,0 +1,63 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
try:
|
||||
import timm
|
||||
except ImportError:
|
||||
timm = None
|
||||
|
||||
from mmcv.cnn.bricks.registry import NORM_LAYERS
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from ..builder import BACKBONES
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class TIMMBackbone(BaseModule):
|
||||
"""Wrapper to use backbones from timm library. More details can be found in
|
||||
`timm <https://github.com/rwightman/pytorch-image-models>`_ .
|
||||
|
||||
Args:
|
||||
model_name (str): Name of timm model to instantiate.
|
||||
pretrained (bool): Load pretrained weights if True.
|
||||
checkpoint_path (str): Path of checkpoint to load after
|
||||
model is initialized.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
init_cfg (dict, optional): Initialization config dict
|
||||
**kwargs: Other timm & model specific arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
features_only=True,
|
||||
pretrained=True,
|
||||
checkpoint_path='',
|
||||
in_channels=3,
|
||||
init_cfg=None,
|
||||
**kwargs,
|
||||
):
|
||||
if timm is None:
|
||||
raise RuntimeError('timm is not installed')
|
||||
super(TIMMBackbone, self).__init__(init_cfg)
|
||||
if 'norm_layer' in kwargs:
|
||||
kwargs['norm_layer'] = NORM_LAYERS.get(kwargs['norm_layer'])
|
||||
self.timm_model = timm.create_model(
|
||||
model_name=model_name,
|
||||
features_only=features_only,
|
||||
pretrained=pretrained,
|
||||
in_chans=in_channels,
|
||||
checkpoint_path=checkpoint_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Make unused parameters None
|
||||
self.timm_model.global_pool = None
|
||||
self.timm_model.fc = None
|
||||
self.timm_model.classifier = None
|
||||
|
||||
# Hack to use pretrained weights from timm
|
||||
if pretrained or checkpoint_path:
|
||||
self._is_init = True
|
||||
|
||||
def forward(self, x):
|
||||
features = self.timm_model(x)
|
||||
return features
|
133
tests/test_models/test_backbones/test_timm_backbone.py
Normal file
133
tests/test_models/test_backbones/test_timm_backbone.py
Normal file
@ -0,0 +1,133 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import TIMMBackbone
|
||||
from .utils import check_norm_state
|
||||
|
||||
|
||||
def test_timm_backbone():
|
||||
with pytest.raises(TypeError):
|
||||
# pretrained must be a string path
|
||||
model = TIMMBackbone()
|
||||
model.init_weights(pretrained=0)
|
||||
|
||||
# Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN'
|
||||
# Test resnet18 from timm, norm_layer='BN2d'
|
||||
model = TIMMBackbone(
|
||||
model_name='resnet18',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=32,
|
||||
norm_layer='BN2d')
|
||||
|
||||
# Test resnet18 from timm, norm_layer='SyncBN'
|
||||
model = TIMMBackbone(
|
||||
model_name='resnet18',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=32,
|
||||
norm_layer='SyncBN')
|
||||
|
||||
# Test resnet18 from timm, features_only=True, output_stride=32
|
||||
model = TIMMBackbone(
|
||||
model_name='resnet18',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=32)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), True)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feats = model(imgs)
|
||||
feats = [feat.shape for feat in feats]
|
||||
assert len(feats) == 5
|
||||
assert feats[0] == torch.Size((1, 64, 112, 112))
|
||||
assert feats[1] == torch.Size((1, 64, 56, 56))
|
||||
assert feats[2] == torch.Size((1, 128, 28, 28))
|
||||
assert feats[3] == torch.Size((1, 256, 14, 14))
|
||||
assert feats[4] == torch.Size((1, 512, 7, 7))
|
||||
|
||||
# Test resnet18 from timm, features_only=True, output_stride=16
|
||||
model = TIMMBackbone(
|
||||
model_name='resnet18',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=16)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feats = model(imgs)
|
||||
feats = [feat.shape for feat in feats]
|
||||
assert len(feats) == 5
|
||||
assert feats[0] == torch.Size((1, 64, 112, 112))
|
||||
assert feats[1] == torch.Size((1, 64, 56, 56))
|
||||
assert feats[2] == torch.Size((1, 128, 28, 28))
|
||||
assert feats[3] == torch.Size((1, 256, 14, 14))
|
||||
assert feats[4] == torch.Size((1, 512, 14, 14))
|
||||
|
||||
# Test resnet18 from timm, features_only=True, output_stride=8
|
||||
model = TIMMBackbone(
|
||||
model_name='resnet18',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=8)
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feats = model(imgs)
|
||||
feats = [feat.shape for feat in feats]
|
||||
assert len(feats) == 5
|
||||
assert feats[0] == torch.Size((1, 64, 112, 112))
|
||||
assert feats[1] == torch.Size((1, 64, 56, 56))
|
||||
assert feats[2] == torch.Size((1, 128, 28, 28))
|
||||
assert feats[3] == torch.Size((1, 256, 28, 28))
|
||||
assert feats[4] == torch.Size((1, 512, 28, 28))
|
||||
|
||||
# Test efficientnet_b1 with pretrained weights
|
||||
model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True)
|
||||
|
||||
# Test resnetv2_50x1_bitm from timm, features_only=True, output_stride=8
|
||||
model = TIMMBackbone(
|
||||
model_name='resnetv2_50x1_bitm',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=8)
|
||||
imgs = torch.randn(1, 3, 8, 8)
|
||||
feats = model(imgs)
|
||||
feats = [feat.shape for feat in feats]
|
||||
assert len(feats) == 5
|
||||
assert feats[0] == torch.Size((1, 64, 4, 4))
|
||||
assert feats[1] == torch.Size((1, 256, 2, 2))
|
||||
assert feats[2] == torch.Size((1, 512, 1, 1))
|
||||
assert feats[3] == torch.Size((1, 1024, 1, 1))
|
||||
assert feats[4] == torch.Size((1, 2048, 1, 1))
|
||||
|
||||
# Test resnetv2_50x3_bitm from timm, features_only=True, output_stride=8
|
||||
model = TIMMBackbone(
|
||||
model_name='resnetv2_50x3_bitm',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=8)
|
||||
imgs = torch.randn(1, 3, 8, 8)
|
||||
feats = model(imgs)
|
||||
feats = [feat.shape for feat in feats]
|
||||
assert len(feats) == 5
|
||||
assert feats[0] == torch.Size((1, 192, 4, 4))
|
||||
assert feats[1] == torch.Size((1, 768, 2, 2))
|
||||
assert feats[2] == torch.Size((1, 1536, 1, 1))
|
||||
assert feats[3] == torch.Size((1, 3072, 1, 1))
|
||||
assert feats[4] == torch.Size((1, 6144, 1, 1))
|
||||
|
||||
# Test resnetv2_101x1_bitm from timm, features_only=True, output_stride=8
|
||||
model = TIMMBackbone(
|
||||
model_name='resnetv2_101x1_bitm',
|
||||
features_only=True,
|
||||
pretrained=False,
|
||||
output_stride=8)
|
||||
imgs = torch.randn(1, 3, 8, 8)
|
||||
feats = model(imgs)
|
||||
feats = [feat.shape for feat in feats]
|
||||
assert len(feats) == 5
|
||||
assert feats[0] == torch.Size((1, 64, 4, 4))
|
||||
assert feats[1] == torch.Size((1, 256, 2, 2))
|
||||
assert feats[2] == torch.Size((1, 512, 1, 1))
|
||||
assert feats[3] == torch.Size((1, 1024, 1, 1))
|
||||
assert feats[4] == torch.Size((1, 2048, 1, 1))
|
Loading…
x
Reference in New Issue
Block a user