diff --git a/csrc/mmdeploy/codebase/mmocr/dbnet.cpp b/csrc/mmdeploy/codebase/mmocr/dbnet.cpp index c4a84006d..5b9d2f0e2 100644 --- a/csrc/mmdeploy/codebase/mmocr/dbnet.cpp +++ b/csrc/mmdeploy/codebase/mmocr/dbnet.cpp @@ -50,7 +50,9 @@ class DBHead : public MMOCR { return Status(eNotSupported); } - conf.Squeeze(); + // drop batch dimension + conf.Squeeze(0); + conf = conf.Slice(0); std::vector> contours; diff --git a/csrc/mmdeploy/codebase/mmocr/panet.cpp b/csrc/mmdeploy/codebase/mmocr/panet.cpp index 9c5ecb281..042d088be 100644 --- a/csrc/mmdeploy/codebase/mmocr/panet.cpp +++ b/csrc/mmdeploy/codebase/mmocr/panet.cpp @@ -53,7 +53,9 @@ class PANHead : public MMOCR { (int)pred.data_type()); return Status(eNotSupported); } - pred.Squeeze(); + + // drop batch dimension + pred.Squeeze(0); auto text_pred = pred.Slice(0); auto kernel_pred = pred.Slice(1); diff --git a/csrc/mmdeploy/codebase/mmocr/psenet.cpp b/csrc/mmdeploy/codebase/mmocr/psenet.cpp index 3b0f2bd5f..19ab31817 100644 --- a/csrc/mmdeploy/codebase/mmocr/psenet.cpp +++ b/csrc/mmdeploy/codebase/mmocr/psenet.cpp @@ -51,7 +51,7 @@ class PSEHead : public MMOCR { } // drop batch dimension - _preds.Squeeze(); + _preds.Squeeze(0); cv::Mat_ masks; cv::Mat_ kernel_labels; diff --git a/csrc/mmdeploy/core/tensor.cpp b/csrc/mmdeploy/core/tensor.cpp index ddbb08252..07fac1ae7 100644 --- a/csrc/mmdeploy/core/tensor.cpp +++ b/csrc/mmdeploy/core/tensor.cpp @@ -88,11 +88,13 @@ void Tensor::Reshape(const TensorShape& shape) { } void Tensor::Squeeze() { - TensorShape new_shape; - new_shape.reserve(shape().size()); - std::copy_if(begin(shape()), end(shape()), std::back_inserter(new_shape), - [](int64_t dim) { return dim != 1; }); - Reshape(new_shape); + desc_.shape.erase(std::remove(desc_.shape.begin(), desc_.shape.end(), 1), desc_.shape.end()); +} + +void Tensor::Squeeze(int dim) { + if (shape(dim) == 1) { + desc_.shape.erase(desc_.shape.begin() + dim); + } } Result Tensor::CopyFrom(const Tensor& tensor, Stream stream) { diff --git a/csrc/mmdeploy/core/tensor.h b/csrc/mmdeploy/core/tensor.h index 92403fe38..ef967af9e 100644 --- a/csrc/mmdeploy/core/tensor.h +++ b/csrc/mmdeploy/core/tensor.h @@ -47,6 +47,7 @@ class MMDEPLOY_API Tensor { void Reshape(const TensorShape& shape); void Squeeze(); + void Squeeze(int dim); Tensor Slice(int start, int end); Tensor Slice(int index) { return Slice(index, index + 1); }