[Enhancement] support kwargs in SDK python bindings (#794)

* support-kwargs

* make '__call__' as single image inference and add 'batch' API to deal with batch images inference

* fix linting error and typo

* fix lint
This commit is contained in:
lvhan028 2022-07-28 21:32:42 -07:00 committed by GitHub
parent f80c90ed47
commit 2c18fbd2c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 196 additions and 95 deletions

View File

@ -57,9 +57,12 @@ class PyClassifier {
static void register_python_classifier(py::module &m) {
py::class_<PyClassifier>(m, "Classifier")
.def(py::init([](const char *model_path, const char *device_name, int device_id) {
return std::make_unique<PyClassifier>(model_path, device_name, device_id);
}))
.def("__call__", &PyClassifier::Apply);
return std::make_unique<PyClassifier>(model_path, device_name, device_id);
}),
py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0)
.def("__call__",
[](PyClassifier *self, const PyImage &img) { return self->Apply(std::vector{img})[0]; })
.def("batch", &PyClassifier::Apply);
}
class PythonClassifierRegisterer {

View File

@ -68,9 +68,14 @@ class PyDetector {
static void register_python_detector(py::module &m) {
py::class_<PyDetector>(m, "Detector")
.def(py::init([](const char *model_path, const char *device_name, int device_id) {
return std::make_unique<PyDetector>(model_path, device_name, device_id);
}))
.def("__call__", &PyDetector::Apply);
return std::make_unique<PyDetector>(model_path, device_name, device_id);
}),
py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0)
.def("__call__",
[](PyDetector *self, const PyImage &img) -> py::tuple {
return self->Apply(std::vector{img})[0];
})
.def("batch", &PyDetector::Apply);
}
class PythonDetectorRegisterer {

View File

@ -11,22 +11,22 @@ namespace mmdeploy {
using Rect = std::array<float, 4>;
class PyPoseDedector {
class PyPoseDetector {
public:
PyPoseDedector(const char *model_path, const char *device_name, int device_id) {
PyPoseDetector(const char *model_path, const char *device_name, int device_id) {
auto status =
mmdeploy_pose_detector_create_by_path(model_path, device_name, device_id, &detector_);
if (status != MMDEPLOY_SUCCESS) {
throw std::runtime_error("failed to create pose_detector");
}
}
py::list Apply(const std::vector<PyImage> &imgs, const std::vector<std::vector<Rect>> &vboxes) {
if (imgs.size() == 0 && vboxes.size() == 0) {
py::list Apply(const std::vector<PyImage> &imgs, const std::vector<std::vector<Rect>> &bboxes) {
if (imgs.size() == 0 && bboxes.size() == 0) {
return py::list{};
}
if (vboxes.size() != 0 && vboxes.size() != imgs.size()) {
if (bboxes.size() != 0 && bboxes.size() != imgs.size()) {
std::ostringstream os;
os << "imgs length not equal with vboxes [" << imgs.size() << " vs " << vboxes.size() << "]";
os << "imgs length not equal with vboxes [" << imgs.size() << " vs " << bboxes.size() << "]";
throw std::invalid_argument(os.str());
}
@ -39,7 +39,7 @@ class PyPoseDedector {
mats.push_back(mat);
}
for (auto _boxes : vboxes) {
for (auto _boxes : bboxes) {
for (auto _box : _boxes) {
mmdeploy_rect_t box = {_box[0], _box[1], _box[2], _box[3]};
boxes.push_back(box);
@ -48,7 +48,7 @@ class PyPoseDedector {
}
// full image
if (vboxes.size() == 0) {
if (bboxes.size() == 0) {
for (int i = 0; i < mats.size(); i++) {
mmdeploy_rect_t box = {0.f, 0.f, mats[i].width - 1.f, mats[i].height - 1.f};
boxes.push_back(box);
@ -89,7 +89,7 @@ class PyPoseDedector {
mmdeploy_pose_detector_release_result(detection, total);
return output;
}
~PyPoseDedector() {
~PyPoseDetector() {
mmdeploy_pose_detector_destroy(detector_);
detector_ = {};
}
@ -99,12 +99,33 @@ class PyPoseDedector {
};
static void register_python_pose_detector(py::module &m) {
py::class_<PyPoseDedector>(m, "PoseDetector")
py::class_<PyPoseDetector>(m, "PoseDetector")
.def(py::init([](const char *model_path, const char *device_name, int device_id) {
return std::make_unique<PyPoseDedector>(model_path, device_name, device_id);
}))
.def("__call__", &PyPoseDedector::Apply, py::arg("imgs"),
py::arg("vboxes") = std::vector<std::vector<Rect>>());
return std::make_unique<PyPoseDetector>(model_path, device_name, device_id);
}),
py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0)
.def("__call__",
[](PyPoseDetector *self, const PyImage &img) -> py::array {
return self->Apply({img}, {})[0];
})
.def(
"__call__",
[](PyPoseDetector *self, const PyImage &img, const Rect &box) -> py::array {
std::vector<std::vector<Rect>> bboxes;
bboxes.push_back({box});
return self->Apply({img}, bboxes)[0];
},
py::arg("img"), py::arg("box"))
.def(
"__call__",
[](PyPoseDetector *self, const PyImage &img, const std::vector<Rect> &bboxes) {
std::vector<std::vector<Rect>> _bboxes;
_bboxes.push_back(bboxes);
return self->Apply({img}, _bboxes);
},
py::arg("img"), py::arg("bboxes"))
.def("batch", &PyPoseDetector::Apply, py::arg("imgs"),
py::arg("bboxes") = std::vector<std::vector<Rect>>());
}
class PythonPoseDetectorRegisterer {

View File

@ -49,9 +49,14 @@ class PyRestorer {
static void register_python_restorer(py::module &m) {
py::class_<PyRestorer>(m, "Restorer")
.def(py::init([](const char *model_path, const char *device_name, int device_id) {
return std::make_unique<PyRestorer>(model_path, device_name, device_id);
}))
.def("__call__", &PyRestorer::Apply);
return std::make_unique<PyRestorer>(model_path, device_name, device_id);
}),
py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0)
.def("__call__",
[](PyRestorer *self, const PyImage &img) -> py::array {
return self->Apply(std::vector{img})[0];
})
.def("batch", &PyRestorer::Apply);
}
class PythonRestorerRegisterer {

View File

@ -64,9 +64,14 @@ class PyRotatedDetector {
static void register_python_rotated_detector(py::module &m) {
py::class_<PyRotatedDetector>(m, "RotatedDetector")
.def(py::init([](const char *model_path, const char *device_name, int device_id) {
return std::make_unique<PyRotatedDetector>(model_path, device_name, device_id);
}))
.def("__call__", &PyRotatedDetector::Apply);
return std::make_unique<PyRotatedDetector>(model_path, device_name, device_id);
}),
py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0)
.def("__call__",
[](PyRotatedDetector *self, const PyImage &img) -> py::tuple {
return self->Apply(std::vector{img})[0];
})
.def("batch", &PyRotatedDetector::Apply);
}
class PythonRotatedDetectorRegisterer {

View File

@ -50,9 +50,14 @@ class PySegmentor {
static void register_python_segmentor(py::module &m) {
py::class_<PySegmentor>(m, "Segmentor")
.def(py::init([](const char *model_path, const char *device_name, int device_id) {
return std::make_unique<PySegmentor>(model_path, device_name, device_id);
}))
.def("__call__", &PySegmentor::Apply);
return std::make_unique<PySegmentor>(model_path, device_name, device_id);
}),
py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0)
.def("__call__",
[](PySegmentor *self, const PyImage &img) -> py::array {
return self->Apply(std::vector{img})[0];
})
.def("batch", &PySegmentor::Apply);
}
class PythonSegmentorRegisterer {

View File

@ -58,9 +58,14 @@ class PyTextDetector {
static void register_python_text_detector(py::module &m) {
py::class_<PyTextDetector>(m, "TextDetector")
.def(py::init([](const char *model_path, const char *device_name, int device_id) {
return std::make_unique<PyTextDetector>(model_path, device_name, device_id);
}))
.def("__call__", &PyTextDetector::Apply);
return std::make_unique<PyTextDetector>(model_path, device_name, device_id);
}),
py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0)
.def("__call__",
[](PyTextDetector *self, const PyImage &img) -> py::array {
return self->Apply(std::vector{img})[0];
})
.def("batch", &PyTextDetector::Apply);
}
class PythonTextDetectorRegisterer {

View File

@ -36,6 +36,27 @@ class PyTextRecognizer {
mmdeploy_text_recognizer_release_result(results, (int)mats.size());
return output;
}
std::vector<std::tuple<std::string, std::vector<float>>> Apply(const PyImage &img,
const std::vector<float> &bboxes) {
if (bboxes.size() * sizeof(float) % sizeof(mmdeploy_text_detection_t)) {
throw std::invalid_argument("bboxes is not a list of 'mmdeploy_text_detection_t'");
}
auto mat = GetMat(img);
int bbox_count = bboxes.size() * sizeof(float) / sizeof(mmdeploy_text_detection_t);
mmdeploy_text_recognition_t *results{};
auto status = mmdeploy_text_recognizer_apply_bbox(
recognizer_, &mat, 1, (mmdeploy_text_detection_t *)bboxes.data(), &bbox_count, &results);
if (status != MMDEPLOY_SUCCESS) {
throw std::runtime_error("failed to apply text_recognizer, code: " + std::to_string(status));
}
auto output = std::vector<std::tuple<std::string, std::vector<float>>>{};
for (int i = 0; i < bbox_count; ++i) {
std::vector<float> score(results[i].score, results[i].score + results[i].length);
output.emplace_back(results[i].text, std::move(score));
}
mmdeploy_text_recognizer_release_result(results, bbox_count);
return output;
}
~PyTextRecognizer() {
mmdeploy_text_recognizer_destroy(recognizer_);
recognizer_ = {};
@ -48,9 +69,14 @@ class PyTextRecognizer {
static void register_python_text_recognizer(py::module &m) {
py::class_<PyTextRecognizer>(m, "TextRecognizer")
.def(py::init([](const char *model_path, const char *device_name, int device_id) {
return std::make_unique<PyTextRecognizer>(model_path, device_name, device_id);
}))
.def("__call__", &PyTextRecognizer::Apply);
return std::make_unique<PyTextRecognizer>(model_path, device_name, device_id);
}),
py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0)
.def("__call__", [](PyTextRecognizer *self,
const PyImage &img) { return self->Apply(std::vector{img})[0]; })
.def("__call__", [](PyTextRecognizer *self, const PyImage &img,
const std::vector<float> &bboxes) { return self->Apply(img, bboxes); })
.def("batch", py::overload_cast<const std::vector<PyImage> &>(&PyTextRecognizer::Apply));
}
class PythonTextRecognizerRegisterer {

View File

@ -8,11 +8,11 @@ from mmdeploy_python import Classifier
def parse_args():
parser = argparse.ArgumentParser(
description='show how to use sdk python api')
parser.add_argument('device_name', help='name of device, cuda or cpu')
parser.add_argument(
'model_path', help='the directory path of mmdeploy model')
parser.add_argument('image_path', help='the path of an image')
parser.add_argument(
'--device-name', default='cpu', help='the name of device, cuda or cpu')
'model_path',
help='path of mmdeploy SDK model dumped by model converter')
parser.add_argument('image_path', help='path of an image')
args = parser.parse_args()
return args
@ -21,9 +21,10 @@ def main():
args = parse_args()
img = cv2.imread(args.image_path)
classifier = Classifier(args.model_path, args.device_name, 0)
result = classifier([img])
for label_id, score in result[0]:
classifier = Classifier(
model_path=args.model_path, device_name=args.device_name, device_id=0)
result = classifier(img)
for label_id, score in result:
print(label_id, score)

View File

@ -8,11 +8,10 @@ from mmdeploy_python import Restorer
def parse_args():
parser = argparse.ArgumentParser(
description='show how to use sdk python api')
parser.add_argument('device_name', help='name of device, cuda or cpu')
parser.add_argument(
'model_path', help='the directory path of mmdeploy model')
parser.add_argument('image_path', help='the path of an image')
parser.add_argument(
'--device-name', default='cpu', help='the name of device, cuda or cpu')
'model_path', help='path of SDK model dumped by model converter')
parser.add_argument('image_path', help='path of an image')
args = parser.parse_args()
return args
@ -22,8 +21,9 @@ def main():
img = cv2.imread(args.image_path)
restorer = Restorer(args.model_path, args.device_name, 0)
result = restorer([img])[0]
restorer = Restorer(
model_path=args.model_path, device_name=args.device_name, device_id=0)
result = restorer(img)
# convert to BGR
result = result[..., ::-1]

View File

@ -9,11 +9,11 @@ from mmdeploy_python import Segmentor
def parse_args():
parser = argparse.ArgumentParser(
description='show how to use sdk python api')
parser.add_argument('device_name', help='name of device, cuda or cpu')
parser.add_argument(
'model_path', help='the directory path of mmdeploy model')
parser.add_argument('image_path', help='the path of an image')
parser.add_argument(
'--device-name', default='cpu', help='the name of device, cuda or cpu')
'model_path',
help='path of mmdeploy SDK model dumped by model converter')
parser.add_argument('image_path', help='path of an image')
args = parser.parse_args()
return args
@ -32,8 +32,9 @@ def main():
img = cv2.imread(args.image_path)
segmentor = Segmentor(args.model_path, args.device_name, 0)
seg = segmentor([img])[0]
segmentor = Segmentor(
model_path=args.model_path, device_name=args.device_name, device_id=0)
seg = segmentor(img)
palette = get_palette()
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)

View File

@ -2,18 +2,17 @@
import argparse
import cv2
import numpy as np
from mmdeploy_python import Detector
def parse_args():
parser = argparse.ArgumentParser(
description='show how to use sdk python api')
parser.add_argument('device_name', help='name of device, cuda or cpu')
parser.add_argument(
'model_path', help='the directory path of mmdeploy model')
parser.add_argument('image_path', help='the path of an image')
parser.add_argument(
'--device-name', default='cpu', help='the name of device, cuda or cpu')
'model_path',
help='path of mmdeploy SDK model dumped by model converter')
parser.add_argument('image_path', help='path of an image')
args = parser.parse_args()
return args
@ -22,11 +21,9 @@ def main():
args = parse_args()
img = cv2.imread(args.image_path)
detector = Detector(args.model_path, args.device_name, 0)
bboxes, labels, masks = detector([img])[0]
assert (isinstance(bboxes, np.ndarray))
assert (isinstance(labels, np.ndarray))
assert (isinstance(masks, list))
detector = Detector(
model_path=args.model_path, device_name=args.device_name, device_id=0)
bboxes, labels, masks = detector(img)
indices = [i for i in range(len(bboxes))]
for index, bbox, label_id in zip(indices, bboxes, labels):

View File

@ -2,23 +2,26 @@
import argparse
import cv2
from mmdeploy_python import TextDetector
from mmdeploy_python import TextDetector, TextRecognizer
def parse_args():
parser = argparse.ArgumentParser(
description='show how to use sdk python api')
parser.add_argument('device_name', help='name of device, cuda or cpu')
parser.add_argument('image_path', help='path of an image')
parser.add_argument(
'--textdet',
default='',
help='the directory path of mmdeploy text-detector sdk model')
help='path of mmdeploy text-detector SDK model dumped by'
'model converter',
)
parser.add_argument(
'--textrecog',
default='',
help='the directory path of mmdeploy text-recognizer sdk model')
parser.add_argument('image_path', help='the path of an image')
parser.add_argument(
'--device-name', default='cpu', help='the name of device, cuda or cpu')
help='path of mmdeploy text-recognizer SDK model dumped by'
'model converter',
)
args = parser.parse_args()
return args
@ -29,15 +32,37 @@ def main():
img = cv2.imread(args.image_path)
if args.textdet:
detector = TextDetector(args.textdet, args.device_name, 0)
bboxes = detector([img])[0]
detector = TextDetector(
model_path=args.textdet, device_name=args.device_name, device_id=0)
bboxes = detector(img)
print(f'bboxes.shape={bboxes.shape}')
print(f'bboxes={bboxes}')
if len(bboxes) > 0:
pts = ((bboxes[:, 0:8] + 0.5).reshape(len(bboxes), -1,
2).astype(int))
cv2.polylines(img, pts, True, (0, 255, 0), 2)
cv2.imwrite('output_ocr.png', img)
pts = (bboxes[:, 0:8] + 0.5).reshape(len(bboxes), -1, 2).astype(int)
cv2.polylines(img, pts, True, (0, 255, 0), 2)
cv2.imwrite('output_ocr.png', img)
if len(bboxes) > 0 and args.textrecog:
recognizer = TextRecognizer(
model_path=args.textrecog,
device_name=args.device_name,
device_id=0,
)
texts = recognizer(img, bboxes.flatten().tolist())
print(texts)
if args.textrecog:
print('API of TextRecognizer does not support bbox as argument yet')
elif args.textrecog:
recognizer = TextRecognizer(
model_path=args.textrecog,
device_name=args.device_name,
device_id=0,
)
texts = recognizer(img)
print(texts)
else:
print('do nothing since neither text detection sdk model or '
'text recognition sdk model in input')
if __name__ == '__main__':

View File

@ -9,16 +9,17 @@ from mmdeploy_python import PoseDetector
def parse_args():
parser = argparse.ArgumentParser(
description='show how to use sdk python api')
parser.add_argument('device_name', help='name of device, cuda or cpu')
parser.add_argument(
'model_path', help='the directory path of mmdeploy model')
parser.add_argument('image_path', help='the path of an image')
'model_path',
help='path of mmdeploy SDK model dumped by model converter')
parser.add_argument('image_path', help='path of an image')
parser.add_argument(
'--bbox',
default=None,
nargs='+',
type=int,
help='bounding box of an object in format (x, y, w, h)')
parser.add_argument(
'--device-name', default='cpu', help='the name of device, cuda or cpu')
args = parser.parse_args()
return args
@ -28,17 +29,18 @@ def main():
img = cv2.imread(args.image_path)
bboxes = []
detector = PoseDetector(
model_path=args.model_path, device_name=args.device_name, device_id=0)
if args.bbox is None:
bbox = [0, 0, img.shape[1], img.shape[0]]
result = detector(img)
else:
# x, y, w, h -> left, top, right, bottom
# converter (x, y, w, h) -> (left, top, right, bottom)
print(args.bbox)
bbox = np.array(args.bbox, dtype=int)
bbox[2:] += bbox[:2]
bboxes.append(bbox)
detector = PoseDetector(args.model_path, args.device_name, 0)
result = detector([img], [bboxes])[0]
result = detector(img, bbox)
print(result)
_, point_num, _ = result.shape
points = result[:, :, :2].reshape(point_num, 2)

View File

@ -10,11 +10,10 @@ from mmdeploy_python import RotatedDetector
def parse_args():
parser = argparse.ArgumentParser(
description='show how to use sdk python api')
parser.add_argument('device_name', help='name of device, cuda or cpu')
parser.add_argument(
'model_path', help='the directory path of mmdeploy model')
parser.add_argument('image_path', help='the path of an image')
parser.add_argument(
'--device-name', default='cpu', help='the name of device, cuda or cpu')
'model_path', help='path of SDK model dumped by model converter')
parser.add_argument('image_path', help='path of an image')
args = parser.parse_args()
return args
@ -23,9 +22,10 @@ def main():
args = parse_args()
img = cv2.imread(args.image_path)
detector = RotatedDetector(args.model_path, args.device_name, 0)
rbboxes, labels = detector([img])[0]
# print(rbboxes, labels)
detector = RotatedDetector(
model_path=args.model_path, device_name=args.device_name, device_id=0)
rbboxes, labels = detector(img)
indices = [i for i in range(len(rbboxes))]
for index, rbbox, label_id in zip(indices, rbboxes, labels):
[cx, cy, w, h, angle], score = rbbox[0:5], rbbox[-1]