improve shape checking (#315)

pull/1/head
lzhangzz 2021-12-21 20:16:40 +08:00 committed by GitHub
parent ce2b778061
commit 56e32fdf63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 56 additions and 28 deletions

View File

@ -26,15 +26,17 @@ class LinearClsHead : public MMClassification {
Result<Value> operator()(const Value& infer_res) {
DEBUG("infer_res: {}", infer_res);
auto output_tensor = infer_res["output"].get<Tensor>();
assert(output_tensor.shape().size() >= 2);
auto class_num = (int)output_tensor.shape()[1];
auto output = infer_res["output"].get<Tensor>();
if (output_tensor.data_type() != DataType::kFLOAT) {
if (!(output.shape().size() >= 2 && output.data_type() == DataType::kFLOAT)) {
ERROR("unsupported `output` tensor, shape: {}, dtype: {}", output.shape(),
(int)output.data_type());
return Status(eNotSupported);
}
OUTCOME_TRY(auto _scores, MakeAvailableOnDevice(output_tensor, kHost, stream()));
auto class_num = (int)output.shape(1);
OUTCOME_TRY(auto _scores, MakeAvailableOnDevice(output, kHost, stream()));
OUTCOME_TRY(stream().Wait());
return GetLabels(_scores, class_num);

View File

@ -17,12 +17,10 @@ class ResizeInstanceMask : public ResizeBBox {
}
}
// TODO: remove duplication
Result<Value> operator()(const Value& prep_res, const Value& infer_res) {
DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res);
try {
assert(prep_res.contains("img_metas"));
// Value res = prep_res;
auto dets = infer_res["dets"].get<Tensor>();
auto labels = infer_res["labels"].get<Tensor>();
auto masks = infer_res["masks"].get<Tensor>();
@ -33,14 +31,25 @@ class ResizeInstanceMask : public ResizeBBox {
// `dets` is supposed to have 3 dims. They are 'batch', 'bboxes_number'
// and 'channels' respectively
assert(dets.shape().size() == 3);
assert(dets.data_type() == DataType::kFLOAT);
assert(masks.data_type() == DataType::kFLOAT);
if (!(dets.shape().size() == 3 && dets.data_type() == DataType::kFLOAT)) {
ERROR("unsupported `dets` tensor, shape: {}, dtype: {}", dets.shape(),
(int)dets.data_type());
return Status(eNotSupported);
}
// `labels` is supposed to have 2 dims, which are 'batch' and
// 'bboxes_number'
assert(labels.shape().size() == 2);
if (labels.shape().size() != 2) {
ERROR("unsupported `labels`, tensor, shape: {}, dtype: {}", labels.shape(),
(int)labels.data_type());
return Status(eNotSupported);
}
if (!(masks.shape().size() == 4 && masks.data_type() == DataType::kFLOAT)) {
ERROR("unsupported `mask` tensor, shape: {}, dtype: {}", masks.shape(),
(int)masks.data_type());
return Status(eNotSupported);
}
OUTCOME_TRY(auto _dets, MakeAvailableOnDevice(dets, kHost, stream()));
OUTCOME_TRY(auto _labels, MakeAvailableOnDevice(labels, kHost, stream()));

View File

@ -19,9 +19,6 @@ ResizeBBox::ResizeBBox(const Value& cfg) : MMDetection(cfg) {
Result<Value> ResizeBBox::operator()(const Value& prep_res, const Value& infer_res) {
DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res);
try {
assert(prep_res.contains("img_metas"));
// Value res = prep_res;
auto dets = infer_res["dets"].get<Tensor>();
auto labels = infer_res["labels"].get<Tensor>();
@ -30,12 +27,18 @@ Result<Value> ResizeBBox::operator()(const Value& prep_res, const Value& infer_r
// `dets` is supposed to have 3 dims. They are 'batch', 'bboxes_number'
// and 'channels' respectively
assert(dets.shape().size() == 3);
assert(dets.data_type() == DataType::kFLOAT);
if (!(dets.shape().size() == 3 && dets.data_type() == DataType::kFLOAT)) {
ERROR("unsupported `dets` tensor, shape: {}, dtype: {}", dets.shape(), (int)dets.data_type());
return Status(eNotSupported);
}
// `labels` is supposed to have 2 dims, which are 'batch' and
// 'bboxes_number'
assert(labels.shape().size() == 2);
if (labels.shape().size() != 2) {
ERROR("unsupported `labels`, tensor, shape: {}, dtype: {}", labels.shape(),
(int)labels.data_type());
return Status(eNotSupported);
}
OUTCOME_TRY(auto _dets, MakeAvailableOnDevice(dets, kHost, stream()));
OUTCOME_TRY(auto _labels, MakeAvailableOnDevice(labels, kHost, stream()));

View File

@ -16,7 +16,7 @@ class TensorToImg : public MMEdit {
auto upscale = input["output"].get<Tensor>();
OUTCOME_TRY(auto upscale_cpu, MakeAvailableOnDevice(upscale, kHOST, stream()));
OUTCOME_TRY(stream().Wait());
if (upscale.data_type() == DataType::kFLOAT) {
if (upscale.shape().size() == 4 && upscale.data_type() == DataType::kFLOAT) {
auto channels = static_cast<int>(upscale.shape(1));
auto height = static_cast<int>(upscale.shape(2));
auto width = static_cast<int>(upscale.shape(3));
@ -32,6 +32,8 @@ class TensorToImg : public MMEdit {
mat_hwc.convertTo(rescale_uint8, CV_8UC(channels), 255.f);
return mat;
} else {
ERROR("unsupported `output` tensor, shape: {}, dtype: {}", upscale.shape(),
(int)upscale.data_type());
return Status(eNotSupported);
}
}

View File

@ -61,7 +61,9 @@ class CTCConvertor : public MMOCR {
Result<Value> operator()(const Value& _data, const Value& _prob) {
auto d_conf = _prob["output"].get<Tensor>();
if (d_conf.data_type() != DataType::kFLOAT) {
if (!(d_conf.shape().size() == 3 && d_conf.data_type() == DataType::kFLOAT)) {
ERROR("unsupported `output` tensor, shape: {}, dtype: {}", d_conf.shape(),
(int)d_conf.data_type());
return Status(eNotSupported);
}

View File

@ -49,6 +49,12 @@ class DBHead : public MMOCR {
OUTCOME_TRY(stream_.Wait());
DEBUG("shape: {}", conf.shape());
if (!(conf.shape().size() == 4 && conf.data_type() == DataType::kFLOAT)) {
ERROR("unsupported `output` tensor, shape: {}, dtype: {}", conf.shape(),
(int)conf.data_type());
return Status(eNotSupported);
}
auto h = conf.shape(2);
auto w = conf.shape(3);
auto data = conf.buffer().GetNative();

View File

@ -24,11 +24,12 @@ class ResizeMask : public MMSegmentation {
DEBUG("preprocess: {}\ninference: {}", preprocess_result, inference_result);
auto mask = inference_result["output"].get<Tensor>();
INFO("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(), mask.shape(),
mask.data_type());
assert(mask.data_type() == DataType::kINT32 || mask.data_type() == DataType::kINT64);
assert(mask.shape(0) == 1);
assert(mask.shape(1) == 1);
DEBUG("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(), mask.shape(),
mask.data_type());
if (!(mask.shape().size() == 4 && mask.shape(0) == 1 && mask.shape(1) == 1)) {
ERROR("unsupported `output` tensor, shape: {}", mask.shape());
return Status(eNotSupported);
}
auto height = (int)mask.shape(2);
auto width = (int)mask.shape(3);
@ -36,7 +37,7 @@ class ResizeMask : public MMSegmentation {
auto input_width = preprocess_result["img_metas"]["ori_shape"][2].get<int>();
Device host{"cpu"};
OUTCOME_TRY(auto host_tensor, MakeAvailableOnDevice(mask, host, stream_));
stream_.Wait().value();
OUTCOME_TRY(stream_.Wait());
if (mask.data_type() == DataType::kINT64) {
// change kINT64 to 2 INT32
TensorDesc desc{.device = host_tensor.device(),
@ -45,8 +46,11 @@ class ResizeMask : public MMSegmentation {
.name = host_tensor.name()};
Tensor _host_tensor(desc, mask.buffer());
return MaskResize(_host_tensor, input_height, input_width);
} else {
} else if (mask.data_type() == DataType::kINT32) {
return MaskResize(host_tensor, input_height, input_width);
} else {
ERROR("unsupported `output` tensor, dtype: {}", (int)mask.data_type());
return Status(eNotSupported);
}
}