improve shape checking (#315)
parent
ce2b778061
commit
56e32fdf63
|
@ -26,15 +26,17 @@ class LinearClsHead : public MMClassification {
|
|||
|
||||
Result<Value> operator()(const Value& infer_res) {
|
||||
DEBUG("infer_res: {}", infer_res);
|
||||
auto output_tensor = infer_res["output"].get<Tensor>();
|
||||
assert(output_tensor.shape().size() >= 2);
|
||||
auto class_num = (int)output_tensor.shape()[1];
|
||||
auto output = infer_res["output"].get<Tensor>();
|
||||
|
||||
if (output_tensor.data_type() != DataType::kFLOAT) {
|
||||
if (!(output.shape().size() >= 2 && output.data_type() == DataType::kFLOAT)) {
|
||||
ERROR("unsupported `output` tensor, shape: {}, dtype: {}", output.shape(),
|
||||
(int)output.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
OUTCOME_TRY(auto _scores, MakeAvailableOnDevice(output_tensor, kHost, stream()));
|
||||
auto class_num = (int)output.shape(1);
|
||||
|
||||
OUTCOME_TRY(auto _scores, MakeAvailableOnDevice(output, kHost, stream()));
|
||||
OUTCOME_TRY(stream().Wait());
|
||||
|
||||
return GetLabels(_scores, class_num);
|
||||
|
|
|
@ -17,12 +17,10 @@ class ResizeInstanceMask : public ResizeBBox {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: remove duplication
|
||||
Result<Value> operator()(const Value& prep_res, const Value& infer_res) {
|
||||
DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res);
|
||||
try {
|
||||
assert(prep_res.contains("img_metas"));
|
||||
// Value res = prep_res;
|
||||
|
||||
auto dets = infer_res["dets"].get<Tensor>();
|
||||
auto labels = infer_res["labels"].get<Tensor>();
|
||||
auto masks = infer_res["masks"].get<Tensor>();
|
||||
|
@ -33,14 +31,25 @@ class ResizeInstanceMask : public ResizeBBox {
|
|||
|
||||
// `dets` is supposed to have 3 dims. They are 'batch', 'bboxes_number'
|
||||
// and 'channels' respectively
|
||||
assert(dets.shape().size() == 3);
|
||||
assert(dets.data_type() == DataType::kFLOAT);
|
||||
|
||||
assert(masks.data_type() == DataType::kFLOAT);
|
||||
if (!(dets.shape().size() == 3 && dets.data_type() == DataType::kFLOAT)) {
|
||||
ERROR("unsupported `dets` tensor, shape: {}, dtype: {}", dets.shape(),
|
||||
(int)dets.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
// `labels` is supposed to have 2 dims, which are 'batch' and
|
||||
// 'bboxes_number'
|
||||
assert(labels.shape().size() == 2);
|
||||
if (labels.shape().size() != 2) {
|
||||
ERROR("unsupported `labels`, tensor, shape: {}, dtype: {}", labels.shape(),
|
||||
(int)labels.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
if (!(masks.shape().size() == 4 && masks.data_type() == DataType::kFLOAT)) {
|
||||
ERROR("unsupported `mask` tensor, shape: {}, dtype: {}", masks.shape(),
|
||||
(int)masks.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
OUTCOME_TRY(auto _dets, MakeAvailableOnDevice(dets, kHost, stream()));
|
||||
OUTCOME_TRY(auto _labels, MakeAvailableOnDevice(labels, kHost, stream()));
|
||||
|
|
|
@ -19,9 +19,6 @@ ResizeBBox::ResizeBBox(const Value& cfg) : MMDetection(cfg) {
|
|||
Result<Value> ResizeBBox::operator()(const Value& prep_res, const Value& infer_res) {
|
||||
DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res);
|
||||
try {
|
||||
assert(prep_res.contains("img_metas"));
|
||||
// Value res = prep_res;
|
||||
|
||||
auto dets = infer_res["dets"].get<Tensor>();
|
||||
auto labels = infer_res["labels"].get<Tensor>();
|
||||
|
||||
|
@ -30,12 +27,18 @@ Result<Value> ResizeBBox::operator()(const Value& prep_res, const Value& infer_r
|
|||
|
||||
// `dets` is supposed to have 3 dims. They are 'batch', 'bboxes_number'
|
||||
// and 'channels' respectively
|
||||
assert(dets.shape().size() == 3);
|
||||
assert(dets.data_type() == DataType::kFLOAT);
|
||||
if (!(dets.shape().size() == 3 && dets.data_type() == DataType::kFLOAT)) {
|
||||
ERROR("unsupported `dets` tensor, shape: {}, dtype: {}", dets.shape(), (int)dets.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
// `labels` is supposed to have 2 dims, which are 'batch' and
|
||||
// 'bboxes_number'
|
||||
assert(labels.shape().size() == 2);
|
||||
if (labels.shape().size() != 2) {
|
||||
ERROR("unsupported `labels`, tensor, shape: {}, dtype: {}", labels.shape(),
|
||||
(int)labels.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
OUTCOME_TRY(auto _dets, MakeAvailableOnDevice(dets, kHost, stream()));
|
||||
OUTCOME_TRY(auto _labels, MakeAvailableOnDevice(labels, kHost, stream()));
|
||||
|
|
|
@ -16,7 +16,7 @@ class TensorToImg : public MMEdit {
|
|||
auto upscale = input["output"].get<Tensor>();
|
||||
OUTCOME_TRY(auto upscale_cpu, MakeAvailableOnDevice(upscale, kHOST, stream()));
|
||||
OUTCOME_TRY(stream().Wait());
|
||||
if (upscale.data_type() == DataType::kFLOAT) {
|
||||
if (upscale.shape().size() == 4 && upscale.data_type() == DataType::kFLOAT) {
|
||||
auto channels = static_cast<int>(upscale.shape(1));
|
||||
auto height = static_cast<int>(upscale.shape(2));
|
||||
auto width = static_cast<int>(upscale.shape(3));
|
||||
|
@ -32,6 +32,8 @@ class TensorToImg : public MMEdit {
|
|||
mat_hwc.convertTo(rescale_uint8, CV_8UC(channels), 255.f);
|
||||
return mat;
|
||||
} else {
|
||||
ERROR("unsupported `output` tensor, shape: {}, dtype: {}", upscale.shape(),
|
||||
(int)upscale.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -61,7 +61,9 @@ class CTCConvertor : public MMOCR {
|
|||
Result<Value> operator()(const Value& _data, const Value& _prob) {
|
||||
auto d_conf = _prob["output"].get<Tensor>();
|
||||
|
||||
if (d_conf.data_type() != DataType::kFLOAT) {
|
||||
if (!(d_conf.shape().size() == 3 && d_conf.data_type() == DataType::kFLOAT)) {
|
||||
ERROR("unsupported `output` tensor, shape: {}, dtype: {}", d_conf.shape(),
|
||||
(int)d_conf.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
|
|
|
@ -49,6 +49,12 @@ class DBHead : public MMOCR {
|
|||
OUTCOME_TRY(stream_.Wait());
|
||||
DEBUG("shape: {}", conf.shape());
|
||||
|
||||
if (!(conf.shape().size() == 4 && conf.data_type() == DataType::kFLOAT)) {
|
||||
ERROR("unsupported `output` tensor, shape: {}, dtype: {}", conf.shape(),
|
||||
(int)conf.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
auto h = conf.shape(2);
|
||||
auto w = conf.shape(3);
|
||||
auto data = conf.buffer().GetNative();
|
||||
|
|
|
@ -24,11 +24,12 @@ class ResizeMask : public MMSegmentation {
|
|||
DEBUG("preprocess: {}\ninference: {}", preprocess_result, inference_result);
|
||||
|
||||
auto mask = inference_result["output"].get<Tensor>();
|
||||
INFO("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(), mask.shape(),
|
||||
mask.data_type());
|
||||
assert(mask.data_type() == DataType::kINT32 || mask.data_type() == DataType::kINT64);
|
||||
assert(mask.shape(0) == 1);
|
||||
assert(mask.shape(1) == 1);
|
||||
DEBUG("tensor.name: {}, tensor.shape: {}, tensor.data_type: {}", mask.name(), mask.shape(),
|
||||
mask.data_type());
|
||||
if (!(mask.shape().size() == 4 && mask.shape(0) == 1 && mask.shape(1) == 1)) {
|
||||
ERROR("unsupported `output` tensor, shape: {}", mask.shape());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
auto height = (int)mask.shape(2);
|
||||
auto width = (int)mask.shape(3);
|
||||
|
@ -36,7 +37,7 @@ class ResizeMask : public MMSegmentation {
|
|||
auto input_width = preprocess_result["img_metas"]["ori_shape"][2].get<int>();
|
||||
Device host{"cpu"};
|
||||
OUTCOME_TRY(auto host_tensor, MakeAvailableOnDevice(mask, host, stream_));
|
||||
stream_.Wait().value();
|
||||
OUTCOME_TRY(stream_.Wait());
|
||||
if (mask.data_type() == DataType::kINT64) {
|
||||
// change kINT64 to 2 INT32
|
||||
TensorDesc desc{.device = host_tensor.device(),
|
||||
|
@ -45,8 +46,11 @@ class ResizeMask : public MMSegmentation {
|
|||
.name = host_tensor.name()};
|
||||
Tensor _host_tensor(desc, mask.buffer());
|
||||
return MaskResize(_host_tensor, input_height, input_width);
|
||||
} else {
|
||||
} else if (mask.data_type() == DataType::kINT32) {
|
||||
return MaskResize(host_tensor, input_height, input_width);
|
||||
} else {
|
||||
ERROR("unsupported `output` tensor, dtype: {}", (int)mask.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue