diff --git a/backend_ops/ncnn/ops/expand/expand.cpp b/backend_ops/ncnn/ops/expand/expand.cpp index d23f45f57..9a4aa7a6f 100755 --- a/backend_ops/ncnn/ops/expand/expand.cpp +++ b/backend_ops/ncnn/ops/expand/expand.cpp @@ -115,34 +115,6 @@ int Expand::forward(const std::vector& bottom_blobs, } return 0; } - if (bottom_blob.dims == 2 && shape_blob.w == 1) { - int shape_0 = (int)(shape_blob[0] + 0.5); - if (bottom_blob.w != shape_0 && bottom_blob.w != 1 && shape_0 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (1, %d)\n", - bottom_blob.h, bottom_blob.w, shape_0); - } else if (bottom_blob.w == shape_0 || shape_0 == 1) { - top_blob.create(bottom_blob.w, bottom_blob.h, elemsize, - opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.row(j)[i] = bottom_blob.row(j)[i]; - } - } - } else if (bottom_blob.w == 1) { - top_blob.create(shape_0, bottom_blob.h, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < shape_0; i++) { - top_blob.row(j)[i] = bottom_blob.row(j)[0]; - } - } - } else { - fprintf(stderr, "error case\n"); - return -100; - } - return 0; - } if (bottom_blob.dims == 2 && shape_blob.w == 2) { int shape_0 = (int)(shape_blob[0] + 0.5); int shape_1 = (int)(shape_blob[1] + 0.5); @@ -258,110 +230,6 @@ int Expand::forward(const std::vector& bottom_blobs, } return 0; } - if (bottom_blob.dims == 3 && shape_blob.w == 1) { - int shape_0 = (int)(shape_blob[0] + 0.5); - if (bottom_blob.w != shape_0 && bottom_blob.w != 1 && shape_0 != 1) { - fprintf(stderr, - "The broadcast rule is wrong, (%d, %d, %d) vs (1, 1, %d)\n", - bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0); - } else if (bottom_blob.w == shape_0 || shape_0 == 1) { - top_blob.create(bottom_blob.w, bottom_blob.h, bottom_blob.c, elemsize, - opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[i]; - } - } - } - - } else if (bottom_blob.w == 1) { - top_blob.create(shape_0, bottom_blob.h, bottom_blob.c, elemsize, - opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < shape_0; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[0]; - } - } - } - - } else { - fprintf(stderr, "error case\n"); - return -100; - } - return 0; - } - if (bottom_blob.dims == 3 && shape_blob.w == 2) { - int shape_0 = (int)(shape_blob[0] + 0.5); - int shape_1 = (int)(shape_blob[1] + 0.5); - if (bottom_blob.w != shape_1 && bottom_blob.w != 1 && shape_1 != 1) { - fprintf(stderr, - "The broadcast rule is wrong, (%d, %d, %d) vs (1, %d, %d)\n", - bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0, shape_1); - } else if (bottom_blob.h != shape_0 && bottom_blob.h != 1 && shape_0 != 1) { - fprintf(stderr, - "The broadcast rule is wrong, (%d, %d, %d) vs (1, %d, %d)\n", - bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0, shape_1); - } else if ((bottom_blob.w == shape_1 || shape_1 == 1) && - (bottom_blob.h == shape_0 || shape_0 == 1)) { - top_blob.create(bottom_blob.w, bottom_blob.h, bottom_blob.c, elemsize, - opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[i]; - } - } - } - - } else if ((bottom_blob.w == shape_1 || shape_1 == 1) && - (bottom_blob.h == 1)) { - top_blob.create(bottom_blob.w, shape_0, bottom_blob.c, elemsize, - opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < shape_0; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[i]; - } - } - } - - } else if ((bottom_blob.w == 1) && - (bottom_blob.h == shape_0 || shape_0 == 1)) { - top_blob.create(shape_1, bottom_blob.h, bottom_blob.c, elemsize, - opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < shape_1; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[0]; - } - } - } - - } else if (bottom_blob.h == 1 && bottom_blob.w == 1) { - top_blob.create(shape_1, shape_0, bottom_blob.c, elemsize, - opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < shape_0; j++) { - for (int i = 0; i < shape_1; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[0]; - } - } - } - - } else { - fprintf(stderr, "error case\n"); - return -100; - } - return 0; - } if (bottom_blob.dims == 3 && shape_blob.w == 3) { int shape_0 = (int)(shape_blob[0] + 0.5); int shape_1 = (int)(shape_blob[1] + 0.5); diff --git a/tests/test_ops/test_ops.py b/tests/test_ops/test_ops.py index 729d46589..f847788c7 100644 --- a/tests/test_ops/test_ops.py +++ b/tests/test_ops/test_ops.py @@ -610,3 +610,37 @@ def test_tensorslice(backend, dim, input_list=None, save_dir=None): input_names=['inputs'], output_names=['outputs'], save_dir=save_dir) + + +@pytest.mark.parametrize('backend', [TEST_NCNN]) +@pytest.mark.parametrize('input_dim, output_dim', [(1, 1), (1, 2), (1, 3), + (2, 2), (2, 3), (3, 3)]) +def test_expand(backend, + input_dim, + output_dim, + input_list=None, + save_dir=None): + backend.check_env() + if input_list is None: + input = torch.rand((1, 12, 1)[-input_dim:]).unsqueeze(0) + target = torch.rand((8, 12, 17)[-output_dim:]).unsqueeze(0) + else: + input = input_list[0] + target = input_list[1] + assert input.shape[0] == 1, (f'ncnn batch must be 1, \ + but not {input.shape[0]}') + assert target.shape[0] == 1, (f'ncnn batch must be 1, \ + but not {target.shape[0]}') + cfg = dict() + register_extra_symbolics(cfg=cfg, backend=backend.backend_name, opset=11) + + def expand_function(input, target): + return input.expand_as(target) + + wrapped_model = WrapFunction(expand_function) + backend.run_and_validate( + wrapped_model, [input.float(), target.float()], + 'expand', + input_names=['input', 'shape'], + output_names=['output'], + save_dir=save_dir)