mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
add dim param for Tensor::Squeeze
(#603)
This commit is contained in:
parent
a822ba7330
commit
ac0b52f12a
@ -50,7 +50,9 @@ class DBHead : public MMOCR {
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
|
||||
conf.Squeeze();
|
||||
// drop batch dimension
|
||||
conf.Squeeze(0);
|
||||
|
||||
conf = conf.Slice(0);
|
||||
|
||||
std::vector<std::vector<cv::Point>> contours;
|
||||
|
@ -53,7 +53,9 @@ class PANHead : public MMOCR {
|
||||
(int)pred.data_type());
|
||||
return Status(eNotSupported);
|
||||
}
|
||||
pred.Squeeze();
|
||||
|
||||
// drop batch dimension
|
||||
pred.Squeeze(0);
|
||||
|
||||
auto text_pred = pred.Slice(0);
|
||||
auto kernel_pred = pred.Slice(1);
|
||||
|
@ -51,7 +51,7 @@ class PSEHead : public MMOCR {
|
||||
}
|
||||
|
||||
// drop batch dimension
|
||||
_preds.Squeeze();
|
||||
_preds.Squeeze(0);
|
||||
|
||||
cv::Mat_<uint8_t> masks;
|
||||
cv::Mat_<int> kernel_labels;
|
||||
|
@ -88,11 +88,13 @@ void Tensor::Reshape(const TensorShape& shape) {
|
||||
}
|
||||
|
||||
void Tensor::Squeeze() {
|
||||
TensorShape new_shape;
|
||||
new_shape.reserve(shape().size());
|
||||
std::copy_if(begin(shape()), end(shape()), std::back_inserter(new_shape),
|
||||
[](int64_t dim) { return dim != 1; });
|
||||
Reshape(new_shape);
|
||||
desc_.shape.erase(std::remove(desc_.shape.begin(), desc_.shape.end(), 1), desc_.shape.end());
|
||||
}
|
||||
|
||||
void Tensor::Squeeze(int dim) {
|
||||
if (shape(dim) == 1) {
|
||||
desc_.shape.erase(desc_.shape.begin() + dim);
|
||||
}
|
||||
}
|
||||
|
||||
Result<void> Tensor::CopyFrom(const Tensor& tensor, Stream stream) {
|
||||
|
@ -47,6 +47,7 @@ class MMDEPLOY_API Tensor {
|
||||
void Reshape(const TensorShape& shape);
|
||||
|
||||
void Squeeze();
|
||||
void Squeeze(int dim);
|
||||
|
||||
Tensor Slice(int start, int end);
|
||||
Tensor Slice(int index) { return Slice(index, index + 1); }
|
||||
|
Loading…
x
Reference in New Issue
Block a user