mmdeploy/tools/quant_image_dataset.py

90 lines
3.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
import mmcv
from mmcv import FileClient
from torch.utils.data import Dataset
from mmdeploy.utils import Codebase, get_codebase
class QuantizationImageDataset(Dataset):
def __init__(
self,
path: str,
deploy_cfg: mmcv.Config,
model_cfg: mmcv.Config,
file_client_args: Optional[dict] = None,
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp',
'.pgm', '.tif'),
):
super().__init__()
codebase_type = get_codebase(deploy_cfg)
self.exclude_pipe = ['LoadImageFromFile']
if codebase_type == Codebase.MMCLS:
from mmcls.datasets.pipelines import Compose
elif codebase_type == Codebase.MMDET:
from mmdet.datasets.pipelines import Compose
elif codebase_type == Codebase.MMDET3D:
from mmdet3d.datasets.pipelines import Compose
self.exclude_pipe.extend([
'LoadMultiViewImageFromFiles', 'LoadImageFromFileMono3D',
'LoadPointsFromMultiSweeps', 'LoadPointsFromDict'
])
elif codebase_type == Codebase.MMEDIT:
from mmedit.datasets.pipelines import Compose
self.exclude_pipe.extend(
['LoadImageFromFileList', 'LoadPairedImageFromFile'])
elif codebase_type == Codebase.MMOCR:
from mmocr.datasets.pipelines import Compose
self.exclude_pipe.extend(
['LoadImageFromNdarray', 'LoadImageFromLMDB'])
elif codebase_type == Codebase.MMPOSE:
from mmpose.datasets.pipelines import Compose
elif codebase_type == Codebase.MMROTATE:
from mmrotate.datasets.pipelines import Compose
elif codebase_type == Codebase.MMSEG:
from mmseg.datasets.pipelines import Compose
else:
raise Exception(
'Not supported codebase_type {}'.format(codebase_type))
pipeline = filter(lambda val: val['type'] not in self.exclude_pipe,
model_cfg.data.test.pipeline)
self.preprocess = Compose(list(pipeline))
self.samples = []
self.extensions = tuple(set([i.lower() for i in extensions]))
self.file_client = FileClient.infer_client(file_client_args, path)
self.path = path
assert self.file_client.isdir(path)
files = list(
self.file_client.list_dir_or_file(
path,
list_dir=False,
list_file=True,
recursive=False,
))
for file in files:
if self.is_valid_file(self.file_client.join_path(file)):
path = self.file_client.join_path(self.path, file)
self.samples.append(path)
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
sample = self.samples[index]
image = mmcv.imread(sample)
data = dict(img=image)
data = self.preprocess(data)
from mmcv.parallel import collate
data = collate([data], samples_per_gpu=1)
return {'img': data['img'].squeeze()}
def is_valid_file(self, filename: str) -> bool:
"""Check if a file is a valid sample."""
return filename.lower().endswith(self.extensions)