mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
add model splitting support (#1)
* add function marker and model extractor * add fsaf split & partial mask rcnn split, import extract.py * 1. add value renaming 2. add apply_marks in config to turn on/off marks * rewind changes on pytorch2onnx.py Co-authored-by: q.yao <streetyao@live.com>
This commit is contained in:
parent
5998d24766
commit
ef41f69553
4
configs/mmdet/split.py
Normal file
4
configs/mmdet/split.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
_base_ = ['./base.py', '../_base_/backends/tensorrt.py']
|
||||||
|
|
||||||
|
backend = 'default'
|
||||||
|
apply_marks = True
|
@ -1,3 +1,5 @@
|
|||||||
from .anchor_head import AnchorHead
|
from .anchor_head import AnchorHead
|
||||||
|
from .rpn_head import rpn_head_forward
|
||||||
|
from .fsaf_head import fsaf_head_forward
|
||||||
|
|
||||||
__all__ = ['AnchorHead']
|
__all__ = ['AnchorHead', 'rpn_head_forward', 'fsaf_head_forward']
|
||||||
|
7
mmdeploy/mmdet/models/dense_heads/fsaf_head.py
Normal file
7
mmdeploy/mmdet/models/dense_heads/fsaf_head.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from mmdeploy.utils import FUNCTION_REWRITERS, mark
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITERS.register_rewriter('mmdet.models.FSAFHead.forward')
|
||||||
|
@mark('rpn_forward')
|
||||||
|
def fsaf_head_forward(rewriter, *args):
|
||||||
|
return rewriter.origin_func(*args)
|
7
mmdeploy/mmdet/models/dense_heads/rpn_head.py
Normal file
7
mmdeploy/mmdet/models/dense_heads/rpn_head.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from mmdeploy.utils import FUNCTION_REWRITERS, mark
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITERS.register_rewriter('mmdet.models.RPNHead.forward')
|
||||||
|
@mark('rpn_forward')
|
||||||
|
def rpn_head_forward(rewriter, self, feats):
|
||||||
|
return rewriter.origin_func(self, feats)
|
@ -1,3 +1,4 @@
|
|||||||
from .single_stage import SingleStageDetector
|
from .single_stage import SingleStageDetector
|
||||||
|
from .two_stage import extract_feat
|
||||||
|
|
||||||
__all__ = ['SingleStageDetector']
|
__all__ = ['SingleStageDetector', 'extract_feat']
|
||||||
|
19
mmdeploy/mmdet/models/detectors/two_stage.py
Normal file
19
mmdeploy/mmdet/models/detectors/two_stage.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from mmdeploy.utils import FUNCTION_REWRITERS, mark
|
||||||
|
from mmdeploy.utils import SYMBOLICS_REGISTER
|
||||||
|
from mmcv.onnx.symbolic import grid_sampler
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITERS.register_rewriter('mmdet.models.TwoStageDetector.extract_feat')
|
||||||
|
@mark('extract_feat')
|
||||||
|
def extract_feat(rewriter, self, img):
|
||||||
|
return rewriter.origin_func(self, img)
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITERS.register_rewriter('mmdet.models.TwoStageDetector.forward')
|
||||||
|
def two_stage_forward(rewriter, self, img, *args):
|
||||||
|
return rewriter.origin_func(self, [img], img_metas=[[{}]], return_loss=False, *args)
|
||||||
|
|
||||||
|
|
||||||
|
@SYMBOLICS_REGISTER.register_symbolic('grid_sampler', is_pytorch=True)
|
||||||
|
def symbolic_grid_sample(symbolic_wrapper, *args):
|
||||||
|
return grid_sampler(*args)
|
@ -1,8 +1,9 @@
|
|||||||
from .function_rewriter import FUNCTION_REWRITERS, RewriterContext
|
from .function_rewriter import FUNCTION_REWRITERS, RewriterContext
|
||||||
from .module_rewriter import MODULE_REWRITERS, patch_model
|
from .module_rewriter import MODULE_REWRITERS, patch_model
|
||||||
from .symbolic_register import SYMBOLICS_REGISTER, register_extra_symbolics
|
from .symbolic_register import SYMBOLICS_REGISTER, register_extra_symbolics
|
||||||
|
from .function_marker import mark
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'RewriterContext', 'FUNCTION_REWRITERS', 'MODULE_REWRITERS', 'patch_model',
|
'RewriterContext', 'FUNCTION_REWRITERS', 'MODULE_REWRITERS', 'patch_model',
|
||||||
'SYMBOLICS_REGISTER', 'register_extra_symbolics'
|
'SYMBOLICS_REGISTER', 'register_extra_symbolics', 'mark'
|
||||||
]
|
]
|
||||||
|
61
mmdeploy/utils/function_marker.py
Normal file
61
mmdeploy/utils/function_marker.py
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import inspect
|
||||||
|
import torch
|
||||||
|
from .function_rewriter import FUNCTION_REWRITERS
|
||||||
|
|
||||||
|
|
||||||
|
class Mark(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def symbolic(g, x, type, name, id, attrs):
|
||||||
|
n = g.op("mmcv::Mark", x, type_s=type, name_s=name, id_i=id, **attrs)
|
||||||
|
return n
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x, *args):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@FUNCTION_REWRITERS.register_rewriter("mmdeploy.utils.function_marker.Mark.symbolic")
|
||||||
|
def mark_symbolic(rewriter, g, x, *args):
|
||||||
|
if rewriter.cfg.get("apply_marks", False):
|
||||||
|
return rewriter.origin_func(g, x, *args)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def mark_tensors(xs, type, name, attrs):
|
||||||
|
index = 0
|
||||||
|
visit = set()
|
||||||
|
|
||||||
|
def impl(ys, prefix):
|
||||||
|
nonlocal index
|
||||||
|
if isinstance(ys, torch.Tensor):
|
||||||
|
if ys not in visit:
|
||||||
|
visit.add(ys)
|
||||||
|
index += 1
|
||||||
|
return Mark.apply(ys, type, prefix, index - 1, attrs)
|
||||||
|
return ys
|
||||||
|
elif isinstance(ys, list):
|
||||||
|
return [impl(y, f'{prefix}/{i}') for i, y in enumerate(ys)]
|
||||||
|
elif isinstance(ys, tuple):
|
||||||
|
return tuple(impl(y, f'{prefix}/{i}') for i, y in enumerate(ys))
|
||||||
|
elif isinstance(ys, dict):
|
||||||
|
return {k: impl(v, f'{prefix}/{k}') for k, v in ys.items()}
|
||||||
|
return ys
|
||||||
|
return impl(xs, name)
|
||||||
|
|
||||||
|
|
||||||
|
def mark(func, **attrs):
|
||||||
|
attrs['func_s'] = func
|
||||||
|
|
||||||
|
def decorator(f):
|
||||||
|
params = inspect.signature(f).parameters.keys()
|
||||||
|
def g(*args, **kwargs):
|
||||||
|
if torch.onnx.is_in_onnx_export():
|
||||||
|
args = [mark_tensors(arg, 'input', name, attrs)
|
||||||
|
for name, arg in zip(params, args)]
|
||||||
|
rets = f(*args, **kwargs)
|
||||||
|
# TODO: maybe we can traverse the AST to get the retval names?
|
||||||
|
return mark_tensors(rets, 'output', func, attrs)
|
||||||
|
else:
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
return g
|
||||||
|
return decorator
|
234
tools/extract.py
Normal file
234
tools/extract.py
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
import argparse
|
||||||
|
import os.path as osp
|
||||||
|
import onnx
|
||||||
|
import onnx.utils
|
||||||
|
import onnx.helper
|
||||||
|
from onnx import AttributeProto
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Extract model based on markers.')
|
||||||
|
parser.add_argument('input_model', help='Input ONNX model')
|
||||||
|
parser.add_argument('output_model', help='Output ONNX model')
|
||||||
|
parser.add_argument(
|
||||||
|
'--start', help='Start markers, format: func:type, e.g. backbone:input')
|
||||||
|
parser.add_argument('--end', help='End markers')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args.start = args.start.split(',') if args.start else []
|
||||||
|
args.end = args.end.split(',') if args.end else []
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def remove_markers(model):
|
||||||
|
shortcut = []
|
||||||
|
success = True
|
||||||
|
while success:
|
||||||
|
success = False
|
||||||
|
for i, node in enumerate(model.graph.node):
|
||||||
|
if node.op_type == 'Mark':
|
||||||
|
for input in node.input:
|
||||||
|
shortcut.append((input, node.output))
|
||||||
|
del model.graph.node[i]
|
||||||
|
success = True
|
||||||
|
break
|
||||||
|
for src, dsts in shortcut:
|
||||||
|
for curr in model.graph.node:
|
||||||
|
for k, input in enumerate(curr.input):
|
||||||
|
if input in dsts:
|
||||||
|
curr.input[k] = src
|
||||||
|
# TODO: handle duplicated case?
|
||||||
|
for k, output in enumerate(model.graph.output):
|
||||||
|
print(output.name, dsts)
|
||||||
|
if output.name in dsts:
|
||||||
|
output.name = src
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def attribute_to_dict(attribute):
|
||||||
|
ret = {}
|
||||||
|
for a in attribute:
|
||||||
|
name = a.name
|
||||||
|
if a.type == AttributeProto.AttributeType.STRING:
|
||||||
|
ret[name] = str(a.s, 'utf-8')
|
||||||
|
elif a.type == AttributeProto.AttributeType.INT:
|
||||||
|
ret[name] = a.i
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def _dfs_search_reacable_nodes_fast(self, node_output_name, graph_input_nodes, reachable_nodes):
|
||||||
|
outputs = {}
|
||||||
|
for index, node in enumerate(self.graph.node):
|
||||||
|
for name in node.output:
|
||||||
|
if name not in outputs:
|
||||||
|
outputs[name] = set()
|
||||||
|
outputs[name].add(index)
|
||||||
|
|
||||||
|
def impl(node_output_name, graph_input_nodes, reachable_nodes):
|
||||||
|
if node_output_name in graph_input_nodes:
|
||||||
|
return
|
||||||
|
if node_output_name not in outputs:
|
||||||
|
return
|
||||||
|
for index in outputs[node_output_name]:
|
||||||
|
node = self.graph.node[index]
|
||||||
|
if node in reachable_nodes:
|
||||||
|
continue
|
||||||
|
reachable_nodes.append(node)
|
||||||
|
for name in node.input:
|
||||||
|
impl(name, graph_input_nodes, reachable_nodes)
|
||||||
|
impl(node_output_name, graph_input_nodes, reachable_nodes)
|
||||||
|
|
||||||
|
|
||||||
|
def get_new_name(attrs):
|
||||||
|
if 'name' in attrs:
|
||||||
|
return attrs['name']
|
||||||
|
return '_'.join((attrs['func'], attrs['type'], str(attrs['id'])))
|
||||||
|
|
||||||
|
|
||||||
|
def rename_value(model, old_name, new_name):
|
||||||
|
for n in model.graph.node:
|
||||||
|
for i, output in enumerate(n.output):
|
||||||
|
if output == old_name:
|
||||||
|
n.output[i] = new_name
|
||||||
|
for i, input in enumerate(n.input):
|
||||||
|
if input == old_name:
|
||||||
|
n.input[i] = new_name
|
||||||
|
for v in model.graph.value_info:
|
||||||
|
if v.name == old_name:
|
||||||
|
v.name = new_name
|
||||||
|
for i, name in enumerate(model.graph.input):
|
||||||
|
if name == old_name:
|
||||||
|
model.graph.input[i] = new_name
|
||||||
|
for i, name in enumerate(model.graph.output):
|
||||||
|
if name == old_name:
|
||||||
|
model.graph.output[i] = new_name
|
||||||
|
|
||||||
|
|
||||||
|
def extract_model(model, start, end):
|
||||||
|
inputs = []
|
||||||
|
outputs = []
|
||||||
|
if not isinstance(start, (list, tuple)):
|
||||||
|
start = [start]
|
||||||
|
for s in start:
|
||||||
|
start_name, start_type = s.split(':')
|
||||||
|
assert start_type in ['input', 'output']
|
||||||
|
for node in model.graph.node:
|
||||||
|
if node.op_type == 'Mark':
|
||||||
|
attr = attribute_to_dict(node.attribute)
|
||||||
|
if attr['func'] == start_name and attr['type'] == start_type:
|
||||||
|
name = node.output[0] if start_type == 'input' else node.input[0]
|
||||||
|
if name not in inputs:
|
||||||
|
new_name = get_new_name(attr)
|
||||||
|
rename_value(model, name, new_name)
|
||||||
|
inputs.append(new_name)
|
||||||
|
|
||||||
|
print(f'inputs: {inputs}')
|
||||||
|
|
||||||
|
# collect outputs
|
||||||
|
# outputs = []
|
||||||
|
if not isinstance(end, (list, tuple)):
|
||||||
|
end = [end]
|
||||||
|
for e in end:
|
||||||
|
end_name, end_type = e.split(':')
|
||||||
|
assert end_type in ['input', 'output']
|
||||||
|
for node in model.graph.node:
|
||||||
|
if node.op_type == 'Mark':
|
||||||
|
attr = attribute_to_dict(node.attribute)
|
||||||
|
if attr['func'] == end_name and attr['type'] == end_type:
|
||||||
|
name = node.input[0] if end_type == 'input' else node.output[0]
|
||||||
|
if name not in outputs:
|
||||||
|
new_name = get_new_name(attr)
|
||||||
|
rename_value(model, name, new_name)
|
||||||
|
outputs.append(new_name)
|
||||||
|
|
||||||
|
print(f'outputs: {outputs}')
|
||||||
|
|
||||||
|
# replace Mark with Identity
|
||||||
|
for node in model.graph.node:
|
||||||
|
if node.op_type == 'Mark':
|
||||||
|
del node.attribute[:]
|
||||||
|
node.domain = ''
|
||||||
|
node.op_type = 'Identity'
|
||||||
|
|
||||||
|
# patch extractor
|
||||||
|
onnx.utils.Extractor._dfs_search_reachable_nodes = _dfs_search_reacable_nodes_fast
|
||||||
|
|
||||||
|
extractor = onnx.utils.Extractor(model)
|
||||||
|
extracted_model = extractor.extract_model(inputs, outputs)
|
||||||
|
|
||||||
|
# collect all used inputs
|
||||||
|
used = set()
|
||||||
|
for node in extracted_model.graph.node:
|
||||||
|
for input in node.input:
|
||||||
|
used.add(input)
|
||||||
|
|
||||||
|
for output in extracted_model.graph.output:
|
||||||
|
used.add(output.name)
|
||||||
|
|
||||||
|
# delete unused inputs
|
||||||
|
success = True
|
||||||
|
while success:
|
||||||
|
success = False
|
||||||
|
for i, input in enumerate(extracted_model.graph.input):
|
||||||
|
if input.name not in used:
|
||||||
|
del extracted_model.graph.input[i]
|
||||||
|
success = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# eliminate output without shape
|
||||||
|
for xs in [extracted_model.graph.output]:
|
||||||
|
for x in xs:
|
||||||
|
if not x.type.tensor_type.shape.dim:
|
||||||
|
print(f'fixing output shape: {x.name}')
|
||||||
|
x.CopyFrom(onnx.helper.make_tensor_value_info(
|
||||||
|
x.name, x.type.tensor_type.elem_type, []))
|
||||||
|
|
||||||
|
# eliminate 0-batch dimension, dirty workaround for two-stage detectors
|
||||||
|
for input in extracted_model.graph.input:
|
||||||
|
if input.name in inputs:
|
||||||
|
if input.type.tensor_type.shape.dim[0].dim_value == 0:
|
||||||
|
input.type.tensor_type.shape.dim[0].dim_value = 1
|
||||||
|
|
||||||
|
# eliminate duplicated value_info for inputs
|
||||||
|
success = True
|
||||||
|
while success:
|
||||||
|
success = False
|
||||||
|
for i, x in enumerate(extracted_model.graph.value_info):
|
||||||
|
if x.name in inputs:
|
||||||
|
del extracted_model.graph.value_info[i]
|
||||||
|
success = True
|
||||||
|
break
|
||||||
|
|
||||||
|
return extracted_model
|
||||||
|
|
||||||
|
|
||||||
|
def collect_avaiable_marks(model):
|
||||||
|
marks = []
|
||||||
|
for node in model.graph.node:
|
||||||
|
if node.op_type == 'Mark':
|
||||||
|
attr = attribute_to_dict(node.attribute)
|
||||||
|
func = attr['func']
|
||||||
|
if func not in marks:
|
||||||
|
marks.append(func)
|
||||||
|
return marks
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
model = onnx.load(args.input_model)
|
||||||
|
marks = collect_avaiable_marks(model)
|
||||||
|
print("Available marks:\n {}".format('\n '.join(marks)))
|
||||||
|
|
||||||
|
extracted_model = extract_model(model, args.start, args.end)
|
||||||
|
|
||||||
|
if osp.splitext(args.output_model)[-1] != '.onnx':
|
||||||
|
args.output_model += '.onnx'
|
||||||
|
onnx.save(extracted_model, args.output_model)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user