improvement(ViT): use Crop to subtitude Gather (#477)
* improvement(ViT): use Crop to subtitude Gather * fix(CI): code format * fix(pytorch/ops/linear.py): bias maybe None * fix(test/test_pytorch_ops.py): op_type error * fix(test): pytest error * fix(test): torch version 1.8pull/551/head
parent
ee878b539b
commit
cd336eada1
|
@ -1,6 +1,44 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
#include "fuse_pass.h"
|
||||
|
||||
void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count) {
|
||||
const int node_count = mutable_graph->node_size();
|
||||
for (int i = 0; i < node_count; ++i) {
|
||||
onnx::NodeProto* gather = mutable_graph->mutable_node(i);
|
||||
if (gather->op_type() != "Gather") {
|
||||
continue;
|
||||
}
|
||||
auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]);
|
||||
if (indices.size() != 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
{
|
||||
// reconstruct node connections
|
||||
node_reference[gather->input(1)] -= 1;
|
||||
std::string origin_inp = gather->input(0);
|
||||
gather->clear_input();
|
||||
gather->add_input(origin_inp);
|
||||
}
|
||||
|
||||
{
|
||||
// update axis, starts and ends
|
||||
int axis = get_node_attr_i(*gather, "axis", 1) - 1;
|
||||
|
||||
gather->set_op_type("Crop");
|
||||
gather->clear_attribute();
|
||||
|
||||
int indice = indices[0];
|
||||
set_node_attr_ai(*gather, "starts", std::vector<int>{indice});
|
||||
set_node_attr_ai(*gather, "ends", std::vector<int>{indice + 1});
|
||||
set_node_attr_ai(*gather, "axis", std::vector<int>{axis});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void fuse_weight_reshape(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
|
|
|
@ -4,6 +4,11 @@
|
|||
#include "shape_inference.h"
|
||||
#include "utils.h"
|
||||
|
||||
void fuse_rewrite_gather(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count);
|
||||
|
||||
void fuse_weight_reshape(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
|
|
|
@ -229,6 +229,7 @@ int main(int argc, char** argv) {
|
|||
fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
|
||||
fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names,
|
||||
reduced_node_count);
|
||||
fuse_rewrite_gather(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
|
||||
}
|
||||
|
||||
// reduce common const weight node_reference
|
||||
|
@ -623,6 +624,8 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
} else if (op == "Cos") {
|
||||
fprintf(pp, "%-16s", "UnaryOp");
|
||||
} else if (op == "Crop") {
|
||||
fprintf(pp, "%-16s", "Crop");
|
||||
} else if (op == "DepthToSpace") {
|
||||
fprintf(pp, "%-16s", "PixelShuffle");
|
||||
} else if (op == "DetectionOutput") {
|
||||
|
@ -1196,6 +1199,22 @@ int main(int argc, char** argv) {
|
|||
} else if (op == "Cos") {
|
||||
int op_type = 10;
|
||||
fprintf(pp, " 0=%d", op_type);
|
||||
} else if (op == "Crop") {
|
||||
auto starts = get_node_attr_ai(node, "starts");
|
||||
fprintf(pp, " -23309=%zu", starts.size());
|
||||
for (size_t j = 0; j < starts.size(); ++j) {
|
||||
fprintf(pp, ",%i", starts[j]);
|
||||
}
|
||||
auto ends = get_node_attr_ai(node, "ends");
|
||||
fprintf(pp, " -23310=%zu", ends.size());
|
||||
for (size_t j = 0; j < ends.size(); ++j) {
|
||||
fprintf(pp, ",%i", ends[j]);
|
||||
}
|
||||
auto axis = get_node_attr_ai(node, "axis");
|
||||
fprintf(pp, " -23311=%zu", axis.size());
|
||||
for (size_t j = 0; j < axis.size(); ++j) {
|
||||
fprintf(pp, ",%i", axis[j]);
|
||||
}
|
||||
} else if (op == "DepthToSpace") {
|
||||
// pixelshuffle
|
||||
int scale_factor = get_node_attr_i(node, "blocksize", 1);
|
||||
|
@ -1287,7 +1306,7 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
fprintf(pp, " 0=%d", axis);
|
||||
} else if (op == "Gelu") {
|
||||
fprintf(pp, " 0=0");
|
||||
fprintf(pp, " 0=1");
|
||||
} else if (op == "Gemm") {
|
||||
float alpha = get_node_attr_f(node, "alpha", 1.f);
|
||||
float beta = get_node_attr_f(node, "beta", 1.f);
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
#include "shape_inference.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
/**
|
||||
* @brief query output shape of target node
|
||||
*
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
// Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ def linear__ncnn(
|
|||
|
||||
dim = input.dim()
|
||||
|
||||
if dim == 2:
|
||||
if dim == 2 or dim == 3 and input.shape[0] == 1:
|
||||
return origin_func(input, weight, bias)
|
||||
else:
|
||||
out = origin_func(input, weight)
|
||||
|
|
|
@ -8,6 +8,7 @@ from .grid_sampler import grid_sampler__default
|
|||
from .hardsigmoid import hardsigmoid__default
|
||||
from .instance_norm import instance_norm__tensorrt
|
||||
from .layer_norm import layer_norm__ncnn
|
||||
from .linear import linear__ncnn
|
||||
from .lstm import generic_rnn__ncnn
|
||||
from .squeeze import squeeze__default
|
||||
|
||||
|
@ -16,5 +17,5 @@ __all__ = [
|
|||
'adaptive_avg_pool3d__default', 'grid_sampler__default',
|
||||
'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn',
|
||||
'squeeze__default', 'adaptive_avg_pool2d__ncnn', 'gelu__ncnn',
|
||||
'layer_norm__ncnn'
|
||||
'layer_norm__ncnn', 'linear__ncnn'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from:
|
||||
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
|
||||
from torch.onnx.symbolic_helper import parse_args
|
||||
|
||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
||||
from mmdeploy.utils import Backend
|
||||
|
||||
|
||||
@parse_args('v', 'v', 'f', 'f', 'i', 'i')
|
||||
def linear_no_bias(g, input, weight):
|
||||
"""Symbolic function for `linear` without bias.
|
||||
|
||||
PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'.
|
||||
"""
|
||||
return g.op(
|
||||
'Gemm', input, weight, alpha_f=1.0, beta_f=1.0, transA_i=0, transB_i=1)
|
||||
|
||||
|
||||
@parse_args('v', 'v', 'v', 'f', 'f', 'i', 'i')
|
||||
def linear_normal(g, input, weight, bias):
|
||||
"""Symbolic function for `linear`.
|
||||
|
||||
PyTorch `nn.Linear` will be exported as ONNX node 'Gemm'.
|
||||
"""
|
||||
return g.op(
|
||||
'Gemm',
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
alpha_f=1.0,
|
||||
beta_f=1.0,
|
||||
transA_i=0,
|
||||
transB_i=1)
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'linear', is_pytorch=True, backend=Backend.NCNN.value)
|
||||
def linear__ncnn(ctx, g, input, weight, bias):
|
||||
"""Support export linear This rewrite enable export Gemm."""
|
||||
if bias is None:
|
||||
return linear_no_bias(g, input, weight)
|
||||
else:
|
||||
return linear_normal(g, input, weight, bias)
|
|
@ -127,6 +127,41 @@ def test_instance_norm():
|
|||
assert nodes[4].domain == 'mmdeploy'
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('prepare_symbolics_ncnn')
|
||||
class TestLinear:
|
||||
|
||||
def check(self, nodes):
|
||||
print(nodes)
|
||||
|
||||
from packaging.version import parse as version_parse
|
||||
version = version_parse(torch.__version__)
|
||||
target = 'Gemm'
|
||||
if version.major <= 1 and version.minor <= 8:
|
||||
target = 'MatMul'
|
||||
exist = False
|
||||
for node in nodes:
|
||||
if node.op_type == target:
|
||||
exist = True
|
||||
break
|
||||
|
||||
assert exist is True
|
||||
|
||||
def test_normal(self):
|
||||
x = torch.rand(1, 2, 3)
|
||||
w = torch.rand(2, 3)
|
||||
bias = torch.rand(2)
|
||||
model = OpModel(torch.nn.functional.linear, w, bias).eval()
|
||||
nodes = get_model_onnx_nodes(model, x)
|
||||
self.check(nodes)
|
||||
|
||||
def test_no_bias(self):
|
||||
x = torch.rand(1, 2, 3)
|
||||
w = torch.rand(2, 3)
|
||||
model = OpModel(torch.nn.functional.linear, w).eval()
|
||||
nodes = get_model_onnx_nodes(model, x)
|
||||
self.check(nodes)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures('prepare_symbolics')
|
||||
class TestSqueeze:
|
||||
|
||||
|
|
Loading…
Reference in New Issue