diff --git a/.gitignore b/.gitignore index 742b078d1..324f44bd2 100644 --- a/.gitignore +++ b/.gitignore @@ -146,3 +146,5 @@ bin/ # ncnn mmdeploy/backend/ncnn/onnx2ncnn + +/mmdeploy-* diff --git a/demo/csrc/image_classification.cpp b/demo/csrc/image_classification.cpp index 18d1e0793..034a9b9d1 100644 --- a/demo/csrc/image_classification.cpp +++ b/demo/csrc/image_classification.cpp @@ -35,8 +35,10 @@ int main(int argc, char *argv[]) { fprintf(stderr, "failed to apply classifier, code: %d\n", (int)status); return 1; } - - fprintf(stderr, "label: %d, score: %.4f\n", res->label_id, res->score); + for (int i = 0; i < res_count[0]; ++i) { + fprintf(stderr, "label: %d, score: %.4f\n", res->label_id, res->score); + ++res; + } mmdeploy_classifier_release_result(res, res_count, 1); diff --git a/demo/csrc/object_detection.cpp b/demo/csrc/object_detection.cpp index 184340753..3ed7ac4e5 100644 --- a/demo/csrc/object_detection.cpp +++ b/demo/csrc/object_detection.cpp @@ -52,7 +52,7 @@ int main(int argc, char *argv[]) { } // skip detections less than specified score threshold - if (bboxes[i].score < 0.1) { + if (bboxes[i].score < 0.3) { continue; } diff --git a/demo/python/image_classification.py b/demo/python/image_classification.py new file mode 100644 index 000000000..aae3f744b --- /dev/null +++ b/demo/python/image_classification.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import cv2 +from mmdeploy_python import Classifier + + +def parse_args(): + parser = argparse.ArgumentParser( + description='show how to use sdk python api') + parser.add_argument( + 'model_path', help='the directory path of mmdeploy model') + parser.add_argument('image_path', help='the path of an image') + parser.add_argument( + '--device-name', default='cpu', help='the name of device, cuda or cpu') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + img = cv2.imread(args.image_path) + classifier = Classifier(args.model_path, args.device_name, 0) + result = classifier([img]) + for label_id, score in result[0]: + print(label_id, score) + + +if __name__ == '__main__': + main() diff --git a/demo/python/image_restorer.py b/demo/python/image_restorer.py new file mode 100644 index 000000000..8b0274c62 --- /dev/null +++ b/demo/python/image_restorer.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import cv2 +from mmdeploy_python import Restorer + + +def parse_args(): + parser = argparse.ArgumentParser( + description='show how to use sdk python api') + parser.add_argument( + 'model_path', help='the directory path of mmdeploy model') + parser.add_argument('image_path', help='the path of an image') + parser.add_argument( + '--device-name', default='cpu', help='the name of device, cuda or cpu') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + img = cv2.imread(args.image_path) + + restorer = Restorer(args.model_path, args.device_name, 0) + result = restorer([img])[0] + + # convert to BGR + result = result[..., ::-1] + cv2.imwrite('output_restorer.bmp', result) + + +if __name__ == '__main__': + main() diff --git a/demo/python/image_segmentation.py b/demo/python/image_segmentation.py new file mode 100644 index 000000000..3c106a565 --- /dev/null +++ b/demo/python/image_segmentation.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import cv2 +import numpy as np +from mmdeploy_python import Segmentor + + +def parse_args(): + parser = argparse.ArgumentParser( + description='show how to use sdk python api') + parser.add_argument( + 'model_path', help='the directory path of mmdeploy model') + parser.add_argument('image_path', help='the path of an image') + parser.add_argument( + '--device-name', default='cpu', help='the name of device, cuda or cpu') + args = parser.parse_args() + return args + + +def get_palette(num_classes=256): + state = np.random.get_state() + # random color + np.random.seed(42) + palette = np.random.randint(0, 256, size=(num_classes, 3)) + np.random.set_state(state) + return [tuple(c) for c in palette] + + +def main(): + args = parse_args() + + img = cv2.imread(args.image_path) + + segmentor = Segmentor(args.model_path, args.device_name, 0) + seg = segmentor([img])[0] + + palette = get_palette() + color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) + for label, color in enumerate(palette): + color_seg[seg == label, :] = color + # convert to BGR + color_seg = color_seg[..., ::-1] + + img = img * 0.5 + color_seg * 0.5 + img = img.astype(np.uint8) + cv2.imwrite('output_segmentation.png', img) + + +if __name__ == '__main__': + main() diff --git a/demo/python/object_detection.py b/demo/python/object_detection.py new file mode 100644 index 000000000..8a9df839e --- /dev/null +++ b/demo/python/object_detection.py @@ -0,0 +1,50 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import cv2 +import numpy as np +from mmdeploy_python import Detector + + +def parse_args(): + parser = argparse.ArgumentParser( + description='show how to use sdk python api') + parser.add_argument( + 'model_path', help='the directory path of mmdeploy model') + parser.add_argument('image_path', help='the path of an image') + parser.add_argument( + '--device-name', default='cpu', help='the name of device, cuda or cpu') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + img = cv2.imread(args.image_path) + detector = Detector(args.model_path, args.device_name, 0) + bboxes, labels, masks = detector([img])[0] + assert (isinstance(bboxes, np.ndarray)) + assert (isinstance(labels, np.ndarray)) + assert (isinstance(masks, list)) + + indices = [i for i in range(len(bboxes))] + for index, bbox, label_id in zip(indices, bboxes, labels): + [left, top, right, bottom], score = bbox[0:4].astype(int), bbox[4] + if score < 0.3: + continue + + cv2.rectangle(img, (left, top), (right, bottom), (0, 255, 0)) + + if masks[index].size: + mask = masks[index] + blue, green, red = cv2.split(img) + mask_img = blue[top:top + mask.shape[0], left:left + mask.shape[1]] + cv2.bitwise_or(mask, mask_img, mask_img) + img = cv2.merge([blue, green, red]) + + cv2.imwrite('output_detection.png', img) + + +if __name__ == '__main__': + main() diff --git a/demo/python/ocr.py b/demo/python/ocr.py new file mode 100644 index 000000000..b6d2dda08 --- /dev/null +++ b/demo/python/ocr.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import cv2 +from mmdeploy_python import TextDetector + + +def parse_args(): + parser = argparse.ArgumentParser( + description='show how to use sdk python api') + parser.add_argument( + '--textdet', + default='', + help='the directory path of mmdeploy text-detector sdk model') + parser.add_argument( + '--textrecog', + default='', + help='the directory path of mmdeploy text-recognizer sdk model') + parser.add_argument('image_path', help='the path of an image') + parser.add_argument( + '--device-name', default='cpu', help='the name of device, cuda or cpu') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + img = cv2.imread(args.image_path) + + if args.textdet: + detector = TextDetector(args.textdet, args.device_name, 0) + bboxes = detector([img])[0] + + pts = (bboxes[:, 0:8] + 0.5).reshape(len(bboxes), -1, 2).astype(int) + cv2.polylines(img, pts, True, (0, 255, 0), 2) + cv2.imwrite('output_ocr.png', img) + + if args.textrecog: + print('API of TextRecognizer does not support bbox as argument yet') + + +if __name__ == '__main__': + main() diff --git a/demo/python/pose_detection.py b/demo/python/pose_detection.py new file mode 100644 index 000000000..d5656b5af --- /dev/null +++ b/demo/python/pose_detection.py @@ -0,0 +1,52 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import cv2 +import numpy as np +from mmdeploy_python import PoseDetector + + +def parse_args(): + parser = argparse.ArgumentParser( + description='show how to use sdk python api') + parser.add_argument( + 'model_path', help='the directory path of mmdeploy model') + parser.add_argument('image_path', help='the path of an image') + parser.add_argument( + '--bbox', + default=None, + nargs='+', + help='bounding box of an object in format (x, y, w, h)') + parser.add_argument( + '--device-name', default='cpu', help='the name of device, cuda or cpu') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + img = cv2.imread(args.image_path) + + bboxes = [] + if args.bbox is None: + bbox = [0, 0, img.shape[1], img.shape[0]] + else: + # x, y, w, h -> left, top, right, bottom + bbox = np.array(args.bbox, dtype=int) + bbox[2:] += bbox[:2] + bboxes.append(bbox) + + detector = PoseDetector(args.model_path, args.device_name, 0) + result = detector([img], [bboxes])[0] + + _, point_num, _ = result.shape + points = result[:, :, :2].reshape(point_num, 2) + for [x, y] in points.astype(int): + cv2.circle(img, (x, y), 1, (0, 255, 0), 2) + + cv2.imwrite('output_pose.png', img) + + +if __name__ == '__main__': + main() diff --git a/demo/python/rotated_object_detection.py b/demo/python/rotated_object_detection.py new file mode 100644 index 000000000..3ac288cb4 --- /dev/null +++ b/demo/python/rotated_object_detection.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +from math import cos, sin + +import cv2 +import numpy as np +from mmdeploy_python import RotatedDetector + + +def parse_args(): + parser = argparse.ArgumentParser( + description='show how to use sdk python api') + parser.add_argument( + 'model_path', help='the directory path of mmdeploy model') + parser.add_argument('image_path', help='the path of an image') + parser.add_argument( + '--device-name', default='cpu', help='the name of device, cuda or cpu') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + img = cv2.imread(args.image_path) + detector = RotatedDetector(args.model_path, args.device_name, 0) + rbboxes, labels = detector([img])[0] + # print(rbboxes, labels) + indices = [i for i in range(len(rbboxes))] + for index, rbbox, label_id in zip(indices, rbboxes, labels): + [cx, cy, w, h, angle], score = rbbox[0:5], rbbox[-1] + if score < 0.1: + continue + [wx, wy, hx, hy] = \ + 0.5 * np.array([w, w, -h, h]) * \ + np.array([cos(angle), sin(angle), sin(angle), cos(angle)]) + points = np.array([[[int(cx - wx - hx), + int(cy - wy - hy)], + [int(cx + wx - hx), + int(cy + wy - hy)], + [int(cx + wx + hx), + int(cy + wy + hy)], + [int(cx - wx + hx), + int(cy - wy + hy)]]]) + cv2.drawContours(img, points, -1, (0, 255, 0), 2) + + cv2.imwrite('output_detection.png', img) + + +if __name__ == '__main__': + main()