[Unittest] add ncnn unittest expand and fix expand.cpp bugs. (#118)

* add ncnn unittest expand

* rollback utils.py

* remove figures
This commit is contained in:
hanrui1sensetime 2021-10-14 10:24:14 +08:00 committed by GitHub
parent d4828c7836
commit 07cb78bb7c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 132 deletions

View File

@ -115,34 +115,6 @@ int Expand::forward(const std::vector<Mat>& bottom_blobs,
}
return 0;
}
if (bottom_blob.dims == 2 && shape_blob.w == 1) {
int shape_0 = (int)(shape_blob[0] + 0.5);
if (bottom_blob.w != shape_0 && bottom_blob.w != 1 && shape_0 != 1) {
fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (1, %d)\n",
bottom_blob.h, bottom_blob.w, shape_0);
} else if (bottom_blob.w == shape_0 || shape_0 == 1) {
top_blob.create(bottom_blob.w, bottom_blob.h, elemsize,
opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.row(j)[i] = bottom_blob.row(j)[i];
}
}
} else if (bottom_blob.w == 1) {
top_blob.create(shape_0, bottom_blob.h, elemsize, opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < shape_0; i++) {
top_blob.row(j)[i] = bottom_blob.row(j)[0];
}
}
} else {
fprintf(stderr, "error case\n");
return -100;
}
return 0;
}
if (bottom_blob.dims == 2 && shape_blob.w == 2) {
int shape_0 = (int)(shape_blob[0] + 0.5);
int shape_1 = (int)(shape_blob[1] + 0.5);
@ -258,110 +230,6 @@ int Expand::forward(const std::vector<Mat>& bottom_blobs,
}
return 0;
}
if (bottom_blob.dims == 3 && shape_blob.w == 1) {
int shape_0 = (int)(shape_blob[0] + 0.5);
if (bottom_blob.w != shape_0 && bottom_blob.w != 1 && shape_0 != 1) {
fprintf(stderr,
"The broadcast rule is wrong, (%d, %d, %d) vs (1, 1, %d)\n",
bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0);
} else if (bottom_blob.w == shape_0 || shape_0 == 1) {
top_blob.create(bottom_blob.w, bottom_blob.h, bottom_blob.c, elemsize,
opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < bottom_blob.c; k++) {
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[i];
}
}
}
} else if (bottom_blob.w == 1) {
top_blob.create(shape_0, bottom_blob.h, bottom_blob.c, elemsize,
opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < bottom_blob.c; k++) {
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < shape_0; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[0];
}
}
}
} else {
fprintf(stderr, "error case\n");
return -100;
}
return 0;
}
if (bottom_blob.dims == 3 && shape_blob.w == 2) {
int shape_0 = (int)(shape_blob[0] + 0.5);
int shape_1 = (int)(shape_blob[1] + 0.5);
if (bottom_blob.w != shape_1 && bottom_blob.w != 1 && shape_1 != 1) {
fprintf(stderr,
"The broadcast rule is wrong, (%d, %d, %d) vs (1, %d, %d)\n",
bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0, shape_1);
} else if (bottom_blob.h != shape_0 && bottom_blob.h != 1 && shape_0 != 1) {
fprintf(stderr,
"The broadcast rule is wrong, (%d, %d, %d) vs (1, %d, %d)\n",
bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0, shape_1);
} else if ((bottom_blob.w == shape_1 || shape_1 == 1) &&
(bottom_blob.h == shape_0 || shape_0 == 1)) {
top_blob.create(bottom_blob.w, bottom_blob.h, bottom_blob.c, elemsize,
opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < bottom_blob.c; k++) {
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[i];
}
}
}
} else if ((bottom_blob.w == shape_1 || shape_1 == 1) &&
(bottom_blob.h == 1)) {
top_blob.create(bottom_blob.w, shape_0, bottom_blob.c, elemsize,
opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < bottom_blob.c; k++) {
for (int j = 0; j < shape_0; j++) {
for (int i = 0; i < bottom_blob.w; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[i];
}
}
}
} else if ((bottom_blob.w == 1) &&
(bottom_blob.h == shape_0 || shape_0 == 1)) {
top_blob.create(shape_1, bottom_blob.h, bottom_blob.c, elemsize,
opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < bottom_blob.c; k++) {
for (int j = 0; j < bottom_blob.h; j++) {
for (int i = 0; i < shape_1; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[0];
}
}
}
} else if (bottom_blob.h == 1 && bottom_blob.w == 1) {
top_blob.create(shape_1, shape_0, bottom_blob.c, elemsize,
opt.blob_allocator);
if (top_blob.empty()) return -100;
for (int k = 0; k < bottom_blob.c; k++) {
for (int j = 0; j < shape_0; j++) {
for (int i = 0; i < shape_1; i++) {
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[0];
}
}
}
} else {
fprintf(stderr, "error case\n");
return -100;
}
return 0;
}
if (bottom_blob.dims == 3 && shape_blob.w == 3) {
int shape_0 = (int)(shape_blob[0] + 0.5);
int shape_1 = (int)(shape_blob[1] + 0.5);

View File

@ -610,3 +610,37 @@ def test_tensorslice(backend, dim, input_list=None, save_dir=None):
input_names=['inputs'],
output_names=['outputs'],
save_dir=save_dir)
@pytest.mark.parametrize('backend', [TEST_NCNN])
@pytest.mark.parametrize('input_dim, output_dim', [(1, 1), (1, 2), (1, 3),
(2, 2), (2, 3), (3, 3)])
def test_expand(backend,
input_dim,
output_dim,
input_list=None,
save_dir=None):
backend.check_env()
if input_list is None:
input = torch.rand((1, 12, 1)[-input_dim:]).unsqueeze(0)
target = torch.rand((8, 12, 17)[-output_dim:]).unsqueeze(0)
else:
input = input_list[0]
target = input_list[1]
assert input.shape[0] == 1, (f'ncnn batch must be 1, \
but not {input.shape[0]}')
assert target.shape[0] == 1, (f'ncnn batch must be 1, \
but not {target.shape[0]}')
cfg = dict()
register_extra_symbolics(cfg=cfg, backend=backend.backend_name, opset=11)
def expand_function(input, target):
return input.expand_as(target)
wrapped_model = WrapFunction(expand_function)
backend.run_and_validate(
wrapped_model, [input.float(), target.float()],
'expand',
input_names=['input', 'shape'],
output_names=['output'],
save_dir=save_dir)