diff --git a/csrc/mmdeploy/apis/python/classifier.cpp b/csrc/mmdeploy/apis/python/classifier.cpp index cb7188a03..7467a5706 100644 --- a/csrc/mmdeploy/apis/python/classifier.cpp +++ b/csrc/mmdeploy/apis/python/classifier.cpp @@ -57,9 +57,12 @@ class PyClassifier { static void register_python_classifier(py::module &m) { py::class_(m, "Classifier") .def(py::init([](const char *model_path, const char *device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - })) - .def("__call__", &PyClassifier::Apply); + return std::make_unique(model_path, device_name, device_id); + }), + py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) + .def("__call__", + [](PyClassifier *self, const PyImage &img) { return self->Apply(std::vector{img})[0]; }) + .def("batch", &PyClassifier::Apply); } class PythonClassifierRegisterer { diff --git a/csrc/mmdeploy/apis/python/detector.cpp b/csrc/mmdeploy/apis/python/detector.cpp index e47fac301..645ec820f 100644 --- a/csrc/mmdeploy/apis/python/detector.cpp +++ b/csrc/mmdeploy/apis/python/detector.cpp @@ -68,9 +68,14 @@ class PyDetector { static void register_python_detector(py::module &m) { py::class_(m, "Detector") .def(py::init([](const char *model_path, const char *device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - })) - .def("__call__", &PyDetector::Apply); + return std::make_unique(model_path, device_name, device_id); + }), + py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) + .def("__call__", + [](PyDetector *self, const PyImage &img) -> py::tuple { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PyDetector::Apply); } class PythonDetectorRegisterer { diff --git a/csrc/mmdeploy/apis/python/pose_detector.cpp b/csrc/mmdeploy/apis/python/pose_detector.cpp index b19cc1027..db7617131 100644 --- a/csrc/mmdeploy/apis/python/pose_detector.cpp +++ b/csrc/mmdeploy/apis/python/pose_detector.cpp @@ -11,22 +11,22 @@ namespace mmdeploy { using Rect = std::array; -class PyPoseDedector { +class PyPoseDetector { public: - PyPoseDedector(const char *model_path, const char *device_name, int device_id) { + PyPoseDetector(const char *model_path, const char *device_name, int device_id) { auto status = mmdeploy_pose_detector_create_by_path(model_path, device_name, device_id, &detector_); if (status != MMDEPLOY_SUCCESS) { throw std::runtime_error("failed to create pose_detector"); } } - py::list Apply(const std::vector &imgs, const std::vector> &vboxes) { - if (imgs.size() == 0 && vboxes.size() == 0) { + py::list Apply(const std::vector &imgs, const std::vector> &bboxes) { + if (imgs.size() == 0 && bboxes.size() == 0) { return py::list{}; } - if (vboxes.size() != 0 && vboxes.size() != imgs.size()) { + if (bboxes.size() != 0 && bboxes.size() != imgs.size()) { std::ostringstream os; - os << "imgs length not equal with vboxes [" << imgs.size() << " vs " << vboxes.size() << "]"; + os << "imgs length not equal with vboxes [" << imgs.size() << " vs " << bboxes.size() << "]"; throw std::invalid_argument(os.str()); } @@ -39,7 +39,7 @@ class PyPoseDedector { mats.push_back(mat); } - for (auto _boxes : vboxes) { + for (auto _boxes : bboxes) { for (auto _box : _boxes) { mmdeploy_rect_t box = {_box[0], _box[1], _box[2], _box[3]}; boxes.push_back(box); @@ -48,7 +48,7 @@ class PyPoseDedector { } // full image - if (vboxes.size() == 0) { + if (bboxes.size() == 0) { for (int i = 0; i < mats.size(); i++) { mmdeploy_rect_t box = {0.f, 0.f, mats[i].width - 1.f, mats[i].height - 1.f}; boxes.push_back(box); @@ -89,7 +89,7 @@ class PyPoseDedector { mmdeploy_pose_detector_release_result(detection, total); return output; } - ~PyPoseDedector() { + ~PyPoseDetector() { mmdeploy_pose_detector_destroy(detector_); detector_ = {}; } @@ -99,12 +99,33 @@ class PyPoseDedector { }; static void register_python_pose_detector(py::module &m) { - py::class_(m, "PoseDetector") + py::class_(m, "PoseDetector") .def(py::init([](const char *model_path, const char *device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - })) - .def("__call__", &PyPoseDedector::Apply, py::arg("imgs"), - py::arg("vboxes") = std::vector>()); + return std::make_unique(model_path, device_name, device_id); + }), + py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) + .def("__call__", + [](PyPoseDetector *self, const PyImage &img) -> py::array { + return self->Apply({img}, {})[0]; + }) + .def( + "__call__", + [](PyPoseDetector *self, const PyImage &img, const Rect &box) -> py::array { + std::vector> bboxes; + bboxes.push_back({box}); + return self->Apply({img}, bboxes)[0]; + }, + py::arg("img"), py::arg("box")) + .def( + "__call__", + [](PyPoseDetector *self, const PyImage &img, const std::vector &bboxes) { + std::vector> _bboxes; + _bboxes.push_back(bboxes); + return self->Apply({img}, _bboxes); + }, + py::arg("img"), py::arg("bboxes")) + .def("batch", &PyPoseDetector::Apply, py::arg("imgs"), + py::arg("bboxes") = std::vector>()); } class PythonPoseDetectorRegisterer { diff --git a/csrc/mmdeploy/apis/python/restorer.cpp b/csrc/mmdeploy/apis/python/restorer.cpp index 33ff52f4e..4a345be2d 100644 --- a/csrc/mmdeploy/apis/python/restorer.cpp +++ b/csrc/mmdeploy/apis/python/restorer.cpp @@ -49,9 +49,14 @@ class PyRestorer { static void register_python_restorer(py::module &m) { py::class_(m, "Restorer") .def(py::init([](const char *model_path, const char *device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - })) - .def("__call__", &PyRestorer::Apply); + return std::make_unique(model_path, device_name, device_id); + }), + py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) + .def("__call__", + [](PyRestorer *self, const PyImage &img) -> py::array { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PyRestorer::Apply); } class PythonRestorerRegisterer { diff --git a/csrc/mmdeploy/apis/python/rotated_detector.cpp b/csrc/mmdeploy/apis/python/rotated_detector.cpp index 1624359ff..df2c9ea7c 100644 --- a/csrc/mmdeploy/apis/python/rotated_detector.cpp +++ b/csrc/mmdeploy/apis/python/rotated_detector.cpp @@ -64,9 +64,14 @@ class PyRotatedDetector { static void register_python_rotated_detector(py::module &m) { py::class_(m, "RotatedDetector") .def(py::init([](const char *model_path, const char *device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - })) - .def("__call__", &PyRotatedDetector::Apply); + return std::make_unique(model_path, device_name, device_id); + }), + py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) + .def("__call__", + [](PyRotatedDetector *self, const PyImage &img) -> py::tuple { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PyRotatedDetector::Apply); } class PythonRotatedDetectorRegisterer { diff --git a/csrc/mmdeploy/apis/python/segmentor.cpp b/csrc/mmdeploy/apis/python/segmentor.cpp index 459298ec1..2132e4c03 100644 --- a/csrc/mmdeploy/apis/python/segmentor.cpp +++ b/csrc/mmdeploy/apis/python/segmentor.cpp @@ -50,9 +50,14 @@ class PySegmentor { static void register_python_segmentor(py::module &m) { py::class_(m, "Segmentor") .def(py::init([](const char *model_path, const char *device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - })) - .def("__call__", &PySegmentor::Apply); + return std::make_unique(model_path, device_name, device_id); + }), + py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) + .def("__call__", + [](PySegmentor *self, const PyImage &img) -> py::array { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PySegmentor::Apply); } class PythonSegmentorRegisterer { diff --git a/csrc/mmdeploy/apis/python/text_detector.cpp b/csrc/mmdeploy/apis/python/text_detector.cpp index 1181363e7..fb1975370 100644 --- a/csrc/mmdeploy/apis/python/text_detector.cpp +++ b/csrc/mmdeploy/apis/python/text_detector.cpp @@ -58,9 +58,14 @@ class PyTextDetector { static void register_python_text_detector(py::module &m) { py::class_(m, "TextDetector") .def(py::init([](const char *model_path, const char *device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - })) - .def("__call__", &PyTextDetector::Apply); + return std::make_unique(model_path, device_name, device_id); + }), + py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) + .def("__call__", + [](PyTextDetector *self, const PyImage &img) -> py::array { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PyTextDetector::Apply); } class PythonTextDetectorRegisterer { diff --git a/csrc/mmdeploy/apis/python/text_recognizer.cpp b/csrc/mmdeploy/apis/python/text_recognizer.cpp index 2ac1fdb97..4b6b13434 100644 --- a/csrc/mmdeploy/apis/python/text_recognizer.cpp +++ b/csrc/mmdeploy/apis/python/text_recognizer.cpp @@ -36,6 +36,27 @@ class PyTextRecognizer { mmdeploy_text_recognizer_release_result(results, (int)mats.size()); return output; } + std::vector>> Apply(const PyImage &img, + const std::vector &bboxes) { + if (bboxes.size() * sizeof(float) % sizeof(mmdeploy_text_detection_t)) { + throw std::invalid_argument("bboxes is not a list of 'mmdeploy_text_detection_t'"); + } + auto mat = GetMat(img); + int bbox_count = bboxes.size() * sizeof(float) / sizeof(mmdeploy_text_detection_t); + mmdeploy_text_recognition_t *results{}; + auto status = mmdeploy_text_recognizer_apply_bbox( + recognizer_, &mat, 1, (mmdeploy_text_detection_t *)bboxes.data(), &bbox_count, &results); + if (status != MMDEPLOY_SUCCESS) { + throw std::runtime_error("failed to apply text_recognizer, code: " + std::to_string(status)); + } + auto output = std::vector>>{}; + for (int i = 0; i < bbox_count; ++i) { + std::vector score(results[i].score, results[i].score + results[i].length); + output.emplace_back(results[i].text, std::move(score)); + } + mmdeploy_text_recognizer_release_result(results, bbox_count); + return output; + } ~PyTextRecognizer() { mmdeploy_text_recognizer_destroy(recognizer_); recognizer_ = {}; @@ -48,9 +69,14 @@ class PyTextRecognizer { static void register_python_text_recognizer(py::module &m) { py::class_(m, "TextRecognizer") .def(py::init([](const char *model_path, const char *device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - })) - .def("__call__", &PyTextRecognizer::Apply); + return std::make_unique(model_path, device_name, device_id); + }), + py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) + .def("__call__", [](PyTextRecognizer *self, + const PyImage &img) { return self->Apply(std::vector{img})[0]; }) + .def("__call__", [](PyTextRecognizer *self, const PyImage &img, + const std::vector &bboxes) { return self->Apply(img, bboxes); }) + .def("batch", py::overload_cast &>(&PyTextRecognizer::Apply)); } class PythonTextRecognizerRegisterer { diff --git a/demo/python/image_classification.py b/demo/python/image_classification.py index aae3f744b..9ef5ce103 100644 --- a/demo/python/image_classification.py +++ b/demo/python/image_classification.py @@ -8,11 +8,11 @@ from mmdeploy_python import Classifier def parse_args(): parser = argparse.ArgumentParser( description='show how to use sdk python api') + parser.add_argument('device_name', help='name of device, cuda or cpu') parser.add_argument( - 'model_path', help='the directory path of mmdeploy model') - parser.add_argument('image_path', help='the path of an image') - parser.add_argument( - '--device-name', default='cpu', help='the name of device, cuda or cpu') + 'model_path', + help='path of mmdeploy SDK model dumped by model converter') + parser.add_argument('image_path', help='path of an image') args = parser.parse_args() return args @@ -21,9 +21,10 @@ def main(): args = parse_args() img = cv2.imread(args.image_path) - classifier = Classifier(args.model_path, args.device_name, 0) - result = classifier([img]) - for label_id, score in result[0]: + classifier = Classifier( + model_path=args.model_path, device_name=args.device_name, device_id=0) + result = classifier(img) + for label_id, score in result: print(label_id, score) diff --git a/demo/python/image_restorer.py b/demo/python/image_restorer.py index 8b0274c62..ed10b153f 100644 --- a/demo/python/image_restorer.py +++ b/demo/python/image_restorer.py @@ -8,11 +8,10 @@ from mmdeploy_python import Restorer def parse_args(): parser = argparse.ArgumentParser( description='show how to use sdk python api') + parser.add_argument('device_name', help='name of device, cuda or cpu') parser.add_argument( - 'model_path', help='the directory path of mmdeploy model') - parser.add_argument('image_path', help='the path of an image') - parser.add_argument( - '--device-name', default='cpu', help='the name of device, cuda or cpu') + 'model_path', help='path of SDK model dumped by model converter') + parser.add_argument('image_path', help='path of an image') args = parser.parse_args() return args @@ -22,8 +21,9 @@ def main(): img = cv2.imread(args.image_path) - restorer = Restorer(args.model_path, args.device_name, 0) - result = restorer([img])[0] + restorer = Restorer( + model_path=args.model_path, device_name=args.device_name, device_id=0) + result = restorer(img) # convert to BGR result = result[..., ::-1] diff --git a/demo/python/image_segmentation.py b/demo/python/image_segmentation.py index 3c106a565..32391f434 100644 --- a/demo/python/image_segmentation.py +++ b/demo/python/image_segmentation.py @@ -9,11 +9,11 @@ from mmdeploy_python import Segmentor def parse_args(): parser = argparse.ArgumentParser( description='show how to use sdk python api') + parser.add_argument('device_name', help='name of device, cuda or cpu') parser.add_argument( - 'model_path', help='the directory path of mmdeploy model') - parser.add_argument('image_path', help='the path of an image') - parser.add_argument( - '--device-name', default='cpu', help='the name of device, cuda or cpu') + 'model_path', + help='path of mmdeploy SDK model dumped by model converter') + parser.add_argument('image_path', help='path of an image') args = parser.parse_args() return args @@ -32,8 +32,9 @@ def main(): img = cv2.imread(args.image_path) - segmentor = Segmentor(args.model_path, args.device_name, 0) - seg = segmentor([img])[0] + segmentor = Segmentor( + model_path=args.model_path, device_name=args.device_name, device_id=0) + seg = segmentor(img) palette = get_palette() color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) diff --git a/demo/python/object_detection.py b/demo/python/object_detection.py index 8a9df839e..a584d4dd4 100644 --- a/demo/python/object_detection.py +++ b/demo/python/object_detection.py @@ -2,18 +2,17 @@ import argparse import cv2 -import numpy as np from mmdeploy_python import Detector def parse_args(): parser = argparse.ArgumentParser( description='show how to use sdk python api') + parser.add_argument('device_name', help='name of device, cuda or cpu') parser.add_argument( - 'model_path', help='the directory path of mmdeploy model') - parser.add_argument('image_path', help='the path of an image') - parser.add_argument( - '--device-name', default='cpu', help='the name of device, cuda or cpu') + 'model_path', + help='path of mmdeploy SDK model dumped by model converter') + parser.add_argument('image_path', help='path of an image') args = parser.parse_args() return args @@ -22,11 +21,9 @@ def main(): args = parse_args() img = cv2.imread(args.image_path) - detector = Detector(args.model_path, args.device_name, 0) - bboxes, labels, masks = detector([img])[0] - assert (isinstance(bboxes, np.ndarray)) - assert (isinstance(labels, np.ndarray)) - assert (isinstance(masks, list)) + detector = Detector( + model_path=args.model_path, device_name=args.device_name, device_id=0) + bboxes, labels, masks = detector(img) indices = [i for i in range(len(bboxes))] for index, bbox, label_id in zip(indices, bboxes, labels): diff --git a/demo/python/ocr.py b/demo/python/ocr.py index b6d2dda08..6f02b5b04 100644 --- a/demo/python/ocr.py +++ b/demo/python/ocr.py @@ -2,23 +2,26 @@ import argparse import cv2 -from mmdeploy_python import TextDetector +from mmdeploy_python import TextDetector, TextRecognizer def parse_args(): parser = argparse.ArgumentParser( description='show how to use sdk python api') + parser.add_argument('device_name', help='name of device, cuda or cpu') + parser.add_argument('image_path', help='path of an image') parser.add_argument( '--textdet', default='', - help='the directory path of mmdeploy text-detector sdk model') + help='path of mmdeploy text-detector SDK model dumped by' + 'model converter', + ) parser.add_argument( '--textrecog', default='', - help='the directory path of mmdeploy text-recognizer sdk model') - parser.add_argument('image_path', help='the path of an image') - parser.add_argument( - '--device-name', default='cpu', help='the name of device, cuda or cpu') + help='path of mmdeploy text-recognizer SDK model dumped by' + 'model converter', + ) args = parser.parse_args() return args @@ -29,15 +32,37 @@ def main(): img = cv2.imread(args.image_path) if args.textdet: - detector = TextDetector(args.textdet, args.device_name, 0) - bboxes = detector([img])[0] + detector = TextDetector( + model_path=args.textdet, device_name=args.device_name, device_id=0) + bboxes = detector(img) + print(f'bboxes.shape={bboxes.shape}') + print(f'bboxes={bboxes}') + if len(bboxes) > 0: + pts = ((bboxes[:, 0:8] + 0.5).reshape(len(bboxes), -1, + 2).astype(int)) + cv2.polylines(img, pts, True, (0, 255, 0), 2) + cv2.imwrite('output_ocr.png', img) - pts = (bboxes[:, 0:8] + 0.5).reshape(len(bboxes), -1, 2).astype(int) - cv2.polylines(img, pts, True, (0, 255, 0), 2) - cv2.imwrite('output_ocr.png', img) + if len(bboxes) > 0 and args.textrecog: + recognizer = TextRecognizer( + model_path=args.textrecog, + device_name=args.device_name, + device_id=0, + ) + texts = recognizer(img, bboxes.flatten().tolist()) + print(texts) - if args.textrecog: - print('API of TextRecognizer does not support bbox as argument yet') + elif args.textrecog: + recognizer = TextRecognizer( + model_path=args.textrecog, + device_name=args.device_name, + device_id=0, + ) + texts = recognizer(img) + print(texts) + else: + print('do nothing since neither text detection sdk model or ' + 'text recognition sdk model in input') if __name__ == '__main__': diff --git a/demo/python/pose_detection.py b/demo/python/pose_detection.py index d5656b5af..2eebd12bb 100644 --- a/demo/python/pose_detection.py +++ b/demo/python/pose_detection.py @@ -9,16 +9,17 @@ from mmdeploy_python import PoseDetector def parse_args(): parser = argparse.ArgumentParser( description='show how to use sdk python api') + parser.add_argument('device_name', help='name of device, cuda or cpu') parser.add_argument( - 'model_path', help='the directory path of mmdeploy model') - parser.add_argument('image_path', help='the path of an image') + 'model_path', + help='path of mmdeploy SDK model dumped by model converter') + parser.add_argument('image_path', help='path of an image') parser.add_argument( '--bbox', default=None, nargs='+', + type=int, help='bounding box of an object in format (x, y, w, h)') - parser.add_argument( - '--device-name', default='cpu', help='the name of device, cuda or cpu') args = parser.parse_args() return args @@ -28,17 +29,18 @@ def main(): img = cv2.imread(args.image_path) - bboxes = [] + detector = PoseDetector( + model_path=args.model_path, device_name=args.device_name, device_id=0) + if args.bbox is None: - bbox = [0, 0, img.shape[1], img.shape[0]] + result = detector(img) else: - # x, y, w, h -> left, top, right, bottom + # converter (x, y, w, h) -> (left, top, right, bottom) + print(args.bbox) bbox = np.array(args.bbox, dtype=int) bbox[2:] += bbox[:2] - bboxes.append(bbox) - - detector = PoseDetector(args.model_path, args.device_name, 0) - result = detector([img], [bboxes])[0] + result = detector(img, bbox) + print(result) _, point_num, _ = result.shape points = result[:, :, :2].reshape(point_num, 2) diff --git a/demo/python/rotated_object_detection.py b/demo/python/rotated_object_detection.py index 3ac288cb4..4f02d5d1f 100644 --- a/demo/python/rotated_object_detection.py +++ b/demo/python/rotated_object_detection.py @@ -10,11 +10,10 @@ from mmdeploy_python import RotatedDetector def parse_args(): parser = argparse.ArgumentParser( description='show how to use sdk python api') + parser.add_argument('device_name', help='name of device, cuda or cpu') parser.add_argument( - 'model_path', help='the directory path of mmdeploy model') - parser.add_argument('image_path', help='the path of an image') - parser.add_argument( - '--device-name', default='cpu', help='the name of device, cuda or cpu') + 'model_path', help='path of SDK model dumped by model converter') + parser.add_argument('image_path', help='path of an image') args = parser.parse_args() return args @@ -23,9 +22,10 @@ def main(): args = parse_args() img = cv2.imread(args.image_path) - detector = RotatedDetector(args.model_path, args.device_name, 0) - rbboxes, labels = detector([img])[0] - # print(rbboxes, labels) + detector = RotatedDetector( + model_path=args.model_path, device_name=args.device_name, device_id=0) + rbboxes, labels = detector(img) + indices = [i for i in range(len(rbboxes))] for index, rbbox, label_id in zip(indices, rbboxes, labels): [cx, cy, w, h, angle], score = rbbox[0:5], rbbox[-1]