[Enhancement]: Support VFNet from MMDetection for OpenVINO and ONNX Runtime. (#195)

* Add deform_conv_openvino.

* Add get_bboxes_of_vfnet_head.

* Fix vfnet and add test_get_bboxes_of_vfnet_head.

* Update docs.

* Fix test_shufflenetv2_backbone__forward for openvino.

* Fixes.
This commit is contained in:
Semyon Bevzyuk 2021-11-16 05:59:59 +03:00 committed by GitHub
parent e142c72663
commit 49dd1cf678
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 350 additions and 29 deletions

View File

@ -24,7 +24,6 @@ python tools/deploy.py \
### List of supported models exportable to OpenVINO from MMDetection ### List of supported models exportable to OpenVINO from MMDetection
The table below lists the models that are guaranteed to be exportable to OpenVINO from MMDetection. The table below lists the models that are guaranteed to be exportable to OpenVINO from MMDetection.
| Model name | Config | Dynamic Shape | | Model name | Config | Dynamic Shape |
| :----------------: | :-----------------------------------------------------------------------: | :-----------: | | :----------------: | :-----------------------------------------------------------------------: | :-----------: |
| ATSS | `configs/atss/atss_r50_fpn_1x_coco.py` | Y | | ATSS | `configs/atss/atss_r50_fpn_1x_coco.py` | Y |
@ -39,7 +38,10 @@ The table below lists the models that are guaranteed to be exportable to OpenVIN
| SSD | `configs/ssd/ssd300_coco.py` | Y | | SSD | `configs/ssd/ssd300_coco.py` | Y |
| YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y | | YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y |
| YOLOX | `configs/yolox/yolox_tiny_8x8_300e_coco.py` | Y | | YOLOX | `configs/yolox/yolox_tiny_8x8_300e_coco.py` | Y |
| Faster R-CNN + DCN | `configs/dcn/faster_rcnn_r50_fpn_dconv_c3-c5_1x_coco.py` | Y |
| VFNet | `configs/vfnet/vfnet_r50_fpn_1x_coco.py` | Y |
Notes: Notes:
- For faster work in OpenVINO in the Faster-RCNN, Mask-RCNN, Cascade-RCNN, Cascade-Mask-RCNN models - For faster work in OpenVINO in the Faster-RCNN, Mask-RCNN, Cascade-RCNN, Cascade-Mask-RCNN models
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvinotoolkit.org/latest/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph. the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvinotoolkit.org/latest/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.
- Models "VFNet" and "Faster R-CNN + DCN" use the custom "DeformableConv2D" operation.

View File

@ -100,6 +100,7 @@ The table below lists the models that are guaranteed to be exportable to other b
| ATSS | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/atss/atss_r50_fpn_1x_coco.py | | ATSS | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/atss/atss_r50_fpn_1x_coco.py |
| Cascade R-CNN | MMDetection | Y | ? | ? | Y | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py | | Cascade R-CNN | MMDetection | Y | ? | ? | Y | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py |
| Cascade Mask R-CNN | MMDetection | Y | ? | ? | Y | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py | | Cascade Mask R-CNN | MMDetection | Y | ? | ? | Y | Y | $MMDET_DIR/configs/cascade_rcnn/cascade_mask_rcnn_r50_fpn_1x_coco.py |
| VFNet | MMDetection | Y | ? | ? | ? | Y | $MMDET_DIR/configs/vfnet/vfnet_r50_fpn_1x_coco.py |
| ResNet | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/resnet/resnet18_b32x8_imagenet.py | | ResNet | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/resnet/resnet18_b32x8_imagenet.py |
| ResNeXt | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/resnext/resnext50_32x4d_b32x8_imagenet.py | | ResNeXt | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/resnext/resnext50_32x4d_b32x8_imagenet.py |
| SE-ResNet | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/seresnet/seresnet50_b32x8_imagenet.py | | SE-ResNet | MMClassification | Y | Y | Y | Y | N | $MMCLS_DIR/configs/seresnet/seresnet50_b32x8_imagenet.py |

View File

@ -1,4 +1,5 @@
from .deform_conv import deform_conv_openvino
from .nms import * # noqa: F401,F403 from .nms import * # noqa: F401,F403
from .roi_align import roi_align_default from .roi_align import roi_align_default
__all__ = ['roi_align_default'] __all__ = ['roi_align_default', 'deform_conv_openvino']

View File

@ -0,0 +1,34 @@
from mmdeploy.core import SYMBOLIC_REWRITER
@SYMBOLIC_REWRITER.register_symbolic(
'mmcv.ops.deform_conv.DeformConv2dFunction', backend='openvino')
def deform_conv_openvino(ctx,
g,
input,
offset,
weight,
stride,
padding,
dilation,
groups,
deform_groups,
bias=False,
im2col_step=32):
"""Rewrite symbolic function for OpenVINO backend."""
assert not bias, 'The "bias" parameter should be False.'
assert groups == 1, 'The "groups" parameter should be 1.'
kh, kw = weight.type().sizes()[2:]
domain = 'org.openvinotoolkit'
op_name = 'DeformableConv2D'
return g.op(
f'{domain}::{op_name}',
input,
offset,
weight,
strides_i=stride,
pads_i=[p for pair in zip(padding, padding) for p in pair],
dilations_i=dilation,
groups_i=groups,
deformable_groups_i=deform_groups,
kernel_shape_i=[kh, kw])

View File

@ -3,12 +3,13 @@ from .atss_head import atss_head__get_bboxes
from .fcos_head import fcos_head__get_bboxes from .fcos_head import fcos_head__get_bboxes
from .fovea_head import fovea_head__get_bboxes from .fovea_head import fovea_head__get_bboxes
from .rpn_head import rpn_head__get_bboxes from .rpn_head import rpn_head__get_bboxes
from .vfnet_head import vfnet_head__get_bboxes
from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn
from .yolox_head import yolox_head__get_bboxes from .yolox_head import yolox_head__get_bboxes
__all__ = [ __all__ = [
'anchor_head__get_bboxes', 'atss_head__get_bboxes', 'anchor_head__get_bboxes', 'atss_head__get_bboxes',
'fcos_head__get_bboxes', 'fovea_head__get_bboxes', 'rpn_head__get_bboxes', 'fcos_head__get_bboxes', 'fovea_head__get_bboxes', 'rpn_head__get_bboxes',
'yolov3_head__get_bboxes', 'yolov3_head__get_bboxes__ncnn', 'vfnet_head__get_bboxes', 'yolov3_head__get_bboxes',
'yolox_head__get_bboxes' 'yolov3_head__get_bboxes__ncnn', 'yolox_head__get_bboxes'
] ]

View File

@ -0,0 +1,113 @@
import torch
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.mmdet.core import distance2bbox, multiclass_nms
from mmdeploy.utils import get_mmdet_params
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.VFNetHead.get_bboxes')
def vfnet_head__get_bboxes(ctx,
self,
cls_scores,
bbox_preds,
bbox_preds_refine,
img_metas,
cfg=None,
rescale=None,
with_nms=True):
"""Rewrite `get_bboxes` of VFNetHead for default backend.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
cls_scores (list[Tensor]): Box iou-aware scores for each scale
level with shape (N, num_points * num_classes, H, W).
bbox_preds (list[Tensor]): Box offsets for each scale
level with shape (N, num_points * 4, H, W).
bbox_preds_refine (list[Tensor]): Refined Box offsets for
each scale level with shape (N, num_points * 4, H, W).
img_metas (dict): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config): Test / postprocessing configuration,
if None, test_cfg would be used. Default: None.
rescale (bool): If True, return boxes in original image space.
Default: False.
with_nms (bool): If True, do nms before returning boxes.
Default: True.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor]: batch_mlvl_bboxes, batch_mlvl_scores
"""
assert len(cls_scores) == len(bbox_preds) == len(bbox_preds_refine), \
'The lengths of lists "cls_scores", "bbox_preds", "bbox_preds_refine"'\
' should be the same.'
num_levels = len(cls_scores)
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
points_list = self.get_points(featmap_sizes, bbox_preds[0].dtype,
bbox_preds[0].device)
cls_score_list = [cls_scores[i].detach() for i in range(num_levels)]
bbox_pred_list = [bbox_preds_refine[i].detach() for i in range(num_levels)]
cfg = self.test_cfg if cfg is None else cfg
batch_size = cls_scores[0].shape[0]
pre_topk = cfg.get('nms_pre', -1)
# loop over features, decode boxes
mlvl_bboxes = []
mlvl_scores = []
mlvl_points = []
for cls_score, bbox_pred, points, in zip(cls_score_list, bbox_pred_list,
points_list):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:], \
'The Height and Width should be the same.'
scores = cls_score.permute(0, 2, 3,
1).reshape(batch_size, -1,
self.cls_out_channels).sigmoid()
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
points = points.expand(batch_size, -1, 2)
if pre_topk > 0:
max_scores, _ = scores.max(-1)
_, topk_inds = max_scores.topk(pre_topk)
batch_inds = torch.arange(batch_size).view(-1,
1).expand_as(topk_inds)
points = points[batch_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
mlvl_bboxes.append(bbox_pred)
mlvl_scores.append(scores)
mlvl_points.append(points)
batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
batch_mlvl_points = torch.cat(mlvl_points, dim=1)
batch_mlvl_bboxes = distance2bbox(
batch_mlvl_points, batch_mlvl_bboxes, max_shape=img_metas['img_shape'])
if rescale:
batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor(
img_metas['scale_factor'])
if not with_nms:
return batch_mlvl_bboxes, batch_mlvl_scores
deploy_cfg = ctx.cfg
post_params = get_mmdet_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = cfg.get('nms_pre', post_params.pre_top_k)
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
return multiclass_nms(batch_mlvl_bboxes, batch_mlvl_scores,
max_output_boxes_per_class, iou_threshold,
score_threshold, pre_top_k, keep_top_k)

View File

@ -94,7 +94,7 @@ def test_shufflenetv2_backbone__forward(backend_type):
deploy_cfg = mmcv.Config( deploy_cfg = mmcv.Config(
dict( dict(
backend_config=dict(type=backend_type), backend_config=dict(type=backend_type),
onnx_config=dict(input_shape=None), onnx_config=dict(input_shape=None, output_names=['output']),
codebase_config=dict(type='mmcls', task='Classification'))) codebase_config=dict(type='mmcls', task='Classification')))
imgs = torch.rand((1, 16, 28, 28)) imgs = torch.rand((1, 16, 28, 28))

View File

@ -46,3 +46,48 @@ def test_ONNXNMSop(iou_threshold, score_threshold, max_output_boxes_per_class):
opset_version=11) opset_version=11)
model = onnx.load(onnx_file_path) model = onnx.load(onnx_file_path)
assert model.graph.node[3].op_type == 'NonMaxSuppression' assert model.graph.node[3].op_type == 'NonMaxSuppression'
def test_deform_conv_openvino():
pytest.importorskip('openvino', reason='requires openvino')
input = torch.Tensor([[[[1., 2., 3.], [0., 1., 2.], [3., 5., 2.]]]])
offset = torch.Tensor([[[[1.7000, 2.9000], [3.4000, 4.8000]],
[[1.1000, 2.0000], [2.1000, 1.9000]],
[[3.1000, 5.1000], [5.9000, 4.9000]],
[[2.0000, 4.1000], [4.0000, 6.6000]],
[[1.6000, 2.7000], [3.8000, 3.1000]],
[[2.5000, 4.3000], [4.2000, 5.3000]],
[[1.7000, 3.3000], [3.6000, 4.5000]],
[[1.7000, 3.4000], [5.2000, 6.1000]]]])
expected_output = torch.Tensor([[[[1.6500, 0.0000], [0.0000, 0.0000]]]])
from mmcv.ops.deform_conv import DeformConv2dFunction
def wrapped_function(input, offset):
weight = torch.Tensor([[[[0.4000, 0.2000], [0.1000, 0.9000]]]])
stride = (1, 1)
padding = (0, 0)
dilation = (1, 1)
groups = 1
deform_groups = 1
return DeformConv2dFunction.apply(input, offset, weight, stride,
padding, dilation, groups,
deform_groups)
wrapped_model = WrapFunction(wrapped_function).eval()
model_output = wrapped_model(input, offset)
assert torch.allclose(expected_output, model_output)
onnx_file_path = tempfile.NamedTemporaryFile().name
with RewriterContext({}, backend='openvino'), torch.no_grad():
torch.onnx.export(
wrapped_model, (input, offset),
onnx_file_path,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input', 'offset'],
output_names=['result'],
opset_version=11)
model = onnx.load(onnx_file_path)
assert model.graph.node[1].op_type == 'DeformableConv2D'
assert model.graph.node[1].domain == 'org.openvinotoolkit'

View File

@ -3,6 +3,7 @@ import importlib
import os import os
import random import random
import tempfile import tempfile
from typing import Dict, List
import mmcv import mmcv
import numpy as np import numpy as np
@ -27,6 +28,18 @@ def seed_everything(seed=1029):
torch.backends.cudnn.enabled = False torch.backends.cudnn.enabled = False
def convert_to_list(rewrite_output: Dict, output_names: List[str]) -> List:
"""Converts output from a dictionary to a list.
The new list will contain only those output values, whose names are in list
'output_names'.
"""
outputs = [
value for name, value in rewrite_output.items() if name in output_names
]
return outputs
def get_anchor_head_model(): def get_anchor_head_model():
"""AnchorHead Config.""" """AnchorHead Config."""
test_cfg = mmcv.Config( test_cfg = mmcv.Config(
@ -152,10 +165,7 @@ def test_anchor_head_get_bboxes(backend_type):
if is_backend_output: if is_backend_output:
if isinstance(rewrite_outputs, dict): if isinstance(rewrite_outputs, dict):
rewrite_outputs = [ rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
value for name, value in rewrite_outputs.items()
if name in output_names
]
for model_output, rewrite_output in zip(model_outputs[0], for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs): rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy() model_output = model_output.squeeze().cpu().numpy()
@ -507,10 +517,8 @@ def test_cascade_roi_head(backend_type):
deploy_cfg=deploy_cfg) deploy_cfg=deploy_cfg)
processed_backend_outputs = [] processed_backend_outputs = []
if isinstance(backend_outputs, dict): if isinstance(backend_outputs, dict):
processed_backend_outputs = [ processed_backend_outputs = convert_to_list(backend_outputs,
backend_outputs[name] for name in output_names output_names)
if name in backend_outputs
]
elif isinstance(backend_outputs, (list, tuple)) and \ elif isinstance(backend_outputs, (list, tuple)) and \
backend_outputs[0].shape == (1, 0, 5): backend_outputs[0].shape == (1, 0, 5):
processed_backend_outputs = np.zeros((1, 80, 5)) processed_backend_outputs = np.zeros((1, 80, 5))
@ -601,10 +609,7 @@ def test_get_bboxes_of_fovea_head(backend_type):
deploy_cfg=deploy_cfg) deploy_cfg=deploy_cfg)
if is_backend_output: if is_backend_output:
if isinstance(rewrite_outputs, dict): if isinstance(rewrite_outputs, dict):
rewrite_outputs = [ rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
value for name, value in rewrite_outputs.items()
if name in output_names
]
for model_output, rewrite_output in zip(model_outputs[0], for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs): rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy() model_output = model_output.squeeze().cpu().numpy()
@ -716,10 +721,7 @@ def test_get_bboxes_of_atss_head(backend_type):
deploy_cfg=deploy_cfg) deploy_cfg=deploy_cfg)
if is_backend_output: if is_backend_output:
if isinstance(rewrite_outputs, dict): if isinstance(rewrite_outputs, dict):
rewrite_outputs = [ rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
value for name, value in rewrite_outputs.items()
if name in output_names
]
for model_output, rewrite_output in zip(model_outputs[0], for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs): rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy() model_output = model_output.squeeze().cpu().numpy()
@ -863,10 +865,7 @@ def test_yolov3_head_get_bboxes(backend_type):
if is_backend_output: if is_backend_output:
if isinstance(rewrite_outputs, dict): if isinstance(rewrite_outputs, dict):
rewrite_outputs = [ rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
value for name, value in rewrite_outputs.items()
if name in output_names
]
for model_output, rewrite_output in zip(model_outputs[0], for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs): rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy() model_output = model_output.squeeze().cpu().numpy()
@ -965,10 +964,7 @@ def test_yolox_head_get_bboxes(backend_type):
if is_backend_output: if is_backend_output:
if isinstance(rewrite_outputs, dict): if isinstance(rewrite_outputs, dict):
rewrite_outputs = [ rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
value for name, value in rewrite_outputs.items()
if name in output_names
]
for model_output, rewrite_output in zip(model_outputs[0], for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs): rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy() model_output = model_output.squeeze().cpu().numpy()
@ -983,3 +979,131 @@ def test_yolox_head_get_bboxes(backend_type):
atol=1e-05) atol=1e-05)
else: else:
assert rewrite_outputs is not None assert rewrite_outputs is not None
def get_vfnet_head_model():
"""VFNet Head Config."""
test_cfg = mmcv.Config(
dict(
deploy_nms_pre=0,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.5),
max_per_img=100))
from mmdet.models import VFNetHead
model = VFNetHead(num_classes=4, in_channels=1, test_cfg=test_cfg)
model.requires_grad_(False)
model.cpu().eval()
return model
@pytest.mark.parametrize('backend_type', ['openvino'])
def test_get_bboxes_of_vfnet_head(backend_type):
"""Test get_bboxes rewrite of VFNet head."""
pytest.importorskip(backend_type, reason=f'requires {backend_type}')
class TestModel(torch.nn.Module):
"""Stub for VFNetHead with fake bbox_preds operations.
Then bbox_preds will be one of the inputs to the ONNX graph.
"""
def __init__(self, vfnet_head):
super().__init__()
self.vfnet_head = vfnet_head
def get_bboxes(self,
cls_scores,
bbox_preds,
bbox_preds_refine,
img_metas,
cfg=None,
rescale=None,
with_nms=True):
tmp_bbox_pred_refine = []
for bbox_pred, bbox_pred_refine in zip(bbox_preds,
bbox_preds_refine):
tmp = bbox_pred_refine + bbox_pred
tmp = tmp - bbox_pred
tmp_bbox_pred_refine.append(tmp)
bbox_preds_refine = tmp_bbox_pred_refine
return self.vfnet_head.get_bboxes(cls_scores, bbox_preds,
bbox_preds_refine, img_metas,
cfg, rescale, with_nms)
test_model = TestModel(get_vfnet_head_model())
test_model.requires_grad_(False)
test_model.cpu().eval()
s = 16
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
output_names = ['dets', 'labels']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=-1,
keep_top_k=100,
background_label_id=-1,
))))
seed_everything(1234)
cls_score = [
torch.rand(1, test_model.vfnet_head.num_classes, pow(2, i), pow(2, i))
for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
seed_everything(9101)
bbox_preds_refine = [
torch.rand(1, 4, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
model_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'bbox_preds_refine': bbox_preds_refine,
'img_metas': img_metas
}
model_outputs = get_model_outputs(test_model, 'get_bboxes', model_inputs)
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
test_model, 'get_bboxes', img_metas=img_metas[0], with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
'bbox_preds_refine': bbox_preds_refine
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
if is_backend_output:
if isinstance(rewrite_outputs, dict):
rewrite_outputs = convert_to_list(rewrite_outputs, output_names)
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
min_shape = min(model_output.shape[0], rewrite_output.shape[0])
assert np.allclose(
model_output[:min_shape],
rewrite_output[:min_shape],
rtol=1e-03,
atol=1e-05)
else:
assert rewrite_outputs is not None