# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import importlib.util
import os
import sys
import tempfile
from contextlib import contextmanager
import torch

# need an explicit import due to https://github.com/pytorch/pytorch/issues/38964
from detectron2.structures import Boxes, Instances  # noqa F401

_counter = 0


def export_torchscript_with_instances(model, fields):
    """
    Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since
    attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult
    for torchscript to support it out of the box. This function is made to support scripting
    a model that uses :class:`Instances`. It does the following:

    1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``,
       but with all attributes been "static".
       The attributes need to be statically declared in the ``fields`` argument.
    2. Register ``new_Instances`` to torchscript, and force torchscript to
       use it when trying to compile ``Instances``.

    After this function, the process will be reverted. User should be able to script another model
    using different fields.

    Example:
        Assume that ``Instances`` in the model consist of two attributes named
        ``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and
        :class:`Tensor` respectively during inference. You can call this function like:

        ::
            fields = {"proposal_boxes": "Boxes", "objectness_logits": "Tensor"}
            torchscipt_model =  export_torchscript_with_instances(model, fields)

    Note:
        Currently we only support models in evaluation mode. Exporting models in training mode
        or running inference processes of torchscripts that are exported from models in training
        mode may encounter unexpected errors.

    Args:
        model (nn.Module): The input model to be exported to torchscript.
        fields (Dict[str, str]): Attribute names and corresponding type annotations that
            ``Instances`` will use in the model. Note that all attributes used in ``Instances``
            need to be added, regarldess of whether they are inputs/outputs of the model.
            Custom data type is not supported for now.

    Returns:
        torch.jit.ScriptModule: the input model in torchscript format
    """

    assert (
        not model.training
    ), "Currently we only support exporting models in evaluation mode to torchscript"

    with patch_instances(fields):
        scripted_model = torch.jit.script(model)
        return scripted_model


@contextmanager
def patch_instances(fields):
    with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile(
        mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False
    ) as f:
        try:
            cls_name, s = _gen_module(fields)
            f.write(s)
            f.flush()
            f.close()

            module = _import(f.name)
            new_instances = getattr(module, cls_name)
            _ = torch.jit.script(new_instances)

            # let torchscript think Instances was scripted already
            Instances.__torch_script_class__ = True
            # let torchscript find new_instances when looking for the jit type of Instances
            Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances)
            yield new_instances
        finally:
            try:
                del Instances.__torch_script_class__
                del Instances._jit_override_qualname
            except AttributeError:
                pass
            sys.modules.pop(module.__name__)


# TODO: find a more automatic way to enable import of other classes
def _gen_imports():
    imports_str = """
from copy import deepcopy
import torch
from torch import Tensor
import typing
from typing import *

from detectron2.structures import Boxes, Instances

"""
    return imports_str


def _gen_class(fields):
    def indent(level, s):
        return " " * 4 * level + s

    lines = []

    global _counter
    _counter += 1

    cls_name = "Instances_patched{}".format(_counter)

    lines.append(
        f"""
class {cls_name}:
    def __init__(self, image_size: Tuple[int, int]):
        self.image_size = image_size
"""
    )

    for name, type_ in fields.items():
        lines.append(indent(2, f"self._{name} = torch.jit.annotate(Optional[{type_}], None)"))

    for name, type_ in fields.items():
        lines.append(
            f"""
    @property
    def {name}(self) -> {type_}:
        # has to use a local for type refinement
        # https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement
        t = self._{name}
        assert t is not None
        return t

    @{name}.setter
    def {name}(self, value: {type_}) -> None:
        self._{name} = value
"""
        )

    # support function attribute `__len__`
    lines.append(
        """
    def __len__(self) -> int:
"""
    )
    for name, _ in fields.items():
        lines.append(
            f"""
        t = self._{name}
        if t is not None:
            return len(t)
"""
        )
    lines.append(
        """
        raise NotImplementedError("Empty Instances does not support __len__!")
"""
    )

    # support function attribute `has`
    lines.append(
        """
    def has(self, name: str) -> bool:
"""
    )
    for name, _ in fields.items():
        lines.append(
            f"""
        if name == "{name}":
            return self._{name} is not None
"""
        )
    lines.append(
        """
        return False
"""
    )

    # support function attribute `from_instances`
    lines.append(
        f"""
    @torch.jit.unused
    @staticmethod
    def from_instances(instances: Instances) -> "{cls_name}":
        fields = instances.get_fields()
        image_size = instances.image_size
        new_instances = {cls_name}(image_size)
        for name, val in fields.items():
            assert hasattr(new_instances, '_{{}}'.format(name)), \\
                "No attribute named {{}} in {cls_name}".format(name)
            setattr(new_instances, name, deepcopy(val))
        return new_instances
"""
    )
    return cls_name, os.linesep.join(lines)


def _gen_module(fields):
    s = ""
    s += _gen_imports()
    cls_name, cls_def = _gen_class(fields)
    s += cls_def
    return cls_name, s


def _import(path):
    # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
    spec = importlib.util.spec_from_file_location(
        "{}{}".format(sys.modules[__name__].__name__, _counter), path
    )
    module = importlib.util.module_from_spec(spec)
    sys.modules[module.__name__] = module
    spec.loader.exec_module(module)
    return module