From 940a06f645a9b62632a411e7bff1f6dedcb0e9cb Mon Sep 17 00:00:00 2001 From: Ma Zerun Date: Mon, 21 Nov 2022 10:56:35 +0800 Subject: [PATCH] [Refactor] Refactor to use new fileio API in MMEngine. (#1176) * [Refactor] Refactor to use new fileio API in MMEngine. * Add comment about why use `backend` --- mmcls/datasets/cifar.py | 43 +++++------- mmcls/datasets/cub.py | 12 ++-- mmcls/datasets/custom.py | 67 ++++++++++--------- mmcls/datasets/mnist.py | 25 +++---- mmcls/datasets/utils.py | 12 ++-- mmcls/datasets/voc.py | 14 ++-- mmcls/engine/hooks/visualization_hook.py | 10 +-- .../test_hooks/test_visualization_hook.py | 12 ---- 8 files changed, 83 insertions(+), 112 deletions(-) diff --git a/mmcls/datasets/cifar.py b/mmcls/datasets/cifar.py index da4fe3dd..25d9d058 100644 --- a/mmcls/datasets/cifar.py +++ b/mmcls/datasets/cifar.py @@ -4,7 +4,8 @@ from typing import List, Optional import mmengine.dist as dist import numpy as np -from mmengine import FileClient +from mmengine.fileio import (LocalBackend, exists, get, get_file_backend, + join_path) from mmcls.registry import DATASETS from .base_dataset import BaseDataset @@ -73,21 +74,17 @@ class CIFAR10(BaseDataset): def load_data_list(self): """Load images and ground truth labels.""" - root_prefix = self.data_prefix['root'] - file_client = FileClient.infer_client(uri=root_prefix) + root = self.data_prefix['root'] + backend = get_file_backend(root, enable_singleton=True) if dist.is_main_process() and not self._check_integrity(): - if file_client.name != 'HardDiskBackend': - raise RuntimeError( - f'The dataset on {root_prefix} is not integrated, ' - f'please manually handle it.') + if not isinstance(backend, LocalBackend): + raise RuntimeError(f'The dataset on {root} is not integrated, ' + f'please manually handle it.') if self.download: download_and_extract_archive( - self.url, - root_prefix, - filename=self.filename, - md5=self.tgz_md5) + self.url, root, filename=self.filename, md5=self.tgz_md5) else: raise RuntimeError( f'Cannot find {self.__class__.__name__} dataset in ' @@ -109,10 +106,8 @@ class CIFAR10(BaseDataset): # load the picked numpy arrays for file_name, _ in downloaded_list: - file_path = file_client.join_path(root_prefix, self.base_folder, - file_name) - content = file_client.get(file_path) - entry = pickle.loads(content, encoding='latin1') + file_path = join_path(root, self.base_folder, file_name) + entry = pickle.loads(get(file_path), encoding='latin1') imgs.append(entry['data']) if 'labels' in entry: gt_labels.extend(entry['labels']) @@ -136,32 +131,26 @@ class CIFAR10(BaseDataset): def _load_meta(self): """Load categories information from metafile.""" root = self.data_prefix['root'] - file_client = FileClient.infer_client(uri=root) - path = file_client.join_path(root, self.base_folder, - self.meta['filename']) + path = join_path(root, self.base_folder, self.meta['filename']) md5 = self.meta.get('md5', None) - if not file_client.exists(path) or (md5 is not None - and not check_md5(path, md5)): + if not exists(path) or (md5 is not None and not check_md5(path, md5)): raise RuntimeError( 'Dataset metadata file not found or corrupted.' + ' You can use `download=True` to download it') - content = file_client.get(path) - data = pickle.loads(content, encoding='latin1') + data = pickle.loads(get(path), encoding='latin1') self._metainfo.setdefault('classes', data[self.meta['key']]) def _check_integrity(self): """Check the integrity of data files.""" root = self.data_prefix['root'] - file_client = FileClient.infer_client(uri=root) for fentry in (self.train_list + self.test_list): filename, md5 = fentry[0], fentry[1] - fpath = file_client.join_path(root, self.base_folder, filename) - if not file_client.exists(fpath): + fpath = join_path(root, self.base_folder, filename) + if not exists(fpath): return False - if md5 is not None and not check_md5( - fpath, md5, file_client=file_client): + if md5 is not None and not check_md5(fpath, md5): return False return True diff --git a/mmcls/datasets/cub.py b/mmcls/datasets/cub.py index 2db4511c..5248df09 100644 --- a/mmcls/datasets/cub.py +++ b/mmcls/datasets/cub.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import List -from mmengine import FileClient, list_from_file +from mmengine import get_file_backend, list_from_file from mmcls.registry import DATASETS from .base_dataset import BaseDataset @@ -78,10 +78,10 @@ class CUB(BaseDataset): image_class_labels_file: str = 'image_class_labels.txt', train_test_split_file: str = 'train_test_split.txt', **kwargs): - self.file_client = FileClient.infer_client(uri=data_root) - self.image_class_labels_file = self.file_client.join_path( + self.backend = get_file_backend(data_root, enable_singleton=True) + self.image_class_labels_file = self.backend.join_path( data_root, image_class_labels_file) - self.train_test_split_file = self.file_client.join_path( + self.train_test_split_file = self.backend.join_path( data_root, train_test_split_file) super(CUB, self).__init__( ann_file=ann_file, @@ -123,8 +123,8 @@ class CUB(BaseDataset): # skip test samples when test_mode=False continue - img_path = self.file_client.join_path(self.img_prefix, - sample_dict[sample_id]) + img_path = self.backend.join_path(self.img_prefix, + sample_dict[sample_id]) gt_label = int(label_dict[sample_id]) - 1 info = dict(img_path=img_path, gt_label=gt_label) data_list.append(info) diff --git a/mmcls/datasets/custom.py b/mmcls/datasets/custom.py index 22cbee3d..af1c0c14 100644 --- a/mmcls/datasets/custom.py +++ b/mmcls/datasets/custom.py @@ -1,20 +1,24 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union -import mmengine -from mmengine import FileClient +from mmengine.fileio import (BaseStorageBackend, get_file_backend, + list_from_file) from mmengine.logging import MMLogger from mmcls.registry import DATASETS from .base_dataset import BaseDataset -def find_folders(root: str, - file_client: FileClient) -> Tuple[List[str], Dict[str, int]]: +def find_folders( + root: str, + backend: Optional[BaseStorageBackend] = None +) -> Tuple[List[str], Dict[str, int]]: """Find classes by folders under a root. Args: root (string): root directory of folders + backend (BaseStorageBackend | None): The file backend of the root. + If None, auto infer backend from the root path. Defaults to None. Returns: Tuple[List[str], Dict[str, int]]: @@ -22,8 +26,10 @@ def find_folders(root: str, - folders: The name of sub folders under the root. - folder_to_idx: The map from folder name to class idx. """ + # Pre-build file backend to prevent verbose file backend inference. + backend = backend or get_file_backend(root, enable_singleton=True) folders = list( - file_client.list_dir_or_file( + backend.list_dir_or_file( root, list_dir=True, list_file=False, @@ -34,8 +40,12 @@ def find_folders(root: str, return folders, folder_to_idx -def get_samples(root: str, folder_to_idx: Dict[str, int], - is_valid_file: Callable, file_client: FileClient): +def get_samples( + root: str, + folder_to_idx: Dict[str, int], + is_valid_file: Callable, + backend: Optional[BaseStorageBackend] = None, +): """Make dataset by walking all images under a root. Args: @@ -43,6 +53,8 @@ def get_samples(root: str, folder_to_idx: Dict[str, int], folder_to_idx (dict): the map from class name to class idx is_valid_file (Callable): A function that takes path of a file and check if the file is a valid sample file. + backend (BaseStorageBackend | None): The file backend of the root. + If None, auto infer backend from the root path. Defaults to None. Returns: Tuple[list, set]: @@ -52,19 +64,20 @@ def get_samples(root: str, folder_to_idx: Dict[str, int], """ samples = [] available_classes = set() + # Pre-build file backend to prevent verbose file backend inference. + backend = backend or get_file_backend(root, enable_singleton=True) for folder_name in sorted(list(folder_to_idx.keys())): - _dir = file_client.join_path(root, folder_name) - files = list( - file_client.list_dir_or_file( - _dir, - list_dir=False, - list_file=True, - recursive=True, - )) + _dir = backend.join_path(root, folder_name) + files = backend.list_dir_or_file( + _dir, + list_dir=False, + list_file=True, + recursive=True, + ) for file in sorted(list(files)): if is_valid_file(file): - path = file_client.join_path(folder_name, file) + path = backend.join_path(folder_name, file) item = (path, folder_to_idx[folder_name]) samples.append(item) available_classes.add(folder_name) @@ -169,14 +182,13 @@ class CustomDataset(BaseDataset): if not lazy_init: self.full_init() - def _find_samples(self, file_client): + def _find_samples(self): """find samples from ``data_prefix``.""" - classes, folder_to_idx = find_folders(self.img_prefix, file_client) + classes, folder_to_idx = find_folders(self.img_prefix) samples, empty_classes = get_samples( self.img_prefix, folder_to_idx, is_valid_file=self.is_valid_file, - file_client=file_client, ) if len(samples) == 0: @@ -205,24 +217,17 @@ class CustomDataset(BaseDataset): def load_data_list(self): """Load image paths and gt_labels.""" - if self.img_prefix: - file_client = FileClient.infer_client(uri=self.img_prefix) - if not self.ann_file: - samples = self._find_samples(file_client) + samples = self._find_samples() else: - lines = mmengine.list_from_file(self.ann_file) + lines = list_from_file(self.ann_file) samples = [x.strip().rsplit(' ', 1) for x in lines] - def add_prefix(filename, prefix=''): - if not prefix: - return filename - else: - return file_client.join_path(prefix, filename) - + # Pre-build file backend to prevent verbose file backend inference. + backend = get_file_backend(self.img_prefix, enable_singleton=True) data_list = [] for filename, gt_label in samples: - img_path = add_prefix(filename, self.img_prefix) + img_path = backend.join_path(self.img_prefix, filename) info = {'img_path': img_path, 'gt_label': int(gt_label)} data_list.append(info) return data_list diff --git a/mmcls/datasets/mnist.py b/mmcls/datasets/mnist.py index 4250bfb1..71d980df 100644 --- a/mmcls/datasets/mnist.py +++ b/mmcls/datasets/mnist.py @@ -1,11 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. import codecs from typing import List, Optional +from urllib.parse import urljoin import mmengine.dist as dist import numpy as np import torch -from mmengine import FileClient +from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path from mmcls.registry import DATASETS from .base_dataset import BaseDataset @@ -67,10 +68,10 @@ class MNIST(BaseDataset): def load_data_list(self): """Load images and ground truth labels.""" root = self.data_prefix['root'] - file_client = FileClient.infer_client(uri=root) + backend = get_file_backend(root, enable_singleton=True) if dist.is_main_process() and not self._check_exists(): - if file_client.name != 'HardDiskBackend': + if not isinstance(backend, LocalBackend): raise RuntimeError(f'The dataset on {root} is not integrated, ' f'please manually handle it.') @@ -93,10 +94,9 @@ class MNIST(BaseDataset): file_list = self.test_list # load data from SN3 files - imgs = read_image_file( - file_client.join_path(root, rm_suffix(file_list[0][0]))) + imgs = read_image_file(join_path(root, rm_suffix(file_list[0][0]))) gt_labels = read_label_file( - file_client.join_path(root, rm_suffix(file_list[1][0]))) + join_path(root, rm_suffix(file_list[1][0]))) data_infos = [] for img, gt_label in zip(imgs, gt_labels): @@ -108,28 +108,23 @@ class MNIST(BaseDataset): def _check_exists(self): """Check the exists of data files.""" root = self.data_prefix['root'] - file_client = FileClient.infer_client(uri=root) for filename, _ in (self.train_list + self.test_list): # get extracted filename of data extract_filename = rm_suffix(filename) - fpath = file_client.join_path(root, extract_filename) - if not file_client.exists(fpath): + fpath = join_path(root, extract_filename) + if not exists(fpath): return False return True def _download(self): """Download and extract data files.""" root = self.data_prefix['root'] - file_client = FileClient.infer_client(uri=root) for filename, md5 in (self.train_list + self.test_list): - url = file_client.join_path(self.url_prefix, filename) + url = urljoin(self.url_prefix, filename) download_and_extract_archive( - url, - download_root=self.data_prefix['root'], - filename=filename, - md5=md5) + url, download_root=root, filename=filename, md5=md5) def extra_repr(self) -> List[str]: """The extra repr information of the dataset.""" diff --git a/mmcls/datasets/utils.py b/mmcls/datasets/utils.py index 1a1cef93..fcb60e43 100644 --- a/mmcls/datasets/utils.py +++ b/mmcls/datasets/utils.py @@ -10,7 +10,7 @@ import urllib.error import urllib.request import zipfile -from mmengine.fileio.file_client import FileClient +from mmengine.fileio import LocalBackend, get_file_backend __all__ = [ 'rm_suffix', 'check_integrity', 'download_and_extract_archive', @@ -25,16 +25,16 @@ def rm_suffix(s, suffix=None): return s[:s.rfind(suffix)] -def calculate_md5(fpath: str, - file_client: FileClient = None, - chunk_size: int = 1024 * 1024): +def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024): md5 = hashlib.md5() - if file_client is None or file_client.name == 'HardDiskBackend': + backend = get_file_backend(fpath, enable_singleton=True) + if isinstance(backend, LocalBackend): + # Enable chunk update for local file. with open(fpath, 'rb') as f: for chunk in iter(lambda: f.read(chunk_size), b''): md5.update(chunk) else: - md5.update(file_client.get(fpath)) + md5.update(backend.get(fpath)) return md5.hexdigest() diff --git a/mmcls/datasets/voc.py b/mmcls/datasets/voc.py index feeb27ab..346ff330 100644 --- a/mmcls/datasets/voc.py +++ b/mmcls/datasets/voc.py @@ -2,7 +2,7 @@ import xml.etree.ElementTree as ET from typing import List, Optional, Union -from mmengine import FileClient, list_from_file +from mmengine import get_file_backend, list_from_file from mmcls.registry import DATASETS from .base_dataset import expanduser @@ -72,9 +72,8 @@ class VOC(MultiLabelDataset): ' False.' self.data_root = data_root - self.file_client = FileClient.infer_client(uri=data_root) - self.image_set_path = self.file_client.join_path( - data_root, image_set_path) + self.backend = get_file_backend(data_root, enable_singleton=True) + self.image_set_path = self.backend.join_path(data_root, image_set_path) super().__init__( ann_file='', @@ -94,8 +93,8 @@ class VOC(MultiLabelDataset): def _get_labels_from_xml(self, img_id): """Get gt_labels and labels_difficult from xml file.""" - xml_path = self.file_client.join_path(self.ann_prefix, f'{img_id}.xml') - content = self.file_client.get(xml_path) + xml_path = self.backend.join_path(self.ann_prefix, f'{img_id}.xml') + content = self.backend.get(xml_path) root = ET.fromstring(content) labels, labels_difficult = set(), set() @@ -120,8 +119,7 @@ class VOC(MultiLabelDataset): img_ids = list_from_file(self.image_set_path) for img_id in img_ids: - img_path = self.file_client.join_path(self.img_prefix, - f'{img_id}.jpg') + img_path = self.backend.join_path(self.img_prefix, f'{img_id}.jpg') labels, labels_difficult = None, None if self.ann_prefix is not None: diff --git a/mmcls/engine/hooks/visualization_hook.py b/mmcls/engine/hooks/visualization_hook.py index 852aa54f..921804fe 100644 --- a/mmcls/engine/hooks/visualization_hook.py +++ b/mmcls/engine/hooks/visualization_hook.py @@ -3,7 +3,7 @@ import math import os.path as osp from typing import Optional, Sequence -from mmengine import FileClient +from mmengine.fileio import join_path from mmengine.hooks import Hook from mmengine.runner import EpochBasedTrainLoop, Runner from mmengine.visualization import Visualizer @@ -45,10 +45,6 @@ class VisualizationHook(Hook): self.interval = interval self.show = show self.out_dir = out_dir - if out_dir is not None: - self.file_client = FileClient.infer_client(uri=out_dir) - else: - self.file_client = None self.draw_args = {**kwargs, 'show': show} @@ -89,8 +85,8 @@ class VisualizationHook(Hook): draw_args = self.draw_args if self.out_dir is not None: - draw_args['out_file'] = self.file_client.join_path( - self.out_dir, f'{sample_name}_{step}.png') + draw_args['out_file'] = join_path(self.out_dir, + f'{sample_name}_{step}.png') self._visualizer.add_datasample( sample_name, diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py index 04e77d9b..e922a983 100644 --- a/tests/test_engine/test_hooks/test_visualization_hook.py +++ b/tests/test_engine/test_hooks/test_visualization_hook.py @@ -32,18 +32,6 @@ class TestVisualizationHook(TestCase): self.tmpdir = tempfile.TemporaryDirectory() - def test_initialize(self): - # test file_client - cfg = dict(type='VisualizationHook') - hook = HOOKS.build(cfg) - self.assertIsNone(hook.file_client) - - cfg = dict(type='VisualizationHook', out_dir=self.tmpdir.name) - hook = HOOKS.build(cfg) - self.assertIsNotNone(hook.file_client) - - # test draw_args - def test_draw_samples(self): # test enable=False cfg = dict(type='VisualizationHook', enable=False)