[Refactor] Refactor to use new fileio API in MMEngine. (#1176)
* [Refactor] Refactor to use new fileio API in MMEngine. * Add comment about why use `backend`pull/1143/head
parent
4969830c8a
commit
940a06f645
|
@ -4,7 +4,8 @@ from typing import List, Optional
|
|||
|
||||
import mmengine.dist as dist
|
||||
import numpy as np
|
||||
from mmengine import FileClient
|
||||
from mmengine.fileio import (LocalBackend, exists, get, get_file_backend,
|
||||
join_path)
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
|
@ -73,21 +74,17 @@ class CIFAR10(BaseDataset):
|
|||
|
||||
def load_data_list(self):
|
||||
"""Load images and ground truth labels."""
|
||||
root_prefix = self.data_prefix['root']
|
||||
file_client = FileClient.infer_client(uri=root_prefix)
|
||||
root = self.data_prefix['root']
|
||||
backend = get_file_backend(root, enable_singleton=True)
|
||||
|
||||
if dist.is_main_process() and not self._check_integrity():
|
||||
if file_client.name != 'HardDiskBackend':
|
||||
raise RuntimeError(
|
||||
f'The dataset on {root_prefix} is not integrated, '
|
||||
f'please manually handle it.')
|
||||
if not isinstance(backend, LocalBackend):
|
||||
raise RuntimeError(f'The dataset on {root} is not integrated, '
|
||||
f'please manually handle it.')
|
||||
|
||||
if self.download:
|
||||
download_and_extract_archive(
|
||||
self.url,
|
||||
root_prefix,
|
||||
filename=self.filename,
|
||||
md5=self.tgz_md5)
|
||||
self.url, root, filename=self.filename, md5=self.tgz_md5)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f'Cannot find {self.__class__.__name__} dataset in '
|
||||
|
@ -109,10 +106,8 @@ class CIFAR10(BaseDataset):
|
|||
|
||||
# load the picked numpy arrays
|
||||
for file_name, _ in downloaded_list:
|
||||
file_path = file_client.join_path(root_prefix, self.base_folder,
|
||||
file_name)
|
||||
content = file_client.get(file_path)
|
||||
entry = pickle.loads(content, encoding='latin1')
|
||||
file_path = join_path(root, self.base_folder, file_name)
|
||||
entry = pickle.loads(get(file_path), encoding='latin1')
|
||||
imgs.append(entry['data'])
|
||||
if 'labels' in entry:
|
||||
gt_labels.extend(entry['labels'])
|
||||
|
@ -136,32 +131,26 @@ class CIFAR10(BaseDataset):
|
|||
def _load_meta(self):
|
||||
"""Load categories information from metafile."""
|
||||
root = self.data_prefix['root']
|
||||
file_client = FileClient.infer_client(uri=root)
|
||||
|
||||
path = file_client.join_path(root, self.base_folder,
|
||||
self.meta['filename'])
|
||||
path = join_path(root, self.base_folder, self.meta['filename'])
|
||||
md5 = self.meta.get('md5', None)
|
||||
if not file_client.exists(path) or (md5 is not None
|
||||
and not check_md5(path, md5)):
|
||||
if not exists(path) or (md5 is not None and not check_md5(path, md5)):
|
||||
raise RuntimeError(
|
||||
'Dataset metadata file not found or corrupted.' +
|
||||
' You can use `download=True` to download it')
|
||||
content = file_client.get(path)
|
||||
data = pickle.loads(content, encoding='latin1')
|
||||
data = pickle.loads(get(path), encoding='latin1')
|
||||
self._metainfo.setdefault('classes', data[self.meta['key']])
|
||||
|
||||
def _check_integrity(self):
|
||||
"""Check the integrity of data files."""
|
||||
root = self.data_prefix['root']
|
||||
file_client = FileClient.infer_client(uri=root)
|
||||
|
||||
for fentry in (self.train_list + self.test_list):
|
||||
filename, md5 = fentry[0], fentry[1]
|
||||
fpath = file_client.join_path(root, self.base_folder, filename)
|
||||
if not file_client.exists(fpath):
|
||||
fpath = join_path(root, self.base_folder, filename)
|
||||
if not exists(fpath):
|
||||
return False
|
||||
if md5 is not None and not check_md5(
|
||||
fpath, md5, file_client=file_client):
|
||||
if md5 is not None and not check_md5(fpath, md5):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
from mmengine import FileClient, list_from_file
|
||||
from mmengine import get_file_backend, list_from_file
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
|
@ -78,10 +78,10 @@ class CUB(BaseDataset):
|
|||
image_class_labels_file: str = 'image_class_labels.txt',
|
||||
train_test_split_file: str = 'train_test_split.txt',
|
||||
**kwargs):
|
||||
self.file_client = FileClient.infer_client(uri=data_root)
|
||||
self.image_class_labels_file = self.file_client.join_path(
|
||||
self.backend = get_file_backend(data_root, enable_singleton=True)
|
||||
self.image_class_labels_file = self.backend.join_path(
|
||||
data_root, image_class_labels_file)
|
||||
self.train_test_split_file = self.file_client.join_path(
|
||||
self.train_test_split_file = self.backend.join_path(
|
||||
data_root, train_test_split_file)
|
||||
super(CUB, self).__init__(
|
||||
ann_file=ann_file,
|
||||
|
@ -123,8 +123,8 @@ class CUB(BaseDataset):
|
|||
# skip test samples when test_mode=False
|
||||
continue
|
||||
|
||||
img_path = self.file_client.join_path(self.img_prefix,
|
||||
sample_dict[sample_id])
|
||||
img_path = self.backend.join_path(self.img_prefix,
|
||||
sample_dict[sample_id])
|
||||
gt_label = int(label_dict[sample_id]) - 1
|
||||
info = dict(img_path=img_path, gt_label=gt_label)
|
||||
data_list.append(info)
|
||||
|
|
|
@ -1,20 +1,24 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import mmengine
|
||||
from mmengine import FileClient
|
||||
from mmengine.fileio import (BaseStorageBackend, get_file_backend,
|
||||
list_from_file)
|
||||
from mmengine.logging import MMLogger
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
def find_folders(root: str,
|
||||
file_client: FileClient) -> Tuple[List[str], Dict[str, int]]:
|
||||
def find_folders(
|
||||
root: str,
|
||||
backend: Optional[BaseStorageBackend] = None
|
||||
) -> Tuple[List[str], Dict[str, int]]:
|
||||
"""Find classes by folders under a root.
|
||||
|
||||
Args:
|
||||
root (string): root directory of folders
|
||||
backend (BaseStorageBackend | None): The file backend of the root.
|
||||
If None, auto infer backend from the root path. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], Dict[str, int]]:
|
||||
|
@ -22,8 +26,10 @@ def find_folders(root: str,
|
|||
- folders: The name of sub folders under the root.
|
||||
- folder_to_idx: The map from folder name to class idx.
|
||||
"""
|
||||
# Pre-build file backend to prevent verbose file backend inference.
|
||||
backend = backend or get_file_backend(root, enable_singleton=True)
|
||||
folders = list(
|
||||
file_client.list_dir_or_file(
|
||||
backend.list_dir_or_file(
|
||||
root,
|
||||
list_dir=True,
|
||||
list_file=False,
|
||||
|
@ -34,8 +40,12 @@ def find_folders(root: str,
|
|||
return folders, folder_to_idx
|
||||
|
||||
|
||||
def get_samples(root: str, folder_to_idx: Dict[str, int],
|
||||
is_valid_file: Callable, file_client: FileClient):
|
||||
def get_samples(
|
||||
root: str,
|
||||
folder_to_idx: Dict[str, int],
|
||||
is_valid_file: Callable,
|
||||
backend: Optional[BaseStorageBackend] = None,
|
||||
):
|
||||
"""Make dataset by walking all images under a root.
|
||||
|
||||
Args:
|
||||
|
@ -43,6 +53,8 @@ def get_samples(root: str, folder_to_idx: Dict[str, int],
|
|||
folder_to_idx (dict): the map from class name to class idx
|
||||
is_valid_file (Callable): A function that takes path of a file
|
||||
and check if the file is a valid sample file.
|
||||
backend (BaseStorageBackend | None): The file backend of the root.
|
||||
If None, auto infer backend from the root path. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[list, set]:
|
||||
|
@ -52,19 +64,20 @@ def get_samples(root: str, folder_to_idx: Dict[str, int],
|
|||
"""
|
||||
samples = []
|
||||
available_classes = set()
|
||||
# Pre-build file backend to prevent verbose file backend inference.
|
||||
backend = backend or get_file_backend(root, enable_singleton=True)
|
||||
|
||||
for folder_name in sorted(list(folder_to_idx.keys())):
|
||||
_dir = file_client.join_path(root, folder_name)
|
||||
files = list(
|
||||
file_client.list_dir_or_file(
|
||||
_dir,
|
||||
list_dir=False,
|
||||
list_file=True,
|
||||
recursive=True,
|
||||
))
|
||||
_dir = backend.join_path(root, folder_name)
|
||||
files = backend.list_dir_or_file(
|
||||
_dir,
|
||||
list_dir=False,
|
||||
list_file=True,
|
||||
recursive=True,
|
||||
)
|
||||
for file in sorted(list(files)):
|
||||
if is_valid_file(file):
|
||||
path = file_client.join_path(folder_name, file)
|
||||
path = backend.join_path(folder_name, file)
|
||||
item = (path, folder_to_idx[folder_name])
|
||||
samples.append(item)
|
||||
available_classes.add(folder_name)
|
||||
|
@ -169,14 +182,13 @@ class CustomDataset(BaseDataset):
|
|||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
def _find_samples(self, file_client):
|
||||
def _find_samples(self):
|
||||
"""find samples from ``data_prefix``."""
|
||||
classes, folder_to_idx = find_folders(self.img_prefix, file_client)
|
||||
classes, folder_to_idx = find_folders(self.img_prefix)
|
||||
samples, empty_classes = get_samples(
|
||||
self.img_prefix,
|
||||
folder_to_idx,
|
||||
is_valid_file=self.is_valid_file,
|
||||
file_client=file_client,
|
||||
)
|
||||
|
||||
if len(samples) == 0:
|
||||
|
@ -205,24 +217,17 @@ class CustomDataset(BaseDataset):
|
|||
|
||||
def load_data_list(self):
|
||||
"""Load image paths and gt_labels."""
|
||||
if self.img_prefix:
|
||||
file_client = FileClient.infer_client(uri=self.img_prefix)
|
||||
|
||||
if not self.ann_file:
|
||||
samples = self._find_samples(file_client)
|
||||
samples = self._find_samples()
|
||||
else:
|
||||
lines = mmengine.list_from_file(self.ann_file)
|
||||
lines = list_from_file(self.ann_file)
|
||||
samples = [x.strip().rsplit(' ', 1) for x in lines]
|
||||
|
||||
def add_prefix(filename, prefix=''):
|
||||
if not prefix:
|
||||
return filename
|
||||
else:
|
||||
return file_client.join_path(prefix, filename)
|
||||
|
||||
# Pre-build file backend to prevent verbose file backend inference.
|
||||
backend = get_file_backend(self.img_prefix, enable_singleton=True)
|
||||
data_list = []
|
||||
for filename, gt_label in samples:
|
||||
img_path = add_prefix(filename, self.img_prefix)
|
||||
img_path = backend.join_path(self.img_prefix, filename)
|
||||
info = {'img_path': img_path, 'gt_label': int(gt_label)}
|
||||
data_list.append(info)
|
||||
return data_list
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import codecs
|
||||
from typing import List, Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import mmengine.dist as dist
|
||||
import numpy as np
|
||||
import torch
|
||||
from mmengine import FileClient
|
||||
from mmengine.fileio import LocalBackend, exists, get_file_backend, join_path
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
|
@ -67,10 +68,10 @@ class MNIST(BaseDataset):
|
|||
def load_data_list(self):
|
||||
"""Load images and ground truth labels."""
|
||||
root = self.data_prefix['root']
|
||||
file_client = FileClient.infer_client(uri=root)
|
||||
backend = get_file_backend(root, enable_singleton=True)
|
||||
|
||||
if dist.is_main_process() and not self._check_exists():
|
||||
if file_client.name != 'HardDiskBackend':
|
||||
if not isinstance(backend, LocalBackend):
|
||||
raise RuntimeError(f'The dataset on {root} is not integrated, '
|
||||
f'please manually handle it.')
|
||||
|
||||
|
@ -93,10 +94,9 @@ class MNIST(BaseDataset):
|
|||
file_list = self.test_list
|
||||
|
||||
# load data from SN3 files
|
||||
imgs = read_image_file(
|
||||
file_client.join_path(root, rm_suffix(file_list[0][0])))
|
||||
imgs = read_image_file(join_path(root, rm_suffix(file_list[0][0])))
|
||||
gt_labels = read_label_file(
|
||||
file_client.join_path(root, rm_suffix(file_list[1][0])))
|
||||
join_path(root, rm_suffix(file_list[1][0])))
|
||||
|
||||
data_infos = []
|
||||
for img, gt_label in zip(imgs, gt_labels):
|
||||
|
@ -108,28 +108,23 @@ class MNIST(BaseDataset):
|
|||
def _check_exists(self):
|
||||
"""Check the exists of data files."""
|
||||
root = self.data_prefix['root']
|
||||
file_client = FileClient.infer_client(uri=root)
|
||||
|
||||
for filename, _ in (self.train_list + self.test_list):
|
||||
# get extracted filename of data
|
||||
extract_filename = rm_suffix(filename)
|
||||
fpath = file_client.join_path(root, extract_filename)
|
||||
if not file_client.exists(fpath):
|
||||
fpath = join_path(root, extract_filename)
|
||||
if not exists(fpath):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _download(self):
|
||||
"""Download and extract data files."""
|
||||
root = self.data_prefix['root']
|
||||
file_client = FileClient.infer_client(uri=root)
|
||||
|
||||
for filename, md5 in (self.train_list + self.test_list):
|
||||
url = file_client.join_path(self.url_prefix, filename)
|
||||
url = urljoin(self.url_prefix, filename)
|
||||
download_and_extract_archive(
|
||||
url,
|
||||
download_root=self.data_prefix['root'],
|
||||
filename=filename,
|
||||
md5=md5)
|
||||
url, download_root=root, filename=filename, md5=md5)
|
||||
|
||||
def extra_repr(self) -> List[str]:
|
||||
"""The extra repr information of the dataset."""
|
||||
|
|
|
@ -10,7 +10,7 @@ import urllib.error
|
|||
import urllib.request
|
||||
import zipfile
|
||||
|
||||
from mmengine.fileio.file_client import FileClient
|
||||
from mmengine.fileio import LocalBackend, get_file_backend
|
||||
|
||||
__all__ = [
|
||||
'rm_suffix', 'check_integrity', 'download_and_extract_archive',
|
||||
|
@ -25,16 +25,16 @@ def rm_suffix(s, suffix=None):
|
|||
return s[:s.rfind(suffix)]
|
||||
|
||||
|
||||
def calculate_md5(fpath: str,
|
||||
file_client: FileClient = None,
|
||||
chunk_size: int = 1024 * 1024):
|
||||
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024):
|
||||
md5 = hashlib.md5()
|
||||
if file_client is None or file_client.name == 'HardDiskBackend':
|
||||
backend = get_file_backend(fpath, enable_singleton=True)
|
||||
if isinstance(backend, LocalBackend):
|
||||
# Enable chunk update for local file.
|
||||
with open(fpath, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b''):
|
||||
md5.update(chunk)
|
||||
else:
|
||||
md5.update(file_client.get(fpath))
|
||||
md5.update(backend.get(fpath))
|
||||
return md5.hexdigest()
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import xml.etree.ElementTree as ET
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from mmengine import FileClient, list_from_file
|
||||
from mmengine import get_file_backend, list_from_file
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from .base_dataset import expanduser
|
||||
|
@ -72,9 +72,8 @@ class VOC(MultiLabelDataset):
|
|||
' False.'
|
||||
|
||||
self.data_root = data_root
|
||||
self.file_client = FileClient.infer_client(uri=data_root)
|
||||
self.image_set_path = self.file_client.join_path(
|
||||
data_root, image_set_path)
|
||||
self.backend = get_file_backend(data_root, enable_singleton=True)
|
||||
self.image_set_path = self.backend.join_path(data_root, image_set_path)
|
||||
|
||||
super().__init__(
|
||||
ann_file='',
|
||||
|
@ -94,8 +93,8 @@ class VOC(MultiLabelDataset):
|
|||
|
||||
def _get_labels_from_xml(self, img_id):
|
||||
"""Get gt_labels and labels_difficult from xml file."""
|
||||
xml_path = self.file_client.join_path(self.ann_prefix, f'{img_id}.xml')
|
||||
content = self.file_client.get(xml_path)
|
||||
xml_path = self.backend.join_path(self.ann_prefix, f'{img_id}.xml')
|
||||
content = self.backend.get(xml_path)
|
||||
root = ET.fromstring(content)
|
||||
|
||||
labels, labels_difficult = set(), set()
|
||||
|
@ -120,8 +119,7 @@ class VOC(MultiLabelDataset):
|
|||
img_ids = list_from_file(self.image_set_path)
|
||||
|
||||
for img_id in img_ids:
|
||||
img_path = self.file_client.join_path(self.img_prefix,
|
||||
f'{img_id}.jpg')
|
||||
img_path = self.backend.join_path(self.img_prefix, f'{img_id}.jpg')
|
||||
|
||||
labels, labels_difficult = None, None
|
||||
if self.ann_prefix is not None:
|
||||
|
|
|
@ -3,7 +3,7 @@ import math
|
|||
import os.path as osp
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mmengine import FileClient
|
||||
from mmengine.fileio import join_path
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.runner import EpochBasedTrainLoop, Runner
|
||||
from mmengine.visualization import Visualizer
|
||||
|
@ -45,10 +45,6 @@ class VisualizationHook(Hook):
|
|||
self.interval = interval
|
||||
self.show = show
|
||||
self.out_dir = out_dir
|
||||
if out_dir is not None:
|
||||
self.file_client = FileClient.infer_client(uri=out_dir)
|
||||
else:
|
||||
self.file_client = None
|
||||
|
||||
self.draw_args = {**kwargs, 'show': show}
|
||||
|
||||
|
@ -89,8 +85,8 @@ class VisualizationHook(Hook):
|
|||
|
||||
draw_args = self.draw_args
|
||||
if self.out_dir is not None:
|
||||
draw_args['out_file'] = self.file_client.join_path(
|
||||
self.out_dir, f'{sample_name}_{step}.png')
|
||||
draw_args['out_file'] = join_path(self.out_dir,
|
||||
f'{sample_name}_{step}.png')
|
||||
|
||||
self._visualizer.add_datasample(
|
||||
sample_name,
|
||||
|
|
|
@ -32,18 +32,6 @@ class TestVisualizationHook(TestCase):
|
|||
|
||||
self.tmpdir = tempfile.TemporaryDirectory()
|
||||
|
||||
def test_initialize(self):
|
||||
# test file_client
|
||||
cfg = dict(type='VisualizationHook')
|
||||
hook = HOOKS.build(cfg)
|
||||
self.assertIsNone(hook.file_client)
|
||||
|
||||
cfg = dict(type='VisualizationHook', out_dir=self.tmpdir.name)
|
||||
hook = HOOKS.build(cfg)
|
||||
self.assertIsNotNone(hook.file_client)
|
||||
|
||||
# test draw_args
|
||||
|
||||
def test_draw_samples(self):
|
||||
# test enable=False
|
||||
cfg = dict(type='VisualizationHook', enable=False)
|
||||
|
|
Loading…
Reference in New Issue