feat(codebase/cls): support vision_transformer (#403)
* feat(codebase/cls): support vision_transformer * style(onnx2ncnn): format cpp code, upgrade mmcls version * fix(CI): upgrade mmcv to 1.4.2 * fix(onnx2ncnn): offset out of range during fuse conv reshape * docs(vision_transformer.py): update VisionTransformer desc * docs(onnx2ncnn.cpp): add more comment * feat(onnx2ncnn.cpp): revert fuse weight * docs(onnx2ncnn.cpp): add more comment * test(vision_transformer): add test case * refactor(vision_transformer.py): use symbol rewrite layer_norm * refactor(vision_transformer): fix review * fix(attention): add missing filespull/316/head^2
parent
01a44c00c9
commit
2c2d1e5ad9
|
@ -23,7 +23,7 @@ jobs:
|
|||
matrix:
|
||||
python-version: [3.7]
|
||||
torch: [1.8.0, 1.9.0]
|
||||
mmcv: [1.4.0]
|
||||
mmcv: [1.4.2]
|
||||
include:
|
||||
- torch: 1.8.0
|
||||
torch_version: torch1.8
|
||||
|
@ -65,7 +65,7 @@ jobs:
|
|||
matrix:
|
||||
python-version: [3.7]
|
||||
torch: [1.9.0+cu102]
|
||||
mmcv: [1.4.0]
|
||||
mmcv: [1.4.2]
|
||||
include:
|
||||
- torch: 1.9.0+cu102
|
||||
torch_version: torch1.9
|
||||
|
@ -108,7 +108,7 @@ jobs:
|
|||
matrix:
|
||||
python-version: [3.7]
|
||||
torch: [1.8.0+cu111]
|
||||
mmcv: [1.4.0]
|
||||
mmcv: [1.4.2]
|
||||
include:
|
||||
- torch: 1.8.0+cu111
|
||||
torch_version: torch1.8
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <set>
|
||||
#include <tuple>
|
||||
|
||||
#include "onnx.pb.h"
|
||||
|
||||
|
@ -73,6 +74,17 @@ static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char
|
|||
return v;
|
||||
}
|
||||
|
||||
static void set_node_attr_ai(onnx::NodeProto& node, const char* key,
|
||||
const std::vector<int>& value) {
|
||||
onnx::AttributeProto* attr_group = node.add_attribute();
|
||||
attr_group->set_name(key);
|
||||
for (auto v : value) {
|
||||
attr_group->add_ints(v);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key) {
|
||||
std::vector<float> v;
|
||||
|
||||
|
@ -137,8 +149,9 @@ static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const
|
|||
return onnx::TensorProto();
|
||||
}
|
||||
|
||||
static float get_node_attr_from_input_f(const onnx::TensorProto& tp) {
|
||||
float v = 0.f;
|
||||
template <typename T>
|
||||
static T get_node_attr_from_input(const onnx::TensorProto& tp) {
|
||||
T v = 0.f;
|
||||
|
||||
// float
|
||||
if (tp.data_type() == 1) {
|
||||
|
@ -183,7 +196,7 @@ static float get_node_attr_from_input_f(const onnx::TensorProto& tp) {
|
|||
} else {
|
||||
// fprintf(stderr, "tp.name: %s\n", tp.name().c_str());
|
||||
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
|
||||
fprintf(stderr, "get_node_attr_from_input_f\n");
|
||||
fprintf(stderr, "get_node_attr_from_input\n");
|
||||
abort();
|
||||
}
|
||||
|
||||
|
@ -680,7 +693,7 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph,
|
|||
const onnx::TensorProto& add_three = weights[node->input(1)];
|
||||
if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue;
|
||||
|
||||
float constant_add_three = get_node_attr_from_input_f(add_three);
|
||||
float constant_add_three = get_node_attr_from_input<float>(add_three);
|
||||
if (constant_add_three != 3.f) continue;
|
||||
|
||||
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
|
||||
|
@ -708,8 +721,8 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph,
|
|||
const onnx::TensorProto& min_tp = weights[node2->input(1)];
|
||||
const onnx::TensorProto& max_tp = weights[node2->input(2)];
|
||||
|
||||
relu6_min = get_node_attr_from_input_f(min_tp);
|
||||
relu6_max = get_node_attr_from_input_f(max_tp);
|
||||
relu6_min = get_node_attr_from_input<float>(min_tp);
|
||||
relu6_max = get_node_attr_from_input<float>(max_tp);
|
||||
}
|
||||
if (relu6_min != 0.f || relu6_max != 6.f) continue;
|
||||
|
||||
|
@ -722,7 +735,7 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph,
|
|||
const onnx::TensorProto& div_six = weights[node4->input(1)];
|
||||
if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue;
|
||||
|
||||
float constant_div_six = get_node_attr_from_input_f(div_six);
|
||||
float constant_div_six = get_node_attr_from_input<float>(div_six);
|
||||
if (node4->op_type() == "Div" && constant_div_six != 6.f) continue;
|
||||
if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue;
|
||||
|
||||
|
@ -831,7 +844,7 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph,
|
|||
const onnx::TensorProto& add_three = weights[node->input(1)];
|
||||
if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue;
|
||||
|
||||
float constant_add_three = get_node_attr_from_input_f(add_three);
|
||||
float constant_add_three = get_node_attr_from_input<float>(add_three);
|
||||
if (constant_add_three != 3.f) continue;
|
||||
|
||||
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
|
||||
|
@ -857,8 +870,8 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph,
|
|||
const onnx::TensorProto& min_tp = weights[node2->input(1)];
|
||||
const onnx::TensorProto& max_tp = weights[node2->input(2)];
|
||||
|
||||
relu6_min = get_node_attr_from_input_f(min_tp);
|
||||
relu6_max = get_node_attr_from_input_f(max_tp);
|
||||
relu6_min = get_node_attr_from_input<float>(min_tp);
|
||||
relu6_max = get_node_attr_from_input<float>(max_tp);
|
||||
}
|
||||
if (relu6_min != 0.f || relu6_max != 6.f) continue;
|
||||
|
||||
|
@ -867,7 +880,7 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph,
|
|||
const onnx::TensorProto& div_six = weights[node3->input(1)];
|
||||
if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue;
|
||||
|
||||
float constant_div_six = get_node_attr_from_input_f(div_six);
|
||||
float constant_div_six = get_node_attr_from_input<float>(div_six);
|
||||
if (node3->op_type() == "Div" && constant_div_six != 6.f) continue;
|
||||
if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue;
|
||||
|
||||
|
@ -1090,7 +1103,7 @@ static void fuse_normalize(onnx::GraphProto* mutable_graph,
|
|||
} else {
|
||||
const onnx::TensorProto& min_tp = weights[node2->input(1)];
|
||||
|
||||
clip_min = get_node_attr_from_input_f(min_tp);
|
||||
clip_min = get_node_attr_from_input<float>(min_tp);
|
||||
}
|
||||
|
||||
// reduce
|
||||
|
@ -1343,7 +1356,7 @@ static void fuse_layernorm(onnx::GraphProto* mutable_graph,
|
|||
const onnx::TensorProto& pow_two = weights[node3->input(1)];
|
||||
if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1) continue;
|
||||
|
||||
float constant_pow_two = get_node_attr_from_input_f(pow_two);
|
||||
float constant_pow_two = get_node_attr_from_input<float>(pow_two);
|
||||
if (constant_pow_two != 2.f) continue;
|
||||
|
||||
std::vector<int> axes4 = get_node_attr_ai(*node4, "axes");
|
||||
|
@ -1360,7 +1373,7 @@ static void fuse_layernorm(onnx::GraphProto* mutable_graph,
|
|||
const onnx::TensorProto& add_eps = weights[node5->input(1)];
|
||||
if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1) continue;
|
||||
|
||||
float eps = get_node_attr_from_input_f(add_eps);
|
||||
float eps = get_node_attr_from_input<float>(add_eps);
|
||||
|
||||
int affine = 0;
|
||||
while (i + 8 < node_count) {
|
||||
|
@ -2546,6 +2559,320 @@ static void fuse_multiheadattention(onnx::GraphProto* mutable_graph,
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief find graph node by output name
|
||||
*
|
||||
* @param graph
|
||||
* @param name
|
||||
* @return onnx::NodeProto*
|
||||
*/
|
||||
static onnx::NodeProto* find_node_by_output_name(onnx::GraphProto* mutable_graph,
|
||||
const std::string& name) {
|
||||
const int input_count = mutable_graph->node_size();
|
||||
for (int i = 0; i < input_count; ++i) {
|
||||
onnx::NodeProto* node = mutable_graph->mutable_node(i);
|
||||
|
||||
for (int j = 0; j < node->output_size(); ++j) {
|
||||
auto output = node->output(j);
|
||||
if (output == name) {
|
||||
return node;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief query output shape of target node
|
||||
*
|
||||
* @param mutable_graph
|
||||
* @param target
|
||||
* @param weights
|
||||
* @param context <tensor name, shape>
|
||||
* @return std::tuple<bool, std::vector<int>>
|
||||
*/
|
||||
static std::tuple<bool, std::vector<int>> query_shape(
|
||||
onnx::GraphProto* mutable_graph, onnx::NodeProto* target,
|
||||
const std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, std::vector<int>>& context) {
|
||||
// emplace all input nodes
|
||||
const int input_count = mutable_graph->input_size();
|
||||
for (int i = 0; i < input_count; i++) {
|
||||
auto inp = mutable_graph->input(i);
|
||||
onnx::TypeProto inp_type = inp.type();
|
||||
onnx::TensorShapeProto shape_proto = inp_type.tensor_type().shape();
|
||||
|
||||
auto dim_size = shape_proto.dim_size();
|
||||
std::vector<int> shape(dim_size);
|
||||
for (int index = 0; index < dim_size; ++index) {
|
||||
shape[index] = shape_proto.dim(index).dim_value();
|
||||
}
|
||||
|
||||
context.emplace(inp.name(), shape);
|
||||
}
|
||||
|
||||
// BFS the tree, `target` as root, onnx::graph inputs and weights as leaf nodes
|
||||
std::vector<onnx::NodeProto*> serial = {target};
|
||||
{
|
||||
std::set<std::string> mark_as_appended = {};
|
||||
while (true) {
|
||||
int start = 0, end = serial.size();
|
||||
for (int i = start; i < end; ++i) {
|
||||
auto node_ptr = serial[i];
|
||||
auto len = node_ptr->input_size();
|
||||
|
||||
for (int j = 0; j < len; ++j) {
|
||||
std::string name = node_ptr->input(j);
|
||||
if (context.find(name) != context.end()) {
|
||||
// if input founded, skip
|
||||
continue;
|
||||
}
|
||||
|
||||
if (weights.find(name) != weights.end()) {
|
||||
// if founded in weights, extract shape to context
|
||||
auto weight = weights.at(name);
|
||||
std::vector<int> shape;
|
||||
for (auto index = 0; index < weight.dims_size(); ++index) {
|
||||
shape.emplace_back(weight.dims(index));
|
||||
}
|
||||
context.emplace(name, shape);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mark_as_appended.find(name) != mark_as_appended.end()) {
|
||||
// if mark as appended, skip
|
||||
continue;
|
||||
}
|
||||
// else append it to serialization list
|
||||
auto depend_ptr = find_node_by_output_name(mutable_graph, name);
|
||||
if (depend_ptr == nullptr) {
|
||||
fprintf(stderr, "cannot find %s from graph !\n", name.c_str());
|
||||
return std::make_tuple(false, std::vector<int>{});
|
||||
}
|
||||
mark_as_appended.insert(name);
|
||||
serial.emplace_back(depend_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
if (serial.size() <= end) {
|
||||
// if not new node added, quit
|
||||
break;
|
||||
}
|
||||
|
||||
// update start and end position, continue BFS the tree
|
||||
start = end;
|
||||
end = serial.size();
|
||||
}
|
||||
}
|
||||
|
||||
// for each node in serialization list, calculate the output shape
|
||||
{
|
||||
std::reverse(serial.begin(), serial.end());
|
||||
for (auto node : serial) {
|
||||
if (node->op_type() == "Conv") {
|
||||
auto inp = context[node->input(0)];
|
||||
auto weight = context[node->input(1)];
|
||||
assert(inp.size() == 4 and weight.size() == 4);
|
||||
|
||||
int group = get_node_attr_i(*node, "group", 1);
|
||||
assert(group == 1);
|
||||
|
||||
// treat multiple spatial attr as single one
|
||||
#define EXTRACT_REPEATED_PARAM(NAME, ATTR, DEFAULT) \
|
||||
int ATTR = DEFAULT; \
|
||||
{ \
|
||||
std::vector<int> _vec = get_node_attr_ai(*node, NAME); \
|
||||
if (not _vec.empty()) { \
|
||||
ATTR = _vec[0]; \
|
||||
} \
|
||||
}
|
||||
|
||||
EXTRACT_REPEATED_PARAM("dilations", dilation, 1);
|
||||
EXTRACT_REPEATED_PARAM("pads", pad, 0);
|
||||
EXTRACT_REPEATED_PARAM("strides", stride, 1);
|
||||
|
||||
#undef EXTRACT_REPEATED_PARAM
|
||||
|
||||
int on = inp[0];
|
||||
int oc = weight[0];
|
||||
int oh = (inp[2] + 2 * pad - weight[2]) / stride + 1;
|
||||
int ow = (inp[3] + 2 * pad - weight[3]) / stride + 1;
|
||||
context.emplace(node->output(0), std::vector<int>{on, oc, oh, ow});
|
||||
|
||||
} else if (node->op_type() == "Shape") {
|
||||
auto inp = context[node->input(0)];
|
||||
context.emplace(node->output(0), std::vector<int>{1, inp[1], inp[2], inp[3]});
|
||||
|
||||
} else if (node->op_type() == "Slice") {
|
||||
assert(node->input_size() >= 4);
|
||||
|
||||
auto inp = context[node->input(0)];
|
||||
int start = get_node_attr_from_input<int>(weights.at(node->input(1)));
|
||||
int end = get_node_attr_from_input<int>(weights.at(node->input(2)));
|
||||
int axes = get_node_attr_from_input<int>(weights.at(node->input(3)));
|
||||
|
||||
if (axes != 0) {
|
||||
fprintf(stderr, "Not support axes=%d !\n", axes);
|
||||
return std::make_tuple(false, std::vector<int>{});
|
||||
}
|
||||
|
||||
assert(inp.size() >= end - start);
|
||||
context.emplace(node->output(0), std::vector<int>{inp.begin() + start, inp.begin() + end});
|
||||
|
||||
} else if (node->op_type() == "Concat") {
|
||||
assert(node->input_size() >= 2);
|
||||
|
||||
auto axis = get_node_attr_i(*node, "axis", 0);
|
||||
if (axis != 0) {
|
||||
fprintf(stderr, "Not support axes=%d !\n", axis);
|
||||
return std::make_tuple(false, std::vector<int>{});
|
||||
}
|
||||
|
||||
std::vector<int> inp = context[node->input(0)];
|
||||
std::vector<int> w_data = get_node_attr_from_input_ai(weights.at(node->input(1)));
|
||||
|
||||
// concat data on axis 0
|
||||
inp.insert(inp.end(), w_data.begin(), w_data.end());
|
||||
context.emplace(node->output(0), inp);
|
||||
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported type %s in query_shape !\n", node->op_type().c_str());
|
||||
return std::make_tuple(false, std::vector<int>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert(context.find(target->output(0)) != context.end());
|
||||
auto target_shape = context[target->output(0)];
|
||||
return std::make_tuple(true, target_shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief fuse subgraph
|
||||
*
|
||||
* conv - - - - - - - - - - - -> reshape
|
||||
* \ /
|
||||
* shape - slice - concat
|
||||
*
|
||||
* to
|
||||
*
|
||||
* conv --> reshape
|
||||
*
|
||||
* @param mutable_graph
|
||||
* @param weights
|
||||
* @param node_reference
|
||||
* @param blob_names
|
||||
* @param reduced_node_count
|
||||
*/
|
||||
static void fuse_conv_reshape(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
std::set<std::string>& blob_names, int& reduced_node_count) {
|
||||
std::map<std::string, std::vector<int>> shape_context;
|
||||
const int node_count = mutable_graph->node_size();
|
||||
|
||||
for (int i = 0; i < node_count; i++) {
|
||||
onnx::NodeProto* conv = mutable_graph->mutable_node(i);
|
||||
|
||||
if (conv->op_type() != "Conv") {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (i + 4 >= node_count) {
|
||||
continue;
|
||||
}
|
||||
|
||||
onnx::NodeProto *shape = nullptr, *slice = nullptr, *concat = nullptr, *reshape = nullptr;
|
||||
|
||||
// match [Shape ... Slice, Concat ... Reshape] from near sequence, skip useless Constant
|
||||
std::vector<std::tuple<std::string, onnx::NodeProto**>> candidates = {
|
||||
{"Shape", &shape}, {"Slice", &slice}, {"Concat", &concat}, {"Reshape", &reshape}};
|
||||
|
||||
int MAX = std::min(10, node_count - i - 1);
|
||||
int pos_candidate = 0;
|
||||
|
||||
for (int j = 0; j < MAX; ++j) {
|
||||
auto node_ptr = mutable_graph->mutable_node(j + i + 1);
|
||||
if (node_ptr->op_type() == "Constant") {
|
||||
continue;
|
||||
}
|
||||
if (node_ptr->op_type() == std::get<0>(candidates[pos_candidate])) {
|
||||
*(std::get<1>(candidates[pos_candidate])) = node_ptr;
|
||||
pos_candidate++;
|
||||
}
|
||||
}
|
||||
|
||||
if (pos_candidate != candidates.size()) {
|
||||
// not match the sequence
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node_reference[conv->output(0)] != 2 || node_reference[shape->output(0)] != 1 ||
|
||||
node_reference[slice->output(0)] != 1 || node_reference[concat->output(0)] != 1 ||
|
||||
node_reference[reshape->output(0)] != 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// check the connections
|
||||
if (shape->input(0) != conv->output(0) || reshape->input(0) != conv->output(0)) {
|
||||
continue;
|
||||
}
|
||||
if (slice->input(0) != shape->output(0)) {
|
||||
continue;
|
||||
}
|
||||
if (concat->input(0) != slice->output(0)) {
|
||||
continue;
|
||||
}
|
||||
if (reshape->input(0) != conv->output(0) || reshape->input(1) != concat->output(0)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// add reshape attr
|
||||
auto result = query_shape(mutable_graph, concat, weights, shape_context);
|
||||
if (!std::get<0>(result)) {
|
||||
continue;
|
||||
}
|
||||
set_node_attr_ai(*reshape, "shape", std::get<1>(result));
|
||||
|
||||
// reconstruct graph
|
||||
{
|
||||
// remove reference
|
||||
node_reference[reshape->input(1)] -= 1;
|
||||
node_reference[concat->input(0)] -= 1;
|
||||
node_reference[slice->input(0)] -= 1;
|
||||
node_reference[shape->input(0)] -= 1;
|
||||
|
||||
// remove tensor/blob on edge
|
||||
blob_names.erase(slice->input(0));
|
||||
blob_names.erase(slice->input(1));
|
||||
blob_names.erase(slice->input(2));
|
||||
blob_names.erase(slice->input(3));
|
||||
weights.erase(slice->input(1));
|
||||
weights.erase(slice->input(2));
|
||||
weights.erase(slice->input(3));
|
||||
|
||||
blob_names.erase(concat->input(0));
|
||||
blob_names.erase(concat->input(1));
|
||||
weights.erase(concat->input(1));
|
||||
|
||||
blob_names.erase(reshape->input(0));
|
||||
|
||||
// update edge
|
||||
shape->clear_input();
|
||||
reshape->clear_input();
|
||||
reshape->add_input(conv->output(0));
|
||||
|
||||
shape->set_op_type("noop_reducedncnn");
|
||||
slice->set_op_type("noop_reducedncnn");
|
||||
concat->set_op_type("noop_reducedncnn");
|
||||
|
||||
reduced_node_count += 3;
|
||||
}
|
||||
i += 3;
|
||||
}
|
||||
}
|
||||
|
||||
static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph,
|
||||
std::map<std::string, onnx::TensorProto>& weights,
|
||||
std::map<std::string, int>& node_reference,
|
||||
|
@ -2563,7 +2890,7 @@ static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph,
|
|||
const onnx::TensorProto& scalar_b = weights[node->input(1)];
|
||||
if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1) continue;
|
||||
|
||||
float b = get_node_attr_from_input_f(scalar_b);
|
||||
float b = get_node_attr_from_input<float>(scalar_b);
|
||||
|
||||
node_reference[node->input(1)] -= 1;
|
||||
|
||||
|
@ -2763,6 +3090,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
// op chain fusion
|
||||
int reduced_node_count = 0;
|
||||
fuse_conv_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
|
||||
fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
|
||||
fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
|
||||
fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
|
||||
|
@ -3200,6 +3528,8 @@ int main(int argc, char** argv) {
|
|||
fprintf(pp, "%-16s", "UnaryOp");
|
||||
} else if (op == "Gather") {
|
||||
fprintf(pp, "%-16s", "Gather");
|
||||
} else if (op == "Gelu") {
|
||||
fprintf(pp, "%-16s", "GELU");
|
||||
} else if (op == "Gemm") {
|
||||
float alpha = get_node_attr_f(node, "alpha", 1.f);
|
||||
float beta = get_node_attr_f(node, "beta", 1.f);
|
||||
|
@ -3542,10 +3872,10 @@ int main(int argc, char** argv) {
|
|||
max = get_node_attr_f(node, "max", FLT_MAX);
|
||||
} else {
|
||||
min = weights.find(node.input(1)) != weights.end()
|
||||
? get_node_attr_from_input_f(weights[node.input(1)])
|
||||
? get_node_attr_from_input<float>(weights[node.input(1)])
|
||||
: -FLT_MAX;
|
||||
max = weights.find(node.input(2)) != weights.end()
|
||||
? get_node_attr_from_input_f(weights[node.input(2)])
|
||||
? get_node_attr_from_input<float>(weights[node.input(2)])
|
||||
: FLT_MAX;
|
||||
}
|
||||
|
||||
|
@ -3835,6 +4165,8 @@ int main(int argc, char** argv) {
|
|||
fprintf(stderr, "Unsupported Gather axis: %d\n", axis + 1);
|
||||
}
|
||||
fprintf(pp, " 0=%d", axis);
|
||||
} else if (op == "Gelu") {
|
||||
fprintf(pp, " 0=0");
|
||||
} else if (op == "Gemm") {
|
||||
float alpha = get_node_attr_f(node, "alpha", 1.f);
|
||||
float beta = get_node_attr_f(node, "beta", 1.f);
|
||||
|
@ -4405,7 +4737,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
for (int j = 0; j < embed_dim; j++) {
|
||||
for (int k = 0; k < embed_dim; k++) {
|
||||
float vb = wptr[k * embed_dim * 3 + j];
|
||||
float vb = wptr[j * embed_dim * 3 + k];
|
||||
fwrite(&vb, sizeof(float), 1, bp);
|
||||
}
|
||||
}
|
||||
|
@ -4424,7 +4756,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
for (int j = 0; j < embed_dim; j++) {
|
||||
for (int k = 0; k < embed_dim; k++) {
|
||||
float vb = wptr[k * embed_dim * 3 + j + embed_dim];
|
||||
float vb = wptr[j * embed_dim * 3 + k + embed_dim];
|
||||
fwrite(&vb, sizeof(float), 1, bp);
|
||||
}
|
||||
}
|
||||
|
@ -4443,7 +4775,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
for (int j = 0; j < embed_dim; j++) {
|
||||
for (int k = 0; k < embed_dim; k++) {
|
||||
float vb = wptr[k * embed_dim * 3 + j + embed_dim * 2];
|
||||
float vb = wptr[j * embed_dim * 3 + k + embed_dim * 2];
|
||||
fwrite(&vb, sizeof(float), 1, bp);
|
||||
}
|
||||
}
|
||||
|
@ -4459,7 +4791,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
for (int j = 0; j < embed_dim; j++) {
|
||||
for (int k = 0; k < embed_dim; k++) {
|
||||
float vb = wptr[k * embed_dim + j];
|
||||
float vb = wptr[j * embed_dim + k];
|
||||
fwrite(&vb, sizeof(float), 1, bp);
|
||||
}
|
||||
}
|
||||
|
@ -4489,7 +4821,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
for (int j = 0; j < embed_dim; j++) {
|
||||
for (int k = 0; k < embed_dim; k++) {
|
||||
float vb = wptr[k * embed_dim + j];
|
||||
float vb = wptr[j * embed_dim + k];
|
||||
fwrite(&vb, sizeof(float), 1, bp);
|
||||
}
|
||||
}
|
||||
|
@ -4504,7 +4836,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
for (int j = 0; j < embed_dim; j++) {
|
||||
for (int k = 0; k < embed_dim; k++) {
|
||||
float vb = wptr[k * embed_dim + j];
|
||||
float vb = wptr[j * embed_dim + k];
|
||||
fwrite(&vb, sizeof(float), 1, bp);
|
||||
}
|
||||
}
|
||||
|
@ -4519,7 +4851,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
for (int j = 0; j < embed_dim; j++) {
|
||||
for (int k = 0; k < embed_dim; k++) {
|
||||
float vb = wptr[k * embed_dim + j];
|
||||
float vb = wptr[j * embed_dim + k];
|
||||
fwrite(&vb, sizeof(float), 1, bp);
|
||||
}
|
||||
}
|
||||
|
@ -4534,7 +4866,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
for (int j = 0; j < embed_dim; j++) {
|
||||
for (int k = 0; k < embed_dim; k++) {
|
||||
float vb = wptr[k * embed_dim + j];
|
||||
float vb = wptr[j * embed_dim + k];
|
||||
fwrite(&vb, sizeof(float), 1, bp);
|
||||
}
|
||||
}
|
||||
|
@ -4552,17 +4884,17 @@ int main(int argc, char** argv) {
|
|||
// fprintf(stderr, "node.input_size(): %d\n", node.input_size());
|
||||
if (node.input_size() >= 3) {
|
||||
// fprintf(stderr, "ok12!\n");
|
||||
max_dets = (int)(get_node_attr_from_input_f(weights[node.input(2)]) + 0.5);
|
||||
max_dets = (int)(get_node_attr_from_input<float>(weights[node.input(2)]) + 0.5);
|
||||
}
|
||||
if (node.input_size() >= 4) {
|
||||
// fprintf(stderr, "iou_thre: %f\n",
|
||||
// get_node_attr_from_input_f(weights[node.input(3)]));
|
||||
iou_thre = get_node_attr_from_input_f(weights[node.input(3)]);
|
||||
// get_node_attr_from_input<float>(weights[node.input(3)]));
|
||||
iou_thre = get_node_attr_from_input<float>(weights[node.input(3)]);
|
||||
}
|
||||
if (node.input_size() >= 5) {
|
||||
// fprintf(stderr, "score_thre: %f\n",
|
||||
// get_node_attr_from_input_f(weights[node.input(4)]));
|
||||
score_thre = get_node_attr_from_input_f(weights[node.input(4)]);
|
||||
// get_node_attr_from_input<float>(weights[node.input(4)]));
|
||||
score_thre = get_node_attr_from_input<float>(weights[node.input(4)]);
|
||||
}
|
||||
fprintf(pp, " 0=%d", max_dets);
|
||||
fprintf(pp, " 1=%f", iou_thre);
|
||||
|
@ -4736,8 +5068,10 @@ int main(int argc, char** argv) {
|
|||
|
||||
if (node.input_size() == 1) {
|
||||
shape = get_node_attr_ai(node, "shape");
|
||||
} else {
|
||||
} else if (weights.find(node.input(1)) != weights.end()) {
|
||||
shape = get_node_attr_from_input_ai(weights[node.input(1)]);
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported reshape weight ! \n");
|
||||
}
|
||||
|
||||
if (shape.size() == 1) {
|
||||
|
|
|
@ -21,27 +21,27 @@ int Shape::forward(const Mat &bottom_blob, Mat &top_blob, const Option &opt) con
|
|||
return -100;
|
||||
}
|
||||
float *outptr = top_blob;
|
||||
|
||||
if (dims == 1) {
|
||||
outptr[0] = 1.0f;
|
||||
outptr[1] = w;
|
||||
return 0;
|
||||
}
|
||||
if (dims == 2) {
|
||||
} else if (dims == 2) {
|
||||
int h = bottom_blob.h;
|
||||
outptr[0] = 1.0f;
|
||||
outptr[1] = h;
|
||||
outptr[2] = w;
|
||||
return 0;
|
||||
}
|
||||
if (dims == 3) {
|
||||
} else if (dims == 3) {
|
||||
int h = bottom_blob.h;
|
||||
int channels = bottom_blob.c;
|
||||
outptr[0] = 1.0f;
|
||||
outptr[1] = channels;
|
||||
outptr[2] = h;
|
||||
outptr[3] = w;
|
||||
return 0;
|
||||
} else {
|
||||
fprintf(stdout, "Unsupported dims=%d\n", dims);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace mmdeploy
|
||||
|
|
|
@ -2,3 +2,4 @@
|
|||
from .backbones import * # noqa: F401,F403
|
||||
from .classifiers import * # noqa: F401,F403
|
||||
from .heads import * # noqa: F401,F403
|
||||
from .utils import * # noqa: F401,F403
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .shufflenet_v2 import shufflenetv2_backbone__forward__ncnn
|
||||
from .vision_transformer import visiontransformer__forward__ncnn
|
||||
|
||||
__all__ = ['shufflenetv2_backbone__forward__ncnn']
|
||||
__all__ = [
|
||||
'shufflenetv2_backbone__forward__ncnn',
|
||||
'visiontransformer__forward__ncnn',
|
||||
]
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcls.models.utils import resize_pos_embed
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import Backend
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name= # noqa: E251
|
||||
'mmcls.models.backbones.vision_transformer.VisionTransformer.forward',
|
||||
backend=Backend.NCNN.value)
|
||||
def visiontransformer__forward__ncnn(ctx, self, x):
|
||||
"""Rewrite `forward` of VisionTransformer for ncnn backend.
|
||||
|
||||
The chunk in original VisionTransformer.forward will convert
|
||||
`self.cls_token` to `where` operator in ONNX, which will raise
|
||||
error in ncnn.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self (VisionTransformer): The instance of the class InvertedResidual.
|
||||
x (Tensor): Input features of shape (N, Cin, H, W).
|
||||
Returns:
|
||||
out (Tensor): A feature map output from InvertedResidual. The tensor
|
||||
shape (N, Cout, H, W).
|
||||
"""
|
||||
B = x.shape[0]
|
||||
x, patch_resolution = self.patch_embed(x)
|
||||
|
||||
# cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((self.cls_token, x), dim=1)
|
||||
x = x + resize_pos_embed(
|
||||
self.pos_embed,
|
||||
self.patch_resolution,
|
||||
patch_resolution,
|
||||
mode=self.interpolate_mode,
|
||||
num_extra_tokens=self.num_extra_tokens)
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
if not self.with_cls_token:
|
||||
# Remove class token for transformer encoder input
|
||||
x = x[:, 1:]
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
|
||||
if i == len(self.layers) - 1 and self.final_norm:
|
||||
x = self.norm1(x)
|
||||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
if self.with_cls_token:
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
else:
|
||||
patch_token = x.reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = None
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token]
|
||||
else:
|
||||
out = patch_token
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .attention import multiheadattention__forward__ncnn
|
||||
|
||||
__all__ = ['multiheadattention__forward__ncnn']
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from mmdeploy.core import FUNCTION_REWRITER
|
||||
from mmdeploy.utils import Backend
|
||||
|
||||
|
||||
class MultiHeadAttentionop(torch.autograd.Function):
|
||||
"""Create onnx::MultiHeadAttention op."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q: Tensor, k: Tensor, v: Tensor, q_weight: Tensor,
|
||||
q_bias: Tensor, k_weight: Tensor, k_bias: Tensor,
|
||||
v_weight: Tensor, v_bias: Tensor, o_weight: Tensor,
|
||||
o_bias: Tensor, embed_dims: int, num_heads: int) -> Tensor:
|
||||
return torch.rand_like(q)
|
||||
|
||||
@staticmethod
|
||||
def symbolic(g, q: torch._C.Value, k: torch._C.Value, v: torch._C.Value,
|
||||
q_weight: torch._C.Value, q_bias: torch._C.Value,
|
||||
k_weight: torch._C.Value, k_bias: torch._C.Value,
|
||||
v_weight: torch._C.Value, v_bias: torch._C.Value,
|
||||
o_weight: torch._C.Value, o_bias: torch._C.Value,
|
||||
embed_dims: int, num_heads: int):
|
||||
|
||||
q_weight.setDebugName('q_weight')
|
||||
q_bias.setDebugName('q_bias')
|
||||
|
||||
k_weight.setDebugName('k_weight')
|
||||
k_bias.setDebugName('k_bias')
|
||||
|
||||
v_weight.setDebugName('v_weight')
|
||||
v_bias.setDebugName('v_bias')
|
||||
|
||||
o_weight.setDebugName('o_weight')
|
||||
o_bias.setDebugName('o_bias')
|
||||
|
||||
return g.op(
|
||||
'mmdeploy::MultiHeadAttention',
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q_weight,
|
||||
q_bias,
|
||||
k_weight,
|
||||
k_bias,
|
||||
v_weight,
|
||||
v_bias,
|
||||
o_weight,
|
||||
o_bias,
|
||||
embed_dim_i=embed_dims,
|
||||
num_heads_i=num_heads)
|
||||
|
||||
|
||||
@FUNCTION_REWRITER.register_rewriter(
|
||||
func_name='mmcls.models.utils.attention.MultiheadAttention.forward',
|
||||
backend=Backend.NCNN.value)
|
||||
def multiheadattention__forward__ncnn(ctx, self, qkv_input):
|
||||
"""Rewrite `forward` of MultiheadAttention used in vision_transformer for
|
||||
ncnn backend.
|
||||
|
||||
Args:
|
||||
ctx (ContextCaller): The context with additional information.
|
||||
self (MultiheadAttention): The instance of the class
|
||||
MultiheadAttention.
|
||||
x (Tensor): Input features of shape (N, Cin, H, W).
|
||||
Returns:
|
||||
out (Tensor): A feature map output from MultiHeadAttention. The tensor
|
||||
shape (N, Cout, H, W).
|
||||
"""
|
||||
|
||||
# split qkv weight and bias
|
||||
qkv_weight = self.qkv.weight.data.reshape(3, self.input_dims,
|
||||
self.embed_dims)
|
||||
|
||||
q_weight = qkv_weight[0]
|
||||
k_weight = qkv_weight[1]
|
||||
v_weight = qkv_weight[2]
|
||||
|
||||
qkv_bias = self.qkv.bias.data.reshape(3, self.embed_dims)
|
||||
q_bias = qkv_bias[0]
|
||||
k_bias = qkv_bias[1]
|
||||
v_bias = qkv_bias[2]
|
||||
|
||||
# out weight and bias
|
||||
o_weight = self.proj.weight.data
|
||||
o_bias = self.proj.bias.data
|
||||
|
||||
out = MultiHeadAttentionop.apply(qkv_input, qkv_input, qkv_input, q_weight,
|
||||
q_bias, k_weight, k_bias, v_weight,
|
||||
v_bias, o_weight, o_bias, self.embed_dims,
|
||||
self.num_heads)
|
||||
return out
|
|
@ -2,9 +2,11 @@
|
|||
from .adaptive_avg_pool import (adaptive_avg_pool1d__default,
|
||||
adaptive_avg_pool2d__default,
|
||||
adaptive_avg_pool3d__default)
|
||||
from .gelu import gelu__ncnn
|
||||
from .grid_sampler import grid_sampler__default
|
||||
from .hardsigmoid import hardsigmoid__default
|
||||
from .instance_norm import instance_norm__tensorrt
|
||||
from .layer_norm import layer_norm__ncnn
|
||||
from .lstm import generic_rnn__ncnn
|
||||
from .squeeze import squeeze__default
|
||||
|
||||
|
@ -12,5 +14,5 @@ __all__ = [
|
|||
'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default',
|
||||
'adaptive_avg_pool3d__default', 'grid_sampler__default',
|
||||
'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn',
|
||||
'squeeze__default'
|
||||
'squeeze__default', 'gelu__ncnn', 'layer_norm__ncnn'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
||||
from mmdeploy.utils import Backend
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'gelu', is_pytorch=True, arg_descriptors=['v'], backend=Backend.NCNN.value)
|
||||
def gelu__ncnn(ctx, g, self):
|
||||
"""Support export GELU with ncnn backend."""
|
||||
return g.op('mmdeploy::Gelu', self)
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Modified from:
|
||||
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
|
||||
|
||||
from torch.onnx.symbolic_helper import parse_args
|
||||
|
||||
from mmdeploy.core import SYMBOLIC_REWRITER
|
||||
from mmdeploy.utils import Backend
|
||||
|
||||
|
||||
@parse_args('v', 'is', 'v', 'v', 'f', 'i')
|
||||
def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
|
||||
"""Symbolic function for `layer_norm`.
|
||||
|
||||
PyTorch does not support export layer_norm to ONNX by default. We add the
|
||||
support here. `layer_norm` will be exported as ONNX node
|
||||
'mmdeploy::layer_norm'
|
||||
"""
|
||||
weight.setDebugName('layernorm_weight')
|
||||
bias.setDebugName('layernorm_bias')
|
||||
return g.op(
|
||||
'mmdeploy::LayerNorm', input, weight, bias, affine_i=1, epsilon_f=eps)
|
||||
|
||||
|
||||
@SYMBOLIC_REWRITER.register_symbolic(
|
||||
'layer_norm', is_pytorch=True, backend=Backend.NCNN.value)
|
||||
def layer_norm__ncnn(ctx, *args):
|
||||
"""Register default symbolic function for `layer_norm`.
|
||||
|
||||
Add support to layer_norm to ONNX.
|
||||
"""
|
||||
return layer_norm(*args)
|
|
@ -1,4 +1,4 @@
|
|||
mmcls>=0.15.0,<=0.19.0
|
||||
mmcls>=0.21.0,<=0.22.1
|
||||
mmdet>=2.19.0,<=2.20.0
|
||||
mmedit
|
||||
mmocr>=0.3.0,<=0.4.1
|
||||
|
|
|
@ -14,7 +14,7 @@ import_codebase(Codebase.MMCLS)
|
|||
input = torch.rand(1)
|
||||
|
||||
|
||||
def get_invertedresudual_model():
|
||||
def get_invertedresidual_model():
|
||||
from mmcls.models.backbones.shufflenet_v2 import InvertedResidual
|
||||
model = InvertedResidual(16, 16)
|
||||
|
||||
|
@ -22,6 +22,43 @@ def get_invertedresudual_model():
|
|||
return model
|
||||
|
||||
|
||||
def get_vit_model():
|
||||
from mmcls.models.classifiers.image import ImageClassifier
|
||||
model = ImageClassifier(
|
||||
backbone={
|
||||
'type':
|
||||
'VisionTransformer',
|
||||
'arch':
|
||||
'b',
|
||||
'img_size':
|
||||
384,
|
||||
'patch_size':
|
||||
32,
|
||||
'drop_rate':
|
||||
0.1,
|
||||
'init_cfg': [{
|
||||
'type': 'Kaiming',
|
||||
'layer': 'Conv2d',
|
||||
'mode': 'fan_in',
|
||||
'nonlinearity': 'linear'
|
||||
}]
|
||||
},
|
||||
head={
|
||||
'type': 'VisionTransformerClsHead',
|
||||
'num_classes': 1000,
|
||||
'in_channels': 768,
|
||||
'loss': {
|
||||
'type': 'CrossEntropyLoss',
|
||||
'loss_weight': 1.0
|
||||
},
|
||||
'topk': (1, 5)
|
||||
},
|
||||
)
|
||||
model.requires_grad_(False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def test_baseclassifier_forward():
|
||||
from mmcls.models.classifiers import BaseClassifier
|
||||
|
||||
|
@ -78,7 +115,7 @@ def test_multilabel_cls_head():
|
|||
def test_shufflenetv2_backbone__forward(backend_type: Backend):
|
||||
|
||||
check_backend(backend_type, True)
|
||||
model = get_invertedresudual_model()
|
||||
model = get_invertedresidual_model()
|
||||
model.cpu().eval()
|
||||
if backend_type.value == 'tensorrt':
|
||||
deploy_cfg = mmcv.Config(
|
||||
|
@ -121,3 +158,37 @@ def test_shufflenetv2_backbone__forward(backend_type: Backend):
|
|||
rewrite_output = rewrite_output.cpu().numpy()
|
||||
assert np.allclose(
|
||||
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
|
||||
def test_vision_transformer_backbone__forward(backend_type: Backend):
|
||||
|
||||
check_backend(backend_type, True)
|
||||
model = get_vit_model()
|
||||
model.eval()
|
||||
|
||||
deploy_cfg = mmcv.Config(
|
||||
dict(
|
||||
backend_config=dict(type=backend_type.value),
|
||||
onnx_config=dict(input_shape=None, output_names=['output']),
|
||||
codebase_config=dict(type='mmcls', task='Classification')))
|
||||
|
||||
imgs = torch.rand((1, 3, 384, 384))
|
||||
model_outputs = model.forward(imgs, return_loss=False)
|
||||
wrapped_model = WrapModel(model, 'forward')
|
||||
rewrite_inputs = {'img': imgs}
|
||||
rewrite_outputs, is_backend_output = get_rewrite_outputs(
|
||||
wrapped_model=wrapped_model,
|
||||
model_inputs=rewrite_inputs,
|
||||
deploy_cfg=deploy_cfg)
|
||||
|
||||
if isinstance(rewrite_outputs, dict):
|
||||
rewrite_outputs = rewrite_outputs['output']
|
||||
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
|
||||
if isinstance(rewrite_output, torch.Tensor):
|
||||
rewrite_output = rewrite_output.cpu().numpy()
|
||||
assert np.allclose(
|
||||
model_output.reshape(-1),
|
||||
rewrite_output.reshape(-1),
|
||||
rtol=1e-03,
|
||||
atol=1e-05)
|
||||
|
|
Loading…
Reference in New Issue