[Enhancement] Optimize C++ demos (#1715)

* optimize demos

* show text in image

* optimize demos

* fix minor

* fix minor

* fix minor

* install utils & fix demo file extensions

* rename

* parse empty flags

* antialias

* handle video complications

(cherry picked from commit 2b18596795b4028e175e0d80a5385efab317409f)
This commit is contained in:
Li Zhang 2023-02-07 19:08:46 +08:00 committed by zhangli
parent 1f56eea807
commit 682cb79bc5
20 changed files with 1493 additions and 537 deletions

View File

@ -36,7 +36,7 @@ typedef struct mmdeploy_pose_tracker_param_t {
int32_t pose_max_num_bboxes;
// threshold for visible key-points, default = 0.5
float pose_kpt_thr;
// min number of key-points for valid poses, default = -1
// min number of key-points for valid poses (-1 indicates ceil(n_kpts/2)), default = -1
int32_t pose_min_keypoints;
// scale for expanding key-points to bbox, default = 1.25
float pose_bbox_scale;

View File

@ -24,4 +24,5 @@ install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/mmdeploy/common.hpp
install(DIRECTORY ${CMAKE_SOURCE_DIR}/demo/csrc/ DESTINATION example/cpp
FILES_MATCHING
PATTERN "*.cxx"
PATTERN "*.h"
)

View File

@ -22,7 +22,7 @@ using Points = vector<cv::Point2f>;
using Score = float;
using Scores = vector<float>;
#define POSE_TRACKER_DEBUG(...) MMDEPLOY_INFO(__VA_ARGS__)
#define POSE_TRACKER_DEBUG(...) MMDEPLOY_DEBUG(__VA_ARGS__)
// opencv3 can't construct cv::Mat from std::array
template <size_t N>

View File

@ -1,31 +1,43 @@
#include "mmdeploy/classifier.hpp"
#include <string>
#include "opencv2/imgcodecs/imgcodecs.hpp"
#include "utils/argparse.h"
#include "utils/visualize.h"
DEFINE_ARG_string(model, "Model path");
DEFINE_ARG_string(image, "Input image path");
DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")");
DEFINE_string(output, "classifier_output.jpg", "Output image path");
int main(int argc, char* argv[]) {
if (argc != 4) {
fprintf(stderr, "usage:\n image_classification device_name model_path image_path\n");
return 1;
}
auto device_name = argv[1];
auto model_path = argv[2];
auto image_path = argv[3];
cv::Mat img = cv::imread(image_path);
if (!img.data) {
fprintf(stderr, "failed to load image: %s\n", image_path);
return 1;
if (!utils::ParseArguments(argc, argv)) {
return -1;
}
mmdeploy::Model model(model_path);
mmdeploy::Classifier classifier(model, mmdeploy::Device{device_name, 0});
cv::Mat img = cv::imread(ARGS_image);
if (img.empty()) {
fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str());
return -1;
}
auto res = classifier.Apply(img);
// construct a classifier instance
mmdeploy::Classifier classifier(mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device});
for (const auto& cls : res) {
fprintf(stderr, "label: %d, score: %.4f\n", cls.label_id, cls.score);
// apply the classifier; the result is an array-like class holding references to
// `mmdeploy_classification_t`, will be released automatically on destruction
mmdeploy::Classifier::Result result = classifier.Apply(img);
// visualize results
utils::Visualize v;
auto sess = v.get_session(img);
int count = 0;
for (const mmdeploy_classification_t& cls : result) {
sess.add_label(cls.label_id, cls.score, count++);
}
if (!FLAGS_output.empty()) {
cv::imwrite(FLAGS_output, sess.get());
}
return 0;

View File

@ -1,113 +0,0 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <iostream>
#include "mmdeploy/detector.hpp"
#include "mmdeploy/pose_detector.hpp"
#include "opencv2/imgcodecs/imgcodecs.hpp"
#include "opencv2/imgproc/imgproc.hpp"
using std::vector;
cv::Mat Visualize(cv::Mat frame, const std::vector<mmdeploy_pose_detection_t>& poses, int size);
int main(int argc, char* argv[]) {
const auto device_name = argv[1];
const auto det_model_path = argv[2];
const auto pose_model_path = argv[3];
const auto image_path = argv[4];
if (argc != 5) {
std::cerr << "usage:\n\tpose_tracker device_name det_model_path pose_model_path image_path\n";
return -1;
}
auto img = cv::imread(image_path);
if (!img.data) {
std::cerr << "failed to load image: " << image_path << "\n";
return -1;
}
using namespace mmdeploy;
Context context(Device{device_name}); // create context for holding the device handle
Detector detector(Model(det_model_path), context); // create object detector
PoseDetector pose(Model(pose_model_path), context); // create pose detector
// apply detector
auto dets = detector.Apply(img);
// filter detections and extract bboxes for pose model
std::vector<mmdeploy_rect_t> bboxes;
for (const auto& det : dets) {
if (det.label_id == 0 && det.score > .6f) {
bboxes.push_back(det.bbox);
}
}
// apply pose detector
auto poses = pose.Apply(img, bboxes);
// visualize
auto vis = Visualize(img, {poses.begin(), poses.end()}, 1280);
cv::imwrite("det_pose_output.jpg", vis);
return 0;
}
struct Skeleton {
vector<std::pair<int, int>> skeleton;
vector<cv::Scalar> palette;
vector<int> link_color;
vector<int> point_color;
};
const Skeleton& gCocoSkeleton() {
static const Skeleton inst{
{
{15, 13}, {13, 11}, {16, 14}, {14, 12}, {11, 12}, {5, 11}, {6, 12},
{5, 6}, {5, 7}, {6, 8}, {7, 9}, {8, 10}, {1, 2}, {0, 1},
{0, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6},
},
{
{255, 128, 0}, {255, 153, 51}, {255, 178, 102}, {230, 230, 0}, {255, 153, 255},
{153, 204, 255}, {255, 102, 255}, {255, 51, 255}, {102, 178, 255}, {51, 153, 255},
{255, 153, 153}, {255, 102, 102}, {255, 51, 51}, {153, 255, 153}, {102, 255, 102},
{51, 255, 51}, {0, 255, 0}, {0, 0, 255}, {255, 0, 0}, {255, 255, 255},
},
{0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16},
{16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0},
};
return inst;
}
cv::Mat Visualize(cv::Mat frame, const vector<mmdeploy_pose_detection_t>& poses, int size) {
auto& [skeleton, palette, link_color, point_color] = gCocoSkeleton();
auto scale = (float)size / (float)std::max(frame.cols, frame.rows);
if (scale != 1) {
cv::resize(frame, frame, {}, scale, scale);
} else {
frame = frame.clone();
}
for (const auto& pose : poses) {
vector<float> kpts(&pose.point[0].x, &pose.point[pose.length - 1].y + 1);
vector<float> scores(pose.score, pose.score + pose.length);
std::for_each(kpts.begin(), kpts.end(), [&](auto& x) { x *= scale; });
constexpr auto score_thr = .5f;
vector<int> used(kpts.size());
for (size_t i = 0; i < skeleton.size(); ++i) {
auto [u, v] = skeleton[i];
if (scores[u] > score_thr && scores[v] > score_thr) {
used[u] = used[v] = 1;
cv::Point p_u(kpts[u * 2], kpts[u * 2 + 1]);
cv::Point p_v(kpts[v * 2], kpts[v * 2 + 1]);
cv::line(frame, p_u, p_v, palette[link_color[i]], 1, cv::LINE_AA);
}
}
for (size_t i = 0; i < kpts.size(); i += 2) {
if (used[i / 2]) {
cv::Point p(kpts[i], kpts[i + 1]);
cv::circle(frame, p, 1, palette[point_color[i / 2]], 2, cv::LINE_AA);
}
}
}
return frame;
}

View File

@ -0,0 +1,75 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include <iostream>
#include "mmdeploy/detector.hpp"
#include "mmdeploy/pose_detector.hpp"
#include "opencv2/imgcodecs/imgcodecs.hpp"
#include "utils/argparse.h"
#include "utils/visualize.h"
DEFINE_ARG_string(det_model, "Object detection model path");
DEFINE_ARG_string(pose_model, "Pose estimation model path");
DEFINE_ARG_string(image, "Input image path");
DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")");
DEFINE_string(output, "det_pose_output.jpg", "Output image path");
DEFINE_string(skeleton, "coco", R"(Path to skeleton data or name of predefined skeletons: "coco")");
DEFINE_int32(det_label, 0, "Detection label use for pose estimation");
DEFINE_double(det_thr, .5, "Detection score threshold");
DEFINE_double(det_min_bbox_size, -1, "Detection minimum bbox size");
DEFINE_double(pose_thr, 0, "Pose key-point threshold");
int main(int argc, char* argv[]) {
if (!utils::ParseArguments(argc, argv)) {
return -1;
}
cv::Mat img = cv::imread(ARGS_image);
if (img.empty()) {
fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str());
return -1;
}
mmdeploy::Device device{FLAGS_device};
// create object detector
mmdeploy::Detector detector(mmdeploy::Model(ARGS_det_model), device);
// create pose detector
mmdeploy::PoseDetector pose(mmdeploy::Model(ARGS_pose_model), device);
// apply the detector, the result is an array-like class holding references to
// `mmdeploy_detection_t`, will be released automatically on destruction
mmdeploy::Detector::Result dets = detector.Apply(img);
// filter detections and extract bboxes for pose model
std::vector<mmdeploy_rect_t> bboxes;
for (const mmdeploy_detection_t& det : dets) {
if (det.label_id == FLAGS_det_label && det.score > FLAGS_det_thr) {
bboxes.push_back(det.bbox);
}
}
// apply pose detector, if no bboxes are provided, full image will be used; the result is an
// array-like class holding references to `mmdeploy_pose_detection_t`, will be released
// automatically on destruction
mmdeploy::PoseDetector::Result poses = pose.Apply(img, bboxes);
assert(bboxes.size() == poses.size());
// visualize results
utils::Visualize v;
v.set_skeleton(utils::Skeleton::get(FLAGS_skeleton));
auto sess = v.get_session(img);
for (size_t i = 0; i < bboxes.size(); ++i) {
sess.add_bbox(bboxes[i], -1, -1);
sess.add_pose(poses[i].point, poses[i].score, poses[i].length, FLAGS_pose_thr);
}
if (!FLAGS_output.empty()) {
cv::imwrite(FLAGS_output, sess.get());
}
return 0;
}

View File

@ -1,69 +1,47 @@
#include "mmdeploy/detector.hpp"
#include <opencv2/imgcodecs/imgcodecs.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <string>
#include "opencv2/imgcodecs/imgcodecs.hpp"
#include "utils/argparse.h"
#include "utils/visualize.h"
DEFINE_ARG_string(model, "Model path");
DEFINE_ARG_string(image, "Input image path");
DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")");
DEFINE_string(output, "detector_output.jpg", "Output image path");
DEFINE_double(det_thr, .5, "Detection score threshold");
int main(int argc, char* argv[]) {
if (argc != 4) {
fprintf(stderr, "usage:\n object_detection device_name model_path image_path\n");
return 1;
}
auto device_name = argv[1];
auto model_path = argv[2];
auto image_path = argv[3];
cv::Mat img = cv::imread(image_path);
if (!img.data) {
fprintf(stderr, "failed to load image: %s\n", image_path);
return 1;
if (!utils::ParseArguments(argc, argv)) {
return -1;
}
mmdeploy::Model model(model_path);
mmdeploy::Detector detector(model, mmdeploy::Device{device_name, 0});
auto dets = detector.Apply(img);
fprintf(stdout, "bbox_count=%d\n", (int)dets.size());
for (int i = 0; i < dets.size(); ++i) {
const auto& box = dets[i].bbox;
const auto& mask = dets[i].mask;
fprintf(stdout, "box %d, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f, label=%d, score=%.4f\n",
i, box.left, box.top, box.right, box.bottom, dets[i].label_id, dets[i].score);
// skip detections with invalid bbox size (bbox height or width < 1)
if ((box.right - box.left) < 1 || (box.bottom - box.top) < 1) {
continue;
}
// skip detections less than specified score threshold
if (dets[i].score < 0.3) {
continue;
}
// generate mask overlay if model exports masks
if (mask != nullptr) {
fprintf(stdout, "mask %d, height=%d, width=%d\n", i, mask->height, mask->width);
cv::Mat imgMask(mask->height, mask->width, CV_8UC1, &mask->data[0]);
auto x0 = std::max(std::floor(box.left) - 1, 0.f);
auto y0 = std::max(std::floor(box.top) - 1, 0.f);
cv::Rect roi((int)x0, (int)y0, mask->width, mask->height);
// split the RGB channels, overlay mask to a specific color channel
cv::Mat ch[3];
split(img, ch);
int col = 0; // int col = i % 3;
cv::bitwise_or(imgMask, ch[col](roi), ch[col](roi));
merge(ch, 3, img);
}
cv::rectangle(img, cv::Point{(int)box.left, (int)box.top},
cv::Point{(int)box.right, (int)box.bottom}, cv::Scalar{0, 255, 0});
cv::Mat img = cv::imread(ARGS_image);
if (img.empty()) {
fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str());
return -1;
}
cv::imwrite("output_detection.png", img);
// construct a detector instance
mmdeploy::Detector detector(mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device});
// apply the detector, the result is an array-like class holding references to
// `mmdeploy_detection_t`, will be released automatically on destruction
mmdeploy::Detector::Result dets = detector.Apply(img);
// visualize
utils::Visualize v;
auto sess = v.get_session(img);
int count = 0;
for (const mmdeploy_detection_t& det : dets) {
if (det.score > FLAGS_det_thr) { // filter bboxes
sess.add_det(det.bbox, det.label_id, det.score, det.mask, count++);
}
}
if (!FLAGS_output.empty()) {
cv::imwrite(FLAGS_output, sess.get());
}
return 0;
}

View File

@ -1,180 +0,0 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "mmdeploy/pose_tracker.hpp"
#include <iostream>
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgcodecs/imgcodecs.hpp"
#include "opencv2/imgproc/imgproc.hpp"
#include "opencv2/videoio/videoio.hpp"
struct Args {
std::string device;
std::string det_model;
std::string pose_model;
std::string video;
std::string output_dir;
};
Args ParseArgs(int argc, char* argv[]);
using std::vector;
using namespace mmdeploy;
bool Visualize(cv::Mat frame, const PoseTracker::Result& result, int size,
const std::string& output_dir, int frame_id, bool with_bbox);
int main(int argc, char* argv[]) {
auto args = ParseArgs(argc, argv);
if (args.device.empty()) {
return 0;
}
// create pose tracker pipeline
PoseTracker tracker(Model(args.det_model), Model(args.pose_model), Context{Device{args.device}});
// set parameters
PoseTracker::Params params;
params->det_min_bbox_size = 100;
params->det_interval = 1;
params->pose_max_num_bboxes = 6;
// optionally use OKS for keypoints similarity comparison
std::array<float, 17> sigmas{0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072,
0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089};
params->keypoint_sigmas = sigmas.data();
params->keypoint_sigmas_size = sigmas.size();
// create a tracker state for each video
PoseTracker::State state = tracker.CreateState(params);
cv::VideoCapture video;
if (args.video.size() == 1 && std::isdigit(args.video[0])) {
video.open(std::stoi(args.video)); // open by camera index
} else {
video.open(args.video); // open video file
}
if (!video.isOpened()) {
std::cerr << "failed to open video: " << args.video << "\n";
}
cv::Mat frame;
int frame_id = 0;
while (true) {
video >> frame;
if (!frame.data) {
break;
}
// apply the pipeline with the tracker state and video frame
auto result = tracker.Apply(state, frame);
// visualize the results
if (!Visualize(frame, result, 1280, args.output_dir, frame_id++, false)) {
break;
}
}
return 0;
}
struct Skeleton {
vector<std::pair<int, int>> skeleton;
vector<cv::Scalar> palette;
vector<int> link_color;
vector<int> point_color;
};
const Skeleton& gCocoSkeleton() {
static const Skeleton inst{
{
{15, 13}, {13, 11}, {16, 14}, {14, 12}, {11, 12}, {5, 11}, {6, 12},
{5, 6}, {5, 7}, {6, 8}, {7, 9}, {8, 10}, {1, 2}, {0, 1},
{0, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6},
},
{
{255, 128, 0}, {255, 153, 51}, {255, 178, 102}, {230, 230, 0}, {255, 153, 255},
{153, 204, 255}, {255, 102, 255}, {255, 51, 255}, {102, 178, 255}, {51, 153, 255},
{255, 153, 153}, {255, 102, 102}, {255, 51, 51}, {153, 255, 153}, {102, 255, 102},
{51, 255, 51}, {0, 255, 0}, {0, 0, 255}, {255, 0, 0}, {255, 255, 255},
},
{0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16},
{16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0},
};
return inst;
}
bool Visualize(cv::Mat frame, const PoseTracker::Result& result, int size,
const std::string& output_dir, int frame_id, bool with_bbox) {
auto& [skeleton, palette, link_color, point_color] = gCocoSkeleton();
auto scale = (float)size / (float)std::max(frame.cols, frame.rows);
if (scale != 1) {
cv::resize(frame, frame, {}, scale, scale);
} else {
frame = frame.clone();
}
auto draw_bbox = [&](std::array<float, 4> bbox, const cv::Scalar& color) {
std::for_each(bbox.begin(), bbox.end(), [&](auto& x) { x *= scale; });
cv::rectangle(frame, cv::Point2f(bbox[0], bbox[1]), cv::Point2f(bbox[2], bbox[3]), color);
};
for (const auto& r : result) {
vector<float> kpts(&r.keypoints[0].x, &r.keypoints[0].x + r.keypoint_count * 2);
vector<float> scores(r.scores, r.scores + r.keypoint_count);
std::for_each(kpts.begin(), kpts.end(), [&](auto& x) { x *= scale; });
constexpr auto score_thr = .5f;
vector<int> used(kpts.size());
for (size_t i = 0; i < skeleton.size(); ++i) {
auto [u, v] = skeleton[i];
if (scores[u] > score_thr && scores[v] > score_thr) {
used[u] = used[v] = 1;
cv::Point2f p_u(kpts[u * 2], kpts[u * 2 + 1]);
cv::Point2f p_v(kpts[v * 2], kpts[v * 2 + 1]);
cv::line(frame, p_u, p_v, palette[link_color[i]], 1, cv::LINE_AA);
}
}
for (size_t i = 0; i < kpts.size(); i += 2) {
if (used[i / 2]) {
cv::Point2f p(kpts[i], kpts[i + 1]);
cv::circle(frame, p, 1, palette[point_color[i / 2]], 2, cv::LINE_AA);
}
}
if (with_bbox) {
draw_bbox((std::array<float, 4>&)r.bbox, cv::Scalar(0, 255, 0));
}
}
if (output_dir.empty()) {
cv::imshow("pose_tracker", frame);
return cv::waitKey(1) != 'q';
}
auto join = [](const std::string& a, const std::string& b) {
#if _MSC_VER
return a + "\\" + b;
#else
return a + "/" + b;
#endif
};
cv::imwrite(join(output_dir, std::to_string(frame_id) + ".jpg"), frame,
{cv::IMWRITE_JPEG_QUALITY, 90});
return true;
}
Args ParseArgs(int argc, char* argv[]) {
if (argc < 5 || argc > 6) {
std::cout << R"(Usage: pose_tracker device_name det_model pose_model video [output]
device_name device name for execution, e.g. "cpu", "cuda"
det_model object detection model path
pose_model pose estimation model path
video video path or camera index
output output directory, will cv::imshow if omitted
)";
return {};
}
Args args;
args.device = argv[1];
args.det_model = argv[2];
args.pose_model = argv[3];
args.video = argv[4];
if (argc == 6) {
args.output_dir = argv[5];
}
return args;
}

View File

@ -0,0 +1,68 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "mmdeploy/pose_tracker.hpp"
#include "utils/argparse.h"
#include "utils/mediaio.h"
#include "utils/visualize.h"
DEFINE_ARG_string(det_model, "Object detection model path");
DEFINE_ARG_string(pose_model, "Pose estimation model path");
DEFINE_ARG_string(input, "Input video path or camera index");
DEFINE_string(device, "cpu", "Device name, e.g. \"cpu\", \"cuda\"");
DEFINE_string(output, "", "Output video path or format string");
DEFINE_int32(output_size, 0, "Long-edge of output frames");
DEFINE_int32(show, 1, "Delay passed to `cv::waitKey` when using `cv::imshow`; -1: disable");
DEFINE_string(skeleton, "coco", R"(Path to skeleton data or name of predefined skeletons: "coco")");
DEFINE_string(background, "default",
R"(Output background, "default": original image, "black": black background)");
#include "pose_tracker_params.h"
int main(int argc, char* argv[]) {
if (!utils::ParseArguments(argc, argv)) {
return -1;
}
// create pose tracker pipeline
mmdeploy::PoseTracker tracker(mmdeploy::Model(ARGS_det_model), mmdeploy::Model(ARGS_pose_model),
mmdeploy::Device{FLAGS_device});
mmdeploy::PoseTracker::Params params;
// initialize tracker params with program arguments
InitTrackerParams(params);
// create a tracker state for each video
mmdeploy::PoseTracker::State state = tracker.CreateState(params);
utils::mediaio::Input input(ARGS_input);
utils::mediaio::Output output(FLAGS_output, FLAGS_show);
utils::Visualize v(FLAGS_output_size);
v.set_background(FLAGS_background);
v.set_skeleton(utils::Skeleton::get(FLAGS_skeleton));
for (const cv::Mat& frame : input) {
// apply the pipeline with the tracker state and video frame; the result is an array-like class
// holding references to `mmdeploy_pose_tracker_target_t`, will be released automatically on
// destruction
mmdeploy::PoseTracker::Result result = tracker.Apply(state, frame);
// visualize results
auto sess = v.get_session(frame);
for (const mmdeploy_pose_tracker_target_t& target : result) {
sess.add_pose(target.keypoints, target.scores, target.keypoint_count, FLAGS_pose_kpt_thr);
}
// write to output stream
if (!output.write(sess.get())) {
// user requested exit by pressing 'q'
break;
}
}
return 0;
}

View File

@ -0,0 +1,39 @@
// Copyright (c) OpenMMLab. All rights reserved.
DEFINE_int32(det_interval, 1, "Detection interval");
DEFINE_int32(det_label, 0, "Detection label use for pose estimation");
DEFINE_double(det_thr, 0.5, "Detection score threshold");
DEFINE_double(det_min_bbox_size, -1, "Detection minimum bbox size");
DEFINE_double(det_nms_thr, .7,
"NMS IOU threshold for merging detected bboxes and bboxes from tracked targets");
DEFINE_int32(pose_max_num_bboxes, -1, "Max number of bboxes used for pose estimation per frame");
DEFINE_double(pose_kpt_thr, .5, "Threshold for visible key-points");
DEFINE_int32(pose_min_keypoints, -1,
"Min number of key-points for valid poses, -1 indicates ceil(n_kpts/2)");
DEFINE_double(pose_bbox_scale, 1.25, "Scale for expanding key-points to bbox");
DEFINE_double(
pose_min_bbox_size, -1,
"Min pose bbox size, tracks with bbox size smaller than the threshold will be dropped");
DEFINE_double(pose_nms_thr, 0.5,
"NMS OKS/IOU threshold for suppressing overlapped poses, useful when multiple pose "
"estimations collapse to the same target");
DEFINE_double(track_iou_thr, 0.4, "IOU threshold for associating missing tracks");
DEFINE_int32(track_max_missing, 10,
"Max number of missing frames before a missing tracks is removed");
void InitTrackerParams(mmdeploy::PoseTracker::Params& params) {
params->det_interval = FLAGS_det_interval;
params->det_label = FLAGS_det_label;
params->det_thr = FLAGS_det_thr;
params->det_min_bbox_size = FLAGS_det_min_bbox_size;
params->pose_max_num_bboxes = FLAGS_pose_max_num_bboxes;
params->pose_kpt_thr = FLAGS_pose_kpt_thr;
params->pose_min_keypoints = FLAGS_pose_min_keypoints;
params->pose_bbox_scale = FLAGS_pose_bbox_scale;
params->pose_min_bbox_size = FLAGS_pose_min_bbox_size;
params->pose_nms_thr = FLAGS_pose_nms_thr;
params->track_iou_thr = FLAGS_track_iou_thr;
params->track_max_missing = FLAGS_track_max_missing;
}

View File

@ -2,33 +2,39 @@
#include "mmdeploy/restorer.hpp"
#include <opencv2/imgcodecs/imgcodecs.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <string>
#include "opencv2/imgcodecs/imgcodecs.hpp"
#include "opencv2/imgproc/imgproc.hpp"
#include "utils/argparse.h"
DEFINE_ARG_string(model, "Super-resolution model path");
DEFINE_ARG_string(image, "Input image path");
DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")");
DEFINE_string(output, "restorer_output.jpg", "Output image path");
int main(int argc, char* argv[]) {
if (argc != 4) {
fprintf(stderr, "usage:\n image_restorer device_name model_path image_path\n");
return 1;
}
auto device_name = argv[1];
auto model_path = argv[2];
auto image_path = argv[3];
cv::Mat img = cv::imread(image_path);
if (!img.data) {
fprintf(stderr, "failed to load image: %s\n", image_path);
return 1;
if (!utils::ParseArguments(argc, argv)) {
return -1;
}
using namespace mmdeploy;
cv::Mat img = cv::imread(ARGS_image);
if (img.empty()) {
fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str());
return -1;
}
Restorer restorer{Model{model_path}, Device{device_name}};
// construct a restorer instance
mmdeploy::Restorer restorer{mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device}};
auto result = restorer.Apply(img);
// apply restorer to the image
mmdeploy::Restorer::Result result = restorer.Apply(img);
cv::Mat sr_img(result->height, result->width, CV_8UC3, result->data);
cv::cvtColor(sr_img, sr_img, cv::COLOR_RGB2BGR);
cv::imwrite("output_restorer.bmp", sr_img);
// convert to BGR
cv::Mat upsampled(result->height, result->width, CV_8UC3, result->data);
cv::cvtColor(upsampled, upsampled, cv::COLOR_RGB2BGR);
if (!FLAGS_output.empty()) {
cv::imwrite(FLAGS_output, upsampled);
}
return 0;
}

View File

@ -1,51 +1,46 @@
#include "mmdeploy/rotated_detector.hpp"
#include <opencv2/imgcodecs.hpp>
#include <opencv2/imgproc.hpp>
#include <string>
#include "utils/argparse.h"
#include "utils/visualize.h"
DEFINE_ARG_string(model, "Model path");
DEFINE_ARG_string(image, "Input image path");
DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")");
DEFINE_string(output, "rotated_detector_output.jpg", "Output image path");
DEFINE_double(det_thr, 0.1, "Detection score threshold");
int main(int argc, char* argv[]) {
if (argc != 4) {
fprintf(stderr, "usage:\n oriented_object_detection device_name model_path image_path\n");
return 1;
}
auto device_name = argv[1];
auto model_path = argv[2];
auto image_path = argv[3];
cv::Mat img = cv::imread(image_path);
if (!img.data) {
fprintf(stderr, "failed to load image: %s\n", image_path);
return 1;
if (!utils::ParseArguments(argc, argv)) {
return -1;
}
mmdeploy::Model model(model_path);
mmdeploy::RotatedDetector detector(model, mmdeploy::Device{device_name, 0});
cv::Mat img = cv::imread(ARGS_image);
if (img.empty()) {
fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str());
return -1;
}
auto dets = detector.Apply(img);
// construct a detector instance
mmdeploy::RotatedDetector detector(mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device});
for (const auto& det : dets) {
if (det.score < 0.1) {
continue;
// apply the detector, the result is an array-like class holding references to
// `mmdeploy_rotated_detection_t`, will be released automatically on destruction
mmdeploy::RotatedDetector::Result dets = detector.Apply(img);
// visualize results
utils::Visualize v;
auto sess = v.get_session(img);
for (const mmdeploy_rotated_detection_t& det : dets) {
if (det.score > FLAGS_det_thr) {
sess.add_rotated_det(det.rbbox, det.label_id, det.score);
}
auto& rbbox = det.rbbox;
float xc = rbbox[0];
float yc = rbbox[1];
float w = rbbox[2];
float h = rbbox[3];
float ag = rbbox[4];
float wx = w / 2 * std::cos(ag);
float wy = w / 2 * std::sin(ag);
float hx = -h / 2 * std::sin(ag);
float hy = h / 2 * std::cos(ag);
cv::Point p1 = {int(xc - wx - hx), int(yc - wy - hy)};
cv::Point p2 = {int(xc + wx - hx), int(yc + wy - hy)};
cv::Point p3 = {int(xc + wx + hx), int(yc + wy + hy)};
cv::Point p4 = {int(xc - wx + hx), int(yc - wy + hy)};
cv::drawContours(img, std::vector<std::vector<cv::Point>>{{p1, p2, p3, p4}}, -1, {0, 255, 0},
2);
}
cv::imwrite("output_rotated_detection.png", img);
if (!FLAGS_output.empty()) {
cv::imwrite(FLAGS_output, sess.get());
}
return 0;
}

View File

@ -2,75 +2,46 @@
#include "mmdeploy/segmentor.hpp"
#include <fstream>
#include <numeric>
#include <opencv2/imgcodecs/imgcodecs.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <random>
#include <string>
#include <vector>
using namespace std;
#include "utils/argparse.h"
#include "utils/mediaio.h"
#include "utils/visualize.h"
vector<cv::Vec3b> gen_palette(int num_classes) {
std::mt19937 gen;
std::uniform_int_distribution<ushort> uniform_dist(0, 255);
vector<cv::Vec3b> palette;
palette.reserve(num_classes);
for (auto i = 0; i < num_classes; ++i) {
palette.emplace_back(uniform_dist(gen), uniform_dist(gen), uniform_dist(gen));
}
return palette;
}
DEFINE_ARG_string(model, "Model path");
DEFINE_ARG_string(image, "Input image path");
DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")");
DEFINE_string(output, "segmentor_output.jpg", "Output image path");
DEFINE_string(palette, "cityscapes",
R"(Path to palette data or name of predefined palettes: "cityscapes")");
int main(int argc, char* argv[]) {
if (argc != 4) {
fprintf(stderr, "usage:\n image_segmentation device_name model_path image_path\n");
return 1;
}
auto device_name = argv[1];
auto model_path = argv[2];
auto image_path = argv[3];
cv::Mat img = cv::imread(image_path);
if (!img.data) {
fprintf(stderr, "failed to load image: %s\n", image_path);
return 1;
if (!utils::ParseArguments(argc, argv)) {
return -1;
}
using namespace mmdeploy;
Segmentor segmentor{Model{model_path}, Device{device_name}};
auto result = segmentor.Apply(img);
auto palette = gen_palette(result->classes + 1);
cv::Mat color_mask = cv::Mat::zeros(result->height, result->width, CV_8UC3);
int pos = 0;
int total = color_mask.rows * color_mask.cols;
std::vector<int> idxs(result->classes);
for (auto iter = color_mask.begin<cv::Vec3b>(); iter != color_mask.end<cv::Vec3b>(); ++iter) {
// output mask
if (result->mask) {
*iter = palette[result->mask[pos++]];
}
// output score
if (result->score) {
std::iota(idxs.begin(), idxs.end(), 0);
auto k =
std::max_element(idxs.begin(), idxs.end(),
[&](int i, int j) {
return result->score[pos + i * total] < result->score[pos + j * total];
}) -
idxs.begin();
*iter = palette[k];
pos += 1;
}
cv::Mat img = cv::imread(ARGS_image);
if (img.empty()) {
fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str());
return -1;
}
img = img * 0.5 + color_mask * 0.5;
cv::imwrite("output_segmentation.png", img);
mmdeploy::Segmentor segmentor{mmdeploy::Model{ARGS_model}, mmdeploy::Device{FLAGS_device}};
// apply the detector, the result is an array-like class holding a reference to
// `mmdeploy_segmentation_t`, will be released automatically on destruction
mmdeploy::Segmentor::Result seg = segmentor.Apply(img);
// visualize
utils::Visualize v;
v.set_palette(utils::Palette::get(FLAGS_palette));
auto sess = v.get_session(img);
sess.add_mask(seg->height, seg->width, seg->classes, seg->mask, seg->score);
if (!FLAGS_output.empty()) {
cv::imwrite(FLAGS_output, sess.get());
}
return 0;
}

View File

@ -1,46 +1,57 @@
#include <opencv2/imgcodecs/imgcodecs.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <string>
#include "mmdeploy/text_detector.hpp"
#include "mmdeploy/text_recognizer.hpp"
#include "utils/argparse.h"
#include "utils/mediaio.h"
#include "utils/visualize.h"
DEFINE_ARG_string(det_model, "Text detection model path");
DEFINE_ARG_string(reg_model, "Text recognition model path");
DEFINE_ARG_string(image, "Input image path");
DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")");
DEFINE_string(output, "text_ocr_output.jpg", "Output image path");
using mmdeploy::TextDetector;
using mmdeploy::TextRecognizer;
int main(int argc, char* argv[]) {
if (argc != 5) {
fprintf(stderr, "usage:\n ocr device_name det_model_path reg_model_path image_path\n");
return 1;
}
const auto device_name = argv[1];
auto det_model_path = argv[2];
auto reg_model_path = argv[3];
auto image_path = argv[4];
cv::Mat img = cv::imread(image_path);
if (!img.data) {
fprintf(stderr, "failed to load image: %s\n", image_path);
return 1;
if (!utils::ParseArguments(argc, argv)) {
return -1;
}
using namespace mmdeploy;
TextDetector detector{Model(det_model_path), Device(device_name)};
TextRecognizer recognizer{Model(reg_model_path), Device(device_name)};
auto bboxes = detector.Apply(img);
auto texts = recognizer.Apply(img, {bboxes.begin(), bboxes.size()});
for (int i = 0; i < bboxes.size(); ++i) {
fprintf(stdout, "box[%d]: %s\n", i, texts[i].text);
std::vector<cv::Point> poly_points;
for (const auto& pt : bboxes[i].bbox) {
fprintf(stdout, "x: %.2f, y: %.2f, ", pt.x, pt.y);
poly_points.emplace_back((int)pt.x, (int)pt.y);
}
fprintf(stdout, "\n");
cv::polylines(img, poly_points, true, cv::Scalar{0, 255, 0});
cv::Mat img = cv::imread(ARGS_image);
if (img.empty()) {
fprintf(stderr, "failed to load image: %s\n", ARGS_image.c_str());
return -1;
}
cv::imwrite("output_ocr.png", img);
mmdeploy::Device device(FLAGS_device);
TextDetector detector{mmdeploy::Model(ARGS_det_model), device};
TextRecognizer recognizer{mmdeploy::Model(ARGS_reg_model), device};
// apply the detector, the result is an array-like class holding references to
// `mmdeploy_text_detection_t`, will be released automatically on destruction
TextDetector::Result bboxes = detector.Apply(img);
// apply recognizer, if no bboxes are provided, full image will be used; the result is an
// array-like class holding references to `mmdeploy_text_recognition_t`, will be released
// automatically on destruction
TextRecognizer::Result texts = recognizer.Apply(img, {bboxes.begin(), bboxes.size()});
// visualize results
utils::Visualize v;
auto sess = v.get_session(img);
for (size_t i = 0; i < bboxes.size(); ++i) {
mmdeploy_text_detection_t& bbox = bboxes[i];
mmdeploy_text_recognition_t& text = texts[i];
sess.add_text_det(bbox.bbox, bbox.score, text.text, text.length, i);
}
if (!FLAGS_output.empty()) {
cv::imwrite(FLAGS_output, sess.get());
}
return 0;
}

View File

@ -0,0 +1,272 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_ARGPARSE_H
#define MMDEPLOY_ARGPARSE_H
#include <algorithm>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#define DEFINE_int32(name, init, msg) _MMDEPLOY_DEFINE_FLAG(int32_t, name, init, msg)
#define DEFINE_double(name, init, msg) _MMDEPLOY_DEFINE_FLAG(double, name, init, msg)
#define DEFINE_string(name, init, msg) _MMDEPLOY_DEFINE_FLAG(std::string, name, init, msg)
#define DEFINE_ARG_int32(name, msg) _MMDEPLOY_DEFINE_ARG(int32_t, name, msg)
#define DEFINE_ARG_double(name, msg) _MMDEPLOY_DEFINE_ARG(double, name, msg)
#define DEFINE_ARG_string(name, msg) _MMDEPLOY_DEFINE_ARG(std::string, name, msg)
namespace utils {
class ArgParse {
public:
template <typename T>
static T Register(const std::string& type, const std::string& name, T init,
const std::string& msg, void* ptr) {
instance()._Register(type, name, msg, true, ptr);
return init;
}
template <typename T>
static T Register(const std::string& type, const std::string& name, const std::string& msg,
void* ptr) {
instance()._Register(type, name, msg, false, ptr);
return {};
}
static bool ParseArguments(int argc, char* argv[]) {
if (!instance()._Parse(argc, argv)) {
ShowUsageWithFlags(argv[0]);
return false;
}
return true;
}
static void ShowUsageWithFlags(const char* argv0) { instance()._ShowUsageWithFlags(argv0); }
private:
static ArgParse& instance() {
static ArgParse inst;
return inst;
}
struct Info {
std::string name;
std::string type;
std::string msg;
bool is_flag;
void* ptr;
};
void _Register(std::string type, const std::string& name, const std::string& msg, bool is_flag,
void* ptr) {
if (type == "std::string") {
type = "string";
} else if (type == "int32_t") {
type = "int32";
}
infos_.push_back({name, type, msg, is_flag, ptr});
}
bool _Parse(int argc, char* argv[]) {
int arg_idx{-1};
std::vector<std::string> args(infos_.size());
std::vector<int> used(infos_.size());
for (int i = 1; i < argc; ++i) {
if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) {
return false;
}
if (argv[i][0] == '-' && argv[i][1] == '-') {
// parse flag key-value pair (--x=y or --x y)
int eq{-1};
for (int k = 2; argv[i][k]; ++k) {
if (argv[i][k] == '=') {
eq = k;
break;
}
}
std::string key;
std::string val;
if (eq >= 0) {
key = std::string(argv[i] + 2, argv[i] + eq);
val = std::string(argv[i] + eq + 1);
} else {
key = std::string(argv[i] + 2);
if (i < argc - 1) {
val = argv[++i];
}
}
bool found{};
for (int j = 0; j < infos_.size(); ++j) {
auto& flag = infos_[j];
if (key == flag.name) {
args[j] = val;
found = used[j] = 1;
break;
}
}
if (!found) {
std::cout << "error: unknown option: " << key << std::endl;
return false;
}
} else {
for (arg_idx++; arg_idx < infos_.size(); ++arg_idx) {
if (!infos_[arg_idx].is_flag) {
args[arg_idx] = argv[i];
used[arg_idx] = 1;
break;
}
}
if (arg_idx == infos_.size()) {
std::cout << "error: unknown argument: " << argv[i] << std::endl;
return false;
}
}
}
std::vector<std::string> missing;
for (arg_idx++; arg_idx < infos_.size(); ++arg_idx) {
if (!infos_[arg_idx].is_flag) {
missing.push_back(infos_[arg_idx].name);
}
}
if (!missing.empty()) {
std::cout << "error: the following arguments are required:";
for (int i = 0; i < missing.size(); ++i) {
std::cout << " " << missing[i];
if (i != missing.size() - 1) {
std::cout << ",";
}
}
std::cout << "\n";
return false;
}
for (int i = 0; i < infos_.size(); ++i) {
if (used[i]) {
try {
parse_str(infos_[i], args[i]);
} catch (...) {
std::cout << "error: failed to parse " << infos_[i].name << ": " << args[i] << std::endl;
return false;
}
}
}
return true;
}
static void parse_str(Info& info, const std::string& str) {
if (info.type == "int32") {
*static_cast<int32_t*>(info.ptr) = std::stoi(str);
} else if (info.type == "double") {
*static_cast<double*>(info.ptr) = std::stod(str);
} else if (info.type == "string") {
*static_cast<std::string*>(info.ptr) = str;
} else {
// pass
}
}
static std::string get_default_str(const Info& info) {
if (info.type == "int32") {
return std::to_string(*static_cast<int32_t*>(info.ptr));
} else if (info.type == "double") {
std::ostringstream os;
os << std::setprecision(3) << *static_cast<double*>(info.ptr);
return os.str();
} else if (info.type == "string") {
return "\"" + *(static_cast<std::string*>(info.ptr)) + "\"";
} else {
return "<unknown type>";
}
}
void _ShowUsageWithFlags(const char* argv0) const {
ShowUsage(argv0);
static constexpr const auto kLineLength = 80;
std::cout << std::endl;
int max_name_length = 0;
for (const auto& info : infos_) {
max_name_length = std::max(max_name_length, (int)info.name.length());
}
max_name_length += 4;
auto name_col_size = max_name_length + 1;
auto msg_col_size = kLineLength - name_col_size;
std::cout << "required arguments:\n";
ShowFlags(name_col_size, msg_col_size, false);
std::cout << std::endl;
std::cout << "optional arguments:\n";
ShowFlags(name_col_size, msg_col_size, true);
}
void ShowFlags(int name_col_size, int msg_col_size, bool is_flag) const {
for (const auto& info : infos_) {
if (info.is_flag != is_flag) {
continue;
}
std::string name = " ";
if (info.is_flag) {
name.append("--");
}
name.append(info.name);
while (name.length() < name_col_size) {
name.append(" ");
}
std::cout << name;
std::string msg = info.msg;
while (msg.length() > msg_col_size) { // insert line-breaks when msg is too long
auto pos = msg.rend() - std::find(std::make_reverse_iterator(msg.begin() + msg_col_size),
msg.rend(), ' ');
std::cout << msg.substr(0, pos - 1) << std::endl;
std::cout << std::string(name_col_size, ' ');
msg = msg.substr(pos);
}
std::cout << msg;
std::string type;
type.append("[").append(info.type);
if (info.is_flag) {
type.append(" = ").append(get_default_str(info));
}
type.append("]");
if (msg.length() + type.length() + 1 > msg_col_size) {
std::cout << std::endl << std::string(name_col_size, ' ') << type;
} else {
std::cout << " " << type;
}
std::cout << std::endl;
}
}
void ShowUsage(const char* argv0) const {
for (auto p = argv0; *p; ++p) {
if (*p == '/' || *p == '\'') {
argv0 = p + 1;
}
}
std::cout << "Usage: " << argv0 << " [options]";
for (const auto& info : infos_) {
if (!info.is_flag) {
std::cout << " " << info.name;
}
}
std::cout << std::endl;
}
private:
std::vector<Info> infos_;
};
inline bool ParseArguments(int argc, char* argv[]) { return ArgParse::ParseArguments(argc, argv); }
} // namespace utils
#define _MMDEPLOY_DEFINE_FLAG(type, name, init, msg) \
type FLAGS_##name = ::utils::ArgParse::Register(#type, #name, type(init), msg, &FLAGS_##name)
#define _MMDEPLOY_DEFINE_ARG(type, name, msg) \
type ARGS_##name = ::utils::ArgParse::Register<type>(#type, #name, msg, &ARGS_##name)
#endif // MMDEPLOY_ARGPARSE_H

View File

@ -0,0 +1,388 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_MEDIAIO_H
#define MMDEPLOY_MEDIAIO_H
#include <fstream>
#include <set>
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgcodecs/imgcodecs.hpp"
#include "opencv2/videoio/videoio.hpp"
namespace utils {
namespace mediaio {
enum class MediaType { kUnknown, kImage, kVideo, kImageList, kWebcam, kFmtStr, kDisable };
namespace detail {
static std::string get_extension(const std::string& path) {
std::string ext;
for (auto i = (int)path.size() - 1; i >= 0; --i) {
if (path[i] == '.') {
ext.push_back(path[i]);
for (++i; i < path.size(); ++i) {
ext.push_back((char)std::tolower((unsigned char)path[i]));
}
return ext;
}
}
return {};
}
int ext2fourcc(const std::string& ext) {
auto get_fourcc = [](const char* s) { return cv::VideoWriter::fourcc(s[0], s[1], s[2], s[3]); };
static std::map<std::string, int> ext2fourcc{
{".mp4", get_fourcc("mp4v")},
{".avi", get_fourcc("DIVX")},
{".mkv", get_fourcc("X264")},
{".wmv", get_fourcc("WMV3")},
};
auto it = ext2fourcc.find(ext);
if (it != ext2fourcc.end()) {
return it->second;
}
return get_fourcc("DIVX");
}
static bool is_video(const std::string& ext) {
static const std::set<std::string> es{".mp4", ".avi", ".mkv", ".webm", ".mov", ".mpg", ".wmv"};
return es.count(ext);
}
static bool is_list(const std::string& ext) {
static const std::set<std::string> es{".txt"};
return es.count(ext);
}
static bool is_image(const std::string& ext) {
static const std::set<std::string> es{".jpg", ".jpeg", ".png", ".tif", ".tiff",
".bmp", ".ppm", ".pgm", ".webp"};
return es.count(ext);
}
static bool is_fmtstr(const std::string& str) {
for (const auto& c : str) {
if (c == '%') {
return true;
}
}
return false;
}
} // namespace detail
class Input;
class InputIterator {
public:
using iterator_category = std::input_iterator_tag;
using difference_type = std::ptrdiff_t;
using reference = cv::Mat&;
using value_type = reference;
using pointer = void;
public:
InputIterator() = default;
explicit InputIterator(Input& input) : input_(&input) { next(); }
InputIterator& operator++() {
next();
return *this;
}
reference operator*() { return frame_; }
friend bool operator==(const InputIterator& a, const InputIterator& b) {
return &a == &b || a.is_end() == b.is_end();
}
friend bool operator!=(const InputIterator& a, const InputIterator& b) { return !(a == b); }
private:
void next();
bool is_end() const noexcept { return frame_.data != nullptr; }
private:
cv::Mat frame_;
Input* input_{};
};
class BatchInputIterator {
public:
using iterator_category = std::input_iterator_tag;
using difference_type = std::ptrdiff_t;
using reference = std::vector<cv::Mat>&;
using value_type = reference;
using pointer = void;
public:
BatchInputIterator() = default;
BatchInputIterator(InputIterator iter, InputIterator end, size_t batch_size)
: iter_(std::move(iter)), end_(std::move(end)), batch_size_(batch_size) {
next();
}
BatchInputIterator& operator++() {
next();
return *this;
}
reference operator*() { return data_; }
friend bool operator==(const BatchInputIterator& a, const BatchInputIterator& b) {
return &a == &b || a.is_end() == b.is_end();
}
friend bool operator!=(const BatchInputIterator& a, const BatchInputIterator& b) {
return !(a == b);
}
private:
void next() {
data_.clear();
for (size_t i = 0; i < batch_size_ && iter_ != end_; ++i, ++iter_) {
data_.push_back(*iter_);
}
}
bool is_end() const { return data_.empty(); }
private:
InputIterator iter_;
InputIterator end_;
size_t batch_size_{1};
std::vector<cv::Mat> data_;
};
class Input {
public:
explicit Input(const std::string& path, MediaType type = MediaType::kUnknown)
: path_(path), type_(type) {
if (type_ == MediaType::kUnknown) {
auto ext = detail::get_extension(path);
if (detail::is_image(ext)) {
type_ = MediaType::kImage;
} else if (detail::is_video(ext)) {
type_ = MediaType::kVideo;
} else if (path.size() == 1 && std::isdigit((unsigned char)path[0])) {
type_ = MediaType::kWebcam;
} else if (detail::is_list(ext) || try_image_list(path)) {
type_ = MediaType::kImageList;
} else if (try_image(path)) {
type_ = MediaType::kImage;
} else if (try_video(path)) {
type_ = MediaType::kVideo;
} else {
std::cout << "unknown file type: " << path << "\n";
}
}
if (type_ != MediaType::kUnknown) {
if (type_ == MediaType::kVideo) {
cap_.open(path_);
if (!cap_.isOpened()) {
std::cerr << "failed to open video file: " << path_ << "\n";
}
} else if (type_ == MediaType::kWebcam) {
cap_.open(std::stoi(path_));
if (!cap_.isOpened()) {
std::cerr << "failed to open camera index: " << path_ << "\n";
}
type_ = MediaType::kVideo;
} else if (type_ == MediaType::kImage) {
items_ = {path_};
type_ = MediaType::kImageList;
} else if (type_ == MediaType::kImageList) {
if (items_.empty()) {
items_ = load_image_list(path);
}
}
}
}
InputIterator begin() { return InputIterator(*this); }
InputIterator end() { return {}; } // NOLINT
cv::Mat read() {
cv::Mat img;
if (type_ == MediaType::kVideo) {
cap_ >> img;
} else if (type_ == MediaType::kImageList) {
while (!img.data && index_ < items_.size()) {
auto path = items_[index_++];
img = cv::imread(path);
if (!img.data) {
std::cerr << "failed to load image: " << path << "\n";
}
}
}
return img;
}
class Batch {
public:
Batch(Input& input, size_t batch_size) : input_(&input), batch_size_(batch_size) {}
BatchInputIterator begin() { return {input_->begin(), input_->end(), batch_size_}; }
BatchInputIterator end() { return {}; } // NOLINT
private:
Input* input_{};
size_t batch_size_{1};
};
Batch batch(size_t batch_size) { return {*this, batch_size}; }
private:
static bool try_image(const std::string& path) { return cv::imread(path).data; }
static bool try_video(const std::string& path) { return cv::VideoCapture(path).isOpened(); }
static std::vector<std::string> load_image_list(const std::string& path, size_t max_bytes = 0) {
std::ifstream ifs(path);
ifs.seekg(0, std::ifstream::end);
auto size = ifs.tellg();
ifs.seekg(0, std::ifstream::beg);
if (max_bytes && size > max_bytes) {
return {};
}
auto strip = [](std::string& s) {
while (!s.empty() && std::isspace((unsigned char)s.back())) {
s.pop_back();
}
};
std::vector<std::string> ret;
std::string line;
while (std::getline(ifs, line)) {
strip(line);
if (!line.empty()) {
ret.push_back(std::move(line));
}
}
return ret;
}
bool try_image_list(const std::string& path) {
auto items = load_image_list(path, 1 << 20);
size_t count = 0;
for (const auto& item : items) {
if (detail::is_image(detail::get_extension(item)) && ++count > items.size() / 10) {
items_ = std::move(items);
return true;
}
}
return false;
}
private:
MediaType type_{MediaType::kUnknown};
std::string path_;
std::vector<std::string> items_;
cv::VideoCapture cap_;
size_t index_{};
};
inline void InputIterator::next() {
assert(input_);
frame_ = input_->read();
}
class Output;
class OutputIterator {
public:
using iterator_category = std::output_iterator_tag;
using difference_type = std::ptrdiff_t;
using reference = void;
using value_type = void;
using pointer = void;
public:
explicit OutputIterator(Output& output) : output_(&output) {}
OutputIterator& operator=(const cv::Mat& frame);
OutputIterator& operator*() { return *this; }
OutputIterator& operator++() { return *this; }
OutputIterator& operator++(int) { return *this; } // NOLINT
private:
Output* output_{};
};
class Output {
public:
explicit Output(const std::string& path, int show, MediaType type = MediaType::kUnknown)
: path_(path), type_(type), show_(show) {
ext_ = detail::get_extension(path);
if (type_ == MediaType::kUnknown) {
if (path_.empty()) {
type_ = MediaType::kDisable;
} else if (detail::is_image(ext_)) {
if (detail::is_fmtstr(path)) {
type_ = MediaType::kFmtStr;
} else {
type_ = MediaType::kImage;
}
} else if (detail::is_video(ext_)) {
type_ = MediaType::kVideo;
} else {
std::cout << "unknown file type: " << path << "\n";
}
}
}
bool write(const cv::Mat& frame) {
bool exit = false;
switch (type_) {
case MediaType::kDisable:
break;
case MediaType::kImage:
cv::imwrite(path_, frame);
break;
case MediaType::kFmtStr: {
char buf[256];
snprintf(buf, sizeof(buf), path_.c_str(), frame_id_);
cv::imwrite(buf, frame);
break;
}
case MediaType::kVideo:
write_video(frame);
break;
default:
std::cout << "unsupported output media type\n";
assert(0);
}
if (show_ >= 0) {
cv::imshow("", frame);
exit = cv::waitKey(show_) == 'q';
}
++frame_id_;
return !exit;
}
OutputIterator inserter() { return OutputIterator{*this}; }
private:
void write_video(const cv::Mat& frame) {
if (!video_.isOpened()) {
open_video(frame.size());
}
video_ << frame;
}
void open_video(const cv::Size& size) { video_.open(path_, detail::ext2fourcc(ext_), 30, size); }
private:
std::string path_;
std::string ext_;
MediaType type_{MediaType::kUnknown};
int show_{1};
size_t frame_id_{0};
cv::VideoWriter video_;
};
OutputIterator& OutputIterator::operator=(const cv::Mat& frame) {
assert(output_);
output_->write(frame);
return *this;
}
} // namespace mediaio
} // namespace utils
#endif // MMDEPLOY_MEDIAIO_H

View File

@ -0,0 +1,94 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_PALETTE_H
#define MMDEPLOY_PALETTE_H
#include <fstream>
#include <opencv2/core/core.hpp>
#include <random>
#include <string>
#include <utility>
#include <vector>
namespace utils {
struct Palette {
std::vector<cv::Vec3b> data;
static Palette get(const std::string& path);
static Palette get(int n);
};
inline Palette Palette::get(const std::string& path) {
if (path == "coco") {
Palette p{{
{220, 20, 60}, {119, 11, 32}, {0, 0, 142}, {0, 0, 230}, {106, 0, 228},
{0, 60, 100}, {0, 80, 100}, {0, 0, 70}, {0, 0, 192}, {250, 170, 30},
{100, 170, 30}, {220, 220, 0}, {175, 116, 175}, {250, 0, 30}, {165, 42, 42},
{255, 77, 255}, {0, 226, 252}, {182, 182, 255}, {0, 82, 0}, {120, 166, 157},
{110, 76, 0}, {174, 57, 255}, {199, 100, 0}, {72, 0, 118}, {255, 179, 240},
{0, 125, 92}, {209, 0, 151}, {188, 208, 182}, {0, 220, 176}, {255, 99, 164},
{92, 0, 73}, {133, 129, 255}, {78, 180, 255}, {0, 228, 0}, {174, 255, 243},
{45, 89, 255}, {134, 134, 103}, {145, 148, 174}, {255, 208, 186}, {197, 226, 255},
{171, 134, 1}, {109, 63, 54}, {207, 138, 255}, {151, 0, 95}, {9, 80, 61},
{84, 105, 51}, {74, 65, 105}, {166, 196, 102}, {208, 195, 210}, {255, 109, 65},
{0, 143, 149}, {179, 0, 194}, {209, 99, 106}, {5, 121, 0}, {227, 255, 205},
{147, 186, 208}, {153, 69, 1}, {3, 95, 161}, {163, 255, 0}, {119, 0, 170},
{0, 182, 199}, {0, 165, 120}, {183, 130, 88}, {95, 32, 0}, {130, 114, 135},
{110, 129, 133}, {166, 74, 118}, {219, 142, 185}, {79, 210, 114}, {178, 90, 62},
{65, 70, 15}, {127, 167, 115}, {59, 105, 106}, {142, 108, 45}, {196, 172, 0},
{95, 54, 80}, {128, 76, 255}, {201, 57, 1}, {246, 0, 122}, {191, 162, 208},
}};
for (auto& x : p.data) {
std::swap(x[0], x[2]);
}
return p;
} else if (path == "cityscapes") {
Palette p{{
{128, 64, 128}, {244, 35, 232}, {70, 70, 70}, {102, 102, 156}, {190, 153, 153},
{153, 153, 153}, {250, 170, 30}, {220, 220, 0}, {107, 142, 35}, {152, 251, 152},
{70, 130, 180}, {220, 20, 60}, {255, 0, 0}, {0, 0, 142}, {0, 0, 70},
{0, 60, 100}, {0, 80, 100}, {0, 0, 230}, {119, 11, 32},
}};
for (auto& x : p.data) {
std::swap(x[0], x[2]);
}
return p;
}
std::ifstream ifs(path);
if (!ifs.is_open()) {
std::cout << "error: failed to open palette data file: " << path << "\n";
std::abort();
}
Palette p;
int n = 0;
ifs >> n;
for (int i = 0; i < n; ++i) {
cv::Vec3b x{};
ifs >> x[0] >> x[1] >> x[2];
p.data.push_back(x);
}
return p;
}
inline Palette Palette::get(int n) {
std::vector<cv::Point3f> samples(n * 100);
std::vector<int> indices(samples.size());
std::iota(indices.begin(), indices.end(), 0);
std::mt19937 gen; // NOLINT
std::uniform_int_distribution<ushort> uniform_dist(0, 255);
for (auto& x : samples) {
x = {(float)uniform_dist(gen), (float)uniform_dist(gen), (float)uniform_dist(gen)};
}
std::vector<cv::Point3f> centers;
cv::kmeans(samples, n, indices, cv::TermCriteria(cv::TermCriteria::Type::COUNT, 10, 0), 1,
cv::KMEANS_PP_CENTERS, centers);
Palette p;
for (const auto& c : centers) {
p.data.emplace_back((uchar)c.x, (uchar)c.y, (uchar)c.z);
}
return p;
}
} // namespace utils
#endif // MMDEPLOY_PALETTE_H

View File

@ -0,0 +1,89 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_SKELETON_H
#define MMDEPLOY_SKELETON_H
#include <fstream>
#include <opencv2/core/core.hpp>
#include <string>
#include <utility>
#include <vector>
namespace utils {
struct Skeleton {
std::vector<std::pair<int, int>> links;
std::vector<cv::Scalar> palette;
std::vector<int> link_colors;
std::vector<int> point_colors;
static Skeleton get(const std::string& path);
};
const Skeleton& gCocoSkeleton() {
static const Skeleton inst{
{
{15, 13}, {13, 11}, {16, 14}, {14, 12}, {11, 12}, {5, 11}, {6, 12},
{5, 6}, {5, 7}, {6, 8}, {7, 9}, {8, 10}, {1, 2}, {0, 1},
{0, 2}, {1, 3}, {2, 4}, {3, 5}, {4, 6},
},
{
{255, 128, 0}, {255, 153, 51}, {255, 178, 102}, {230, 230, 0}, {255, 153, 255},
{153, 204, 255}, {255, 102, 255}, {255, 51, 255}, {102, 178, 255}, {51, 153, 255},
{255, 153, 153}, {255, 102, 102}, {255, 51, 51}, {153, 255, 153}, {102, 255, 102},
{51, 255, 51}, {0, 255, 0}, {0, 0, 255}, {255, 0, 0}, {255, 255, 255},
},
{0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16},
{16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0},
};
return inst;
}
// n_links
// u0, v0, u1, v1, ..., un-1, vn-1
// n_palette
// b0, g0, r0, ..., bn-1, gn-1, rn-1
// n_link_color
// i0, i1, ..., in-1
// n_point_color
// j0, j1, ..., jn-1
inline Skeleton Skeleton::get(const std::string& path) {
if (path == "coco") {
return gCocoSkeleton();
}
std::ifstream ifs(path);
if (!ifs.is_open()) {
std::cout << "error: failed to open skeleton data file: " << path << "\n";
std::abort();
}
Skeleton skel;
int n = 0;
ifs >> n;
for (int i = 0; i < n; ++i) {
int u{}, v{};
ifs >> u >> v;
skel.links.emplace_back(u, v);
}
ifs >> n;
for (int i = 0; i < n; ++i) {
int b{}, g{}, r{};
ifs >> b >> g >> r;
skel.palette.emplace_back(b, g, r);
}
ifs >> n;
for (int i = 0; i < n; ++i) {
int x{};
ifs >> x;
skel.link_colors.push_back(x);
}
ifs >> n;
for (int i = 0; i < n; ++i) {
int x{};
ifs >> x;
skel.point_colors.push_back(x);
}
return skel;
}
} // namespace utils
#endif // MMDEPLOY_SKELETON_H

View File

@ -0,0 +1,252 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_VISUALIZE_H
#define MMDEPLOY_VISUALIZE_H
#include <algorithm>
#include <iomanip>
#include <numeric>
#include <vector>
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"
#include "palette.h"
#include "skeleton.h"
namespace utils {
class Visualize {
public:
class Session {
public:
explicit Session(Visualize& v, const cv::Mat& frame) : v_(v) {
if (v_.size_) {
scale_ = (float)v_.size_ / (float)std::max(frame.cols, frame.rows);
}
cv::Mat img;
if (v.background_ == "black") {
img = cv::Mat::zeros(frame.size(), CV_8UC3);
} else {
img = frame;
if (img.channels() == 1) {
cv::cvtColor(img, img, cv::COLOR_GRAY2BGR);
}
}
if (scale_ != 1) {
cv::resize(img, img, {}, scale_, scale_);
} else if (img.data == frame.data) {
img = img.clone();
}
img_ = std::move(img);
}
void add_label(int label_id, float score, int index) {
printf("label: %d, label_id: %d, score: %.4f\n", index, label_id, score);
auto size = .5f * static_cast<float>(img_.rows + img_.cols);
offset_ += add_text(to_text(label_id, score), {1, (float)offset_}, size) + 2;
}
int add_text(const std::string& text, const cv::Point2f& origin, float size) {
static constexpr const int font_face = cv::FONT_HERSHEY_SIMPLEX;
static constexpr const int thickness = 1;
static constexpr const auto max_font_scale = .5f;
static constexpr const auto min_font_scale = .25f;
float font_scale{};
if (size < 20) {
font_scale = min_font_scale;
} else if (size > 200) {
font_scale = max_font_scale;
} else {
font_scale = min_font_scale + (size - 20) / (200 - 20) * (max_font_scale - min_font_scale);
}
int baseline{};
auto text_size = cv::getTextSize(text, font_face, font_scale, thickness, &baseline);
cv::Rect rect(origin + cv::Point2f(0, text_size.height + 2 * thickness),
origin + cv::Point2f(text_size.width, 0));
rect &= cv::Rect({}, img_.size());
img_(rect) *= .35f;
cv::putText(img_, text, origin + cv::Point2f(0, text_size.height), font_face, font_scale,
cv::Scalar::all(255), thickness, cv::LINE_AA);
return text_size.height;
}
static std::string to_text(int label_id, float score) {
std::stringstream ss;
ss << label_id << ": " << std::fixed << std::setprecision(1) << score * 100;
return ss.str();
}
template <typename Mask>
void add_det(const mmdeploy_rect_t& rect, int label_id, float score, const Mask* mask,
int index) {
printf("bbox %d, left=%.2f, top=%.2f, right=%.2f, bottom=%.2f, label=%d, score=%.4f\n", index,
rect.left, rect.top, rect.right, rect.bottom, label_id, score);
if (mask) {
fprintf(stdout, "mask %d, height=%d, width=%d\n", index, mask->height, mask->width);
auto x0 = (int)std::max(std::floor(rect.left) - 1, 0.f);
auto y0 = (int)std::max(std::floor(rect.top) - 1, 0.f);
add_instance_mask({x0, y0}, rand(), mask->data, mask->height, mask->width);
}
add_bbox(rect, label_id, score);
}
void add_instance_mask(const cv::Point& origin, int color_id, const char* mask_data, int mask_h,
int mask_w, float alpha = .5f) {
auto color = v_.palette_.data[color_id % v_.palette_.data.size()];
auto x_end = std::min(origin.x + mask_w, img_.cols);
auto y_end = std::min(origin.y + mask_h, img_.rows);
auto img_data = img_.ptr<cv::Vec3b>();
for (int i = origin.y; i < y_end; ++i) {
for (int j = origin.x; j < x_end; ++j) {
if (mask_data[(i - origin.y) * mask_w + (j - origin.x)]) {
img_data[i * img_.cols + j] = img_data[i * img_.cols + j] * (1 - alpha) + color * alpha;
}
}
}
}
void add_bbox(mmdeploy_rect_t rect, int label_id, float score) {
rect.left *= scale_;
rect.right *= scale_;
rect.top *= scale_;
rect.bottom *= scale_;
if (label_id >= 0 && score > 0) {
auto area = std::max(0.f, (rect.right - rect.left) * (rect.bottom - rect.top));
add_text(to_text(label_id, score), {rect.left, rect.top}, std::sqrt(area));
}
cv::rectangle(img_, cv::Point2f(rect.left, rect.top), cv::Point2f(rect.right, rect.bottom),
cv::Scalar(0, 255, 0));
}
void add_text_det(mmdeploy_point_t bbox[4], float score, const char* text, size_t text_size,
int index) {
printf("bbox[%d]: (%.2f, %.2f), (%.2f, %.2f), (%.2f, %.2f), (%.2f, %.2f)\n", index, //
bbox[0].x, bbox[0].y, //
bbox[1].x, bbox[1].y, //
bbox[2].x, bbox[2].y, //
bbox[3].x, bbox[3].y);
std::vector<cv::Point> poly_points;
cv::Point2f center{};
for (int i = 0; i < 4; ++i) {
poly_points.emplace_back(bbox[i].x * scale_, bbox[i].y * scale_);
center += cv::Point2f(poly_points.back());
}
cv::polylines(img_, poly_points, true, cv::Scalar{0, 255, 0}, 1, cv::LINE_AA);
if (text) {
auto area = cv::contourArea(poly_points);
fprintf(stdout, "text[%d]: %s\n", index, text);
add_text(std::string(text, text + text_size), center / 4, std::sqrt(area));
}
}
void add_rotated_det(const float bbox[5], int label_id, float score) {
float xc = bbox[0] * scale_;
float yc = bbox[1] * scale_;
float w = bbox[2] * scale_;
float h = bbox[3] * scale_;
float ag = bbox[4];
float wx = w / 2 * std::cos(ag);
float wy = w / 2 * std::sin(ag);
float hx = -h / 2 * std::sin(ag);
float hy = h / 2 * std::cos(ag);
cv::Point2f p1{xc - wx - hx, yc - wy - hy};
cv::Point2f p2{xc + wx - hx, yc + wy - hy};
cv::Point2f p3{xc + wx + hx, yc + wy + hy};
cv::Point2f p4{xc - wx + hx, yc - wy + hy};
cv::Point2f c = .25f * (p1 + p2 + p3 + p4);
cv::drawContours(
img_,
std::vector<std::vector<cv::Point>>{{p1 * scale_, p2 * scale_, p3 * scale_, p4 * scale_}},
-1, {0, 255, 0}, 2, cv::LINE_AA);
add_text(to_text(label_id, score), c, std::sqrt(w * h));
}
void add_mask(int height, int width, int n_classes, const int* mask, const float* score) {
cv::Mat color_mask = cv::Mat::zeros(height, width, CV_8UC3);
auto n_pix = color_mask.total();
// compute top 1 idx if score (CHW) is available
cv::Mat_<int> top;
if (!mask && score) {
top = cv::Mat_<int>::zeros(height, width);
for (auto c = 1; c < n_classes; ++c) {
top.forEach([&](int& x, const int* idx) {
auto offset = idx[0] * width + idx[1];
if (score[c * n_pix + offset] > score[x * n_pix + offset]) {
x = c;
}
});
}
mask = top.ptr<int>();
}
if (mask) {
// palette look-up
color_mask.forEach<cv::Vec3b>([&](cv::Vec3b& x, const int* idx) {
auto& palette = v_.palette_.data;
x = palette[mask[idx[0] * width + idx[1]] % palette.size()];
});
if (color_mask.size() != img_.size()) {
cv::resize(color_mask, color_mask, img_.size());
}
// blend mask and background image
cv::addWeighted(img_, .5, color_mask, .5, 0., img_);
}
}
void add_pose(const mmdeploy_point_t* pts, const float* scores, int32_t pts_size, double thr) {
auto& skel = v_.skeleton_;
std::vector<int> used(pts_size);
std::vector<int> is_end_point(pts_size);
for (size_t i = 0; i < skel.links.size(); ++i) {
auto u = skel.links[i].first;
auto v = skel.links[i].second;
is_end_point[u] = is_end_point[v] = 1;
if (scores[u] > thr && scores[v] > thr) {
used[u] = used[v] = 1;
cv::Point2f p0(pts[u].x, pts[u].y);
cv::Point2f p1(pts[v].x, pts[v].y);
cv::line(img_, p0 * scale_, p1 * scale_, skel.palette[skel.link_colors[i]], 1,
cv::LINE_AA);
}
}
for (size_t i = 0; i < pts_size; ++i) {
if (!is_end_point[i] && scores[i] > thr || used[i]) {
cv::Point2f p(pts[i].x, pts[i].y);
cv::circle(img_, p * scale_, 1, skel.palette[skel.point_colors[i]], 2, cv::LINE_AA);
}
}
}
cv::Mat get() { return img_; }
private:
Visualize& v_;
float scale_{1};
int offset_{1};
cv::Mat img_;
};
explicit Visualize(int size = 0) : size_(size) { palette_ = Palette::get(32); }
Session get_session(const cv::Mat& frame) { return Session(*this, frame); }
void set_skeleton(const Skeleton& skeleton) { skeleton_ = skeleton; }
void set_palette(const Palette& palette) { palette_ = palette; }
void set_background(const std::string& background) { background_ = background; }
private:
friend Session;
Skeleton skeleton_;
Palette palette_;
std::string background_;
int size_{};
};
} // namespace utils
#endif // MMDEPLOY_VISUALIZE_H

View File

@ -3,8 +3,8 @@
#include <string>
#include "mmdeploy/video_recognizer.hpp"
#include "opencv2/imgcodecs/imgcodecs.hpp"
#include "opencv2/videoio.hpp"
#include "utils/argparse.h"
void SampleFrames(const char* video_path, std::map<int, cv::Mat>& buffer,
std::vector<mmdeploy::Mat>& clips, int clip_len, int frame_interval = 1,
@ -57,28 +57,26 @@ void SampleFrames(const char* video_path, std::map<int, cv::Mat>& buffer,
}
}
int main(int argc, char* argv[]) {
if (argc != 7) {
fprintf(stderr,
"usage:\n video_cls device_name model_path video_path video_path clip_len "
"frame_interval num_clips\n");
return 1;
}
auto device_name = argv[1];
auto model_path = argv[2];
auto video_path = argv[3];
DEFINE_ARG_string(model, "Model path");
DEFINE_ARG_string(video, "Input video path");
DEFINE_ARG_int32(clip_len, "Clip length");
DEFINE_ARG_int32(frame_interval, "Frame interval");
DEFINE_ARG_int32(num_clips, "Number of clips");
DEFINE_string(device, "cpu", R"(Device name, e.g. "cpu", "cuda")");
int clip_len = std::stoi(argv[4]);
int frame_interval = std::stoi(argv[5]);
int num_clips = std::stoi(argv[6]);
int main(int argc, char* argv[]) {
if (!utils::ParseArguments(argc, argv)) {
return -1;
}
std::map<int, cv::Mat> buffer;
std::vector<mmdeploy::Mat> clips;
mmdeploy::VideoSampleInfo clip_info = {clip_len, num_clips};
SampleFrames(video_path, buffer, clips, clip_len, frame_interval, num_clips);
mmdeploy::VideoSampleInfo clip_info = {ARGS_clip_len, ARGS_num_clips};
SampleFrames(ARGS_video.c_str(), buffer, clips, ARGS_clip_len, ARGS_frame_interval,
ARGS_num_clips);
mmdeploy::Model model(model_path);
mmdeploy::VideoRecognizer recognizer(model, mmdeploy::Device{device_name, 0});
mmdeploy::Model model(ARGS_model);
mmdeploy::VideoRecognizer recognizer(model, mmdeploy::Device{FLAGS_device});
auto res = recognizer.Apply(clips, clip_info);