mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Unittest] add ncnn unittest expand and fix expand.cpp bugs. (#118)
* add ncnn unittest expand * rollback utils.py * remove figures
This commit is contained in:
parent
d4828c7836
commit
07cb78bb7c
@ -115,34 +115,6 @@ int Expand::forward(const std::vector<Mat>& bottom_blobs,
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
if (bottom_blob.dims == 2 && shape_blob.w == 1) {
|
||||
int shape_0 = (int)(shape_blob[0] + 0.5);
|
||||
if (bottom_blob.w != shape_0 && bottom_blob.w != 1 && shape_0 != 1) {
|
||||
fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (1, %d)\n",
|
||||
bottom_blob.h, bottom_blob.w, shape_0);
|
||||
} else if (bottom_blob.w == shape_0 || shape_0 == 1) {
|
||||
top_blob.create(bottom_blob.w, bottom_blob.h, elemsize,
|
||||
opt.blob_allocator);
|
||||
if (top_blob.empty()) return -100;
|
||||
for (int j = 0; j < bottom_blob.h; j++) {
|
||||
for (int i = 0; i < bottom_blob.w; i++) {
|
||||
top_blob.row(j)[i] = bottom_blob.row(j)[i];
|
||||
}
|
||||
}
|
||||
} else if (bottom_blob.w == 1) {
|
||||
top_blob.create(shape_0, bottom_blob.h, elemsize, opt.blob_allocator);
|
||||
if (top_blob.empty()) return -100;
|
||||
for (int j = 0; j < bottom_blob.h; j++) {
|
||||
for (int i = 0; i < shape_0; i++) {
|
||||
top_blob.row(j)[i] = bottom_blob.row(j)[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "error case\n");
|
||||
return -100;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
if (bottom_blob.dims == 2 && shape_blob.w == 2) {
|
||||
int shape_0 = (int)(shape_blob[0] + 0.5);
|
||||
int shape_1 = (int)(shape_blob[1] + 0.5);
|
||||
@ -258,110 +230,6 @@ int Expand::forward(const std::vector<Mat>& bottom_blobs,
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
if (bottom_blob.dims == 3 && shape_blob.w == 1) {
|
||||
int shape_0 = (int)(shape_blob[0] + 0.5);
|
||||
if (bottom_blob.w != shape_0 && bottom_blob.w != 1 && shape_0 != 1) {
|
||||
fprintf(stderr,
|
||||
"The broadcast rule is wrong, (%d, %d, %d) vs (1, 1, %d)\n",
|
||||
bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0);
|
||||
} else if (bottom_blob.w == shape_0 || shape_0 == 1) {
|
||||
top_blob.create(bottom_blob.w, bottom_blob.h, bottom_blob.c, elemsize,
|
||||
opt.blob_allocator);
|
||||
if (top_blob.empty()) return -100;
|
||||
for (int k = 0; k < bottom_blob.c; k++) {
|
||||
for (int j = 0; j < bottom_blob.h; j++) {
|
||||
for (int i = 0; i < bottom_blob.w; i++) {
|
||||
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else if (bottom_blob.w == 1) {
|
||||
top_blob.create(shape_0, bottom_blob.h, bottom_blob.c, elemsize,
|
||||
opt.blob_allocator);
|
||||
if (top_blob.empty()) return -100;
|
||||
for (int k = 0; k < bottom_blob.c; k++) {
|
||||
for (int j = 0; j < bottom_blob.h; j++) {
|
||||
for (int i = 0; i < shape_0; i++) {
|
||||
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
fprintf(stderr, "error case\n");
|
||||
return -100;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
if (bottom_blob.dims == 3 && shape_blob.w == 2) {
|
||||
int shape_0 = (int)(shape_blob[0] + 0.5);
|
||||
int shape_1 = (int)(shape_blob[1] + 0.5);
|
||||
if (bottom_blob.w != shape_1 && bottom_blob.w != 1 && shape_1 != 1) {
|
||||
fprintf(stderr,
|
||||
"The broadcast rule is wrong, (%d, %d, %d) vs (1, %d, %d)\n",
|
||||
bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0, shape_1);
|
||||
} else if (bottom_blob.h != shape_0 && bottom_blob.h != 1 && shape_0 != 1) {
|
||||
fprintf(stderr,
|
||||
"The broadcast rule is wrong, (%d, %d, %d) vs (1, %d, %d)\n",
|
||||
bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0, shape_1);
|
||||
} else if ((bottom_blob.w == shape_1 || shape_1 == 1) &&
|
||||
(bottom_blob.h == shape_0 || shape_0 == 1)) {
|
||||
top_blob.create(bottom_blob.w, bottom_blob.h, bottom_blob.c, elemsize,
|
||||
opt.blob_allocator);
|
||||
if (top_blob.empty()) return -100;
|
||||
for (int k = 0; k < bottom_blob.c; k++) {
|
||||
for (int j = 0; j < bottom_blob.h; j++) {
|
||||
for (int i = 0; i < bottom_blob.w; i++) {
|
||||
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else if ((bottom_blob.w == shape_1 || shape_1 == 1) &&
|
||||
(bottom_blob.h == 1)) {
|
||||
top_blob.create(bottom_blob.w, shape_0, bottom_blob.c, elemsize,
|
||||
opt.blob_allocator);
|
||||
if (top_blob.empty()) return -100;
|
||||
for (int k = 0; k < bottom_blob.c; k++) {
|
||||
for (int j = 0; j < shape_0; j++) {
|
||||
for (int i = 0; i < bottom_blob.w; i++) {
|
||||
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else if ((bottom_blob.w == 1) &&
|
||||
(bottom_blob.h == shape_0 || shape_0 == 1)) {
|
||||
top_blob.create(shape_1, bottom_blob.h, bottom_blob.c, elemsize,
|
||||
opt.blob_allocator);
|
||||
if (top_blob.empty()) return -100;
|
||||
for (int k = 0; k < bottom_blob.c; k++) {
|
||||
for (int j = 0; j < bottom_blob.h; j++) {
|
||||
for (int i = 0; i < shape_1; i++) {
|
||||
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else if (bottom_blob.h == 1 && bottom_blob.w == 1) {
|
||||
top_blob.create(shape_1, shape_0, bottom_blob.c, elemsize,
|
||||
opt.blob_allocator);
|
||||
if (top_blob.empty()) return -100;
|
||||
for (int k = 0; k < bottom_blob.c; k++) {
|
||||
for (int j = 0; j < shape_0; j++) {
|
||||
for (int i = 0; i < shape_1; i++) {
|
||||
top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
fprintf(stderr, "error case\n");
|
||||
return -100;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
if (bottom_blob.dims == 3 && shape_blob.w == 3) {
|
||||
int shape_0 = (int)(shape_blob[0] + 0.5);
|
||||
int shape_1 = (int)(shape_blob[1] + 0.5);
|
||||
|
@ -610,3 +610,37 @@ def test_tensorslice(backend, dim, input_list=None, save_dir=None):
|
||||
input_names=['inputs'],
|
||||
output_names=['outputs'],
|
||||
save_dir=save_dir)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend', [TEST_NCNN])
|
||||
@pytest.mark.parametrize('input_dim, output_dim', [(1, 1), (1, 2), (1, 3),
|
||||
(2, 2), (2, 3), (3, 3)])
|
||||
def test_expand(backend,
|
||||
input_dim,
|
||||
output_dim,
|
||||
input_list=None,
|
||||
save_dir=None):
|
||||
backend.check_env()
|
||||
if input_list is None:
|
||||
input = torch.rand((1, 12, 1)[-input_dim:]).unsqueeze(0)
|
||||
target = torch.rand((8, 12, 17)[-output_dim:]).unsqueeze(0)
|
||||
else:
|
||||
input = input_list[0]
|
||||
target = input_list[1]
|
||||
assert input.shape[0] == 1, (f'ncnn batch must be 1, \
|
||||
but not {input.shape[0]}')
|
||||
assert target.shape[0] == 1, (f'ncnn batch must be 1, \
|
||||
but not {target.shape[0]}')
|
||||
cfg = dict()
|
||||
register_extra_symbolics(cfg=cfg, backend=backend.backend_name, opset=11)
|
||||
|
||||
def expand_function(input, target):
|
||||
return input.expand_as(target)
|
||||
|
||||
wrapped_model = WrapFunction(expand_function)
|
||||
backend.run_and_validate(
|
||||
wrapped_model, [input.float(), target.float()],
|
||||
'expand',
|
||||
input_names=['input', 'shape'],
|
||||
output_names=['output'],
|
||||
save_dir=save_dir)
|
||||
|
Loading…
x
Reference in New Issue
Block a user