improvement(ViT): use Crop to subtitude Gather (#477)

* improvement(ViT): use Crop to subtitude Gather

* fix(CI): code format

* fix(pytorch/ops/linear.py): bias maybe None

* fix(test/test_pytorch_ops.py): op_type error

* fix(test): pytest error

* fix(test): torch version 1.8
pull/551/head
tpoisonooo 2022-06-02 19:39:15 +08:00 committed by GitHub
parent ee878b539b
commit cd336eada1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 147 additions and 4 deletions

View File

@ -1,6 +1,44 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "fuse_pass.h"
void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
const int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; ++i) {
onnx::NodeProto* gather = mutable_graph->mutable_node(i);
if (gather->op_type() != "Gather") {
continue;
}
auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]);
if (indices.size() != 1) {
continue;
}
{
// reconstruct node connections
node_reference[gather->input(1)] -= 1;
std::string origin_inp = gather->input(0);
gather->clear_input();
gather->add_input(origin_inp);
}
{
// update axis, starts and ends
int axis = get_node_attr_i(*gather, "axis", 1) - 1;
gather->set_op_type("Crop");
gather->clear_attribute();
int indice = indices[0];
set_node_attr_ai(*gather, "starts", std::vector<int>{indice});
set_node_attr_ai(*gather, "ends", std::vector<int>{indice + 1});
set_node_attr_ai(*gather, "axis", std::vector<int>{axis});
}
}
}
void fuse_weight_reshape(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,

View File

@ -4,6 +4,11 @@
#include "shape_inference.h"
#include "utils.h"
void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count);
void fuse_weight_reshape(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,

View File

@ -229,6 +229,7 @@ int main(int argc, char** argv) {
fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names,
reduced_node_count);
fuse_rewrite_gather(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
}
// reduce common const weight node_reference
@ -623,6 +624,8 @@ int main(int argc, char** argv) {
}
} else if (op == "Cos") {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Crop") {
fprintf(pp, "%-16s", "Crop");
} else if (op == "DepthToSpace") {
fprintf(pp, "%-16s", "PixelShuffle");
} else if (op == "DetectionOutput") {
@ -1196,6 +1199,22 @@ int main(int argc, char** argv) {
} else if (op == "Cos") {
int op_type = 10;
fprintf(pp, " 0=%d", op_type);
} else if (op == "Crop") {
auto starts = get_node_attr_ai(node, "starts");
fprintf(pp, " -23309=%zu", starts.size());
for (size_t j = 0; j < starts.size(); ++j) {
fprintf(pp, ",%i", starts[j]);
}
auto ends = get_node_attr_ai(node, "ends");
fprintf(pp, " -23310=%zu", ends.size());
for (size_t j = 0; j < ends.size(); ++j) {
fprintf(pp, ",%i", ends[j]);
}
auto axis = get_node_attr_ai(node, "axis");
fprintf(pp, " -23311=%zu", axis.size());
for (size_t j = 0; j < axis.size(); ++j) {
fprintf(pp, ",%i", axis[j]);
}
} else if (op == "DepthToSpace") {
// pixelshuffle
int scale_factor = get_node_attr_i(node, "blocksize", 1);
@ -1287,7 +1306,7 @@ int main(int argc, char** argv) {
}
fprintf(pp, " 0=%d", axis);
} else if (op == "Gelu") {
fprintf(pp, " 0=0");
fprintf(pp, " 0=1");
} else if (op == "Gemm") {
float alpha = get_node_attr_f(node, "alpha", 1.f);
float beta = get_node_attr_f(node, "beta", 1.f);

View File

@ -2,6 +2,8 @@
#include "shape_inference.h"
#include <algorithm>
/**
* @brief query output shape of target node
*

View File

@ -1,7 +1,6 @@
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
#include <algorithm>
#include "utils.h"

View File

@ -25,7 +25,7 @@ def linear__ncnn(
dim = input.dim()
if dim == 2:
if dim == 2 or dim == 3 and input.shape[0] == 1:
return origin_func(input, weight, bias)
else:
out = origin_func(input, weight)

View File

@ -8,6 +8,7 @@ from .grid_sampler import grid_sampler__default
from .hardsigmoid import hardsigmoid__default
from .instance_norm import instance_norm__tensorrt
from .layer_norm import layer_norm__ncnn
from .linear import linear__ncnn
from .lstm import generic_rnn__ncnn
from .squeeze import squeeze__default
@ -16,5 +17,5 @@ __all__ = [
'adaptive_avg_pool3d__default', 'grid_sampler__default',
'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn',
'squeeze__default', 'adaptive_avg_pool2d__ncnn', 'gelu__ncnn',
'layer_norm__ncnn'
'layer_norm__ncnn', 'linear__ncnn'
]

View File

@ -0,0 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from:
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
from torch.onnx.symbolic_helper import parse_args
from mmdeploy.core import SYMBOLIC_REWRITER
from mmdeploy.utils import Backend
@parse_args('v', 'v', 'f', 'f', 'i', 'i')
def linear_no_bias(g, input, weight):
"""Symbolic function for `linear` without bias.
PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'.
"""
return g.op(
'Gemm', input, weight, alpha_f=1.0, beta_f=1.0, transA_i=0, transB_i=1)
@parse_args('v', 'v', 'v', 'f', 'f', 'i', 'i')
def linear_normal(g, input, weight, bias):
"""Symbolic function for `linear`.
PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'.
"""
return g.op(
'Gemm',
input,
weight,
bias,
alpha_f=1.0,
beta_f=1.0,
transA_i=0,
transB_i=1)
@SYMBOLIC_REWRITER.register_symbolic(
'linear', is_pytorch=True, backend=Backend.NCNN.value)
def linear__ncnn(ctx, g, input, weight, bias):
"""Support export linear This rewrite enable export Gemm."""
if bias is None:
return linear_no_bias(g, input, weight)
else:
return linear_normal(g, input, weight, bias)

View File

@ -127,6 +127,41 @@ def test_instance_norm():
assert nodes[4].domain == 'mmdeploy'
@pytest.mark.usefixtures('prepare_symbolics_ncnn')
class TestLinear:
def check(self, nodes):
print(nodes)
from packaging.version import parse as version_parse
version = version_parse(torch.__version__)
target = 'Gemm'
if version.major <= 1 and version.minor <= 8:
target = 'MatMul'
exist = False
for node in nodes:
if node.op_type == target:
exist = True
break
assert exist is True
def test_normal(self):
x = torch.rand(1, 2, 3)
w = torch.rand(2, 3)
bias = torch.rand(2)
model = OpModel(torch.nn.functional.linear, w, bias).eval()
nodes = get_model_onnx_nodes(model, x)
self.check(nodes)
def test_no_bias(self):
x = torch.rand(1, 2, 3)
w = torch.rand(2, 3)
model = OpModel(torch.nn.functional.linear, w).eval()
nodes = get_model_onnx_nodes(model, x)
self.check(nodes)
@pytest.mark.usefixtures('prepare_symbolics')
class TestSqueeze: