mirror of https://github.com/alibaba/EasyCV.git
207 lines
7.6 KiB
Python
207 lines
7.6 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
from mmcv.parallel import collate, scatter
|
|
from numpy.testing import assert_array_almost_equal
|
|
from torchvision.transforms import Compose
|
|
|
|
from easycv.datasets.registry import PIPELINES
|
|
from easycv.datasets.utils import replace_ImageToTensor
|
|
from easycv.models import build_model
|
|
from easycv.utils.checkpoint import load_checkpoint
|
|
from easycv.utils.config_tools import mmcv_config_fromfile
|
|
from easycv.utils.registry import build_from_cfg
|
|
|
|
|
|
class FCOSTest(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def init_fcos(self, model_path, config_path):
|
|
self.model_path = model_path
|
|
|
|
self.cfg = mmcv_config_fromfile(config_path)
|
|
|
|
# modify model_config
|
|
if self.cfg.model.head.test_cfg.get('max_per_img', None):
|
|
self.cfg.model.head.test_cfg.max_per_img = 10
|
|
|
|
# build model
|
|
self.model = build_model(self.cfg.model)
|
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
map_location = 'cpu' if self.device == 'cpu' else 'cuda'
|
|
self.ckpt = load_checkpoint(
|
|
self.model, self.model_path, map_location=map_location)
|
|
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
|
|
self.CLASSES = self.cfg.CLASSES
|
|
|
|
def predict(self, imgs):
|
|
"""Inference image(s) with the detector.
|
|
Args:
|
|
model (nn.Module): The loaded detector.
|
|
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
|
|
Either image files or loaded images.
|
|
Returns:
|
|
If imgs is a list or tuple, the same length list type results
|
|
will be returned, otherwise return the detection results directly.
|
|
"""
|
|
|
|
if isinstance(imgs, (list, tuple)):
|
|
is_batch = True
|
|
else:
|
|
imgs = [imgs]
|
|
is_batch = False
|
|
|
|
cfg = self.cfg
|
|
device = next(self.model.parameters()).device # model device
|
|
|
|
if isinstance(imgs[0], np.ndarray):
|
|
cfg = cfg.copy()
|
|
# set loading pipeline type
|
|
cfg.data.val.pipeline.insert(
|
|
0,
|
|
dict(
|
|
type='LoadImageFromWebcam',
|
|
file_client_args=dict(backend='http')))
|
|
else:
|
|
cfg = cfg.copy()
|
|
# set loading pipeline type
|
|
cfg.data.val.pipeline.insert(
|
|
0,
|
|
dict(
|
|
type='LoadImageFromFile',
|
|
file_client_args=dict(backend='http')))
|
|
|
|
cfg.data.val.pipeline = replace_ImageToTensor(cfg.data.val.pipeline)
|
|
|
|
transforms = []
|
|
for transform in cfg.data.val.pipeline:
|
|
if 'img_scale' in transform:
|
|
transform['img_scale'] = tuple(transform['img_scale'])
|
|
if isinstance(transform, dict):
|
|
transform = build_from_cfg(transform, PIPELINES)
|
|
transforms.append(transform)
|
|
elif callable(transform):
|
|
transforms.append(transform)
|
|
else:
|
|
raise TypeError('transform must be callable or a dict')
|
|
test_pipeline = Compose(transforms)
|
|
|
|
datas = []
|
|
for img in imgs:
|
|
# prepare data
|
|
if isinstance(img, np.ndarray):
|
|
# directly add img
|
|
data = dict(img=img)
|
|
else:
|
|
# add information into dict
|
|
data = dict(img_info=dict(filename=img), img_prefix=None)
|
|
# build the data pipeline
|
|
data = test_pipeline(data)
|
|
datas.append(data)
|
|
|
|
data = collate(datas, samples_per_gpu=len(imgs))
|
|
# just get the actual data from DataContainer
|
|
data['img_metas'] = [
|
|
img_metas.data[0] for img_metas in data['img_metas']
|
|
]
|
|
data['img'] = [img.data[0] for img in data['img']]
|
|
if next(self.model.parameters()).is_cuda:
|
|
# scatter to specified GPU
|
|
data = scatter(data, [device])[0]
|
|
|
|
# forward the model
|
|
with torch.no_grad():
|
|
results = self.model(mode='test', **data)
|
|
|
|
return results
|
|
|
|
def test_fcos(self):
|
|
model_path = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/detection/fcos/epoch_12.pth'
|
|
config_path = 'configs/detection/fcos/fcos_center-normbbox-centeronreg-giou_r50_caffe_fpn_gn-head_1x_coco.py'
|
|
img = 'https://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/demo/demo.jpg'
|
|
self.init_fcos(model_path, config_path)
|
|
output = self.predict(img)
|
|
|
|
self.assertIn('detection_boxes', output)
|
|
self.assertIn('detection_scores', output)
|
|
self.assertIn('detection_classes', output)
|
|
self.assertIn('img_metas', output)
|
|
self.assertEqual(len(output['detection_boxes'][0]), 10)
|
|
self.assertEqual(len(output['detection_scores'][0]), 10)
|
|
self.assertEqual(len(output['detection_classes'][0]), 10)
|
|
|
|
print(output['detection_boxes'][0].tolist())
|
|
print(output['detection_scores'][0].tolist())
|
|
print(output['detection_classes'][0].tolist())
|
|
|
|
self.assertListEqual(
|
|
output['detection_classes'][0].tolist(),
|
|
np.array([2, 2, 2, 2, 2, 2, 2, 2, 2, 13], dtype=np.int32).tolist())
|
|
|
|
assert_array_almost_equal(
|
|
output['detection_scores'][0],
|
|
np.array([
|
|
0.6641181707382202, 0.6135501265525818, 0.5985610485076904,
|
|
0.5694775581359863, 0.5586040616035461, 0.5209507942199707,
|
|
0.5056729912757874, 0.4943872094154358, 0.4850597083568573,
|
|
0.45443734526634216
|
|
],
|
|
dtype=np.float32),
|
|
decimal=2)
|
|
|
|
assert_array_almost_equal(
|
|
output['detection_boxes'][0],
|
|
np.array([[
|
|
295.5196228027344, 116.56035614013672, 380.0883483886719,
|
|
150.24908447265625
|
|
],
|
|
[
|
|
190.57131958007812, 108.96343231201172,
|
|
297.7738037109375, 154.69515991210938
|
|
],
|
|
[
|
|
480.5726013183594, 110.4341812133789,
|
|
522.8551635742188, 129.9452667236328
|
|
],
|
|
[
|
|
431.1232604980469, 105.17676544189453,
|
|
483.89617919921875, 131.85870361328125
|
|
],
|
|
[
|
|
398.6544494628906, 110.90837860107422,
|
|
432.6370849609375, 132.89173889160156
|
|
],
|
|
[
|
|
609.3126831054688, 111.62432861328125,
|
|
635.4577026367188, 137.03529357910156
|
|
],
|
|
[
|
|
98.66332244873047, 89.88417053222656,
|
|
118.9398422241211, 101.25397491455078
|
|
],
|
|
[
|
|
167.9045867919922, 109.57560729980469,
|
|
209.74375915527344, 139.98898315429688
|
|
],
|
|
[
|
|
591.0496826171875, 110.55867767333984,
|
|
619.4395751953125, 126.65755462646484
|
|
],
|
|
[
|
|
218.92051696777344, 177.0509033203125,
|
|
455.8321838378906, 385.0356140136719
|
|
]]),
|
|
decimal=1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|