[Refactor] Refactor to use new fileio API in MMEngine. (#1176)

* [Refactor] Refactor to use new fileio API in MMEngine.

* Add comment about why use `backend`
pull/1143/head
Ma Zerun 2022-11-21 10:56:35 +08:00 committed by GitHub
parent 4969830c8a
commit 940a06f645
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 83 additions and 112 deletions

View File

@ -4,7 +4,8 @@ from typing import List, Optional
import mmengine.dist as dist
import numpy as np
from mmengine import FileClient
from mmengine.fileio import (LocalBackend, exists, get, get_file_backend,
join_path)
from mmcls.registry import DATASETS
from .base_dataset import BaseDataset
@ -73,21 +74,17 @@ class CIFAR10(BaseDataset):
def load_data_list(self):
"""Load images and ground truth labels."""
root_prefix = self.data_prefix['root']
file_client = FileClient.infer_client(uri=root_prefix)
root = self.data_prefix['root']
backend = get_file_backend(root, enable_singleton=True)
if dist.is_main_process() and not self._check_integrity():
if file_client.name != 'HardDiskBackend':
raise RuntimeError(
f'The dataset on {root_prefix} is not integrated, '
f'please manually handle it.')
if not isinstance(backend, LocalBackend):
raise RuntimeError(f'The dataset on {root} is not integrated, '
f'please manually handle it.')
if self.download:
download_and_extract_archive(
self.url,
root_prefix,
filename=self.filename,
md5=self.tgz_md5)
self.url, root, filename=self.filename, md5=self.tgz_md5)
else:
raise RuntimeError(
f'Cannot find {self.__class__.__name__} dataset in '
@ -109,10 +106,8 @@ class CIFAR10(BaseDataset):
# load the picked numpy arrays
for file_name, _ in downloaded_list:
file_path = file_client.join_path(root_prefix, self.base_folder,
file_name)
content = file_client.get(file_path)
entry = pickle.loads(content, encoding='latin1')
file_path = join_path(root, self.base_folder, file_name)
entry = pickle.loads(get(file_path), encoding='latin1')
imgs.append(entry['data'])
if 'labels' in entry:
gt_labels.extend(entry['labels'])
@ -136,32 +131,26 @@ class CIFAR10(BaseDataset):
def _load_meta(self):
"""Load categories information from metafile."""
root = self.data_prefix['root']
file_client = FileClient.infer_client(uri=root)
path = file_client.join_path(root, self.base_folder,
self.meta['filename'])
path = join_path(root, self.base_folder, self.meta['filename'])
md5 = self.meta.get('md5', None)
if not file_client.exists(path) or (md5 is not None
and not check_md5(path, md5)):
if not exists(path) or (md5 is not None and not check_md5(path, md5)):
raise RuntimeError(
'Dataset metadata file not found or corrupted.' +
' You can use `download=True` to download it')
content = file_client.get(path)
data = pickle.loads(content, encoding='latin1')
data = pickle.loads(get(path), encoding='latin1')
self._metainfo.setdefault('classes', data[self.meta['key']])
def _check_integrity(self):
"""Check the integrity of data files."""
root = self.data_prefix['root']
file_client = FileClient.infer_client(uri=root)
for fentry in (self.train_list + self.test_list):
filename, md5 = fentry[0], fentry[1]
fpath = file_client.join_path(root, self.base_folder, filename)
if not file_client.exists(fpath):
fpath = join_path(root, self.base_folder, filename)
if not exists(fpath):
return False
if md5 is not None and not check_md5(
fpath, md5, file_client=file_client):
if md5 is not None and not check_md5(fpath, md5):
return False
return True

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmengine import FileClient, list_from_file
from mmengine import get_file_backend, list_from_file
from mmcls.registry import DATASETS
from .base_dataset import BaseDataset
@ -78,10 +78,10 @@ class CUB(BaseDataset):
image_class_labels_file: str = 'image_class_labels.txt',
train_test_split_file: str = 'train_test_split.txt',
**kwargs):
self.file_client = FileClient.infer_client(uri=data_root)
self.image_class_labels_file = self.file_client.join_path(
self.backend = get_file_backend(data_root, enable_singleton=True)
self.image_class_labels_file = self.backend.join_path(
data_root, image_class_labels_file)
self.train_test_split_file = self.file_client.join_path(
self.train_test_split_file = self.backend.join_path(
data_root, train_test_split_file)
super(CUB, self).__init__(
ann_file=ann_file,
@ -123,8 +123,8 @@ class CUB(BaseDataset):
# skip test samples when test_mode=False
continue
img_path = self.file_client.join_path(self.img_prefix,
sample_dict[sample_id])
img_path = self.backend.join_path(self.img_prefix,
sample_dict[sample_id])
gt_label = int(label_dict[sample_id]) - 1
info = dict(img_path=img_path, gt_label=gt_label)
data_list.append(info)

View File

@ -1,20 +1,24 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
import mmengine
from mmengine import FileClient
from mmengine.fileio import (BaseStorageBackend, get_file_backend,
list_from_file)
from mmengine.logging import MMLogger
from mmcls.registry import DATASETS
from .base_dataset import BaseDataset
def find_folders(root: str,
file_client: FileClient) -> Tuple[List[str], Dict[str, int]]:
def find_folders(
root: str,
backend: Optional[BaseStorageBackend] = None
) -> Tuple[List[str], Dict[str, int]]:
"""Find classes by folders under a root.
Args:
root (string): root directory of folders
backend (BaseStorageBackend | None): The file backend of the root.
If None, auto infer backend from the root path. Defaults to None.
Returns:
Tuple[List[str], Dict[str, int]]:
@ -22,8 +26,10 @@ def find_folders(root: str,
- folders: The name of sub folders under the root.
- folder_to_idx: The map from folder name to class idx.
"""
# Pre-build file backend to prevent verbose file backend inference.
backend = backend or get_file_backend(root, enable_singleton=True)
folders = list(
file_client.list_dir_or_file(
backend.list_dir_or_file(
root,
list_dir=True,
list_file=False,
@ -34,8 +40,12 @@ def find_folders(root: str,
return folders, folder_to_idx
def get_samples(root: str, folder_to_idx: Dict[str, int],
is_valid_file: Callable, file_client: FileClient):
def get_samples(
root: str,
folder_to_idx: Dict[str, int],
is_valid_file: Callable,
backend: Optional[BaseStorageBackend] = None,
):
"""Make dataset by walking all images under a root.
Args:
@ -43,6 +53,8 @@ def get_samples(root: str, folder_to_idx: Dict[str, int],
folder_to_idx (dict): the map from class name to class idx
is_valid_file (Callable): A function that takes path of a file
and check if the file is a valid sample file.
backend (BaseStorageBackend | None): The file backend of the root.
If None, auto infer backend from the root path. Defaults to None.
Returns:
Tuple[list, set]:
@ -52,19 +64,20 @@ def get_samples(root: str, folder_to_idx: Dict[str, int],
"""
samples = []
available_classes = set()
# Pre-build file backend to prevent verbose file backend inference.
backend = backend or get_file_backend(root, enable_singleton=True)
for folder_name in sorted(list(folder_to_idx.keys())):
_dir = file_client.join_path(root, folder_name)
files = list(
file_client.list_dir_or_file(
_dir,
list_dir=False,
list_file=True,
recursive=True,
))
_dir = backend.join_path(root, folder_name)
files = backend.list_dir_or_file(
_dir,
list_dir=False,
list_file=True,
recursive=True,
)
for file in sorted(list(files)):
if is_valid_file(file):
path = file_client.join_path(folder_name, file)
path = backend.join_path(folder_name, file)
item = (path, folder_to_idx[folder_name])
samples.append(item)
available_classes.add(folder_name)
@ -169,14 +182,13 @@ class CustomDataset(BaseDataset):
if not lazy_init:
self.full_init()
def _find_samples(self, file_client):
def _find_samples(self):
"""find samples from ``data_prefix``."""
classes, folder_to_idx = find_folders(self.img_prefix, file_client)
classes, folder_to_idx = find_folders(self.img_prefix)
samples, empty_classes = get_samples(
self.img_prefix,
folder_to_idx,
is_valid_file=self.is_valid_file,
file_client=file_client,
)
if len(samples) == 0:
@ -205,24 +217,17 @@ class CustomDataset(BaseDataset):
def load_data_list(self):
"""Load image paths and gt_labels."""
if self.img_prefix:
file_client = FileClient.infer_client(uri=self.img_prefix)
if not self.ann_file:
samples = self._find_samples(file_client)
samples = self._find_samples()
else:
lines = mmengine.list_from_file(self.ann_file)
lines = list_from_file(self.ann_file)
samples = [x.strip().rsplit(' ', 1) for x in lines]
def add_prefix(filename, prefix=''):
if not prefix:
return filename
else:
return file_client.join_path(prefix, filename)
# Pre-build file backend to prevent verbose file backend inference.
backend = get_file_backend(self.img_prefix, enable_singleton=True)
data_list = []
for filename, gt_label in samples:
img_path = add_prefix(filename, self.img_prefix)
img_path = backend.join_path(self.img_prefix, filename)
info = {'img_path': img_path, 'gt_label': int(gt_label)}
data_list.append(info)
return data_list

View File

@ -1,11 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import codecs
from typing import List, Optional
from urllib.parse import urljoin
import mmengine.dist as dist
import numpy as np
import torch
from mmengine import FileClient
from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path
from mmcls.registry import DATASETS
from .base_dataset import BaseDataset
@ -67,10 +68,10 @@ class MNIST(BaseDataset):
def load_data_list(self):
"""Load images and ground truth labels."""
root = self.data_prefix['root']
file_client = FileClient.infer_client(uri=root)
backend = get_file_backend(root, enable_singleton=True)
if dist.is_main_process() and not self._check_exists():
if file_client.name != 'HardDiskBackend':
if not isinstance(backend, LocalBackend):
raise RuntimeError(f'The dataset on {root} is not integrated, '
f'please manually handle it.')
@ -93,10 +94,9 @@ class MNIST(BaseDataset):
file_list = self.test_list
# load data from SN3 files
imgs = read_image_file(
file_client.join_path(root, rm_suffix(file_list[0][0])))
imgs = read_image_file(join_path(root, rm_suffix(file_list[0][0])))
gt_labels = read_label_file(
file_client.join_path(root, rm_suffix(file_list[1][0])))
join_path(root, rm_suffix(file_list[1][0])))
data_infos = []
for img, gt_label in zip(imgs, gt_labels):
@ -108,28 +108,23 @@ class MNIST(BaseDataset):
def _check_exists(self):
"""Check the exists of data files."""
root = self.data_prefix['root']
file_client = FileClient.infer_client(uri=root)
for filename, _ in (self.train_list + self.test_list):
# get extracted filename of data
extract_filename = rm_suffix(filename)
fpath = file_client.join_path(root, extract_filename)
if not file_client.exists(fpath):
fpath = join_path(root, extract_filename)
if not exists(fpath):
return False
return True
def _download(self):
"""Download and extract data files."""
root = self.data_prefix['root']
file_client = FileClient.infer_client(uri=root)
for filename, md5 in (self.train_list + self.test_list):
url = file_client.join_path(self.url_prefix, filename)
url = urljoin(self.url_prefix, filename)
download_and_extract_archive(
url,
download_root=self.data_prefix['root'],
filename=filename,
md5=md5)
url, download_root=root, filename=filename, md5=md5)
def extra_repr(self) -> List[str]:
"""The extra repr information of the dataset."""

View File

@ -10,7 +10,7 @@ import urllib.error
import urllib.request
import zipfile
from mmengine.fileio.file_client import FileClient
from mmengine.fileio import LocalBackend, get_file_backend
__all__ = [
'rm_suffix', 'check_integrity', 'download_and_extract_archive',
@ -25,16 +25,16 @@ def rm_suffix(s, suffix=None):
return s[:s.rfind(suffix)]
def calculate_md5(fpath: str,
file_client: FileClient = None,
chunk_size: int = 1024 * 1024):
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024):
md5 = hashlib.md5()
if file_client is None or file_client.name == 'HardDiskBackend':
backend = get_file_backend(fpath, enable_singleton=True)
if isinstance(backend, LocalBackend):
# Enable chunk update for local file.
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
else:
md5.update(file_client.get(fpath))
md5.update(backend.get(fpath))
return md5.hexdigest()

View File

@ -2,7 +2,7 @@
import xml.etree.ElementTree as ET
from typing import List, Optional, Union
from mmengine import FileClient, list_from_file
from mmengine import get_file_backend, list_from_file
from mmcls.registry import DATASETS
from .base_dataset import expanduser
@ -72,9 +72,8 @@ class VOC(MultiLabelDataset):
' False.'
self.data_root = data_root
self.file_client = FileClient.infer_client(uri=data_root)
self.image_set_path = self.file_client.join_path(
data_root, image_set_path)
self.backend = get_file_backend(data_root, enable_singleton=True)
self.image_set_path = self.backend.join_path(data_root, image_set_path)
super().__init__(
ann_file='',
@ -94,8 +93,8 @@ class VOC(MultiLabelDataset):
def _get_labels_from_xml(self, img_id):
"""Get gt_labels and labels_difficult from xml file."""
xml_path = self.file_client.join_path(self.ann_prefix, f'{img_id}.xml')
content = self.file_client.get(xml_path)
xml_path = self.backend.join_path(self.ann_prefix, f'{img_id}.xml')
content = self.backend.get(xml_path)
root = ET.fromstring(content)
labels, labels_difficult = set(), set()
@ -120,8 +119,7 @@ class VOC(MultiLabelDataset):
img_ids = list_from_file(self.image_set_path)
for img_id in img_ids:
img_path = self.file_client.join_path(self.img_prefix,
f'{img_id}.jpg')
img_path = self.backend.join_path(self.img_prefix, f'{img_id}.jpg')
labels, labels_difficult = None, None
if self.ann_prefix is not None:

View File

@ -3,7 +3,7 @@ import math
import os.path as osp
from typing import Optional, Sequence
from mmengine import FileClient
from mmengine.fileio import join_path
from mmengine.hooks import Hook
from mmengine.runner import EpochBasedTrainLoop, Runner
from mmengine.visualization import Visualizer
@ -45,10 +45,6 @@ class VisualizationHook(Hook):
self.interval = interval
self.show = show
self.out_dir = out_dir
if out_dir is not None:
self.file_client = FileClient.infer_client(uri=out_dir)
else:
self.file_client = None
self.draw_args = {**kwargs, 'show': show}
@ -89,8 +85,8 @@ class VisualizationHook(Hook):
draw_args = self.draw_args
if self.out_dir is not None:
draw_args['out_file'] = self.file_client.join_path(
self.out_dir, f'{sample_name}_{step}.png')
draw_args['out_file'] = join_path(self.out_dir,
f'{sample_name}_{step}.png')
self._visualizer.add_datasample(
sample_name,

View File

@ -32,18 +32,6 @@ class TestVisualizationHook(TestCase):
self.tmpdir = tempfile.TemporaryDirectory()
def test_initialize(self):
# test file_client
cfg = dict(type='VisualizationHook')
hook = HOOKS.build(cfg)
self.assertIsNone(hook.file_client)
cfg = dict(type='VisualizationHook', out_dir=self.tmpdir.name)
hook = HOOKS.build(cfg)
self.assertIsNotNone(hook.file_client)
# test draw_args
def test_draw_samples(self):
# test enable=False
cfg = dict(type='VisualizationHook', enable=False)