[Fix] Fix dataloader for ncnn quantization (#2018)

* fix mmpose create_input when feed ndarray

* fix dataloader for ncnn-ppq
pull/2070/head
Chen Xin 2023-05-08 15:03:25 +08:00 committed by GitHub
parent 2d85be9aa6
commit 49103cb72e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 22 deletions

View File

@ -197,15 +197,15 @@ class PoseDetection(BaseTask):
from mmcv.transforms import Compose
from mmpose.registry import TRANSFORMS
cfg = self.model_cfg
if isinstance(imgs, str):
imgs = [mmcv.imread(imgs)]
elif isinstance(imgs, (list, tuple)):
img_data = []
for img in imgs:
if isinstance(img, str):
img_data.append(mmcv.imread(img))
else:
img_data.append(img)
if isinstance(imgs, (list, tuple)):
if not isinstance(imgs[0], (np.ndarray, str)):
raise AssertionError('imgs must be strings or numpy arrays')
elif isinstance(imgs, (np.ndarray, str)):
imgs = [imgs]
else:
raise AssertionError('imgs must be strings or numpy arrays')
if isinstance(imgs, (list, tuple)) and isinstance(imgs[0], str):
img_data = [mmcv.imread(img) for img in imgs]
imgs = img_data
person_results = []
bboxes = []

View File

@ -4,6 +4,7 @@ import logging
from copy import deepcopy
from mmengine import Config
from torch.utils.data import DataLoader
from mmdeploy.apis.utils import build_task_processor
from mmdeploy.utils import get_root_logger, load_config
@ -31,9 +32,11 @@ def get_table(onnx_path: str,
from quant_image_dataset import QuantizationImageDataset
dataset = QuantizationImageDataset(
path=image_dir, deploy_cfg=deploy_cfg, model_cfg=model_cfg)
calib_dataloader['dataset'] = dataset
dataloader = task_processor.build_dataloader(calib_dataloader)
# dataloader = DataLoader(dataset, batch_size=1)
def collate(data_batch):
return data_batch[0]
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate)
else:
dataset = task_processor.build_dataset(calib_dataloader['dataset'])
calib_dataloader['dataset'] = dataset
@ -44,16 +47,10 @@ def get_table(onnx_path: str,
# get an available input shape randomly
for _, input_data in enumerate(dataloader):
input_data = data_preprocessor(input_data)
input_tensor = input_data[0]
if isinstance(input_tensor, list):
input_shape = input_tensor[0].shape
collate_fn = lambda x: data_preprocessor(x[0])[0].to( # noqa: E731
device)
else:
input_shape = input_tensor.shape
collate_fn = lambda x: data_preprocessor(x)[0].to( # noqa: E731
device)
break
input_tensor = input_data['inputs']
input_shape = input_tensor.shape
collate_fn = lambda x: data_preprocessor(x)['inputs'].to( # noqa: E731
device)
from ppq import QuantizationSettingFactory, TargetPlatform
from ppq.api import export_ppq_graph, quantize_onnx_model