[Fix] Fix dataloader for ncnn quantization (#2018)
* fix mmpose create_input when feed ndarray * fix dataloader for ncnn-ppqpull/2070/head
parent
2d85be9aa6
commit
49103cb72e
|
@ -197,15 +197,15 @@ class PoseDetection(BaseTask):
|
|||
from mmcv.transforms import Compose
|
||||
from mmpose.registry import TRANSFORMS
|
||||
cfg = self.model_cfg
|
||||
if isinstance(imgs, str):
|
||||
imgs = [mmcv.imread(imgs)]
|
||||
elif isinstance(imgs, (list, tuple)):
|
||||
img_data = []
|
||||
for img in imgs:
|
||||
if isinstance(img, str):
|
||||
img_data.append(mmcv.imread(img))
|
||||
else:
|
||||
img_data.append(img)
|
||||
if isinstance(imgs, (list, tuple)):
|
||||
if not isinstance(imgs[0], (np.ndarray, str)):
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
elif isinstance(imgs, (np.ndarray, str)):
|
||||
imgs = [imgs]
|
||||
else:
|
||||
raise AssertionError('imgs must be strings or numpy arrays')
|
||||
if isinstance(imgs, (list, tuple)) and isinstance(imgs[0], str):
|
||||
img_data = [mmcv.imread(img) for img in imgs]
|
||||
imgs = img_data
|
||||
person_results = []
|
||||
bboxes = []
|
||||
|
|
|
@ -4,6 +4,7 @@ import logging
|
|||
from copy import deepcopy
|
||||
|
||||
from mmengine import Config
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmdeploy.apis.utils import build_task_processor
|
||||
from mmdeploy.utils import get_root_logger, load_config
|
||||
|
@ -31,9 +32,11 @@ def get_table(onnx_path: str,
|
|||
from quant_image_dataset import QuantizationImageDataset
|
||||
dataset = QuantizationImageDataset(
|
||||
path=image_dir, deploy_cfg=deploy_cfg, model_cfg=model_cfg)
|
||||
calib_dataloader['dataset'] = dataset
|
||||
dataloader = task_processor.build_dataloader(calib_dataloader)
|
||||
# dataloader = DataLoader(dataset, batch_size=1)
|
||||
|
||||
def collate(data_batch):
|
||||
return data_batch[0]
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate)
|
||||
else:
|
||||
dataset = task_processor.build_dataset(calib_dataloader['dataset'])
|
||||
calib_dataloader['dataset'] = dataset
|
||||
|
@ -44,16 +47,10 @@ def get_table(onnx_path: str,
|
|||
# get an available input shape randomly
|
||||
for _, input_data in enumerate(dataloader):
|
||||
input_data = data_preprocessor(input_data)
|
||||
input_tensor = input_data[0]
|
||||
if isinstance(input_tensor, list):
|
||||
input_shape = input_tensor[0].shape
|
||||
collate_fn = lambda x: data_preprocessor(x[0])[0].to( # noqa: E731
|
||||
device)
|
||||
else:
|
||||
input_shape = input_tensor.shape
|
||||
collate_fn = lambda x: data_preprocessor(x)[0].to( # noqa: E731
|
||||
device)
|
||||
break
|
||||
input_tensor = input_data['inputs']
|
||||
input_shape = input_tensor.shape
|
||||
collate_fn = lambda x: data_preprocessor(x)['inputs'].to( # noqa: E731
|
||||
device)
|
||||
|
||||
from ppq import QuantizationSettingFactory, TargetPlatform
|
||||
from ppq.api import export_ppq_graph, quantize_onnx_model
|
||||
|
|
Loading…
Reference in New Issue