add dim param for Tensor::Squeeze (#603)

This commit is contained in:
Li Zhang 2022-06-17 14:06:35 +08:00 committed by GitHub
parent a822ba7330
commit ac0b52f12a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 15 additions and 8 deletions

View File

@ -50,7 +50,9 @@ class DBHead : public MMOCR {
return Status(eNotSupported); return Status(eNotSupported);
} }
conf.Squeeze(); // drop batch dimension
conf.Squeeze(0);
conf = conf.Slice(0); conf = conf.Slice(0);
std::vector<std::vector<cv::Point>> contours; std::vector<std::vector<cv::Point>> contours;

View File

@ -53,7 +53,9 @@ class PANHead : public MMOCR {
(int)pred.data_type()); (int)pred.data_type());
return Status(eNotSupported); return Status(eNotSupported);
} }
pred.Squeeze();
// drop batch dimension
pred.Squeeze(0);
auto text_pred = pred.Slice(0); auto text_pred = pred.Slice(0);
auto kernel_pred = pred.Slice(1); auto kernel_pred = pred.Slice(1);

View File

@ -51,7 +51,7 @@ class PSEHead : public MMOCR {
} }
// drop batch dimension // drop batch dimension
_preds.Squeeze(); _preds.Squeeze(0);
cv::Mat_<uint8_t> masks; cv::Mat_<uint8_t> masks;
cv::Mat_<int> kernel_labels; cv::Mat_<int> kernel_labels;

View File

@ -88,11 +88,13 @@ void Tensor::Reshape(const TensorShape& shape) {
} }
void Tensor::Squeeze() { void Tensor::Squeeze() {
TensorShape new_shape; desc_.shape.erase(std::remove(desc_.shape.begin(), desc_.shape.end(), 1), desc_.shape.end());
new_shape.reserve(shape().size()); }
std::copy_if(begin(shape()), end(shape()), std::back_inserter(new_shape),
[](int64_t dim) { return dim != 1; }); void Tensor::Squeeze(int dim) {
Reshape(new_shape); if (shape(dim) == 1) {
desc_.shape.erase(desc_.shape.begin() + dim);
}
} }
Result<void> Tensor::CopyFrom(const Tensor& tensor, Stream stream) { Result<void> Tensor::CopyFrom(const Tensor& tensor, Stream stream) {

View File

@ -47,6 +47,7 @@ class MMDEPLOY_API Tensor {
void Reshape(const TensorShape& shape); void Reshape(const TensorShape& shape);
void Squeeze(); void Squeeze();
void Squeeze(int dim);
Tensor Slice(int start, int end); Tensor Slice(int start, int end);
Tensor Slice(int index) { return Slice(index, index + 1); } Tensor Slice(int index) { return Slice(index, index + 1); }