add sdk python demo (#554)
* check in python demos * check in text detector python demo * check in roatated object python demo * check in pose python demo * ignore the output class number when testing metrics with sdk as a backend * fix object_detection * rollback segmentation_model and python/segmentor.cpppull/557/head
parent
74243dc98b
commit
0d609701df
|
@ -146,3 +146,5 @@ bin/
|
|||
|
||||
# ncnn
|
||||
mmdeploy/backend/ncnn/onnx2ncnn
|
||||
|
||||
/mmdeploy-*
|
||||
|
|
|
@ -35,8 +35,10 @@ int main(int argc, char *argv[]) {
|
|||
fprintf(stderr, "failed to apply classifier, code: %d\n", (int)status);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "label: %d, score: %.4f\n", res->label_id, res->score);
|
||||
for (int i = 0; i < res_count[0]; ++i) {
|
||||
fprintf(stderr, "label: %d, score: %.4f\n", res->label_id, res->score);
|
||||
++res;
|
||||
}
|
||||
|
||||
mmdeploy_classifier_release_result(res, res_count, 1);
|
||||
|
||||
|
|
|
@ -52,7 +52,7 @@ int main(int argc, char *argv[]) {
|
|||
}
|
||||
|
||||
// skip detections less than specified score threshold
|
||||
if (bboxes[i].score < 0.1) {
|
||||
if (bboxes[i].score < 0.3) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
from mmdeploy_python import Classifier
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='show how to use sdk python api')
|
||||
parser.add_argument(
|
||||
'model_path', help='the directory path of mmdeploy model')
|
||||
parser.add_argument('image_path', help='the path of an image')
|
||||
parser.add_argument(
|
||||
'--device-name', default='cpu', help='the name of device, cuda or cpu')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
img = cv2.imread(args.image_path)
|
||||
classifier = Classifier(args.model_path, args.device_name, 0)
|
||||
result = classifier([img])
|
||||
for label_id, score in result[0]:
|
||||
print(label_id, score)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
from mmdeploy_python import Restorer
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='show how to use sdk python api')
|
||||
parser.add_argument(
|
||||
'model_path', help='the directory path of mmdeploy model')
|
||||
parser.add_argument('image_path', help='the path of an image')
|
||||
parser.add_argument(
|
||||
'--device-name', default='cpu', help='the name of device, cuda or cpu')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
img = cv2.imread(args.image_path)
|
||||
|
||||
restorer = Restorer(args.model_path, args.device_name, 0)
|
||||
result = restorer([img])[0]
|
||||
|
||||
# convert to BGR
|
||||
result = result[..., ::-1]
|
||||
cv2.imwrite('output_restorer.bmp', result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from mmdeploy_python import Segmentor
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='show how to use sdk python api')
|
||||
parser.add_argument(
|
||||
'model_path', help='the directory path of mmdeploy model')
|
||||
parser.add_argument('image_path', help='the path of an image')
|
||||
parser.add_argument(
|
||||
'--device-name', default='cpu', help='the name of device, cuda or cpu')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def get_palette(num_classes=256):
|
||||
state = np.random.get_state()
|
||||
# random color
|
||||
np.random.seed(42)
|
||||
palette = np.random.randint(0, 256, size=(num_classes, 3))
|
||||
np.random.set_state(state)
|
||||
return [tuple(c) for c in palette]
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
img = cv2.imread(args.image_path)
|
||||
|
||||
segmentor = Segmentor(args.model_path, args.device_name, 0)
|
||||
seg = segmentor([img])[0]
|
||||
|
||||
palette = get_palette()
|
||||
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
|
||||
for label, color in enumerate(palette):
|
||||
color_seg[seg == label, :] = color
|
||||
# convert to BGR
|
||||
color_seg = color_seg[..., ::-1]
|
||||
|
||||
img = img * 0.5 + color_seg * 0.5
|
||||
img = img.astype(np.uint8)
|
||||
cv2.imwrite('output_segmentation.png', img)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from mmdeploy_python import Detector
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='show how to use sdk python api')
|
||||
parser.add_argument(
|
||||
'model_path', help='the directory path of mmdeploy model')
|
||||
parser.add_argument('image_path', help='the path of an image')
|
||||
parser.add_argument(
|
||||
'--device-name', default='cpu', help='the name of device, cuda or cpu')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
img = cv2.imread(args.image_path)
|
||||
detector = Detector(args.model_path, args.device_name, 0)
|
||||
bboxes, labels, masks = detector([img])[0]
|
||||
assert (isinstance(bboxes, np.ndarray))
|
||||
assert (isinstance(labels, np.ndarray))
|
||||
assert (isinstance(masks, list))
|
||||
|
||||
indices = [i for i in range(len(bboxes))]
|
||||
for index, bbox, label_id in zip(indices, bboxes, labels):
|
||||
[left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4]
|
||||
if score < 0.3:
|
||||
continue
|
||||
|
||||
cv2.rectangle(img, (left, top), (right, bottom), (0, 255, 0))
|
||||
|
||||
if masks[index].size:
|
||||
mask = masks[index]
|
||||
blue, green, red = cv2.split(img)
|
||||
mask_img = blue[top:top + mask.shape[0], left:left + mask.shape[1]]
|
||||
cv2.bitwise_or(mask, mask_img, mask_img)
|
||||
img = cv2.merge([blue, green, red])
|
||||
|
||||
cv2.imwrite('output_detection.png', img)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
from mmdeploy_python import TextDetector
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='show how to use sdk python api')
|
||||
parser.add_argument(
|
||||
'--textdet',
|
||||
default='',
|
||||
help='the directory path of mmdeploy text-detector sdk model')
|
||||
parser.add_argument(
|
||||
'--textrecog',
|
||||
default='',
|
||||
help='the directory path of mmdeploy text-recognizer sdk model')
|
||||
parser.add_argument('image_path', help='the path of an image')
|
||||
parser.add_argument(
|
||||
'--device-name', default='cpu', help='the name of device, cuda or cpu')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
img = cv2.imread(args.image_path)
|
||||
|
||||
if args.textdet:
|
||||
detector = TextDetector(args.textdet, args.device_name, 0)
|
||||
bboxes = detector([img])[0]
|
||||
|
||||
pts = (bboxes[:, 0:8] + 0.5).reshape(len(bboxes), -1, 2).astype(int)
|
||||
cv2.polylines(img, pts, True, (0, 255, 0), 2)
|
||||
cv2.imwrite('output_ocr.png', img)
|
||||
|
||||
if args.textrecog:
|
||||
print('API of TextRecognizer does not support bbox as argument yet')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from mmdeploy_python import PoseDetector
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='show how to use sdk python api')
|
||||
parser.add_argument(
|
||||
'model_path', help='the directory path of mmdeploy model')
|
||||
parser.add_argument('image_path', help='the path of an image')
|
||||
parser.add_argument(
|
||||
'--bbox',
|
||||
default=None,
|
||||
nargs='+',
|
||||
help='bounding box of an object in format (x, y, w, h)')
|
||||
parser.add_argument(
|
||||
'--device-name', default='cpu', help='the name of device, cuda or cpu')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
img = cv2.imread(args.image_path)
|
||||
|
||||
bboxes = []
|
||||
if args.bbox is None:
|
||||
bbox = [0, 0, img.shape[1], img.shape[0]]
|
||||
else:
|
||||
# x, y, w, h -> left, top, right, bottom
|
||||
bbox = np.array(args.bbox, dtype=int)
|
||||
bbox[2:] += bbox[:2]
|
||||
bboxes.append(bbox)
|
||||
|
||||
detector = PoseDetector(args.model_path, args.device_name, 0)
|
||||
result = detector([img], [bboxes])[0]
|
||||
|
||||
_, point_num, _ = result.shape
|
||||
points = result[:, :, :2].reshape(point_num, 2)
|
||||
for [x, y] in points.astype(int):
|
||||
cv2.circle(img, (x, y), 1, (0, 255, 0), 2)
|
||||
|
||||
cv2.imwrite('output_pose.png', img)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
from math import cos, sin
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from mmdeploy_python import RotatedDetector
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='show how to use sdk python api')
|
||||
parser.add_argument(
|
||||
'model_path', help='the directory path of mmdeploy model')
|
||||
parser.add_argument('image_path', help='the path of an image')
|
||||
parser.add_argument(
|
||||
'--device-name', default='cpu', help='the name of device, cuda or cpu')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
img = cv2.imread(args.image_path)
|
||||
detector = RotatedDetector(args.model_path, args.device_name, 0)
|
||||
rbboxes, labels = detector([img])[0]
|
||||
# print(rbboxes, labels)
|
||||
indices = [i for i in range(len(rbboxes))]
|
||||
for index, rbbox, label_id in zip(indices, rbboxes, labels):
|
||||
[cx, cy, w, h, angle], score = rbbox[0:5], rbbox[-1]
|
||||
if score < 0.1:
|
||||
continue
|
||||
[wx, wy, hx, hy] = \
|
||||
0.5 * np.array([w, w, -h, h]) * \
|
||||
np.array([cos(angle), sin(angle), sin(angle), cos(angle)])
|
||||
points = np.array([[[int(cx - wx - hx),
|
||||
int(cy - wy - hy)],
|
||||
[int(cx + wx - hx),
|
||||
int(cy + wy - hy)],
|
||||
[int(cx + wx + hx),
|
||||
int(cy + wy + hy)],
|
||||
[int(cx - wx + hx),
|
||||
int(cy - wy + hy)]]])
|
||||
cv2.drawContours(img, points, -1, (0, 255, 0), 2)
|
||||
|
||||
cv2.imwrite('output_detection.png', img)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue