94 lines
4.6 KiB
Markdown
94 lines
4.6 KiB
Markdown
|
# How to support new models
|
||
|
|
||
|
We provide several tools to support model conversion.
|
||
|
|
||
|
## Function Rewriter
|
||
|
|
||
|
The PyTorch neural network is written in python that eases the development of the algorithm. But the use of Python control flow and third-party libraries make it difficult to export the network to an intermediate representation. We provide a 'monkey patch' tool to rewrite the unsupported function to another one that can be exported. Here is an example:
|
||
|
|
||
|
```python
|
||
|
from mmdeploy.core import FUNCTION_REWRITER
|
||
|
|
||
|
@FUNCTION_REWRITER.register_rewriter(
|
||
|
func_name='torch.Tensor.repeat', backend='tensorrt')
|
||
|
def repeat_static(ctx, input, *size):
|
||
|
origin_func = ctx.origin_func
|
||
|
if input.dim() == 1 and len(size) == 1:
|
||
|
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
|
||
|
else:
|
||
|
return origin_func(input, *size)
|
||
|
```
|
||
|
|
||
|
It is easy to use the function rewriter. Just add a decorator with arguments:
|
||
|
|
||
|
- `func_name` is the function to override. It can be either a PyTorch function or a custom function. Methods in modules can also be overridden by this tool.
|
||
|
- `backend` is the inference engine. The function will be overridden when the model is exported to this engine. If it is not given, this rewrite will be the default rewrite. The default rewrite will be used if the rewrite of the given backend does not exist.
|
||
|
|
||
|
The arguments are the same as the original function, except a context `ctx` as the first argument. The context provides some useful information such as the deployment config `ctx.cfg` and the original function (which has been overridden) `ctx.origin_func`.
|
||
|
|
||
|
## Module Rewriter
|
||
|
|
||
|
If you want to replace a whole module with another one, we have another rewriter as follows:
|
||
|
|
||
|
```python
|
||
|
@MODULE_REWRITER.register_rewrite_module(
|
||
|
'mmedit.models.backbones.sr_backbones.SRCNN', backend='tensorrt')
|
||
|
class SRCNNWrapper(nn.Module):
|
||
|
|
||
|
def __init__(self,
|
||
|
module,
|
||
|
cfg,
|
||
|
channels=(3, 64, 32, 3),
|
||
|
kernel_sizes=(9, 1, 5),
|
||
|
upscale_factor=4):
|
||
|
super(SRCNNWrapper, self).__init__()
|
||
|
|
||
|
self._module = module
|
||
|
|
||
|
module.img_upsampler = nn.Upsample(
|
||
|
scale_factor=module.upscale_factor,
|
||
|
mode='bilinear',
|
||
|
align_corners=False)
|
||
|
|
||
|
def forward(self, *args, **kwargs):
|
||
|
"""Run forward."""
|
||
|
return self._module(*args, **kwargs)
|
||
|
|
||
|
def init_weights(self, *args, **kwargs):
|
||
|
"""Initialize weights."""
|
||
|
return self._module.init_weights(*args, **kwargs)
|
||
|
```
|
||
|
|
||
|
Just like function rewriter, add a decorator with arguments:
|
||
|
|
||
|
- `module_type` the module class to rewrite.
|
||
|
- `backend` is the inference engine. The function will be overridden when the model is exported to this engine. If it is not given, this rewrite will be the default rewrite. The default rewrite will be used if the rewrite of the given backend does not exist.
|
||
|
|
||
|
All instances of the module in the network will be replaced with instances of this new class. The original module and the deployment config will be passed as the first two arguments.
|
||
|
|
||
|
## Custom Symbolic
|
||
|
|
||
|
The mappings between PyTorch and ONNX are defined in PyTorch with symbolic functions. The custom symbolic function can help us to bypass some ONNX nodes which are unsupported by inference engine.
|
||
|
|
||
|
```python
|
||
|
@SYMBOLIC_REWRITER.register_symbolic('squeeze', is_pytorch=True)
|
||
|
def squeeze_default(ctx, g, self, dim=None):
|
||
|
if dim is None:
|
||
|
dims = []
|
||
|
for i, size in enumerate(self.type().sizes()):
|
||
|
if size == 1:
|
||
|
dims.append(i)
|
||
|
else:
|
||
|
dims = [sym_help._get_const(dim, 'i', 'dim')]
|
||
|
return g.op('Squeeze', self, axes_i=dims)
|
||
|
```
|
||
|
|
||
|
The decorator arguments:
|
||
|
|
||
|
- `func_name` The function name to add symbolic. Use full path if it is a custom `torch.autograd.Function`. Or just a name if it is a PyTorch built-in function.
|
||
|
- `backend` is the inference engine. The function will be overridden when the model is exported to this engine. If it is not given, this rewrite will be the default rewrite. The default rewrite will be used if the rewrite of the given backend does not exist.
|
||
|
- `is_pytorch` True if the function is a PyTorch built-in function.
|
||
|
- `arg_descriptors` the descriptors of the symbolic function arguments. Will be feed to `torch.onnx.symbolic_helper._parse_arg`.
|
||
|
|
||
|
Just like function rewriter, there is a context `ctx` as the first argument. The context provides some useful information such as the deployment config `ctx.cfg` and the original function (which has been overridden) `ctx.origin_func`. Note that the `ctx.origin_func` can be used only when `is_pytorch==False`.
|