204 lines
7.1 KiB
Python
204 lines
7.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
|
|
from mmcls.models.backbones import TIMMBackbone
|
|
|
|
|
|
def check_norm_state(modules, train_state):
|
|
"""Check if norm layer is in correct train state."""
|
|
for mod in modules:
|
|
if isinstance(mod, _BatchNorm):
|
|
if mod.training != train_state:
|
|
return False
|
|
return True
|
|
|
|
|
|
def test_timm_backbone():
|
|
"""Test timm backbones, features_only=False (default)."""
|
|
with pytest.raises(TypeError):
|
|
# TIMMBackbone has 1 required positional argument: 'model_name'
|
|
model = TIMMBackbone(pretrained=True)
|
|
|
|
with pytest.raises(TypeError):
|
|
# pretrained must be bool
|
|
model = TIMMBackbone(model_name='resnet18', pretrained='model.pth')
|
|
|
|
# Test resnet18 from timm
|
|
model = TIMMBackbone(model_name='resnet18')
|
|
model.init_weights()
|
|
model.train()
|
|
assert check_norm_state(model.modules(), True)
|
|
assert isinstance(model.timm_model.global_pool.pool, nn.Identity)
|
|
assert isinstance(model.timm_model.fc, nn.Identity)
|
|
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
feat = model(imgs)
|
|
assert len(feat) == 1
|
|
assert feat[0].shape == torch.Size((1, 512, 7, 7))
|
|
|
|
# Test efficientnet_b1 with pretrained weights
|
|
model = TIMMBackbone(model_name='efficientnet_b1', pretrained=True)
|
|
model.init_weights()
|
|
model.train()
|
|
assert isinstance(model.timm_model.global_pool.pool, nn.Identity)
|
|
assert isinstance(model.timm_model.classifier, nn.Identity)
|
|
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
feat = model(imgs)
|
|
assert len(feat) == 1
|
|
assert feat[0].shape == torch.Size((1, 1280, 7, 7))
|
|
|
|
# Test vit_tiny_patch16_224 with pretrained weights
|
|
model = TIMMBackbone(model_name='vit_tiny_patch16_224', pretrained=True)
|
|
model.init_weights()
|
|
model.train()
|
|
assert isinstance(model.timm_model.head, nn.Identity)
|
|
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
feat = model(imgs)
|
|
assert len(feat) == 1
|
|
assert feat[0].shape == torch.Size((1, 192))
|
|
|
|
|
|
def test_timm_backbone_features_only():
|
|
"""Test timm backbones, features_only=True."""
|
|
# Test different norm_layer, can be: 'SyncBN', 'BN2d', 'GN', 'LN', 'IN'
|
|
# Test resnet18 from timm, norm_layer='BN2d'
|
|
model = TIMMBackbone(
|
|
model_name='resnet18',
|
|
features_only=True,
|
|
pretrained=False,
|
|
output_stride=32,
|
|
norm_layer='BN2d')
|
|
|
|
# Test resnet18 from timm, norm_layer='SyncBN'
|
|
model = TIMMBackbone(
|
|
model_name='resnet18',
|
|
features_only=True,
|
|
pretrained=False,
|
|
output_stride=32,
|
|
norm_layer='SyncBN')
|
|
|
|
# Test resnet18 from timm, output_stride=32
|
|
model = TIMMBackbone(
|
|
model_name='resnet18',
|
|
features_only=True,
|
|
pretrained=False,
|
|
output_stride=32)
|
|
model.init_weights()
|
|
model.train()
|
|
assert check_norm_state(model.modules(), True)
|
|
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
feats = model(imgs)
|
|
assert len(feats) == 5
|
|
assert feats[0].shape == torch.Size((1, 64, 112, 112))
|
|
assert feats[1].shape == torch.Size((1, 64, 56, 56))
|
|
assert feats[2].shape == torch.Size((1, 128, 28, 28))
|
|
assert feats[3].shape == torch.Size((1, 256, 14, 14))
|
|
assert feats[4].shape == torch.Size((1, 512, 7, 7))
|
|
|
|
# Test resnet18 from timm, output_stride=32, out_indices=(1, 2, 3)
|
|
model = TIMMBackbone(
|
|
model_name='resnet18',
|
|
features_only=True,
|
|
pretrained=False,
|
|
output_stride=32,
|
|
out_indices=(1, 2, 3))
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
feats = model(imgs)
|
|
assert len(feats) == 3
|
|
assert feats[0].shape == torch.Size((1, 64, 56, 56))
|
|
assert feats[1].shape == torch.Size((1, 128, 28, 28))
|
|
assert feats[2].shape == torch.Size((1, 256, 14, 14))
|
|
|
|
# Test resnet18 from timm, output_stride=16
|
|
model = TIMMBackbone(
|
|
model_name='resnet18',
|
|
features_only=True,
|
|
pretrained=False,
|
|
output_stride=16)
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
feats = model(imgs)
|
|
assert len(feats) == 5
|
|
assert feats[0].shape == torch.Size((1, 64, 112, 112))
|
|
assert feats[1].shape == torch.Size((1, 64, 56, 56))
|
|
assert feats[2].shape == torch.Size((1, 128, 28, 28))
|
|
assert feats[3].shape == torch.Size((1, 256, 14, 14))
|
|
assert feats[4].shape == torch.Size((1, 512, 14, 14))
|
|
|
|
# Test resnet18 from timm, output_stride=8
|
|
model = TIMMBackbone(
|
|
model_name='resnet18',
|
|
features_only=True,
|
|
pretrained=False,
|
|
output_stride=8)
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
feats = model(imgs)
|
|
assert len(feats) == 5
|
|
assert feats[0].shape == torch.Size((1, 64, 112, 112))
|
|
assert feats[1].shape == torch.Size((1, 64, 56, 56))
|
|
assert feats[2].shape == torch.Size((1, 128, 28, 28))
|
|
assert feats[3].shape == torch.Size((1, 256, 28, 28))
|
|
assert feats[4].shape == torch.Size((1, 512, 28, 28))
|
|
|
|
# Test efficientnet_b1 with pretrained weights
|
|
model = TIMMBackbone(
|
|
model_name='efficientnet_b1', features_only=True, pretrained=True)
|
|
imgs = torch.randn(1, 3, 64, 64)
|
|
feats = model(imgs)
|
|
assert len(feats) == 5
|
|
assert feats[0].shape == torch.Size((1, 16, 32, 32))
|
|
assert feats[1].shape == torch.Size((1, 24, 16, 16))
|
|
assert feats[2].shape == torch.Size((1, 40, 8, 8))
|
|
assert feats[3].shape == torch.Size((1, 112, 4, 4))
|
|
assert feats[4].shape == torch.Size((1, 320, 2, 2))
|
|
|
|
# Test resnetv2_50x1_bitm from timm, output_stride=8
|
|
model = TIMMBackbone(
|
|
model_name='resnetv2_50x1_bitm',
|
|
features_only=True,
|
|
pretrained=False,
|
|
output_stride=8)
|
|
imgs = torch.randn(1, 3, 8, 8)
|
|
feats = model(imgs)
|
|
assert len(feats) == 5
|
|
assert feats[0].shape == torch.Size((1, 64, 4, 4))
|
|
assert feats[1].shape == torch.Size((1, 256, 2, 2))
|
|
assert feats[2].shape == torch.Size((1, 512, 1, 1))
|
|
assert feats[3].shape == torch.Size((1, 1024, 1, 1))
|
|
assert feats[4].shape == torch.Size((1, 2048, 1, 1))
|
|
|
|
# Test resnetv2_50x3_bitm from timm, output_stride=8
|
|
model = TIMMBackbone(
|
|
model_name='resnetv2_50x3_bitm',
|
|
features_only=True,
|
|
pretrained=False,
|
|
output_stride=8)
|
|
imgs = torch.randn(1, 3, 8, 8)
|
|
feats = model(imgs)
|
|
assert len(feats) == 5
|
|
assert feats[0].shape == torch.Size((1, 192, 4, 4))
|
|
assert feats[1].shape == torch.Size((1, 768, 2, 2))
|
|
assert feats[2].shape == torch.Size((1, 1536, 1, 1))
|
|
assert feats[3].shape == torch.Size((1, 3072, 1, 1))
|
|
assert feats[4].shape == torch.Size((1, 6144, 1, 1))
|
|
|
|
# Test resnetv2_101x1_bitm from timm, output_stride=8
|
|
model = TIMMBackbone(
|
|
model_name='resnetv2_101x1_bitm',
|
|
features_only=True,
|
|
pretrained=False,
|
|
output_stride=8)
|
|
imgs = torch.randn(1, 3, 8, 8)
|
|
feats = model(imgs)
|
|
assert len(feats) == 5
|
|
assert feats[0].shape == torch.Size((1, 64, 4, 4))
|
|
assert feats[1].shape == torch.Size((1, 256, 2, 2))
|
|
assert feats[2].shape == torch.Size((1, 512, 1, 1))
|
|
assert feats[3].shape == torch.Size((1, 1024, 1, 1))
|
|
assert feats[4].shape == torch.Size((1, 2048, 1, 1))
|