89 lines
2.7 KiB
Python
89 lines
2.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import tempfile
|
|
from copy import deepcopy
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from mmrazor.models.builder import ARCHITECTURES
|
|
|
|
|
|
def test_architecture_mmcls():
|
|
model_cfg = dict(
|
|
dict(
|
|
type='mmcls.ImageClassifier',
|
|
backbone=dict(
|
|
type='mmcls.ResNet_CIFAR',
|
|
depth=50,
|
|
num_stages=4,
|
|
out_indices=(3, ),
|
|
style='pytorch'),
|
|
neck=dict(type='mmcls.GlobalAveragePooling'),
|
|
head=dict(
|
|
type='mmcls.LinearClsHead',
|
|
num_classes=10,
|
|
in_channels=2048,
|
|
loss=dict(type='CrossEntropyLoss'))), )
|
|
|
|
architecture_cfg = dict(type='MMClsArchitecture', model=model_cfg)
|
|
|
|
imgs = torch.randn(16, 3, 32, 32)
|
|
label = torch.randint(0, 10, (16, ))
|
|
|
|
supernet_cfg_ = deepcopy(architecture_cfg)
|
|
architecture = ARCHITECTURES.build(supernet_cfg_)
|
|
|
|
# test property
|
|
assert architecture.model.with_neck
|
|
assert architecture.model.with_head
|
|
|
|
# test train_step
|
|
outputs = architecture.model.train_step({
|
|
'img': imgs,
|
|
'gt_label': label
|
|
}, None)
|
|
assert outputs['loss'].item() > 0
|
|
assert outputs['num_samples'] == 16
|
|
|
|
# test val_step
|
|
outputs = architecture.model.val_step({
|
|
'img': imgs,
|
|
'gt_label': label
|
|
}, None)
|
|
assert outputs['loss'].item() > 0
|
|
assert outputs['num_samples'] == 16
|
|
|
|
# test forward
|
|
losses = architecture(imgs, return_loss=True, gt_label=label)
|
|
assert losses['loss'].item() > 0
|
|
|
|
# test forward_test
|
|
architecture_cfg_ = deepcopy(architecture_cfg)
|
|
architecture = ARCHITECTURES.build(architecture_cfg_)
|
|
pred = architecture(imgs, return_loss=False, img_metas=None)
|
|
assert isinstance(pred, list) and len(pred) == 16
|
|
|
|
single_img = torch.randn(1, 3, 32, 32)
|
|
pred = architecture(single_img, return_loss=False, img_metas=None)
|
|
assert isinstance(pred, list) and len(pred) == 1
|
|
|
|
# test simple_test
|
|
single_img = torch.randn(1, 3, 32, 32)
|
|
pred = architecture.simple_test(single_img, img_metas=None)
|
|
assert isinstance(pred, list) and len(pred) == 1
|
|
|
|
# test show_result
|
|
img = np.random.randint(0, 255, (224, 224, 3)).astype(np.uint8)
|
|
result = dict(pred_class='cat', pred_label=0, pred_score=0.9)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
out_file = osp.join(tmpdir, 'out.png')
|
|
architecture.show_result(img, result, out_file=out_file)
|
|
assert osp.exists(out_file)
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
out_file = osp.join(tmpdir, 'out.png')
|
|
architecture.show_result(img, result, out_file=out_file)
|
|
assert osp.exists(out_file)
|