mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
add dim param for Tensor::Squeeze
(#603)
This commit is contained in:
parent
a822ba7330
commit
ac0b52f12a
@ -50,7 +50,9 @@ class DBHead : public MMOCR {
|
|||||||
return Status(eNotSupported);
|
return Status(eNotSupported);
|
||||||
}
|
}
|
||||||
|
|
||||||
conf.Squeeze();
|
// drop batch dimension
|
||||||
|
conf.Squeeze(0);
|
||||||
|
|
||||||
conf = conf.Slice(0);
|
conf = conf.Slice(0);
|
||||||
|
|
||||||
std::vector<std::vector<cv::Point>> contours;
|
std::vector<std::vector<cv::Point>> contours;
|
||||||
|
@ -53,7 +53,9 @@ class PANHead : public MMOCR {
|
|||||||
(int)pred.data_type());
|
(int)pred.data_type());
|
||||||
return Status(eNotSupported);
|
return Status(eNotSupported);
|
||||||
}
|
}
|
||||||
pred.Squeeze();
|
|
||||||
|
// drop batch dimension
|
||||||
|
pred.Squeeze(0);
|
||||||
|
|
||||||
auto text_pred = pred.Slice(0);
|
auto text_pred = pred.Slice(0);
|
||||||
auto kernel_pred = pred.Slice(1);
|
auto kernel_pred = pred.Slice(1);
|
||||||
|
@ -51,7 +51,7 @@ class PSEHead : public MMOCR {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// drop batch dimension
|
// drop batch dimension
|
||||||
_preds.Squeeze();
|
_preds.Squeeze(0);
|
||||||
|
|
||||||
cv::Mat_<uint8_t> masks;
|
cv::Mat_<uint8_t> masks;
|
||||||
cv::Mat_<int> kernel_labels;
|
cv::Mat_<int> kernel_labels;
|
||||||
|
@ -88,11 +88,13 @@ void Tensor::Reshape(const TensorShape& shape) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Tensor::Squeeze() {
|
void Tensor::Squeeze() {
|
||||||
TensorShape new_shape;
|
desc_.shape.erase(std::remove(desc_.shape.begin(), desc_.shape.end(), 1), desc_.shape.end());
|
||||||
new_shape.reserve(shape().size());
|
}
|
||||||
std::copy_if(begin(shape()), end(shape()), std::back_inserter(new_shape),
|
|
||||||
[](int64_t dim) { return dim != 1; });
|
void Tensor::Squeeze(int dim) {
|
||||||
Reshape(new_shape);
|
if (shape(dim) == 1) {
|
||||||
|
desc_.shape.erase(desc_.shape.begin() + dim);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Result<void> Tensor::CopyFrom(const Tensor& tensor, Stream stream) {
|
Result<void> Tensor::CopyFrom(const Tensor& tensor, Stream stream) {
|
||||||
|
@ -47,6 +47,7 @@ class MMDEPLOY_API Tensor {
|
|||||||
void Reshape(const TensorShape& shape);
|
void Reshape(const TensorShape& shape);
|
||||||
|
|
||||||
void Squeeze();
|
void Squeeze();
|
||||||
|
void Squeeze(int dim);
|
||||||
|
|
||||||
Tensor Slice(int start, int end);
|
Tensor Slice(int start, int end);
|
||||||
Tensor Slice(int index) { return Slice(index, index + 1); }
|
Tensor Slice(int index) { return Slice(index, index + 1); }
|
||||||
|
Loading…
x
Reference in New Issue
Block a user