mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
fix quanti
This commit is contained in:
parent
d700413d30
commit
ce036d547a
@ -67,6 +67,15 @@ class BaseTask(metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def build_data_preprocessor(self):
|
||||||
|
model = deepcopy(self.model_cfg.model)
|
||||||
|
preprocess_cfg = model['data_preprocessor']
|
||||||
|
|
||||||
|
from mmengine.registry import MODELS
|
||||||
|
data_preprocessor = MODELS.build(preprocess_cfg)
|
||||||
|
|
||||||
|
return data_preprocessor
|
||||||
|
|
||||||
def build_pytorch_model(self,
|
def build_pytorch_model(self,
|
||||||
model_checkpoint: Optional[str] = None,
|
model_checkpoint: Optional[str] = None,
|
||||||
cfg_options: Optional[Dict] = None,
|
cfg_options: Optional[Dict] = None,
|
||||||
|
@ -113,6 +113,16 @@ class Classification(BaseTask):
|
|||||||
super(Classification, self).__init__(model_cfg, deploy_cfg, device,
|
super(Classification, self).__init__(model_cfg, deploy_cfg, device,
|
||||||
experiment_name)
|
experiment_name)
|
||||||
|
|
||||||
|
def build_data_preprocessor(self):
|
||||||
|
model_cfg = deepcopy(self.model_cfg)
|
||||||
|
data_preprocessor = deepcopy(model_cfg.get('preprocess_cfg', {}))
|
||||||
|
data_preprocessor.setdefault('type', 'mmcls.ClsDataPreprocessor')
|
||||||
|
|
||||||
|
from mmengine.registry import MODELS
|
||||||
|
data_preprocessor = MODELS.build(data_preprocessor)
|
||||||
|
|
||||||
|
return data_preprocessor
|
||||||
|
|
||||||
def build_backend_model(self,
|
def build_backend_model(self,
|
||||||
model_files: Sequence[str] = None,
|
model_files: Sequence[str] = None,
|
||||||
**kwargs) -> torch.nn.Module:
|
**kwargs) -> torch.nn.Module:
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
from mmcv import Config
|
from mmengine import Config
|
||||||
|
|
||||||
|
from mmdeploy.apis.utils import build_task_processor
|
||||||
from mmdeploy.utils import get_root_logger, load_config
|
from mmdeploy.utils import get_root_logger, load_config
|
||||||
|
|
||||||
|
|
||||||
@ -21,27 +23,36 @@ def get_table(onnx_path: str,
|
|||||||
if 'onnx_config' in deploy_cfg and 'input_shape' in deploy_cfg.onnx_config:
|
if 'onnx_config' in deploy_cfg and 'input_shape' in deploy_cfg.onnx_config:
|
||||||
input_shape = deploy_cfg.onnx_config.input_shape
|
input_shape = deploy_cfg.onnx_config.input_shape
|
||||||
|
|
||||||
|
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
|
||||||
|
calib_dataloader = deepcopy(model_cfg[f'{dataset_type}_dataloader'])
|
||||||
|
calib_dataloader['batch_size'] = 1
|
||||||
# build calibration dataloader. If img dir not specified, use val dataset.
|
# build calibration dataloader. If img dir not specified, use val dataset.
|
||||||
if image_dir is not None:
|
if image_dir is not None:
|
||||||
from quant_image_dataset import QuantizationImageDataset
|
from quant_image_dataset import QuantizationImageDataset
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
dataset = QuantizationImageDataset(
|
dataset = QuantizationImageDataset(
|
||||||
path=image_dir, deploy_cfg=deploy_cfg, model_cfg=model_cfg)
|
path=image_dir, deploy_cfg=deploy_cfg, model_cfg=model_cfg)
|
||||||
dataloader = DataLoader(dataset, batch_size=1)
|
calib_dataloader['dataset'] = dataset
|
||||||
|
dataloader = task_processor.build_dataloader(calib_dataloader)
|
||||||
|
# dataloader = DataLoader(dataset, batch_size=1)
|
||||||
else:
|
else:
|
||||||
from mmdeploy.apis.utils import build_task_processor
|
dataset = task_processor.build_dataset(calib_dataloader['dataset'])
|
||||||
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
|
calib_dataloader['dataset'] = dataset
|
||||||
dataset = task_processor.build_dataset(model_cfg, dataset_type)
|
dataloader = task_processor.build_dataloader(calib_dataloader)
|
||||||
dataloader = task_processor.build_dataloader(dataset, 1, 1)
|
|
||||||
|
data_preprocessor = task_processor.build_data_preprocessor()
|
||||||
|
|
||||||
# get an available input shape randomly
|
# get an available input shape randomly
|
||||||
for _, input_data in enumerate(dataloader):
|
for _, input_data in enumerate(dataloader):
|
||||||
if isinstance(input_data['img'], list):
|
input_data = data_preprocessor(input_data)
|
||||||
input_shape = input_data['img'][0].shape
|
input_tensor = input_data[0]
|
||||||
collate_fn = lambda x: x['img'][0].to(device) # noqa: E731
|
if isinstance(input_tensor, list):
|
||||||
|
input_shape = input_tensor[0].shape
|
||||||
|
collate_fn = lambda x: data_preprocessor(x[0])[0].to( # noqa: E731
|
||||||
|
device)
|
||||||
else:
|
else:
|
||||||
input_shape = input_data['img'].shape
|
input_shape = input_tensor.shape
|
||||||
collate_fn = lambda x: x['img'].to(device) # noqa: E731
|
collate_fn = lambda x: data_preprocessor(x)[0].to( # noqa: E731
|
||||||
|
device)
|
||||||
break
|
break
|
||||||
|
|
||||||
from ppq import QuantizationSettingFactory, TargetPlatform
|
from ppq import QuantizationSettingFactory, TargetPlatform
|
||||||
@ -106,13 +117,9 @@ def main():
|
|||||||
quant_onnx_path = args.out_onnx
|
quant_onnx_path = args.out_onnx
|
||||||
image_dir = args.image_dir
|
image_dir = args.image_dir
|
||||||
|
|
||||||
try:
|
|
||||||
get_table(onnx_path, deploy_cfg, model_cfg, quant_onnx_path,
|
get_table(onnx_path, deploy_cfg, model_cfg, quant_onnx_path,
|
||||||
quant_table_path, image_dir)
|
quant_table_path, image_dir)
|
||||||
logger.info('onnx2ncnn_quant_table success.')
|
logger.info('onnx2ncnn_quant_table success.')
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
logger.error('onnx2ncnn_quant_table failed.')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -3,9 +3,10 @@ from typing import Optional, Sequence
|
|||||||
|
|
||||||
import mmcv
|
import mmcv
|
||||||
from mmcv import FileClient
|
from mmcv import FileClient
|
||||||
|
from mmengine import Config
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from mmdeploy.utils import Codebase, get_codebase
|
from mmdeploy.apis import build_task_processor
|
||||||
|
|
||||||
|
|
||||||
class QuantizationImageDataset(Dataset):
|
class QuantizationImageDataset(Dataset):
|
||||||
@ -13,46 +14,16 @@ class QuantizationImageDataset(Dataset):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
deploy_cfg: mmcv.Config,
|
deploy_cfg: Config,
|
||||||
model_cfg: mmcv.Config,
|
model_cfg: Config,
|
||||||
file_client_args: Optional[dict] = None,
|
file_client_args: Optional[dict] = None,
|
||||||
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp',
|
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp',
|
||||||
'.pgm', '.tif'),
|
'.pgm', '.tif'),
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
codebase_type = get_codebase(deploy_cfg)
|
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
|
||||||
self.exclude_pipe = ['LoadImageFromFile']
|
self.task_processor = task_processor
|
||||||
if codebase_type == Codebase.MMCLS:
|
|
||||||
from mmcls.datasets.pipelines import Compose
|
|
||||||
elif codebase_type == Codebase.MMDET:
|
|
||||||
from mmdet.datasets.pipelines import Compose
|
|
||||||
elif codebase_type == Codebase.MMDET3D:
|
|
||||||
from mmdet3d.datasets.pipelines import Compose
|
|
||||||
self.exclude_pipe.extend([
|
|
||||||
'LoadMultiViewImageFromFiles', 'LoadImageFromFileMono3D',
|
|
||||||
'LoadPointsFromMultiSweeps', 'LoadPointsFromDict'
|
|
||||||
])
|
|
||||||
elif codebase_type == Codebase.MMEDIT:
|
|
||||||
from mmedit.datasets.pipelines import Compose
|
|
||||||
self.exclude_pipe.extend(
|
|
||||||
['LoadImageFromFileList', 'LoadPairedImageFromFile'])
|
|
||||||
elif codebase_type == Codebase.MMOCR:
|
|
||||||
from mmocr.datasets.pipelines import Compose
|
|
||||||
self.exclude_pipe.extend(
|
|
||||||
['LoadImageFromNdarray', 'LoadImageFromLMDB'])
|
|
||||||
elif codebase_type == Codebase.MMPOSE:
|
|
||||||
from mmpose.datasets.pipelines import Compose
|
|
||||||
elif codebase_type == Codebase.MMROTATE:
|
|
||||||
from mmrotate.datasets.pipelines import Compose
|
|
||||||
elif codebase_type == Codebase.MMSEG:
|
|
||||||
from mmseg.datasets.pipelines import Compose
|
|
||||||
else:
|
|
||||||
raise Exception(
|
|
||||||
'Not supported codebase_type {}'.format(codebase_type))
|
|
||||||
pipeline = filter(lambda val: val['type'] not in self.exclude_pipe,
|
|
||||||
model_cfg.data.test.pipeline)
|
|
||||||
|
|
||||||
self.preprocess = Compose(list(pipeline))
|
|
||||||
self.samples = []
|
self.samples = []
|
||||||
self.extensions = tuple(set([i.lower() for i in extensions]))
|
self.extensions = tuple(set([i.lower() for i in extensions]))
|
||||||
self.file_client = FileClient.infer_client(file_client_args, path)
|
self.file_client = FileClient.infer_client(file_client_args, path)
|
||||||
@ -77,12 +48,8 @@ class QuantizationImageDataset(Dataset):
|
|||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
sample = self.samples[index]
|
sample = self.samples[index]
|
||||||
image = mmcv.imread(sample)
|
image = mmcv.imread(sample)
|
||||||
data = dict(img=image)
|
data = self.task_processor.create_input(image)
|
||||||
data = self.preprocess(data)
|
return data[0]
|
||||||
from mmcv.parallel import collate
|
|
||||||
data = collate([data], samples_per_gpu=1)
|
|
||||||
|
|
||||||
return {'img': data['img'].squeeze()}
|
|
||||||
|
|
||||||
def is_valid_file(self, filename: str) -> bool:
|
def is_valid_file(self, filename: str) -> bool:
|
||||||
"""Check if a file is a valid sample."""
|
"""Check if a file is a valid sample."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user