68 lines
1.9 KiB
Python
68 lines
1.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
import json
|
|
|
|
import cv2
|
|
from mmdeploy_python import Context, Device, Model, Pipeline
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description='Demo of MMDeploy SDK pipeline API')
|
|
parser.add_argument('device', help='name of device, cuda or cpu')
|
|
parser.add_argument('det_model_path', help='path of detection model')
|
|
parser.add_argument('cls_model_path', help='path of classification model')
|
|
parser.add_argument('image_path', help='path to test image')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
det_model = Model(args.det_model_path)
|
|
reg_model = Model(args.cls_model_path)
|
|
|
|
config = dict(
|
|
type='Pipeline',
|
|
input='img',
|
|
tasks=[
|
|
dict(
|
|
type='Inference',
|
|
input='img',
|
|
output='dets',
|
|
params=dict(model=det_model)),
|
|
dict(
|
|
type='Pipeline',
|
|
# flatten dets ([[a]] -> [a]) and broadcast img
|
|
input=['boxes=*dets', 'imgs=+img'],
|
|
tasks=[
|
|
dict(
|
|
type='Task',
|
|
module='CropBox',
|
|
input=['imgs', 'boxes'],
|
|
output='patches'),
|
|
dict(
|
|
type='Inference',
|
|
input='patches',
|
|
output='labels',
|
|
params=dict(model=reg_model))
|
|
],
|
|
# unflatten labels ([a] -> [[a]])
|
|
output='*labels')
|
|
],
|
|
output=['dets', 'labels'])
|
|
|
|
device = Device(args.device)
|
|
pipeline = Pipeline(config, Context(device))
|
|
|
|
img = cv2.imread(args.image_path)
|
|
|
|
output = pipeline(dict(ori_img=img))
|
|
|
|
print(json.dumps(output, indent=4))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|