mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
fix quanti
This commit is contained in:
parent
d700413d30
commit
ce036d547a
@ -67,6 +67,15 @@ class BaseTask(metaclass=ABCMeta):
|
||||
"""
|
||||
pass
|
||||
|
||||
def build_data_preprocessor(self):
|
||||
model = deepcopy(self.model_cfg.model)
|
||||
preprocess_cfg = model['data_preprocessor']
|
||||
|
||||
from mmengine.registry import MODELS
|
||||
data_preprocessor = MODELS.build(preprocess_cfg)
|
||||
|
||||
return data_preprocessor
|
||||
|
||||
def build_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
|
@ -113,6 +113,16 @@ class Classification(BaseTask):
|
||||
super(Classification, self).__init__(model_cfg, deploy_cfg, device,
|
||||
experiment_name)
|
||||
|
||||
def build_data_preprocessor(self):
|
||||
model_cfg = deepcopy(self.model_cfg)
|
||||
data_preprocessor = deepcopy(model_cfg.get('preprocess_cfg', {}))
|
||||
data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor')
|
||||
|
||||
from mmengine.registry import MODELS
|
||||
data_preprocessor = MODELS.build(data_preprocessor)
|
||||
|
||||
return data_preprocessor
|
||||
|
||||
def build_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
|
@ -1,9 +1,11 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
from mmcv import Config
|
||||
from mmengine import Config
|
||||
|
||||
from mmdeploy.apis.utils import build_task_processor
|
||||
from mmdeploy.utils import get_root_logger, load_config
|
||||
|
||||
|
||||
@ -21,27 +23,36 @@ def get_table(onnx_path: str,
|
||||
if 'onnx_config' in deploy_cfg and 'input_shape' in deploy_cfg.onnx_config:
|
||||
input_shape = deploy_cfg.onnx_config.input_shape
|
||||
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
|
||||
calib_dataloader = deepcopy(model_cfg[f'{dataset_type}_dataloader'])
|
||||
calib_dataloader['batch_size'] = 1
|
||||
# build calibration dataloader. If img dir not specified, use val dataset.
|
||||
if image_dir is not None:
|
||||
from quant_image_dataset import QuantizationImageDataset
|
||||
from torch.utils.data import DataLoader
|
||||
dataset = QuantizationImageDataset(
|
||||
path=image_dir, deploy_cfg=deploy_cfg, model_cfg=model_cfg)
|
||||
dataloader = DataLoader(dataset, batch_size=1)
|
||||
calib_dataloader['dataset'] = dataset
|
||||
dataloader = task_processor.build_dataloader(calib_dataloader)
|
||||
# dataloader = DataLoader(dataset, batch_size=1)
|
||||
else:
|
||||
from mmdeploy.apis.utils import build_task_processor
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
|
||||
dataset = task_processor.build_dataset(model_cfg, dataset_type)
|
||||
dataloader = task_processor.build_dataloader(dataset, 1, 1)
|
||||
dataset = task_processor.build_dataset(calib_dataloader['dataset'])
|
||||
calib_dataloader['dataset'] = dataset
|
||||
dataloader = task_processor.build_dataloader(calib_dataloader)
|
||||
|
||||
data_preprocessor = task_processor.build_data_preprocessor()
|
||||
|
||||
# get an available input shape randomly
|
||||
for _, input_data in enumerate(dataloader):
|
||||
if isinstance(input_data['img'], list):
|
||||
input_shape = input_data['img'][0].shape
|
||||
collate_fn = lambda x: x['img'][0].to(device) # noqa: E731
|
||||
input_data = data_preprocessor(input_data)
|
||||
input_tensor = input_data[0]
|
||||
if isinstance(input_tensor, list):
|
||||
input_shape = input_tensor[0].shape
|
||||
collate_fn = lambda x: data_preprocessor(x[0])[0].to( # noqa: E731
|
||||
device)
|
||||
else:
|
||||
input_shape = input_data['img'].shape
|
||||
collate_fn = lambda x: x['img'].to(device) # noqa: E731
|
||||
input_shape = input_tensor.shape
|
||||
collate_fn = lambda x: data_preprocessor(x)[0].to( # noqa: E731
|
||||
device)
|
||||
break
|
||||
|
||||
from ppq import QuantizationSettingFactory, TargetPlatform
|
||||
@ -106,13 +117,9 @@ def main():
|
||||
quant_onnx_path = args.out_onnx
|
||||
image_dir = args.image_dir
|
||||
|
||||
try:
|
||||
get_table(onnx_path, deploy_cfg, model_cfg, quant_onnx_path,
|
||||
quant_table_path, image_dir)
|
||||
logger.info('onnx2ncnn_quant_table success.')
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error('onnx2ncnn_quant_table failed.')
|
||||
get_table(onnx_path, deploy_cfg, model_cfg, quant_onnx_path,
|
||||
quant_table_path, image_dir)
|
||||
logger.info('onnx2ncnn_quant_table success.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -3,9 +3,10 @@ from typing import Optional, Sequence
|
||||
|
||||
import mmcv
|
||||
from mmcv import FileClient
|
||||
from mmengine import Config
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmdeploy.utils import Codebase, get_codebase
|
||||
from mmdeploy.apis import build_task_processor
|
||||
|
||||
|
||||
class QuantizationImageDataset(Dataset):
|
||||
@ -13,46 +14,16 @@ class QuantizationImageDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
deploy_cfg: mmcv.Config,
|
||||
model_cfg: mmcv.Config,
|
||||
deploy_cfg: Config,
|
||||
model_cfg: Config,
|
||||
file_client_args: Optional[dict] = None,
|
||||
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp',
|
||||
'.pgm', '.tif'),
|
||||
):
|
||||
super().__init__()
|
||||
codebase_type = get_codebase(deploy_cfg)
|
||||
self.exclude_pipe = ['LoadImageFromFile']
|
||||
if codebase_type == Codebase.MMCLS:
|
||||
from mmcls.datasets.pipelines import Compose
|
||||
elif codebase_type == Codebase.MMDET:
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
elif codebase_type == Codebase.MMDET3D:
|
||||
from mmdet3d.datasets.pipelines import Compose
|
||||
self.exclude_pipe.extend([
|
||||
'LoadMultiViewImageFromFiles', 'LoadImageFromFileMono3D',
|
||||
'LoadPointsFromMultiSweeps', 'LoadPointsFromDict'
|
||||
])
|
||||
elif codebase_type == Codebase.MMEDIT:
|
||||
from mmedit.datasets.pipelines import Compose
|
||||
self.exclude_pipe.extend(
|
||||
['LoadImageFromFileList', 'LoadPairedImageFromFile'])
|
||||
elif codebase_type == Codebase.MMOCR:
|
||||
from mmocr.datasets.pipelines import Compose
|
||||
self.exclude_pipe.extend(
|
||||
['LoadImageFromNdarray', 'LoadImageFromLMDB'])
|
||||
elif codebase_type == Codebase.MMPOSE:
|
||||
from mmpose.datasets.pipelines import Compose
|
||||
elif codebase_type == Codebase.MMROTATE:
|
||||
from mmrotate.datasets.pipelines import Compose
|
||||
elif codebase_type == Codebase.MMSEG:
|
||||
from mmseg.datasets.pipelines import Compose
|
||||
else:
|
||||
raise Exception(
|
||||
'Not supported codebase_type {}'.format(codebase_type))
|
||||
pipeline = filter(lambda val: val['type'] not in self.exclude_pipe,
|
||||
model_cfg.data.test.pipeline)
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
|
||||
self.task_processor = task_processor
|
||||
|
||||
self.preprocess = Compose(list(pipeline))
|
||||
self.samples = []
|
||||
self.extensions = tuple(set([i.lower() for i in extensions]))
|
||||
self.file_client = FileClient.infer_client(file_client_args, path)
|
||||
@ -77,12 +48,8 @@ class QuantizationImageDataset(Dataset):
|
||||
def __getitem__(self, index):
|
||||
sample = self.samples[index]
|
||||
image = mmcv.imread(sample)
|
||||
data = dict(img=image)
|
||||
data = self.preprocess(data)
|
||||
from mmcv.parallel import collate
|
||||
data = collate([data], samples_per_gpu=1)
|
||||
|
||||
return {'img': data['img'].squeeze()}
|
||||
data = self.task_processor.create_input(image)
|
||||
return data[0]
|
||||
|
||||
def is_valid_file(self, filename: str) -> bool:
|
||||
"""Check if a file is a valid sample."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user