230 lines
8.2 KiB
Python
230 lines
8.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import os
|
|
import os.path as osp
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from mmengine.infer import BaseInferencer
|
|
from mmengine.registry import VISUALIZERS, DefaultScope
|
|
from mmengine.testing import RunnerTestCase
|
|
from mmengine.utils import is_list_of
|
|
from mmengine.visualization import Visualizer
|
|
|
|
|
|
def is_imported(package):
|
|
try:
|
|
__import__(package)
|
|
return True
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
class ToyInferencer(BaseInferencer):
|
|
preprocess_kwargs = {'pre_arg'}
|
|
forward_kwargs = {'for_arg'}
|
|
visualize_kwargs = {'vis_arg'}
|
|
postprocess_kwargs = {'pos_arg'}
|
|
|
|
def preprocess(self, inputs, batch_size=1, pre_arg=None, **kwargs):
|
|
return super().preprocess(inputs, batch_size, **kwargs)
|
|
|
|
def forward(self, inputs, for_arg=None, **kwargs):
|
|
return super().forward(inputs, **kwargs)
|
|
|
|
def visualize(self, inputs, preds, vis_arg=None, **kwargs):
|
|
return inputs
|
|
|
|
def postprocess(self,
|
|
preds,
|
|
imgs,
|
|
return_datasamples,
|
|
pos_arg=None,
|
|
**kwargs):
|
|
return imgs, preds
|
|
|
|
def _init_pipeline(self, cfg):
|
|
|
|
def pipeline(img):
|
|
if isinstance(img, str):
|
|
img = np.load(img, allow_pickle=True)
|
|
img = torch.from_numpy(img).float()
|
|
elif isinstance(img, np.ndarray):
|
|
img = torch.from_numpy(img).float()
|
|
else:
|
|
img = torch.tensor(img).float()
|
|
return img
|
|
|
|
return pipeline
|
|
|
|
|
|
class ToyVisualizer(Visualizer):
|
|
...
|
|
|
|
|
|
class TestBaseInferencer(RunnerTestCase):
|
|
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
runner = self.build_runner(copy.deepcopy(self.epoch_based_cfg))
|
|
runner.train()
|
|
self.cfg_path = osp.join(runner.work_dir, f'{runner.timestamp}.py')
|
|
self.ckpt_path = osp.join(runner.work_dir, 'epoch_1.pth')
|
|
VISUALIZERS.register_module(module=ToyVisualizer, name='ToyVisualizer')
|
|
|
|
def test_custom_inferencer(self):
|
|
# Inferencer should not define ***_kwargs with duplicate keys.
|
|
with self.assertRaisesRegex(AssertionError, 'Class define error'):
|
|
|
|
class CustomInferencer(BaseInferencer):
|
|
preprocess_kwargs = set('a')
|
|
forward_kwargs = set('a')
|
|
|
|
def tearDown(self):
|
|
VISUALIZERS._module_dict.pop('ToyVisualizer')
|
|
return super().tearDown()
|
|
|
|
def test_init(self):
|
|
# Pass model as Config
|
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
ToyInferencer(cfg, self.ckpt_path)
|
|
# Pass model as ConfigDict
|
|
ToyInferencer(cfg._cfg_dict, self.ckpt_path)
|
|
# Pass model as normal dict
|
|
ToyInferencer(dict(cfg._cfg_dict), self.ckpt_path)
|
|
# Pass model as string point to path of config
|
|
ToyInferencer(self.cfg_path, self.ckpt_path)
|
|
|
|
cfg.model.pretrained = 'fake_path'
|
|
inferencer = ToyInferencer(cfg, self.ckpt_path)
|
|
self.assertNotIn('pretrained', inferencer.cfg.model)
|
|
|
|
# Pass invalid model
|
|
with self.assertRaisesRegex(TypeError, 'model must'):
|
|
ToyInferencer([self.epoch_based_cfg], self.ckpt_path)
|
|
|
|
# Pass model as model name defined in metafile
|
|
if is_imported('mmdet'):
|
|
from mmdet.utils import register_all_modules
|
|
|
|
register_all_modules()
|
|
ToyInferencer(
|
|
'faster-rcnn_s50_fpn_syncbn-backbone+head_ms-range-1x_coco',
|
|
'https://download.openmmlab.com/mmdetection/v2.0/resnest/faster_rcnn_s50_fpn_syncbn-backbone%2Bhead_mstrain-range_1x_coco/faster_rcnn_s50_fpn_syncbn-backbone%2Bhead_mstrain-range_1x_coco_20200926_125502-20289c16.pth', # noqa: E501
|
|
)
|
|
|
|
checkpoint = self.ckpt_path
|
|
ToyInferencer(weights=checkpoint)
|
|
|
|
def test_call(self):
|
|
num_imgs = 12
|
|
imgs = []
|
|
img_paths = []
|
|
for i in range(num_imgs):
|
|
img = np.random.random((1, 2))
|
|
img_path = osp.join(self.temp_dir.name, f'{i}.npy')
|
|
img.dump(img_path)
|
|
imgs.append(img)
|
|
img_paths.append(img_path)
|
|
|
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
|
inferencer(imgs)
|
|
inferencer(img_paths)
|
|
|
|
@pytest.mark.skipif(
|
|
not is_imported('mmdet'), reason='mmdet is not installed')
|
|
def test_load_model_from_meta(self):
|
|
from mmdet.utils import register_all_modules
|
|
|
|
register_all_modules()
|
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
|
inferencer._load_model_from_metafile('retinanet_r18_fpn_1x_coco')
|
|
with self.assertRaisesRegex(ValueError, 'Cannot find model'):
|
|
inferencer._load_model_from_metafile('fake_model')
|
|
# TODO: Test alias
|
|
|
|
def test_init_model(self):
|
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
|
model = inferencer._init_model(self.iter_based_cfg, self.ckpt_path)
|
|
self.assertFalse(model.training)
|
|
|
|
def test_get_chunk_data(self):
|
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
|
data = list(range(1, 11))
|
|
chunk_data = inferencer._get_chunk_data(data, 3)
|
|
self.assertEqual(
|
|
list(chunk_data), [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]])
|
|
|
|
def test_init_visualizer(self):
|
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
|
visualizer = inferencer._init_visualizer(cfg)
|
|
self.assertIsNone(visualizer, None)
|
|
cfg.visualizer = dict(type='ToyVisualizer')
|
|
visualizer = inferencer._init_visualizer(cfg)
|
|
self.assertIsInstance(visualizer, ToyVisualizer)
|
|
|
|
# Visualizer could be built with the same name repeatedly.
|
|
cfg.visualizer = dict(type='ToyVisualizer', name='toy')
|
|
visualizer = inferencer._init_visualizer(cfg)
|
|
visualizer = inferencer._init_visualizer(cfg)
|
|
|
|
def test_dispatch_kwargs(self):
|
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
|
kwargs = dict(
|
|
pre_arg=dict(a=1),
|
|
for_arg=dict(c=2),
|
|
vis_arg=dict(b=3),
|
|
pos_arg=dict(d=4))
|
|
pre_arg, for_arg, vis_arg, pos_arg = inferencer._dispatch_kwargs(
|
|
**kwargs)
|
|
self.assertEqual(pre_arg, dict(pre_arg=dict(a=1)))
|
|
self.assertEqual(for_arg, dict(for_arg=dict(c=2)))
|
|
self.assertEqual(vis_arg, dict(vis_arg=dict(b=3)))
|
|
self.assertEqual(pos_arg, dict(pos_arg=dict(d=4)))
|
|
# Test unknown arg.
|
|
kwargs = dict(return_datasample=dict())
|
|
with self.assertRaisesRegex(ValueError, 'unknown'):
|
|
inferencer._dispatch_kwargs(**kwargs)
|
|
|
|
def test_preprocess(self):
|
|
inferencer = ToyInferencer(self.cfg_path, self.ckpt_path)
|
|
data = list(range(1, 11))
|
|
pre_data = inferencer.preprocess(data, batch_size=3)
|
|
target_data = [
|
|
[torch.tensor(1),
|
|
torch.tensor(2),
|
|
torch.tensor(3)],
|
|
[torch.tensor(4),
|
|
torch.tensor(5),
|
|
torch.tensor(6)],
|
|
[torch.tensor(7),
|
|
torch.tensor(8),
|
|
torch.tensor(9)],
|
|
[torch.tensor(10)],
|
|
]
|
|
self.assertEqual(list(pre_data), target_data)
|
|
os.mkdir(osp.join(self.temp_dir.name, 'imgs'))
|
|
for i in range(1, 11):
|
|
img = np.array(1)
|
|
img.dump(osp.join(self.temp_dir.name, 'imgs', f'{i}.npy'))
|
|
# Passing a directory of images.
|
|
inputs = inferencer._inputs_to_list(
|
|
osp.join(self.temp_dir.name, 'imgs'))
|
|
dataloader = inferencer.preprocess(inputs, batch_size=3)
|
|
for data in dataloader:
|
|
self.assertTrue(is_list_of(data, torch.Tensor))
|
|
|
|
@pytest.mark.skipif(
|
|
not is_imported('mmdet'), reason='mmdet is not installed')
|
|
def test_list_models(self):
|
|
model_list = BaseInferencer.list_models('mmdet')
|
|
self.assertTrue(len(model_list) > 0)
|
|
DefaultScope._instance_dict.clear()
|
|
with self.assertRaisesRegex(AssertionError, 'scope should be'):
|
|
BaseInferencer.list_models()
|
|
with self.assertRaisesRegex(AssertionError, 'unknown not in'):
|
|
BaseInferencer.list_models('unknown')
|