mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Unittest] Add NCNN tensorslice unittest and fix tensorslice.cpp bugs. (#115)
* add tensorslice unittest * reply code review * fix lint * fix typo
This commit is contained in:
parent
f56a30025a
commit
cba43e4c22
@ -53,12 +53,17 @@ int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob,
|
||||
size_t elemsize = bottom_blob.elemsize;
|
||||
const int* start_ptr = starts;
|
||||
const int* end_ptr = ends;
|
||||
const float* axes_ptr = axes;
|
||||
const int* axes_ptr = axes;
|
||||
const int* step_ptr = steps;
|
||||
if (starts.w > dims || ends.w > dims) {
|
||||
fprintf(stderr, "start/end attributes shape error!\n");
|
||||
return -100;
|
||||
}
|
||||
if (axes.w != 1) {
|
||||
fprintf(stderr,
|
||||
"axes.w must be 1 because any of multiaxes slice is regarded as "
|
||||
"multi-staged onnx slice in pytorch2onnx.");
|
||||
}
|
||||
if (dims == 1) {
|
||||
for (int i = 0; i < axes.w; i++) {
|
||||
int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i];
|
||||
@ -106,6 +111,8 @@ int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob,
|
||||
int start = start_ptr[i];
|
||||
int end = end_ptr[i];
|
||||
int dim_shape = get_shape_by_axes(bottom_blob, positive_axis, dims);
|
||||
int dim_shape_test =
|
||||
get_shape_by_axes(bottom_blob, positive_axis, dims - 1);
|
||||
if (dim_shape < 0) {
|
||||
return -1;
|
||||
}
|
||||
@ -127,6 +134,7 @@ int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob,
|
||||
return -100;
|
||||
}
|
||||
active_indice[positive_axis - 1] = temp_indice;
|
||||
active_indice[positive_axis - 1].resize(temp_indice.size());
|
||||
}
|
||||
top_blob.create((int)active_indice[1].size(), (int)active_indice[0].size(),
|
||||
elemsize, opt.blob_allocator);
|
||||
@ -138,6 +146,7 @@ int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob,
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (dims == 3) {
|
||||
std::vector<std::vector<int> > active_indice;
|
||||
std::vector<int> indices;
|
||||
@ -177,7 +186,8 @@ int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob,
|
||||
fprintf(stderr, "step should not be 0!\n");
|
||||
return -100;
|
||||
}
|
||||
active_indice[positive_axis] = temp_indice;
|
||||
active_indice[positive_axis - 1] = temp_indice;
|
||||
active_indice[positive_axis - 1].resize(temp_indice.size());
|
||||
}
|
||||
top_blob.create((int)active_indice[2].size(), (int)active_indice[1].size(),
|
||||
(int)active_indice[0].size(), elemsize, opt.blob_allocator);
|
||||
@ -192,6 +202,7 @@ int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob,
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
@ -545,3 +545,38 @@ def test_constantofshape(backend,
|
||||
ncnn_outputs = ncnn_model(dict(zip(input_names, [input.float()])))
|
||||
ncnn_outputs = [ncnn_outputs[name] for name in output_names]
|
||||
assert_allclose(model_outputs, ncnn_outputs, tolerate_small_mismatch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [TEST_NCNN])
|
||||
@pytest.mark.parametrize('dim', [1, 2, 3])
|
||||
def test_tensorslice(backend, dim, input_list=None, save_dir=None):
|
||||
backend.check_env()
|
||||
|
||||
if input_list is None:
|
||||
input = torch.rand((8, 12, 17)[-dim:]).unsqueeze(0)
|
||||
else:
|
||||
input = input_list[0]
|
||||
assert input.dim() == dim + 1, f'input.dim() must equal to \
|
||||
dim + 1, expected: {dim + 1}, got: {input.dim()}'
|
||||
|
||||
assert input.shape[0] == 1, (f'ncnn batch must be 1, \
|
||||
but got {input.shape[0]}')
|
||||
cfg = dict()
|
||||
register_extra_symbolics(cfg=cfg, backend=backend.backend_name, opset=11)
|
||||
|
||||
def tensorslice_function(inputs):
|
||||
if dim == 1:
|
||||
return inputs[:, 2:17:7]
|
||||
if dim == 2:
|
||||
return inputs[:, 3:12:4, 2:15:3]
|
||||
if dim == 3:
|
||||
return inputs[:, 0:8:2, 2:12:4, 2:17:7]
|
||||
|
||||
wrapped_model = WrapFunction(tensorslice_function)
|
||||
|
||||
backend.run_and_validate(
|
||||
wrapped_model, [input.float()],
|
||||
'tensorslice',
|
||||
input_names=['inputs'],
|
||||
output_names=['outputs'],
|
||||
save_dir=save_dir)
|
||||
|
Loading…
x
Reference in New Issue
Block a user