mirror of https://github.com/JosephKJ/OWOD.git
224 lines
6.7 KiB
Python
224 lines
6.7 KiB
Python
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|||
|
|
|||
|
import importlib.util
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import tempfile
|
|||
|
from contextlib import contextmanager
|
|||
|
import torch
|
|||
|
|
|||
|
# need an explicit import due to https://github.com/pytorch/pytorch/issues/38964
|
|||
|
from detectron2.structures import Boxes, Instances # noqa F401
|
|||
|
|
|||
|
_counter = 0
|
|||
|
|
|||
|
|
|||
|
def export_torchscript_with_instances(model, fields):
|
|||
|
"""
|
|||
|
Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
|
|||
|
attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult
|
|||
|
for torchscript to support it out of the box. This function is made to support scripting
|
|||
|
a model that uses :class:`Instances`. It does the following:
|
|||
|
|
|||
|
1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
|
|||
|
but with all attributes been "static".
|
|||
|
The attributes need to be statically declared in the ``fields`` argument.
|
|||
|
2. Register ``new_Instances`` to torchscript, and force torchscript to
|
|||
|
use it when trying to compile ``Instances``.
|
|||
|
|
|||
|
After this function, the process will be reverted. User should be able to script another model
|
|||
|
using different fields.
|
|||
|
|
|||
|
Example:
|
|||
|
Assume that ``Instances`` in the model consist of two attributes named
|
|||
|
``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
|
|||
|
:class:`Tensor` respectively during inference. You can call this function like:
|
|||
|
|
|||
|
::
|
|||
|
fields = {"proposal_boxes": "Boxes", "objectness_logits": "Tensor"}
|
|||
|
torchscipt_model = export_torchscript_with_instances(model, fields)
|
|||
|
|
|||
|
Note:
|
|||
|
Currently we only support models in evaluation mode. Exporting models in training mode
|
|||
|
or running inference processes of torchscripts that are exported from models in training
|
|||
|
mode may encounter unexpected errors.
|
|||
|
|
|||
|
Args:
|
|||
|
model (nn.Module): The input model to be exported to torchscript.
|
|||
|
fields (Dict[str, str]): Attribute names and corresponding type annotations that
|
|||
|
``Instances`` will use in the model. Note that all attributes used in ``Instances``
|
|||
|
need to be added, regarldess of whether they are inputs/outputs of the model.
|
|||
|
Custom data type is not supported for now.
|
|||
|
|
|||
|
Returns:
|
|||
|
torch.jit.ScriptModule: the input model in torchscript format
|
|||
|
"""
|
|||
|
|
|||
|
assert (
|
|||
|
not model.training
|
|||
|
), "Currently we only support exporting models in evaluation mode to torchscript"
|
|||
|
|
|||
|
with patch_instances(fields):
|
|||
|
scripted_model = torch.jit.script(model)
|
|||
|
return scripted_model
|
|||
|
|
|||
|
|
|||
|
@contextmanager
|
|||
|
def patch_instances(fields):
|
|||
|
with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile(
|
|||
|
mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False
|
|||
|
) as f:
|
|||
|
try:
|
|||
|
cls_name, s = _gen_module(fields)
|
|||
|
f.write(s)
|
|||
|
f.flush()
|
|||
|
f.close()
|
|||
|
|
|||
|
module = _import(f.name)
|
|||
|
new_instances = getattr(module, cls_name)
|
|||
|
_ = torch.jit.script(new_instances)
|
|||
|
|
|||
|
# let torchscript think Instances was scripted already
|
|||
|
Instances.__torch_script_class__ = True
|
|||
|
# let torchscript find new_instances when looking for the jit type of Instances
|
|||
|
Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances)
|
|||
|
yield new_instances
|
|||
|
finally:
|
|||
|
try:
|
|||
|
del Instances.__torch_script_class__
|
|||
|
del Instances._jit_override_qualname
|
|||
|
except AttributeError:
|
|||
|
pass
|
|||
|
sys.modules.pop(module.__name__)
|
|||
|
|
|||
|
|
|||
|
# TODO: find a more automatic way to enable import of other classes
|
|||
|
def _gen_imports():
|
|||
|
imports_str = """
|
|||
|
from copy import deepcopy
|
|||
|
import torch
|
|||
|
from torch import Tensor
|
|||
|
import typing
|
|||
|
from typing import *
|
|||
|
|
|||
|
from detectron2.structures import Boxes, Instances
|
|||
|
|
|||
|
"""
|
|||
|
return imports_str
|
|||
|
|
|||
|
|
|||
|
def _gen_class(fields):
|
|||
|
def indent(level, s):
|
|||
|
return " " * 4 * level + s
|
|||
|
|
|||
|
lines = []
|
|||
|
|
|||
|
global _counter
|
|||
|
_counter += 1
|
|||
|
|
|||
|
cls_name = "Instances_patched{}".format(_counter)
|
|||
|
|
|||
|
lines.append(
|
|||
|
f"""
|
|||
|
class {cls_name}:
|
|||
|
def __init__(self, image_size: Tuple[int, int]):
|
|||
|
self.image_size = image_size
|
|||
|
"""
|
|||
|
)
|
|||
|
|
|||
|
for name, type_ in fields.items():
|
|||
|
lines.append(indent(2, f"self._{name} = torch.jit.annotate(Optional[{type_}], None)"))
|
|||
|
|
|||
|
for name, type_ in fields.items():
|
|||
|
lines.append(
|
|||
|
f"""
|
|||
|
@property
|
|||
|
def {name}(self) -> {type_}:
|
|||
|
# has to use a local for type refinement
|
|||
|
# https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement
|
|||
|
t = self._{name}
|
|||
|
assert t is not None
|
|||
|
return t
|
|||
|
|
|||
|
@{name}.setter
|
|||
|
def {name}(self, value: {type_}) -> None:
|
|||
|
self._{name} = value
|
|||
|
"""
|
|||
|
)
|
|||
|
|
|||
|
# support function attribute `__len__`
|
|||
|
lines.append(
|
|||
|
"""
|
|||
|
def __len__(self) -> int:
|
|||
|
"""
|
|||
|
)
|
|||
|
for name, _ in fields.items():
|
|||
|
lines.append(
|
|||
|
f"""
|
|||
|
t = self._{name}
|
|||
|
if t is not None:
|
|||
|
return len(t)
|
|||
|
"""
|
|||
|
)
|
|||
|
lines.append(
|
|||
|
"""
|
|||
|
raise NotImplementedError("Empty Instances does not support __len__!")
|
|||
|
"""
|
|||
|
)
|
|||
|
|
|||
|
# support function attribute `has`
|
|||
|
lines.append(
|
|||
|
"""
|
|||
|
def has(self, name: str) -> bool:
|
|||
|
"""
|
|||
|
)
|
|||
|
for name, _ in fields.items():
|
|||
|
lines.append(
|
|||
|
f"""
|
|||
|
if name == "{name}":
|
|||
|
return self._{name} is not None
|
|||
|
"""
|
|||
|
)
|
|||
|
lines.append(
|
|||
|
"""
|
|||
|
return False
|
|||
|
"""
|
|||
|
)
|
|||
|
|
|||
|
# support function attribute `from_instances`
|
|||
|
lines.append(
|
|||
|
f"""
|
|||
|
@torch.jit.unused
|
|||
|
@staticmethod
|
|||
|
def from_instances(instances: Instances) -> "{cls_name}":
|
|||
|
fields = instances.get_fields()
|
|||
|
image_size = instances.image_size
|
|||
|
new_instances = {cls_name}(image_size)
|
|||
|
for name, val in fields.items():
|
|||
|
assert hasattr(new_instances, '_{{}}'.format(name)), \\
|
|||
|
"No attribute named {{}} in {cls_name}".format(name)
|
|||
|
setattr(new_instances, name, deepcopy(val))
|
|||
|
return new_instances
|
|||
|
"""
|
|||
|
)
|
|||
|
return cls_name, os.linesep.join(lines)
|
|||
|
|
|||
|
|
|||
|
def _gen_module(fields):
|
|||
|
s = ""
|
|||
|
s += _gen_imports()
|
|||
|
cls_name, cls_def = _gen_class(fields)
|
|||
|
s += cls_def
|
|||
|
return cls_name, s
|
|||
|
|
|||
|
|
|||
|
def _import(path):
|
|||
|
# https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
|||
|
spec = importlib.util.spec_from_file_location(
|
|||
|
"{}{}".format(sys.modules[__name__].__name__, _counter), path
|
|||
|
)
|
|||
|
module = importlib.util.module_from_spec(spec)
|
|||
|
sys.modules[module.__name__] = module
|
|||
|
spec.loader.exec_module(module)
|
|||
|
return module
|