fix quanti

This commit is contained in:
grimoire 2022-07-11 16:42:32 +08:00
parent d700413d30
commit ce036d547a
4 changed files with 53 additions and 60 deletions

View File

@ -67,6 +67,15 @@ class BaseTask(metaclass=ABCMeta):
""" """
pass pass
def build_data_preprocessor(self):
model = deepcopy(self.model_cfg.model)
preprocess_cfg = model['data_preprocessor']
from mmengine.registry import MODELS
data_preprocessor = MODELS.build(preprocess_cfg)
return data_preprocessor
def build_pytorch_model(self, def build_pytorch_model(self,
model_checkpoint: Optional[str] = None, model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None, cfg_options: Optional[Dict] = None,

View File

@ -113,6 +113,16 @@ class Classification(BaseTask):
super(Classification, self).__init__(model_cfg, deploy_cfg, device, super(Classification, self).__init__(model_cfg, deploy_cfg, device,
experiment_name) experiment_name)
def build_data_preprocessor(self):
model_cfg = deepcopy(self.model_cfg)
data_preprocessor = deepcopy(model_cfg.get('preprocess_cfg', {}))
data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor')
from mmengine.registry import MODELS
data_preprocessor = MODELS.build(data_preprocessor)
return data_preprocessor
def build_backend_model(self, def build_backend_model(self,
model_files: Sequence[str] = None, model_files: Sequence[str] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:

View File

@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import argparse import argparse
import logging import logging
from copy import deepcopy
from mmcv import Config from mmengine import Config
from mmdeploy.apis.utils import build_task_processor
from mmdeploy.utils import get_root_logger, load_config from mmdeploy.utils import get_root_logger, load_config
@ -21,27 +23,36 @@ def get_table(onnx_path: str,
if 'onnx_config' in deploy_cfg and 'input_shape' in deploy_cfg.onnx_config: if 'onnx_config' in deploy_cfg and 'input_shape' in deploy_cfg.onnx_config:
input_shape = deploy_cfg.onnx_config.input_shape input_shape = deploy_cfg.onnx_config.input_shape
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
calib_dataloader = deepcopy(model_cfg[f'{dataset_type}_dataloader'])
calib_dataloader['batch_size'] = 1
# build calibration dataloader. If img dir not specified, use val dataset. # build calibration dataloader. If img dir not specified, use val dataset.
if image_dir is not None: if image_dir is not None:
from quant_image_dataset import QuantizationImageDataset from quant_image_dataset import QuantizationImageDataset
from torch.utils.data import DataLoader
dataset = QuantizationImageDataset( dataset = QuantizationImageDataset(
path=image_dir, deploy_cfg=deploy_cfg, model_cfg=model_cfg) path=image_dir, deploy_cfg=deploy_cfg, model_cfg=model_cfg)
dataloader = DataLoader(dataset, batch_size=1) calib_dataloader['dataset'] = dataset
dataloader = task_processor.build_dataloader(calib_dataloader)
# dataloader = DataLoader(dataset, batch_size=1)
else: else:
from mmdeploy.apis.utils import build_task_processor dataset = task_processor.build_dataset(calib_dataloader['dataset'])
task_processor = build_task_processor(model_cfg, deploy_cfg, device) calib_dataloader['dataset'] = dataset
dataset = task_processor.build_dataset(model_cfg, dataset_type) dataloader = task_processor.build_dataloader(calib_dataloader)
dataloader = task_processor.build_dataloader(dataset, 1, 1)
data_preprocessor = task_processor.build_data_preprocessor()
# get an available input shape randomly # get an available input shape randomly
for _, input_data in enumerate(dataloader): for _, input_data in enumerate(dataloader):
if isinstance(input_data['img'], list): input_data = data_preprocessor(input_data)
input_shape = input_data['img'][0].shape input_tensor = input_data[0]
collate_fn = lambda x: x['img'][0].to(device) # noqa: E731 if isinstance(input_tensor, list):
input_shape = input_tensor[0].shape
collate_fn = lambda x: data_preprocessor(x[0])[0].to( # noqa: E731
device)
else: else:
input_shape = input_data['img'].shape input_shape = input_tensor.shape
collate_fn = lambda x: x['img'].to(device) # noqa: E731 collate_fn = lambda x: data_preprocessor(x)[0].to( # noqa: E731
device)
break break
from ppq import QuantizationSettingFactory, TargetPlatform from ppq import QuantizationSettingFactory, TargetPlatform
@ -106,13 +117,9 @@ def main():
quant_onnx_path = args.out_onnx quant_onnx_path = args.out_onnx
image_dir = args.image_dir image_dir = args.image_dir
try:
get_table(onnx_path, deploy_cfg, model_cfg, quant_onnx_path, get_table(onnx_path, deploy_cfg, model_cfg, quant_onnx_path,
quant_table_path, image_dir) quant_table_path, image_dir)
logger.info('onnx2ncnn_quant_table success.') logger.info('onnx2ncnn_quant_table success.')
except Exception as e:
logger.error(e)
logger.error('onnx2ncnn_quant_table failed.')
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -3,9 +3,10 @@ from typing import Optional, Sequence
import mmcv import mmcv
from mmcv import FileClient from mmcv import FileClient
from mmengine import Config
from torch.utils.data import Dataset from torch.utils.data import Dataset
from mmdeploy.utils import Codebase, get_codebase from mmdeploy.apis import build_task_processor
class QuantizationImageDataset(Dataset): class QuantizationImageDataset(Dataset):
@ -13,46 +14,16 @@ class QuantizationImageDataset(Dataset):
def __init__( def __init__(
self, self,
path: str, path: str,
deploy_cfg: mmcv.Config, deploy_cfg: Config,
model_cfg: mmcv.Config, model_cfg: Config,
file_client_args: Optional[dict] = None, file_client_args: Optional[dict] = None,
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp',
'.pgm', '.tif'), '.pgm', '.tif'),
): ):
super().__init__() super().__init__()
codebase_type = get_codebase(deploy_cfg) task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
self.exclude_pipe = ['LoadImageFromFile'] self.task_processor = task_processor
if codebase_type == Codebase.MMCLS:
from mmcls.datasets.pipelines import Compose
elif codebase_type == Codebase.MMDET:
from mmdet.datasets.pipelines import Compose
elif codebase_type == Codebase.MMDET3D:
from mmdet3d.datasets.pipelines import Compose
self.exclude_pipe.extend([
'LoadMultiViewImageFromFiles', 'LoadImageFromFileMono3D',
'LoadPointsFromMultiSweeps', 'LoadPointsFromDict'
])
elif codebase_type == Codebase.MMEDIT:
from mmedit.datasets.pipelines import Compose
self.exclude_pipe.extend(
['LoadImageFromFileList', 'LoadPairedImageFromFile'])
elif codebase_type == Codebase.MMOCR:
from mmocr.datasets.pipelines import Compose
self.exclude_pipe.extend(
['LoadImageFromNdarray', 'LoadImageFromLMDB'])
elif codebase_type == Codebase.MMPOSE:
from mmpose.datasets.pipelines import Compose
elif codebase_type == Codebase.MMROTATE:
from mmrotate.datasets.pipelines import Compose
elif codebase_type == Codebase.MMSEG:
from mmseg.datasets.pipelines import Compose
else:
raise Exception(
'Not supported codebase_type {}'.format(codebase_type))
pipeline = filter(lambda val: val['type'] not in self.exclude_pipe,
model_cfg.data.test.pipeline)
self.preprocess = Compose(list(pipeline))
self.samples = [] self.samples = []
self.extensions = tuple(set([i.lower() for i in extensions])) self.extensions = tuple(set([i.lower() for i in extensions]))
self.file_client = FileClient.infer_client(file_client_args, path) self.file_client = FileClient.infer_client(file_client_args, path)
@ -77,12 +48,8 @@ class QuantizationImageDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
sample = self.samples[index] sample = self.samples[index]
image = mmcv.imread(sample) image = mmcv.imread(sample)
data = dict(img=image) data = self.task_processor.create_input(image)
data = self.preprocess(data) return data[0]
from mmcv.parallel import collate
data = collate([data], samples_per_gpu=1)
return {'img': data['img'].squeeze()}
def is_valid_file(self, filename: str) -> bool: def is_valid_file(self, filename: str) -> bool:
"""Check if a file is a valid sample.""" """Check if a file is a valid sample."""