diff --git a/mmdeploy/codebase/base/task.py b/mmdeploy/codebase/base/task.py index c9d4f75d5..6714d7a81 100644 --- a/mmdeploy/codebase/base/task.py +++ b/mmdeploy/codebase/base/task.py @@ -67,6 +67,15 @@ class BaseTask(metaclass=ABCMeta): """ pass + def build_data_preprocessor(self): + model = deepcopy(self.model_cfg.model) + preprocess_cfg = model['data_preprocessor'] + + from mmengine.registry import MODELS + data_preprocessor = MODELS.build(preprocess_cfg) + + return data_preprocessor + def build_pytorch_model(self, model_checkpoint: Optional[str] = None, cfg_options: Optional[Dict] = None, diff --git a/mmdeploy/codebase/mmcls/deploy/classification.py b/mmdeploy/codebase/mmcls/deploy/classification.py index e5a5ec5f5..016c4a1ae 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification.py +++ b/mmdeploy/codebase/mmcls/deploy/classification.py @@ -113,6 +113,16 @@ class Classification(BaseTask): super(Classification, self).__init__(model_cfg, deploy_cfg, device, experiment_name) + def build_data_preprocessor(self): + model_cfg = deepcopy(self.model_cfg) + data_preprocessor = deepcopy(model_cfg.get('preprocess_cfg', {})) + data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor') + + from mmengine.registry import MODELS + data_preprocessor = MODELS.build(data_preprocessor) + + return data_preprocessor + def build_backend_model(self, model_files: Sequence[str] = None, **kwargs) -> torch.nn.Module: diff --git a/tools/onnx2ncnn_quant_table.py b/tools/onnx2ncnn_quant_table.py index e55a5a70d..fb959ecd3 100644 --- a/tools/onnx2ncnn_quant_table.py +++ b/tools/onnx2ncnn_quant_table.py @@ -1,9 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse import logging +from copy import deepcopy -from mmcv import Config +from mmengine import Config +from mmdeploy.apis.utils import build_task_processor from mmdeploy.utils import get_root_logger, load_config @@ -21,27 +23,36 @@ def get_table(onnx_path: str, if 'onnx_config' in deploy_cfg and 'input_shape' in deploy_cfg.onnx_config: input_shape = deploy_cfg.onnx_config.input_shape + task_processor = build_task_processor(model_cfg, deploy_cfg, device) + calib_dataloader = deepcopy(model_cfg[f'{dataset_type}_dataloader']) + calib_dataloader['batch_size'] = 1 # build calibration dataloader. If img dir not specified, use val dataset. if image_dir is not None: from quant_image_dataset import QuantizationImageDataset - from torch.utils.data import DataLoader dataset = QuantizationImageDataset( path=image_dir, deploy_cfg=deploy_cfg, model_cfg=model_cfg) - dataloader = DataLoader(dataset, batch_size=1) + calib_dataloader['dataset'] = dataset + dataloader = task_processor.build_dataloader(calib_dataloader) + # dataloader = DataLoader(dataset, batch_size=1) else: - from mmdeploy.apis.utils import build_task_processor - task_processor = build_task_processor(model_cfg, deploy_cfg, device) - dataset = task_processor.build_dataset(model_cfg, dataset_type) - dataloader = task_processor.build_dataloader(dataset, 1, 1) + dataset = task_processor.build_dataset(calib_dataloader['dataset']) + calib_dataloader['dataset'] = dataset + dataloader = task_processor.build_dataloader(calib_dataloader) + + data_preprocessor = task_processor.build_data_preprocessor() # get an available input shape randomly for _, input_data in enumerate(dataloader): - if isinstance(input_data['img'], list): - input_shape = input_data['img'][0].shape - collate_fn = lambda x: x['img'][0].to(device) # noqa: E731 + input_data = data_preprocessor(input_data) + input_tensor = input_data[0] + if isinstance(input_tensor, list): + input_shape = input_tensor[0].shape + collate_fn = lambda x: data_preprocessor(x[0])[0].to( # noqa: E731 + device) else: - input_shape = input_data['img'].shape - collate_fn = lambda x: x['img'].to(device) # noqa: E731 + input_shape = input_tensor.shape + collate_fn = lambda x: data_preprocessor(x)[0].to( # noqa: E731 + device) break from ppq import QuantizationSettingFactory, TargetPlatform @@ -106,13 +117,9 @@ def main(): quant_onnx_path = args.out_onnx image_dir = args.image_dir - try: - get_table(onnx_path, deploy_cfg, model_cfg, quant_onnx_path, - quant_table_path, image_dir) - logger.info('onnx2ncnn_quant_table success.') - except Exception as e: - logger.error(e) - logger.error('onnx2ncnn_quant_table failed.') + get_table(onnx_path, deploy_cfg, model_cfg, quant_onnx_path, + quant_table_path, image_dir) + logger.info('onnx2ncnn_quant_table success.') if __name__ == '__main__': diff --git a/tools/quant_image_dataset.py b/tools/quant_image_dataset.py index 549d24b2b..d6a37f4dd 100644 --- a/tools/quant_image_dataset.py +++ b/tools/quant_image_dataset.py @@ -3,9 +3,10 @@ from typing import Optional, Sequence import mmcv from mmcv import FileClient +from mmengine import Config from torch.utils.data import Dataset -from mmdeploy.utils import Codebase, get_codebase +from mmdeploy.apis import build_task_processor class QuantizationImageDataset(Dataset): @@ -13,46 +14,16 @@ class QuantizationImageDataset(Dataset): def __init__( self, path: str, - deploy_cfg: mmcv.Config, - model_cfg: mmcv.Config, + deploy_cfg: Config, + model_cfg: Config, file_client_args: Optional[dict] = None, extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'), ): super().__init__() - codebase_type = get_codebase(deploy_cfg) - self.exclude_pipe = ['LoadImageFromFile'] - if codebase_type == Codebase.MMCLS: - from mmcls.datasets.pipelines import Compose - elif codebase_type == Codebase.MMDET: - from mmdet.datasets.pipelines import Compose - elif codebase_type == Codebase.MMDET3D: - from mmdet3d.datasets.pipelines import Compose - self.exclude_pipe.extend([ - 'LoadMultiViewImageFromFiles', 'LoadImageFromFileMono3D', - 'LoadPointsFromMultiSweeps', 'LoadPointsFromDict' - ]) - elif codebase_type == Codebase.MMEDIT: - from mmedit.datasets.pipelines import Compose - self.exclude_pipe.extend( - ['LoadImageFromFileList', 'LoadPairedImageFromFile']) - elif codebase_type == Codebase.MMOCR: - from mmocr.datasets.pipelines import Compose - self.exclude_pipe.extend( - ['LoadImageFromNdarray', 'LoadImageFromLMDB']) - elif codebase_type == Codebase.MMPOSE: - from mmpose.datasets.pipelines import Compose - elif codebase_type == Codebase.MMROTATE: - from mmrotate.datasets.pipelines import Compose - elif codebase_type == Codebase.MMSEG: - from mmseg.datasets.pipelines import Compose - else: - raise Exception( - 'Not supported codebase_type {}'.format(codebase_type)) - pipeline = filter(lambda val: val['type'] not in self.exclude_pipe, - model_cfg.data.test.pipeline) + task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu') + self.task_processor = task_processor - self.preprocess = Compose(list(pipeline)) self.samples = [] self.extensions = tuple(set([i.lower() for i in extensions])) self.file_client = FileClient.infer_client(file_client_args, path) @@ -77,12 +48,8 @@ class QuantizationImageDataset(Dataset): def __getitem__(self, index): sample = self.samples[index] image = mmcv.imread(sample) - data = dict(img=image) - data = self.preprocess(data) - from mmcv.parallel import collate - data = collate([data], samples_per_gpu=1) - - return {'img': data['img'].squeeze()} + data = self.task_processor.create_input(image) + return data[0] def is_valid_file(self, filename: str) -> bool: """Check if a file is a valid sample."""