RE-OWOD/detectron2/export/torchscript.py

224 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import importlib.util
import os
import sys
import tempfile
from contextlib import contextmanager
import torch
# need an explicit import due to https://github.com/pytorch/pytorch/issues/38964
from detectron2.structures import Boxes, Instances # noqa F401
_counter = 0
def export_torchscript_with_instances(model, fields):
"""
Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
attributes of :class:`Instances` are "dynamically" added in eager modeit is difficult
for torchscript to support it out of the box. This function is made to support scripting
a model that uses :class:`Instances`. It does the following:
1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
but with all attributes been "static".
The attributes need to be statically declared in the ``fields`` argument.
2. Register ``new_Instances`` to torchscript, and force torchscript to
use it when trying to compile ``Instances``.
After this function, the process will be reverted. User should be able to script another model
using different fields.
Example:
Assume that ``Instances`` in the model consist of two attributes named
``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
:class:`Tensor` respectively during inference. You can call this function like:
::
fields = {"proposal_boxes": "Boxes", "objectness_logits": "Tensor"}
torchscipt_model = export_torchscript_with_instances(model, fields)
Note:
Currently we only support models in evaluation mode. Exporting models in training mode
or running inference processes of torchscripts that are exported from models in training
mode may encounter unexpected errors.
Args:
model (nn.Module): The input model to be exported to torchscript.
fields (Dict[str, str]): Attribute names and corresponding type annotations that
``Instances`` will use in the model. Note that all attributes used in ``Instances``
need to be added, regarldess of whether they are inputs/outputs of the model.
Custom data type is not supported for now.
Returns:
torch.jit.ScriptModule: the input model in torchscript format
"""
assert (
not model.training
), "Currently we only support exporting models in evaluation mode to torchscript"
with patch_instances(fields):
scripted_model = torch.jit.script(model)
return scripted_model
@contextmanager
def patch_instances(fields):
with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile(
mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False
) as f:
try:
cls_name, s = _gen_module(fields)
f.write(s)
f.flush()
f.close()
module = _import(f.name)
new_instances = getattr(module, cls_name)
_ = torch.jit.script(new_instances)
# let torchscript think Instances was scripted already
Instances.__torch_script_class__ = True
# let torchscript find new_instances when looking for the jit type of Instances
Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances)
yield new_instances
finally:
try:
del Instances.__torch_script_class__
del Instances._jit_override_qualname
except AttributeError:
pass
sys.modules.pop(module.__name__)
# TODO: find a more automatic way to enable import of other classes
def _gen_imports():
imports_str = """
from copy import deepcopy
import torch
from torch import Tensor
import typing
from typing import *
from detectron2.structures import Boxes, Instances
"""
return imports_str
def _gen_class(fields):
def indent(level, s):
return " " * 4 * level + s
lines = []
global _counter
_counter += 1
cls_name = "Instances_patched{}".format(_counter)
lines.append(
f"""
class {cls_name}:
def __init__(self, image_size: Tuple[int, int]):
self.image_size = image_size
"""
)
for name, type_ in fields.items():
lines.append(indent(2, f"self._{name} = torch.jit.annotate(Optional[{type_}], None)"))
for name, type_ in fields.items():
lines.append(
f"""
@property
def {name}(self) -> {type_}:
# has to use a local for type refinement
# https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement
t = self._{name}
assert t is not None
return t
@{name}.setter
def {name}(self, value: {type_}) -> None:
self._{name} = value
"""
)
# support function attribute `__len__`
lines.append(
"""
def __len__(self) -> int:
"""
)
for name, _ in fields.items():
lines.append(
f"""
t = self._{name}
if t is not None:
return len(t)
"""
)
lines.append(
"""
raise NotImplementedError("Empty Instances does not support __len__!")
"""
)
# support function attribute `has`
lines.append(
"""
def has(self, name: str) -> bool:
"""
)
for name, _ in fields.items():
lines.append(
f"""
if name == "{name}":
return self._{name} is not None
"""
)
lines.append(
"""
return False
"""
)
# support function attribute `from_instances`
lines.append(
f"""
@torch.jit.unused
@staticmethod
def from_instances(instances: Instances) -> "{cls_name}":
fields = instances.get_fields()
image_size = instances.image_size
new_instances = {cls_name}(image_size)
for name, val in fields.items():
assert hasattr(new_instances, '_{{}}'.format(name)), \\
"No attribute named {{}} in {cls_name}".format(name)
setattr(new_instances, name, deepcopy(val))
return new_instances
"""
)
return cls_name, os.linesep.join(lines)
def _gen_module(fields):
s = ""
s += _gen_imports()
cls_name, cls_def = _gen_class(fields)
s += cls_def
return cls_name, s
def _import(path):
# https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
spec = importlib.util.spec_from_file_location(
"{}{}".format(sys.modules[__name__].__name__, _counter), path
)
module = importlib.util.module_from_spec(spec)
sys.modules[module.__name__] = module
spec.loader.exec_module(module)
return module