mirror of https://github.com/open-mmlab/mmyolo.git
153 lines
4.9 KiB
Python
153 lines
4.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from projects.easydeploy.model import ORTWrapper, TRTWrapper # isort:skip
|
|
import os
|
|
import random
|
|
from argparse import ArgumentParser
|
|
|
|
import cv2
|
|
import mmcv
|
|
import numpy as np
|
|
import torch
|
|
from mmcv.transforms import Compose
|
|
from mmdet.utils import get_test_pipeline_cfg
|
|
from mmengine.config import Config, ConfigDict
|
|
from mmengine.utils import ProgressBar, path
|
|
|
|
from mmyolo.utils import register_all_modules
|
|
from mmyolo.utils.misc import get_file_list
|
|
|
|
|
|
def parse_args():
|
|
parser = ArgumentParser()
|
|
parser.add_argument(
|
|
'img', help='Image path, include image file, dir and URL.')
|
|
parser.add_argument('config', help='Config file')
|
|
parser.add_argument('checkpoint', help='Checkpoint file')
|
|
parser.add_argument(
|
|
'--out-dir', default='./output', help='Path to output file')
|
|
parser.add_argument(
|
|
'--device', default='cuda:0', help='Device used for inference')
|
|
parser.add_argument(
|
|
'--show', action='store_true', help='Show the detection results')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def preprocess(config):
|
|
data_preprocess = config.get('model', {}).get('data_preprocessor', {})
|
|
mean = data_preprocess.get('mean', [0., 0., 0.])
|
|
std = data_preprocess.get('std', [1., 1., 1.])
|
|
mean = torch.tensor(mean, dtype=torch.float32).reshape(1, 3, 1, 1)
|
|
std = torch.tensor(std, dtype=torch.float32).reshape(1, 3, 1, 1)
|
|
|
|
class PreProcess(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
x = x[None].float()
|
|
x -= mean.to(x.device)
|
|
x /= std.to(x.device)
|
|
return x
|
|
|
|
return PreProcess().eval()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
# register all modules in mmdet into the registries
|
|
register_all_modules()
|
|
|
|
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(1000)]
|
|
|
|
# build the model from a config file and a checkpoint file
|
|
if args.checkpoint.endswith('.onnx'):
|
|
model = ORTWrapper(args.checkpoint, args.device)
|
|
elif args.checkpoint.endswith('.engine') or args.checkpoint.endswith(
|
|
'.plan'):
|
|
model = TRTWrapper(args.checkpoint, args.device)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
model.to(args.device)
|
|
|
|
cfg = Config.fromfile(args.config)
|
|
class_names = cfg.get('class_name')
|
|
|
|
test_pipeline = get_test_pipeline_cfg(cfg)
|
|
test_pipeline[0] = ConfigDict({'type': 'mmdet.LoadImageFromNDArray'})
|
|
test_pipeline = Compose(test_pipeline)
|
|
|
|
pre_pipeline = preprocess(cfg)
|
|
|
|
if not args.show:
|
|
path.mkdir_or_exist(args.out_dir)
|
|
|
|
# get file list
|
|
files, source_type = get_file_list(args.img)
|
|
|
|
# start detector inference
|
|
progress_bar = ProgressBar(len(files))
|
|
for i, file in enumerate(files):
|
|
bgr = mmcv.imread(file)
|
|
rgb = mmcv.imconvert(bgr, 'bgr', 'rgb')
|
|
data, samples = test_pipeline(dict(img=rgb, img_id=i)).values()
|
|
pad_param = samples.get('pad_param',
|
|
np.array([0, 0, 0, 0], dtype=np.float32))
|
|
h, w = samples.get('ori_shape', rgb.shape[:2])
|
|
pad_param = torch.asarray(
|
|
[pad_param[2], pad_param[0], pad_param[2], pad_param[0]],
|
|
device=args.device)
|
|
scale_factor = samples.get('scale_factor', [1., 1])
|
|
scale_factor = torch.asarray(scale_factor * 2, device=args.device)
|
|
data = pre_pipeline(data).to(args.device)
|
|
|
|
result = model(data)
|
|
if source_type['is_dir']:
|
|
filename = os.path.relpath(file, args.img).replace('/', '_')
|
|
else:
|
|
filename = os.path.basename(file)
|
|
out_file = None if args.show else os.path.join(args.out_dir, filename)
|
|
|
|
# Get candidate predict info by num_dets
|
|
num_dets, bboxes, scores, labels = result
|
|
scores = scores[0, :num_dets]
|
|
bboxes = bboxes[0, :num_dets]
|
|
labels = labels[0, :num_dets]
|
|
bboxes -= pad_param
|
|
bboxes /= scale_factor
|
|
|
|
bboxes[:, 0::2].clamp_(0, w)
|
|
bboxes[:, 1::2].clamp_(0, h)
|
|
bboxes = bboxes.round().int()
|
|
|
|
for (bbox, score, label) in zip(bboxes, scores, labels):
|
|
bbox = bbox.tolist()
|
|
color = colors[label]
|
|
|
|
if class_names is not None:
|
|
label_name = class_names[label]
|
|
name = f'cls:{label_name}_score:{score:0.4f}'
|
|
else:
|
|
name = f'cls:{label}_score:{score:0.4f}'
|
|
|
|
cv2.rectangle(bgr, bbox[:2], bbox[2:], color, 2)
|
|
cv2.putText(
|
|
bgr,
|
|
name, (bbox[0], bbox[1] - 2),
|
|
cv2.FONT_HERSHEY_SIMPLEX,
|
|
2.0, [225, 255, 255],
|
|
thickness=3)
|
|
|
|
if args.show:
|
|
mmcv.imshow(bgr, 'result', 0)
|
|
else:
|
|
mmcv.imwrite(bgr, out_file)
|
|
progress_bar.update()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|