fix quanti

This commit is contained in:
grimoire 2022-07-11 16:42:32 +08:00
parent d700413d30
commit ce036d547a
4 changed files with 53 additions and 60 deletions

View File

@ -67,6 +67,15 @@ class BaseTask(metaclass=ABCMeta):
"""
pass
def build_data_preprocessor(self):
model = deepcopy(self.model_cfg.model)
preprocess_cfg = model['data_preprocessor']
from mmengine.registry import MODELS
data_preprocessor = MODELS.build(preprocess_cfg)
return data_preprocessor
def build_pytorch_model(self,
model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None,

View File

@ -113,6 +113,16 @@ class Classification(BaseTask):
super(Classification, self).__init__(model_cfg, deploy_cfg, device,
experiment_name)
def build_data_preprocessor(self):
model_cfg = deepcopy(self.model_cfg)
data_preprocessor = deepcopy(model_cfg.get('preprocess_cfg', {}))
data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor')
from mmengine.registry import MODELS
data_preprocessor = MODELS.build(data_preprocessor)
return data_preprocessor
def build_backend_model(self,
model_files: Sequence[str] = None,
**kwargs) -> torch.nn.Module:

View File

@ -1,9 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import logging
from copy import deepcopy
from mmcv import Config
from mmengine import Config
from mmdeploy.apis.utils import build_task_processor
from mmdeploy.utils import get_root_logger, load_config
@ -21,27 +23,36 @@ def get_table(onnx_path: str,
if 'onnx_config' in deploy_cfg and 'input_shape' in deploy_cfg.onnx_config:
input_shape = deploy_cfg.onnx_config.input_shape
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
calib_dataloader = deepcopy(model_cfg[f'{dataset_type}_dataloader'])
calib_dataloader['batch_size'] = 1
# build calibration dataloader. If img dir not specified, use val dataset.
if image_dir is not None:
from quant_image_dataset import QuantizationImageDataset
from torch.utils.data import DataLoader
dataset = QuantizationImageDataset(
path=image_dir, deploy_cfg=deploy_cfg, model_cfg=model_cfg)
dataloader = DataLoader(dataset, batch_size=1)
calib_dataloader['dataset'] = dataset
dataloader = task_processor.build_dataloader(calib_dataloader)
# dataloader = DataLoader(dataset, batch_size=1)
else:
from mmdeploy.apis.utils import build_task_processor
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
dataset = task_processor.build_dataset(model_cfg, dataset_type)
dataloader = task_processor.build_dataloader(dataset, 1, 1)
dataset = task_processor.build_dataset(calib_dataloader['dataset'])
calib_dataloader['dataset'] = dataset
dataloader = task_processor.build_dataloader(calib_dataloader)
data_preprocessor = task_processor.build_data_preprocessor()
# get an available input shape randomly
for _, input_data in enumerate(dataloader):
if isinstance(input_data['img'], list):
input_shape = input_data['img'][0].shape
collate_fn = lambda x: x['img'][0].to(device) # noqa: E731
input_data = data_preprocessor(input_data)
input_tensor = input_data[0]
if isinstance(input_tensor, list):
input_shape = input_tensor[0].shape
collate_fn = lambda x: data_preprocessor(x[0])[0].to( # noqa: E731
device)
else:
input_shape = input_data['img'].shape
collate_fn = lambda x: x['img'].to(device) # noqa: E731
input_shape = input_tensor.shape
collate_fn = lambda x: data_preprocessor(x)[0].to( # noqa: E731
device)
break
from ppq import QuantizationSettingFactory, TargetPlatform
@ -106,13 +117,9 @@ def main():
quant_onnx_path = args.out_onnx
image_dir = args.image_dir
try:
get_table(onnx_path, deploy_cfg, model_cfg, quant_onnx_path,
quant_table_path, image_dir)
logger.info('onnx2ncnn_quant_table success.')
except Exception as e:
logger.error(e)
logger.error('onnx2ncnn_quant_table failed.')
get_table(onnx_path, deploy_cfg, model_cfg, quant_onnx_path,
quant_table_path, image_dir)
logger.info('onnx2ncnn_quant_table success.')
if __name__ == '__main__':

View File

@ -3,9 +3,10 @@ from typing import Optional, Sequence
import mmcv
from mmcv import FileClient
from mmengine import Config
from torch.utils.data import Dataset
from mmdeploy.utils import Codebase, get_codebase
from mmdeploy.apis import build_task_processor
class QuantizationImageDataset(Dataset):
@ -13,46 +14,16 @@ class QuantizationImageDataset(Dataset):
def __init__(
self,
path: str,
deploy_cfg: mmcv.Config,
model_cfg: mmcv.Config,
deploy_cfg: Config,
model_cfg: Config,
file_client_args: Optional[dict] = None,
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp',
'.pgm', '.tif'),
):
super().__init__()
codebase_type = get_codebase(deploy_cfg)
self.exclude_pipe = ['LoadImageFromFile']
if codebase_type == Codebase.MMCLS:
from mmcls.datasets.pipelines import Compose
elif codebase_type == Codebase.MMDET:
from mmdet.datasets.pipelines import Compose
elif codebase_type == Codebase.MMDET3D:
from mmdet3d.datasets.pipelines import Compose
self.exclude_pipe.extend([
'LoadMultiViewImageFromFiles', 'LoadImageFromFileMono3D',
'LoadPointsFromMultiSweeps', 'LoadPointsFromDict'
])
elif codebase_type == Codebase.MMEDIT:
from mmedit.datasets.pipelines import Compose
self.exclude_pipe.extend(
['LoadImageFromFileList', 'LoadPairedImageFromFile'])
elif codebase_type == Codebase.MMOCR:
from mmocr.datasets.pipelines import Compose
self.exclude_pipe.extend(
['LoadImageFromNdarray', 'LoadImageFromLMDB'])
elif codebase_type == Codebase.MMPOSE:
from mmpose.datasets.pipelines import Compose
elif codebase_type == Codebase.MMROTATE:
from mmrotate.datasets.pipelines import Compose
elif codebase_type == Codebase.MMSEG:
from mmseg.datasets.pipelines import Compose
else:
raise Exception(
'Not supported codebase_type {}'.format(codebase_type))
pipeline = filter(lambda val: val['type'] not in self.exclude_pipe,
model_cfg.data.test.pipeline)
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
self.task_processor = task_processor
self.preprocess = Compose(list(pipeline))
self.samples = []
self.extensions = tuple(set([i.lower() for i in extensions]))
self.file_client = FileClient.infer_client(file_client_args, path)
@ -77,12 +48,8 @@ class QuantizationImageDataset(Dataset):
def __getitem__(self, index):
sample = self.samples[index]
image = mmcv.imread(sample)
data = dict(img=image)
data = self.preprocess(data)
from mmcv.parallel import collate
data = collate([data], samples_per_gpu=1)
return {'img': data['img'].squeeze()}
data = self.task_processor.create_input(image)
return data[0]
def is_valid_file(self, filename: str) -> bool:
"""Check if a file is a valid sample."""