diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection.py b/mmdeploy/codebase/mmpose/deploy/pose_detection.py index 07f817eaf..f3b5b6ded 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection.py @@ -197,15 +197,15 @@ class PoseDetection(BaseTask): from mmcv.transforms import Compose from mmpose.registry import TRANSFORMS cfg = self.model_cfg - if isinstance(imgs, str): - imgs = [mmcv.imread(imgs)] - elif isinstance(imgs, (list, tuple)): - img_data = [] - for img in imgs: - if isinstance(img, str): - img_data.append(mmcv.imread(img)) - else: - img_data.append(img) + if isinstance(imgs, (list, tuple)): + if not isinstance(imgs[0], (np.ndarray, str)): + raise AssertionError('imgs must be strings or numpy arrays') + elif isinstance(imgs, (np.ndarray, str)): + imgs = [imgs] + else: + raise AssertionError('imgs must be strings or numpy arrays') + if isinstance(imgs, (list, tuple)) and isinstance(imgs[0], str): + img_data = [mmcv.imread(img) for img in imgs] imgs = img_data person_results = [] bboxes = [] diff --git a/tools/onnx2ncnn_quant_table.py b/tools/onnx2ncnn_quant_table.py index fb959ecd3..a3c3dbd99 100644 --- a/tools/onnx2ncnn_quant_table.py +++ b/tools/onnx2ncnn_quant_table.py @@ -4,6 +4,7 @@ import logging from copy import deepcopy from mmengine import Config +from torch.utils.data import DataLoader from mmdeploy.apis.utils import build_task_processor from mmdeploy.utils import get_root_logger, load_config @@ -31,9 +32,11 @@ def get_table(onnx_path: str, from quant_image_dataset import QuantizationImageDataset dataset = QuantizationImageDataset( path=image_dir, deploy_cfg=deploy_cfg, model_cfg=model_cfg) - calib_dataloader['dataset'] = dataset - dataloader = task_processor.build_dataloader(calib_dataloader) - # dataloader = DataLoader(dataset, batch_size=1) + + def collate(data_batch): + return data_batch[0] + + dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate) else: dataset = task_processor.build_dataset(calib_dataloader['dataset']) calib_dataloader['dataset'] = dataset @@ -44,16 +47,10 @@ def get_table(onnx_path: str, # get an available input shape randomly for _, input_data in enumerate(dataloader): input_data = data_preprocessor(input_data) - input_tensor = input_data[0] - if isinstance(input_tensor, list): - input_shape = input_tensor[0].shape - collate_fn = lambda x: data_preprocessor(x[0])[0].to( # noqa: E731 - device) - else: - input_shape = input_tensor.shape - collate_fn = lambda x: data_preprocessor(x)[0].to( # noqa: E731 - device) - break + input_tensor = input_data['inputs'] + input_shape = input_tensor.shape + collate_fn = lambda x: data_preprocessor(x)['inputs'].to( # noqa: E731 + device) from ppq import QuantizationSettingFactory, TargetPlatform from ppq.api import export_ppq_graph, quantize_onnx_model