feat(codebase/cls): support vision_transformer (#403)

* feat(codebase/cls): support vision_transformer

* style(onnx2ncnn): format cpp code, upgrade mmcls version

* fix(CI): upgrade mmcv to 1.4.2

* fix(onnx2ncnn): offset out of range during fuse conv reshape

* docs(vision_transformer.py): update VisionTransformer desc

* docs(onnx2ncnn.cpp): add more comment

* feat(onnx2ncnn.cpp): revert fuse weight

* docs(onnx2ncnn.cpp): add more comment

* test(vision_transformer): add test case

* refactor(vision_transformer.py): use symbol rewrite layer_norm

* refactor(vision_transformer): fix review

* fix(attention): add missing files
pull/316/head^2
tpoisonooo 2022-04-26 18:00:38 +08:00 committed by GitHub
parent 01a44c00c9
commit 2c2d1e5ad9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 668 additions and 46 deletions

View File

@ -23,7 +23,7 @@ jobs:
matrix:
python-version: [3.7]
torch: [1.8.0, 1.9.0]
mmcv: [1.4.0]
mmcv: [1.4.2]
include:
- torch: 1.8.0
torch_version: torch1.8
@ -65,7 +65,7 @@ jobs:
matrix:
python-version: [3.7]
torch: [1.9.0+cu102]
mmcv: [1.4.0]
mmcv: [1.4.2]
include:
- torch: 1.9.0+cu102
torch_version: torch1.9
@ -108,7 +108,7 @@ jobs:
matrix:
python-version: [3.7]
torch: [1.8.0+cu111]
mmcv: [1.4.0]
mmcv: [1.4.2]
include:
- torch: 1.8.0+cu111
torch_version: torch1.8

View File

@ -28,6 +28,7 @@
#include <iostream>
#include <limits>
#include <set>
#include <tuple>
#include "onnx.pb.h"
@ -73,6 +74,17 @@ static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char
return v;
}
static void set_node_attr_ai(onnx::NodeProto& node, const char* key,
const std::vector<int>& value) {
onnx::AttributeProto* attr_group = node.add_attribute();
attr_group->set_name(key);
for (auto v : value) {
attr_group->add_ints(v);
}
return;
}
static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key) {
std::vector<float> v;
@ -137,8 +149,9 @@ static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const
return onnx::TensorProto();
}
static float get_node_attr_from_input_f(const onnx::TensorProto& tp) {
float v = 0.f;
template <typename T>
static T get_node_attr_from_input(const onnx::TensorProto& tp) {
T v = 0.f;
// float
if (tp.data_type() == 1) {
@ -183,7 +196,7 @@ static float get_node_attr_from_input_f(const onnx::TensorProto& tp) {
} else {
// fprintf(stderr, "tp.name: %s\n", tp.name().c_str());
fprintf(stderr, "Unknown data type %d\n", tp.data_type());
fprintf(stderr, "get_node_attr_from_input_f\n");
fprintf(stderr, "get_node_attr_from_input\n");
abort();
}
@ -680,7 +693,7 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph,
const onnx::TensorProto& add_three = weights[node->input(1)];
if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue;
float constant_add_three = get_node_attr_from_input_f(add_three);
float constant_add_three = get_node_attr_from_input<float>(add_three);
if (constant_add_three != 3.f) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
@ -708,8 +721,8 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph,
const onnx::TensorProto& min_tp = weights[node2->input(1)];
const onnx::TensorProto& max_tp = weights[node2->input(2)];
relu6_min = get_node_attr_from_input_f(min_tp);
relu6_max = get_node_attr_from_input_f(max_tp);
relu6_min = get_node_attr_from_input<float>(min_tp);
relu6_max = get_node_attr_from_input<float>(max_tp);
}
if (relu6_min != 0.f || relu6_max != 6.f) continue;
@ -722,7 +735,7 @@ static void fuse_hardswish(onnx::GraphProto* mutable_graph,
const onnx::TensorProto& div_six = weights[node4->input(1)];
if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue;
float constant_div_six = get_node_attr_from_input_f(div_six);
float constant_div_six = get_node_attr_from_input<float>(div_six);
if (node4->op_type() == "Div" && constant_div_six != 6.f) continue;
if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue;
@ -831,7 +844,7 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph,
const onnx::TensorProto& add_three = weights[node->input(1)];
if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue;
float constant_add_three = get_node_attr_from_input_f(add_three);
float constant_add_three = get_node_attr_from_input<float>(add_three);
if (constant_add_three != 3.f) continue;
onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
@ -857,8 +870,8 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph,
const onnx::TensorProto& min_tp = weights[node2->input(1)];
const onnx::TensorProto& max_tp = weights[node2->input(2)];
relu6_min = get_node_attr_from_input_f(min_tp);
relu6_max = get_node_attr_from_input_f(max_tp);
relu6_min = get_node_attr_from_input<float>(min_tp);
relu6_max = get_node_attr_from_input<float>(max_tp);
}
if (relu6_min != 0.f || relu6_max != 6.f) continue;
@ -867,7 +880,7 @@ static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph,
const onnx::TensorProto& div_six = weights[node3->input(1)];
if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue;
float constant_div_six = get_node_attr_from_input_f(div_six);
float constant_div_six = get_node_attr_from_input<float>(div_six);
if (node3->op_type() == "Div" && constant_div_six != 6.f) continue;
if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue;
@ -1090,7 +1103,7 @@ static void fuse_normalize(onnx::GraphProto* mutable_graph,
} else {
const onnx::TensorProto& min_tp = weights[node2->input(1)];
clip_min = get_node_attr_from_input_f(min_tp);
clip_min = get_node_attr_from_input<float>(min_tp);
}
// reduce
@ -1343,7 +1356,7 @@ static void fuse_layernorm(onnx::GraphProto* mutable_graph,
const onnx::TensorProto& pow_two = weights[node3->input(1)];
if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1) continue;
float constant_pow_two = get_node_attr_from_input_f(pow_two);
float constant_pow_two = get_node_attr_from_input<float>(pow_two);
if (constant_pow_two != 2.f) continue;
std::vector<int> axes4 = get_node_attr_ai(*node4, "axes");
@ -1360,7 +1373,7 @@ static void fuse_layernorm(onnx::GraphProto* mutable_graph,
const onnx::TensorProto& add_eps = weights[node5->input(1)];
if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1) continue;
float eps = get_node_attr_from_input_f(add_eps);
float eps = get_node_attr_from_input<float>(add_eps);
int affine = 0;
while (i + 8 < node_count) {
@ -2546,6 +2559,320 @@ static void fuse_multiheadattention(onnx::GraphProto* mutable_graph,
}
}
/**
* @brief find graph node by output name
*
* @param graph
* @param name
* @return onnx::NodeProto*
*/
static onnx::NodeProto* find_node_by_output_name(onnx::GraphProto* mutable_graph,
const std::string& name) {
const int input_count = mutable_graph->node_size();
for (int i = 0; i < input_count; ++i) {
onnx::NodeProto* node = mutable_graph->mutable_node(i);
for (int j = 0; j < node->output_size(); ++j) {
auto output = node->output(j);
if (output == name) {
return node;
}
}
}
return nullptr;
}
/**
* @brief query output shape of target node
*
* @param mutable_graph
* @param target
* @param weights
* @param context <tensor name, shape>
* @return std::tuple<bool, std::vector<int>>
*/
static std::tuple<bool, std::vector<int>> query_shape(
onnx::GraphProto* mutable_graph, onnx::NodeProto* target,
const std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, std::vector<int>>& context) {
// emplace all input nodes
const int input_count = mutable_graph->input_size();
for (int i = 0; i < input_count; i++) {
auto inp = mutable_graph->input(i);
onnx::TypeProto inp_type = inp.type();
onnx::TensorShapeProto shape_proto = inp_type.tensor_type().shape();
auto dim_size = shape_proto.dim_size();
std::vector<int> shape(dim_size);
for (int index = 0; index < dim_size; ++index) {
shape[index] = shape_proto.dim(index).dim_value();
}
context.emplace(inp.name(), shape);
}
// BFS the tree, `target` as root, onnx::graph inputs and weights as leaf nodes
std::vector<onnx::NodeProto*> serial = {target};
{
std::set<std::string> mark_as_appended = {};
while (true) {
int start = 0, end = serial.size();
for (int i = start; i < end; ++i) {
auto node_ptr = serial[i];
auto len = node_ptr->input_size();
for (int j = 0; j < len; ++j) {
std::string name = node_ptr->input(j);
if (context.find(name) != context.end()) {
// if input founded, skip
continue;
}
if (weights.find(name) != weights.end()) {
// if founded in weights, extract shape to context
auto weight = weights.at(name);
std::vector<int> shape;
for (auto index = 0; index < weight.dims_size(); ++index) {
shape.emplace_back(weight.dims(index));
}
context.emplace(name, shape);
continue;
}
if (mark_as_appended.find(name) != mark_as_appended.end()) {
// if mark as appended, skip
continue;
}
// else append it to serialization list
auto depend_ptr = find_node_by_output_name(mutable_graph, name);
if (depend_ptr == nullptr) {
fprintf(stderr, "cannot find %s from graph !\n", name.c_str());
return std::make_tuple(false, std::vector<int>{});
}
mark_as_appended.insert(name);
serial.emplace_back(depend_ptr);
}
}
if (serial.size() <= end) {
// if not new node added, quit
break;
}
// update start and end position, continue BFS the tree
start = end;
end = serial.size();
}
}
// for each node in serialization list, calculate the output shape
{
std::reverse(serial.begin(), serial.end());
for (auto node : serial) {
if (node->op_type() == "Conv") {
auto inp = context[node->input(0)];
auto weight = context[node->input(1)];
assert(inp.size() == 4 and weight.size() == 4);
int group = get_node_attr_i(*node, "group", 1);
assert(group == 1);
// treat multiple spatial attr as single one
#define EXTRACT_REPEATED_PARAM(NAME, ATTR, DEFAULT) \
int ATTR = DEFAULT; \
{ \
std::vector<int> _vec = get_node_attr_ai(*node, NAME); \
if (not _vec.empty()) { \
ATTR = _vec[0]; \
} \
}
EXTRACT_REPEATED_PARAM("dilations", dilation, 1);
EXTRACT_REPEATED_PARAM("pads", pad, 0);
EXTRACT_REPEATED_PARAM("strides", stride, 1);
#undef EXTRACT_REPEATED_PARAM
int on = inp[0];
int oc = weight[0];
int oh = (inp[2] + 2 * pad - weight[2]) / stride + 1;
int ow = (inp[3] + 2 * pad - weight[3]) / stride + 1;
context.emplace(node->output(0), std::vector<int>{on, oc, oh, ow});
} else if (node->op_type() == "Shape") {
auto inp = context[node->input(0)];
context.emplace(node->output(0), std::vector<int>{1, inp[1], inp[2], inp[3]});
} else if (node->op_type() == "Slice") {
assert(node->input_size() >= 4);
auto inp = context[node->input(0)];
int start = get_node_attr_from_input<int>(weights.at(node->input(1)));
int end = get_node_attr_from_input<int>(weights.at(node->input(2)));
int axes = get_node_attr_from_input<int>(weights.at(node->input(3)));
if (axes != 0) {
fprintf(stderr, "Not support axes=%d !\n", axes);
return std::make_tuple(false, std::vector<int>{});
}
assert(inp.size() >= end - start);
context.emplace(node->output(0), std::vector<int>{inp.begin() + start, inp.begin() + end});
} else if (node->op_type() == "Concat") {
assert(node->input_size() >= 2);
auto axis = get_node_attr_i(*node, "axis", 0);
if (axis != 0) {
fprintf(stderr, "Not support axes=%d !\n", axis);
return std::make_tuple(false, std::vector<int>{});
}
std::vector<int> inp = context[node->input(0)];
std::vector<int> w_data = get_node_attr_from_input_ai(weights.at(node->input(1)));
// concat data on axis 0
inp.insert(inp.end(), w_data.begin(), w_data.end());
context.emplace(node->output(0), inp);
} else {
fprintf(stderr, "Unsupported type %s in query_shape !\n", node->op_type().c_str());
return std::make_tuple(false, std::vector<int>{});
}
}
}
assert(context.find(target->output(0)) != context.end());
auto target_shape = context[target->output(0)];
return std::make_tuple(true, target_shape);
}
/**
* @brief fuse subgraph
*
* conv - - - - - - - - - - - -> reshape
* \ /
* shape - slice - concat
*
* to
*
* conv --> reshape
*
* @param mutable_graph
* @param weights
* @param node_reference
* @param blob_names
* @param reduced_node_count
*/
static void fuse_conv_reshape(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
std::set<std::string>& blob_names, int& reduced_node_count) {
std::map<std::string, std::vector<int>> shape_context;
const int node_count = mutable_graph->node_size();
for (int i = 0; i < node_count; i++) {
onnx::NodeProto* conv = mutable_graph->mutable_node(i);
if (conv->op_type() != "Conv") {
continue;
}
if (i + 4 >= node_count) {
continue;
}
onnx::NodeProto *shape = nullptr, *slice = nullptr, *concat = nullptr, *reshape = nullptr;
// match [Shape ... Slice, Concat ... Reshape] from near sequence, skip useless Constant
std::vector<std::tuple<std::string, onnx::NodeProto**>> candidates = {
{"Shape", &shape}, {"Slice", &slice}, {"Concat", &concat}, {"Reshape", &reshape}};
int MAX = std::min(10, node_count - i - 1);
int pos_candidate = 0;
for (int j = 0; j < MAX; ++j) {
auto node_ptr = mutable_graph->mutable_node(j + i + 1);
if (node_ptr->op_type() == "Constant") {
continue;
}
if (node_ptr->op_type() == std::get<0>(candidates[pos_candidate])) {
*(std::get<1>(candidates[pos_candidate])) = node_ptr;
pos_candidate++;
}
}
if (pos_candidate != candidates.size()) {
// not match the sequence
continue;
}
if (node_reference[conv->output(0)] != 2 || node_reference[shape->output(0)] != 1 ||
node_reference[slice->output(0)] != 1 || node_reference[concat->output(0)] != 1 ||
node_reference[reshape->output(0)] != 1) {
continue;
}
// check the connections
if (shape->input(0) != conv->output(0) || reshape->input(0) != conv->output(0)) {
continue;
}
if (slice->input(0) != shape->output(0)) {
continue;
}
if (concat->input(0) != slice->output(0)) {
continue;
}
if (reshape->input(0) != conv->output(0) || reshape->input(1) != concat->output(0)) {
continue;
}
// add reshape attr
auto result = query_shape(mutable_graph, concat, weights, shape_context);
if (!std::get<0>(result)) {
continue;
}
set_node_attr_ai(*reshape, "shape", std::get<1>(result));
// reconstruct graph
{
// remove reference
node_reference[reshape->input(1)] -= 1;
node_reference[concat->input(0)] -= 1;
node_reference[slice->input(0)] -= 1;
node_reference[shape->input(0)] -= 1;
// remove tensor/blob on edge
blob_names.erase(slice->input(0));
blob_names.erase(slice->input(1));
blob_names.erase(slice->input(2));
blob_names.erase(slice->input(3));
weights.erase(slice->input(1));
weights.erase(slice->input(2));
weights.erase(slice->input(3));
blob_names.erase(concat->input(0));
blob_names.erase(concat->input(1));
weights.erase(concat->input(1));
blob_names.erase(reshape->input(0));
// update edge
shape->clear_input();
reshape->clear_input();
reshape->add_input(conv->output(0));
shape->set_op_type("noop_reducedncnn");
slice->set_op_type("noop_reducedncnn");
concat->set_op_type("noop_reducedncnn");
reduced_node_count += 3;
}
i += 3;
}
}
static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph,
std::map<std::string, onnx::TensorProto>& weights,
std::map<std::string, int>& node_reference,
@ -2563,7 +2890,7 @@ static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph,
const onnx::TensorProto& scalar_b = weights[node->input(1)];
if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1) continue;
float b = get_node_attr_from_input_f(scalar_b);
float b = get_node_attr_from_input<float>(scalar_b);
node_reference[node->input(1)] -= 1;
@ -2763,6 +3090,7 @@ int main(int argc, char** argv) {
// op chain fusion
int reduced_node_count = 0;
fuse_conv_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
@ -3200,6 +3528,8 @@ int main(int argc, char** argv) {
fprintf(pp, "%-16s", "UnaryOp");
} else if (op == "Gather") {
fprintf(pp, "%-16s", "Gather");
} else if (op == "Gelu") {
fprintf(pp, "%-16s", "GELU");
} else if (op == "Gemm") {
float alpha = get_node_attr_f(node, "alpha", 1.f);
float beta = get_node_attr_f(node, "beta", 1.f);
@ -3542,10 +3872,10 @@ int main(int argc, char** argv) {
max = get_node_attr_f(node, "max", FLT_MAX);
} else {
min = weights.find(node.input(1)) != weights.end()
? get_node_attr_from_input_f(weights[node.input(1)])
? get_node_attr_from_input<float>(weights[node.input(1)])
: -FLT_MAX;
max = weights.find(node.input(2)) != weights.end()
? get_node_attr_from_input_f(weights[node.input(2)])
? get_node_attr_from_input<float>(weights[node.input(2)])
: FLT_MAX;
}
@ -3835,6 +4165,8 @@ int main(int argc, char** argv) {
fprintf(stderr, "Unsupported Gather axis: %d\n", axis + 1);
}
fprintf(pp, " 0=%d", axis);
} else if (op == "Gelu") {
fprintf(pp, " 0=0");
} else if (op == "Gemm") {
float alpha = get_node_attr_f(node, "alpha", 1.f);
float beta = get_node_attr_f(node, "beta", 1.f);
@ -4405,7 +4737,7 @@ int main(int argc, char** argv) {
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[k * embed_dim * 3 + j];
float vb = wptr[j * embed_dim * 3 + k];
fwrite(&vb, sizeof(float), 1, bp);
}
}
@ -4424,7 +4756,7 @@ int main(int argc, char** argv) {
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[k * embed_dim * 3 + j + embed_dim];
float vb = wptr[j * embed_dim * 3 + k + embed_dim];
fwrite(&vb, sizeof(float), 1, bp);
}
}
@ -4443,7 +4775,7 @@ int main(int argc, char** argv) {
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[k * embed_dim * 3 + j + embed_dim * 2];
float vb = wptr[j * embed_dim * 3 + k + embed_dim * 2];
fwrite(&vb, sizeof(float), 1, bp);
}
}
@ -4459,7 +4791,7 @@ int main(int argc, char** argv) {
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[k * embed_dim + j];
float vb = wptr[j * embed_dim + k];
fwrite(&vb, sizeof(float), 1, bp);
}
}
@ -4489,7 +4821,7 @@ int main(int argc, char** argv) {
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[k * embed_dim + j];
float vb = wptr[j * embed_dim + k];
fwrite(&vb, sizeof(float), 1, bp);
}
}
@ -4504,7 +4836,7 @@ int main(int argc, char** argv) {
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[k * embed_dim + j];
float vb = wptr[j * embed_dim + k];
fwrite(&vb, sizeof(float), 1, bp);
}
}
@ -4519,7 +4851,7 @@ int main(int argc, char** argv) {
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[k * embed_dim + j];
float vb = wptr[j * embed_dim + k];
fwrite(&vb, sizeof(float), 1, bp);
}
}
@ -4534,7 +4866,7 @@ int main(int argc, char** argv) {
for (int j = 0; j < embed_dim; j++) {
for (int k = 0; k < embed_dim; k++) {
float vb = wptr[k * embed_dim + j];
float vb = wptr[j * embed_dim + k];
fwrite(&vb, sizeof(float), 1, bp);
}
}
@ -4552,17 +4884,17 @@ int main(int argc, char** argv) {
// fprintf(stderr, "node.input_size(): %d\n", node.input_size());
if (node.input_size() >= 3) {
// fprintf(stderr, "ok12!\n");
max_dets = (int)(get_node_attr_from_input_f(weights[node.input(2)]) + 0.5);
max_dets = (int)(get_node_attr_from_input<float>(weights[node.input(2)]) + 0.5);
}
if (node.input_size() >= 4) {
// fprintf(stderr, "iou_thre: %f\n",
// get_node_attr_from_input_f(weights[node.input(3)]));
iou_thre = get_node_attr_from_input_f(weights[node.input(3)]);
// get_node_attr_from_input<float>(weights[node.input(3)]));
iou_thre = get_node_attr_from_input<float>(weights[node.input(3)]);
}
if (node.input_size() >= 5) {
// fprintf(stderr, "score_thre: %f\n",
// get_node_attr_from_input_f(weights[node.input(4)]));
score_thre = get_node_attr_from_input_f(weights[node.input(4)]);
// get_node_attr_from_input<float>(weights[node.input(4)]));
score_thre = get_node_attr_from_input<float>(weights[node.input(4)]);
}
fprintf(pp, " 0=%d", max_dets);
fprintf(pp, " 1=%f", iou_thre);
@ -4736,8 +5068,10 @@ int main(int argc, char** argv) {
if (node.input_size() == 1) {
shape = get_node_attr_ai(node, "shape");
} else {
} else if (weights.find(node.input(1)) != weights.end()) {
shape = get_node_attr_from_input_ai(weights[node.input(1)]);
} else {
fprintf(stderr, "Unsupported reshape weight ! \n");
}
if (shape.size() == 1) {

View File

@ -21,27 +21,27 @@ int Shape::forward(const Mat &bottom_blob, Mat &top_blob, const Option &opt) con
return -100;
}
float *outptr = top_blob;
if (dims == 1) {
outptr[0] = 1.0f;
outptr[1] = w;
return 0;
}
if (dims == 2) {
} else if (dims == 2) {
int h = bottom_blob.h;
outptr[0] = 1.0f;
outptr[1] = h;
outptr[2] = w;
return 0;
}
if (dims == 3) {
} else if (dims == 3) {
int h = bottom_blob.h;
int channels = bottom_blob.c;
outptr[0] = 1.0f;
outptr[1] = channels;
outptr[2] = h;
outptr[3] = w;
return 0;
} else {
fprintf(stdout, "Unsupported dims=%d\n", dims);
}
return 0;
}
} // namespace mmdeploy

View File

@ -2,3 +2,4 @@
from .backbones import * # noqa: F401,F403
from .classifiers import * # noqa: F401,F403
from .heads import * # noqa: F401,F403
from .utils import * # noqa: F401,F403

View File

@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .shufflenet_v2 import shufflenetv2_backbone__forward__ncnn
from .vision_transformer import visiontransformer__forward__ncnn
__all__ = ['shufflenetv2_backbone__forward__ncnn']
__all__ = [
'shufflenetv2_backbone__forward__ncnn',
'visiontransformer__forward__ncnn',
]

View File

@ -0,0 +1,68 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcls.models.utils import resize_pos_embed
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend
@FUNCTION_REWRITER.register_rewriter(
func_name= # noqa: E251
'mmcls.models.backbones.vision_transformer.VisionTransformer.forward',
backend=Backend.NCNN.value)
def visiontransformer__forward__ncnn(ctx, self, x):
"""Rewrite `forward` of VisionTransformer for ncnn backend.
The chunk in original VisionTransformer.forward will convert
`self.cls_token` to `where` operator in ONNX, which will raise
error in ncnn.
Args:
ctx (ContextCaller): The context with additional information.
self (VisionTransformer): The instance of the class InvertedResidual.
x (Tensor): Input features of shape (N, Cin, H, W).
Returns:
out (Tensor): A feature map output from InvertedResidual. The tensor
shape (N, Cout, H, W).
"""
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
# cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((self.cls_token, x), dim=1)
x = x + resize_pos_embed(
self.pos_embed,
self.patch_resolution,
patch_resolution,
mode=self.interpolate_mode,
num_extra_tokens=self.num_extra_tokens)
x = self.drop_after_pos(x)
if not self.with_cls_token:
# Remove class token for transformer encoder input
x = x[:, 1:]
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.norm1(x)
if i in self.out_indices:
B, _, C = x.shape
if self.with_cls_token:
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
else:
patch_token = x.reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = None
if self.output_cls_token:
out = [patch_token, cls_token]
else:
out = patch_token
outs.append(out)
return tuple(outs)

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .attention import multiheadattention__forward__ncnn
__all__ = ['multiheadattention__forward__ncnn']

View File

@ -0,0 +1,95 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch import Tensor
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend
class MultiHeadAttentionop(torch.autograd.Function):
"""Create onnx::MultiHeadAttention op."""
@staticmethod
def forward(ctx, q: Tensor, k: Tensor, v: Tensor, q_weight: Tensor,
q_bias: Tensor, k_weight: Tensor, k_bias: Tensor,
v_weight: Tensor, v_bias: Tensor, o_weight: Tensor,
o_bias: Tensor, embed_dims: int, num_heads: int) -> Tensor:
return torch.rand_like(q)
@staticmethod
def symbolic(g, q: torch._C.Value, k: torch._C.Value, v: torch._C.Value,
q_weight: torch._C.Value, q_bias: torch._C.Value,
k_weight: torch._C.Value, k_bias: torch._C.Value,
v_weight: torch._C.Value, v_bias: torch._C.Value,
o_weight: torch._C.Value, o_bias: torch._C.Value,
embed_dims: int, num_heads: int):
q_weight.setDebugName('q_weight')
q_bias.setDebugName('q_bias')
k_weight.setDebugName('k_weight')
k_bias.setDebugName('k_bias')
v_weight.setDebugName('v_weight')
v_bias.setDebugName('v_bias')
o_weight.setDebugName('o_weight')
o_bias.setDebugName('o_bias')
return g.op(
'mmdeploy::MultiHeadAttention',
q,
k,
v,
q_weight,
q_bias,
k_weight,
k_bias,
v_weight,
v_bias,
o_weight,
o_bias,
embed_dim_i=embed_dims,
num_heads_i=num_heads)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.utils.attention.MultiheadAttention.forward',
backend=Backend.NCNN.value)
def multiheadattention__forward__ncnn(ctx, self, qkv_input):
"""Rewrite `forward` of MultiheadAttention used in vision_transformer for
ncnn backend.
Args:
ctx (ContextCaller): The context with additional information.
self (MultiheadAttention): The instance of the class
MultiheadAttention.
x (Tensor): Input features of shape (N, Cin, H, W).
Returns:
out (Tensor): A feature map output from MultiHeadAttention. The tensor
shape (N, Cout, H, W).
"""
# split qkv weight and bias
qkv_weight = self.qkv.weight.data.reshape(3, self.input_dims,
self.embed_dims)
q_weight = qkv_weight[0]
k_weight = qkv_weight[1]
v_weight = qkv_weight[2]
qkv_bias = self.qkv.bias.data.reshape(3, self.embed_dims)
q_bias = qkv_bias[0]
k_bias = qkv_bias[1]
v_bias = qkv_bias[2]
# out weight and bias
o_weight = self.proj.weight.data
o_bias = self.proj.bias.data
out = MultiHeadAttentionop.apply(qkv_input, qkv_input, qkv_input, q_weight,
q_bias, k_weight, k_bias, v_weight,
v_bias, o_weight, o_bias, self.embed_dims,
self.num_heads)
return out

View File

@ -2,9 +2,11 @@
from .adaptive_avg_pool import (adaptive_avg_pool1d__default,
adaptive_avg_pool2d__default,
adaptive_avg_pool3d__default)
from .gelu import gelu__ncnn
from .grid_sampler import grid_sampler__default
from .hardsigmoid import hardsigmoid__default
from .instance_norm import instance_norm__tensorrt
from .layer_norm import layer_norm__ncnn
from .lstm import generic_rnn__ncnn
from .squeeze import squeeze__default
@ -12,5 +14,5 @@ __all__ = [
'adaptive_avg_pool1d__default', 'adaptive_avg_pool2d__default',
'adaptive_avg_pool3d__default', 'grid_sampler__default',
'hardsigmoid__default', 'instance_norm__tensorrt', 'generic_rnn__ncnn',
'squeeze__default'
'squeeze__default', 'gelu__ncnn', 'layer_norm__ncnn'
]

View File

@ -0,0 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.core import SYMBOLIC_REWRITER
from mmdeploy.utils import Backend
@SYMBOLIC_REWRITER.register_symbolic(
'gelu', is_pytorch=True, arg_descriptors=['v'], backend=Backend.NCNN.value)
def gelu__ncnn(ctx, g, self):
"""Support export GELU with ncnn backend."""
return g.op('mmdeploy::Gelu', self)

View File

@ -0,0 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from:
# https://github.com/pytorch/pytorch/blob/9ade03959392e5a90b74261012de1d806cab2253/torch/onnx/symbolic_opset9.py
from torch.onnx.symbolic_helper import parse_args
from mmdeploy.core import SYMBOLIC_REWRITER
from mmdeploy.utils import Backend
@parse_args('v', 'is', 'v', 'v', 'f', 'i')
def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
"""Symbolic function for `layer_norm`.
PyTorch does not support export layer_norm to ONNX by default. We add the
support here. `layer_norm` will be exported as ONNX node
'mmdeploy::layer_norm'
"""
weight.setDebugName('layernorm_weight')
bias.setDebugName('layernorm_bias')
return g.op(
'mmdeploy::LayerNorm', input, weight, bias, affine_i=1, epsilon_f=eps)
@SYMBOLIC_REWRITER.register_symbolic(
'layer_norm', is_pytorch=True, backend=Backend.NCNN.value)
def layer_norm__ncnn(ctx, *args):
"""Register default symbolic function for `layer_norm`.
Add support to layer_norm to ONNX.
"""
return layer_norm(*args)

View File

@ -1,4 +1,4 @@
mmcls>=0.15.0,<=0.19.0
mmcls>=0.21.0,<=0.22.1
mmdet>=2.19.0,<=2.20.0
mmedit
mmocr>=0.3.0,<=0.4.1

View File

@ -14,7 +14,7 @@ import_codebase(Codebase.MMCLS)
input = torch.rand(1)
def get_invertedresudual_model():
def get_invertedresidual_model():
from mmcls.models.backbones.shufflenet_v2 import InvertedResidual
model = InvertedResidual(16, 16)
@ -22,6 +22,43 @@ def get_invertedresudual_model():
return model
def get_vit_model():
from mmcls.models.classifiers.image import ImageClassifier
model = ImageClassifier(
backbone={
'type':
'VisionTransformer',
'arch':
'b',
'img_size':
384,
'patch_size':
32,
'drop_rate':
0.1,
'init_cfg': [{
'type': 'Kaiming',
'layer': 'Conv2d',
'mode': 'fan_in',
'nonlinearity': 'linear'
}]
},
head={
'type': 'VisionTransformerClsHead',
'num_classes': 1000,
'in_channels': 768,
'loss': {
'type': 'CrossEntropyLoss',
'loss_weight': 1.0
},
'topk': (1, 5)
},
)
model.requires_grad_(False)
return model
def test_baseclassifier_forward():
from mmcls.models.classifiers import BaseClassifier
@ -78,7 +115,7 @@ def test_multilabel_cls_head():
def test_shufflenetv2_backbone__forward(backend_type: Backend):
check_backend(backend_type, True)
model = get_invertedresudual_model()
model = get_invertedresidual_model()
model.cpu().eval()
if backend_type.value == 'tensorrt':
deploy_cfg = mmcv.Config(
@ -121,3 +158,37 @@ def test_shufflenetv2_backbone__forward(backend_type: Backend):
rewrite_output = rewrite_output.cpu().numpy()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)
@pytest.mark.parametrize('backend_type', [Backend.NCNN])
def test_vision_transformer_backbone__forward(backend_type: Backend):
check_backend(backend_type, True)
model = get_vit_model()
model.eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(input_shape=None, output_names=['output']),
codebase_config=dict(type='mmcls', task='Classification')))
imgs = torch.rand((1, 3, 384, 384))
model_outputs = model.forward(imgs, return_loss=False)
wrapped_model = WrapModel(model, 'forward')
rewrite_inputs = {'img': imgs}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if isinstance(rewrite_outputs, dict):
rewrite_outputs = rewrite_outputs['output']
for model_output, rewrite_output in zip(model_outputs, rewrite_outputs):
if isinstance(rewrite_output, torch.Tensor):
rewrite_output = rewrite_output.cpu().numpy()
assert np.allclose(
model_output.reshape(-1),
rewrite_output.reshape(-1),
rtol=1e-03,
atol=1e-05)