mmdeploy/tests/test_csrc/preprocess/test_collect.cpp
Chen Xin 6b01a2e649
[Feature] Add option to fuse transform. (#741)
* add collect_impl.cpp to cuda device

* add dummy compute node wich device elena

* add compiler & dynamic library loader

* add code to compile with gen code(elena)

* move folder

* fix lint

* add tracer module

* add license

* update type id

* add fuse kernel registry

* remove compilier & dynamic_library

* update fuse kernel interface

* Add elena-mmdeploy project in 3rd-party

* Fix README.md

* fix cmake file

* Support cuda device and clang format all file

* Add cudaStreamSynchronize for cudafree

* fix cudaStreamSynchronize

* rename to __tracer__

* remove unused code

* update kernel

* update extract elena script

* update gitignore

* fix ci

* Change the crop_size to crop_h and crop_w in arglist

* update Tracer

* remove cond

* avoid allocate memory

* add build.sh for elena

* remove code

* update test

* Support bilinear resize with float input

* Rename elena-mmdeploy to delete

* Introduce public submodule

* use get_ref

* update elena

* update tools

* update tools

* update fuse transform docs

* add fuse transform doc link to get_started

* fix shape in crop

* remove fuse_transform_ == true check

* remove fuse_transform_ member

* remove elena_int.h

* doesn't dump transform_static.json

* update tracer

* update CVFusion to remove compile warning

* remove mmcv version > 1.5.1 dep

* fix tests

* update docs

* add elena use option

* remove submodule of CVFusion

* update doc

* use auto

* use throw_exception(eEntryNotFound);

* update

Co-authored-by: cx <cx@ubuntu20.04>
Co-authored-by: miraclezqc <969226879@qq.com>
2022-09-05 20:29:18 +08:00

105 lines
3.0 KiB
C++

// Copyright (c) OpenMMLab. All rights reserved.
#include "catch.hpp"
#include "mmdeploy/core/tensor.h"
#include "mmdeploy/preprocess/transform/transform.h"
using namespace mmdeploy;
using namespace std;
TEST_CASE("test collect constructor", "[collect]") {
Device device{"cpu"};
Stream stream{device};
Value cfg = {{"context", {{"device", device}, {"stream", stream}}}};
std::string transform_type{"Collect"};
auto creator = Registry<Transform>::Get().GetCreator(transform_type, 1);
REQUIRE(creator != nullptr);
REQUIRE_THROWS(creator->Create(cfg));
SECTION("args with 'keys' which is not an array") {
auto _cfg = cfg;
_cfg["keys"] = "img";
REQUIRE_THROWS(creator->Create(_cfg));
}
SECTION("args with keys in array") {
auto _cfg = cfg;
_cfg["keys"] = {"img"};
auto module = creator->Create(_cfg);
REQUIRE(module != nullptr);
}
SECTION("args with meta_keys that is not an array") {
auto _cfg = cfg;
_cfg["keys"] = {"img"};
_cfg["meta_keys"] = "ori_img";
REQUIRE_THROWS(creator->Create(_cfg));
}
SECTION("args with meta_keys in array") {
auto _cfg = cfg;
_cfg["keys"] = {"img"};
_cfg["meta_keys"] = {"ori_img"};
auto module = creator->Create(_cfg);
REQUIRE(module != nullptr);
}
}
TEST_CASE("test collect", "[collect]") {
std::string transform_type{"Collect"};
vector<std::string> keys{"img"};
vector<std::string> meta_keys{"filename", "ori_filename", "ori_shape", "img_shape",
"flip", "flip_direction", "img_norm_cfg"};
Value args;
Device device{"cpu"};
Stream stream{device};
args["context"]["device"] = device;
args["context"]["stream"] = stream;
for (auto& key : keys) {
args["keys"].push_back(key);
}
for (auto& meta_key : meta_keys) {
args["meta_keys"].push_back(meta_key);
}
auto creator = Registry<Transform>::Get().GetCreator(transform_type, 1);
REQUIRE(creator != nullptr);
auto module = creator->Create(args);
REQUIRE(module != nullptr);
Value input;
SECTION("input is empty") {
auto ret = module->Process(input);
REQUIRE(ret.has_error());
REQUIRE(ret.error() == eInvalidArgument);
}
SECTION("input has 'ori_img' and 'attribute'") {
input["ori_img"] = Tensor{};
input["attribute"] = "this is a faked image";
auto ret = module->Process(input);
REQUIRE(ret.has_error());
REQUIRE(ret.error() == eInvalidArgument);
}
SECTION("array input with correct keys and meta keys") {
Tensor tensor;
Value input{{"img", tensor},
{"filename", "test.jpg"},
{"ori_filename", "/the/path/of/test.jpg"},
{"ori_shape", {1000, 1000, 3}},
{"img_shape", {1, 3, 224, 224}},
{"flip", "false"},
{"flip_direction", "horizontal"},
{"img_norm_cfg",
{{"mean", {123.675, 116.28, 103.53}},
{"std", {58.395, 57.12, 57.375}},
{"to_rgb", true}}}};
auto ret = module->Process(input);
REQUIRE(ret.has_value());
}
}