mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[feature] tmp add fileio (#17)
* tmp add fileio * ignore fileio mypy check error
This commit is contained in:
parent
019c2f5cc9
commit
8e9de77da4
3
mmengine/__init__.py
Normal file
3
mmengine/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# flake8: noqa
|
||||
from .utils import *
|
23
mmengine/utils/__init__.py
Normal file
23
mmengine/utils/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# type: ignore
|
||||
from .fileio import (FileClient, dict_from_file, dump, list_from_file, load,
|
||||
register_handler)
|
||||
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
|
||||
has_method, import_modules_from_strings, is_list_of,
|
||||
is_method_overridden, is_seq_of, is_str, is_tuple_of,
|
||||
iter_cast, list_cast, requires_executable, requires_package,
|
||||
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
|
||||
to_ntuple, tuple_cast)
|
||||
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
|
||||
scandir, symlink)
|
||||
|
||||
__all__ = [
|
||||
'is_str', 'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of',
|
||||
'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
|
||||
'check_prerequisites', 'requires_package', 'requires_executable',
|
||||
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink',
|
||||
'scandir', 'deprecated_api_warning', 'import_modules_from_strings',
|
||||
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
|
||||
'is_method_overridden', 'has_method', 'dict_from_file', 'list_from_file',
|
||||
'register_handler', 'dump', 'load', 'FileClient'
|
||||
]
|
12
mmengine/utils/fileio/__init__.py
Normal file
12
mmengine/utils/fileio/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# type: ignore
|
||||
from .file_client import BaseStorageBackend, FileClient
|
||||
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
|
||||
from .io import dump, load, register_handler
|
||||
from .parse import dict_from_file, list_from_file
|
||||
|
||||
__all__ = [
|
||||
'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler',
|
||||
'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler',
|
||||
'list_from_file', 'dict_from_file'
|
||||
]
|
1149
mmengine/utils/fileio/file_client.py
Normal file
1149
mmengine/utils/fileio/file_client.py
Normal file
File diff suppressed because it is too large
Load Diff
7
mmengine/utils/fileio/handlers/__init__.py
Normal file
7
mmengine/utils/fileio/handlers/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base import BaseFileHandler
|
||||
from .json_handler import JsonHandler
|
||||
from .pickle_handler import PickleHandler
|
||||
from .yaml_handler import YamlHandler
|
||||
|
||||
__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler']
|
30
mmengine/utils/fileio/handlers/base.py
Normal file
30
mmengine/utils/fileio/handlers/base.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
class BaseFileHandler(metaclass=ABCMeta):
|
||||
# `str_like` is a flag to indicate whether the type of file object is
|
||||
# str-like object or bytes-like object. Pickle only processes bytes-like
|
||||
# objects but json only processes str-like object. If it is str-like
|
||||
# object, `StringIO` will be used to process the buffer.
|
||||
str_like = True
|
||||
|
||||
@abstractmethod
|
||||
def load_from_fileobj(self, file, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dump_to_fileobj(self, obj, file, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def dump_to_str(self, obj, **kwargs):
|
||||
pass
|
||||
|
||||
def load_from_path(self, filepath, mode='r', **kwargs):
|
||||
with open(filepath, mode) as f:
|
||||
return self.load_from_fileobj(f, **kwargs)
|
||||
|
||||
def dump_to_path(self, obj, filepath, mode='w', **kwargs):
|
||||
with open(filepath, mode) as f:
|
||||
self.dump_to_fileobj(obj, f, **kwargs)
|
36
mmengine/utils/fileio/handlers/json_handler.py
Normal file
36
mmengine/utils/fileio/handlers/json_handler.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseFileHandler
|
||||
|
||||
|
||||
def set_default(obj):
|
||||
"""Set default json values for non-serializable values.
|
||||
|
||||
It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
|
||||
It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
|
||||
etc.) into plain numbers of plain python built-in types.
|
||||
"""
|
||||
if isinstance(obj, (set, range)):
|
||||
return list(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, np.generic):
|
||||
return obj.item()
|
||||
raise TypeError(f'{type(obj)} is unsupported for json dump')
|
||||
|
||||
|
||||
class JsonHandler(BaseFileHandler):
|
||||
|
||||
def load_from_fileobj(self, file):
|
||||
return json.load(file)
|
||||
|
||||
def dump_to_fileobj(self, obj, file, **kwargs):
|
||||
kwargs.setdefault('default', set_default)
|
||||
json.dump(obj, file, **kwargs)
|
||||
|
||||
def dump_to_str(self, obj, **kwargs):
|
||||
kwargs.setdefault('default', set_default)
|
||||
return json.dumps(obj, **kwargs)
|
28
mmengine/utils/fileio/handlers/pickle_handler.py
Normal file
28
mmengine/utils/fileio/handlers/pickle_handler.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pickle
|
||||
|
||||
from .base import BaseFileHandler
|
||||
|
||||
|
||||
class PickleHandler(BaseFileHandler):
|
||||
|
||||
str_like = False
|
||||
|
||||
def load_from_fileobj(self, file, **kwargs):
|
||||
return pickle.load(file, **kwargs)
|
||||
|
||||
def load_from_path(self, filepath, **kwargs):
|
||||
return super(PickleHandler, self).load_from_path(
|
||||
filepath, mode='rb', **kwargs)
|
||||
|
||||
def dump_to_str(self, obj, **kwargs):
|
||||
kwargs.setdefault('protocol', 2)
|
||||
return pickle.dumps(obj, **kwargs)
|
||||
|
||||
def dump_to_fileobj(self, obj, file, **kwargs):
|
||||
kwargs.setdefault('protocol', 2)
|
||||
pickle.dump(obj, file, **kwargs)
|
||||
|
||||
def dump_to_path(self, obj, filepath, **kwargs):
|
||||
super(PickleHandler, self).dump_to_path(
|
||||
obj, filepath, mode='wb', **kwargs)
|
25
mmengine/utils/fileio/handlers/yaml_handler.py
Normal file
25
mmengine/utils/fileio/handlers/yaml_handler.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import yaml
|
||||
|
||||
try:
|
||||
from yaml import CDumper as Dumper # type: ignore
|
||||
from yaml import CLoader as Loader # type: ignore
|
||||
except ImportError:
|
||||
from yaml import Loader, Dumper # type: ignore
|
||||
|
||||
from .base import BaseFileHandler # isort:skip
|
||||
|
||||
|
||||
class YamlHandler(BaseFileHandler):
|
||||
|
||||
def load_from_fileobj(self, file, **kwargs):
|
||||
kwargs.setdefault('Loader', Loader)
|
||||
return yaml.load(file, **kwargs)
|
||||
|
||||
def dump_to_fileobj(self, obj, file, **kwargs):
|
||||
kwargs.setdefault('Dumper', Dumper)
|
||||
yaml.dump(obj, file, **kwargs)
|
||||
|
||||
def dump_to_str(self, obj, **kwargs):
|
||||
kwargs.setdefault('Dumper', Dumper)
|
||||
return yaml.dump(obj, **kwargs)
|
152
mmengine/utils/fileio/io.py
Normal file
152
mmengine/utils/fileio/io.py
Normal file
@ -0,0 +1,152 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# type: ignore
|
||||
from io import BytesIO, StringIO
|
||||
from pathlib import Path
|
||||
|
||||
from mmengine import is_list_of, is_str
|
||||
from .file_client import FileClient
|
||||
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
|
||||
|
||||
file_handlers = {
|
||||
'json': JsonHandler(),
|
||||
'yaml': YamlHandler(),
|
||||
'yml': YamlHandler(),
|
||||
'pickle': PickleHandler(),
|
||||
'pkl': PickleHandler()
|
||||
}
|
||||
|
||||
|
||||
def load(file, file_format=None, file_client_args=None, **kwargs):
|
||||
"""Load data from json/yaml/pickle files.
|
||||
|
||||
This method provides a unified api for loading data from serialized files.
|
||||
|
||||
Note:
|
||||
In v1.3.16 and later, ``load`` supports loading data from serialized
|
||||
files those can be storaged in different backends.
|
||||
|
||||
Args:
|
||||
file (str or :obj:`Path` or file-like object): Filename or a file-like
|
||||
object.
|
||||
file_format (str, optional): If not specified, the file format will be
|
||||
inferred from the file extension, otherwise use the specified one.
|
||||
Currently supported formats include "json", "yaml/yml" and
|
||||
"pickle/pkl".
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Default: None.
|
||||
|
||||
Examples:
|
||||
>>> load('/path/of/your/file') # file is storaged in disk
|
||||
>>> load('https://path/of/your/file') # file is storaged in Internet
|
||||
>>> load('s3://path/of/your/file') # file is storaged in petrel
|
||||
|
||||
Returns:
|
||||
The content from the file.
|
||||
"""
|
||||
if isinstance(file, Path):
|
||||
file = str(file)
|
||||
if file_format is None and is_str(file):
|
||||
file_format = file.split('.')[-1]
|
||||
if file_format not in file_handlers:
|
||||
raise TypeError(f'Unsupported format: {file_format}')
|
||||
|
||||
handler = file_handlers[file_format]
|
||||
if is_str(file):
|
||||
file_client = FileClient.infer_client(file_client_args, file)
|
||||
if handler.str_like:
|
||||
with StringIO(file_client.get_text(file)) as f:
|
||||
obj = handler.load_from_fileobj(f, **kwargs)
|
||||
else:
|
||||
with BytesIO(file_client.get(file)) as f:
|
||||
obj = handler.load_from_fileobj(f, **kwargs)
|
||||
elif hasattr(file, 'read'):
|
||||
obj = handler.load_from_fileobj(file, **kwargs)
|
||||
else:
|
||||
raise TypeError('"file" must be a filepath str or a file-object')
|
||||
return obj
|
||||
|
||||
|
||||
def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
|
||||
"""Dump data to json/yaml/pickle strings or files.
|
||||
|
||||
This method provides a unified api for dumping data as strings or to files,
|
||||
and also supports custom arguments for each file format.
|
||||
|
||||
Note:
|
||||
In v1.3.16 and later, ``dump`` supports dumping data as strings or to
|
||||
files which is saved to different backends.
|
||||
|
||||
Args:
|
||||
obj (any): The python object to be dumped.
|
||||
file (str or :obj:`Path` or file-like object, optional): If not
|
||||
specified, then the object is dumped to a str, otherwise to a file
|
||||
specified by the filename or file-like object.
|
||||
file_format (str, optional): Same as :func:`load`.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Default: None.
|
||||
|
||||
Examples:
|
||||
>>> dump('hello world', '/path/of/your/file') # disk
|
||||
>>> dump('hello world', 's3://path/of/your/file') # ceph or petrel
|
||||
|
||||
Returns:
|
||||
bool: True for success, False otherwise.
|
||||
"""
|
||||
if isinstance(file, Path):
|
||||
file = str(file)
|
||||
if file_format is None:
|
||||
if is_str(file):
|
||||
file_format = file.split('.')[-1]
|
||||
elif file is None:
|
||||
raise ValueError(
|
||||
'file_format must be specified since file is None')
|
||||
if file_format not in file_handlers:
|
||||
raise TypeError(f'Unsupported format: {file_format}')
|
||||
|
||||
handler = file_handlers[file_format]
|
||||
if file is None:
|
||||
return handler.dump_to_str(obj, **kwargs)
|
||||
elif is_str(file):
|
||||
file_client = FileClient.infer_client(file_client_args, file)
|
||||
if handler.str_like:
|
||||
with StringIO() as f:
|
||||
handler.dump_to_fileobj(obj, f, **kwargs)
|
||||
file_client.put_text(f.getvalue(), file)
|
||||
else:
|
||||
with BytesIO() as f:
|
||||
handler.dump_to_fileobj(obj, f, **kwargs)
|
||||
file_client.put(f.getvalue(), file)
|
||||
elif hasattr(file, 'write'):
|
||||
handler.dump_to_fileobj(obj, file, **kwargs)
|
||||
else:
|
||||
raise TypeError('"file" must be a filename str or a file-object')
|
||||
|
||||
|
||||
def _register_handler(handler, file_formats):
|
||||
"""Register a handler for some file extensions.
|
||||
|
||||
Args:
|
||||
handler (:obj:`BaseFileHandler`): Handler to be registered.
|
||||
file_formats (str or list[str]): File formats to be handled by this
|
||||
handler.
|
||||
"""
|
||||
if not isinstance(handler, BaseFileHandler):
|
||||
raise TypeError(
|
||||
f'handler must be a child of BaseFileHandler, not {type(handler)}')
|
||||
if isinstance(file_formats, str):
|
||||
file_formats = [file_formats]
|
||||
if not is_list_of(file_formats, str):
|
||||
raise TypeError('file_formats must be a str or a list of str')
|
||||
for ext in file_formats:
|
||||
file_handlers[ext] = handler
|
||||
|
||||
|
||||
def register_handler(file_formats, **kwargs):
|
||||
|
||||
def wrap(cls):
|
||||
_register_handler(cls(**kwargs), file_formats)
|
||||
return cls
|
||||
|
||||
return wrap
|
97
mmengine/utils/fileio/parse.py
Normal file
97
mmengine/utils/fileio/parse.py
Normal file
@ -0,0 +1,97 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# type: ignore
|
||||
from io import StringIO
|
||||
|
||||
from .file_client import FileClient
|
||||
|
||||
|
||||
def list_from_file(filename,
|
||||
prefix='',
|
||||
offset=0,
|
||||
max_num=0,
|
||||
encoding='utf-8',
|
||||
file_client_args=None):
|
||||
"""Load a text file and parse the content as a list of strings.
|
||||
|
||||
Note:
|
||||
In v1.3.16 and later, ``list_from_file`` supports loading a text file
|
||||
which can be storaged in different backends and parsing the content as
|
||||
a list for strings.
|
||||
|
||||
Args:
|
||||
filename (str): Filename.
|
||||
prefix (str): The prefix to be inserted to the beginning of each item.
|
||||
offset (int): The offset of lines.
|
||||
max_num (int): The maximum number of lines to be read,
|
||||
zeros and negatives mean no limitation.
|
||||
encoding (str): Encoding used to open the file. Default utf-8.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Default: None.
|
||||
|
||||
Examples:
|
||||
>>> list_from_file('/path/of/your/file') # disk
|
||||
['hello', 'world']
|
||||
>>> list_from_file('s3://path/of/your/file') # ceph or petrel
|
||||
['hello', 'world']
|
||||
|
||||
Returns:
|
||||
list[str]: A list of strings.
|
||||
"""
|
||||
cnt = 0
|
||||
item_list = []
|
||||
file_client = FileClient.infer_client(file_client_args, filename)
|
||||
with StringIO(file_client.get_text(filename, encoding)) as f:
|
||||
for _ in range(offset):
|
||||
f.readline()
|
||||
for line in f:
|
||||
if 0 < max_num <= cnt:
|
||||
break
|
||||
item_list.append(prefix + line.rstrip('\n\r'))
|
||||
cnt += 1
|
||||
return item_list
|
||||
|
||||
|
||||
def dict_from_file(filename,
|
||||
key_type=str,
|
||||
encoding='utf-8',
|
||||
file_client_args=None):
|
||||
"""Load a text file and parse the content as a dict.
|
||||
|
||||
Each line of the text file will be two or more columns split by
|
||||
whitespaces or tabs. The first column will be parsed as dict keys, and
|
||||
the following columns will be parsed as dict values.
|
||||
|
||||
Note:
|
||||
In v1.3.16 and later, ``dict_from_file`` supports loading a text file
|
||||
which can be storaged in different backends and parsing the content as
|
||||
a dict.
|
||||
|
||||
Args:
|
||||
filename(str): Filename.
|
||||
key_type(type): Type of the dict keys. str is user by default and
|
||||
type conversion will be performed if specified.
|
||||
encoding (str): Encoding used to open the file. Default utf-8.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmcv.fileio.FileClient` for details.
|
||||
Default: None.
|
||||
|
||||
Examples:
|
||||
>>> dict_from_file('/path/of/your/file') # disk
|
||||
{'key1': 'value1', 'key2': 'value2'}
|
||||
>>> dict_from_file('s3://path/of/your/file') # ceph or petrel
|
||||
{'key1': 'value1', 'key2': 'value2'}
|
||||
|
||||
Returns:
|
||||
dict: The parsed contents.
|
||||
"""
|
||||
mapping = {}
|
||||
file_client = FileClient.infer_client(file_client_args, filename)
|
||||
with StringIO(file_client.get_text(filename, encoding)) as f:
|
||||
for line in f:
|
||||
items = line.rstrip('\n').split()
|
||||
assert len(items) >= 2
|
||||
key = key_type(items[0])
|
||||
val = items[1:] if len(items) > 2 else items[1]
|
||||
mapping[key] = val
|
||||
return mapping
|
377
mmengine/utils/misc.py
Normal file
377
mmengine/utils/misc.py
Normal file
@ -0,0 +1,377 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import collections.abc
|
||||
import functools
|
||||
import itertools
|
||||
import subprocess
|
||||
import warnings
|
||||
from collections import abc
|
||||
from importlib import import_module
|
||||
from inspect import getfullargspec
|
||||
from itertools import repeat
|
||||
|
||||
|
||||
# From PyTorch internals
|
||||
def _ntuple(n):
|
||||
|
||||
def parse(x):
|
||||
if isinstance(x, collections.abc.Iterable):
|
||||
return x
|
||||
return tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_1tuple = _ntuple(1)
|
||||
to_2tuple = _ntuple(2)
|
||||
to_3tuple = _ntuple(3)
|
||||
to_4tuple = _ntuple(4)
|
||||
to_ntuple = _ntuple
|
||||
|
||||
|
||||
def is_str(x):
|
||||
"""Whether the input is an string instance.
|
||||
|
||||
Note: This method is deprecated since python 2 is no longer supported.
|
||||
"""
|
||||
return isinstance(x, str)
|
||||
|
||||
|
||||
def import_modules_from_strings(imports, allow_failed_imports=False):
|
||||
"""Import modules from the given list of strings.
|
||||
|
||||
Args:
|
||||
imports (list | str | None): The given module names to be imported.
|
||||
allow_failed_imports (bool): If True, the failed imports will return
|
||||
None. Otherwise, an ImportError is raise. Default: False.
|
||||
|
||||
Returns:
|
||||
list[module] | module | None: The imported modules.
|
||||
|
||||
Examples:
|
||||
>>> osp, sys = import_modules_from_strings(
|
||||
... ['os.path', 'sys'])
|
||||
>>> import os.path as osp_
|
||||
>>> import sys as sys_
|
||||
>>> assert osp == osp_
|
||||
>>> assert sys == sys_
|
||||
"""
|
||||
if not imports:
|
||||
return
|
||||
single_import = False
|
||||
if isinstance(imports, str):
|
||||
single_import = True
|
||||
imports = [imports]
|
||||
if not isinstance(imports, list):
|
||||
raise TypeError(
|
||||
f'custom_imports must be a list but got type {type(imports)}')
|
||||
imported = []
|
||||
for imp in imports:
|
||||
if not isinstance(imp, str):
|
||||
raise TypeError(
|
||||
f'{imp} is of type {type(imp)} and cannot be imported.')
|
||||
try:
|
||||
imported_tmp = import_module(imp)
|
||||
except ImportError:
|
||||
if allow_failed_imports:
|
||||
warnings.warn(f'{imp} failed to import and is ignored.',
|
||||
UserWarning)
|
||||
imported_tmp = None
|
||||
else:
|
||||
raise ImportError
|
||||
imported.append(imported_tmp)
|
||||
if single_import:
|
||||
imported = imported[0]
|
||||
return imported
|
||||
|
||||
|
||||
def iter_cast(inputs, dst_type, return_type=None):
|
||||
"""Cast elements of an iterable object into some type.
|
||||
|
||||
Args:
|
||||
inputs (Iterable): The input object.
|
||||
dst_type (type): Destination type.
|
||||
return_type (type, optional): If specified, the output object will be
|
||||
converted to this type, otherwise an iterator.
|
||||
|
||||
Returns:
|
||||
iterator or specified type: The converted object.
|
||||
"""
|
||||
if not isinstance(inputs, abc.Iterable):
|
||||
raise TypeError('inputs must be an iterable object')
|
||||
if not isinstance(dst_type, type):
|
||||
raise TypeError('"dst_type" must be a valid type')
|
||||
|
||||
out_iterable = map(dst_type, inputs)
|
||||
|
||||
if return_type is None:
|
||||
return out_iterable
|
||||
else:
|
||||
return return_type(out_iterable)
|
||||
|
||||
|
||||
def list_cast(inputs, dst_type):
|
||||
"""Cast elements of an iterable object into a list of some type.
|
||||
|
||||
A partial method of :func:`iter_cast`.
|
||||
"""
|
||||
return iter_cast(inputs, dst_type, return_type=list)
|
||||
|
||||
|
||||
def tuple_cast(inputs, dst_type):
|
||||
"""Cast elements of an iterable object into a tuple of some type.
|
||||
|
||||
A partial method of :func:`iter_cast`.
|
||||
"""
|
||||
return iter_cast(inputs, dst_type, return_type=tuple)
|
||||
|
||||
|
||||
def is_seq_of(seq, expected_type, seq_type=None):
|
||||
"""Check whether it is a sequence of some type.
|
||||
|
||||
Args:
|
||||
seq (Sequence): The sequence to be checked.
|
||||
expected_type (type): Expected type of sequence items.
|
||||
seq_type (type, optional): Expected sequence type.
|
||||
|
||||
Returns:
|
||||
bool: Whether the sequence is valid.
|
||||
"""
|
||||
if seq_type is None:
|
||||
exp_seq_type = abc.Sequence
|
||||
else:
|
||||
assert isinstance(seq_type, type)
|
||||
exp_seq_type = seq_type
|
||||
if not isinstance(seq, exp_seq_type):
|
||||
return False
|
||||
for item in seq:
|
||||
if not isinstance(item, expected_type):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_list_of(seq, expected_type):
|
||||
"""Check whether it is a list of some type.
|
||||
|
||||
A partial method of :func:`is_seq_of`.
|
||||
"""
|
||||
return is_seq_of(seq, expected_type, seq_type=list)
|
||||
|
||||
|
||||
def is_tuple_of(seq, expected_type):
|
||||
"""Check whether it is a tuple of some type.
|
||||
|
||||
A partial method of :func:`is_seq_of`.
|
||||
"""
|
||||
return is_seq_of(seq, expected_type, seq_type=tuple)
|
||||
|
||||
|
||||
def slice_list(in_list, lens):
|
||||
"""Slice a list into several sub lists by a list of given length.
|
||||
|
||||
Args:
|
||||
in_list (list): The list to be sliced.
|
||||
lens(int or list): The expected length of each out list.
|
||||
|
||||
Returns:
|
||||
list: A list of sliced list.
|
||||
"""
|
||||
if isinstance(lens, int):
|
||||
assert len(in_list) % lens == 0
|
||||
lens = [lens] * int(len(in_list) / lens)
|
||||
if not isinstance(lens, list):
|
||||
raise TypeError('"indices" must be an integer or a list of integers')
|
||||
elif sum(lens) != len(in_list):
|
||||
raise ValueError('sum of lens and list length does not '
|
||||
f'match: {sum(lens)} != {len(in_list)}')
|
||||
out_list = []
|
||||
idx = 0
|
||||
for i in range(len(lens)):
|
||||
out_list.append(in_list[idx:idx + lens[i]])
|
||||
idx += lens[i]
|
||||
return out_list
|
||||
|
||||
|
||||
def concat_list(in_list):
|
||||
"""Concatenate a list of list into a single list.
|
||||
|
||||
Args:
|
||||
in_list (list): The list of list to be merged.
|
||||
|
||||
Returns:
|
||||
list: The concatenated flat list.
|
||||
"""
|
||||
return list(itertools.chain(*in_list))
|
||||
|
||||
|
||||
def check_prerequisites(
|
||||
prerequisites,
|
||||
checker,
|
||||
msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
|
||||
'found, please install them first.'): # yapf: disable
|
||||
"""A decorator factory to check if prerequisites are satisfied.
|
||||
|
||||
Args:
|
||||
prerequisites (str of list[str]): Prerequisites to be checked.
|
||||
checker (callable): The checker method that returns True if a
|
||||
prerequisite is meet, False otherwise.
|
||||
msg_tmpl (str): The message template with two variables.
|
||||
|
||||
Returns:
|
||||
decorator: A specific decorator.
|
||||
"""
|
||||
|
||||
def wrap(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
requirements = [prerequisites] if isinstance(
|
||||
prerequisites, str) else prerequisites
|
||||
missing = []
|
||||
for item in requirements:
|
||||
if not checker(item):
|
||||
missing.append(item)
|
||||
if missing:
|
||||
print(msg_tmpl.format(', '.join(missing), func.__name__))
|
||||
raise RuntimeError('Prerequisites not meet.')
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped_func
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def _check_py_package(package):
|
||||
try:
|
||||
import_module(package)
|
||||
except ImportError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def _check_executable(cmd):
|
||||
if subprocess.call(f'which {cmd}', shell=True) != 0:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def requires_package(prerequisites):
|
||||
"""A decorator to check if some python packages are installed.
|
||||
|
||||
Example:
|
||||
>>> @requires_package('numpy')
|
||||
>>> func(arg1, args):
|
||||
>>> return numpy.zeros(1)
|
||||
array([0.])
|
||||
>>> @requires_package(['numpy', 'non_package'])
|
||||
>>> func(arg1, args):
|
||||
>>> return numpy.zeros(1)
|
||||
ImportError
|
||||
"""
|
||||
return check_prerequisites(prerequisites, checker=_check_py_package)
|
||||
|
||||
|
||||
def requires_executable(prerequisites):
|
||||
"""A decorator to check if some executable files are installed.
|
||||
|
||||
Example:
|
||||
>>> @requires_executable('ffmpeg')
|
||||
>>> func(arg1, args):
|
||||
>>> print(1)
|
||||
1
|
||||
"""
|
||||
return check_prerequisites(prerequisites, checker=_check_executable)
|
||||
|
||||
|
||||
def deprecated_api_warning(name_dict, cls_name=None):
|
||||
"""A decorator to check if some arguments are deprecate and try to replace
|
||||
deprecate src_arg_name to dst_arg_name.
|
||||
|
||||
Args:
|
||||
name_dict(dict):
|
||||
key (str): Deprecate argument names.
|
||||
val (str): Expected argument names.
|
||||
|
||||
Returns:
|
||||
func: New function.
|
||||
"""
|
||||
|
||||
def api_warning_wrapper(old_func):
|
||||
|
||||
@functools.wraps(old_func)
|
||||
def new_func(*args, **kwargs):
|
||||
# get the arg spec of the decorated method
|
||||
args_info = getfullargspec(old_func)
|
||||
# get name of the function
|
||||
func_name = old_func.__name__
|
||||
if cls_name is not None:
|
||||
func_name = f'{cls_name}.{func_name}'
|
||||
if args:
|
||||
arg_names = args_info.args[:len(args)]
|
||||
for src_arg_name, dst_arg_name in name_dict.items():
|
||||
if src_arg_name in arg_names:
|
||||
warnings.warn(
|
||||
f'"{src_arg_name}" is deprecated in '
|
||||
f'`{func_name}`, please use "{dst_arg_name}" '
|
||||
'instead', DeprecationWarning)
|
||||
arg_names[arg_names.index(src_arg_name)] = dst_arg_name
|
||||
if kwargs:
|
||||
for src_arg_name, dst_arg_name in name_dict.items():
|
||||
if src_arg_name in kwargs:
|
||||
|
||||
assert dst_arg_name not in kwargs, (
|
||||
f'The expected behavior is to replace '
|
||||
f'the deprecated key `{src_arg_name}` to '
|
||||
f'new key `{dst_arg_name}`, but got them '
|
||||
f'in the arguments at the same time, which '
|
||||
f'is confusing. `{src_arg_name} will be '
|
||||
f'deprecated in the future, please '
|
||||
f'use `{dst_arg_name}` instead.')
|
||||
|
||||
warnings.warn(
|
||||
f'"{src_arg_name}" is deprecated in '
|
||||
f'`{func_name}`, please use "{dst_arg_name}" '
|
||||
'instead', DeprecationWarning)
|
||||
kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
|
||||
|
||||
# apply converted arguments to the decorated method
|
||||
output = old_func(*args, **kwargs)
|
||||
return output
|
||||
|
||||
return new_func
|
||||
|
||||
return api_warning_wrapper
|
||||
|
||||
|
||||
def is_method_overridden(method, base_class, derived_class):
|
||||
"""Check if a method of base class is overridden in derived class.
|
||||
|
||||
Args:
|
||||
method (str): the method name to check.
|
||||
base_class (type): the class of the base class.
|
||||
derived_class (type | Any): the class or instance of the derived class.
|
||||
"""
|
||||
assert isinstance(base_class, type), \
|
||||
"base_class doesn't accept instance, Please pass class instead."
|
||||
|
||||
if not isinstance(derived_class, type):
|
||||
derived_class = derived_class.__class__
|
||||
|
||||
base_method = getattr(base_class, method)
|
||||
derived_method = getattr(derived_class, method)
|
||||
return derived_method != base_method
|
||||
|
||||
|
||||
def has_method(obj: object, method: str) -> bool:
|
||||
"""Check whether the object has a method.
|
||||
|
||||
Args:
|
||||
method (str): The method name to check.
|
||||
obj (object): The object to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the object has the method else False.
|
||||
"""
|
||||
return hasattr(obj, method) and callable(getattr(obj, method))
|
101
mmengine/utils/path.py
Normal file
101
mmengine/utils/path.py
Normal file
@ -0,0 +1,101 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
|
||||
from .misc import is_str
|
||||
|
||||
|
||||
def is_filepath(x):
|
||||
return is_str(x) or isinstance(x, Path)
|
||||
|
||||
|
||||
def fopen(filepath, *args, **kwargs):
|
||||
if is_str(filepath):
|
||||
return open(filepath, *args, **kwargs)
|
||||
elif isinstance(filepath, Path):
|
||||
return filepath.open(*args, **kwargs)
|
||||
raise ValueError('`filepath` should be a string or a Path')
|
||||
|
||||
|
||||
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
|
||||
if not osp.isfile(filename):
|
||||
raise FileNotFoundError(msg_tmpl.format(filename))
|
||||
|
||||
|
||||
def mkdir_or_exist(dir_name, mode=0o777):
|
||||
if dir_name == '':
|
||||
return
|
||||
dir_name = osp.expanduser(dir_name)
|
||||
os.makedirs(dir_name, mode=mode, exist_ok=True)
|
||||
|
||||
|
||||
def symlink(src, dst, overwrite=True, **kwargs):
|
||||
if os.path.lexists(dst) and overwrite:
|
||||
os.remove(dst)
|
||||
os.symlink(src, dst, **kwargs)
|
||||
|
||||
|
||||
def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
|
||||
"""Scan a directory to find the interested files.
|
||||
|
||||
Args:
|
||||
dir_path (str | :obj:`Path`): Path of the directory.
|
||||
suffix (str | tuple(str), optional): File suffix that we are
|
||||
interested in. Default: None.
|
||||
recursive (bool, optional): If set to True, recursively scan the
|
||||
directory. Default: False.
|
||||
case_sensitive (bool, optional) : If set to False, ignore the case of
|
||||
suffix. Default: True.
|
||||
|
||||
Returns:
|
||||
A generator for all the interested files with relative paths.
|
||||
"""
|
||||
if isinstance(dir_path, (str, Path)):
|
||||
dir_path = str(dir_path)
|
||||
else:
|
||||
raise TypeError('"dir_path" must be a string or Path object')
|
||||
|
||||
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
||||
raise TypeError('"suffix" must be a string or tuple of strings')
|
||||
|
||||
if suffix is not None and not case_sensitive:
|
||||
suffix = suffix.lower() if isinstance(suffix, str) else tuple(
|
||||
item.lower() for item in suffix)
|
||||
|
||||
root = dir_path
|
||||
|
||||
def _scandir(dir_path, suffix, recursive, case_sensitive):
|
||||
for entry in os.scandir(dir_path):
|
||||
if not entry.name.startswith('.') and entry.is_file():
|
||||
rel_path = osp.relpath(entry.path, root)
|
||||
_rel_path = rel_path if case_sensitive else rel_path.lower()
|
||||
if suffix is None or _rel_path.endswith(suffix):
|
||||
yield rel_path
|
||||
elif recursive and os.path.isdir(entry.path):
|
||||
# scan recursively if entry.path is a directory
|
||||
yield from _scandir(entry.path, suffix, recursive,
|
||||
case_sensitive)
|
||||
|
||||
return _scandir(dir_path, suffix, recursive, case_sensitive)
|
||||
|
||||
|
||||
def find_vcs_root(path, markers=('.git', )):
|
||||
"""Finds the root directory (including itself) of specified markers.
|
||||
|
||||
Args:
|
||||
path (str): Path of directory or file.
|
||||
markers (list[str], optional): List of file or directory names.
|
||||
|
||||
Returns:
|
||||
The directory contained one of the markers or None if not found.
|
||||
"""
|
||||
if osp.isfile(path):
|
||||
path = osp.dirname(path)
|
||||
|
||||
prev, cur = None, osp.abspath(osp.expanduser(path))
|
||||
while cur != prev:
|
||||
if any(osp.exists(osp.join(cur, marker)) for marker in markers):
|
||||
return cur
|
||||
prev, cur = cur, osp.split(cur)[0]
|
||||
return None
|
Loading…
x
Reference in New Issue
Block a user