mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Unittest]: Test mmcls (#135)
* Add tests * lint * add data * Remove redundant code
This commit is contained in:
parent
a7111eddb6
commit
e240c1569f
0
tests/test_mmcls/data/imgs/ann.txt
Normal file
0
tests/test_mmcls/data/imgs/ann.txt
Normal file
BIN
tests/test_mmcls/data/imgs/blank.jpg
Normal file
BIN
tests/test_mmcls/data/imgs/blank.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 691 B |
39
tests/test_mmcls/data/model.py
Normal file
39
tests/test_mmcls/data/model.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# model settings
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNet',
|
||||||
|
depth=50,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(3, ),
|
||||||
|
style='pytorch'),
|
||||||
|
neck=dict(type='GlobalAveragePooling'),
|
||||||
|
head=dict(
|
||||||
|
type='LinearClsHead',
|
||||||
|
num_classes=1000,
|
||||||
|
in_channels=2048,
|
||||||
|
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||||
|
topk=(1, 5),
|
||||||
|
))
|
||||||
|
|
||||||
|
# dataset settings
|
||||||
|
dataset_type = 'ImageNet'
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='Resize', size=(256, -1)),
|
||||||
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img'])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=32,
|
||||||
|
workers_per_gpu=2,
|
||||||
|
test=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_prefix='data/imagenet/val',
|
||||||
|
ann_file='data/imagenet/meta/val.txt',
|
||||||
|
pipeline=test_pipeline))
|
||||||
|
evaluation = dict(interval=1, metric='accuracy')
|
196
tests/test_mmcls/test_mmcls_apis.py
Normal file
196
tests/test_mmcls/test_mmcls_apis.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import mmdeploy.apis.ncnn as ncnn_apis
|
||||||
|
import mmdeploy.apis.onnxruntime as ort_apis
|
||||||
|
import mmdeploy.apis.ppl as ppl_apis
|
||||||
|
import mmdeploy.apis.tensorrt as trt_apis
|
||||||
|
import mmdeploy.apis.utils as api_utils
|
||||||
|
from mmdeploy.utils.constants import Backend, Codebase
|
||||||
|
from mmdeploy.utils.test import SwitchBackendWrapper
|
||||||
|
|
||||||
|
model_cfg = 'tests/test_mmcls/data/model.py'
|
||||||
|
deploy_cfg = mmcv.Config(
|
||||||
|
dict(
|
||||||
|
backend_config=dict(type='onnxruntime'),
|
||||||
|
codebase_config=dict(type='mmcls', task='Classification'),
|
||||||
|
onnx_config=dict(
|
||||||
|
type='onnx',
|
||||||
|
export_params=True,
|
||||||
|
keep_initializers_as_inputs=False,
|
||||||
|
opset_version=11,
|
||||||
|
input_shape=None,
|
||||||
|
input_names=['input'],
|
||||||
|
output_names=['output'])))
|
||||||
|
input_img = torch.rand(1, 3, 64, 64)
|
||||||
|
input = {'img': input_img}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('tensorrt'), reason='requires tensorrt')
|
||||||
|
def test_TensorRTClassifier():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
from mmdeploy.apis.tensorrt.tensorrt_utils import TRTWrapper
|
||||||
|
trt_apis.__dict__.update({'TRTWrapper': TRTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = {
|
||||||
|
'output': torch.rand(1, 3, 64, 64).cuda(),
|
||||||
|
}
|
||||||
|
|
||||||
|
with SwitchBackendWrapper(TRTWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
|
||||||
|
from mmdeploy.mmcls.apis.inference import TensorRTClassifier
|
||||||
|
trt_classifier = TensorRTClassifier('', [''], 0)
|
||||||
|
imgs = torch.rand(1, 3, 64, 64).cuda()
|
||||||
|
|
||||||
|
results = trt_classifier.forward(imgs, return_loss=False)
|
||||||
|
assert results is not None, ('failed to get output using '
|
||||||
|
'TensorRTClassifier')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_ONNXRuntimeClassifier():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = torch.rand(1, 3, 64, 64)
|
||||||
|
|
||||||
|
with SwitchBackendWrapper(ORTWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
|
||||||
|
from mmdeploy.mmcls.apis.inference import ONNXRuntimeClassifier
|
||||||
|
ort_classifier = ONNXRuntimeClassifier('', [''], 0)
|
||||||
|
imgs = torch.rand(1, 3, 64, 64)
|
||||||
|
|
||||||
|
results = ort_classifier.forward(imgs, return_loss=False)
|
||||||
|
assert results is not None, 'failed to get output using '\
|
||||||
|
'ONNXRuntimeClassifier'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda')
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('pyppl'), reason='requires pyppl')
|
||||||
|
def test_PPLClassifier():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
from mmdeploy.apis.ppl.ppl_utils import PPLWrapper
|
||||||
|
ppl_apis.__dict__.update({'PPLWrapper': PPLWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = torch.rand(1, 3, 64, 64)
|
||||||
|
|
||||||
|
with SwitchBackendWrapper(PPLWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
|
||||||
|
from mmdeploy.mmcls.apis.inference import PPLClassifier
|
||||||
|
ppl_classifier = PPLClassifier('', [''], 0)
|
||||||
|
imgs = torch.rand(1, 3, 64, 64)
|
||||||
|
|
||||||
|
results = ppl_classifier.forward(imgs, return_loss=False)
|
||||||
|
assert results is not None, 'failed to get output using PPLClassifier'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('ncnn'), reason='requires ncnn')
|
||||||
|
def test_NCNNClassifier():
|
||||||
|
# force add backend wrapper regardless of plugins
|
||||||
|
from mmdeploy.apis.ncnn.ncnn_utils import NCNNWrapper
|
||||||
|
ncnn_apis.__dict__.update({'NCNNWrapper': NCNNWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
outputs = {'output': torch.rand(1, 3, 64, 64)}
|
||||||
|
|
||||||
|
with SwitchBackendWrapper(NCNNWrapper) as wrapper:
|
||||||
|
wrapper.set(outputs=outputs)
|
||||||
|
|
||||||
|
from mmdeploy.mmcls.apis.inference import NCNNClassifier
|
||||||
|
ncnn_classifier = NCNNClassifier('', '', [''], 0)
|
||||||
|
imgs = torch.rand(1, 3, 64, 64)
|
||||||
|
|
||||||
|
results = ncnn_classifier.forward(imgs, return_loss=False)
|
||||||
|
assert results is not None, 'failed to get output using NCNNClassifier'
|
||||||
|
|
||||||
|
|
||||||
|
def test_init_pytorch_model():
|
||||||
|
model = api_utils.init_pytorch_model(
|
||||||
|
Codebase.MMCLS, model_cfg=model_cfg, device='cpu')
|
||||||
|
assert model is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def create_backend_model():
|
||||||
|
from mmdeploy.apis.onnxruntime.onnxruntime_utils import ORTWrapper
|
||||||
|
ort_apis.__dict__.update({'ORTWrapper': ORTWrapper})
|
||||||
|
|
||||||
|
# simplify backend inference
|
||||||
|
|
||||||
|
wrapper = SwitchBackendWrapper(ORTWrapper)
|
||||||
|
wrapper.set(outputs=[[1]], model_cfg=model_cfg, deploy_cfg=deploy_cfg)
|
||||||
|
model = api_utils.init_backend_model([''], model_cfg, deploy_cfg)
|
||||||
|
|
||||||
|
return model, wrapper
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_init_backend_model():
|
||||||
|
model, wrapper = create_backend_model()
|
||||||
|
assert model is not None
|
||||||
|
|
||||||
|
# Recovery
|
||||||
|
wrapper.recover()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_run_inference():
|
||||||
|
model, wrapper = create_backend_model()
|
||||||
|
result = api_utils.run_inference(Codebase.MMCLS, input, model)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
# Recovery
|
||||||
|
wrapper.recover()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_visualize():
|
||||||
|
numpy_img = np.random.rand(64, 64, 3)
|
||||||
|
model, wrapper = create_backend_model()
|
||||||
|
result = api_utils.run_inference(Codebase.MMCLS, input, model)
|
||||||
|
with tempfile.TemporaryDirectory() as dir:
|
||||||
|
filename = dir + 'tmp.jpg'
|
||||||
|
api_utils.visualize(Codebase.MMCLS, numpy_img, result, model, filename,
|
||||||
|
Backend.ONNXRUNTIME)
|
||||||
|
assert os.path.exists(filename)
|
||||||
|
|
||||||
|
# Recovery
|
||||||
|
wrapper.recover()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not importlib.util.find_spec('onnxruntime'), reason='requires onnxruntime')
|
||||||
|
def test_inference_model():
|
||||||
|
numpy_img = np.random.rand(64, 64, 3)
|
||||||
|
with tempfile.TemporaryDirectory() as dir:
|
||||||
|
filename = dir + 'tmp.jpg'
|
||||||
|
model, wrapper = create_backend_model()
|
||||||
|
from mmdeploy.apis.inference import inference_model
|
||||||
|
inference_model(model_cfg, deploy_cfg, model, numpy_img, 'cpu',
|
||||||
|
Backend.ONNXRUNTIME, filename, False)
|
||||||
|
assert os.path.exists(filename)
|
||||||
|
|
||||||
|
# Recovery
|
||||||
|
wrapper.recover()
|
84
tests/test_mmcls/test_mmcls_export.py
Normal file
84
tests/test_mmcls/test_mmcls_export.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mmdeploy.apis.utils import build_dataloader, build_dataset, create_input
|
||||||
|
from mmdeploy.utils.constants import Codebase, Task
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateInput:
|
||||||
|
task = Task.CLASSIFICATION
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53],
|
||||||
|
std=[58.395, 57.12, 57.375],
|
||||||
|
to_rgb=True)
|
||||||
|
img_test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='Resize', size=(256, -1)),
|
||||||
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img'])
|
||||||
|
]
|
||||||
|
|
||||||
|
imgs = np.random.rand(32, 32, 3)
|
||||||
|
img_path = 'tests/test_mmcls/data/imgs/blank.jpg'
|
||||||
|
|
||||||
|
def test_create_input_static(this):
|
||||||
|
data = dict(test=dict(pipeline=TestCreateInput.img_test_pipeline))
|
||||||
|
model_cfg = mmcv.Config(
|
||||||
|
dict(data=data, test_pipeline=TestCreateInput.img_test_pipeline))
|
||||||
|
inputs = create_input(
|
||||||
|
Codebase.MMCLS,
|
||||||
|
TestCreateInput.task,
|
||||||
|
model_cfg,
|
||||||
|
TestCreateInput.imgs,
|
||||||
|
input_shape=(32, 32),
|
||||||
|
device='cpu')
|
||||||
|
assert inputs is not None, 'Failed to create input'
|
||||||
|
|
||||||
|
def test_create_input_dynamic(this):
|
||||||
|
data = dict(test=dict(pipeline=TestCreateInput.img_test_pipeline))
|
||||||
|
model_cfg = mmcv.Config(
|
||||||
|
dict(data=data, test_pipeline=TestCreateInput.img_test_pipeline))
|
||||||
|
inputs = create_input(
|
||||||
|
Codebase.MMCLS,
|
||||||
|
TestCreateInput.task,
|
||||||
|
model_cfg,
|
||||||
|
TestCreateInput.imgs,
|
||||||
|
input_shape=None,
|
||||||
|
device='cpu')
|
||||||
|
assert inputs is not None, 'Failed to create input'
|
||||||
|
|
||||||
|
def test_create_input_from_file(this):
|
||||||
|
data = dict(test=dict(pipeline=TestCreateInput.img_test_pipeline))
|
||||||
|
model_cfg = mmcv.Config(
|
||||||
|
dict(data=data, test_pipeline=TestCreateInput.img_test_pipeline))
|
||||||
|
inputs = create_input(
|
||||||
|
Codebase.MMCLS,
|
||||||
|
TestCreateInput.task,
|
||||||
|
model_cfg,
|
||||||
|
TestCreateInput.img_path,
|
||||||
|
input_shape=None,
|
||||||
|
device='cpu')
|
||||||
|
assert inputs is not None, 'Failed to create input'
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_dataset():
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=1,
|
||||||
|
workers_per_gpu=1,
|
||||||
|
test=dict(
|
||||||
|
type='ImageNet',
|
||||||
|
data_prefix='tests/test_mmcls/data/imgs',
|
||||||
|
ann_file='tests/test_mmcls/data/imgs/ann.txt',
|
||||||
|
pipeline=[
|
||||||
|
{
|
||||||
|
'type': 'LoadImageFromFile'
|
||||||
|
},
|
||||||
|
]))
|
||||||
|
dataset_cfg = mmcv.Config(dict(data=data))
|
||||||
|
dataset = build_dataset(
|
||||||
|
Codebase.MMCLS, dataset_cfg=dataset_cfg, dataset_type='test')
|
||||||
|
assert dataset is not None, 'Failed to build dataset'
|
||||||
|
dataloader = build_dataloader(Codebase.MMCLS, dataset, 1, 1)
|
||||||
|
assert dataloader is not None, 'Failed to build dataloader'
|
57
tests/test_mmcls/test_mmcls_models.py
Normal file
57
tests/test_mmcls/test_mmcls_models.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import mmcv
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmdeploy.core import RewriterContext
|
||||||
|
from mmdeploy.utils.test import WrapModel
|
||||||
|
|
||||||
|
input = torch.rand(1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_baseclassfier_forward():
|
||||||
|
from mmcls.models.classifiers import BaseClassifier
|
||||||
|
|
||||||
|
class DummyClassifier(BaseClassifier):
|
||||||
|
|
||||||
|
def __init__(self, init_cfg=None):
|
||||||
|
super().__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
|
def extract_feat(self, imgs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def forward_train(self, imgs):
|
||||||
|
return 'train'
|
||||||
|
|
||||||
|
def simple_test(self, img, tmp, **kwargs):
|
||||||
|
return 'simple_test'
|
||||||
|
|
||||||
|
model = DummyClassifier().eval()
|
||||||
|
|
||||||
|
model_output = model(input)
|
||||||
|
with RewriterContext(
|
||||||
|
cfg=mmcv.Config(dict()), backend='onnxruntime'), torch.no_grad():
|
||||||
|
backend_output = model(input)
|
||||||
|
|
||||||
|
assert model_output == 'train'
|
||||||
|
assert backend_output == 'simple_test'
|
||||||
|
|
||||||
|
|
||||||
|
def test_cls_head():
|
||||||
|
from mmcls.models.heads.cls_head import ClsHead
|
||||||
|
model = WrapModel(ClsHead(), 'post_process').eval()
|
||||||
|
model_output = model(input)
|
||||||
|
with RewriterContext(
|
||||||
|
cfg=mmcv.Config(dict()), backend='onnxruntime'), torch.no_grad():
|
||||||
|
backend_output = model(input)
|
||||||
|
|
||||||
|
assert list(backend_output.detach().cpu().numpy()) == model_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_multilabel_cls_head():
|
||||||
|
from mmcls.models.heads.multi_label_head import MultiLabelClsHead
|
||||||
|
model = WrapModel(MultiLabelClsHead(), 'post_process').eval()
|
||||||
|
model_output = model(input)
|
||||||
|
with RewriterContext(
|
||||||
|
cfg=mmcv.Config(dict()), backend='onnxruntime'), torch.no_grad():
|
||||||
|
backend_output = model(input)
|
||||||
|
|
||||||
|
assert list(backend_output.detach().cpu().numpy()) == model_output
|
Loading…
x
Reference in New Issue
Block a user