mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
add model splitting support (#1)
* add function marker and model extractor * add fsaf split & partial mask rcnn split, import extract.py * 1. add value renaming 2. add apply_marks in config to turn on/off marks * rewind changes on pytorch2onnx.py Co-authored-by: q.yao <streetyao@live.com>
This commit is contained in:
parent
5998d24766
commit
ef41f69553
4
configs/mmdet/split.py
Normal file
4
configs/mmdet/split.py
Normal file
@ -0,0 +1,4 @@
|
||||
_base_ = ['./base.py', '../_base_/backends/tensorrt.py']
|
||||
|
||||
backend = 'default'
|
||||
apply_marks = True
|
@ -62,4 +62,4 @@ def torch2onnx(img: Any,
|
||||
keep_initializers_as_inputs=pytorch2onnx_cfg[
|
||||
'keep_initializers_as_inputs'])
|
||||
|
||||
ret_value.value = 0
|
||||
ret_value.value = 0
|
@ -1,3 +1,5 @@
|
||||
from .anchor_head import AnchorHead
|
||||
from .rpn_head import rpn_head_forward
|
||||
from .fsaf_head import fsaf_head_forward
|
||||
|
||||
__all__ = ['AnchorHead']
|
||||
__all__ = ['AnchorHead', 'rpn_head_forward', 'fsaf_head_forward']
|
||||
|
7
mmdeploy/mmdet/models/dense_heads/fsaf_head.py
Normal file
7
mmdeploy/mmdet/models/dense_heads/fsaf_head.py
Normal file
@ -0,0 +1,7 @@
|
||||
from mmdeploy.utils import FUNCTION_REWRITERS, mark
|
||||
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter('mmdet.models.FSAFHead.forward')
|
||||
@mark('rpn_forward')
|
||||
def fsaf_head_forward(rewriter, *args):
|
||||
return rewriter.origin_func(*args)
|
7
mmdeploy/mmdet/models/dense_heads/rpn_head.py
Normal file
7
mmdeploy/mmdet/models/dense_heads/rpn_head.py
Normal file
@ -0,0 +1,7 @@
|
||||
from mmdeploy.utils import FUNCTION_REWRITERS, mark
|
||||
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter('mmdet.models.RPNHead.forward')
|
||||
@mark('rpn_forward')
|
||||
def rpn_head_forward(rewriter, self, feats):
|
||||
return rewriter.origin_func(self, feats)
|
@ -1,3 +1,4 @@
|
||||
from .single_stage import SingleStageDetector
|
||||
from .two_stage import extract_feat
|
||||
|
||||
__all__ = ['SingleStageDetector']
|
||||
__all__ = ['SingleStageDetector', 'extract_feat']
|
||||
|
19
mmdeploy/mmdet/models/detectors/two_stage.py
Normal file
19
mmdeploy/mmdet/models/detectors/two_stage.py
Normal file
@ -0,0 +1,19 @@
|
||||
from mmdeploy.utils import FUNCTION_REWRITERS, mark
|
||||
from mmdeploy.utils import SYMBOLICS_REGISTER
|
||||
from mmcv.onnx.symbolic import grid_sampler
|
||||
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter('mmdet.models.TwoStageDetector.extract_feat')
|
||||
@mark('extract_feat')
|
||||
def extract_feat(rewriter, self, img):
|
||||
return rewriter.origin_func(self, img)
|
||||
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter('mmdet.models.TwoStageDetector.forward')
|
||||
def two_stage_forward(rewriter, self, img, *args):
|
||||
return rewriter.origin_func(self, [img], img_metas=[[{}]], return_loss=False, *args)
|
||||
|
||||
|
||||
@SYMBOLICS_REGISTER.register_symbolic('grid_sampler', is_pytorch=True)
|
||||
def symbolic_grid_sample(symbolic_wrapper, *args):
|
||||
return grid_sampler(*args)
|
@ -1,8 +1,9 @@
|
||||
from .function_rewriter import FUNCTION_REWRITERS, RewriterContext
|
||||
from .module_rewriter import MODULE_REWRITERS, patch_model
|
||||
from .symbolic_register import SYMBOLICS_REGISTER, register_extra_symbolics
|
||||
from .function_marker import mark
|
||||
|
||||
__all__ = [
|
||||
'RewriterContext', 'FUNCTION_REWRITERS', 'MODULE_REWRITERS', 'patch_model',
|
||||
'SYMBOLICS_REGISTER', 'register_extra_symbolics'
|
||||
'SYMBOLICS_REGISTER', 'register_extra_symbolics', 'mark'
|
||||
]
|
||||
|
61
mmdeploy/utils/function_marker.py
Normal file
61
mmdeploy/utils/function_marker.py
Normal file
@ -0,0 +1,61 @@
|
||||
import inspect
|
||||
import torch
|
||||
from .function_rewriter import FUNCTION_REWRITERS
|
||||
|
||||
|
||||
class Mark(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def symbolic(g, x, type, name, id, attrs):
|
||||
n = g.op("mmcv::Mark", x, type_s=type, name_s=name, id_i=id, **attrs)
|
||||
return n
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, *args):
|
||||
return x
|
||||
|
||||
|
||||
@FUNCTION_REWRITERS.register_rewriter("mmdeploy.utils.function_marker.Mark.symbolic")
|
||||
def mark_symbolic(rewriter, g, x, *args):
|
||||
if rewriter.cfg.get("apply_marks", False):
|
||||
return rewriter.origin_func(g, x, *args)
|
||||
return x
|
||||
|
||||
|
||||
def mark_tensors(xs, type, name, attrs):
|
||||
index = 0
|
||||
visit = set()
|
||||
|
||||
def impl(ys, prefix):
|
||||
nonlocal index
|
||||
if isinstance(ys, torch.Tensor):
|
||||
if ys not in visit:
|
||||
visit.add(ys)
|
||||
index += 1
|
||||
return Mark.apply(ys, type, prefix, index - 1, attrs)
|
||||
return ys
|
||||
elif isinstance(ys, list):
|
||||
return [impl(y, f'{prefix}/{i}') for i, y in enumerate(ys)]
|
||||
elif isinstance(ys, tuple):
|
||||
return tuple(impl(y, f'{prefix}/{i}') for i, y in enumerate(ys))
|
||||
elif isinstance(ys, dict):
|
||||
return {k: impl(v, f'{prefix}/{k}') for k, v in ys.items()}
|
||||
return ys
|
||||
return impl(xs, name)
|
||||
|
||||
|
||||
def mark(func, **attrs):
|
||||
attrs['func_s'] = func
|
||||
|
||||
def decorator(f):
|
||||
params = inspect.signature(f).parameters.keys()
|
||||
def g(*args, **kwargs):
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
args = [mark_tensors(arg, 'input', name, attrs)
|
||||
for name, arg in zip(params, args)]
|
||||
rets = f(*args, **kwargs)
|
||||
# TODO: maybe we can traverse the AST to get the retval names?
|
||||
return mark_tensors(rets, 'output', func, attrs)
|
||||
else:
|
||||
return f(*args, **kwargs)
|
||||
return g
|
||||
return decorator
|
234
tools/extract.py
Normal file
234
tools/extract.py
Normal file
@ -0,0 +1,234 @@
|
||||
import argparse
|
||||
import os.path as osp
|
||||
import onnx
|
||||
import onnx.utils
|
||||
import onnx.helper
|
||||
from onnx import AttributeProto
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Extract model based on markers.')
|
||||
parser.add_argument('input_model', help='Input ONNX model')
|
||||
parser.add_argument('output_model', help='Output ONNX model')
|
||||
parser.add_argument(
|
||||
'--start', help='Start markers, format: func:type, e.g. backbone:input')
|
||||
parser.add_argument('--end', help='End markers')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.start = args.start.split(',') if args.start else []
|
||||
args.end = args.end.split(',') if args.end else []
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def remove_markers(model):
|
||||
shortcut = []
|
||||
success = True
|
||||
while success:
|
||||
success = False
|
||||
for i, node in enumerate(model.graph.node):
|
||||
if node.op_type == 'Mark':
|
||||
for input in node.input:
|
||||
shortcut.append((input, node.output))
|
||||
del model.graph.node[i]
|
||||
success = True
|
||||
break
|
||||
for src, dsts in shortcut:
|
||||
for curr in model.graph.node:
|
||||
for k, input in enumerate(curr.input):
|
||||
if input in dsts:
|
||||
curr.input[k] = src
|
||||
# TODO: handle duplicated case?
|
||||
for k, output in enumerate(model.graph.output):
|
||||
print(output.name, dsts)
|
||||
if output.name in dsts:
|
||||
output.name = src
|
||||
return model
|
||||
|
||||
|
||||
def attribute_to_dict(attribute):
|
||||
ret = {}
|
||||
for a in attribute:
|
||||
name = a.name
|
||||
if a.type == AttributeProto.AttributeType.STRING:
|
||||
ret[name] = str(a.s, 'utf-8')
|
||||
elif a.type == AttributeProto.AttributeType.INT:
|
||||
ret[name] = a.i
|
||||
return ret
|
||||
|
||||
|
||||
def _dfs_search_reacable_nodes_fast(self, node_output_name, graph_input_nodes, reachable_nodes):
|
||||
outputs = {}
|
||||
for index, node in enumerate(self.graph.node):
|
||||
for name in node.output:
|
||||
if name not in outputs:
|
||||
outputs[name] = set()
|
||||
outputs[name].add(index)
|
||||
|
||||
def impl(node_output_name, graph_input_nodes, reachable_nodes):
|
||||
if node_output_name in graph_input_nodes:
|
||||
return
|
||||
if node_output_name not in outputs:
|
||||
return
|
||||
for index in outputs[node_output_name]:
|
||||
node = self.graph.node[index]
|
||||
if node in reachable_nodes:
|
||||
continue
|
||||
reachable_nodes.append(node)
|
||||
for name in node.input:
|
||||
impl(name, graph_input_nodes, reachable_nodes)
|
||||
impl(node_output_name, graph_input_nodes, reachable_nodes)
|
||||
|
||||
|
||||
def get_new_name(attrs):
|
||||
if 'name' in attrs:
|
||||
return attrs['name']
|
||||
return '_'.join((attrs['func'], attrs['type'], str(attrs['id'])))
|
||||
|
||||
|
||||
def rename_value(model, old_name, new_name):
|
||||
for n in model.graph.node:
|
||||
for i, output in enumerate(n.output):
|
||||
if output == old_name:
|
||||
n.output[i] = new_name
|
||||
for i, input in enumerate(n.input):
|
||||
if input == old_name:
|
||||
n.input[i] = new_name
|
||||
for v in model.graph.value_info:
|
||||
if v.name == old_name:
|
||||
v.name = new_name
|
||||
for i, name in enumerate(model.graph.input):
|
||||
if name == old_name:
|
||||
model.graph.input[i] = new_name
|
||||
for i, name in enumerate(model.graph.output):
|
||||
if name == old_name:
|
||||
model.graph.output[i] = new_name
|
||||
|
||||
|
||||
def extract_model(model, start, end):
|
||||
inputs = []
|
||||
outputs = []
|
||||
if not isinstance(start, (list, tuple)):
|
||||
start = [start]
|
||||
for s in start:
|
||||
start_name, start_type = s.split(':')
|
||||
assert start_type in ['input', 'output']
|
||||
for node in model.graph.node:
|
||||
if node.op_type == 'Mark':
|
||||
attr = attribute_to_dict(node.attribute)
|
||||
if attr['func'] == start_name and attr['type'] == start_type:
|
||||
name = node.output[0] if start_type == 'input' else node.input[0]
|
||||
if name not in inputs:
|
||||
new_name = get_new_name(attr)
|
||||
rename_value(model, name, new_name)
|
||||
inputs.append(new_name)
|
||||
|
||||
print(f'inputs: {inputs}')
|
||||
|
||||
# collect outputs
|
||||
# outputs = []
|
||||
if not isinstance(end, (list, tuple)):
|
||||
end = [end]
|
||||
for e in end:
|
||||
end_name, end_type = e.split(':')
|
||||
assert end_type in ['input', 'output']
|
||||
for node in model.graph.node:
|
||||
if node.op_type == 'Mark':
|
||||
attr = attribute_to_dict(node.attribute)
|
||||
if attr['func'] == end_name and attr['type'] == end_type:
|
||||
name = node.input[0] if end_type == 'input' else node.output[0]
|
||||
if name not in outputs:
|
||||
new_name = get_new_name(attr)
|
||||
rename_value(model, name, new_name)
|
||||
outputs.append(new_name)
|
||||
|
||||
print(f'outputs: {outputs}')
|
||||
|
||||
# replace Mark with Identity
|
||||
for node in model.graph.node:
|
||||
if node.op_type == 'Mark':
|
||||
del node.attribute[:]
|
||||
node.domain = ''
|
||||
node.op_type = 'Identity'
|
||||
|
||||
# patch extractor
|
||||
onnx.utils.Extractor._dfs_search_reachable_nodes = _dfs_search_reacable_nodes_fast
|
||||
|
||||
extractor = onnx.utils.Extractor(model)
|
||||
extracted_model = extractor.extract_model(inputs, outputs)
|
||||
|
||||
# collect all used inputs
|
||||
used = set()
|
||||
for node in extracted_model.graph.node:
|
||||
for input in node.input:
|
||||
used.add(input)
|
||||
|
||||
for output in extracted_model.graph.output:
|
||||
used.add(output.name)
|
||||
|
||||
# delete unused inputs
|
||||
success = True
|
||||
while success:
|
||||
success = False
|
||||
for i, input in enumerate(extracted_model.graph.input):
|
||||
if input.name not in used:
|
||||
del extracted_model.graph.input[i]
|
||||
success = True
|
||||
break
|
||||
|
||||
# eliminate output without shape
|
||||
for xs in [extracted_model.graph.output]:
|
||||
for x in xs:
|
||||
if not x.type.tensor_type.shape.dim:
|
||||
print(f'fixing output shape: {x.name}')
|
||||
x.CopyFrom(onnx.helper.make_tensor_value_info(
|
||||
x.name, x.type.tensor_type.elem_type, []))
|
||||
|
||||
# eliminate 0-batch dimension, dirty workaround for two-stage detectors
|
||||
for input in extracted_model.graph.input:
|
||||
if input.name in inputs:
|
||||
if input.type.tensor_type.shape.dim[0].dim_value == 0:
|
||||
input.type.tensor_type.shape.dim[0].dim_value = 1
|
||||
|
||||
# eliminate duplicated value_info for inputs
|
||||
success = True
|
||||
while success:
|
||||
success = False
|
||||
for i, x in enumerate(extracted_model.graph.value_info):
|
||||
if x.name in inputs:
|
||||
del extracted_model.graph.value_info[i]
|
||||
success = True
|
||||
break
|
||||
|
||||
return extracted_model
|
||||
|
||||
|
||||
def collect_avaiable_marks(model):
|
||||
marks = []
|
||||
for node in model.graph.node:
|
||||
if node.op_type == 'Mark':
|
||||
attr = attribute_to_dict(node.attribute)
|
||||
func = attr['func']
|
||||
if func not in marks:
|
||||
marks.append(func)
|
||||
return marks
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
model = onnx.load(args.input_model)
|
||||
marks = collect_avaiable_marks(model)
|
||||
print("Available marks:\n {}".format('\n '.join(marks)))
|
||||
|
||||
extracted_model = extract_model(model, args.start, args.end)
|
||||
|
||||
if osp.splitext(args.output_model)[-1] != '.onnx':
|
||||
args.output_model += '.onnx'
|
||||
onnx.save(extracted_model, args.output_model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user