add dim param for Tensor::Squeeze (#603)

This commit is contained in:
Li Zhang 2022-06-17 14:06:35 +08:00 committed by GitHub
parent a822ba7330
commit ac0b52f12a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 15 additions and 8 deletions

View File

@ -50,7 +50,9 @@ class DBHead : public MMOCR {
return Status(eNotSupported);
}
conf.Squeeze();
// drop batch dimension
conf.Squeeze(0);
conf = conf.Slice(0);
std::vector<std::vector<cv::Point>> contours;

View File

@ -53,7 +53,9 @@ class PANHead : public MMOCR {
(int)pred.data_type());
return Status(eNotSupported);
}
pred.Squeeze();
// drop batch dimension
pred.Squeeze(0);
auto text_pred = pred.Slice(0);
auto kernel_pred = pred.Slice(1);

View File

@ -51,7 +51,7 @@ class PSEHead : public MMOCR {
}
// drop batch dimension
_preds.Squeeze();
_preds.Squeeze(0);
cv::Mat_<uint8_t> masks;
cv::Mat_<int> kernel_labels;

View File

@ -88,11 +88,13 @@ void Tensor::Reshape(const TensorShape& shape) {
}
void Tensor::Squeeze() {
TensorShape new_shape;
new_shape.reserve(shape().size());
std::copy_if(begin(shape()), end(shape()), std::back_inserter(new_shape),
[](int64_t dim) { return dim != 1; });
Reshape(new_shape);
desc_.shape.erase(std::remove(desc_.shape.begin(), desc_.shape.end(), 1), desc_.shape.end());
}
void Tensor::Squeeze(int dim) {
if (shape(dim) == 1) {
desc_.shape.erase(desc_.shape.begin() + dim);
}
}
Result<void> Tensor::CopyFrom(const Tensor& tensor, Stream stream) {

View File

@ -47,6 +47,7 @@ class MMDEPLOY_API Tensor {
void Reshape(const TensorShape& shape);
void Squeeze();
void Squeeze(int dim);
Tensor Slice(int start, int end);
Tensor Slice(int index) { return Slice(index, index + 1); }