# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional, Sequence import mmcv from mmcv import FileClient from torch.utils.data import Dataset from mmdeploy.utils import Codebase, get_codebase class QuantizationImageDataset(Dataset): def __init__( self, path: str, deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, file_client_args: Optional[dict] = None, extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'), ): super().__init__() codebase_type = get_codebase(deploy_cfg) self.exclude_pipe = ['LoadImageFromFile'] if codebase_type == Codebase.MMCLS: from mmcls.datasets.pipelines import Compose elif codebase_type == Codebase.MMDET: from mmdet.datasets.pipelines import Compose elif codebase_type == Codebase.MMDET3D: from mmdet3d.datasets.pipelines import Compose self.exclude_pipe.extend([ 'LoadMultiViewImageFromFiles', 'LoadImageFromFileMono3D', 'LoadPointsFromMultiSweeps', 'LoadPointsFromDict' ]) elif codebase_type == Codebase.MMEDIT: from mmedit.datasets.pipelines import Compose self.exclude_pipe.extend( ['LoadImageFromFileList', 'LoadPairedImageFromFile']) elif codebase_type == Codebase.MMOCR: from mmocr.datasets.pipelines import Compose self.exclude_pipe.extend( ['LoadImageFromNdarray', 'LoadImageFromLMDB']) elif codebase_type == Codebase.MMPOSE: from mmpose.datasets.pipelines import Compose elif codebase_type == Codebase.MMROTATE: from mmrotate.datasets.pipelines import Compose elif codebase_type == Codebase.MMSEG: from mmseg.datasets.pipelines import Compose else: raise Exception( 'Not supported codebase_type {}'.format(codebase_type)) pipeline = filter(lambda val: val['type'] not in self.exclude_pipe, model_cfg.data.test.pipeline) self.preprocess = Compose(list(pipeline)) self.samples = [] self.extensions = tuple(set([i.lower() for i in extensions])) self.file_client = FileClient.infer_client(file_client_args, path) self.path = path assert self.file_client.isdir(path) files = list( self.file_client.list_dir_or_file( path, list_dir=False, list_file=True, recursive=False, )) for file in files: if self.is_valid_file(self.file_client.join_path(file)): path = self.file_client.join_path(self.path, file) self.samples.append(path) def __len__(self): return len(self.samples) def __getitem__(self, index): sample = self.samples[index] image = mmcv.imread(sample) data = dict(img=image) data = self.preprocess(data) from mmcv.parallel import collate data = collate([data], samples_per_gpu=1) return {'img': data['img'].squeeze()} def is_valid_file(self, filename: str) -> bool: """Check if a file is a valid sample.""" return filename.lower().endswith(self.extensions)