From 331292a99229a18c0e21ea39a57c3312e86a4f6c Mon Sep 17 00:00:00 2001 From: tpoisonooo Date: Fri, 4 Nov 2022 20:54:01 +0800 Subject: [PATCH 1/4] Feature: support mmdet3d dev-1.x 1.1.0rc1 (#1225) * feat(mmdet3d): test pointpillars and centerpoint on ort, openvino and trt passed * fix(centerpoint): mvx_two_stage input error * fix(review): remove mode decorator * fix(mmdet3d): review advices * fix(regression): update mmdet3d.yml and test ort/openvino passed * unittest(mmdet3d): fix * fix(unittest): fix * fix(mmdet3d): unittest * fix(mmdet3d): unittest * fix(CI): remove mmcv.Config * fix(mmdet3d): unittest * fix(mmdet3d): support torch1.12 * fix(CI): use bigger point cloud file * improvement(mmdet3d): align backend outputs with torch * fix(mmdet3d): remove useless * style(mmdet3d): format code * style(mmdet3d): remove useless * fix(mmdet3d): sync vis_task * unittest(mmdet3d): add test * docs(mmdet3d): add docstring * unittest(ci): add unittest data * fix(mmdet3d): review advices * feat(mmdet3d): convert fail * style(mmdet3d): docstring * style(mmdet3d): docstring --- .../voxel-detection/voxel-detection_static.py | 2 +- mmdeploy/apis/pytorch2onnx.py | 4 +- mmdeploy/apis/visualize.py | 15 +- mmdeploy/backend/tensorrt/utils.py | 1 - mmdeploy/codebase/base/task.py | 1 + mmdeploy/codebase/mmdet3d/deploy/__init__.py | 3 +- .../codebase/mmdet3d/deploy/mmdetection3d.py | 128 ------ .../mmdet3d/deploy/voxel_detection.py | 347 ++++++--------- .../mmdet3d/deploy/voxel_detection_model.py | 420 ++++++++++++------ mmdeploy/codebase/mmdet3d/models/__init__.py | 2 - mmdeploy/codebase/mmdet3d/models/base.py | 37 +- .../codebase/mmdet3d/models/centerpoint.py | 190 -------- .../codebase/mmdet3d/models/mvx_two_stage.py | 137 +++--- .../codebase/mmdet3d/models/pillar_encode.py | 3 +- .../codebase/mmdet3d/models/pillar_scatter.py | 1 + mmdeploy/codebase/mmdet3d/models/voxelnet.py | 58 --- requirements/codebases.txt | 1 + tests/regression/mmdet3d.yml | 12 +- ...02_second_secfpn_8xb4-cyclic-20e_nus-3d.py | 141 ++++++ ...n_head-circlenms_8xb4-cyclic-20e_nus-3d.py | 4 + .../centerpoint_pillar02_second_secfpn_nus.py | 90 ++++ .../test_mmdet3d/data/cyclic-20e.py | 66 +++ .../test_mmdet3d/data/cyclic-40e.py | 68 +++ .../test_mmdet3d/data/default_runtime.py | 24 + .../test_mmdet3d/data/kitti-3d-3class.py | 133 ++++++ .../data/kitti/kitti_infos_val.pkl | Bin 5109 -> 7083 bytes .../training/velodyne_reduced/000008.bin | 1 + .../test_mmdet3d/data/model_cfg.py | 349 +++------------ .../test_codebase/test_mmdet3d/data/nus-3d.py | 133 ++++++ ...-0400__LIDAR_TOP__1537287083900561.pcd.bin | Bin 0 -> 8000 bytes .../test_mmdet3d/data/pointpillars.py | 114 +++++ .../data/pointpillars_hv_secfpn_kitti.py | 99 +++++ .../test_mmdet3d/test_mmdet3d_models.py | 68 +-- .../test_mmdet3d/test_voxel_detection.py | 83 +--- .../test_voxel_detection_model.py | 40 +- .../test_mmrotate/test_mmrotate_core.py | 20 +- .../test_mmrotate/test_mmrotate_models.py | 31 +- .../test_mmrotate/test_rotated_detection.py | 6 +- .../test_rotated_detection_model.py | 6 +- 39 files changed, 1564 insertions(+), 1274 deletions(-) delete mode 100644 mmdeploy/codebase/mmdet3d/deploy/mmdetection3d.py delete mode 100644 mmdeploy/codebase/mmdet3d/models/centerpoint.py delete mode 100644 mmdeploy/codebase/mmdet3d/models/voxelnet.py create mode 100644 tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_8xb4-cyclic-20e_nus-3d.py create mode 100644 tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_head-circlenms_8xb4-cyclic-20e_nus-3d.py create mode 100644 tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_nus.py create mode 100644 tests/test_codebase/test_mmdet3d/data/cyclic-20e.py create mode 100644 tests/test_codebase/test_mmdet3d/data/cyclic-40e.py create mode 100644 tests/test_codebase/test_mmdet3d/data/default_runtime.py create mode 100644 tests/test_codebase/test_mmdet3d/data/kitti-3d-3class.py create mode 120000 tests/test_codebase/test_mmdet3d/data/kitti/training/velodyne_reduced/000008.bin create mode 100644 tests/test_codebase/test_mmdet3d/data/nus-3d.py create mode 100644 tests/test_codebase/test_mmdet3d/data/nuscenes/n008-2018-09-18-12-07-26-0400__LIDAR_TOP__1537287083900561.pcd.bin create mode 100644 tests/test_codebase/test_mmdet3d/data/pointpillars.py create mode 100644 tests/test_codebase/test_mmdet3d/data/pointpillars_hv_secfpn_kitti.py diff --git a/configs/mmdet3d/voxel-detection/voxel-detection_static.py b/configs/mmdet3d/voxel-detection/voxel-detection_static.py index 406c16513..bba7e819f 100644 --- a/configs/mmdet3d/voxel-detection/voxel-detection_static.py +++ b/configs/mmdet3d/voxel-detection/voxel-detection_static.py @@ -3,4 +3,4 @@ codebase_config = dict( type='mmdet3d', task='VoxelDetection', model_type='end2end') onnx_config = dict( input_names=['voxels', 'num_points', 'coors'], - output_names=['scores', 'bbox_preds', 'dir_scores']) + output_names=['cls_score', 'bbox_pred', 'dir_cls_pred']) diff --git a/mmdeploy/apis/pytorch2onnx.py b/mmdeploy/apis/pytorch2onnx.py index 6b80543cd..8da387dd2 100644 --- a/mmdeploy/apis/pytorch2onnx.py +++ b/mmdeploy/apis/pytorch2onnx.py @@ -3,7 +3,6 @@ import os.path as osp from typing import Any, Optional, Union import mmengine -import torch from mmdeploy.apis.core.pipeline_manager import no_mp from mmdeploy.utils import (Backend, get_backend, get_dynamic_axes, @@ -64,7 +63,8 @@ def torch2onnx(img: Any, img, input_shape, data_preprocessor=getattr(torch_model, 'data_preprocessor', None)) - if not isinstance(model_inputs, torch.Tensor) and len(model_inputs) == 1: + + if isinstance(model_inputs, list) and len(model_inputs) == 1: model_inputs = model_inputs[0] data_samples = data['data_samples'] patch_metas = {'data_samples': data_samples} diff --git a/mmdeploy/apis/visualize.py b/mmdeploy/apis/visualize.py index 2ed6aee3d..947480d3a 100644 --- a/mmdeploy/apis/visualize.py +++ b/mmdeploy/apis/visualize.py @@ -71,11 +71,20 @@ def visualize_model(model_cfg: Union[str, mmengine.Config], with torch.no_grad(): result = model.test_step(model_inputs)[0] + visualize = True try: # check headless import tkinter tkinter.Tk() + except Exception as e: + from mmdeploy.utils import get_root_logger + logger = get_root_logger() + logger.warning( + f'render and display result skipped for headless device, exception {e}' # noqa: E501 + ) + visualize = False + if visualize is True: task_processor.visualize( image=img, model=model, @@ -83,9 +92,3 @@ def visualize_model(model_cfg: Union[str, mmengine.Config], output_file=output_file, window_name=backend.value, show_result=show_result) - except Exception as e: - from mmdeploy.utils import get_root_logger - logger = get_root_logger() - logger.warn( - f'render and display result skipped for headless device, exception {e}' # noqa: E501 - ) diff --git a/mmdeploy/backend/tensorrt/utils.py b/mmdeploy/backend/tensorrt/utils.py index fcd307a4d..d0d3d24e7 100644 --- a/mmdeploy/backend/tensorrt/utils.py +++ b/mmdeploy/backend/tensorrt/utils.py @@ -43,7 +43,6 @@ def load(path: str) -> trt.ICudaEngine: def search_cuda_version() -> str: """try cmd to get cuda version, then try `torch.cuda` - Returns: str: cuda version, for example 10.2 """ diff --git a/mmdeploy/codebase/base/task.py b/mmdeploy/codebase/base/task.py index 8b2417dbf..a915e31f7 100644 --- a/mmdeploy/codebase/base/task.py +++ b/mmdeploy/codebase/base/task.py @@ -75,6 +75,7 @@ class BaseTask(metaclass=ABCMeta): from mmengine.registry import MODELS data_preprocessor = MODELS.build(preprocess_cfg) + data_preprocessor.to(self.device) return data_preprocessor diff --git a/mmdeploy/codebase/mmdet3d/deploy/__init__.py b/mmdeploy/codebase/mmdet3d/deploy/__init__.py index 60ef615ac..bf72ab09d 100644 --- a/mmdeploy/codebase/mmdet3d/deploy/__init__.py +++ b/mmdeploy/codebase/mmdet3d/deploy/__init__.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .mmdetection3d import MMDetection3d -from .voxel_detection import VoxelDetection +from .voxel_detection import MMDetection3d, VoxelDetection from .voxel_detection_model import VoxelDetectionModel __all__ = ['MMDetection3d', 'VoxelDetection', 'VoxelDetectionModel'] diff --git a/mmdeploy/codebase/mmdet3d/deploy/mmdetection3d.py b/mmdeploy/codebase/mmdet3d/deploy/mmdetection3d.py deleted file mode 100644 index 0ce371d2f..000000000 --- a/mmdeploy/codebase/mmdet3d/deploy/mmdetection3d.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Union - -import mmengine -from mmcv.utils import Registry -from torch.utils.data import DataLoader, Dataset - -from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase -from mmdeploy.utils import Codebase, get_task_type - - -def __build_mmdet3d_task(model_cfg: mmengine.Config, - deploy_cfg: mmengine.Config, device: str, - registry: Registry) -> BaseTask: - task = get_task_type(deploy_cfg) - return registry.module_dict[task.value](model_cfg, deploy_cfg, device) - - -MMDET3D_TASK = Registry('mmdet3d_tasks', build_func=__build_mmdet3d_task) - - -@CODEBASE.register_module(Codebase.MMDET3D.value) -class MMDetection3d(MMCodebase): - - task_registry = MMDET3D_TASK - - def __init__(self): - super().__init__() - - @staticmethod - def build_task_processor(model_cfg: mmengine.Config, - deploy_cfg: mmengine.Config, - device: str) -> BaseTask: - """The interface to build the task processors of mmdet3d. - - Args: - model_cfg (str | mmengine.Config): Model config file. - deploy_cfg (str | mmengine.Config): Deployment config file. - device (str): A string specifying device type. - - Returns: - BaseTask: A task processor. - """ - - return MMDET3D_TASK.build(model_cfg, deploy_cfg, device) - - @classmethod - def register_deploy_modules(cls): - import mmdeploy.codebase.mmdet3d.models # noqa: F401 - - @classmethod - def register_all_modules(cls): - from mmdet3d.utils.setup_env import register_all_modules - - cls.register_deploy_modules() - register_all_modules(True) - - @staticmethod - def build_dataset(dataset_cfg: Union[str, mmengine.Config], *args, - **kwargs) -> Dataset: - """Build dataset for detection3d. - - Args: - dataset_cfg (str | mmengine.Config): The input dataset config. - - Returns: - Dataset: A PyTorch dataset. - """ - from mmdet3d.datasets import build_dataset as build_dataset_mmdet3d - - from mmdeploy.utils import load_config - dataset_cfg = load_config(dataset_cfg)[0] - data = dataset_cfg.data - - dataset = build_dataset_mmdet3d(data.test) - return dataset - - @staticmethod - def build_dataloader(dataset: Dataset, - samples_per_gpu: int, - workers_per_gpu: int, - num_gpus: int = 1, - dist: bool = False, - shuffle: bool = False, - seed: Optional[int] = None, - runner_type: str = 'EpochBasedRunner', - persistent_workers: bool = True, - **kwargs) -> DataLoader: - """Build dataloader for detection3d. - - Args: - dataset (Dataset): Input dataset. - samples_per_gpu (int): Number of training samples on each GPU, i.e. - ,batch size of each GPU. - workers_per_gpu (int): How many subprocesses to use for data - loading for each GPU. - num_gpus (int): Number of GPUs. Only used in non-distributed - training. - dist (bool): Distributed training/test or not. - Defaults to `False`. - shuffle (bool): Whether to shuffle the data at every epoch. - Defaults to `False`. - seed (int): An integer set to be seed. Default is `None`. - runner_type (str): Type of runner. Default: `EpochBasedRunner`. - persistent_workers (bool): If True, the data loader will not - shutdown the worker processes after a dataset has been consumed - once. This allows to maintain the workers `Dataset` instances - alive. This argument is only valid when PyTorch>=1.7.0. - Default: False. - kwargs: Any other keyword argument to be used to initialize - DataLoader. - - Returns: - DataLoader: A PyTorch dataloader. - """ - from mmdet3d.datasets import \ - build_dataloader as build_dataloader_mmdet3d - return build_dataloader_mmdet3d( - dataset, - samples_per_gpu, - workers_per_gpu, - num_gpus=num_gpus, - dist=dist, - shuffle=shuffle, - seed=seed, - runner_type=runner_type, - persistent_workers=persistent_workers, - **kwargs) diff --git a/mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py index 7a1d7c278..8854ee0f2 100644 --- a/mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py +++ b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py @@ -1,20 +1,60 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from copy import deepcopy +from typing import Dict, Optional, Sequence, Tuple, Union -import mmcv import mmengine import numpy as np import torch -import torch.nn as nn -from mmcv.parallel import collate, scatter -from mmdet3d.core.bbox import get_box_type -from mmdet3d.datasets.pipelines import Compose -from torch.utils.data import DataLoader, Dataset +from mmdet3d.structures import get_box_type +from mmengine import Config +from mmengine.dataset import Compose, pseudo_collate +from mmengine.model import BaseDataPreprocessor +from mmengine.registry import Registry -from mmdeploy.codebase.base import BaseTask -from mmdeploy.codebase.mmdet3d.deploy.mmdetection3d import MMDET3D_TASK -from mmdeploy.utils import Task, get_root_logger, load_config -from .voxel_detection_model import VoxelDetectionModel +from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase +from mmdeploy.utils import Codebase, Task + +MMDET3D_TASK = Registry('mmdet3d_tasks') + + +@CODEBASE.register_module(Codebase.MMDET3D.value) +class MMDetection3d(MMCodebase): + """MMDetection3d codebase class.""" + + task_registry = MMDET3D_TASK + + @classmethod + def register_deploy_modules(mmdet3d): + import mmdeploy.codebase.mmdet3d.models # noqa: F401 + + @classmethod + def register_all_modules(mmdet3d): + from mmdet3d.utils.setup_env import register_all_modules + + mmdet3d.register_deploy_modules() + register_all_modules(True) + + +def _get_dataset_metainfo(model_cfg: Config): + """Get metainfo of dataset. + + Args: + model_cfg Config: Input model Config object. + + Returns: + list[str]: A list of string specifying names of different class. + """ + + for dataloader_name in [ + 'test_dataloader', 'val_dataloader', 'train_dataloader' + ]: + if dataloader_name not in model_cfg: + continue + dataloader_cfg = model_cfg[dataloader_name] + dataset_cfg = dataloader_cfg.dataset + if 'metainfo' in dataset_cfg: + return dataset_cfg.metainfo + return None @MMDET3D_TASK.register_module(Task.VOXEL_DETECTION.value) @@ -36,168 +76,106 @@ class VoxelDetection(BaseTask): nn.Module: An initialized backend model. """ from .voxel_detection_model import build_voxel_detection_model + + data_preprocessor = deepcopy( + self.model_cfg.model.get('data_preprocessor', {})) + data_preprocessor.setdefault('type', 'mmdet3D.Det3DDataPreprocessor') + model = build_voxel_detection_model( - model_files, self.model_cfg, self.deploy_cfg, device=self.device) + model_files, + self.model_cfg, + self.deploy_cfg, + device=self.device, + data_preprocessor=data_preprocessor) + model = model.to(self.device) return model - def build_pytorch_model(self, - model_checkpoint: Optional[str] = None, - cfg_options: Optional[Dict] = None, - **kwargs) -> torch.nn.Module: - """Initialize torch model. - - Args: - model_checkpoint (str): The checkpoint file of torch model, - defaults to `None`. - cfg_options (dict): Optional config key-pair parameters. - Returns: - nn.Module: An initialized torch model generated by other OpenMMLab - codebases. - """ - from mmdet3d.apis import init_model - device = self.device - model = init_model(self.model_cfg, model_checkpoint, device) - return model.eval() - - def create_input(self, pcd: str, *args) -> Tuple[Dict, torch.Tensor]: + def create_input( + self, + pcd: str, + input_shape: Sequence[int] = None, + data_preprocessor: Optional[BaseDataPreprocessor] = None + ) -> Tuple[Dict, torch.Tensor]: """Create input for detector. Args: pcd (str): Input pcd file path. + input_shape (Sequence[int], optional): model input shape. + Defaults to None. + data_preprocessor (Optional[BaseDataPreprocessor], optional): + model input preprocess. Defaults to None. Returns: tuple: (data, input), meta information for the input pcd and model input. """ - data = VoxelDetection.read_pcd_file(pcd, self.model_cfg, self.device) - voxels, num_points, coors = VoxelDetectionModel.voxelize( - self.model_cfg, data['points'][0]) - return data, (voxels, num_points, coors) + + cfg = self.model_cfg + test_pipeline = deepcopy(cfg.test_dataloader.dataset.pipeline) + test_pipeline = Compose(test_pipeline) + box_type_3d, box_mode_3d = \ + get_box_type(cfg.test_dataloader.dataset.box_type_3d) + + data = [] + data_ = dict( + lidar_points=dict(lidar_path=pcd), + timestamp=1, + # for ScanNet demo we need axis_align_matrix + axis_align_matrix=np.eye(4), + box_type_3d=box_type_3d, + box_mode_3d=box_mode_3d) + data_ = test_pipeline(data_) + data.append(data_) + + collate_data = pseudo_collate(data) + data[0]['inputs']['points'] = data[0]['inputs']['points'].to( + self.device) + + if data_preprocessor is not None: + collate_data = data_preprocessor(collate_data, False) + voxels = collate_data['inputs']['voxels'] + inputs = [voxels['voxels'], voxels['num_points'], voxels['coors']] + else: + inputs = collate_data['inputs'] + return collate_data, inputs def visualize(self, + image: Union[str, np.ndarray], model: torch.nn.Module, - image: str, result: list, output_file: str, - window_name: str, + window_name: str = '', show_result: bool = False, - score_thr: float = 0.3): - """Visualize predictions of a model. + draw_gt: bool = False, + **kwargs): + """visualize backend output. Args: - model (nn.Module): Input model. - image (str): Pcd file to draw predictions on. - result (list): A list of predictions. - output_file (str): Output file to save result. - window_name (str): The name of visualization window. Defaults to - an empty string. - show_result (bool): Whether to show result in windows, defaults - to `False`. - score_thr (float): The score threshold to display the bbox. - Defaults to 0.3. + image (Union[str, np.ndarray]): pcd file path + result (list): output bbox, score and type + output_file (str): the directory to save output + window_name (str, optional): display window name + show_result (bool, optional): show result or not. + Defaults to False. + draw_gt (bool, optional): show gt or not. Defaults to False. """ - from mmdet3d.apis import show_result_meshlab - data = VoxelDetection.read_pcd_file(image, self.model_cfg, self.device) - show_result_meshlab( - data, - result, - output_file, - score_thr, + cfg = self.model_cfg + visualizer = super().get_visualizer(window_name, output_file) + visualizer.dataset_meta = _get_dataset_metainfo(cfg) + + # show the results + collate_data, _ = self.create_input(pcd=image) + + visualizer.add_datasample( + window_name, + dict(points=collate_data['inputs']['points'][0]), + data_sample=result, + draw_gt=draw_gt, show=show_result, - snapshot=1 - show_result, - task='det') - - @staticmethod - def read_pcd_file(pcd: str, model_cfg: Union[str, mmengine.Config], - device: str) -> Dict: - """Read data from pcd file and run test pipeline. - - Args: - pcd (str): Pcd file path. - model_cfg (str | mmengine.Config): The model config. - device (str): A string specifying device type. - - Returns: - dict: meta information for the input pcd. - """ - if isinstance(pcd, (list, tuple)): - pcd = pcd[0] - model_cfg = load_config(model_cfg)[0] - test_pipeline = Compose(model_cfg.data.test.pipeline) - box_type_3d, box_mode_3d = get_box_type( - model_cfg.data.test.box_type_3d) - data = dict( - pts_filename=pcd, - box_type_3d=box_type_3d, - box_mode_3d=box_mode_3d, - # for ScanNet demo we need axis_align_matrix - ann_info=dict(axis_align_matrix=np.eye(4)), - sweeps=[], - # set timestamp = 0 - timestamp=[0], - img_fields=[], - bbox3d_fields=[], - pts_mask_fields=[], - pts_seg_fields=[], - bbox_fields=[], - mask_fields=[], - seg_fields=[]) - data = test_pipeline(data) - data = collate([data], samples_per_gpu=1) - data['img_metas'] = [ - img_metas.data[0] for img_metas in data['img_metas'] - ] - data['points'] = [point.data[0] for point in data['points']] - if device != 'cpu': - data = scatter(data, [device])[0] - return data - - @staticmethod - def run_inference(model: nn.Module, - model_inputs: Dict[str, torch.Tensor]) -> List: - """Run inference once for a object detection model of mmdet3d. - - Args: - model (nn.Module): Input model. - model_inputs (dict): A dict containing model inputs tensor and - meta info. - - Returns: - list: The predictions of model inference. - """ - result = model( - return_loss=False, - points=model_inputs['points'], - img_metas=model_inputs['img_metas']) - return [result] - - @staticmethod - def evaluate_outputs(model_cfg, - outputs: Sequence, - dataset: Dataset, - metrics: Optional[str] = None, - out: Optional[str] = None, - metric_options: Optional[dict] = None, - format_only: bool = False, - log_file: Optional[str] = None): - if out: - logger = get_root_logger() - logger.info(f'\nwriting results to {out}') - mmcv.dump(outputs, out) - kwargs = {} if metric_options is None else metric_options - if format_only: - dataset.format_results(outputs, **kwargs) - if metrics: - eval_kwargs = model_cfg.get('evaluation', {}).copy() - # hard-code way to remove EvalHook args - for key in [ - 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', - 'rule' - ]: - eval_kwargs.pop(key, None) - eval_kwargs.pop(key, None) - eval_kwargs.update(dict(metric=metrics, **kwargs)) - dataset.evaluate(outputs, **eval_kwargs) + wait_time=0, + out_file=output_file, + pred_score_thr=0.0, + vis_task='lidar_det') def get_model_name(self, *args, **kwargs) -> str: """Get the model name. @@ -207,18 +185,6 @@ class VoxelDetection(BaseTask): """ raise NotImplementedError - def get_tensor_from_input(self, input_data: Dict[str, Any], - **kwargs) -> torch.Tensor: - """Get input tensor from input data. - - Args: - input_data (dict): Input data containing meta info and image - tensor. - Returns: - torch.Tensor: An image in `Tensor`. - """ - raise NotImplementedError - def get_partition_cfg(partition_type: str, **kwargs) -> Dict: """Get a certain partition config for mmdet. @@ -245,58 +211,3 @@ class VoxelDetection(BaseTask): dict: Composed of the preprocess information. """ raise NotImplementedError - - def single_gpu_test(self, - model: nn.Module, - data_loader: DataLoader, - show: bool = False, - out_dir: Optional[str] = None, - **kwargs) -> List: - """Run test with single gpu. - - Args: - model (nn.Module): Input model from nn.Module. - data_loader (DataLoader): PyTorch data loader. - show (bool): Specifying whether to show plotted results. Defaults - to `False`. - out_dir (str): A directory to save results, defaults to `None`. - - Returns: - list: The prediction results. - """ - model.eval() - results = [] - dataset = data_loader.dataset - - prog_bar = mmcv.ProgressBar(len(dataset)) - for i, data in enumerate(data_loader): - with torch.no_grad(): - result = model(data['points'][0].data, - data['img_metas'][0].data, False) - if show: - # Visualize the results of MMDetection3D model - # 'show_results' is MMdetection3D visualization API - if out_dir is None: - model.module.show_result( - data, - result, - out_dir='', - file_name='', - show=show, - snapshot=False, - score_thr=0.3) - else: - model.module.show_result( - data, - result, - out_dir=out_dir, - file_name=f'model_output{i}', - show=show, - snapshot=True, - score_thr=0.3) - results.extend(result) - - batch_size = len(result) - for _ in range(batch_size): - prog_bar.update() - return results diff --git a/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py index 49f35071b..d4603625d 100644 --- a/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py +++ b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py @@ -1,25 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Union import mmcv -import mmengine import torch -from mmcv.utils import Registry -from torch.nn import functional as F +from mmdet3d.structures.det3d_data_sample import SampleList +from mmengine import Config +from mmengine.model.base_model.data_preprocessor import BaseDataPreprocessor +from mmengine.registry import Registry +from mmengine.structures import BaseDataElement, InstanceData from mmdeploy.codebase.base import BaseBackendModel -from mmdeploy.core import RewriterContext from mmdeploy.utils import (Backend, get_backend, get_codebase_config, - get_root_logger, load_config) + load_config) - -def __build_backend_voxel_model(cls_name: str, registry: Registry, *args, - **kwargs): - return registry.module_dict[cls_name](*args, **kwargs) - - -__BACKEND_MODEL = mmcv.utils.Registry( - 'backend_voxel_detectors', build_func=__build_backend_voxel_model) +__BACKEND_MODEL = Registry('backend_voxel_detectors') @__BACKEND_MODEL.register_module('end2end') @@ -31,8 +25,8 @@ class VoxelDetectionModel(BaseBackendModel): backend_files (Sequence[str]): Paths to all required backend files (e.g. '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn). device (str): A string specifying device type. - model_cfg (str | mmengine.Config): The model config. - deploy_cfg (str|mmengine.Config): Deployment config file or loaded + model_cfg (str | Config): The model config. + deploy_cfg (str|Config): Deployment config file or loaded Config object. """ @@ -40,11 +34,15 @@ class VoxelDetectionModel(BaseBackendModel): backend: Backend, backend_files: Sequence[str], device: str, - model_cfg: mmengine.Config, - deploy_cfg: Union[str, mmengine.Config] = None): - super().__init__(deploy_cfg=deploy_cfg) - self.deploy_cfg = deploy_cfg + model_cfg: Union[str, Config], + deploy_cfg: Union[str, Config], + data_preprocessor: Optional[Union[dict, + torch.nn.Module]] = None, + **kwargs): + super().__init__( + deploy_cfg=deploy_cfg, data_preprocessor=data_preprocessor) self.model_cfg = model_cfg + self.deploy_cfg = deploy_cfg self.device = device self._init_wrapper( backend=backend, backend_files=backend_files, device=device) @@ -64,13 +62,14 @@ class VoxelDetectionModel(BaseBackendModel): backend=backend, backend_files=backend_files, device=device, + input_names=[self.input_name], output_names=output_names, deploy_cfg=self.deploy_cfg) def forward(self, - points: Sequence[torch.Tensor], - img_metas: Sequence[dict], - return_loss=False): + inputs: dict, + data_samples: Optional[List[BaseDataElement]] = None, + **kwargs) -> Any: """Run forward inference. Args: @@ -84,22 +83,25 @@ class VoxelDetectionModel(BaseBackendModel): Returns: list: A list contains predictions. """ - result_list = [] - for i in range(len(img_metas)): - voxels, num_points, coors = VoxelDetectionModel.voxelize( - self.model_cfg, points[i]) - input_dict = { - 'voxels': voxels, - 'num_points': num_points, - 'coors': coors - } - outputs = self.wrapper(input_dict) - result = VoxelDetectionModel.post_process(self.model_cfg, - self.deploy_cfg, outputs, - img_metas[i], - self.device)[0] - result_list.append(result) - return result_list + preprocessed = inputs['voxels'] + input_dict = { + 'voxels': preprocessed['voxels'].to(self.device), + 'num_points': preprocessed['num_points'].to(self.device), + 'coors': preprocessed['coors'].to(self.device) + } + + outputs = self.wrapper(input_dict) + + if data_samples is None: + return outputs + + prediction = VoxelDetectionModel.postprocess( + model_cfg=self.model_cfg, + deploy_cfg=self.deploy_cfg, + outs=outputs, + metas=data_samples) + + return prediction def show_result(self, data: Dict, @@ -132,120 +134,259 @@ class VoxelDetectionModel(BaseBackendModel): pred_labels=pred_labels) @staticmethod - def voxelize(model_cfg: Union[str, mmengine.Config], points: torch.Tensor): - """convert kitti points(N, >=3) to voxels. + def convert_to_datasample( + data_samples: SampleList, + data_instances_3d: Optional[List[InstanceData]] = None, + data_instances_2d: Optional[List[InstanceData]] = None, + ) -> SampleList: + """Convert results list to `Det3DDataSample`. + + Subclasses could override it to be compatible for some multi-modality + 3D detectors. Args: - model_cfg (str | mmengine.Config): The model config. - points (torch.Tensor): [N, ndim] float tensor. points[:, :3] - contain xyz points and points[:, 3:] contain other information - like reflectivity. + data_samples (list[:obj:`Det3DDataSample`]): The input data. + data_instances_3d (list[:obj:`InstanceData`], optional): 3D + Detection results of each sample. + data_instances_2d (list[:obj:`InstanceData`], optional): 2D + Detection results of each sample. Returns: - voxels: [M, max_points, ndim] float tensor. only contain points - and returned when max_points != -1. - coordinates: [M, 3] int32 tensor, always returned. - num_points_per_voxel: [M] int32 tensor. Only returned when - max_points != -1. + list[:obj:`Det3DDataSample`]: Detection results of the + input. Each Det3DDataSample usually contains + 'pred_instances_3d'. And the ``pred_instances_3d`` normally + contains following keys. + + - scores_3d (Tensor): Classification scores, has a shape + (num_instance, ) + - labels_3d (Tensor): Labels of 3D bboxes, has a shape + (num_instances, ). + - bboxes_3d (Tensor): Contains a tensor with shape + (num_instances, C) where C >=7. + + When there are image prediction in some models, it should + contains `pred_instances`, And the ``pred_instances`` normally + contains following keys. + + - scores (Tensor): Classification scores of image, has a shape + (num_instance, ) + - labels (Tensor): Predict Labels of 2D bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Contains a tensor with shape + (num_instances, 4). """ - from mmcv.ops import Voxelization - model_cfg = load_config(model_cfg)[0] - if 'voxel_layer' in model_cfg.model.keys(): - voxel_layer = model_cfg.model['voxel_layer'] - elif 'pts_voxel_layer' in model_cfg.model.keys(): - voxel_layer = model_cfg.model['pts_voxel_layer'] - else: - raise - voxel_layer = Voxelization(**voxel_layer) - voxels, coors, num_points = [], [], [] - for res in points: - res_voxels, res_coors, res_num_points = voxel_layer(res) - voxels.append(res_voxels) - coors.append(res_coors) - num_points.append(res_num_points) - voxels = torch.cat(voxels, dim=0) - num_points = torch.cat(num_points, dim=0) - coors_batch = [] - for i, coor in enumerate(coors): - coor_pad = F.pad(coor, (1, 0), mode='constant', value=i) - coors_batch.append(coor_pad) - coors_batch = torch.cat(coors_batch, dim=0) - return voxels, num_points, coors_batch + + assert (data_instances_2d is not None) or \ + (data_instances_3d is not None),\ + 'please pass at least one type of data_samples' + + if data_instances_2d is None: + data_instances_2d = [ + InstanceData() for _ in range(len(data_instances_3d)) + ] + if data_instances_3d is None: + data_instances_3d = [ + InstanceData() for _ in range(len(data_instances_2d)) + ] + + for i, data_sample in enumerate(data_samples): + data_sample.pred_instances_3d = data_instances_3d[i] + data_sample.pred_instances = data_instances_2d[i] + return data_samples @staticmethod - def post_process(model_cfg: Union[str, mmengine.Config], - deploy_cfg: Union[str, mmengine.Config], - outs: Dict, - img_metas: Dict, - device: str, - rescale=False): - """model post process. + def postprocess(model_cfg: Union[str, Config], + deploy_cfg: Union[str, Config], outs: Dict, metas: Dict): + """postprocess outputs to datasamples. Args: - model_cfg (str | mmengine.Config): The model config. - deploy_cfg (str|mmengine.Config): Deployment config file or loaded - Config object. - outs (Dict): Output of model's head. - img_metas(Dict): Meta info for pcd. - device (str): A string specifying device type. - rescale (list[torch.Tensor]): whether th rescale bbox. + model_cfg (Union[str, Config]): The model config from + trainning repo + deploy_cfg (Union[str, Config]): The deploy config to specify + backend and input shape + outs (Dict): output bbox, cls and score + metas (Dict): DataSample3D for bbox3d render + + Raises: + NotImplementedError: Only support mmdet3d model with `bbox_head` + Returns: - list: A list contains predictions, include bboxes, scores, labels. + DataSample3D: datatype for render """ - from mmdet3d.core import bbox3d2result - from mmdet3d.models.builder import build_head - model_cfg = load_config(model_cfg)[0] - deploy_cfg = load_config(deploy_cfg)[0] - if 'bbox_head' in model_cfg.model.keys(): - head_cfg = dict(**model_cfg.model['bbox_head']) - elif 'pts_bbox_head' in model_cfg.model.keys(): - head_cfg = dict(**model_cfg.model['pts_bbox_head']) + if 'cls_score' not in outs or 'bbox_pred' not in outs or 'dir_cls_pred' not in outs: # noqa: E501 + raise RuntimeError('output tensor not found') + + if 'test_cfg' not in model_cfg.model: + raise RuntimeError('test_cfg not found') + + from mmengine.registry import MODELS + cls_score = outs['cls_score'] + bbox_pred = outs['bbox_pred'] + dir_cls_pred = outs['dir_cls_pred'] + batch_input_metas = [data_samples.metainfo for data_samples in metas] + + head = None + cfg = None + if 'bbox_head' in model_cfg.model: + # pointpillars postprocess + head = MODELS.build(model_cfg.model['bbox_head']) + cfg = model_cfg.model.test_cfg + elif 'pts_bbox_head' in model_cfg.model: + # centerpoint postprocess + head = MODELS.build(model_cfg.model['pts_bbox_head']) + cfg = model_cfg.model.test_cfg.pts else: - raise NotImplementedError('Not supported model.') - head_cfg['train_cfg'] = None - head_cfg['test_cfg'] = model_cfg.model['test_cfg']\ - if 'pts' not in model_cfg.model['test_cfg'].keys()\ - else model_cfg.model['test_cfg']['pts'] - head = build_head(head_cfg) - if device == 'cpu': - logger = get_root_logger() - logger.warning( - 'Don\'t suggest using CPU device. Post process can\'t support.' - ) - if torch.cuda.is_available(): - device = 'cuda' - else: - raise NotImplementedError( - 'Post process don\'t support device=cpu') - cls_scores = [outs['scores'].to(device)] - bbox_preds = [outs['bbox_preds'].to(device)] - dir_scores = [outs['dir_scores'].to(device)] - with RewriterContext( - cfg=deploy_cfg, - backend=deploy_cfg.backend_config.type, - opset=deploy_cfg.onnx_config.opset_version): - bbox_list = head.get_bboxes( - cls_scores, bbox_preds, dir_scores, img_metas, rescale=False) - bbox_results = [ - bbox3d2result(bboxes, scores, labels) - for bboxes, scores, labels in bbox_list - ] - return bbox_results + raise NotImplementedError('mmdet3d model bbox_head not found') + + if not hasattr(head, 'task_heads'): + data_instances_3d = head.predict_by_feat( + cls_scores=[cls_score], + bbox_preds=[bbox_pred], + dir_cls_preds=[dir_cls_pred], + batch_input_metas=batch_input_metas, + cfg=cfg) + + data_samples = VoxelDetectionModel.convert_to_datasample( + data_samples=metas, data_instances_3d=data_instances_3d) + + else: + pts = model_cfg.model.test_cfg.pts + + rets = [] + scores_range = [0] + bbox_range = [0] + dir_range = [0] + for i, _ in enumerate(head.task_heads): + scores_range.append(scores_range[i] + head.num_classes[i]) + bbox_range.append(bbox_range[i] + 8) + dir_range.append(dir_range[i] + 2) + + for task_id in range(len(head.num_classes)): + num_class_with_bg = head.num_classes[task_id] + + batch_heatmap = cls_score[:, + scores_range[task_id]:scores_range[ + task_id + 1], ...].sigmoid() + + batch_reg = bbox_pred[:, + bbox_range[task_id]:bbox_range[task_id] + + 2, ...] + batch_hei = bbox_pred[:, bbox_range[task_id] + + 2:bbox_range[task_id] + 3, ...] + + if head.norm_bbox: + batch_dim = torch.exp(bbox_pred[:, bbox_range[task_id] + + 3:bbox_range[task_id] + 6, + ...]) + else: + batch_dim = bbox_pred[:, bbox_range[task_id] + + 3:bbox_range[task_id] + 6, ...] + + batch_vel = bbox_pred[:, bbox_range[task_id] + + 6:bbox_range[task_id + 1], ...] + + batch_rots = dir_cls_pred[:, + dir_range[task_id]:dir_range[task_id + + 1], + ...][:, 0].unsqueeze(1) + batch_rotc = dir_cls_pred[:, + dir_range[task_id]:dir_range[task_id + + 1], + ...][:, 1].unsqueeze(1) + + temp = head.bbox_coder.decode( + batch_heatmap, + batch_rots, + batch_rotc, + batch_hei, + batch_dim, + batch_vel, + reg=batch_reg, + task_id=task_id) + + assert pts['nms_type'] in ['circle', 'rotate'] + batch_reg_preds = [box['bboxes'] for box in temp] + batch_cls_preds = [box['scores'] for box in temp] + batch_cls_labels = [box['labels'] for box in temp] + if pts['nms_type'] == 'circle': + boxes3d = temp[0]['bboxes'] + scores = temp[0]['scores'] + labels = temp[0]['labels'] + centers = boxes3d[:, [0, 1]] + boxes = torch.cat([centers, scores.view(-1, 1)], dim=1) + from mmdet3d.models.layers import circle_nms + keep = torch.tensor( + circle_nms( + boxes.detach().cpu().numpy(), + pts['min_radius'][task_id], + post_max_size=pts['post_max_size']), + dtype=torch.long, + device=boxes.device) + + boxes3d = boxes3d[keep] + scores = scores[keep] + labels = labels[keep] + ret = dict(bboxes=boxes3d, scores=scores, labels=labels) + ret_task = [ret] + rets.append(ret_task) + else: + rets.append( + head.get_task_detections(num_class_with_bg, + batch_cls_preds, + batch_reg_preds, + batch_cls_labels, + batch_input_metas)) + + # Merge branches results + num_samples = len(rets[0]) + + ret_list = [] + for i in range(num_samples): + temp_instances = InstanceData() + for k in rets[0][i].keys(): + if k == 'bboxes': + bboxes = torch.cat([ret[i][k] for ret in rets]) + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 + bboxes = batch_input_metas[i]['box_type_3d']( + bboxes, head.bbox_coder.code_size) + elif k == 'scores': + scores = torch.cat([ret[i][k] for ret in rets]) + elif k == 'labels': + flag = 0 + for j, num_class in enumerate(head.num_classes): + rets[j][i][k] += flag + flag += num_class + labels = torch.cat([ret[i][k].int() for ret in rets]) + temp_instances.bboxes_3d = bboxes + temp_instances.scores_3d = scores + temp_instances.labels_3d = labels + ret_list.append(temp_instances) + + data_samples = VoxelDetectionModel.convert_to_datasample( + metas, data_instances_3d=ret_list) + + return data_samples -def build_voxel_detection_model(model_files: Sequence[str], - model_cfg: Union[str, mmengine.Config], - deploy_cfg: Union[str, mmengine.Config], - device: str): +def build_voxel_detection_model( + model_files: Sequence[str], + model_cfg: Union[str, Config], + deploy_cfg: Union[str, Config], + device: str, + data_preprocessor: Optional[Union[Config, + BaseDataPreprocessor]] = None, + **kwargs): """Build 3d voxel object detection model for different backends. Args: model_files (Sequence[str]): Input model file(s). - model_cfg (str | mmengine.Config): Input model config file or Config + model_cfg (str | Config): Input model config file or Config object. - deploy_cfg (str | mmengine.Config): Input deployment config file or + deploy_cfg (str | Config): Input deployment config file or Config object. device (str): Device to input model + data_preprocessor (BaseDataPreprocessor | Config): The data + preprocessor of the model. Returns: VoxelDetectionModel: Detector for a configured backend. @@ -256,11 +397,14 @@ def build_voxel_detection_model(model_files: Sequence[str], model_type = get_codebase_config(deploy_cfg).get('model_type', 'end2end') backend_detector = __BACKEND_MODEL.build( - model_type, - backend=backend, - backend_files=model_files, - device=device, - model_cfg=model_cfg, - deploy_cfg=deploy_cfg) + dict( + type=model_type, + backend=backend, + backend_files=model_files, + device=device, + model_cfg=model_cfg, + deploy_cfg=deploy_cfg, + data_preprocessor=data_preprocessor, + **kwargs)) return backend_detector diff --git a/mmdeploy/codebase/mmdet3d/models/__init__.py b/mmdeploy/codebase/mmdet3d/models/__init__.py index 8de0c41b3..f9cd7c328 100644 --- a/mmdeploy/codebase/mmdet3d/models/__init__.py +++ b/mmdeploy/codebase/mmdet3d/models/__init__.py @@ -1,7 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import * # noqa: F401,F403 -from .centerpoint import * # noqa: F401,F403 from .mvx_two_stage import * # noqa: F401,F403 from .pillar_encode import * # noqa: F401,F403 from .pillar_scatter import * # noqa: F401,F403 -from .voxelnet import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet3d/models/base.py b/mmdeploy/codebase/mmdet3d/models/base.py index 61fddae8e..38d35cd95 100644 --- a/mmdeploy/codebase/mmdet3d/models/base.py +++ b/mmdeploy/codebase/mmdet3d/models/base.py @@ -1,23 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import torch + from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( - 'mmdet3d.models.detectors.base.Base3DDetector.forward_test') -def base3ddetector__forward_test(ctx, - self, - voxels, - num_points, - coors, - img_metas=None, - img=None, - rescale=False): - """Rewrite this function to run simple_test directly.""" - return self.simple_test(voxels, num_points, coors, img_metas, img) + 'mmdet3d.models.detectors.Base3DDetector.forward' # noqa: E501 +) +def basedetector__forward(ctx, + self, + inputs: list, + data_samples=None, + **kwargs) -> Tuple[List[torch.Tensor]]: + """Extract features of images.""" - -@FUNCTION_REWRITER.register_rewriter( - 'mmdet3d.models.detectors.base.Base3DDetector.forward') -def base3ddetector__forward(ctx, self, *args, **kwargs): - """Rewrite this function to run the model directly.""" - return self.forward_test(*args) + batch_inputs_dict = { + 'voxels': { + 'voxels': inputs[0], + 'num_points': inputs[1], + 'coors': inputs[2] + } + } + return self._forward(batch_inputs_dict, data_samples, **kwargs) diff --git a/mmdeploy/codebase/mmdet3d/models/centerpoint.py b/mmdeploy/codebase/mmdet3d/models/centerpoint.py deleted file mode 100644 index fe6657e2b..000000000 --- a/mmdeploy/codebase/mmdet3d/models/centerpoint.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -from mmdet3d.core import circle_nms - -from mmdeploy.core import FUNCTION_REWRITER - - -@FUNCTION_REWRITER.register_rewriter( - 'mmdet3d.models.detectors.centerpoint.CenterPoint.extract_pts_feat') -def centerpoint__extract_pts_feat(ctx, self, voxels, num_points, coors, - img_feats, img_metas): - """Extract features from points. Rewrite this func to remove voxelize op. - - Args: - voxels (torch.Tensor): Point features or raw points in shape (N, M, C). - num_points (torch.Tensor): Number of points in each voxel. - coors (torch.Tensor): Coordinates of each voxel. - img_feats (list[torch.Tensor], optional): Image features used for - multi-modality fusion. Defaults to None. - img_metas (list[dict]): Meta information of samples. - - Returns: - torch.Tensor: Points feature. - """ - if not self.with_pts_bbox: - return None - - voxel_features = self.pts_voxel_encoder(voxels, num_points, coors) - batch_size = coors[-1, 0] + 1 - x = self.pts_middle_encoder(voxel_features, coors, batch_size) - x = self.pts_backbone(x) - if self.with_pts_neck: - x = self.pts_neck(x) - return x - - -@FUNCTION_REWRITER.register_rewriter( - 'mmdet3d.models.detectors.centerpoint.CenterPoint.simple_test_pts') -def centerpoint__simple_test_pts(ctx, self, x, img_metas, rescale=False): - """Rewrite this func to format model outputs. - - Args: - x (torch.Tensor): Input points feature. - img_metas (list[dict]): Meta information of samples. - rescale (bool): Whether need rescale. - - Returns: - List: Result of model. - """ - outs = self.pts_bbox_head(x) - bbox_preds, scores, dir_scores = [], [], [] - for task_res in outs: - bbox_preds.append(task_res[0]['reg']) - bbox_preds.append(task_res[0]['height']) - bbox_preds.append(task_res[0]['dim']) - if 'vel' in task_res[0].keys(): - bbox_preds.append(task_res[0]['vel']) - scores.append(task_res[0]['heatmap']) - dir_scores.append(task_res[0]['rot']) - bbox_preds = torch.cat(bbox_preds, dim=1) - scores = torch.cat(scores, dim=1) - dir_scores = torch.cat(dir_scores, dim=1) - return scores, bbox_preds, dir_scores - - -@FUNCTION_REWRITER.register_rewriter( - 'mmdet3d.models.dense_heads.centerpoint_head.CenterHead.get_bboxes') -def centerpoint__get_bbox(ctx, - self, - cls_scores, - bbox_preds, - dir_scores, - img_metas, - img=None, - rescale=False): - """Rewrite this func to format func inputs. - - Args - cls_scores (list[torch.Tensor]): Classification predicts results. - bbox_preds (list[torch.Tensor]): Bbox predicts results. - dir_scores (list[torch.Tensor]): Dir predicts results. - img_metas (list[dict]): Point cloud and image's meta info. - img (torch.Tensor): Input image. - rescale (Bool): Whether need rescale. - - Returns: - list[dict]: Decoded bbox, scores and labels after nms. - """ - rets = [] - scores_range = [0] - bbox_range = [0] - dir_range = [0] - for i, task_head in enumerate(self.task_heads): - scores_range.append(scores_range[i] + self.num_classes[i]) - bbox_range.append(bbox_range[i] + 8) - dir_range.append(dir_range[i] + 2) - for task_id in range(len(self.num_classes)): - num_class_with_bg = self.num_classes[task_id] - - batch_heatmap = cls_scores[ - 0][:, scores_range[task_id]:scores_range[task_id + 1], - ...].sigmoid() - - batch_reg = bbox_preds[0][:, - bbox_range[task_id]:bbox_range[task_id] + 2, - ...] - batch_hei = bbox_preds[0][:, bbox_range[task_id] + - 2:bbox_range[task_id] + 3, ...] - - if self.norm_bbox: - batch_dim = torch.exp(bbox_preds[0][:, bbox_range[task_id] + - 3:bbox_range[task_id] + 6, - ...]) - else: - batch_dim = bbox_preds[0][:, bbox_range[task_id] + - 3:bbox_range[task_id] + 6, ...] - - batch_vel = bbox_preds[0][:, bbox_range[task_id] + - 6:bbox_range[task_id + 1], ...] - - batch_rots = dir_scores[0][:, - dir_range[task_id]:dir_range[task_id + 1], - ...][:, 0].unsqueeze(1) - batch_rotc = dir_scores[0][:, - dir_range[task_id]:dir_range[task_id + 1], - ...][:, 1].unsqueeze(1) - - temp = self.bbox_coder.decode( - batch_heatmap, - batch_rots, - batch_rotc, - batch_hei, - batch_dim, - batch_vel, - reg=batch_reg, - task_id=task_id) - if 'pts' in self.test_cfg.keys(): - self.test_cfg = self.test_cfg.pts - assert self.test_cfg['nms_type'] in ['circle', 'rotate'] - batch_reg_preds = [box['bboxes'] for box in temp] - batch_cls_preds = [box['scores'] for box in temp] - batch_cls_labels = [box['labels'] for box in temp] - if self.test_cfg['nms_type'] == 'circle': - - boxes3d = temp[0]['bboxes'] - scores = temp[0]['scores'] - labels = temp[0]['labels'] - centers = boxes3d[:, [0, 1]] - boxes = torch.cat([centers, scores.view(-1, 1)], dim=1) - keep = torch.tensor( - circle_nms( - boxes.detach().cpu().numpy(), - self.test_cfg['min_radius'][task_id], - post_max_size=self.test_cfg['post_max_size']), - dtype=torch.long, - device=boxes.device) - - boxes3d = boxes3d[keep] - scores = scores[keep] - labels = labels[keep] - ret = dict(bboxes=boxes3d, scores=scores, labels=labels) - ret_task = [ret] - rets.append(ret_task) - else: - rets.append( - self.get_task_detections(num_class_with_bg, batch_cls_preds, - batch_reg_preds, batch_cls_labels, - img_metas)) - - # Merge branches results - num_samples = len(rets[0]) - - ret_list = [] - for i in range(num_samples): - for k in rets[0][i].keys(): - if k == 'bboxes': - bboxes = torch.cat([ret[i][k] for ret in rets]) - bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 - bboxes = img_metas[i]['box_type_3d'](bboxes, - self.bbox_coder.code_size) - elif k == 'scores': - scores = torch.cat([ret[i][k] for ret in rets]) - elif k == 'labels': - flag = 0 - for j, num_class in enumerate(self.num_classes): - rets[j][i][k] += flag - flag += num_class - labels = torch.cat([ret[i][k].int() for ret in rets]) - ret_list.append([bboxes, scores, labels]) - return ret_list diff --git a/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py b/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py index b4dc388f2..1e75d08ae 100644 --- a/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py +++ b/mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py @@ -1,41 +1,33 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch + from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( - 'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.simple_test') -def mvxtwostagedetector__simple_test(ctx, - self, - voxels, - num_points, - coors, - img_metas, - img=None, - rescale=False): - """Rewrite this func to remove voxelize op. - - Args: - voxels (torch.Tensor): Point features or raw points in shape (N, M, C). - num_points (torch.Tensor): Number of points in each voxel. - coors (torch.Tensor): Coordinates of each voxel. - img_metas (list[dict]): Meta information of samples. - img (torch.Tensor): Input image. - rescale (Bool): Whether need rescale. - - Returns: - list[dict]: Decoded bbox, scores and labels after nms. - """ - _, pts_feats = self.extract_feat( - voxels, num_points, coors, img=img, img_metas=img_metas) - if pts_feats and self.with_pts_bbox: - bbox_pts = self.simple_test_pts(pts_feats, img_metas, rescale=rescale) - return bbox_pts + 'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.extract_img_feat' # noqa: E501 +) +def mvxtwostagedetector__extract_img_feat(ctx, self, + img: torch.Tensor) -> dict: + """Extract features of images.""" + if self.with_img_backbone and img is not None: + if img.dim() == 5 and img.size(0) == 1: + img.squeeze_() + elif img.dim() == 5 and img.size(0) > 1: + B, N, C, H, W = img.size() + img = img.view(B * N, C, H, W) + img_feats = self.img_backbone(img) + else: + return None + if self.with_img_neck: + img_feats = self.img_neck(img_feats) + return img_feats @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.extract_feat') -def mvxtwostagedetector__extract_feat(ctx, self, voxels, num_points, coors, - img, img_metas): +def mvxtwostagedetector__extract_feat(ctx, self, + batch_inputs_dict: dict) -> tuple: """Rewrite this func to remove voxelize op. Args: @@ -44,63 +36,58 @@ def mvxtwostagedetector__extract_feat(ctx, self, voxels, num_points, coors, coors (torch.Tensor): Coordinates of each voxel. img (torch.Tensor): Input image. img_metas (list[dict]): Meta information of samples. - Returns: tuple(torch.Tensor) : image feature and points feather. """ - img_feats = self.extract_img_feat(img, img_metas) - pts_feats = self.extract_pts_feat(voxels, num_points, coors, img_feats, - img_metas) + voxel_dict = batch_inputs_dict.get('voxels', None) + imgs = batch_inputs_dict.get('imgs', None) + points = batch_inputs_dict.get('points', None) + img_feats = self.extract_img_feat(imgs) + pts_feats = self.extract_pts_feat( + voxel_dict, points=points, img_feats=img_feats) return (img_feats, pts_feats) @FUNCTION_REWRITER.register_rewriter( - 'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.' - 'extract_pts_feat') -def mvxtwostagedetector__extract_pts_feat(ctx, self, voxels, num_points, coors, - img_feats, img_metas): - """Extract features from points. Rewrite this func to remove voxelize op. + 'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.forward') +def mvxtwostagedetector__forward(ctx, self, inputs: list, **kwargs): + """Rewrite this func to remove voxelize op. Args: - voxels (torch.Tensor): Point features or raw points in shape (N, M, C). - num_points (torch.Tensor): Number of points in each voxel. - coors (torch.Tensor): Coordinates of each voxel. - img_feats (list[torch.Tensor], optional): Image features used for - multi-modality fusion. Defaults to None. - img_metas (list[dict]): Meta information of samples. + inputs (list): voxels, num_points and coors compose the input list + data_samples (DataSample3D): intermediate format within multiple + algorithm framework Returns: - torch.Tensor: Points feature. + bbox (Tensor): Decoded bbox after nms + scores (Tensor): bbox scores + labels (Tensor): bbox labels """ - if not self.with_pts_bbox: - return None - voxel_features = self.pts_voxel_encoder(voxels, num_points, coors, - img_feats, img_metas) - batch_size = coors[-1, 0] + 1 - x = self.pts_middle_encoder(voxel_features, coors, batch_size) - x = self.pts_backbone(x) - if self.with_pts_neck: - x = self.pts_neck(x) - return x + batch_inputs_dict = { + 'voxels': { + 'voxels': inputs[0], + 'num_points': inputs[1], + 'coors': inputs[2] + } + } + _, pts_feats = self.extract_feat(batch_inputs_dict=batch_inputs_dict) + outs = self.pts_bbox_head(pts_feats) -@FUNCTION_REWRITER.register_rewriter( - 'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.' - 'simple_test_pts') -def mvxtwostagedetector__simple_test_pts(ctx, - self, - x, - img_metas, - rescale=False): - """Rewrite this func to format model outputs. - - Args: - x (torch.Tensor): Input points feature. - img_metas (list[dict]): Meta information of samples. - rescale (bool): Whether need rescale. - - Returns: - List: Result of model. - """ - bbox_preds, scores, dir_scores = self.pts_bbox_head(x) - return bbox_preds, scores, dir_scores + if type(outs[0][0]) is dict: + bbox_preds, scores, dir_scores = [], [], [] + for task_res in outs: + bbox_preds.append(task_res[0]['reg']) + bbox_preds.append(task_res[0]['height']) + bbox_preds.append(task_res[0]['dim']) + if 'vel' in task_res[0].keys(): + bbox_preds.append(task_res[0]['vel']) + scores.append(task_res[0]['heatmap']) + dir_scores.append(task_res[0]['rot']) + bbox_preds = torch.cat(bbox_preds, dim=1) + scores = torch.cat(scores, dim=1) + dir_scores = torch.cat(dir_scores, dim=1) + return scores, bbox_preds, dir_scores + else: + cls_score, bbox_pred, dir_cls_pred = outs[0][0], outs[1][0], outs[2][0] + return cls_score, bbox_pred, dir_cls_pred diff --git a/mmdeploy/codebase/mmdet3d/models/pillar_encode.py b/mmdeploy/codebase/mmdet3d/models/pillar_encode.py index 23d6c8d15..4908a5707 100644 --- a/mmdeploy/codebase/mmdet3d/models/pillar_encode.py +++ b/mmdeploy/codebase/mmdet3d/models/pillar_encode.py @@ -7,7 +7,8 @@ from mmdeploy.core import FUNCTION_REWRITER @FUNCTION_REWRITER.register_rewriter( 'mmdet3d.models.voxel_encoders.pillar_encoder.PillarFeatureNet.forward') -def pillar_encoder__forward(ctx, self, features, num_points, coors): +def pillar_encoder__forward(ctx, self, features, num_points, coors, *args, + **kwargs): """Rewrite this func to optimize node. Modify the code at _with_voxel_center and use slice instead of the original operation. diff --git a/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py b/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py index 523d6762e..5844c3909 100644 --- a/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py +++ b/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py @@ -30,6 +30,7 @@ def pointpillarsscatter__forward(ctx, indices = indices.long() voxels = voxel_features.t() # Now scatter the blob back to the canvas. + canvas.scatter_( dim=1, index=indices.expand(canvas.shape[0], -1), src=voxels) # Undo the column stacking to final 4-dim tensor diff --git a/mmdeploy/codebase/mmdet3d/models/voxelnet.py b/mmdeploy/codebase/mmdet3d/models/voxelnet.py deleted file mode 100644 index e5d285bb2..000000000 --- a/mmdeploy/codebase/mmdet3d/models/voxelnet.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from mmdeploy.core import FUNCTION_REWRITER - - -@FUNCTION_REWRITER.register_rewriter( - 'mmdet3d.models.detectors.voxelnet.VoxelNet.simple_test') -def voxelnet__simple_test(ctx, - self, - voxels, - num_points, - coors, - img_metas=None, - imgs=None, - rescale=False): - """Test function without augmentaiton. Rewrite this func to remove model - post process. - - Args: - voxels (torch.Tensor): Point features or raw points in shape (N, M, C). - num_points (torch.Tensor): Number of points in each pillar. - coors (torch.Tensor): Coordinates of each voxel. - input_metas (list[dict]): Contain pcd meta info. - - Returns: - List: Result of model. - """ - x = self.extract_feat(voxels, num_points, coors, img_metas) - bbox_preds, scores, dir_scores = self.bbox_head(x) - return bbox_preds, scores, dir_scores - - -@FUNCTION_REWRITER.register_rewriter( - 'mmdet3d.models.detectors.voxelnet.VoxelNet.extract_feat') -def voxelnet__extract_feat(ctx, - self, - voxels, - num_points, - coors, - img_metas=None): - """Extract features from points. Rewrite this func to remove voxelize op. - - Args: - voxels (torch.Tensor): Point features or raw points in shape (N, M, C). - num_points (torch.Tensor): Number of points in each pillar. - coors (torch.Tensor): Coordinates of each voxel. - input_metas (list[dict]): Contain pcd meta info. - - Returns: - torch.Tensor: Features from points. - """ - voxel_features = self.voxel_encoder(voxels, num_points, coors) - batch_size = coors[-1, 0] + 1 # refactor - assert batch_size == 1 - x = self.middle_encoder(voxel_features, coors, batch_size) - x = self.backbone(x) - if self.with_neck: - x = self.neck(x) - return x diff --git a/requirements/codebases.txt b/requirements/codebases.txt index d08a2d4c5..f15091fa0 100644 --- a/requirements/codebases.txt +++ b/requirements/codebases.txt @@ -1,5 +1,6 @@ mmcls>=1.0.0rc2 mmdet @ git+https://github.com/open-mmlab/mmdetection.git@dev-3.x +mmdet3d @ git+https://github.com/open-mmlab/mmdetection3d.git@dev-1.x mmedit @ git+https://github.com/open-mmlab/mmediting.git@1.x mmocr @ git+https://github.com/open-mmlab/mmocr.git@dev-1.x mmpose>=1.0.0rc0 diff --git a/tests/regression/mmdet3d.yml b/tests/regression/mmdet3d.yml index 1eb046ab0..dc476aed7 100644 --- a/tests/regression/mmdet3d.yml +++ b/tests/regression/mmdet3d.yml @@ -2,8 +2,8 @@ globals: codebase_dir: ../mmdetection3d checkpoint_force_download: False images: - kitti_input: &kitti_input ../mmdetection3d/demo/data/kitti/kitti_000008.bin - nus_input: &nus_input ./tests/data/n008-2018-08-01-15-16-36-0400__LIDAR_TOP__1533151612397179.pcd.bin + kitti_input: &kitti_input ../mmdetection3d/demo/data/kitti/000008.bin + nus_input: &nus_input tests/data/n008-2018-08-01-15-16-36-0400__LIDAR_TOP__1533151612397179.pcd.bin metric_info: &metric_info AP: # named after metafile.Results.Metrics @@ -82,8 +82,8 @@ models: - name: PointPillars metafile: configs/pointpillars/metafile.yml model_configs: - - configs/pointpillars/hv_pointpillars_secfpn_6x8_160e_kitti-3d-3class.py - - configs/pointpillars/hv_pointpillars_secfpn_6x8_160e_kitti-3d-car.py + - configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class.py + - configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-car.py pipelines: - *pipeline_ort_dynamic_kitti_fp32 - *pipeline_openvino_dynamic_kitti_fp32 @@ -91,7 +91,7 @@ models: - name: PointPillars metafile: configs/pointpillars/metafile.yml model_configs: - - configs/pointpillars/hv_pointpillars_secfpn_sbn-all_4x8_2x_nus-3d.py + - configs/pointpillars/pointpillars_hv_secfpn_sbn-all_8xb4-2x_nus-3d.py pipelines: - *pipeline_ort_dynamic_nus_fp32 - *pipeline_openvino_dynamic_nus_fp32_64x4 @@ -99,7 +99,7 @@ models: - name: CenterPoint metafile: configs/centerpoint/metafile.yml model_configs: - - configs/centerpoint/centerpoint_02pillar_second_secfpn_circlenms_4x8_cyclic_20e_nus.py + - configs/centerpoint/centerpoint_pillar02_second_secfpn_head-circlenms_8xb4-cyclic-20e_nus-3d.py pipelines: - *pipeline_ort_dynamic_nus_fp32 - *pipeline_openvino_dynamic_nus_fp32_20x5 diff --git a/tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_8xb4-cyclic-20e_nus-3d.py b/tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_8xb4-cyclic-20e_nus-3d.py new file mode 100644 index 000000000..098c1df26 --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_8xb4-cyclic-20e_nus-3d.py @@ -0,0 +1,141 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = [ + 'nus-3d.py', 'centerpoint_pillar02_second_secfpn_nus.py', 'cyclic-20e.py', + 'default_runtime.py' +] + +# If point cloud range is changed, the models should also change their point +# cloud range accordingly +point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0] +# For nuScenes we usually do 10-class detection +class_names = [ + 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', + 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' +] +data_prefix = dict(pts='samples/LIDAR_TOP', img='', sweeps='sweeps/LIDAR_TOP') +model = dict( + data_preprocessor=dict( + voxel_layer=dict(point_cloud_range=point_cloud_range)), + pts_voxel_encoder=dict(point_cloud_range=point_cloud_range), + pts_bbox_head=dict(bbox_coder=dict(pc_range=point_cloud_range[:2])), + # model training and testing settings + train_cfg=dict(pts=dict(point_cloud_range=point_cloud_range)), + test_cfg=dict(pts=dict(pc_range=point_cloud_range[:2]))) + +dataset_type = 'NuScenesDataset' +data_root = 'data/nuscenes/' +file_client_args = dict(backend='disk') + +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'nuscenes_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict( + car=5, + truck=5, + bus=5, + trailer=5, + construction_vehicle=5, + traffic_cone=5, + barrier=5, + motorcycle=5, + bicycle=5, + pedestrian=5)), + classes=class_names, + sample_groups=dict( + car=2, + truck=3, + construction_vehicle=7, + bus=4, + trailer=6, + barrier=2, + motorcycle=6, + bicycle=6, + pedestrian=2, + traffic_cone=2), + points_loader=dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=5, + use_dim=[0, 1, 2, 3, 4])) + +train_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=9, + use_dim=[0, 1, 2, 3, 4], + pad_empty_sweeps=True, + remove_close=True), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), + dict(type='ObjectSample', db_sampler=db_sampler), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.3925, 0.3925], + scale_ratio_range=[0.95, 1.05], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectNameFilter', classes=class_names), + dict(type='PointShuffle'), + dict( + type='Pack3DDetInputs', + keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5), + dict( + type='LoadPointsFromMultiSweeps', + sweeps_num=9, + use_dim=[0, 1, 2, 3, 4], + pad_empty_sweeps=True, + remove_close=True), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D') + ]), + dict(type='Pack3DDetInputs', keys=['points']) +] + +train_dataloader = dict( + _delete_=True, + batch_size=4, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='CBGSDataset', + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='nuscenes_infos_train.pkl', + pipeline=train_pipeline, + metainfo=dict(CLASSES=class_names), + test_mode=False, + data_prefix=data_prefix, + use_valid_flag=True, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='LiDAR'))) +test_dataloader = dict( + dataset=dict(pipeline=test_pipeline, metainfo=dict(CLASSES=class_names))) +val_dataloader = dict( + dataset=dict(pipeline=test_pipeline, metainfo=dict(CLASSES=class_names))) + +train_cfg = dict(val_interval=20) diff --git a/tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_head-circlenms_8xb4-cyclic-20e_nus-3d.py b/tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_head-circlenms_8xb4-cyclic-20e_nus-3d.py new file mode 100644 index 000000000..37e0ea125 --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_head-circlenms_8xb4-cyclic-20e_nus-3d.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = ['./centerpoint_pillar02_second_secfpn_8xb4-cyclic-20e_nus-3d.py'] + +model = dict(test_cfg=dict(pts=dict(nms_type='circle'))) diff --git a/tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_nus.py b/tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_nus.py new file mode 100644 index 000000000..18fe78532 --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_nus.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +voxel_size = [0.2, 0.2, 8] +model = dict( + type='CenterPoint', + data_preprocessor=dict( + type='Det3DDataPreprocessor', + voxel=True, + voxel_layer=dict( + max_num_points=20, + voxel_size=voxel_size, + max_voxels=(30000, 40000))), + pts_voxel_encoder=dict( + type='PillarFeatureNet', + in_channels=5, + feat_channels=[64], + with_distance=False, + voxel_size=(0.2, 0.2, 8), + norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), + legacy=False), + pts_middle_encoder=dict( + type='PointPillarsScatter', in_channels=64, output_shape=(512, 512)), + pts_backbone=dict( + type='SECOND', + in_channels=64, + out_channels=[64, 128, 256], + layer_nums=[3, 5, 5], + layer_strides=[2, 2, 2], + norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), + conv_cfg=dict(type='Conv2d', bias=False)), + pts_neck=dict( + type='SECONDFPN', + in_channels=[64, 128, 256], + out_channels=[128, 128, 128], + upsample_strides=[0.5, 1, 2], + norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), + upsample_cfg=dict(type='deconv', bias=False), + use_conv_for_no_stride=True), + pts_bbox_head=dict( + type='CenterHead', + in_channels=sum([128, 128, 128]), + tasks=[ + dict(num_class=1, class_names=['car']), + dict(num_class=2, class_names=['truck', 'construction_vehicle']), + dict(num_class=2, class_names=['bus', 'trailer']), + dict(num_class=1, class_names=['barrier']), + dict(num_class=2, class_names=['motorcycle', 'bicycle']), + dict(num_class=2, class_names=['pedestrian', 'traffic_cone']), + ], + common_heads=dict( + reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)), + share_conv_channel=64, + bbox_coder=dict( + type='CenterPointBBoxCoder', + post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + max_num=500, + score_threshold=0.1, + out_size_factor=4, + voxel_size=voxel_size[:2], + code_size=9), + separate_head=dict( + type='SeparateHead', init_bias=-2.19, final_kernel=3), + loss_cls=dict(type='mmdet.GaussianFocalLoss', reduction='mean'), + loss_bbox=dict( + type='mmdet.L1Loss', reduction='mean', loss_weight=0.25), + norm_bbox=True), + # model training and testing settings + train_cfg=dict( + pts=dict( + grid_size=[512, 512, 1], + voxel_size=voxel_size, + out_size_factor=4, + dense_reg=1, + gaussian_overlap=0.1, + max_objs=500, + min_radius=2, + code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2])), + test_cfg=dict( + pts=dict( + post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + max_per_img=500, + max_pool_nms=False, + min_radius=[4, 12, 10, 1, 0.85, 0.175], + score_threshold=0.1, + pc_range=[-51.2, -51.2], + out_size_factor=4, + voxel_size=voxel_size[:2], + nms_type='rotate', + pre_max_size=1000, + post_max_size=83, + nms_thr=0.2))) diff --git a/tests/test_codebase/test_mmdet3d/data/cyclic-20e.py b/tests/test_codebase/test_mmdet3d/data/cyclic-20e.py new file mode 100644 index 000000000..aa8e2059b --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/data/cyclic-20e.py @@ -0,0 +1,66 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# For nuScenes dataset, we usually evaluate the model at the end of training. +# Since the models are trained by 24 epochs by default, we set evaluation +# interval to be 20. Please change the interval accordingly if you do not +# use a default schedule. +# optimizer +lr = 1e-4 +# This schedule is mainly used by models on nuScenes dataset +# max_norm=10 is better for SECOND +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=lr, weight_decay=0.01), + clip_grad=dict(max_norm=35, norm_type=2)) +# learning rate +param_scheduler = [ + # learning rate scheduler + # During the first 8 epochs, learning rate increases from 0 to lr * 10 + # during the next 12 epochs, learning rate decreases from lr * 10 to + # lr * 1e-4 + dict( + type='CosineAnnealingLR', + T_max=8, + eta_min=lr * 10, + begin=0, + end=8, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=12, + eta_min=lr * 1e-4, + begin=8, + end=20, + by_epoch=True, + convert_to_iter_based=True), + # momentum scheduler + # During the first 8 epochs, momentum increases from 0 to 0.85 / 0.95 + # during the next 12 epochs, momentum increases from 0.85 / 0.95 to 1 + dict( + type='CosineAnnealingMomentum', + T_max=8, + eta_min=0.85 / 0.95, + begin=0, + end=8, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingMomentum', + T_max=12, + eta_min=1, + begin=8, + end=20, + by_epoch=True, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=20) +val_cfg = dict() +test_cfg = dict() + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (4 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=32) diff --git a/tests/test_codebase/test_mmdet3d/data/cyclic-40e.py b/tests/test_codebase/test_mmdet3d/data/cyclic-40e.py new file mode 100644 index 000000000..42d10b651 --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/data/cyclic-40e.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# The schedule is usually used by models trained on KITTI dataset +# The learning rate set in the cyclic schedule is the initial learning rate +# rather than the max learning rate. Since the target_ratio is (10, 1e-4), +# the learning rate will change from 0.0018 to 0.018, than go to 0.0018*1e-4 +lr = 0.0018 +# The optimizer follows the setting in SECOND.Pytorch, but here we use +# the official AdamW optimizer implemented by PyTorch. +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=lr, betas=(0.95, 0.99), weight_decay=0.01), + clip_grad=dict(max_norm=10, norm_type=2)) +# learning rate +param_scheduler = [ + # learning rate scheduler + # During the first 16 epochs, learning rate increases from 0 to lr * 10 + # during the next 24 epochs, learning rate decreases from lr * 10 to + # lr * 1e-4 + dict( + type='CosineAnnealingLR', + T_max=16, + eta_min=lr * 10, + begin=0, + end=16, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=24, + eta_min=lr * 1e-4, + begin=16, + end=40, + by_epoch=True, + convert_to_iter_based=True), + # momentum scheduler + # During the first 16 epochs, momentum increases from 0 to 0.85 / 0.95 + # during the next 24 epochs, momentum increases from 0.85 / 0.95 to 1 + dict( + type='CosineAnnealingMomentum', + T_max=16, + eta_min=0.85 / 0.95, + begin=0, + end=16, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingMomentum', + T_max=24, + eta_min=1, + begin=16, + end=40, + by_epoch=True, + convert_to_iter_based=True) +] + +# Runtime settings,training schedule for 40e +# Although the max_epochs is 40, this schedule is usually used we +# RepeatDataset with repeat ratio N, thus the actual max epoch +# number could be Nx40 +train_cfg = dict(by_epoch=True, max_epochs=40, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (6 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=48) diff --git a/tests/test_codebase/test_mmdet3d/data/default_runtime.py b/tests/test_codebase/test_mmdet3d/data/default_runtime.py new file mode 100644 index 000000000..1a6843968 --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/data/default_runtime.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +default_scope = 'mmdet3d' + +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=50), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=-1), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='Det3DVisualizationHook')) + +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) + +log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True) + +log_level = 'INFO' +load_from = None +resume = False + +# TODO: support auto scaling lr diff --git a/tests/test_codebase/test_mmdet3d/data/kitti-3d-3class.py b/tests/test_codebase/test_mmdet3d/data/kitti-3d-3class.py new file mode 100644 index 000000000..342f9a953 --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/data/kitti-3d-3class.py @@ -0,0 +1,133 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# dataset settings +dataset_type = 'KittiDataset' +data_root = 'tests/test_codebase/test_mmdet3d/data/kitti' +class_names = ['Pedestrian', 'Cyclist', 'Car'] +point_cloud_range = [0, -40, -3, 70.4, 40, 1] +input_modality = dict(use_lidar=True, use_camera=False) +metainfo = dict(CLASSES=class_names) + +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'kitti_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)), + classes=class_names, + sample_groups=dict(Car=12, Pedestrian=6, Cyclist=6), + points_loader=dict( + type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4)) + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=4, # x, y, z, intensity + use_dim=4), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), + dict(type='ObjectSample', db_sampler=db_sampler), + dict( + type='ObjectNoise', + num_try=100, + translation_std=[1.0, 1.0, 0.5], + global_rot_range=[0.0, 0.0], + rot_range=[-0.78539816, 0.78539816]), + dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.78539816, 0.78539816], + scale_ratio_range=[0.95, 1.05]), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='PointShuffle'), + dict( + type='Pack3DDetInputs', + keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='PointsRangeFilter', point_cloud_range=point_cloud_range) + ]), + dict(type='Pack3DDetInputs', keys=['points']) +] +# construct a pipeline for data and gt loading in show function +# please keep its loading function consistent with test_pipeline (e.g. client) +eval_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), + dict(type='Pack3DDetInputs', keys=['points']) +] +train_dataloader = dict( + batch_size=6, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type='RepeatDataset', + times=2, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='kitti_infos_train.pkl', + data_prefix=dict(pts='training/velodyne_reduced'), + pipeline=train_pipeline, + modality=input_modality, + test_mode=False, + metainfo=metainfo, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='LiDAR'))) +val_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(pts='training/velodyne_reduced'), + ann_file='kitti_infos_val.pkl', + pipeline=test_pipeline, + modality=input_modality, + test_mode=True, + metainfo=metainfo, + box_type_3d='LiDAR')) +test_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(pts='training/velodyne_reduced'), + ann_file='kitti_infos_val.pkl', + pipeline=test_pipeline, + modality=input_modality, + test_mode=True, + metainfo=metainfo, + box_type_3d='LiDAR')) +val_evaluator = dict( + type='KittiMetric', + ann_file=data_root + 'kitti_infos_val.pkl', + metric='bbox') +test_evaluator = val_evaluator + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer') diff --git a/tests/test_codebase/test_mmdet3d/data/kitti/kitti_infos_val.pkl b/tests/test_codebase/test_mmdet3d/data/kitti/kitti_infos_val.pkl index f2acbd3dcc82ef6a0d675a5197462d0e3957b7ff..ac17637e64bc5bacb1f34b18ac97504215ba4694 100644 GIT binary patch literal 7083 zcmcIod3aM*7H=tKDHb}e%OC_-kAXO_cI%6&>>O@~YLa+aK`NP=1u!RJ4a?%;@sbQ?4kdf{H>sm5nbiA5U#V6&>;5 zb%))~Dt{THuGXq~VJt3LT5v+0E?+S`t5F?vqAmN>=pm{1OgZw(JEv!dV0coF zVK$=uP&#x*1H+SBGR%z284pzzXsi5Qw=NPt#X+Yc@h(VwLGi7XrQQw+3=u2 zB9AI}*dG`~gAE#@&_xDatk5N5tIzg(k^NAukJeusC^V_Nv72d_2HW8ZjWFm^g+_8o zjlw!zcFdBJ##TRb{Rg|hz&hD;SZ!WX_7}HZ6>9tL4KS2)44V~*Z1>eKwNKtN6o#W) zGHmA2WE~9rfU*dDCaX@XA)En%xUSXkwHmi4*54q(1qW<&lgG^YmQSu>ccYoikDM@{sb^ z)^>WNVgL-497FIfuErC`?sK%eL+$C}m6|GD)B3`7*K&P&H!gRfZZb05s#Vjrtpvm@y=j~0h zkf^WU*1!Tj;=(M`Y;t1ZCHKYf%#>-DPdX!I?JZlXOPz zNjk$HAWubS_?7Fu%~K{Fk6BeJjJxsLvVc$i{b;-@$Z3D>5Yq!9KsY?L5iu4=61_+2$IR>PxP z0}F|jQ|Bjz=8e$ZaR@LBAfoH`n=SRqY2~F*vqx7K3yFPqmN!6vVNENukBF|w9|(or z0T24Upo7R3CEe=zX3VImrI3TdIngAo)7wR(g6n0euGbIds%d>>X*SP}X~oL2kF&}> zL#?7nnbmy?)etYua?os>oGe_+*WI%H7Y?-UZCNVDy=WB4QpdE${!_;0r;_bg?_bmt zEg72x*Q=t+8YgigZY}2FzSgS-QDAk5H|z2ko&`3K8VIYpbFi0ga?l)X_Ulf$A7gp4 zH0!F7Y-L5{x=mLp>h|V|{R;&euixkMd+s;^s-Whkn4I{$%S|84o4W{Qi2DSZm)(0p=?{3-TAE*$CjwRKHIbmx z@1+H0d66hBv}ss|MaVvUz#>obO&1Z-y)MK&jAynaSX4Feu7PUz5WOJ{w7`V_8(v3>t$K$$CU%XQnC0RC)KTF-Sz2Q9Bs@!Tn&@;X-Ujm?7Wr}`E-ABQCs{Ru6G8AY1j}k|7ukN)(qxxRTKiob zQ!uX82)ZPHMT3~Ja^Cy+W?W@Mo5RY(a%GO2rm)%qR$IVYBVgT*kjvML4UJ!W<9X{S zsLfB){0FQB%r4g*FfYx(`jzKr@I%rXslX~1tWtWutOAFG6tv8ZhwxutokCMtZnJN@V z9G@L`y8J|4gk7d0DXvo zzQsEVjfQtzT38my9H~?JB>XBD=)=be=wFWq(2fzSYYMrmgHZ2*%zr67rqCArrJyzcGhyot{0+eW z5oRvh#{UF7uFw+(J*m)Bh_wYK{I-~RTBpm+CsU79O7t`vZC7ZAK|2+CCT+xjAD((; zO=`_~ed{FJ#WsIe=vjjj3MJDv*|~zVCECpOR#T5%#Gq`O0>KQzmY4vhY> zgb$$Sv5IIE<1mozUw65wfdy_9Edh<9#2S79#x0>_yNd;G6!DF=j;H2!+~Y_0_D^fa z;|O{I>)5OZHLafSxC#IH7H5Zx-BzlXj-d5;82EC)uyDhQz z9-F+c&<6&6sL)5MW---ZwRsH)DBJgM?h<2x+dQGkQpRnbi-(%Dd8|>jMshv|InCR< z#ugCdeA0xRPc!6j#SC($5;xafYp#&!zntu63Vm+S{}lQnhnmz~%9%H$ZkwY?w2zIx zRA|3JUnz7TZNz0IF3nsm(bsHqP@zKxeWTE~EfMn_n|!a(4+b4p=tzc`AAto&q=Z-~ zp;5s&3*0AK0{TQKFoZr)Xt+;|v4;CZ7IKOC3Bj88hae{1LxPx}n-KF$h8X@;Y~J24 z^+2(?+}vW$m*`i{_cw)pH|VIM=izU>98`W^Y5P%~e3G8e`oV_Y2H%QafV-i$6>=eu zSS2ZfKW*;lZ!R5SE;9{DZ^wghYcEVY;BM$0gHcbc(TQ&pDg0F z(3^tRbpINO>Zh<(&Vi?&Dik?e@(J@kb3?78zqwe_yI_!!bj7!#pN6}kpDyHXJYsoq zt)^DaH5d7(^;IRkJCDHU3}JUB?uOn&$Y=3LBV`?fX6(AFQuECvl3v0?aOx=xOK~^! zvxR&Pk2E@UPCaIBFdx41teKL2E)T(}moPjJcSAp4$h~pT@N9#Oazn!$6C8f-^SM}% z&@jh^fQC870yoT(5YR9;2muW z6S5um49gN^u#EmahDmWkteRHJ0{8Di!~MI9HQcs0G!a^==FwKX=B+zePPg|FH5742 G2K*nw&%#dt literal 5109 zcmd5=iC%Q zB5G9PhN4l4H4&Gf2%>S?bGz^RzAt}Ad(M0BK1Ty3zqaYGcYZVPy}O+Aecw5Enb$CU z3pnavn9DQL;c!H%eU%{?UI$q&PmX%_MuIgkqBPs%v}@u%Jq#l~lc+*Z_#)9rv~rFy z);n)bA(x_pc(f88^W@q?l3`yw1fv4cRCTDV_-dWrL?{r8CiO%r zpgSSk?B$FG%~r@Mbu@u9;0?rrp+M3Zit>#Uay`Syz8^+=9;0iOPme_BLS8eB5%96S zFjm0FU8TcHhd03|EQ3M4HcsOn_Y7xT3t(K!zP<2-fG1t0S%$5}Fn;aYwX6R@#rP?8 z=rFI1V@rL-AYGITj*godX#!r*VTr|PJ}gy?mNBE{0#?|JN=!y=Lomw9!TtGxqSTPV zDe&OCxixrdm`z-vGp~dZduEp4@yE7&7&kFPm(8XwZhjr7Q53JYDZz zBCpkrlzSKkriT;T6w98v6kdsA8q;g=B(}*tWq8QFgt(r;UDzZ~m$>n5 zc^Qx5-HGP#=MyjFq_72VVp`rT!v=X!o|X5BpX`_-&&l*sd8FP=75C*??0s!LSsNwq z%}UAkpgi(GuYpBi@*A1_CIOpu*kY59RHv9bYXGk(@?IvtRe;YX?>EWcBOlb^9V%$Y;@#dCjmDBtSqEX4hc-1w)K?vXh*IEHR^T-ORjh!s z397B_@`g&P!S3L7Mz~WzoeuSWPc|8e1p-y6AOZZ60@%O+cL`{;0PePu zC{hG_IEk7BG^-?f_|&!6NToPg&9dN8RvG2dXRF%2(B$}!KCq98wFua+L#yAMDOK^X z4-P0+2bt9&0dH8W4x6$^2*NnLRwB>u7$+}bT43K)TIQ~-b7)ZPlqc|FoK}wXN^a+F zJ+u^XCAZ1D?h2#GekJ^LwF3@#14kI;Q2}i_yy-XQuiqc5fny5faRzxpz)1rn9|%vG zF#DB|1Q$CE@3~=3<;(P1%}{ddb`9I{l=+cYbAZ($vbYP$5gtvf?F(ayRDSqxzg)45@9|~2Ybrg zv6iZ#-DTRhK=iq-I6BGpBXI1w+W+r3gNq!9P62Q0aLHW%s$#T0Qyk#SDh!u74DSed z*ABxKa}utjFl53RDI4slWRRz@SQ88TWH+YO8es#@SE;mPj7sHV+5ZsJ33e(?mG8Of zNOi#e^tsgpIX~N;AF^k%_DHWxq*_MDX(|)$xz+5FJ@d^6nSWrl>eb$b%XTejHuGq`cJ!O6jQR<5R1g4&G( zuGr={YJwhyd0RlQ4t?eZljIj7a7Q8RXM}eJ+_MnAZz3!+P5f7X;XZ>&3;2o-Up2u* zqQOuNpaLTq%+~~b-2(HC@RR&6yU2pUy9vKe;cb~uk$kcoY6FIEveRz~__hw;F;6#@ ziC8L5ubPtZUFG5_Q?*)F zR73*w>Zulfrab+eJ^ez!FD*~MGCgTzK~Z6EnfLt6Mb!;0{uS2+H_IQ}T$PZo|pn>e1MK~IyeRYVO)XKW+E7g8l&_%I0m!U+B< O;BPwo-JeSN3jPVk`NVzz diff --git a/tests/test_codebase/test_mmdet3d/data/kitti/training/velodyne_reduced/000008.bin b/tests/test_codebase/test_mmdet3d/data/kitti/training/velodyne_reduced/000008.bin new file mode 120000 index 000000000..35857639c --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/data/kitti/training/velodyne_reduced/000008.bin @@ -0,0 +1 @@ +../../kitti_000008.bin \ No newline at end of file diff --git a/tests/test_codebase/test_mmdet3d/data/model_cfg.py b/tests/test_codebase/test_mmdet3d/data/model_cfg.py index da02e4b39..599dcaa70 100644 --- a/tests/test_codebase/test_mmdet3d/data/model_cfg.py +++ b/tests/test_codebase/test_mmdet3d/data/model_cfg.py @@ -1,75 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. -voxel_size = [0.16, 0.16, 4] +_base_ = [ + 'pointpillars_hv_secfpn_kitti.py', 'kitti-3d-3class.py', 'cyclic-40e.py', + 'default_runtime.py' +] -model = dict( - type='VoxelNet', - voxel_layer=dict( - max_num_points=32, # max_points_per_voxel - point_cloud_range=[0, -39.68, -3, 69.12, 39.68, 1], - voxel_size=voxel_size, - max_voxels=(16000, 40000) # (training, testing) max_voxels - ), - voxel_encoder=dict( - type='PillarFeatureNet', - in_channels=4, - feat_channels=[64], - with_distance=False, - voxel_size=voxel_size, - point_cloud_range=[0, -39.68, -3, 69.12, 39.68, 1]), - middle_encoder=dict( - type='PointPillarsScatter', in_channels=64, output_shape=[496, 432]), - backbone=dict( - type='SECOND', - in_channels=64, - layer_nums=[3, 5, 5], - layer_strides=[2, 2, 2], - out_channels=[64, 128, 256]), - neck=dict( - type='SECONDFPN', - in_channels=[64, 128, 256], - upsample_strides=[1, 2, 4], - out_channels=[128, 128, 128]), - test_cfg=dict( - use_rotate_nms=True, - nms_across_levels=False, - nms_thr=0.01, - score_thr=0.1, - min_bbox_size=0, - nms_pre=100, - max_num=50), - bbox_head=dict( - type='Anchor3DHead', - num_classes=3, - in_channels=384, - feat_channels=384, - use_direction_classifier=True, - anchor_generator=dict( - type='AlignedAnchor3DRangeGenerator', - ranges=[ - [0, -39.68, -0.6, 69.12, 39.68, -0.6], - [0, -39.68, -0.6, 69.12, 39.68, -0.6], - [0, -39.68, -1.78, 69.12, 39.68, -1.78], - ], - sizes=[[0.6, 0.8, 1.73], [0.6, 1.76, 1.73], [1.6, 3.9, 1.56]], - rotations=[0, 1.57], - reshape_out=False), - diff_rad_by_sin=True, - bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), - loss_cls=dict( - type='FocalLoss', - use_sigmoid=True, - gamma=2.0, - alpha=0.25, - loss_weight=1.0), - loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), - loss_dir=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2))) point_cloud_range = [0, -39.68, -3, 69.12, 39.68, 1] # dataset settings -data_root = 'tests/test_codebase/test_mmdet3d/data/kitti/' -dataset_type = 'KittiDataset' +data_root = 'data/kitti/' class_names = ['Pedestrian', 'Cyclist', 'Car'] -input_modality = dict(use_lidar=True, use_camera=False) +metainfo = dict(CLASSES=class_names) + # PointPillars adopted a different sampling strategies among classes db_sampler = dict( data_root=data_root, @@ -77,19 +17,17 @@ db_sampler = dict( rate=1.0, prepare=dict( filter_by_difficulty=[-1], - filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)), + filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)), classes=class_names, - sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10)) + sample_groups=dict(Car=15, Pedestrian=15, Cyclist=15), + points_loader=dict( + type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4)) + +# PointPillars uses different augmentation hyper parameters train_pipeline = [ dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), - dict(type='ObjectSample', db_sampler=db_sampler), - dict( - type='ObjectNoise', - num_try=100, - translation_std=[0.25, 0.25, 0.25], - global_rot_range=[0.0, 0.0], - rot_range=[-0.15707963267, 0.15707963267]), + dict(type='ObjectSample', db_sampler=db_sampler, use_ground_plane=True), dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), dict( type='GlobalRotScaleTrans', @@ -98,8 +36,9 @@ train_pipeline = [ dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), dict(type='PointShuffle'), - dict(type='DefaultFormatBundle3D', class_names=class_names), - dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) + dict( + type='Pack3DDetInputs', + keys=['points', 'gt_labels_3d', 'gt_bboxes_3d']) ] test_pipeline = [ dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), @@ -116,204 +55,60 @@ test_pipeline = [ translation_std=[0, 0, 0]), dict(type='RandomFlip3D'), dict( - type='PointsRangeFilter', point_cloud_range=point_cloud_range), - dict( - type='DefaultFormatBundle3D', - class_names=class_names, - with_label=False), - dict(type='Collect3D', keys=['points']) - ]) + type='PointsRangeFilter', point_cloud_range=point_cloud_range) + ]), + dict(type='Pack3DDetInputs', keys=['points']) ] -data = dict( - train=dict( - dataset=dict( - pipeline=train_pipeline, classes=class_names, - box_type_3d='LiDAR')), - val=dict(pipeline=test_pipeline, classes=class_names, box_type_3d='LiDAR'), - test=dict( - type=dataset_type, - data_root=data_root, - ann_file=data_root + 'kitti_infos_val.pkl', - split='training', - pts_prefix='velodyne_reduced', - pipeline=test_pipeline, - modality=input_modality, - classes=class_names, - test_mode=True, - box_type_3d='LiDAR')) -point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0] -centerpoint_model = dict( - pts_voxel_layer=dict( - max_num_points=20, - voxel_size=voxel_size, - max_voxels=(30000, 40000), - point_cloud_range=point_cloud_range), - pts_voxel_encoder=dict( - type='PillarFeatureNet', - in_channels=4, - feat_channels=[64], - with_distance=False, - voxel_size=(0.2, 0.2, 8), - norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), - legacy=False), - pts_middle_encoder=dict( - type='PointPillarsScatter', in_channels=64, output_shape=(512, 512)), - pts_backbone=dict( - type='SECOND', - in_channels=64, - out_channels=[64, 128, 256], - layer_nums=[3, 5, 5], - layer_strides=[2, 2, 2], - norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), - conv_cfg=dict(type='Conv2d', bias=False)), - pts_neck=dict( - type='SECONDFPN', - in_channels=[64, 128, 256], - out_channels=[128, 128, 128], - upsample_strides=[0.5, 1, 2], - norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), - upsample_cfg=dict(type='deconv', bias=False), - use_conv_for_no_stride=True), - pts_bbox_head=dict( - type='CenterHead', - in_channels=sum([128, 128, 128]), - tasks=[ - dict(num_class=1, class_names=['car']), - dict(num_class=2, class_names=['truck', 'construction_vehicle']), - dict(num_class=2, class_names=['bus', 'trailer']), - dict(num_class=1, class_names=['barrier']), - dict(num_class=2, class_names=['motorcycle', 'bicycle']), - dict(num_class=2, class_names=['pedestrian', 'traffic_cone']), - ], - common_heads=dict( - reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), vel=(2, 2)), - share_conv_channel=64, - bbox_coder=dict( - type='CenterPointBBoxCoder', - post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], - pc_range=point_cloud_range[:2], - max_num=500, - score_threshold=0.1, - out_size_factor=4, - voxel_size=voxel_size[:2], - code_size=9), - separate_head=dict( - type='SeparateHead', init_bias=-2.19, final_kernel=3), - loss_cls=dict(type='GaussianFocalLoss', reduction='mean'), - loss_bbox=dict(type='L1Loss', reduction='mean', loss_weight=0.25), - norm_bbox=True), - # model training and testing settings - train_cfg=dict( - pts=dict( - grid_size=[512, 512, 1], - voxel_size=voxel_size, - out_size_factor=4, - dense_reg=1, - gaussian_overlap=0.1, - max_objs=500, - min_radius=2, - code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2])), - test_cfg=dict( - pts=dict( - post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], - max_per_img=500, - max_pool_nms=False, - min_radius=[4, 12, 10, 1, 0.85, 0.175], - pc_range=point_cloud_range[:2], - score_threshold=0.1, - out_size_factor=4, - voxel_size=voxel_size[:2], - nms_type='circle', - pre_max_size=1000, - post_max_size=83, - nms_thr=0.2))) -voxel_size = [0.25, 0.25, 8] -pointpillars_nus_model = dict( - pts_voxel_layer=dict( - max_num_points=64, - point_cloud_range=[-50, -50, -5, 50, 50, 3], - voxel_size=voxel_size, - max_voxels=(30000, 40000)), - pts_voxel_encoder=dict( - type='HardVFE', - in_channels=4, - feat_channels=[64, 64], - with_distance=False, - voxel_size=voxel_size, - with_cluster_center=True, - with_voxel_center=True, - point_cloud_range=[-50, -50, -5, 50, 50, 3], - norm_cfg=dict(type='naiveSyncBN1d', eps=1e-3, momentum=0.01)), - pts_middle_encoder=dict( - type='PointPillarsScatter', in_channels=64, output_shape=[400, 400]), - pts_backbone=dict( - type='SECOND', - in_channels=64, - norm_cfg=dict(type='naiveSyncBN2d', eps=1e-3, momentum=0.01), - layer_nums=[3, 5, 5], - layer_strides=[2, 2, 2], - out_channels=[64, 128, 256]), - pts_neck=dict( - type='FPN', - norm_cfg=dict(type='naiveSyncBN2d', eps=1e-3, momentum=0.01), - act_cfg=dict(type='ReLU'), - in_channels=[64, 128, 256], - out_channels=256, - start_level=0, - num_outs=3), - pts_bbox_head=dict( - type='Anchor3DHead', - num_classes=10, - in_channels=256, - feat_channels=256, - use_direction_classifier=True, - anchor_generator=dict( - type='AlignedAnchor3DRangeGenerator', - ranges=[[-50, -50, -1.8, 50, 50, -1.8]], - scales=[1, 2, 4], - sizes=[ - [2.5981, 0.8660, 1.], # 1.5 / sqrt(3) - [1.7321, 0.5774, 1.], # 1 / sqrt(3) - [1., 1., 1.], - [0.4, 0.4, 1], - ], - custom_values=[0, 0], - rotations=[0, 1.57], - reshape_out=True), - assigner_per_size=False, - diff_rad_by_sin=True, - dir_offset=-0.7854, # -pi / 4 - bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder', code_size=9), - loss_cls=dict( - type='FocalLoss', - use_sigmoid=True, - gamma=2.0, - alpha=0.25, - loss_weight=1.0), - loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0), - loss_dir=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)), - # model training and testing settings - train_cfg=dict( - pts=dict( - assigner=dict( - type='MaxIoUAssigner', - iou_calculator=dict(type='BboxOverlapsNearest3D'), - pos_iou_thr=0.6, - neg_iou_thr=0.3, - min_pos_iou=0.3, - ignore_iof_thr=-1), - allowed_border=0, - code_weight=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2], - pos_weight=-1, - debug=False)), - test_cfg=dict( - pts=dict( - use_rotate_nms=True, - nms_across_levels=False, - nms_pre=1000, - nms_thr=0.2, - score_thr=0.05, - min_bbox_size=0, - max_num=500))) +train_dataloader = dict( + dataset=dict(dataset=dict(pipeline=train_pipeline, metainfo=metainfo))) +test_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo)) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo)) +# In practice PointPillars also uses a different schedule +# optimizer +lr = 0.001 +epoch_num = 80 +optim_wrapper = dict( + optimizer=dict(lr=lr), clip_grad=dict(max_norm=35, norm_type=2)) +param_scheduler = [ + dict( + type='CosineAnnealingLR', + T_max=epoch_num * 0.4, + eta_min=lr * 10, + begin=0, + end=epoch_num * 0.4, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=epoch_num * 0.6, + eta_min=lr * 1e-4, + begin=epoch_num * 0.4, + end=epoch_num * 1, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingMomentum', + T_max=epoch_num * 0.4, + eta_min=0.85 / 0.95, + begin=0, + end=epoch_num * 0.4, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingMomentum', + T_max=epoch_num * 0.6, + eta_min=1, + begin=epoch_num * 0.4, + end=epoch_num * 1, + convert_to_iter_based=True) +] +# max_norm=35 is slightly better than 10 for PointPillars in the earlier +# development of the codebase thus we keep the setting. But we does not +# specifically tune this parameter. +# PointPillars usually need longer schedule than second, we simply double +# the training schedule. Do remind that since we use RepeatDataset and +# repeat factor is 2, so we actually train 160 epochs. +train_cfg = dict(by_epoch=True, max_epochs=epoch_num, val_interval=2) +val_cfg = dict() +test_cfg = dict() diff --git a/tests/test_codebase/test_mmdet3d/data/nus-3d.py b/tests/test_codebase/test_mmdet3d/data/nus-3d.py new file mode 100644 index 000000000..692c2a84f --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/data/nus-3d.py @@ -0,0 +1,133 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# If point cloud range is changed, the models should also change their point +# cloud range accordingly +point_cloud_range = [-50, -50, -5, 50, 50, 3] +# For nuScenes we usually do 10-class detection +class_names = [ + 'car', 'truck', 'trailer', 'bus', 'construction_vehicle', 'bicycle', + 'motorcycle', 'pedestrian', 'traffic_cone', 'barrier' +] +metainfo = dict(CLASSES=class_names) +dataset_type = 'NuScenesDataset' +data_root = 'data/nuscenes/' +# Input modality for nuScenes dataset, this is consistent with the submission +# format which requires the information in input_modality. +input_modality = dict(use_lidar=True, use_camera=False) +data_prefix = dict(pts='samples/LIDAR_TOP', img='', sweeps='sweeps/LIDAR_TOP') + +file_client_args = dict(backend='disk') +# Uncomment the following if use ceph or other file clients. +# See https://mmcv.readthedocs.io/en/latest/api.html#mmcv.fileio.FileClient +# for more details. +# file_client_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/nuscenes/': 's3://nuscenes/nuscenes/', +# 'data/nuscenes/': 's3://nuscenes/nuscenes/' +# })) + +train_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5), + dict(type='LoadPointsFromMultiSweeps', sweeps_num=10), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.3925, 0.3925], + scale_ratio_range=[0.95, 1.05], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectNameFilter', classes=class_names), + dict(type='PointShuffle'), + dict( + type='Pack3DDetInputs', + keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5), + dict(type='LoadPointsFromMultiSweeps', sweeps_num=10, test_mode=True), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='PointsRangeFilter', point_cloud_range=point_cloud_range) + ]), + dict(type='Pack3DDetInputs', keys=['points']) +] +# construct a pipeline for data and gt loading in show function +# please keep its loading function consistent with test_pipeline (e.g. client) +eval_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=5, use_dim=5), + dict(type='LoadPointsFromMultiSweeps', sweeps_num=10, test_mode=True), + dict(type='Pack3DDetInputs', keys=['points']) +] +train_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='nuscenes_infos_train.pkl', + pipeline=train_pipeline, + metainfo=metainfo, + modality=input_modality, + test_mode=False, + data_prefix=data_prefix, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='LiDAR')) +test_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='nuscenes_infos_val.pkl', + pipeline=test_pipeline, + metainfo=metainfo, + modality=input_modality, + data_prefix=data_prefix, + test_mode=True, + box_type_3d='LiDAR')) +val_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='nuscenes_infos_val.pkl', + pipeline=test_pipeline, + metainfo=metainfo, + modality=input_modality, + test_mode=True, + data_prefix=data_prefix, + box_type_3d='LiDAR')) + +val_evaluator = dict( + type='NuScenesMetric', + data_root=data_root, + ann_file=data_root + 'nuscenes_infos_val.pkl', + metric='bbox') +test_evaluator = val_evaluator + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer') diff --git a/tests/test_codebase/test_mmdet3d/data/nuscenes/n008-2018-09-18-12-07-26-0400__LIDAR_TOP__1537287083900561.pcd.bin b/tests/test_codebase/test_mmdet3d/data/nuscenes/n008-2018-09-18-12-07-26-0400__LIDAR_TOP__1537287083900561.pcd.bin new file mode 100644 index 0000000000000000000000000000000000000000..ee29b47ec2b01e8a1ad35b8ab77fb9facd585384 GIT binary patch literal 8000 zcmZ9Rd0dUzAIC4Hl%&w&jFL%2ZWvj@O{)qamHpOOV)EQ; ziCo>&IVe$$G9uZswoFWEjD(-xIn{O2dA*)Lp3m!h&Uw!Fyzl3Gp65WkNGt zoh6YdcnN7&=gClqip%iHCY`3>H^1>f0oIpc_IRq52_#GyHDD<^(m?|~L#SpBh$L+7 zz83_CFgE9Pv!)b$XKY7x-7i41BpOE>5*}VpMER%B!}SUO zG;3u-$6mjqnc`Vk^`%i0NFX>TQ=_iMYMB3`ra8_O{LdSq_^u*6cI9oeW)22h{DrvN zr(o)Xe~CudW(_QjD$yq83FvbDzvegrY}wHkA1O?Md9sFPO$HX7IIju#`6{%T9hOP@ zl6k#aX#&yYKsnC=2PT|^Q+z)(_f>#8ZyA0Yc?^EpL)WV&d5HoqPs#B|yTh>S*NNsuG$ojwE@G)N6sm(Fh!#h*RG@!F z#4D3S;DXM^M5DP^gO^`K9BI1|ju-W`bgdFe1~7ji;{Hok!|#3TxTZJ+G{b`N?Sfdi zt)m?;k;q6~F6dXI!abF~Fz_5AS}xHF!KWFUv1NxS2ply;QH(hUjAReaLBUhUnnZGMyDQqV z;VfKz@b9dqeN*tO&Uzrv(Pv<$0ga>cN(5P~AFAJ^f@fdTI8&ktg889!=%D;GT#&D0 zO(9IU6Naj6|A5g3lp)H2Qjm|@jug{R!tTE^{tSx~kz?p_*j8lx$!GtL?3eux-P|Z^ zw66r@98jYIHzkalR!cP6Ou^6A2$k6$gz=G#ztr#|l(sby`s>#<$615#j#i@hySw0m z74Mrh0YofM*JONm6=vHuvPR~Rwblg6!nbiv{7o9+fbTa%LeZM>w|kxp+m&yEd!uO% z5Xnmv=rnpImVNPshE>mqmPE8T(0SSc+-Ivbd~uQHPas+nSbI^#F4djjoPCVHzA6#- zH86&o)=>WBBrX*UDkS{9Zv!v&ru@;qYH+lT3V%3Mr!7ol{5^Xq;;kS5)~;Jl`ST)i z2=p5gjGs++gl`)8qnze~2~Smc=d>Hzx5p@dWn{gDAktEea}Q@|HH+JjdnLI_fK8bk zkG?hrI+Z>maWn^Iz&LpkUh{;5Cl-Aq8l9s8xLf&R&jUZfq{wb0hCW{nSSVeMKee}l zbuo5CQ;@436m4)pr-z<`hF^J+NJ?#7(ZkKDQ0c(V?(aVi!^^2AA>&B|k0*X8 zYTQwH%$fOLrrtWF+LjD=@{GS14q@o|+rx0kN1dh=Vy?xaPQ?dc+&QYzh6L~MMD*@+ zA{_Y><1gsn6m(>AJhZ>9V^;!fqSPp3=}tJQb`mIKDF9#;I$?h)s&E&qw}7x}x0?+%}y z2_qid6AmW6GS_g1~@dqD?*eraN zZHH?G=oL~q@DyPl9Z3DL_y*1DDW-P3gKtmpd8aTP#}@x^a*ur?^4@<;pDw94PI zpDBO+44qMZw`8d3$^38ceODBI>@YO>#P};b;ekd@J_uj8qjA*hB*G1aADXskAADI& zD*+K(r*_zF)-0 z7gR{UgfRY0Q$)OcOuck?AmvX+;!;7zSrN~3mT=L=tp3VsM7*f9Jy()U`J-_eK)55~ z&EBS5us8F+zQrQG?rz2vUkWB`rO%QJ9z9p#HX;4Ftj79Fc_{>^oYZ)9+CZ*LDC6(j zywEyb27Xv9$I|DMG~4{!kvu8MRnaO}N1Z5t9XLNMFS~NN)0`0`hVoJmK3#D^_cIdV zfsr~kB$DJ{SJZ~z3r8Jb{H>Yffl54f!%iouRz~_ugoPe{$ftM*YIs6b|)E#ZyM zj6Z|#(ls_wS76%*jK5YIS%(++R-C7ZIGEALGw4hny#GMsWi~${#(8=7RSnDqNGZ zgIla({`YBu8n@rSi?fcU{Lwux0dc=?#qkenrL~ncjr&OX`!+94=fA)&-q__bF zyT;s6{(!z8Hp<>7{o9)IXWGbLT8T62ZMFk;v(#}Qk&H}pMdmM~AuMJ5{m6Tu;aei% ztW>Jez7nCEjUP%_5DM?98Gk2l1|h?pn_&3>-FOP&xL+7bxv>iREY>lu6eN+c=&h3< z?9+n3e=-x$0b@^S+MW4d-rN*a%1wg_U(b#bFlV+JU0dM@_jP3Tcg?yiWYS|O>~uw^ znS=Xli%?L{@1Vwv`JbCz1cpV8fPjmfMlaLy^lUn{LKhu>;V{LRfAhj$(~ zbj-WN{BL|;8P0d|mj1Sw=78>R6!7<0fhWzYmtLCklxT58ivtk_`*50kJC_|_Nwg;a zt3M**%Cbz(FNyje^&utjB${6V;_!Ovf2l-E1wpwYwi$4hTQ-RDCn1^|EUOi9*MRF> zuU(A49)v&Bu6dlvTFRe-#369GLBtFF3pg*MO}x(kj%BHEzX5-7rH2@QvRP{EyZSD- z=K$qTLaq`}H%g5MeP7CTerQ4NDF=kV8%wazuRnM61dT!DBK7ysLL57*Gk0cg7jjSM ztpOJ0%kjFgHk`>TYogKEdQc9Wk+FR^lx}48_p#O$h5Cg+c?0E&#wqyf?jFdr*G6c# zjcNiJPa=#R?}ui_1j3>M#^0#LK}dSQ7upxl@f!OIJH5lulN@(g>%;gnER97@bEm_N zpXnaahJ;`I6Omy92kmT`{{_aRpdYu7gzpwI|0`6e(C|ZjVR3^_GX)20AhPdn08Ixm z{sjLbR2pBRb#B35ANO)3o=MX3cJG?!6~Gb2MNPiN6Q{Hr29eanFV`#P(?<8oAZ$tpolv4${DnR9{U$CpoLhi3}Q*uw=w`yADs*q^)U`6~} zm^X!v8b%CEz~?`{9QYhfHvk>!lB)m|8@O21?l%J`FQ zzMx5ahv7_r#$PLqDZr-KgRg$MQU<&TXTN= z7RKK?@=m|#&gySZa@KQ_4dmOeXZ*>DPkwY5#0Nw({uX42*sjkIJ~@}w-x$K* zv(v+P8_4*>-fA4XcLd+#BCEes>r{Am)d)WBf(6M5^`SEGZLXT?Z+n+ta7TV0cQ=ai z=SAXbfTc!_#U8f&<4IvevnE)Kv`+8KKYe+phg+Wz;hh^uvn6>d6DRPeqZ z+>!IIGhnYq{%Ecwg6Hlf$f493rjBL&S!@nMnZ4}cDt(%JI-Wwv)(=AqzgWSLM>SAuV_dL8fW@ai{9wGQv!UlRmf^gy4GwR<1h6B zL1ZV7{*U!{2l$ZD?WPa zz4XoN&tx66d4X@^1cCbBtGQ$Gh}v1)@lLw$ISMei*cAsv?&RVd&%jOm75Xg43!@P? zBZl#pyDAA+z3;_$OD`vL(7j9o|3-*-%S4`6Dk*B7BSXBR!W|vlB!wp;sn*0_gJua@lV}ai)=-V|EC}5$twpDI41nuBzOpMUR)wOn zQd1Z^Psf^4sF)FpVg^az+sQf&2~)!o&`Fnj+WDWTMq8P1_hT~hYIjom3ugS8x~ouh z-(}jhA9Quv6g-#@k=dYUnHTpn{v2Q-GD_>{sJ~VBPE5Z44K7Ee-ZQ0DHjKX`gD+|( zC1GeXoblI66DWUOx(>&Ol4H2!)y)4o*W2R{=cBpc6y5ub0{DJ*#b!${aC0`(b$F4y zM1h_){BiaByshi*z7fw+&l||p2UCkKc73( zkz<7|JUxc-hwiu|m-$w3L<{}~j9P*orS*n+pR1blCkS0Xu0_+A8$n*h_}hs?(aySm zw7*`R&}VM)zr6YW z*r;CzUUvQo(Uc_jabSpJB7SBj<8K;J{>+J%1P1I7vA&Ul_q)OB@38G6?x`Qb8|`QO z%{wh(OY^O~KF8{>Ir;x%*4wT8k{14VQ7humKSFt1$oQLBD&n!(Vf;WTd;iEvP~qN& z5&Yd2{`dV%H4eTW$rt2P{*>e@0W<8?xc*cW-*JE?xo_g{5m<(=m3QL<9NQ9&dR+zh zHh*zsZ8aeKmK*-1gCRfsmIaBSY}djYy|VacVm{I)z9_e)!tO2^9q7O5`+f*`a=8U=GVT$-t(a-Egfpxw%~7y zODr0HXs_0_@r*=076}K=C7?lnP0$vvV)eI+Dj9inxt7^^6Z5}wqV)ht$jKL(NS zD(1WR@C3%6PpXJ}&)Cf$a$)`#Y^=g1(mi~yK<0mPHSyCsdw7&h`J;>@Fr4)Dy&lg$ z>BQ>q(p@SXI4yy{oy+|1v8Nh$8=b^kJ!AYG{!xuRKPK_FttA< z+7$dB+3slDjQ?r%k5Sg-WDbdN*KG+p^ZJq2W{!?YLGXRG7VRH&LF?0kze66OXouT& zt)79-!<0h)!x*&UnX}e7k8(;I5;_w9%i3@ub8sGwqpeIBbtoBaZgblqF1V3Nfrvmo zRYaG!^^%?`(s61E78Pkx@X3SHwPP56m!IB3{ow~`R37DzGHwmx<6fblT03sq6UN`S a>!oor&`KlodV%bMAFxx`Y7YA{{{9b?zf-^f literal 0 HcmV?d00001 diff --git a/tests/test_codebase/test_mmdet3d/data/pointpillars.py b/tests/test_codebase/test_mmdet3d/data/pointpillars.py new file mode 100644 index 000000000..599dcaa70 --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/data/pointpillars.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = [ + 'pointpillars_hv_secfpn_kitti.py', 'kitti-3d-3class.py', 'cyclic-40e.py', + 'default_runtime.py' +] + +point_cloud_range = [0, -39.68, -3, 69.12, 39.68, 1] +# dataset settings +data_root = 'data/kitti/' +class_names = ['Pedestrian', 'Cyclist', 'Car'] +metainfo = dict(CLASSES=class_names) + +# PointPillars adopted a different sampling strategies among classes +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'kitti_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)), + classes=class_names, + sample_groups=dict(Car=15, Pedestrian=15, Cyclist=15), + points_loader=dict( + type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4)) + +# PointPillars uses different augmentation hyper parameters +train_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), + dict(type='ObjectSample', db_sampler=db_sampler, use_ground_plane=True), + dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.78539816, 0.78539816], + scale_ratio_range=[0.95, 1.05]), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='PointShuffle'), + dict( + type='Pack3DDetInputs', + keys=['points', 'gt_labels_3d', 'gt_bboxes_3d']) +] +test_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='PointsRangeFilter', point_cloud_range=point_cloud_range) + ]), + dict(type='Pack3DDetInputs', keys=['points']) +] + +train_dataloader = dict( + dataset=dict(dataset=dict(pipeline=train_pipeline, metainfo=metainfo))) +test_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo)) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo)) +# In practice PointPillars also uses a different schedule +# optimizer +lr = 0.001 +epoch_num = 80 +optim_wrapper = dict( + optimizer=dict(lr=lr), clip_grad=dict(max_norm=35, norm_type=2)) +param_scheduler = [ + dict( + type='CosineAnnealingLR', + T_max=epoch_num * 0.4, + eta_min=lr * 10, + begin=0, + end=epoch_num * 0.4, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=epoch_num * 0.6, + eta_min=lr * 1e-4, + begin=epoch_num * 0.4, + end=epoch_num * 1, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingMomentum', + T_max=epoch_num * 0.4, + eta_min=0.85 / 0.95, + begin=0, + end=epoch_num * 0.4, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingMomentum', + T_max=epoch_num * 0.6, + eta_min=1, + begin=epoch_num * 0.4, + end=epoch_num * 1, + convert_to_iter_based=True) +] +# max_norm=35 is slightly better than 10 for PointPillars in the earlier +# development of the codebase thus we keep the setting. But we does not +# specifically tune this parameter. +# PointPillars usually need longer schedule than second, we simply double +# the training schedule. Do remind that since we use RepeatDataset and +# repeat factor is 2, so we actually train 160 epochs. +train_cfg = dict(by_epoch=True, max_epochs=epoch_num, val_interval=2) +val_cfg = dict() +test_cfg = dict() diff --git a/tests/test_codebase/test_mmdet3d/data/pointpillars_hv_secfpn_kitti.py b/tests/test_codebase/test_mmdet3d/data/pointpillars_hv_secfpn_kitti.py new file mode 100644 index 000000000..82bb467f9 --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/data/pointpillars_hv_secfpn_kitti.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +voxel_size = [0.16, 0.16, 4] + +model = dict( + type='VoxelNet', + data_preprocessor=dict( + type='Det3DDataPreprocessor', + voxel=True, + voxel_layer=dict( + max_num_points=32, # max_points_per_voxel + point_cloud_range=[0, -39.68, -3, 69.12, 39.68, 1], + voxel_size=voxel_size, + max_voxels=(16000, 40000))), + voxel_encoder=dict( + type='PillarFeatureNet', + in_channels=4, + feat_channels=[64], + with_distance=False, + voxel_size=voxel_size, + point_cloud_range=[0, -39.68, -3, 69.12, 39.68, 1]), + middle_encoder=dict( + type='PointPillarsScatter', in_channels=64, output_shape=[496, 432]), + backbone=dict( + type='SECOND', + in_channels=64, + layer_nums=[3, 5, 5], + layer_strides=[2, 2, 2], + out_channels=[64, 128, 256]), + neck=dict( + type='SECONDFPN', + in_channels=[64, 128, 256], + upsample_strides=[1, 2, 4], + out_channels=[128, 128, 128]), + bbox_head=dict( + type='Anchor3DHead', + num_classes=3, + in_channels=384, + feat_channels=384, + use_direction_classifier=True, + assign_per_class=True, + anchor_generator=dict( + type='AlignedAnchor3DRangeGenerator', + ranges=[ + [0, -39.68, -0.6, 69.12, 39.68, -0.6], + [0, -39.68, -0.6, 69.12, 39.68, -0.6], + [0, -39.68, -1.78, 69.12, 39.68, -1.78], + ], + sizes=[[0.8, 0.6, 1.73], [1.76, 0.6, 1.73], [3.9, 1.6, 1.56]], + rotations=[0, 1.57], + reshape_out=False), + diff_rad_by_sin=True, + bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), + loss_cls=dict( + type='mmdet.FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict( + type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), + loss_dir=dict( + type='mmdet.CrossEntropyLoss', use_sigmoid=False, + loss_weight=0.2)), + # model training and testing settings + train_cfg=dict( + assigner=[ + dict( # for Pedestrian + type='Max3DIoUAssigner', + iou_calculator=dict(type='mmdet3d.BboxOverlapsNearest3D'), + pos_iou_thr=0.5, + neg_iou_thr=0.35, + min_pos_iou=0.35, + ignore_iof_thr=-1), + dict( # for Cyclist + type='Max3DIoUAssigner', + iou_calculator=dict(type='mmdet3d.BboxOverlapsNearest3D'), + pos_iou_thr=0.5, + neg_iou_thr=0.35, + min_pos_iou=0.35, + ignore_iof_thr=-1), + dict( # for Car + type='Max3DIoUAssigner', + iou_calculator=dict(type='mmdet3d.BboxOverlapsNearest3D'), + pos_iou_thr=0.6, + neg_iou_thr=0.45, + min_pos_iou=0.45, + ignore_iof_thr=-1), + ], + allowed_border=0, + pos_weight=-1, + debug=False), + test_cfg=dict( + use_rotate_nms=True, + nms_across_levels=False, + nms_thr=0.01, + score_thr=0.1, + min_bbox_size=0, + nms_pre=100, + max_num=50)) diff --git a/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py b/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py index a30c7ba2f..669df8da0 100644 --- a/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py +++ b/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -import mmcv +import mmengine import numpy as np import pytest import torch +from mmdeploy.apis import build_task_processor from mmdeploy.codebase import import_codebase from mmdeploy.utils import Backend, Codebase, Task, load_config from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs @@ -46,7 +47,7 @@ def test_pillar_encoder(backend_type: Backend): model = get_pillar_encoder() model.cpu().eval() - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict( @@ -58,7 +59,7 @@ def test_pillar_encoder(backend_type: Backend): features = torch.rand(3945, 32, 4) * 100 num_points = torch.randint(0, 32, (3945, ), dtype=torch.int32) coors = torch.randint(0, 10, (3945, 4), dtype=torch.int32) - model_outputs = [model.forward(features, num_points, coors)] + model_outputs = model.forward(features, num_points, coors) wrapped_model = WrapModel(model, 'forward') rewrite_inputs = { 'features': features, @@ -84,7 +85,7 @@ def test_pointpillars_scatter(backend_type: Backend): model = get_pointpillars_scatter() model.cpu().eval() - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict( @@ -95,7 +96,7 @@ def test_pointpillars_scatter(backend_type: Backend): type=Codebase.MMDET3D.value, task=Task.VOXEL_DETECTION.value))) voxel_features = torch.rand(16 * 16, 64) * 100 coors = torch.randint(0, 10, (16 * 16, 4), dtype=torch.int32) - model_outputs = [model.forward_batch(voxel_features, coors, 1)] + model_outputs = model.forward_batch(voxel_features, coors, 1) wrapped_model = WrapModel(model, 'forward_batch') rewrite_inputs = {'voxel_features': voxel_features, 'coors': coors} rewrite_outputs, is_backend_output = get_rewrite_outputs( @@ -129,13 +130,13 @@ def get_centerpoint_head(): @pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) -def test_centerpoint(backend_type: Backend): - from mmdeploy.codebase.mmdet3d.deploy.voxel_detection import VoxelDetection +def test_pointpillars(backend_type: Backend): from mmdeploy.core import RewriterContext check_backend(backend_type, True) - model = get_centerpoint() - model.cpu().eval() - deploy_cfg = mmcv.Config( + + model_cfg = load_config( + 'tests/test_codebase/test_mmdet3d/data/model_cfg.py')[0] + deploy_cfg = mmengine.Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict( @@ -145,19 +146,22 @@ def test_centerpoint(backend_type: Backend): output_names=['outputs']), codebase_config=dict( type=Codebase.MMDET3D.value, task=Task.VOXEL_DETECTION.value))) - voxeldetection = VoxelDetection(model_cfg, deploy_cfg, 'cpu') - inputs, data = voxeldetection.create_input( - 'tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin') + + task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu') + model = task_processor.build_pytorch_model(None) + model.eval() + + preproc = task_processor.build_data_preprocessor() + _, data = task_processor.create_input( + pcd='tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin', + data_preprocessor=preproc) with RewriterContext( cfg=deploy_cfg, backend=deploy_cfg.backend_config.type, opset=deploy_cfg.onnx_config.opset_version): - outputs = model.forward(*data) - head = get_centerpoint_head() - rewrite_outputs = head.get_bboxes(*[[i] for i in outputs], - inputs['img_metas'][0]) - assert rewrite_outputs is not None + outputs = model.forward(data) + assert len(outputs) == 3 def get_pointpillars_nus(): @@ -169,13 +173,15 @@ def get_pointpillars_nus(): @pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) -def test_pointpillars_nus(backend_type: Backend): - from mmdeploy.codebase.mmdet3d.deploy.voxel_detection import VoxelDetection +def test_centerpoint(backend_type: Backend): from mmdeploy.core import RewriterContext check_backend(backend_type, True) - model = get_pointpillars_nus() - model.cpu().eval() - deploy_cfg = mmcv.Config( + + centerpoint_model_cfg = load_config( + 'tests/test_codebase/test_mmdet3d/data/centerpoint_pillar02_second_secfpn_head-circlenms_8xb4-cyclic-20e_nus-3d.py' # noqa: E501 + )[0] + + deploy_cfg = mmengine.Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict( @@ -185,13 +191,21 @@ def test_pointpillars_nus(backend_type: Backend): output_names=['outputs']), codebase_config=dict( type=Codebase.MMDET3D.value, task=Task.VOXEL_DETECTION.value))) - voxeldetection = VoxelDetection(model_cfg, deploy_cfg, 'cpu') - inputs, data = voxeldetection.create_input( - 'tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin') + + task_processor = build_task_processor(centerpoint_model_cfg, deploy_cfg, + 'cpu') + model = task_processor.build_pytorch_model(None) + model.eval() + + preproc = task_processor.build_data_preprocessor() + _, data = task_processor.create_input( + pcd= # noqa: E251 + 'tests/test_codebase/test_mmdet3d/data/nuscenes/n008-2018-09-18-12-07-26-0400__LIDAR_TOP__1537287083900561.pcd.bin', # noqa: E501 + data_preprocessor=preproc) with RewriterContext( cfg=deploy_cfg, backend=deploy_cfg.backend_config.type, opset=deploy_cfg.onnx_config.opset_version): - outputs = model.forward(*data) + outputs = model.forward(data) assert outputs is not None diff --git a/tests/test_codebase/test_mmdet3d/test_voxel_detection.py b/tests/test_codebase/test_mmdet3d/test_voxel_detection.py index b43f8b45c..5d60e81aa 100644 --- a/tests/test_codebase/test_mmdet3d/test_voxel_detection.py +++ b/tests/test_codebase/test_mmdet3d/test_voxel_detection.py @@ -1,12 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os from tempfile import NamedTemporaryFile, TemporaryDirectory -import mmcv +import mmengine import pytest import torch -from torch.utils.data import DataLoader -from torch.utils.data.dataset import Dataset import mmdeploy.backend.onnxruntime as ort_apis from mmdeploy.apis import build_task_processor @@ -22,8 +19,9 @@ except ImportError: model_cfg_path = 'tests/test_codebase/test_mmdet3d/data/model_cfg.py' pcd_path = 'tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin' + model_cfg = load_config(model_cfg_path)[0] -deploy_cfg = mmcv.Config( +deploy_cfg = mmengine.Config( dict( backend_config=dict(type='onnxruntime'), codebase_config=dict(type='mmdet3d', task='VoxelDetection'), @@ -34,7 +32,7 @@ deploy_cfg = mmcv.Config( opset_version=11, input_shape=None, input_names=['voxels', 'num_points', 'coors'], - output_names=['scores', 'bbox_preds', 'dir_scores']))) + output_names=['cls_score', 'bbox_pred', 'dir_cls_pred']))) onnx_file = NamedTemporaryFile(suffix='.onnx').name task_processor = None @@ -58,9 +56,9 @@ def backend_model(): wrapper = SwitchBackendWrapper(ORTWrapper) wrapper.set( outputs={ - 'scores': torch.rand(1, 18, 32, 32), - 'bbox_preds': torch.rand(1, 42, 32, 32), - 'dir_scores': torch.rand(1, 12, 32, 32) + 'cls_score': torch.rand(1, 18, 32, 32), + 'bbox_pred': torch.rand(1, 42, 32, 32), + 'dir_cls_pred': torch.rand(1, 12, 32, 32) }) yield task_processor.build_backend_model(['']) @@ -85,74 +83,15 @@ def test_create_input(device): task_processor.device = original_device -@pytest.mark.skipif( - reason='Only support GPU test', condition=not torch.cuda.is_available()) -def test_run_inference(backend_model): - task_processor.device = 'cuda:0' - torch_model = task_processor.build_pytorch_model(None) - input_dict, _ = task_processor.create_input(pcd_path) - torch_results = task_processor.run_inference(torch_model, input_dict) - backend_results = task_processor.run_inference(backend_model, input_dict) - assert torch_results is not None - assert backend_results is not None - assert len(torch_results[0]) == len(backend_results[0]) - task_processor.device = 'cpu' - - -@pytest.mark.skipif( - reason='Only support GPU test', condition=not torch.cuda.is_available()) -def test_visualize(): - task_processor.device = 'cuda:0' - input_dict, _ = task_processor.create_input(pcd_path) - torch_model = task_processor.build_pytorch_model(None) - results = task_processor.run_inference(torch_model, input_dict) - with TemporaryDirectory() as dir: - filename = dir + 'tmp.bin' - task_processor.visualize(torch_model, pcd_path, results[0], filename, - 'test', False) - assert os.path.exists(filename) - task_processor.device = 'cpu' - - -def test_build_dataset_and_dataloader(): - dataset = task_processor.build_dataset( - dataset_cfg=model_cfg, dataset_type='test') - assert isinstance(dataset, Dataset), 'Failed to build dataset' - dataloader = task_processor.build_dataloader(dataset, 1, 1) - assert isinstance(dataloader, DataLoader), 'Failed to build dataloader' - - @pytest.mark.skipif( reason='Only support GPU test', condition=not torch.cuda.is_available()) def test_single_gpu_test_and_evaluate(): - from mmcv.parallel import MMDataParallel task_processor.device = 'cuda:0' - class DummyDataset(Dataset): - - def __getitem__(self, index): - return 0 - - def __len__(self): - return 0 - - def evaluate(self, *args, **kwargs): - return 0 - - def format_results(self, *args, **kwargs): - return 0 - - dataset = DummyDataset() - # Prepare dataloader - dataloader = DataLoader(dataset) - # Prepare dummy model model = DummyModel(outputs=[torch.rand([1, 10, 5]), torch.rand([1, 10])]) - model = MMDataParallel(model, device_ids=[0]) + + assert model is not None # Run test - outputs = task_processor.single_gpu_test(model, dataloader) - assert isinstance(outputs, list) - output_file = NamedTemporaryFile(suffix='.pkl').name - task_processor.evaluate_outputs( - model_cfg, outputs, dataset, 'bbox', out=output_file, format_only=True) - task_processor.device = 'cpu' + with TemporaryDirectory() as dir: + task_processor.build_test_runner(model, dir) diff --git a/tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py b/tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py index 5946f7b76..99af5ee84 100644 --- a/tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py +++ b/tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py @@ -1,7 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -import os.path as osp - -import mmcv +import mmengine import pytest import torch @@ -15,7 +13,8 @@ try: except ImportError: pytest.skip( f'{Codebase.MMDET3D} is not installed.', allow_module_level=True) -from mmdeploy.codebase.mmdet3d.deploy.voxel_detection import VoxelDetection +from mmdeploy.codebase.mmdet3d.deploy.voxel_detection_model import \ + VoxelDetectionModel pcd_path = 'tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin' model_cfg = 'tests/test_codebase/test_mmdet3d/data/model_cfg.py' @@ -33,27 +32,25 @@ class TestVoxelDetectionModel: # simplify backend inference cls.wrapper = SwitchBackendWrapper(ORTWrapper) cls.outputs = { - 'scores': torch.rand(1, 18, 32, 32), - 'bbox_preds': torch.rand(1, 42, 32, 32), - 'dir_scores': torch.rand(1, 12, 32, 32) + 'cls_score': torch.rand(1, 18, 32, 32), + 'bbox_pred': torch.rand(1, 42, 32, 32), + 'dir_cls_pred': torch.rand(1, 12, 32, 32) } cls.wrapper.set(outputs=cls.outputs) - deploy_cfg = mmcv.Config({ + deploy_cfg = mmengine.Config({ 'onnx_config': { 'input_names': ['voxels', 'num_points', 'coors'], - 'output_names': ['scores', 'bbox_preds', 'dir_scores'], + 'output_names': ['cls_score', 'bbox_pred', 'dir_cls_pred'], 'opset_version': 11 }, 'backend_config': { - 'type': 'tensorrt' + 'type': 'onnxruntime' } }) from mmdeploy.utils import load_config model_cfg_path = 'tests/test_codebase/test_mmdet3d/data/model_cfg.py' model_cfg = load_config(model_cfg_path)[0] - from mmdeploy.codebase.mmdet3d.deploy.voxel_detection_model import \ - VoxelDetectionModel cls.end2end_model = VoxelDetectionModel( Backend.ONNXRUNTIME, [''], device='cuda', @@ -64,14 +61,15 @@ class TestVoxelDetectionModel: reason='Only support GPU test', condition=not torch.cuda.is_available()) def test_forward_and_show_result(self): - data = VoxelDetection.read_pcd_file(pcd_path, model_cfg, 'cuda') - results = self.end2end_model.forward(data['points'], data['img_metas']) + inputs = { + 'voxels': { + 'voxels': torch.rand((3945, 32, 4)), + 'num_points': torch.ones((3945), dtype=torch.int32), + 'coors': torch.ones((3945, 4), dtype=torch.int32) + } + } + results = self.end2end_model.forward(inputs=inputs) assert results is not None - from tempfile import TemporaryDirectory - with TemporaryDirectory() as dir: - self.end2end_model.show_result( - data, results, dir, 'backend_output.bin', show=False) - assert osp.exists(dir + '/backend_output.bin') @backend_checker(Backend.ONNXRUNTIME) @@ -79,11 +77,11 @@ def test_build_voxel_detection_model(): from mmdeploy.utils import load_config model_cfg_path = 'tests/test_codebase/test_mmdet3d/data/model_cfg.py' model_cfg = load_config(model_cfg_path)[0] - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( backend_config=dict(type=Backend.ONNXRUNTIME.value), onnx_config=dict( - output_names=['scores', 'bbox_preds', 'dir_scores']), + output_names=['cls_score', 'bbox_pred', 'dir_cls_pred']), codebase_config=dict(type=Codebase.MMDET3D.value))) from mmdeploy.backend.onnxruntime import ORTWrapper diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_core.py b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py index 3afb681a5..d774b7510 100644 --- a/tests/test_codebase/test_mmrotate/test_mmrotate_core.py +++ b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -import mmcv +import mmengine import numpy as np import pytest import torch @@ -20,7 +20,7 @@ except ImportError: @backend_checker(Backend.ONNXRUNTIME) def test_multiclass_nms_rotated(): from mmdeploy.codebase.mmrotate.core import multiclass_nms_rotated - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict( @@ -72,7 +72,7 @@ def test_multiclass_nms_rotated_with_keep_top_k(pre_top_k): from mmdeploy.codebase.mmrotate.core import multiclass_nms_rotated keep_top_k = 15 - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( onnx_config=dict( output_names=None, @@ -140,7 +140,7 @@ def test_delta_xywha_rbbox_coder_delta2bbox(backend_type: Backend, max_shape: tuple, proj_xy: bool, edge_swap: bool): check_backend(backend_type) - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict(type=backend_type.value, model_inputs=None), @@ -189,7 +189,7 @@ def test_delta_xywha_rbbox_coder_delta2bbox(backend_type: Backend, @pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) def test_delta_midpointoffset_rbbox_delta2bbox(backend_type: Backend): check_backend(backend_type) - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict(type=backend_type.value, model_inputs=None), @@ -227,7 +227,7 @@ def test_delta_midpointoffset_rbbox_delta2bbox(backend_type: Backend): @backend_checker(Backend.ONNXRUNTIME) def test_fake_multiclass_nms_rotated(): from mmdeploy.codebase.mmrotate.core import fake_multiclass_nms_rotated - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict( @@ -277,7 +277,7 @@ def test_fake_multiclass_nms_rotated(): def test_poly2obb_le90(backend_type: Backend): check_backend(backend_type) polys = torch.rand(1, 10, 8) - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict( @@ -316,7 +316,7 @@ def test_poly2obb_le90(backend_type: Backend): def test_poly2obb_le135(backend_type: Backend): check_backend(backend_type) polys = torch.rand(1, 10, 8) - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict( @@ -351,7 +351,7 @@ def test_poly2obb_le135(backend_type: Backend): def test_obb2poly_le135(backend_type: Backend): check_backend(backend_type) rboxes = torch.rand(1, 10, 5) - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict( @@ -386,7 +386,7 @@ def test_obb2poly_le135(backend_type: Backend): def test_gvfixcoder__decode(backend_type: Backend): check_backend(backend_type) - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( onnx_config=dict(output_names=['output'], input_shape=None), backend_config=dict(type=backend_type.value), diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py index 491832655..d3534e8d9 100644 --- a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py +++ b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py @@ -5,6 +5,7 @@ import random from typing import Dict, List import mmcv +import mmengine import numpy as np import pytest import torch @@ -49,7 +50,7 @@ def convert_to_list(rewrite_output: Dict, output_names: List[str]) -> List: def get_anchor_head_model(): """AnchorHead Config.""" - test_cfg = mmcv.Config( + test_cfg = mmengine.Config( dict( nms_pre=2000, min_bbox_size=0, @@ -80,7 +81,7 @@ def _replace_r50_with_r18(model): ['tests/test_codebase/test_mmrotate/data/single_stage_model.json']) def test_forward_of_base_detector(model_cfg_path, backend): check_backend(backend) - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( backend_config=dict(type=backend.value), onnx_config=dict( @@ -95,7 +96,7 @@ def test_forward_of_base_detector(model_cfg_path, backend): keep_top_k=100, )))) - model_cfg = mmcv.Config(dict(model=mmcv.load(model_cfg_path))) + model_cfg = mmengine.Config(dict(model=mmcv.load(model_cfg_path))) model_cfg.model = _replace_r50_with_r18(model_cfg.model) from mmrotate.models import build_detector @@ -117,7 +118,7 @@ def test_forward_of_base_detector(model_cfg_path, backend): def get_deploy_cfg(backend_type: Backend, ir_type: str): - return mmcv.Config( + return mmengine.Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict( @@ -221,7 +222,7 @@ def test_rotated_single_roi_extractor(backend_type: Backend): single_roi_extractor = get_single_roi_extractor() output_names = ['roi_feat'] - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict(output_names=output_names, input_shape=None), @@ -264,7 +265,7 @@ def test_rotated_single_roi_extractor(backend_type: Backend): def get_oriented_rpn_head_model(): """Oriented RPN Head Config.""" - test_cfg = mmcv.Config( + test_cfg = mmengine.Config( dict( nms_pre=2000, min_bbox_size=0, @@ -295,7 +296,7 @@ def test_get_bboxes_of_oriented_rpn_head(backend_type: Backend): }] output_names = ['dets', 'labels'] - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict(output_names=output_names, input_shape=None), @@ -336,7 +337,7 @@ def test_get_bboxes_of_oriented_rpn_head(backend_type: Backend): def get_rotated_rpn_head_model(): """Oriented RPN Head Config.""" - test_cfg = mmcv.Config( + test_cfg = mmengine.Config( dict( nms_pre=2000, min_bbox_size=0, @@ -376,7 +377,7 @@ def test_get_bboxes_of_rotated_rpn_head(backend_type: Backend): }] output_names = ['dets', 'labels'] - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict(output_names=output_names, input_shape=None), @@ -420,7 +421,7 @@ def test_rotate_standard_roi_head__simple_test(backend_type: Backend): check_backend(backend_type) from mmrotate.models.roi_heads import OrientedStandardRoIHead output_names = ['dets', 'labels'] - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict(output_names=output_names, input_shape=None), @@ -433,7 +434,7 @@ def test_rotate_standard_roi_head__simple_test(backend_type: Backend): pre_top_k=2000, keep_top_k=2000)))) angle_version = 'le90' - test_cfg = mmcv.Config( + test_cfg = mmengine.Config( dict( nms_pre=2000, min_bbox_size=0, @@ -488,7 +489,7 @@ def test_gv_ratio_roi_head__simple_test(backend_type: Backend): check_backend(backend_type) from mmrotate.models.roi_heads import GVRatioRoIHead output_names = ['dets', 'labels'] - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict(output_names=output_names, input_shape=None), @@ -502,7 +503,7 @@ def test_gv_ratio_roi_head__simple_test(backend_type: Backend): keep_top_k=2000, max_output_boxes_per_class=1000)))) angle_version = 'le90' - test_cfg = mmcv.Config( + test_cfg = mmengine.Config( dict( nms_pre=2000, min_bbox_size=0, @@ -615,7 +616,7 @@ def get_roi_trans_roi_head_model(): type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) ] - test_cfg = mmcv.Config( + test_cfg = mmengine.Config( dict( nms_pre=2000, min_bbox_size=0, @@ -659,7 +660,7 @@ def test_simple_test_of_roi_trans_roi_head(backend_type: Backend): } output_names = ['det_bboxes', 'det_labels'] - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict(output_names=output_names, input_shape=None), diff --git a/tests/test_codebase/test_mmrotate/test_rotated_detection.py b/tests/test_codebase/test_mmrotate/test_rotated_detection.py index e20b3f3f5..df2677d80 100644 --- a/tests/test_codebase/test_mmrotate/test_rotated_detection.py +++ b/tests/test_codebase/test_mmrotate/test_rotated_detection.py @@ -2,7 +2,7 @@ import os from tempfile import NamedTemporaryFile, TemporaryDirectory -import mmcv +import mmengine import numpy as np import pytest import torch @@ -23,7 +23,7 @@ except ImportError: model_cfg_path = 'tests/test_codebase/test_mmrotate/data/model.py' model_cfg = load_config(model_cfg_path)[0] -deploy_cfg = mmcv.Config( +deploy_cfg = mmengine.Config( dict( backend_config=dict(type='onnxruntime'), codebase_config=dict( @@ -123,7 +123,6 @@ def test_build_dataset_and_dataloader(): def test_single_gpu_test_and_evaluate(): - from mmcv.parallel import MMDataParallel class DummyDataset(Dataset): @@ -145,7 +144,6 @@ def test_single_gpu_test_and_evaluate(): # Prepare dummy model model = DummyModel(outputs=[torch.rand([1, 10, 6]), torch.rand([1, 10])]) - model = MMDataParallel(model, device_ids=[0]) # Run test outputs = task_processor.single_gpu_test(model, dataloader) assert isinstance(outputs, list) diff --git a/tests/test_codebase/test_mmrotate/test_rotated_detection_model.py b/tests/test_codebase/test_mmrotate/test_rotated_detection_model.py index d13617488..48b2558e6 100644 --- a/tests/test_codebase/test_mmrotate/test_rotated_detection_model.py +++ b/tests/test_codebase/test_mmrotate/test_rotated_detection_model.py @@ -2,7 +2,7 @@ import os.path as osp from tempfile import NamedTemporaryFile -import mmcv +import mmengine import numpy as np import pytest import torch @@ -37,7 +37,7 @@ class TestEnd2EndModel: 'labels': torch.rand(1, 10) } cls.wrapper.set(outputs=cls.outputs) - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( {'onnx_config': { 'output_names': ['dets', 'labels'] }}) @@ -90,7 +90,7 @@ class TestEnd2EndModel: def test_build_rotated_detection_model(): model_cfg_path = 'tests/test_codebase/test_mmrotate/data/model.py' model_cfg = load_config(model_cfg_path)[0] - deploy_cfg = mmcv.Config( + deploy_cfg = mmengine.Config( dict( backend_config=dict(type='onnxruntime'), onnx_config=dict(output_names=['dets', 'labels']), From c5edb8555015ee6b24de8d907dcd32cb2499586d Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Mon, 7 Nov 2022 10:19:22 +0800 Subject: [PATCH 2/4] Sync rv1126 to dev-1.x by cherry-pick (#1295) * remove imports (#1207) * remove imports * update doc * detailed docstring * rephrase * Add model conversion support to rv1126 (#1203) * WIP * fix interpolate * support yolov3 and retinanet * support seg * support ssd * supports both partition types for retinanet and ssd * mean std doc * update doc, add UT * support FSAF * rename configs * update dump info * update * python package installation doc * update doc * update doc * doc * fix * docstring * remote partition config --- configs/_base_/backends/rknn.py | 8 +- ... => classification_rknn_static-224x224.py} | 0 .../detection_rknn_static-320x320.py | 29 ++++ .../mmdet/detection/detection_rknn_static.py | 17 --- ...py => segmentation_rknn_static-320x320.py} | 4 +- docs/en/01-how-to-build/rockchip.md | 58 +++++++- docs/en/03-benchmark/supported_models.md | 6 +- docs/en/05-supported-backends/rknn.md | 6 +- docs/zh_cn/01-how-to-build/rockchip.md | 52 ++++++- docs/zh_cn/03-benchmark/supported_models.md | 6 +- docs/zh_cn/05-supported-backends/rknn.md | 6 +- mmdeploy/apis/onnx/partition.py | 6 +- mmdeploy/backend/sdk/export_info.py | 7 +- .../codebase/mmdet/deploy/object_detection.py | 13 +- .../mmdet/deploy/object_detection_model.py | 32 ++++- .../models/dense_heads/base_dense_head.py | 128 +++++++++++++++++- .../single_stage_text_detector.py | 2 +- .../models/text_recognition/base_decoder.py | 2 +- .../encoder_decoder_recognizer.py | 9 +- mmdeploy/core/optimizers/__init__.py | 4 +- mmdeploy/core/optimizers/optimize.py | 21 +++ mmdeploy/pytorch/functions/interpolate.py | 32 +++++ .../test_mmdet/test_mmdet_models.py | 74 ++++++++++ tests/test_pytorch/test_pytorch_functions.py | 24 ++++ 24 files changed, 482 insertions(+), 64 deletions(-) rename configs/mmcls/{classification_rknn_static.py => classification_rknn_static-224x224.py} (100%) create mode 100644 configs/mmdet/detection/detection_rknn_static-320x320.py delete mode 100644 configs/mmdet/detection/detection_rknn_static.py rename configs/mmseg/{segmentation_rknn_static.py => segmentation_rknn_static-320x320.py} (53%) diff --git a/configs/_base_/backends/rknn.py b/configs/_base_/backends/rknn.py index 3dcbbce1c..640a98265 100644 --- a/configs/_base_/backends/rknn.py +++ b/configs/_base_/backends/rknn.py @@ -1,8 +1,8 @@ backend_config = dict( type='rknn', common_config=dict( - mean_values=None, - std_values=None, - target_platform='rk3588', - optimization_level=3), + mean_values=None, # [[103.53, 116.28, 123.675]], + std_values=None, # [[57.375, 57.12, 58.395]], + target_platform='rv1126', # 'rk3588' + optimization_level=1), quantization_config=dict(do_quantization=False, dataset=None)) diff --git a/configs/mmcls/classification_rknn_static.py b/configs/mmcls/classification_rknn_static-224x224.py similarity index 100% rename from configs/mmcls/classification_rknn_static.py rename to configs/mmcls/classification_rknn_static-224x224.py diff --git a/configs/mmdet/detection/detection_rknn_static-320x320.py b/configs/mmdet/detection/detection_rknn_static-320x320.py new file mode 100644 index 000000000..63f221b68 --- /dev/null +++ b/configs/mmdet/detection/detection_rknn_static-320x320.py @@ -0,0 +1,29 @@ +_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py'] + +onnx_config = dict(input_shape=[320, 320]) + +codebase_config = dict(model_type='rknn') + +backend_config = dict(input_size_list=[[3, 320, 320]]) + +# yolov3, yolox +# partition_config = dict( +# type='rknn', # the partition policy name +# apply_marks=True, # should always be set to True +# partition_cfg=[ +# dict( +# save_file='model.onnx', # name to save the partitioned onnx +# start=['detector_forward:input'], # [mark_name:input, ...] +# end=['yolo_head:input']) # [mark_name:output, ...] +# ]) + +# # retinanet, ssd, fsaf +# partition_config = dict( +# type='rknn', # the partition policy name +# apply_marks=True, +# partition_cfg=[ +# dict( +# save_file='model.onnx', +# start='detector_forward:input', +# end=['BaseDenseHead:output']) +# ]) diff --git a/configs/mmdet/detection/detection_rknn_static.py b/configs/mmdet/detection/detection_rknn_static.py deleted file mode 100644 index 4e543ea49..000000000 --- a/configs/mmdet/detection/detection_rknn_static.py +++ /dev/null @@ -1,17 +0,0 @@ -_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py'] - -onnx_config = dict(input_shape=[640, 640]) - -codebase_config = dict(model_type='rknn') - -backend_config = dict(input_size_list=[[3, 640, 640]]) - -partition_config = dict( - type='rknn', # the partition policy name - apply_marks=True, # should always be set to True - partition_cfg=[ - dict( - save_file='model.onnx', # name to save the partitioned onnx model - start=['detector_forward:input'], # [mark_name:input/output, ...] - end=['yolo_head:input']) # [mark_name:input/output, ...] - ]) diff --git a/configs/mmseg/segmentation_rknn_static.py b/configs/mmseg/segmentation_rknn_static-320x320.py similarity index 53% rename from configs/mmseg/segmentation_rknn_static.py rename to configs/mmseg/segmentation_rknn_static-320x320.py index cd99fb614..2bb908234 100644 --- a/configs/mmseg/segmentation_rknn_static.py +++ b/configs/mmseg/segmentation_rknn_static-320x320.py @@ -1,7 +1,7 @@ _base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py'] -onnx_config = dict(input_shape=[512, 512]) +onnx_config = dict(input_shape=[320, 320]) codebase_config = dict(model_type='rknn') -backend_config = dict(input_size_list=[[3, 512, 512]]) +backend_config = dict(input_size_list=[[3, 320, 320]]) diff --git a/docs/en/01-how-to-build/rockchip.md b/docs/en/01-how-to-build/rockchip.md index d099914ba..6c22a2fbd 100644 --- a/docs/en/01-how-to-build/rockchip.md +++ b/docs/en/01-how-to-build/rockchip.md @@ -1,18 +1,26 @@ # Build for RKNN -This tutorial is based on Linux systems like Ubuntu-18.04 and Rockchip NPU like `rk3588`. +This tutorial is based on Ubuntu-18.04 and Rockchip NPU `rk3588`. For different NPU devices, you may have to use different rknn packages. +Below is a table describing the relationship: + +| Device | Python Package | c/c++ SDK | +| -------------------- | ---------------------------------------------------------------- | -------------------------------------------------- | +| RK1808/RK1806 | [rknn-toolkit](https://github.com/rockchip-linux/rknn-toolkit) | [rknpu](https://github.com/rockchip-linux/rknpu) | +| RV1109/RV1126 | [rknn-toolkit](https://github.com/rockchip-linux/rknn-toolkit) | [rknpu](https://github.com/rockchip-linux/rknpu) | +| RK3566/RK3568/RK3588 | [rknn-toolkit2](https://github.com/rockchip-linux/rknn-toolkit2) | [rknpu2](https://github.com/rockchip-linux/rknpu2) | +| RV1103/RV1106 | [rknn-toolkit2](https://github.com/rockchip-linux/rknn-toolkit2) | [rknpu2](https://github.com/rockchip-linux/rknpu2) | ## Installation It is recommended to create a virtual environment for the project. -1. get RKNN-Toolkit2 through: +1. Get RKNN-Toolkit2 or RKNN-Toolkit through git. RKNN-Toolkit2 for example: ``` git clone git@github.com:rockchip-linux/rknn-toolkit2.git ``` -2. install RKNN python package following [official doc](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc). In our testing, we used the rknn-toolkit2 1.2.0 with commit id `834ba0b0a1ab8ee27024443d77b02b5ba48b67fc`. When installing rknn-toolkit2, it is better to append `--no-deps` after the commands to avoid dependency conflicts. For example: +2. Install RKNN python package following [rknn-toolkit2 doc](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc) or [rknn-toolkit doc](https://github.com/rockchip-linux/rknn-toolkit/tree/master/doc). When installing rknn python package, it is better to append `--no-deps` after the commands to avoid dependency conflicts. RKNN-Toolkit2 package for example: ``` pip install packages/rknn_toolkit2-1.2.0_f7bb160f-cp36-cp36m-linux_x86_64.whl --no-deps @@ -67,17 +75,19 @@ backend_config = dict( ``` -The contents of `common_config` are for `rknn.config()`. The contents of `quantization_config` are used to control `rknn.build()`. +The contents of `common_config` are for `rknn.config()`. The contents of `quantization_config` are used to control `rknn.build()`. You may have to modify `target_platform` for your own preference. ## Build SDK with Rockchip NPU -1. get rknpu2 through: +### Build SDK with RKNPU2 + +1. Get rknpu2 through git: ``` git clone git@github.com:rockchip-linux/rknpu2.git ``` -2. for linux, download gcc cross compiler. The download link of the compiler from the official user guide of `rknpu2` was deprecated. You may use another verified [link](https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu). After download and unzip the compiler, you may open the terminal, set `RKNN_TOOL_CHAIN` and `RKNPU2_DEVICE_DIR` by `export RKNN_TOOL_CHAIN=/path/to/gcc/usr;export RKNPU2_DEVICE_DIR=/path/to/rknpu2/runtime/RK3588`. +2. For linux, download gcc cross compiler. The download link of the compiler from the official user guide of `rknpu2` was deprecated. You may use another verified [link](https://github.com/Caesar-github/gcc-buildroot-9.3.0-2020.03-x86_64_aarch64-rockchip-linux-gnu). After download and unzip the compiler, you may open the terminal, set `RKNN_TOOL_CHAIN` and `RKNPU2_DEVICE_DIR` by `export RKNN_TOOL_CHAIN=/path/to/gcc/usr;export RKNPU2_DEVICE_DIR=/path/to/rknpu2/runtime/RK3588`. 3. after the above preparition, run the following commands: @@ -144,4 +154,38 @@ label: 65, score: 0.95 mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True) ``` - Besides, the `mean_values` and `std_values` of deploy_cfg should be replaced with original normalization settings of `model_cfg`. Let `mean_values=[123.675, 116.28, 103.53]` and `std_values=[58.395, 57.12, 57.375]`. + Besides, the `mean_values` and `std_values` of deploy_cfg should be replaced with original normalization settings of `model_cfg`. Let `mean_values=[[103.53, 116.28, 123.675]]` and `std_values=[[57.375, 57.12, 58.395]]`. + +- MMDet models. + + YOLOV3 & YOLOX: you may paste the following partition configuration into [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py): + + ```python + # yolov3, yolox + partition_config = dict( + type='rknn', # the partition policy name + apply_marks=True, # should always be set to True + partition_cfg=[ + dict( + save_file='model.onnx', # name to save the partitioned onnx + start=['detector_forward:input'], # [mark_name:input, ...] + end=['yolo_head:input']) # [mark_name:output, ...] + ]) + ``` + + RetinaNet & SSD & FSAF with rknn-toolkit2, you may paste the following partition configuration into [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py). Users with rknn-toolkit can directly use default config. + + ```python + # retinanet, ssd + partition_config = dict( + type='rknn', # the partition policy name + apply_marks=True, + partition_cfg=[ + dict( + save_file='model.onnx', + start='detector_forward:input', + end=['BaseDenseHead:output']) + ]) + ``` + +- SDK only supports int8 rknn model, which require `do_quantization=True` when converting models. diff --git a/docs/en/03-benchmark/supported_models.md b/docs/en/03-benchmark/supported_models.md index 9c42ef271..51ac08065 100644 --- a/docs/en/03-benchmark/supported_models.md +++ b/docs/en/03-benchmark/supported_models.md @@ -4,14 +4,14 @@ The table below lists the models that are guaranteed to be exportable to other b | Model | Codebase | TorchScript | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | Ascend | RKNN | Model config | | :-------------------------- | :--------------- | :---------: | :---------: | :------: | :--: | :---: | :------: | :----: | :--: | :---------------------------------------------------------------------------------------------: | -| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) | +| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) | | Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) | | YOLOv3 | MMDetection | Y | Y | Y | Y | N | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) | | YOLOX | MMDetection | Y | Y | Y | Y | N | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) | | FCOS | MMDetection | Y | Y | Y | Y | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fcos) | -| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) | +| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) | | Mask R-CNN | MMDetection | Y | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) | -| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) | +| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) | | FoveaBox | MMDetection | Y | Y | N | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) | | ATSS | MMDetection | N | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) | | GFL | MMDetection | N | Y | Y | N | ? | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) | diff --git a/docs/en/05-supported-backends/rknn.md b/docs/en/05-supported-backends/rknn.md index 4c6c50f9e..cc4b8cbe9 100644 --- a/docs/en/05-supported-backends/rknn.md +++ b/docs/en/05-supported-backends/rknn.md @@ -1,9 +1,9 @@ # Supported RKNN feature -Currently, MMDeploy only tests rk3588 with linux platform. +Currently, MMDeploy only tests rk3588 and rv1126 with linux platform. The following features cannot be automatically enabled by mmdeploy and you need to manually modify the configuration in MMDeploy like [here](https://github.com/open-mmlab/mmdeploy/blob/master/configs/_base_/backends/rknn.py). -- target_platform other than `3588` +- target_platform other than default - quantization settings -- optimization level other than 3 +- optimization level other than 1 diff --git a/docs/zh_cn/01-how-to-build/rockchip.md b/docs/zh_cn/01-how-to-build/rockchip.md index 0161a972a..7fd36efe2 100644 --- a/docs/zh_cn/01-how-to-build/rockchip.md +++ b/docs/zh_cn/01-how-to-build/rockchip.md @@ -1,18 +1,26 @@ # 支持 RKNN -本教程基于 Ubuntu-18.04 和 Rockchip `rk3588` NPU。 +本教程基于 Ubuntu-18.04 和 Rockchip `rk3588` NPU。对于不同的 NPU 设备,您需要使用不同的 rknn 包. +这是设备和安装包的关系表: + +| Device | Python Package | c/c++ SDK | +| -------------------- | ---------------------------------------------------------------- | -------------------------------------------------- | +| RK1808/RK1806 | [rknn-toolkit](https://github.com/rockchip-linux/rknn-toolkit) | [rknpu](https://github.com/rockchip-linux/rknpu) | +| RV1109/RV1126 | [rknn-toolkit](https://github.com/rockchip-linux/rknn-toolkit) | [rknpu](https://github.com/rockchip-linux/rknpu) | +| RK3566/RK3568/RK3588 | [rknn-toolkit2](https://github.com/rockchip-linux/rknn-toolkit2) | [rknpu2](https://github.com/rockchip-linux/rknpu2) | +| RV1103/RV1106 | [rknn-toolkit2](https://github.com/rockchip-linux/rknn-toolkit2) | [rknpu2](https://github.com/rockchip-linux/rknpu2) | ## 安装 建议为项目创建一个虚拟环境。 -1. 获取 RKNN-Toolkit2: +1. 使用 git 获取 RKNN-Toolkit2 或者 RKNN-Toolkit。以 RKNN-Toolkit2 为例: ``` git clone git@github.com:rockchip-linux/rknn-toolkit2.git ``` -2. 通过 [官方文档](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc),安装 RKNN python 安装包. 在我们的测试中, 使用的 rknn-toolkit 版本是 1.2.0,commit id `834ba0b0a1ab8ee27024443d77b02b5ba48b67fc`。安装 rknn-toolkit2 时,最好在安装命令后添加`--no-deps`,以避免依赖包的冲突。比如: +2. 通过 [rknn-toolkit2 文档](https://github.com/rockchip-linux/rknn-toolkit2/tree/master/doc) 或者 [rknn-toolkit 文档](https://github.com/rockchip-linux/rknn-toolkit/tree/master/doc)安装 RKNN python 安装包。安装 rknn python 包时,最好在安装命令后添加`--no-deps`,以避免依赖包的冲突。以rknn-toolkit2为例: ``` pip install packages/rknn_toolkit2-1.2.0_f7bb160f-cp36-cp36m-linux_x86_64.whl --no-deps @@ -71,6 +79,8 @@ backend_config = dict( ## 安装 SDK +### RKNPU2 编译 MMDeploy SDK + 1. 获取 rknpu2: ``` @@ -144,4 +154,38 @@ label: 65, score: 0.95 mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True) ``` - 此外, deploy_cfg 的 `mean_values` 和 `std_values` 应该被设置为 `model_cfg` 中归一化的设置. 使 `mean_values=[123.675, 116.28, 103.53]`, `std_values=[58.395, 57.12, 57.375]`。 + 此外, deploy_cfg 的 `mean_values` 和 `std_values` 应该被设置为 `model_cfg` 中归一化的设置. 使 `mean_values=[[103.53, 116.28, 123.675]]`, `std_values=[[57.375, 57.12, 58.395]]`。 + +- MMDet 模型. + + YOLOV3 & YOLOX: 将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py): + + ```python + # yolov3, yolox + partition_config = dict( + type='rknn', # the partition policy name + apply_marks=True, # should always be set to True + partition_cfg=[ + dict( + save_file='model.onnx', # name to save the partitioned onnx + start=['detector_forward:input'], # [mark_name:input, ...] + end=['yolo_head:input']) # [mark_name:output, ...] + ]) + ``` + + RetinaNet & SSD & FSAF with rknn-toolkit2, 将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py)。使用 rknn-toolkit 的用户则不用。 + + ```python + # retinanet, ssd + partition_config = dict( + type='rknn', # the partition policy name + apply_marks=True, + partition_cfg=[ + dict( + save_file='model.onnx', + start='detector_forward:input', + end=['BaseDenseHead:output']) + ]) + ``` + +- SDK 只支持 int8 的 rknn 模型,这需要在转换模型时设置 `do_quantization=True`。 diff --git a/docs/zh_cn/03-benchmark/supported_models.md b/docs/zh_cn/03-benchmark/supported_models.md index 77e327d4c..66a5c9279 100644 --- a/docs/zh_cn/03-benchmark/supported_models.md +++ b/docs/zh_cn/03-benchmark/supported_models.md @@ -4,14 +4,14 @@ | Model | Codebase | TorchScript | OnnxRuntime | TensorRT | ncnn | PPLNN | OpenVINO | Ascend | RKNN | Model config | | :-------------------------- | :--------------- | :---------: | :---------: | :------: | :--: | :---: | :------: | :----: | :--: | :---------------------------------------------------------------------------------------------: | -| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) | +| RetinaNet | MMDetection | Y | Y | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/retinanet) | | Faster R-CNN | MMDetection | Y | Y | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) | | YOLOv3 | MMDetection | Y | Y | Y | Y | N | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolo) | | YOLOX | MMDetection | Y | Y | Y | Y | N | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox) | | FCOS | MMDetection | Y | Y | Y | Y | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fcos) | -| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) | +| FSAF | MMDetection | Y | Y | Y | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/fsaf) | | Mask R-CNN | MMDetection | Y | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) | -| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) | +| SSD[\*](#note) | MMDetection | Y | Y | Y | Y | N | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) | | FoveaBox | MMDetection | Y | Y | N | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) | | ATSS | MMDetection | N | Y | Y | N | N | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) | | GFL | MMDetection | N | Y | Y | N | ? | Y | N | N | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) | diff --git a/docs/zh_cn/05-supported-backends/rknn.md b/docs/zh_cn/05-supported-backends/rknn.md index baf41e540..66f646906 100644 --- a/docs/zh_cn/05-supported-backends/rknn.md +++ b/docs/zh_cn/05-supported-backends/rknn.md @@ -1,9 +1,9 @@ # 支持的 RKNN 特征 -目前, MMDeploy 只在 rk3588 的 linux 平台上测试过. +目前, MMDeploy 只在 rk3588 和 rv1126 的 linux 平台上测试过. 以下特性需要手动在 MMDeploy 自行配置,如[这里](https://github.com/open-mmlab/mmdeploy/blob/master/configs/_base_/backends/rknn.py). -- target_platform != `3588` +- target_platform != default - quantization settings -- optimization level != 3 +- optimization level != 1 diff --git a/mmdeploy/apis/onnx/partition.py b/mmdeploy/apis/onnx/partition.py index 31e0663db..1360360b1 100644 --- a/mmdeploy/apis/onnx/partition.py +++ b/mmdeploy/apis/onnx/partition.py @@ -8,7 +8,8 @@ import onnx.utils from mmdeploy.apis.core import PIPELINE_MANAGER from mmdeploy.core.optimizers import (attribute_to_dict, create_extractor, get_new_name, parse_extractor_io_string, - remove_identity, rename_value) + remove_identity, remove_imports, + rename_value) from mmdeploy.utils import get_root_logger @@ -198,6 +199,9 @@ def extract_partition(model: Union[str, onnx.ModelProto], dim.dim_value = 0 dim.dim_param = f'dim_{idx}' + # remove mmdeploy domain if useless + remove_imports(extracted_model) + # save extract_model if save_file is given if save_file is not None: onnx.save(extracted_model, save_file) diff --git a/mmdeploy/backend/sdk/export_info.py b/mmdeploy/backend/sdk/export_info.py index 1011d259f..fcf978d78 100644 --- a/mmdeploy/backend/sdk/export_info.py +++ b/mmdeploy/backend/sdk/export_info.py @@ -132,7 +132,7 @@ def get_inference_info(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config, name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir, device) ir_config = get_ir_config(deploy_cfg) backend = get_backend(deploy_cfg=deploy_cfg) - if backend == Backend.TORCHSCRIPT: + if backend in (Backend.TORCHSCRIPT, Backend.RKNN): output_names = ir_config.get('output_names', None) input_map = dict(img='#0') output_map = {name: f'#{i}' for i, name in enumerate(output_names)} @@ -159,6 +159,11 @@ def get_preprocess(deploy_cfg: mmengine.Config, model_cfg: mmengine.Config, task_processor = build_task_processor( model_cfg=model_cfg, deploy_cfg=deploy_cfg, device=device) transforms = task_processor.get_preprocess() + if get_backend(deploy_cfg) == Backend.RKNN: + del transforms[-2] + for transform in transforms: + if transform['type'] == 'Normalize': + transform['to_float'] = False assert transforms[0]['type'] == 'LoadImageFromFile', 'The first item'\ ' type of pipeline should be LoadImageFromFile' return dict( diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection.py b/mmdeploy/codebase/mmdet/deploy/object_detection.py index c431ae0f1..9ccb6fc8f 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection.py @@ -10,8 +10,9 @@ from mmengine.model import BaseDataPreprocessor from mmengine.registry import Registry from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase -from mmdeploy.utils import Codebase, Task -from mmdeploy.utils.config_utils import get_input_shape, is_dynamic_shape +from mmdeploy.utils import Backend, Codebase, Task +from mmdeploy.utils.config_utils import (get_backend, get_input_shape, + is_dynamic_shape) MMDET_TASK = Registry('mmdet_tasks') @@ -278,6 +279,14 @@ class ObjectDetection(BaseTask): if 'mask_thr_binary' in params['rcnn']: params['mask_thr_binary'] = params['rcnn']['mask_thr_binary'] type = 'ResizeInstanceMask' # for instance-seg + if get_backend(self.deploy_cfg) == Backend.RKNN: + if 'YOLO' in self.model_cfg.model.type: + bbox_head = self.model_cfg.model.bbox_head + type = bbox_head.type + params['anchor_generator'] = bbox_head.get( + 'anchor_generator', None) + else: # default using base_dense_head + type = 'BaseDenseHead' return dict(type=type, params=params) def get_model_name(self, *args, **kwargs) -> str: diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index 645cd6974..251f26f0a 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -657,10 +657,11 @@ class RKNNModel(End2EndModel): head_cfg = self.model_cfg._cfg_dict.model.bbox_head head = build_head(head_cfg) if head_cfg.type == 'YOLOXHead': + divisor = round(len(outputs) / 3) ret = head.predict_by_feat( - outputs[:3], - outputs[3:6], - outputs[6:9], + outputs[:divisor], + outputs[divisor:2 * divisor], + outputs[2 * divisor:], metainfos, cfg=self.model_cfg._cfg_dict.model.test_cfg, rescale=True) @@ -670,6 +671,31 @@ class RKNNModel(End2EndModel): metainfos, cfg=self.model_cfg._cfg_dict.model.test_cfg, rescale=True) + elif head_cfg.type in ('RetinaHead', 'SSDHead', 'FSAFHead'): + partition_cfgs = get_partition_config(self.deploy_cfg) + if partition_cfgs is None: # bbox decoding done in rknn model + from mmdet.structures.bbox import scale_boxes + + from ..models.layers.bbox_nms import _multiclass_nms + dets, labels = _multiclass_nms(outputs[0], outputs[1]) + ret = [InstanceData() for i in range(dets.shape[0])] + for i, instance_data in enumerate(ret): + instance_data.bboxes = dets[i, :, :4] + instance_data.scores = dets[i, :, 4] + instance_data.labels = labels[i] + scale_factor = [ + 1 / s for s in metainfos[i]['scale_factor'] + ] + instance_data.bboxes = scale_boxes(instance_data.bboxes, + scale_factor) + return ret + divisor = round(len(outputs) / 2) + ret = head.predict_by_feat( + outputs[:divisor], + outputs[divisor:], + batch_img_metas=metainfos, + rescale=True, + cfg=self.model_cfg._cfg_dict.model.test_cfg) else: raise NotImplementedError(f'{head_cfg.type} not supported yet.') return ret diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py index bf5a9a411..84feac5e9 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py @@ -14,7 +14,7 @@ from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params, pad_with_value_if_necessary) from mmdeploy.codebase.mmdet.models.layers import multiclass_nms from mmdeploy.codebase.mmdet.ops import ncnn_detection_output_forward -from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.core import FUNCTION_REWRITER, mark from mmdeploy.utils import Backend, is_dynamic_shape @@ -192,6 +192,132 @@ def base_dense_head__predict_by_feat( @FUNCTION_REWRITER.register_rewriter( func_name='mmdet.models.dense_heads.base_dense_head.' 'BaseDenseHead.predict_by_feat', + backend=Backend.RKNN.value) +def base_dense_head__predict_by_feat__rknn( + ctx, + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + score_factors: Optional[List[Tensor]] = None, + batch_img_metas: Optional[List[dict]] = None, + cfg: Optional[ConfigDict] = None, + rescale: bool = False, + with_nms: bool = True, + **kwargs): + """Rewrite `predict_by_feat` of `BaseDenseHead` for default backend. + Rewrite this function to deploy model, transform network output for a + batch into bbox predictions. + Args: + ctx (ContextCaller): The context with additional information. + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + score_factors (list[Tensor], optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Defaults to None. + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + Returns: + If with_nms == True: + tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels), + `dets` of shape [N, num_det, 5] and `labels` of shape + [N, num_det]. + Else: + tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, + batch_mlvl_scores, batch_mlvl_centerness + """ + # mark nodes for partition + @mark('BaseDenseHead', outputs=['BaseDenseHead.cls', 'BaseDenseHead.loc']) + def __mark_dense_head(cls_scores, bbox_preds): + return cls_scores, bbox_preds + + cls_scores, bbox_preds = __mark_dense_head(cls_scores, bbox_preds) + + deploy_cfg = ctx.cfg + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + num_levels = len(cls_scores) + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device) + mlvl_priors = [priors.unsqueeze(0) for priors in mlvl_priors] + + mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] + mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] + if score_factors is None: + with_score_factors = False + mlvl_score_factor = [None for _ in range(num_levels)] + else: + with_score_factors = True + mlvl_score_factor = [ + score_factors[i].detach() for i in range(num_levels) + ] + mlvl_score_factors = [] + assert batch_img_metas is not None + img_shape = batch_img_metas[0]['img_shape'] + + assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors) + batch_size = cls_scores[0].shape[0] + + mlvl_valid_bboxes = [] + mlvl_valid_scores = [] + mlvl_valid_priors = [] + + for cls_score, bbox_pred, score_factors, priors in zip( + mlvl_cls_scores, mlvl_bbox_preds, mlvl_score_factor, mlvl_priors): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + scores = cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1, + self.cls_out_channels) + if self.use_sigmoid_cls: + scores = scores.sigmoid() + else: + scores = scores.softmax(-1)[:, :, :-1] + if with_score_factors: + score_factors = score_factors.permute(0, 2, 3, + 1).reshape(batch_size, + -1).sigmoid() + score_factors = score_factors.unsqueeze(2) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) + if not is_dynamic_flag: + priors = priors.data + + mlvl_valid_bboxes.append(bbox_pred) + mlvl_valid_scores.append(scores) + mlvl_valid_priors.append(priors) + if with_score_factors: + mlvl_score_factors.append(score_factors) + + batch_mlvl_bboxes_pred = torch.cat(mlvl_valid_bboxes, dim=1) + batch_scores = torch.cat(mlvl_valid_scores, dim=1) + batch_priors = torch.cat(mlvl_valid_priors, dim=1) + batch_bboxes = self.bbox_coder.decode( + batch_priors, batch_mlvl_bboxes_pred, max_shape=img_shape) + if with_score_factors: + batch_score_factors = torch.cat(mlvl_score_factors, dim=1) + if not self.use_sigmoid_cls: + batch_scores = batch_scores[..., :self.num_classes] + + if with_score_factors: + batch_scores = batch_scores * batch_score_factors + if isinstance(self, PAAHead): + batch_scores = batch_scores.sqrt() + return batch_bboxes, batch_scores + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.models.dense_heads.base_dense_head.BaseDenseHead' + '.get_bboxes', backend=Backend.NCNN.value) def base_dense_head__predict_by_feat__ncnn( ctx, diff --git a/mmdeploy/codebase/mmocr/models/text_detection/single_stage_text_detector.py b/mmdeploy/codebase/mmocr/models/text_detection/single_stage_text_detector.py index 3e2a5400f..ea72eae89 100644 --- a/mmdeploy/codebase/mmocr/models/text_detection/single_stage_text_detector.py +++ b/mmdeploy/codebase/mmocr/models/text_detection/single_stage_text_detector.py @@ -20,7 +20,7 @@ def single_stage_text_detector__forward( Args: batch_inputs (torch.Tensor): Images of shape (N, C, H, W). - batch_data_samples (list[TextDetDataSample]): A list of N + data_samples (list[TextDetDataSample]): A list of N datasamples, containing meta information and gold annotations for each of the images. diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/base_decoder.py b/mmdeploy/codebase/mmocr/models/text_recognition/base_decoder.py index 9d5c6e0e1..26adccaec 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/base_decoder.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/base_decoder.py @@ -16,7 +16,7 @@ def base_decoder__forward( out_enc: Optional[torch.Tensor] = None, data_samples: Optional[Sequence[TextRecogDataSample]] = None ) -> Sequence[TextRecogDataSample]: - """Perform forward propagation of the decoder and postprocessor. + """Rewrite `predict` of `BaseDecoder` to skip post-process. Args: feat (Tensor, optional): Features from the backbone. Defaults diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/encoder_decoder_recognizer.py b/mmdeploy/codebase/mmocr/models/text_recognition/encoder_decoder_recognizer.py index 2da011470..155ece62c 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/encoder_decoder_recognizer.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/encoder_decoder_recognizer.py @@ -20,13 +20,10 @@ def encoder_decoder_recognizer__forward(ctx, self, batch_inputs: torch.Tensor, ctx (ContextCaller): The context with additional information. self: The instance of the class EncoderDecoderRecognizer. - img (Tensor): Input images of shape (N, C, H, W). + batch_inputs (Tensor): Input images of shape (N, C, H, W). Typically these should be mean centered and std scaled. - img_metas (list[dict]): A list of image info dict where each dict - has: 'img_shape', 'scale_factor', 'flip', and may also contain - 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. - For details on the values of these keys, see - :class:`mmdet.datasets.pipelines.Collect`. + data_samples (TextRecogDataSample): Containing meta information + and gold annotations for each of the images. Defaults to None. Returns: out_dec (Tensor): A feature map output from a decoder. The tensor shape diff --git a/mmdeploy/core/optimizers/__init__.py b/mmdeploy/core/optimizers/__init__.py index 53145fd45..eb06c9aec 100644 --- a/mmdeploy/core/optimizers/__init__.py +++ b/mmdeploy/core/optimizers/__init__.py @@ -2,10 +2,10 @@ from .extractor import create_extractor, parse_extractor_io_string from .function_marker import mark, reset_mark_function_count from .optimize import (attribute_to_dict, get_new_name, remove_identity, - rename_value) + remove_imports, rename_value) __all__ = [ 'mark', 'reset_mark_function_count', 'create_extractor', 'parse_extractor_io_string', 'remove_identity', 'attribute_to_dict', - 'rename_value', 'get_new_name' + 'rename_value', 'get_new_name', 'remove_imports' ] diff --git a/mmdeploy/core/optimizers/optimize.py b/mmdeploy/core/optimizers/optimize.py index 9ad84f6dd..4587987ae 100644 --- a/mmdeploy/core/optimizers/optimize.py +++ b/mmdeploy/core/optimizers/optimize.py @@ -206,3 +206,24 @@ def remove_identity(model: onnx.ModelProto): pass remove_nodes(model, is_identity) + + +def remove_imports(model: onnx.ModelProto): + """Remove useless imports from an ONNX model. + + The domain like `mmdeploy` might influence model conversion for + some backends. + + Args: + model (onnx.ModelProto): Input onnx model. + """ + logger = get_root_logger() + dst_domain = [''] + for node in model.graph.node: + if hasattr(node, 'module') and (node.module not in dst_domain): + dst_domain.append(node.module) + src_domains = [oi.domain for oi in model.opset_import] + for i, src_domain in enumerate(src_domains): + if src_domain not in dst_domain: + logger.info(f'remove opset_import {src_domain}') + model.opset_import.pop(i) diff --git a/mmdeploy/pytorch/functions/interpolate.py b/mmdeploy/pytorch/functions/interpolate.py index 253b59316..a335792f0 100644 --- a/mmdeploy/pytorch/functions/interpolate.py +++ b/mmdeploy/pytorch/functions/interpolate.py @@ -40,6 +40,38 @@ def interpolate__ncnn(ctx, recompute_scale_factor=recompute_scale_factor) +@FUNCTION_REWRITER.register_rewriter( + func_name='torch.nn.functional.interpolate', backend='rknn') +def interpolate__rknn(ctx, + input: torch.Tensor, + size: Optional[Union[int, Tuple[int], Tuple[int, int], + Tuple[int, int, int]]] = None, + scale_factor: Optional[Union[float, + Tuple[float]]] = None, + mode: str = 'nearest', + align_corners: Optional[bool] = None, + recompute_scale_factor: Optional[bool] = None): + """Rewrite `interpolate` for rknn backend. + + rknn require `size` should be constant in ONNX Node. We use `scale_factor` + instead of `size` to avoid dynamic size. + """ + input_size = input.shape + if scale_factor is None: + scale_factor = [(s_out / s_in) + for s_out, s_in in zip(size, input_size[2:])] + if isinstance(scale_factor[0], torch.Tensor): + scale_factor = [i.item() for i in scale_factor] + + return ctx.origin_func( + input, + None, + scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor) + + @FUNCTION_REWRITER.register_rewriter( 'torch.nn.functional.interpolate', is_pytorch=True, diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index b3b94d28a..de03c9c6f 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -1821,6 +1821,80 @@ def test_base_dense_head_predict_by_feat__ncnn(): assert rewrite_outputs.shape[-1] == 6 +@backend_checker(Backend.RKNN) +def test_base_dense_head_get_bboxes__rknn(): + """Test get_bboxes rewrite of ssd head for rknn.""" + ssd_head = get_ssd_head_model() + ssd_head.cpu().eval() + s = 128 + img_metas = [{ + 'scale_factor': np.ones(4), + 'pad_shape': (s, s, 3), + 'img_shape': (s, s, 3) + }] + output_names = ['output'] + input_names = [] + for i in range(6): + input_names.append('cls_scores_' + str(i)) + input_names.append('bbox_preds_' + str(i)) + dynamic_axes = None + deploy_cfg = mmengine.Config( + dict( + backend_config=dict(type=Backend.RKNN.value), + onnx_config=dict( + input_names=input_names, + output_names=output_names, + input_shape=None, + dynamic_axes=dynamic_axes), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + model_type='rknn', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + )))) + + # For the ssd_head: + # the cls_score's size: (1, 30, 20, 20), (1, 30, 10, 10), + # (1, 30, 5, 5), (1, 30, 3, 3), (1, 30, 2, 2), (1, 30, 1, 1) + # the bboxes's size: (1, 24, 20, 20), (1, 24, 10, 10), + # (1, 24, 5, 5), (1, 24, 3, 3), (1, 24, 2, 2), (1, 24, 1, 1) + feat_shape = [20, 10, 5, 3, 2, 1] + num_prior = 6 + seed_everything(1234) + cls_score = [ + torch.rand(1, 30, feat_shape[i], feat_shape[i]) + for i in range(num_prior) + ] + seed_everything(5678) + bboxes = [ + torch.rand(1, 24, feat_shape[i], feat_shape[i]) + for i in range(num_prior) + ] + + # to get outputs of onnx model after rewrite + img_metas[0]['img_shape'] = [s, s] + wrapped_model = WrapModel( + ssd_head, 'get_bboxes', img_metas=img_metas, with_nms=True) + rewrite_inputs = { + 'cls_scores': cls_score, + 'bbox_preds': bboxes, + } + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg, + run_with_backend=False) + + # output should be of shape [1, N, 4] + assert rewrite_outputs[0].shape[-1] == 4 + + @pytest.mark.parametrize('backend_type, ir_type', [(Backend.OPENVINO, 'onnx')]) def test_reppoints_head_predict_by_feat(backend_type: Backend, ir_type: str): """Test predict_by_feat rewrite of base dense head.""" diff --git a/tests/test_pytorch/test_pytorch_functions.py b/tests/test_pytorch/test_pytorch_functions.py index 51540f0cd..b63830b22 100644 --- a/tests/test_pytorch/test_pytorch_functions.py +++ b/tests/test_pytorch/test_pytorch_functions.py @@ -119,6 +119,30 @@ def test_interpolate_static(): assert np.allclose(model_output, rewrite_output[0], rtol=1e-03, atol=1e-05) +@backend_checker(Backend.RKNN) +def test_interpolate__rknn(): + input = torch.rand([1, 2, 2, 2]) + model_output = F.interpolate(input, scale_factor=[2, 2]) + + def interpolate_caller(*arg, **kwargs): + return F.interpolate(*arg, **kwargs) + + deploy_cfg = Config( + dict( + onnx_config=dict(input_shape=None), + backend_config=dict(type='rknn', model_inputs=None), + codebase_config=dict(type='mmdet', task='ObjectDetection'))) + + wrapped_func = WrapFunction(interpolate_caller, size=[4, 4]) + rewrite_output, _ = get_rewrite_outputs( + wrapped_func, + model_inputs={'input': input}, + deploy_cfg=deploy_cfg, + run_with_backend=False) + + assert np.allclose(model_output, rewrite_output[0], rtol=1e-03, atol=1e-05) + + @backend_checker(Backend.NCNN) def test_linear_ncnn(): input = torch.rand([1, 2, 2]) From 6420e2044515ff2052960c0f8bb9e351e6a7f2c2 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 8 Nov 2022 10:37:59 +0800 Subject: [PATCH 3/4] [Refactor] Ease rewriter import for dev-1.x (#1170) * Update rewriter import * remove root import * add interface --- mmdeploy/codebase/base/mmcodebase.py | 8 ++- mmdeploy/codebase/mmaction/__init__.py | 1 - mmdeploy/codebase/mmaction/deploy/mmaction.py | 5 ++ mmdeploy/codebase/mmaction/models/__init__.py | 2 +- .../mmaction/models/recognizers/__init__.py | 4 +- mmdeploy/codebase/mmcls/__init__.py | 1 - mmdeploy/codebase/mmcls/models/__init__.py | 8 +-- .../mmcls/models/backbones/__init__.py | 9 +-- .../mmcls/models/classifiers/__init__.py | 4 +- .../codebase/mmcls/models/necks/__init__.py | 5 +- .../codebase/mmcls/models/utils/__init__.py | 10 +--- mmdeploy/codebase/mmdet/__init__.py | 3 - .../mmdet/deploy/object_detection_model.py | 3 +- mmdeploy/codebase/mmdet/models/__init__.py | 16 +++--- .../mmdet/models/dense_heads/__init__.py | 31 +++-------- .../models/dense_heads/base_dense_head.py | 5 +- .../mmdet/models/dense_heads/fovea_head.py | 2 +- .../mmdet/models/dense_heads/gfl_head.py | 5 +- .../models/dense_heads/reppoints_head.py | 5 +- .../mmdet/models/dense_heads/rpn_head.py | 5 +- .../mmdet/models/dense_heads/ssd_head.py | 2 +- .../mmdet/models/dense_heads/yolo_head.py | 4 +- .../mmdet/models/dense_heads/yolox_head.py | 2 +- .../codebase/mmdet/models/layers/__init__.py | 4 +- .../codebase/mmdet/models/layers/bbox_nms.py | 12 ++-- .../mmdet/models/roi_heads/__init__.py | 23 ++------ .../mmdet/models/roi_heads/bbox_head.py | 2 +- .../mmdet/models/roi_heads/fcn_mask_head.py | 2 +- .../mmdet/models/task_modules/__init__.py | 4 +- .../models/task_modules/coders/__init__.py | 6 +- .../task_modules/prior_generators/__init__.py | 4 +- .../codebase/mmdet/structures/__init__.py | 2 +- .../mmdet/structures/bbox/__init__.py | 2 +- mmdeploy/codebase/mmdet3d/__init__.py | 1 - mmdeploy/codebase/mmdet3d/models/__init__.py | 8 +-- mmdeploy/codebase/mmedit/__init__.py | 3 +- mmdeploy/codebase/mmedit/models/__init__.py | 2 +- .../mmedit/models/base_models/__init__.py | 4 +- mmdeploy/codebase/mmocr/__init__.py | 1 - mmdeploy/codebase/mmocr/models/__init__.py | 4 +- .../mmocr/models/text_detection/__init__.py | 11 +--- .../mmocr/models/text_recognition/__init__.py | 19 ++----- mmdeploy/codebase/mmpose/__init__.py | 1 - mmdeploy/codebase/mmpose/models/__init__.py | 4 +- .../models/dense_heads/oriented_rpn_head.py | 4 +- .../models/dense_heads/rotated_anchor_head.py | 4 +- .../models/dense_heads/rotated_rpn_head.py | 6 +- .../mmrotate/models/roi_heads/gv_bbox_head.py | 2 +- .../models/roi_heads/rotated_bbox_head.py | 2 +- mmdeploy/codebase/mmseg/__init__.py | 1 - mmdeploy/codebase/mmseg/models/__init__.py | 6 +- .../codebase/mmseg/models/utils/__init__.py | 4 +- mmdeploy/mmcv/__init__.py | 4 +- mmdeploy/mmcv/cnn/__init__.py | 2 +- mmdeploy/mmcv/ops/__init__.py | 21 ++++--- mmdeploy/pytorch/__init__.py | 4 +- mmdeploy/pytorch/functions/__init__.py | 55 +++++++------------ mmdeploy/pytorch/ops/__init__.py | 17 ------ mmdeploy/pytorch/symbolics/__init__.py | 11 ++++ .../{ops => symbolics}/adaptive_pool.py | 0 mmdeploy/pytorch/{ops => symbolics}/gelu.py | 0 .../{ops => symbolics}/grid_sampler.py | 0 .../pytorch/{ops => symbolics}/hardsigmoid.py | 0 .../{ops => symbolics}/instance_norm.py | 0 .../pytorch/{ops => symbolics}/layer_norm.py | 0 mmdeploy/pytorch/{ops => symbolics}/linear.py | 0 mmdeploy/pytorch/{ops => symbolics}/lstm.py | 0 mmdeploy/pytorch/{ops => symbolics}/roll.py | 0 .../pytorch/{ops => symbolics}/squeeze.py | 0 .../test_mmdet/test_mmdet_utils.py | 7 ++- .../test_mmdet3d/test_mmdet3d_models.py | 3 + .../test_mmrotate/test_mmrotate_core.py | 20 +++---- .../test_mmrotate/test_mmrotate_models.py | 32 +++++------ .../test_mmrotate/test_rotated_detection.py | 4 +- .../test_rotated_detection_model.py | 6 +- 75 files changed, 200 insertions(+), 274 deletions(-) delete mode 100644 mmdeploy/pytorch/ops/__init__.py create mode 100644 mmdeploy/pytorch/symbolics/__init__.py rename mmdeploy/pytorch/{ops => symbolics}/adaptive_pool.py (100%) rename mmdeploy/pytorch/{ops => symbolics}/gelu.py (100%) rename mmdeploy/pytorch/{ops => symbolics}/grid_sampler.py (100%) rename mmdeploy/pytorch/{ops => symbolics}/hardsigmoid.py (100%) rename mmdeploy/pytorch/{ops => symbolics}/instance_norm.py (100%) rename mmdeploy/pytorch/{ops => symbolics}/layer_norm.py (100%) rename mmdeploy/pytorch/{ops => symbolics}/linear.py (100%) rename mmdeploy/pytorch/{ops => symbolics}/lstm.py (100%) rename mmdeploy/pytorch/{ops => symbolics}/roll.py (100%) rename mmdeploy/pytorch/{ops => symbolics}/squeeze.py (100%) diff --git a/mmdeploy/codebase/base/mmcodebase.py b/mmdeploy/codebase/base/mmcodebase.py index 2217aee9b..5130ba37c 100644 --- a/mmdeploy/codebase/base/mmcodebase.py +++ b/mmdeploy/codebase/base/mmcodebase.py @@ -49,9 +49,15 @@ class MMCodebase(metaclass=ABCMeta): deploy_cfg=deploy_cfg, device=device)) + @classmethod + def register_deploy_modules(cls): + """register deploy module.""" + raise NotImplementedError('register_deploy_modules not implemented.') + @classmethod def register_all_modules(cls): - pass + """register codebase module.""" + raise NotImplementedError('register_all_modules not implemented.') # Note that the build function returns the class instead of its instance. diff --git a/mmdeploy/codebase/mmaction/__init__.py b/mmdeploy/codebase/mmaction/__init__.py index 1cdbf57bd..daa1fe60a 100644 --- a/mmdeploy/codebase/mmaction/__init__.py +++ b/mmdeploy/codebase/mmaction/__init__.py @@ -1,4 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. from .deploy import * # noqa: F401,F403 -from .models import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmaction/deploy/mmaction.py b/mmdeploy/codebase/mmaction/deploy/mmaction.py index 4f018901f..dae180571 100644 --- a/mmdeploy/codebase/mmaction/deploy/mmaction.py +++ b/mmdeploy/codebase/mmaction/deploy/mmaction.py @@ -13,7 +13,12 @@ class MMACTION(MMCodebase): task_registry = MMACTION_TASK + @classmethod + def register_deploy_modules(cls): + import mmdeploy.codebase.mmaction.models # noqa: F401 + @classmethod def register_all_modules(cls): from mmaction.utils.setup_env import register_all_modules + cls.register_deploy_modules() register_all_modules(True) diff --git a/mmdeploy/codebase/mmaction/models/__init__.py b/mmdeploy/codebase/mmaction/models/__init__.py index db721b1f3..0a5e58dfe 100644 --- a/mmdeploy/codebase/mmaction/models/__init__.py +++ b/mmdeploy/codebase/mmaction/models/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .recognizers import * # noqa: F401,F403 +from . import recognizers # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmaction/models/recognizers/__init__.py b/mmdeploy/codebase/mmaction/models/recognizers/__init__.py index ff8a52482..43ddb8adb 100644 --- a/mmdeploy/codebase/mmaction/models/recognizers/__init__.py +++ b/mmdeploy/codebase/mmaction/models/recognizers/__init__.py @@ -1,5 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import base_recognizer__forward - -__all__ = ['base_recognizer__forward'] +from . import base # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmcls/__init__.py b/mmdeploy/codebase/mmcls/__init__.py index 0683b42a3..75c5c8ebd 100644 --- a/mmdeploy/codebase/mmcls/__init__.py +++ b/mmdeploy/codebase/mmcls/__init__.py @@ -1,3 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. from .deploy import * # noqa: F401,F403 -from .models import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmcls/models/__init__.py b/mmdeploy/codebase/mmcls/models/__init__.py index a489c1edc..75b589fca 100644 --- a/mmdeploy/codebase/mmcls/models/__init__.py +++ b/mmdeploy/codebase/mmcls/models/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .backbones import * # noqa: F401,F403 -from .classifiers import * # noqa: F401,F403 -from .necks import * # noqa: F401,F403 -from .utils import * # noqa: F401,F403 +from . import backbones # noqa: F401,F403 +from . import classifiers # noqa: F401,F403 +from . import necks # noqa: F401,F403 +from . import utils # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmcls/models/backbones/__init__.py b/mmdeploy/codebase/mmcls/models/backbones/__init__.py index fd9d7d3d4..62d47810c 100644 --- a/mmdeploy/codebase/mmcls/models/backbones/__init__.py +++ b/mmdeploy/codebase/mmcls/models/backbones/__init__.py @@ -1,8 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .shufflenet_v2 import shufflenetv2_backbone__forward__default -from .vision_transformer import visiontransformer__forward__ncnn - -__all__ = [ - 'shufflenetv2_backbone__forward__default', - 'visiontransformer__forward__ncnn' -] +from . import shufflenet_v2 # noqa: F401,F403 +from . import vision_transformer # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmcls/models/classifiers/__init__.py b/mmdeploy/codebase/mmcls/models/classifiers/__init__.py index 630dd251b..94502952b 100644 --- a/mmdeploy/codebase/mmcls/models/classifiers/__init__.py +++ b/mmdeploy/codebase/mmcls/models/classifiers/__init__.py @@ -1,4 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import base_classifier__forward - -__all__ = ['base_classifier__forward'] +from . import base # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmcls/models/necks/__init__.py b/mmdeploy/codebase/mmcls/models/necks/__init__.py index 5bdebc563..882128e61 100644 --- a/mmdeploy/codebase/mmcls/models/necks/__init__.py +++ b/mmdeploy/codebase/mmcls/models/necks/__init__.py @@ -1,5 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. - -from .gap import gap__forward - -__all__ = ['gap__forward'] +from . import gap # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmcls/models/utils/__init__.py b/mmdeploy/codebase/mmcls/models/utils/__init__.py index 3d0a17994..2e003bba7 100644 --- a/mmdeploy/codebase/mmcls/models/utils/__init__.py +++ b/mmdeploy/codebase/mmcls/models/utils/__init__.py @@ -1,10 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .attention import (multiheadattention__forward__ncnn, - shift_window_msa__forward__default, - shift_window_msa__get_attn_mask__default) - -__all__ = [ - 'multiheadattention__forward__ncnn', - 'shift_window_msa__get_attn_mask__default', - 'shift_window_msa__forward__default' -] +from . import attention # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/__init__.py b/mmdeploy/codebase/mmdet/__init__.py index 6d69a5a9d..61e8c0794 100644 --- a/mmdeploy/codebase/mmdet/__init__.py +++ b/mmdeploy/codebase/mmdet/__init__.py @@ -2,9 +2,6 @@ from .deploy import (ObjectDetection, clip_bboxes, gather_topk, get_post_processing_params, pad_with_value, pad_with_value_if_necessary) -from .models import * # noqa: F401,F403 -from .ops import * # noqa: F401,F403 -from .structures import * # noqa: F401, F403 __all__ = [ 'get_post_processing_params', 'clip_bboxes', 'pad_with_value', diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index 251f26f0a..50379fd2f 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -13,7 +13,8 @@ from torch import Tensor, nn from mmdeploy.backend.base import get_backend_file_count from mmdeploy.codebase.base import BaseBackendModel -from mmdeploy.codebase.mmdet import get_post_processing_params, multiclass_nms +from mmdeploy.codebase.mmdet.deploy import get_post_processing_params +from mmdeploy.codebase.mmdet.models.layers import multiclass_nms from mmdeploy.utils import (Backend, get_backend, get_codebase_config, get_partition_config, load_config) diff --git a/mmdeploy/codebase/mmdet/models/__init__.py b/mmdeploy/codebase/mmdet/models/__init__.py index d0282caa9..38b7e336d 100644 --- a/mmdeploy/codebase/mmdet/models/__init__.py +++ b/mmdeploy/codebase/mmdet/models/__init__.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .backbones import * # noqa: F401, F403 -from .dense_heads import * # noqa: F401,F403 -from .detectors import * # noqa: F401,F403 -from .layers import * # noqa: F401,F403 -from .necks import * # noqa: F401,F403 -from .roi_heads import * # noqa: F401,F403 -from .task_modules import * # noqa: F401,F403 -from .transformer import * # noqa: F401,F403 +from . import backbones # noqa: F401, F403 +from . import dense_heads # noqa: F401,F403 +from . import detectors # noqa: F401,F403 +from . import layers # noqa: F401,F403 +from . import necks # noqa: F401,F403 +from . import roi_heads # noqa: F401,F403 +from . import task_modules # noqa: F401,F403 +from . import transformer # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py index 39dacd5b1..57feeceaf 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py @@ -1,23 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import detr_head -from .base_dense_head import (base_dense_head__predict_by_feat, - base_dense_head__predict_by_feat__ncnn) -from .fovea_head import fovea_head__predict_by_feat -from .gfl_head import gfl_head__predict_by_feat -from .reppoints_head import reppoints_head__predict_by_feat -from .rpn_head import rpn_head__get_bboxes__ncnn, rpn_head__predict_by_feat -from .rtmdet_head import rtmdet_head__predict_by_feat -from .yolo_head import (yolov3_head__predict_by_feat, - yolov3_head__predict_by_feat__ncnn) -from .yolox_head import (yolox_head__predict_by_feat, - yolox_head__predict_by_feat__ncnn) - -__all__ = [ - 'rpn_head__predict_by_feat', 'rpn_head__get_bboxes__ncnn', - 'yolov3_head__predict_by_feat', 'yolov3_head__predict_by_feat__ncnn', - 'yolox_head__predict_by_feat', 'base_dense_head__predict_by_feat', - 'fovea_head__predict_by_feat', 'base_dense_head__predict_by_feat__ncnn', - 'yolox_head__predict_by_feat__ncnn', 'gfl_head__predict_by_feat', - 'reppoints_head__predict_by_feat', 'detr_head', - 'rtmdet_head__predict_by_feat' -] +from . import base_dense_head # noqa: F401,F403 +from . import detr_head # noqa: F401,F403 +from . import fovea_head # noqa: F401,F403 +from . import gfl_head # noqa: F401,F403 +from . import reppoints_head # noqa: F401,F403 +from . import rpn_head # noqa: F401,F403 +from . import rtmdet_head # noqa: F401,F403 +from . import yolo_head # noqa: F401,F403 +from . import yolox_head # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py index 84feac5e9..43bdf919e 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py @@ -10,8 +10,9 @@ from mmdet.structures.bbox.transforms import distance2bbox from mmengine import ConfigDict from torch import Tensor -from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params, - pad_with_value_if_necessary) +from mmdeploy.codebase.mmdet.deploy import (gather_topk, + get_post_processing_params, + pad_with_value_if_necessary) from mmdeploy.codebase.mmdet.models.layers import multiclass_nms from mmdeploy.codebase.mmdet.ops import ncnn_detection_output_forward from mmdeploy.core import FUNCTION_REWRITER, mark diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/fovea_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/fovea_head.py index 1c1e95141..110a4045f 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/fovea_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/fovea_head.py @@ -6,7 +6,7 @@ from mmengine.config import ConfigDict from mmengine.structures import InstanceData from torch import Tensor -from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.codebase.mmdet.deploy import get_post_processing_params from mmdeploy.codebase.mmdet.models.layers import multiclass_nms from mmdeploy.core import FUNCTION_REWRITER diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py index 56c002ea2..43258b7db 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py @@ -7,8 +7,9 @@ from mmengine.config import ConfigDict from mmengine.structures import InstanceData from torch import Tensor -from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params, - pad_with_value) +from mmdeploy.codebase.mmdet.deploy import (gather_topk, + get_post_processing_params, + pad_with_value) from mmdeploy.codebase.mmdet.models.layers import multiclass_nms from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.utils import Backend, get_backend, is_dynamic_shape diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py index f718936c8..dfc5e0ee3 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/reppoints_head.py @@ -6,8 +6,9 @@ from mmengine.config import ConfigDict from mmengine.structures import InstanceData from torch import Tensor -from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params, - pad_with_value_if_necessary) +from mmdeploy.codebase.mmdet.deploy import (gather_topk, + get_post_processing_params, + pad_with_value_if_necessary) from mmdeploy.codebase.mmdet.models.layers import multiclass_nms from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.utils import is_dynamic_shape diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py index 27be63c65..53a3c29e5 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py @@ -5,8 +5,9 @@ import torch from mmengine import ConfigDict from torch import Tensor -from mmdeploy.codebase.mmdet import (gather_topk, get_post_processing_params, - pad_with_value_if_necessary) +from mmdeploy.codebase.mmdet.deploy import (gather_topk, + get_post_processing_params, + pad_with_value_if_necessary) from mmdeploy.codebase.mmdet.models.layers import multiclass_nms from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.utils import Backend, is_dynamic_shape diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/ssd_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/ssd_head.py index b5dade2a9..b34571add 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/ssd_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/ssd_head.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.codebase.mmdet.deploy import get_post_processing_params from mmdeploy.codebase.mmdet.ops import (ncnn_detection_output_forward, ncnn_prior_box_forward) from mmdeploy.core import FUNCTION_REWRITER diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py index cdb164ab9..87a12645b 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py @@ -6,8 +6,8 @@ import torch from mmdet.utils.typing import OptConfigType from torch import Tensor -from mmdeploy.codebase.mmdet import (get_post_processing_params, - pad_with_value_if_necessary) +from mmdeploy.codebase.mmdet.deploy import (get_post_processing_params, + pad_with_value_if_necessary) from mmdeploy.codebase.mmdet.models.layers import multiclass_nms from mmdeploy.core import FUNCTION_REWRITER, mark from mmdeploy.utils import Backend, is_dynamic_shape diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/yolox_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/yolox_head.py index 439e6f55b..47a696fbb 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/yolox_head.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/yolox_head.py @@ -6,7 +6,7 @@ from mmengine.config import ConfigDict from mmengine.structures import InstanceData from torch import Tensor -from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.codebase.mmdet.deploy import get_post_processing_params from mmdeploy.codebase.mmdet.models.layers import multiclass_nms from mmdeploy.core import FUNCTION_REWRITER, mark from mmdeploy.utils import Backend diff --git a/mmdeploy/codebase/mmdet/models/layers/__init__.py b/mmdeploy/codebase/mmdet/models/layers/__init__.py index b62cf581d..5d18f72f9 100644 --- a/mmdeploy/codebase/mmdet/models/layers/__init__.py +++ b/mmdeploy/codebase/mmdet/models/layers/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .bbox_nms import _multiclass_nms, multiclass_nms +from .bbox_nms import multiclass_nms -__all__ = ['multiclass_nms', '_multiclass_nms'] +__all__ = ['multiclass_nms'] diff --git a/mmdeploy/codebase/mmdet/models/layers/bbox_nms.py b/mmdeploy/codebase/mmdet/models/layers/bbox_nms.py index 855708583..b66f1f765 100644 --- a/mmdeploy/codebase/mmdet/models/layers/bbox_nms.py +++ b/mmdeploy/codebase/mmdet/models/layers/bbox_nms.py @@ -2,7 +2,6 @@ import torch from torch import Tensor -import mmdeploy from mmdeploy.core import FUNCTION_REWRITER, mark from mmdeploy.mmcv.ops import ONNXNMSop, TRTBatchedNMSop from mmdeploy.utils import IR, is_dynamic_batch @@ -166,7 +165,7 @@ def _multiclass_nms_single(boxes: Tensor, @FUNCTION_REWRITER.register_rewriter( - func_name='mmdeploy.codebase.mmdet.models.layers._multiclass_nms') + func_name='mmdeploy.codebase.mmdet.models.layers.bbox_nms._multiclass_nms') def multiclass_nms__default(ctx, boxes: Tensor, scores: Tensor, @@ -223,7 +222,7 @@ def multiclass_nms__default(ctx, @FUNCTION_REWRITER.register_rewriter( - func_name='mmdeploy.codebase.mmdet.models.layers._multiclass_nms', + func_name='mmdeploy.codebase.mmdet.models.layers.bbox_nms._multiclass_nms', backend='tensorrt') def multiclass_nms_static(ctx, boxes: Tensor, @@ -274,12 +273,11 @@ def multiclass_nms_static(ctx, @mark('multiclass_nms', inputs=['boxes', 'scores'], outputs=['dets', 'labels']) def multiclass_nms(*args, **kwargs): """Wrapper function for `_multiclass_nms`.""" - return mmdeploy.codebase.mmdet.models.layers._multiclass_nms( - *args, **kwargs) + return _multiclass_nms(*args, **kwargs) @FUNCTION_REWRITER.register_rewriter( - func_name='mmdeploy.codebase.mmdet.models.layers._multiclass_nms', + func_name='mmdeploy.codebase.mmdet.models.layers.bbox_nms._multiclass_nms', backend=Backend.COREML.value) def multiclass_nms__coreml(ctx, boxes: Tensor, @@ -340,7 +338,7 @@ def multiclass_nms__coreml(ctx, @FUNCTION_REWRITER.register_rewriter( - func_name='mmdeploy.codebase.mmdet.models.layers._multiclass_nms', + func_name='mmdeploy.codebase.mmdet.models.layers.bbox_nms._multiclass_nms', ir=IR.TORCHSCRIPT) def multiclass_nms__torchscript(ctx, boxes: Tensor, diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/__init__.py b/mmdeploy/codebase/mmdet/models/roi_heads/__init__.py index 94cb3c645..f12a70dc6 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/__init__.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/__init__.py @@ -1,19 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .bbox_head import bbox_head__forward, bbox_head__predict_by_feat -from .cascade_roi_head import (cascade_roi_head__predict_bbox, - cascade_roi_head__predict_mask) -from .fcn_mask_head import fcn_mask_head__predict_by_feat -from .single_level_roi_extractor import ( - single_roi_extractor__forward, single_roi_extractor__forward__openvino, - single_roi_extractor__forward__tensorrt) -from .standard_roi_head import (standard_roi_head__predict_bbox, - standard_roi_head__predict_mask) - -__all__ = [ - 'bbox_head__predict_by_feat', 'bbox_head__forward', - 'cascade_roi_head__predict_bbox', 'cascade_roi_head__predict_mask', - 'fcn_mask_head__predict_by_feat', 'single_roi_extractor__forward', - 'single_roi_extractor__forward__openvino', - 'single_roi_extractor__forward__tensorrt', - 'standard_roi_head__predict_bbox', 'standard_roi_head__predict_mask' -] +from . import bbox_head # noqa: F401,F403 +from . import cascade_roi_head # noqa: F401,F403 +from . import fcn_mask_head # noqa: F401,F403 +from . import single_level_roi_extractor # noqa: F401,F403 +from . import standard_roi_head # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py index f9581e7e9..28dbeb326 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/bbox_head.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from mmengine import ConfigDict from torch import Tensor -from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.codebase.mmdet.deploy import get_post_processing_params from mmdeploy.codebase.mmdet.models.layers import multiclass_nms from mmdeploy.core import FUNCTION_REWRITER, mark diff --git a/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py b/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py index 4e98ebf80..360faeb1a 100644 --- a/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py +++ b/mmdeploy/codebase/mmdet/models/roi_heads/fcn_mask_head.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from mmengine import ConfigDict from torch import Tensor -from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.codebase.mmdet.deploy import get_post_processing_params from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.utils import Backend, get_backend diff --git a/mmdeploy/codebase/mmdet/models/task_modules/__init__.py b/mmdeploy/codebase/mmdet/models/task_modules/__init__.py index 9804a53db..a23b35805 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/__init__.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .coders import * # noqa: F401,F403 -from .prior_generators import * # noqa: F401,F403 +from . import coders # noqa: F401,F403 +from . import prior_generators # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/models/task_modules/coders/__init__.py b/mmdeploy/codebase/mmdet/models/task_modules/coders/__init__.py index 8b52d7277..26b705120 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/coders/__init__.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/coders/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .delta_xywh_bbox_coder import * # noqa: F401,F403 -from .distance_point_bbox_coder import * # noqa: F401,F403 -from .tblr_bbox_coder import * # noqa: F401,F403 +from . import delta_xywh_bbox_coder # noqa: F401,F403 +from . import distance_point_bbox_coder # noqa: F401,F403 +from . import tblr_bbox_coder # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/__init__.py b/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/__init__.py index 9985860a1..ada48c478 100644 --- a/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/__init__.py +++ b/mmdeploy/codebase/mmdet/models/task_modules/prior_generators/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .anchor import * # noqa: F401,F403 -from .point_generator import * # noqa: F401,F403 +from . import anchor # noqa: F401,F403 +from . import point_generator # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/structures/__init__.py b/mmdeploy/codebase/mmdet/structures/__init__.py index 1b12e624b..fbd05bdf3 100644 --- a/mmdeploy/codebase/mmdet/structures/__init__.py +++ b/mmdeploy/codebase/mmdet/structures/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .bbox import * # noqa: F401,F403 +from . import bbox # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/structures/bbox/__init__.py b/mmdeploy/codebase/mmdet/structures/bbox/__init__.py index 16a63592c..a3aa67574 100644 --- a/mmdeploy/codebase/mmdet/structures/bbox/__init__.py +++ b/mmdeploy/codebase/mmdet/structures/bbox/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .transforms import * # noqa: F401,F403 +from . import transforms # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet3d/__init__.py b/mmdeploy/codebase/mmdet3d/__init__.py index 1974ef569..70aab6bd3 100644 --- a/mmdeploy/codebase/mmdet3d/__init__.py +++ b/mmdeploy/codebase/mmdet3d/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .deploy import MMDetection3d, VoxelDetection -from .models import * # noqa: F401,F403 __all__ = ['MMDetection3d', 'VoxelDetection'] diff --git a/mmdeploy/codebase/mmdet3d/models/__init__.py b/mmdeploy/codebase/mmdet3d/models/__init__.py index f9cd7c328..494e1100e 100644 --- a/mmdeploy/codebase/mmdet3d/models/__init__.py +++ b/mmdeploy/codebase/mmdet3d/models/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import * # noqa: F401,F403 -from .mvx_two_stage import * # noqa: F401,F403 -from .pillar_encode import * # noqa: F401,F403 -from .pillar_scatter import * # noqa: F401,F403 +from . import base # noqa: F401,F403 +from . import mvx_two_stage # noqa: F401,F403 +from . import pillar_encode # noqa: F401,F403 +from . import pillar_scatter # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmedit/__init__.py b/mmdeploy/codebase/mmedit/__init__.py index 077510cc9..55855b48d 100644 --- a/mmdeploy/codebase/mmedit/__init__.py +++ b/mmdeploy/codebase/mmedit/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .deploy import MMEditing, SuperResolution -from .models import base_edit_model__forward -__all__ = ['MMEditing', 'SuperResolution', 'base_edit_model__forward'] +__all__ = ['MMEditing', 'SuperResolution'] diff --git a/mmdeploy/codebase/mmedit/models/__init__.py b/mmdeploy/codebase/mmedit/models/__init__.py index 2340a632a..83760925d 100644 --- a/mmdeploy/codebase/mmedit/models/__init__.py +++ b/mmdeploy/codebase/mmedit/models/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base_models import * # noqa F401, F403 +from . import base_models # noqa F401, F403 diff --git a/mmdeploy/codebase/mmedit/models/base_models/__init__.py b/mmdeploy/codebase/mmedit/models/base_models/__init__.py index 793bf1e38..5f1ddc82a 100644 --- a/mmdeploy/codebase/mmedit/models/base_models/__init__.py +++ b/mmdeploy/codebase/mmedit/models/base_models/__init__.py @@ -1,4 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base_edit_model import base_edit_model__forward - -__all__ = ['base_edit_model__forward'] +from . import base_edit_model # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmocr/__init__.py b/mmdeploy/codebase/mmocr/__init__.py index 0683b42a3..75c5c8ebd 100644 --- a/mmdeploy/codebase/mmocr/__init__.py +++ b/mmdeploy/codebase/mmocr/__init__.py @@ -1,3 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. from .deploy import * # noqa: F401,F403 -from .models import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmocr/models/__init__.py b/mmdeploy/codebase/mmocr/models/__init__.py index fb561ceac..577a6fb85 100644 --- a/mmdeploy/codebase/mmocr/models/__init__.py +++ b/mmdeploy/codebase/mmocr/models/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .text_detection import * # noqa: F401,F403 -from .text_recognition import * # noqa: F401,F403 +from . import text_detection # noqa: F401,F403 +from . import text_recognition # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmocr/models/text_detection/__init__.py b/mmdeploy/codebase/mmocr/models/text_detection/__init__.py index ef7029ca5..abaed2654 100644 --- a/mmdeploy/codebase/mmocr/models/text_detection/__init__.py +++ b/mmdeploy/codebase/mmocr/models/text_detection/__init__.py @@ -1,9 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .fpn_cat import fpnc__forward__tensorrt -from .heads import base_text_det_head__predict, db_head__predict -from .single_stage_text_detector import single_stage_text_detector__forward - -__all__ = [ - 'fpnc__forward__tensorrt', 'base_text_det_head__predict', - 'single_stage_text_detector__forward', 'db_head__predict' -] +from . import fpn_cat # noqa: F401,F403 +from . import heads # noqa: F401,F403 +from . import single_stage_text_detector # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/__init__.py b/mmdeploy/codebase/mmocr/models/text_recognition/__init__.py index fc5113a39..f3ed1e797 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/__init__.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/__init__.py @@ -1,14 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -# from .base import base_recognizer__forward -from .base_decoder import base_decoder__forward -from .crnn_decoder import crnndecoder__forward_train__ncnn -from .encoder_decoder_recognizer import encoder_decoder_recognizer__forward -from .lstm_layer import bidirectionallstm__forward__ncnn -from .sar_decoder import * # noqa: F401,F403 -from .sar_encoder import sar_encoder__forward - -__all__ = [ - 'base_decoder__forward', 'crnndecoder__forward_train__ncnn', - 'encoder_decoder_recognizer__forward', 'bidirectionallstm__forward__ncnn', - 'sar_encoder__forward' -] +from . import base_decoder # noqa: F401,F403 +from . import crnn_decoder # noqa: F401,F403 +from . import encoder_decoder_recognizer # noqa: F401,F403 +from . import lstm_layer # noqa: F401,F403 +from . import sar_decoder # noqa: F401,F403 +from . import sar_encoder # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmpose/__init__.py b/mmdeploy/codebase/mmpose/__init__.py index 4d28baa62..d5141aa11 100644 --- a/mmdeploy/codebase/mmpose/__init__.py +++ b/mmdeploy/codebase/mmpose/__init__.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. from .deploy import PoseDetection -from .models import * # noqa: F401,F403 __all__ = ['PoseDetection'] diff --git a/mmdeploy/codebase/mmpose/models/__init__.py b/mmdeploy/codebase/mmpose/models/__init__.py index c579ec094..304096f9e 100644 --- a/mmdeploy/codebase/mmpose/models/__init__.py +++ b/mmdeploy/codebase/mmpose/models/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .heads import * # noqa: F401,F403 -from .pose_estimators import * # noqa: F401,F403 +from . import heads # noqa: F401,F403 +from . import pose_estimators # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmrotate/models/dense_heads/oriented_rpn_head.py b/mmdeploy/codebase/mmrotate/models/dense_heads/oriented_rpn_head.py index d28472094..33268394d 100644 --- a/mmdeploy/codebase/mmrotate/models/dense_heads/oriented_rpn_head.py +++ b/mmdeploy/codebase/mmrotate/models/dense_heads/oriented_rpn_head.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmdeploy.codebase.mmdet import (get_post_processing_params, - pad_with_value_if_necessary) +from mmdeploy.codebase.mmdet.deploy import (get_post_processing_params, + pad_with_value_if_necessary) from mmdeploy.codebase.mmrotate.core.post_processing import \ fake_multiclass_nms_rotated from mmdeploy.core import FUNCTION_REWRITER diff --git a/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_anchor_head.py b/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_anchor_head.py index 953702f9d..0cfcf308f 100644 --- a/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_anchor_head.py +++ b/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_anchor_head.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmdeploy.codebase.mmdet import (get_post_processing_params, - pad_with_value_if_necessary) +from mmdeploy.codebase.mmdet.deploy import (get_post_processing_params, + pad_with_value_if_necessary) from mmdeploy.codebase.mmrotate.core.post_processing import \ multiclass_nms_rotated from mmdeploy.core import FUNCTION_REWRITER diff --git a/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rpn_head.py b/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rpn_head.py index eaa81fe02..9d3dc0a84 100644 --- a/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rpn_head.py +++ b/mmdeploy/codebase/mmrotate/models/dense_heads/rotated_rpn_head.py @@ -1,9 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch -from mmdeploy.codebase.mmdet import (get_post_processing_params, - pad_with_value_if_necessary) -from mmdeploy.codebase.mmdet.core.post_processing import multiclass_nms +from mmdeploy.codebase.mmdet.deploy import (get_post_processing_params, + pad_with_value_if_necessary) +from mmdeploy.codebase.mmdet.models.layers import multiclass_nms from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.utils import is_dynamic_shape diff --git a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py index 3d777218d..a5afe7b94 100644 --- a/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py +++ b/mmdeploy/codebase/mmrotate/models/roi_heads/gv_bbox_head.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch.nn.functional as F -from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.codebase.mmdet.deploy import get_post_processing_params from mmdeploy.codebase.mmrotate.core.post_processing import \ multiclass_nms_rotated from mmdeploy.core import FUNCTION_REWRITER diff --git a/mmdeploy/codebase/mmrotate/models/roi_heads/rotated_bbox_head.py b/mmdeploy/codebase/mmrotate/models/roi_heads/rotated_bbox_head.py index 103ebf2f4..99239dce9 100644 --- a/mmdeploy/codebase/mmrotate/models/roi_heads/rotated_bbox_head.py +++ b/mmdeploy/codebase/mmrotate/models/roi_heads/rotated_bbox_head.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F -from mmdeploy.codebase.mmdet import get_post_processing_params +from mmdeploy.codebase.mmdet.deploy import get_post_processing_params from mmdeploy.codebase.mmrotate.core.post_processing import \ multiclass_nms_rotated from mmdeploy.core import FUNCTION_REWRITER diff --git a/mmdeploy/codebase/mmseg/__init__.py b/mmdeploy/codebase/mmseg/__init__.py index 0683b42a3..75c5c8ebd 100644 --- a/mmdeploy/codebase/mmseg/__init__.py +++ b/mmdeploy/codebase/mmseg/__init__.py @@ -1,3 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. from .deploy import * # noqa: F401,F403 -from .models import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmseg/models/__init__.py b/mmdeploy/codebase/mmseg/models/__init__.py index f8c63589a..6e76b282f 100644 --- a/mmdeploy/codebase/mmseg/models/__init__.py +++ b/mmdeploy/codebase/mmseg/models/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .decode_heads import * # noqa: F401,F403 -from .segmentors import * # noqa: F401,F403 -from .utils import * # noqa: F401,F403 +from . import decode_heads # noqa: F401,F403 +from . import segmentors # noqa: F401,F403 +from . import utils # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmseg/models/utils/__init__.py b/mmdeploy/codebase/mmseg/models/utils/__init__.py index 954eaa348..0b872f220 100644 --- a/mmdeploy/codebase/mmseg/models/utils/__init__.py +++ b/mmdeploy/codebase/mmseg/models/utils/__init__.py @@ -1,4 +1,2 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .up_conv_block import up_conv_block__forward - -__all__ = ['up_conv_block__forward'] +from . import up_conv_block # noqa: F401,F403 diff --git a/mmdeploy/mmcv/__init__.py b/mmdeploy/mmcv/__init__.py index a5896f0c3..aedc41397 100644 --- a/mmdeploy/mmcv/__init__.py +++ b/mmdeploy/mmcv/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .cnn import * # noqa: F401,F403 -from .ops import * # noqa: F401,F403 +from . import cnn # noqa: F401,F403 +from . import ops # noqa: F401,F403 diff --git a/mmdeploy/mmcv/cnn/__init__.py b/mmdeploy/mmcv/cnn/__init__.py index 987f190a6..3b777d8b0 100644 --- a/mmdeploy/mmcv/cnn/__init__.py +++ b/mmdeploy/mmcv/cnn/__init__.py @@ -2,4 +2,4 @@ from . import conv2d_adaptive_padding # noqa: F401,F403 from .transformer import MultiHeadAttentionop -__all__ = ['conv2d_adaptive_padding', 'MultiHeadAttentionop'] +__all__ = ['MultiHeadAttentionop'] diff --git a/mmdeploy/mmcv/ops/__init__.py b/mmdeploy/mmcv/ops/__init__.py index bdcf7347d..7a70fc321 100644 --- a/mmdeploy/mmcv/ops/__init__.py +++ b/mmdeploy/mmcv/ops/__init__.py @@ -1,15 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .deform_conv import deform_conv_openvino -from .modulated_deform_conv import modulated_deform_conv_default -from .nms import * # noqa: F401,F403 -from .nms_rotated import * # noqa: F401,F403 -from .point_sample import * # noqa: F401,F403 -from .roi_align import roi_align_default -from .roi_align_rotated import roi_align_rotated_default -from .transformer import patch_embed__forward__ncnn +from . import deform_conv # noqa: F401,F403 +from . import modulated_deform_conv # noqa: F401,F403 +from . import point_sample # noqa: F401,F403 +from . import roi_align # noqa: F401,F403 +from . import roi_align_rotated # noqa: F401,F403 +from . import transformer # noqa: F401,F403 +from .nms import ONNXNMSop, TRTBatchedNMSop +from .nms_rotated import ONNXNMSRotatedOp, TRTBatchedRotatedNMSop __all__ = [ - 'roi_align_default', 'modulated_deform_conv_default', - 'deform_conv_openvino', 'roi_align_rotated_default', - 'patch_embed__forward__ncnn' + 'ONNXNMSop', 'TRTBatchedNMSop', 'TRTBatchedRotatedNMSop', + 'ONNXNMSRotatedOp' ] diff --git a/mmdeploy/pytorch/__init__.py b/mmdeploy/pytorch/__init__.py index c086abb19..83c3288e3 100644 --- a/mmdeploy/pytorch/__init__.py +++ b/mmdeploy/pytorch/__init__.py @@ -1,3 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .functions import * # noqa: F401,F403 -from .ops import * # noqa: F401,F403 +from . import functions # noqa: F401,F403 +from . import symbolics # noqa: F401,F403 diff --git a/mmdeploy/pytorch/functions/__init__.py b/mmdeploy/pytorch/functions/__init__.py index 966946c49..f3e145b4d 100644 --- a/mmdeploy/pytorch/functions/__init__.py +++ b/mmdeploy/pytorch/functions/__init__.py @@ -1,35 +1,22 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import multi_head_attention_forward -from .adaptive_pool import (adaptive_avg_pool2d__default, - adaptive_avg_pool2d__ncnn) -from .atan2 import atan2__default -from .chunk import chunk__ncnn, chunk__torchscript -from .clip import clip__coreml -from .expand import expand__ncnn -from .flatten import flatten__coreml -from .getattribute import tensor__getattribute__ncnn -from .group_norm import group_norm__ncnn -from .interpolate import interpolate__ncnn, interpolate__tensorrt -from .linear import linear__ncnn -from .masked_fill import masked_fill__onnxruntime -from .mod import mod__tensorrt -from .normalize import normalize__ncnn -from .pad import _prepare_onnx_paddings__tensorrt -from .repeat import tensor__repeat__tensorrt -from .size import tensor__size__ncnn -from .tensor_getitem import tensor__getitem__ascend -from .tensor_setitem import tensor__setitem__default -from .topk import topk__dynamic, topk__tensorrt -from .triu import triu__default - -__all__ = [ - 'tensor__getattribute__ncnn', 'group_norm__ncnn', 'interpolate__ncnn', - 'interpolate__tensorrt', 'linear__ncnn', 'tensor__repeat__tensorrt', - 'tensor__size__ncnn', 'topk__dynamic', 'topk__tensorrt', 'chunk__ncnn', - 'triu__default', 'atan2__default', 'normalize__ncnn', 'expand__ncnn', - 'chunk__torchscript', 'masked_fill__onnxruntime', - 'tensor__setitem__default', 'tensor__getitem__ascend', - 'adaptive_avg_pool2d__default', 'adaptive_avg_pool2d__ncnn', - 'multi_head_attention_forward', 'flatten__coreml', 'clip__coreml', - 'mod__tensorrt', '_prepare_onnx_paddings__tensorrt' -] +from . import adaptive_pool # noqa: F401,F403 +from . import atan2 # noqa: F401,F403 +from . import chunk # noqa: F401,F403 +from . import clip # noqa: F401,F403 +from . import expand # noqa: F401,F403 +from . import flatten # noqa: F401,F403 +from . import getattribute # noqa: F401,F403 +from . import group_norm # noqa: F401,F403 +from . import interpolate # noqa: F401,F403 +from . import linear # noqa: F401,F403 +from . import masked_fill # noqa: F401,F403 +from . import mod # noqa: F401,F403 +from . import multi_head_attention_forward # noqa: F401,F403 +from . import normalize # noqa: F401,F403 +from . import pad # noqa: F401,F403 +from . import repeat # noqa: F401,F403 +from . import size # noqa: F401,F403 +from . import tensor_getitem # noqa: F401,F403 +from . import tensor_setitem # noqa: F401,F403 +from . import topk # noqa: F401,F403 +from . import triu # noqa: F401,F403 diff --git a/mmdeploy/pytorch/ops/__init__.py b/mmdeploy/pytorch/ops/__init__.py deleted file mode 100644 index 1fa2cf04f..000000000 --- a/mmdeploy/pytorch/ops/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .adaptive_pool import adaptive_avg_pool2d__ncnn -from .gelu import gelu__ncnn -from .grid_sampler import grid_sampler__default -from .hardsigmoid import hardsigmoid__default -from .instance_norm import instance_norm__tensorrt -from .layer_norm import layer_norm__ncnn -from .linear import linear__ncnn -from .lstm import generic_rnn__ncnn -from .roll import roll_default -from .squeeze import squeeze__default - -__all__ = [ - 'grid_sampler__default', 'hardsigmoid__default', 'instance_norm__tensorrt', - 'generic_rnn__ncnn', 'squeeze__default', 'adaptive_avg_pool2d__ncnn', - 'gelu__ncnn', 'layer_norm__ncnn', 'linear__ncnn', 'roll_default' -] diff --git a/mmdeploy/pytorch/symbolics/__init__.py b/mmdeploy/pytorch/symbolics/__init__.py new file mode 100644 index 000000000..38190b93d --- /dev/null +++ b/mmdeploy/pytorch/symbolics/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from . import adaptive_pool # noqa: F401,F403 +from . import gelu # noqa: F401,F403 +from . import grid_sampler # noqa: F401,F403 +from . import hardsigmoid # noqa: F401,F403 +from . import instance_norm # noqa: F401,F403 +from . import layer_norm # noqa: F401,F403 +from . import linear # noqa: F401,F403 +from . import lstm # noqa: F401,F403 +from . import roll # noqa: F401,F403 +from . import squeeze # noqa: F401,F403 diff --git a/mmdeploy/pytorch/ops/adaptive_pool.py b/mmdeploy/pytorch/symbolics/adaptive_pool.py similarity index 100% rename from mmdeploy/pytorch/ops/adaptive_pool.py rename to mmdeploy/pytorch/symbolics/adaptive_pool.py diff --git a/mmdeploy/pytorch/ops/gelu.py b/mmdeploy/pytorch/symbolics/gelu.py similarity index 100% rename from mmdeploy/pytorch/ops/gelu.py rename to mmdeploy/pytorch/symbolics/gelu.py diff --git a/mmdeploy/pytorch/ops/grid_sampler.py b/mmdeploy/pytorch/symbolics/grid_sampler.py similarity index 100% rename from mmdeploy/pytorch/ops/grid_sampler.py rename to mmdeploy/pytorch/symbolics/grid_sampler.py diff --git a/mmdeploy/pytorch/ops/hardsigmoid.py b/mmdeploy/pytorch/symbolics/hardsigmoid.py similarity index 100% rename from mmdeploy/pytorch/ops/hardsigmoid.py rename to mmdeploy/pytorch/symbolics/hardsigmoid.py diff --git a/mmdeploy/pytorch/ops/instance_norm.py b/mmdeploy/pytorch/symbolics/instance_norm.py similarity index 100% rename from mmdeploy/pytorch/ops/instance_norm.py rename to mmdeploy/pytorch/symbolics/instance_norm.py diff --git a/mmdeploy/pytorch/ops/layer_norm.py b/mmdeploy/pytorch/symbolics/layer_norm.py similarity index 100% rename from mmdeploy/pytorch/ops/layer_norm.py rename to mmdeploy/pytorch/symbolics/layer_norm.py diff --git a/mmdeploy/pytorch/ops/linear.py b/mmdeploy/pytorch/symbolics/linear.py similarity index 100% rename from mmdeploy/pytorch/ops/linear.py rename to mmdeploy/pytorch/symbolics/linear.py diff --git a/mmdeploy/pytorch/ops/lstm.py b/mmdeploy/pytorch/symbolics/lstm.py similarity index 100% rename from mmdeploy/pytorch/ops/lstm.py rename to mmdeploy/pytorch/symbolics/lstm.py diff --git a/mmdeploy/pytorch/ops/roll.py b/mmdeploy/pytorch/symbolics/roll.py similarity index 100% rename from mmdeploy/pytorch/ops/roll.py rename to mmdeploy/pytorch/symbolics/roll.py diff --git a/mmdeploy/pytorch/ops/squeeze.py b/mmdeploy/pytorch/symbolics/squeeze.py similarity index 100% rename from mmdeploy/pytorch/ops/squeeze.py rename to mmdeploy/pytorch/symbolics/squeeze.py diff --git a/tests/test_codebase/test_mmdet/test_mmdet_utils.py b/tests/test_codebase/test_mmdet/test_mmdet_utils.py index 7d6539c84..2045ad93f 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_utils.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_utils.py @@ -12,9 +12,10 @@ try: except ImportError: pytest.skip(f'{Codebase.MMDET} is not installed.', allow_module_level=True) -from mmdeploy.codebase.mmdet import (clip_bboxes, get_post_processing_params, - pad_with_value, - pad_with_value_if_necessary) +from mmdeploy.codebase.mmdet.deploy import (clip_bboxes, + get_post_processing_params, + pad_with_value, + pad_with_value_if_necessary) def test_clip_bboxes(): diff --git a/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py b/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py index 669df8da0..48b41ebb3 100644 --- a/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py +++ b/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py @@ -9,6 +9,7 @@ from mmdeploy.codebase import import_codebase from mmdeploy.utils import Backend, Codebase, Task, load_config from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs +import_codebase(Codebase.MMDET3D) try: import_codebase(Codebase.MMDET3D) except ImportError: @@ -60,6 +61,7 @@ def test_pillar_encoder(backend_type: Backend): num_points = torch.randint(0, 32, (3945, ), dtype=torch.int32) coors = torch.randint(0, 10, (3945, 4), dtype=torch.int32) model_outputs = model.forward(features, num_points, coors) + model_outputs = [model_outputs] wrapped_model = WrapModel(model, 'forward') rewrite_inputs = { 'features': features, @@ -97,6 +99,7 @@ def test_pointpillars_scatter(backend_type: Backend): voxel_features = torch.rand(16 * 16, 64) * 100 coors = torch.randint(0, 10, (16 * 16, 4), dtype=torch.int32) model_outputs = model.forward_batch(voxel_features, coors, 1) + model_outputs = [model_outputs] wrapped_model = WrapModel(model, 'forward_batch') rewrite_inputs = {'voxel_features': voxel_features, 'coors': coors} rewrite_outputs, is_backend_output = get_rewrite_outputs( diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_core.py b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py index d774b7510..a83f79278 100644 --- a/tests/test_codebase/test_mmrotate/test_mmrotate_core.py +++ b/tests/test_codebase/test_mmrotate/test_mmrotate_core.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -import mmengine import numpy as np import pytest import torch +from mmengine import Config from mmdeploy.codebase import import_codebase from mmdeploy.utils import Backend, Codebase @@ -20,7 +20,7 @@ except ImportError: @backend_checker(Backend.ONNXRUNTIME) def test_multiclass_nms_rotated(): from mmdeploy.codebase.mmrotate.core import multiclass_nms_rotated - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict( @@ -72,7 +72,7 @@ def test_multiclass_nms_rotated_with_keep_top_k(pre_top_k): from mmdeploy.codebase.mmrotate.core import multiclass_nms_rotated keep_top_k = 15 - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( onnx_config=dict( output_names=None, @@ -140,7 +140,7 @@ def test_delta_xywha_rbbox_coder_delta2bbox(backend_type: Backend, max_shape: tuple, proj_xy: bool, edge_swap: bool): check_backend(backend_type) - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict(type=backend_type.value, model_inputs=None), @@ -189,7 +189,7 @@ def test_delta_xywha_rbbox_coder_delta2bbox(backend_type: Backend, @pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) def test_delta_midpointoffset_rbbox_delta2bbox(backend_type: Backend): check_backend(backend_type) - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict(type=backend_type.value, model_inputs=None), @@ -227,7 +227,7 @@ def test_delta_midpointoffset_rbbox_delta2bbox(backend_type: Backend): @backend_checker(Backend.ONNXRUNTIME) def test_fake_multiclass_nms_rotated(): from mmdeploy.codebase.mmrotate.core import fake_multiclass_nms_rotated - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict( @@ -277,7 +277,7 @@ def test_fake_multiclass_nms_rotated(): def test_poly2obb_le90(backend_type: Backend): check_backend(backend_type) polys = torch.rand(1, 10, 8) - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict( @@ -316,7 +316,7 @@ def test_poly2obb_le90(backend_type: Backend): def test_poly2obb_le135(backend_type: Backend): check_backend(backend_type) polys = torch.rand(1, 10, 8) - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict( @@ -351,7 +351,7 @@ def test_poly2obb_le135(backend_type: Backend): def test_obb2poly_le135(backend_type: Backend): check_backend(backend_type) rboxes = torch.rand(1, 10, 5) - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( onnx_config=dict(output_names=None, input_shape=None), backend_config=dict( @@ -386,7 +386,7 @@ def test_obb2poly_le135(backend_type: Backend): def test_gvfixcoder__decode(backend_type: Backend): check_backend(backend_type) - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( onnx_config=dict(output_names=['output'], input_shape=None), backend_config=dict(type=backend_type.value), diff --git a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py index d3534e8d9..241d7e7d7 100644 --- a/tests/test_codebase/test_mmrotate/test_mmrotate_models.py +++ b/tests/test_codebase/test_mmrotate/test_mmrotate_models.py @@ -4,11 +4,11 @@ import os import random from typing import Dict, List -import mmcv import mmengine import numpy as np import pytest import torch +from mmengine import Config from mmdeploy.codebase import import_codebase from mmdeploy.utils import Backend, Codebase @@ -50,7 +50,7 @@ def convert_to_list(rewrite_output: Dict, output_names: List[str]) -> List: def get_anchor_head_model(): """AnchorHead Config.""" - test_cfg = mmengine.Config( + test_cfg = Config( dict( nms_pre=2000, min_bbox_size=0, @@ -81,7 +81,7 @@ def _replace_r50_with_r18(model): ['tests/test_codebase/test_mmrotate/data/single_stage_model.json']) def test_forward_of_base_detector(model_cfg_path, backend): check_backend(backend) - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( backend_config=dict(type=backend.value), onnx_config=dict( @@ -96,7 +96,7 @@ def test_forward_of_base_detector(model_cfg_path, backend): keep_top_k=100, )))) - model_cfg = mmengine.Config(dict(model=mmcv.load(model_cfg_path))) + model_cfg = Config(dict(model=mmengine.load(model_cfg_path))) model_cfg.model = _replace_r50_with_r18(model_cfg.model) from mmrotate.models import build_detector @@ -118,7 +118,7 @@ def test_forward_of_base_detector(model_cfg_path, backend): def get_deploy_cfg(backend_type: Backend, ir_type: str): - return mmengine.Config( + return Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict( @@ -222,7 +222,7 @@ def test_rotated_single_roi_extractor(backend_type: Backend): single_roi_extractor = get_single_roi_extractor() output_names = ['roi_feat'] - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict(output_names=output_names, input_shape=None), @@ -265,7 +265,7 @@ def test_rotated_single_roi_extractor(backend_type: Backend): def get_oriented_rpn_head_model(): """Oriented RPN Head Config.""" - test_cfg = mmengine.Config( + test_cfg = Config( dict( nms_pre=2000, min_bbox_size=0, @@ -296,7 +296,7 @@ def test_get_bboxes_of_oriented_rpn_head(backend_type: Backend): }] output_names = ['dets', 'labels'] - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict(output_names=output_names, input_shape=None), @@ -337,7 +337,7 @@ def test_get_bboxes_of_oriented_rpn_head(backend_type: Backend): def get_rotated_rpn_head_model(): """Oriented RPN Head Config.""" - test_cfg = mmengine.Config( + test_cfg = Config( dict( nms_pre=2000, min_bbox_size=0, @@ -377,7 +377,7 @@ def test_get_bboxes_of_rotated_rpn_head(backend_type: Backend): }] output_names = ['dets', 'labels'] - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict(output_names=output_names, input_shape=None), @@ -421,7 +421,7 @@ def test_rotate_standard_roi_head__simple_test(backend_type: Backend): check_backend(backend_type) from mmrotate.models.roi_heads import OrientedStandardRoIHead output_names = ['dets', 'labels'] - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict(output_names=output_names, input_shape=None), @@ -434,7 +434,7 @@ def test_rotate_standard_roi_head__simple_test(backend_type: Backend): pre_top_k=2000, keep_top_k=2000)))) angle_version = 'le90' - test_cfg = mmengine.Config( + test_cfg = Config( dict( nms_pre=2000, min_bbox_size=0, @@ -489,7 +489,7 @@ def test_gv_ratio_roi_head__simple_test(backend_type: Backend): check_backend(backend_type) from mmrotate.models.roi_heads import GVRatioRoIHead output_names = ['dets', 'labels'] - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict(output_names=output_names, input_shape=None), @@ -503,7 +503,7 @@ def test_gv_ratio_roi_head__simple_test(backend_type: Backend): keep_top_k=2000, max_output_boxes_per_class=1000)))) angle_version = 'le90' - test_cfg = mmengine.Config( + test_cfg = Config( dict( nms_pre=2000, min_bbox_size=0, @@ -616,7 +616,7 @@ def get_roi_trans_roi_head_model(): type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=1.0)) ] - test_cfg = mmengine.Config( + test_cfg = Config( dict( nms_pre=2000, min_bbox_size=0, @@ -660,7 +660,7 @@ def test_simple_test_of_roi_trans_roi_head(backend_type: Backend): } output_names = ['det_bboxes', 'det_labels'] - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( backend_config=dict(type=backend_type.value), onnx_config=dict(output_names=output_names, input_shape=None), diff --git a/tests/test_codebase/test_mmrotate/test_rotated_detection.py b/tests/test_codebase/test_mmrotate/test_rotated_detection.py index df2677d80..7c802be8e 100644 --- a/tests/test_codebase/test_mmrotate/test_rotated_detection.py +++ b/tests/test_codebase/test_mmrotate/test_rotated_detection.py @@ -2,10 +2,10 @@ import os from tempfile import NamedTemporaryFile, TemporaryDirectory -import mmengine import numpy as np import pytest import torch +from mmengine import Config from torch.utils.data import DataLoader from torch.utils.data.dataset import Dataset @@ -23,7 +23,7 @@ except ImportError: model_cfg_path = 'tests/test_codebase/test_mmrotate/data/model.py' model_cfg = load_config(model_cfg_path)[0] -deploy_cfg = mmengine.Config( +deploy_cfg = Config( dict( backend_config=dict(type='onnxruntime'), codebase_config=dict( diff --git a/tests/test_codebase/test_mmrotate/test_rotated_detection_model.py b/tests/test_codebase/test_mmrotate/test_rotated_detection_model.py index 48b2558e6..464415b4c 100644 --- a/tests/test_codebase/test_mmrotate/test_rotated_detection_model.py +++ b/tests/test_codebase/test_mmrotate/test_rotated_detection_model.py @@ -2,10 +2,10 @@ import os.path as osp from tempfile import NamedTemporaryFile -import mmengine import numpy as np import pytest import torch +from mmengine import Config import mmdeploy.backend.onnxruntime as ort_apis from mmdeploy.codebase import import_codebase @@ -37,7 +37,7 @@ class TestEnd2EndModel: 'labels': torch.rand(1, 10) } cls.wrapper.set(outputs=cls.outputs) - deploy_cfg = mmengine.Config( + deploy_cfg = Config( {'onnx_config': { 'output_names': ['dets', 'labels'] }}) @@ -90,7 +90,7 @@ class TestEnd2EndModel: def test_build_rotated_detection_model(): model_cfg_path = 'tests/test_codebase/test_mmrotate/data/model.py' model_cfg = load_config(model_cfg_path)[0] - deploy_cfg = mmengine.Config( + deploy_cfg = Config( dict( backend_config=dict(type='onnxruntime'), onnx_config=dict(output_names=['dets', 'labels']), From d330e17af341d0a31ffa89fe985d9286d4cacf66 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Tue, 8 Nov 2022 11:17:59 +0800 Subject: [PATCH 4/4] fix reg of dev-1.x(#1317) --- tests/regression/mmdet.yml | 2 +- tools/regression_test.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/regression/mmdet.yml b/tests/regression/mmdet.yml index 863717be3..cf78b9920 100644 --- a/tests/regression/mmdet.yml +++ b/tests/regression/mmdet.yml @@ -103,7 +103,7 @@ tensorrt: pipeline_seg_trt_dynamic_fp32: &pipeline_seg_trt_dynamic_fp32 convert_image: *convert_image backend_test: *default_backend_test - # sdk_config: *sdk_seg_dynamic + sdk_config: *sdk_seg_dynamic deploy_config: configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py pipeline_seg_trt_dynamic_fp16: &pipeline_seg_trt_dynamic_fp16 diff --git a/tools/regression_test.py b/tools/regression_test.py index 8b45ae5eb..c53e30f9f 100644 --- a/tools/regression_test.py +++ b/tools/regression_test.py @@ -429,13 +429,13 @@ def get_fps_metric(shell_res: int, pytorch_metric: dict, metric_info: dict, metric_key = metric_info[metric_name]['metric_key'] tolerance = metric_info[metric_name]['tolerance'] multi_value = metric_info[metric_name].get('multi_value', 1.0) - compare_flag = True - output_result[metric_name] = '-' + compare_flag = False + output_result[metric_name] = 'x' if metric_key in backend_results: backend_value = backend_results[metric_key] * multi_value output_result[metric_name] = backend_value - if backend_value < metric_value - tolerance: - compare_flag = False + if backend_value >= metric_value - tolerance: + compare_flag = True compare_results[metric_name] = compare_flag if len(compare_results): @@ -489,7 +489,7 @@ def get_backend_fps_metric(deploy_cfg_path: str, model_cfg_path: Path, fps, backend_metric, test_pass = get_fps_metric(return_code, pytorch_metric, metric_info, work_dir) - logger.info(f'test_pass={test_pass}, results{backend_metric}') + logger.info(f'test_pass= {test_pass}, results= {backend_metric}') metric_list = [] for metric in metric_info: value = '-'