96 lines
2.6 KiB
Python
96 lines
2.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from dataclasses import dataclass
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import mmpretrain.models
|
|
from mmpretrain.apis import ModelHub, get_model
|
|
|
|
|
|
@dataclass
|
|
class Cfg:
|
|
name: str
|
|
backbone: type
|
|
num_classes: int = 1000
|
|
build: bool = True
|
|
forward: bool = True
|
|
backward: bool = True
|
|
extract_feat: bool = True
|
|
input_shape: tuple = (1, 3, 224, 224)
|
|
|
|
|
|
test_list = [
|
|
Cfg(name='xcit-small-12-p16_3rdparty_in1k',
|
|
backbone=mmpretrain.models.XCiT),
|
|
Cfg(name='xcit-nano-12-p8_3rdparty-dist_in1k-384px',
|
|
backbone=mmpretrain.models.XCiT,
|
|
input_shape=(1, 3, 384, 384)),
|
|
Cfg(name='vit-base-p16_sam-pre_3rdparty_sa1b-1024px',
|
|
backbone=mmpretrain.models.ViTSAM,
|
|
forward=False,
|
|
backward=False),
|
|
Cfg(name='vit-base-p14_dinov2-pre_3rdparty',
|
|
backbone=mmpretrain.models.VisionTransformer,
|
|
forward=False,
|
|
backward=False),
|
|
Cfg(name='hivit-tiny-p16_16xb64_in1k', backbone=mmpretrain.models.HiViT),
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize('cfg', test_list)
|
|
def test_build(cfg: Cfg):
|
|
if not cfg.build:
|
|
return
|
|
|
|
model_name = cfg.name
|
|
ModelHub._register_mmpretrain_models()
|
|
assert ModelHub.has(model_name)
|
|
|
|
model = get_model(model_name)
|
|
backbone_class = cfg.backbone
|
|
assert isinstance(model.backbone, backbone_class)
|
|
|
|
|
|
@pytest.mark.parametrize('cfg', test_list)
|
|
def test_forward(cfg: Cfg):
|
|
if not cfg.forward:
|
|
return
|
|
|
|
model = get_model(cfg.name)
|
|
inputs = torch.rand(*cfg.input_shape)
|
|
outputs = model(inputs)
|
|
assert outputs.shape == (1, cfg.num_classes)
|
|
|
|
|
|
@pytest.mark.parametrize('cfg', test_list)
|
|
def test_extract_feat(cfg: Cfg):
|
|
if not cfg.extract_feat:
|
|
return
|
|
|
|
model = get_model(cfg.name)
|
|
inputs = torch.rand(*cfg.input_shape)
|
|
feats = model.extract_feat(inputs)
|
|
assert isinstance(feats, tuple)
|
|
assert len(feats) == 1
|
|
|
|
|
|
@pytest.mark.parametrize('cfg', test_list)
|
|
def test_backward(cfg: Cfg):
|
|
if not cfg.backward:
|
|
return
|
|
|
|
model = get_model(cfg.name)
|
|
inputs = torch.rand(*cfg.input_shape)
|
|
outputs = model(inputs)
|
|
outputs.mean().backward()
|
|
|
|
for n, x in model.named_parameters():
|
|
assert x.grad is not None, f'No gradient for {n}'
|
|
num_grad = sum(
|
|
[x.grad.numel() for x in model.parameters() if x.grad is not None])
|
|
assert outputs.shape[-1] == cfg.num_classes
|
|
num_params = sum([x.numel() for x in model.parameters()])
|
|
assert num_params == num_grad, 'Some parameters are missing gradients'
|
|
assert not torch.isnan(outputs).any(), 'Output included NaNs'
|