mirror of https://github.com/FoundationVision/GLEE
153 lines
4.9 KiB
Python
153 lines
4.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
|
|
import contextlib
|
|
from unittest import mock
|
|
import torch
|
|
|
|
from detectron2.modeling import poolers
|
|
from detectron2.modeling.proposal_generator import rpn
|
|
from detectron2.modeling.roi_heads import keypoint_head, mask_head
|
|
from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
|
|
|
|
from .c10 import (
|
|
Caffe2Compatible,
|
|
Caffe2FastRCNNOutputsInference,
|
|
Caffe2KeypointRCNNInference,
|
|
Caffe2MaskRCNNInference,
|
|
Caffe2ROIPooler,
|
|
Caffe2RPN,
|
|
)
|
|
|
|
|
|
class GenericMixin(object):
|
|
pass
|
|
|
|
|
|
class Caffe2CompatibleConverter(object):
|
|
"""
|
|
A GenericUpdater which implements the `create_from` interface, by modifying
|
|
module object and assign it with another class replaceCls.
|
|
"""
|
|
|
|
def __init__(self, replaceCls):
|
|
self.replaceCls = replaceCls
|
|
|
|
def create_from(self, module):
|
|
# update module's class to the new class
|
|
assert isinstance(module, torch.nn.Module)
|
|
if issubclass(self.replaceCls, GenericMixin):
|
|
# replaceCls should act as mixin, create a new class on-the-fly
|
|
new_class = type(
|
|
"{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__),
|
|
(self.replaceCls, module.__class__),
|
|
{}, # {"new_method": lambda self: ...},
|
|
)
|
|
module.__class__ = new_class
|
|
else:
|
|
# replaceCls is complete class, this allow arbitrary class swap
|
|
module.__class__ = self.replaceCls
|
|
|
|
# initialize Caffe2Compatible
|
|
if isinstance(module, Caffe2Compatible):
|
|
module.tensor_mode = False
|
|
|
|
return module
|
|
|
|
|
|
def patch(model, target, updater, *args, **kwargs):
|
|
"""
|
|
recursively (post-order) update all modules with the target type and its
|
|
subclasses, make a initialization/composition/inheritance/... via the
|
|
updater.create_from.
|
|
"""
|
|
for name, module in model.named_children():
|
|
model._modules[name] = patch(module, target, updater, *args, **kwargs)
|
|
if isinstance(model, target):
|
|
return updater.create_from(model, *args, **kwargs)
|
|
return model
|
|
|
|
|
|
def patch_generalized_rcnn(model):
|
|
ccc = Caffe2CompatibleConverter
|
|
model = patch(model, rpn.RPN, ccc(Caffe2RPN))
|
|
model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler))
|
|
|
|
return model
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def mock_fastrcnn_outputs_inference(
|
|
tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers
|
|
):
|
|
with mock.patch.object(
|
|
box_predictor_type,
|
|
"inference",
|
|
autospec=True,
|
|
side_effect=Caffe2FastRCNNOutputsInference(tensor_mode),
|
|
) as mocked_func:
|
|
yield
|
|
if check:
|
|
assert mocked_func.call_count > 0
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True):
|
|
with mock.patch(
|
|
"{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference()
|
|
) as mocked_func:
|
|
yield
|
|
if check:
|
|
assert mocked_func.call_count > 0
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True):
|
|
with mock.patch(
|
|
"{}.keypoint_rcnn_inference".format(patched_module),
|
|
side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint),
|
|
) as mocked_func:
|
|
yield
|
|
if check:
|
|
assert mocked_func.call_count > 0
|
|
|
|
|
|
class ROIHeadsPatcher:
|
|
def __init__(self, heads, use_heatmap_max_keypoint):
|
|
self.heads = heads
|
|
self.use_heatmap_max_keypoint = use_heatmap_max_keypoint
|
|
|
|
@contextlib.contextmanager
|
|
def mock_roi_heads(self, tensor_mode=True):
|
|
"""
|
|
Patching several inference functions inside ROIHeads and its subclasses
|
|
|
|
Args:
|
|
tensor_mode (bool): whether the inputs/outputs are caffe2's tensor
|
|
format or not. Default to True.
|
|
"""
|
|
# NOTE: this requries the `keypoint_rcnn_inference` and `mask_rcnn_inference`
|
|
# are called inside the same file as BaseXxxHead due to using mock.patch.
|
|
kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__
|
|
mask_head_mod = mask_head.BaseMaskRCNNHead.__module__
|
|
|
|
mock_ctx_managers = [
|
|
mock_fastrcnn_outputs_inference(
|
|
tensor_mode=tensor_mode,
|
|
check=True,
|
|
box_predictor_type=type(self.heads.box_predictor),
|
|
)
|
|
]
|
|
if getattr(self.heads, "keypoint_on", False):
|
|
mock_ctx_managers += [
|
|
mock_keypoint_rcnn_inference(
|
|
tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint
|
|
)
|
|
]
|
|
if getattr(self.heads, "mask_on", False):
|
|
mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)]
|
|
|
|
with contextlib.ExitStack() as stack: # python 3.3+
|
|
for mgr in mock_ctx_managers:
|
|
stack.enter_context(mgr)
|
|
yield
|