[feature] tmp add fileio (#17)

* tmp add fileio

* ignore fileio mypy check error
This commit is contained in:
Mashiro 2022-02-14 21:55:35 +08:00 committed by GitHub
parent 019c2f5cc9
commit 8e9de77da4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 2040 additions and 0 deletions

3
mmengine/__init__.py Normal file
View File

@ -0,0 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa
from .utils import *

View File

@ -0,0 +1,23 @@
# Copyright (c) OpenMMLab. All rights reserved.
# type: ignore
from .fileio import (FileClient, dict_from_file, dump, list_from_file, load,
register_handler)
from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
has_method, import_modules_from_strings, is_list_of,
is_method_overridden, is_seq_of, is_str, is_tuple_of,
iter_cast, list_cast, requires_executable, requires_package,
slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple,
to_ntuple, tuple_cast)
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
__all__ = [
'is_str', 'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of',
'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
'check_prerequisites', 'requires_package', 'requires_executable',
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink',
'scandir', 'deprecated_api_warning', 'import_modules_from_strings',
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_method_overridden', 'has_method', 'dict_from_file', 'list_from_file',
'register_handler', 'dump', 'load', 'FileClient'
]

View File

@ -0,0 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
# type: ignore
from .file_client import BaseStorageBackend, FileClient
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
from .io import dump, load, register_handler
from .parse import dict_from_file, list_from_file
__all__ = [
'BaseStorageBackend', 'FileClient', 'load', 'dump', 'register_handler',
'BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler',
'list_from_file', 'dict_from_file'
]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseFileHandler
from .json_handler import JsonHandler
from .pickle_handler import PickleHandler
from .yaml_handler import YamlHandler
__all__ = ['BaseFileHandler', 'JsonHandler', 'PickleHandler', 'YamlHandler']

View File

@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
class BaseFileHandler(metaclass=ABCMeta):
# `str_like` is a flag to indicate whether the type of file object is
# str-like object or bytes-like object. Pickle only processes bytes-like
# objects but json only processes str-like object. If it is str-like
# object, `StringIO` will be used to process the buffer.
str_like = True
@abstractmethod
def load_from_fileobj(self, file, **kwargs):
pass
@abstractmethod
def dump_to_fileobj(self, obj, file, **kwargs):
pass
@abstractmethod
def dump_to_str(self, obj, **kwargs):
pass
def load_from_path(self, filepath, mode='r', **kwargs):
with open(filepath, mode) as f:
return self.load_from_fileobj(f, **kwargs)
def dump_to_path(self, obj, filepath, mode='w', **kwargs):
with open(filepath, mode) as f:
self.dump_to_fileobj(obj, f, **kwargs)

View File

@ -0,0 +1,36 @@
# Copyright (c) OpenMMLab. All rights reserved.
import json
import numpy as np
from .base import BaseFileHandler
def set_default(obj):
"""Set default json values for non-serializable values.
It helps convert ``set``, ``range`` and ``np.ndarray`` data types to list.
It also converts ``np.generic`` (including ``np.int32``, ``np.float32``,
etc.) into plain numbers of plain python built-in types.
"""
if isinstance(obj, (set, range)):
return list(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.generic):
return obj.item()
raise TypeError(f'{type(obj)} is unsupported for json dump')
class JsonHandler(BaseFileHandler):
def load_from_fileobj(self, file):
return json.load(file)
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault('default', set_default)
json.dump(obj, file, **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('default', set_default)
return json.dumps(obj, **kwargs)

View File

@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pickle
from .base import BaseFileHandler
class PickleHandler(BaseFileHandler):
str_like = False
def load_from_fileobj(self, file, **kwargs):
return pickle.load(file, **kwargs)
def load_from_path(self, filepath, **kwargs):
return super(PickleHandler, self).load_from_path(
filepath, mode='rb', **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('protocol', 2)
return pickle.dumps(obj, **kwargs)
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault('protocol', 2)
pickle.dump(obj, file, **kwargs)
def dump_to_path(self, obj, filepath, **kwargs):
super(PickleHandler, self).dump_to_path(
obj, filepath, mode='wb', **kwargs)

View File

@ -0,0 +1,25 @@
# Copyright (c) OpenMMLab. All rights reserved.
import yaml
try:
from yaml import CDumper as Dumper # type: ignore
from yaml import CLoader as Loader # type: ignore
except ImportError:
from yaml import Loader, Dumper # type: ignore
from .base import BaseFileHandler # isort:skip
class YamlHandler(BaseFileHandler):
def load_from_fileobj(self, file, **kwargs):
kwargs.setdefault('Loader', Loader)
return yaml.load(file, **kwargs)
def dump_to_fileobj(self, obj, file, **kwargs):
kwargs.setdefault('Dumper', Dumper)
yaml.dump(obj, file, **kwargs)
def dump_to_str(self, obj, **kwargs):
kwargs.setdefault('Dumper', Dumper)
return yaml.dump(obj, **kwargs)

152
mmengine/utils/fileio/io.py Normal file
View File

@ -0,0 +1,152 @@
# Copyright (c) OpenMMLab. All rights reserved.
# type: ignore
from io import BytesIO, StringIO
from pathlib import Path
from mmengine import is_list_of, is_str
from .file_client import FileClient
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
file_handlers = {
'json': JsonHandler(),
'yaml': YamlHandler(),
'yml': YamlHandler(),
'pickle': PickleHandler(),
'pkl': PickleHandler()
}
def load(file, file_format=None, file_client_args=None, **kwargs):
"""Load data from json/yaml/pickle files.
This method provides a unified api for loading data from serialized files.
Note:
In v1.3.16 and later, ``load`` supports loading data from serialized
files those can be storaged in different backends.
Args:
file (str or :obj:`Path` or file-like object): Filename or a file-like
object.
file_format (str, optional): If not specified, the file format will be
inferred from the file extension, otherwise use the specified one.
Currently supported formats include "json", "yaml/yml" and
"pickle/pkl".
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Examples:
>>> load('/path/of/your/file') # file is storaged in disk
>>> load('https://path/of/your/file') # file is storaged in Internet
>>> load('s3://path/of/your/file') # file is storaged in petrel
Returns:
The content from the file.
"""
if isinstance(file, Path):
file = str(file)
if file_format is None and is_str(file):
file_format = file.split('.')[-1]
if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}')
handler = file_handlers[file_format]
if is_str(file):
file_client = FileClient.infer_client(file_client_args, file)
if handler.str_like:
with StringIO(file_client.get_text(file)) as f:
obj = handler.load_from_fileobj(f, **kwargs)
else:
with BytesIO(file_client.get(file)) as f:
obj = handler.load_from_fileobj(f, **kwargs)
elif hasattr(file, 'read'):
obj = handler.load_from_fileobj(file, **kwargs)
else:
raise TypeError('"file" must be a filepath str or a file-object')
return obj
def dump(obj, file=None, file_format=None, file_client_args=None, **kwargs):
"""Dump data to json/yaml/pickle strings or files.
This method provides a unified api for dumping data as strings or to files,
and also supports custom arguments for each file format.
Note:
In v1.3.16 and later, ``dump`` supports dumping data as strings or to
files which is saved to different backends.
Args:
obj (any): The python object to be dumped.
file (str or :obj:`Path` or file-like object, optional): If not
specified, then the object is dumped to a str, otherwise to a file
specified by the filename or file-like object.
file_format (str, optional): Same as :func:`load`.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Examples:
>>> dump('hello world', '/path/of/your/file') # disk
>>> dump('hello world', 's3://path/of/your/file') # ceph or petrel
Returns:
bool: True for success, False otherwise.
"""
if isinstance(file, Path):
file = str(file)
if file_format is None:
if is_str(file):
file_format = file.split('.')[-1]
elif file is None:
raise ValueError(
'file_format must be specified since file is None')
if file_format not in file_handlers:
raise TypeError(f'Unsupported format: {file_format}')
handler = file_handlers[file_format]
if file is None:
return handler.dump_to_str(obj, **kwargs)
elif is_str(file):
file_client = FileClient.infer_client(file_client_args, file)
if handler.str_like:
with StringIO() as f:
handler.dump_to_fileobj(obj, f, **kwargs)
file_client.put_text(f.getvalue(), file)
else:
with BytesIO() as f:
handler.dump_to_fileobj(obj, f, **kwargs)
file_client.put(f.getvalue(), file)
elif hasattr(file, 'write'):
handler.dump_to_fileobj(obj, file, **kwargs)
else:
raise TypeError('"file" must be a filename str or a file-object')
def _register_handler(handler, file_formats):
"""Register a handler for some file extensions.
Args:
handler (:obj:`BaseFileHandler`): Handler to be registered.
file_formats (str or list[str]): File formats to be handled by this
handler.
"""
if not isinstance(handler, BaseFileHandler):
raise TypeError(
f'handler must be a child of BaseFileHandler, not {type(handler)}')
if isinstance(file_formats, str):
file_formats = [file_formats]
if not is_list_of(file_formats, str):
raise TypeError('file_formats must be a str or a list of str')
for ext in file_formats:
file_handlers[ext] = handler
def register_handler(file_formats, **kwargs):
def wrap(cls):
_register_handler(cls(**kwargs), file_formats)
return cls
return wrap

View File

@ -0,0 +1,97 @@
# Copyright (c) OpenMMLab. All rights reserved.
# type: ignore
from io import StringIO
from .file_client import FileClient
def list_from_file(filename,
prefix='',
offset=0,
max_num=0,
encoding='utf-8',
file_client_args=None):
"""Load a text file and parse the content as a list of strings.
Note:
In v1.3.16 and later, ``list_from_file`` supports loading a text file
which can be storaged in different backends and parsing the content as
a list for strings.
Args:
filename (str): Filename.
prefix (str): The prefix to be inserted to the beginning of each item.
offset (int): The offset of lines.
max_num (int): The maximum number of lines to be read,
zeros and negatives mean no limitation.
encoding (str): Encoding used to open the file. Default utf-8.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Examples:
>>> list_from_file('/path/of/your/file') # disk
['hello', 'world']
>>> list_from_file('s3://path/of/your/file') # ceph or petrel
['hello', 'world']
Returns:
list[str]: A list of strings.
"""
cnt = 0
item_list = []
file_client = FileClient.infer_client(file_client_args, filename)
with StringIO(file_client.get_text(filename, encoding)) as f:
for _ in range(offset):
f.readline()
for line in f:
if 0 < max_num <= cnt:
break
item_list.append(prefix + line.rstrip('\n\r'))
cnt += 1
return item_list
def dict_from_file(filename,
key_type=str,
encoding='utf-8',
file_client_args=None):
"""Load a text file and parse the content as a dict.
Each line of the text file will be two or more columns split by
whitespaces or tabs. The first column will be parsed as dict keys, and
the following columns will be parsed as dict values.
Note:
In v1.3.16 and later, ``dict_from_file`` supports loading a text file
which can be storaged in different backends and parsing the content as
a dict.
Args:
filename(str): Filename.
key_type(type): Type of the dict keys. str is user by default and
type conversion will be performed if specified.
encoding (str): Encoding used to open the file. Default utf-8.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmcv.fileio.FileClient` for details.
Default: None.
Examples:
>>> dict_from_file('/path/of/your/file') # disk
{'key1': 'value1', 'key2': 'value2'}
>>> dict_from_file('s3://path/of/your/file') # ceph or petrel
{'key1': 'value1', 'key2': 'value2'}
Returns:
dict: The parsed contents.
"""
mapping = {}
file_client = FileClient.infer_client(file_client_args, filename)
with StringIO(file_client.get_text(filename, encoding)) as f:
for line in f:
items = line.rstrip('\n').split()
assert len(items) >= 2
key = key_type(items[0])
val = items[1:] if len(items) > 2 else items[1]
mapping[key] = val
return mapping

377
mmengine/utils/misc.py Normal file
View File

@ -0,0 +1,377 @@
# Copyright (c) OpenMMLab. All rights reserved.
import collections.abc
import functools
import itertools
import subprocess
import warnings
from collections import abc
from importlib import import_module
from inspect import getfullargspec
from itertools import repeat
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def is_str(x):
"""Whether the input is an string instance.
Note: This method is deprecated since python 2 is no longer supported.
"""
return isinstance(x, str)
def import_modules_from_strings(imports, allow_failed_imports=False):
"""Import modules from the given list of strings.
Args:
imports (list | str | None): The given module names to be imported.
allow_failed_imports (bool): If True, the failed imports will return
None. Otherwise, an ImportError is raise. Default: False.
Returns:
list[module] | module | None: The imported modules.
Examples:
>>> osp, sys = import_modules_from_strings(
... ['os.path', 'sys'])
>>> import os.path as osp_
>>> import sys as sys_
>>> assert osp == osp_
>>> assert sys == sys_
"""
if not imports:
return
single_import = False
if isinstance(imports, str):
single_import = True
imports = [imports]
if not isinstance(imports, list):
raise TypeError(
f'custom_imports must be a list but got type {type(imports)}')
imported = []
for imp in imports:
if not isinstance(imp, str):
raise TypeError(
f'{imp} is of type {type(imp)} and cannot be imported.')
try:
imported_tmp = import_module(imp)
except ImportError:
if allow_failed_imports:
warnings.warn(f'{imp} failed to import and is ignored.',
UserWarning)
imported_tmp = None
else:
raise ImportError
imported.append(imported_tmp)
if single_import:
imported = imported[0]
return imported
def iter_cast(inputs, dst_type, return_type=None):
"""Cast elements of an iterable object into some type.
Args:
inputs (Iterable): The input object.
dst_type (type): Destination type.
return_type (type, optional): If specified, the output object will be
converted to this type, otherwise an iterator.
Returns:
iterator or specified type: The converted object.
"""
if not isinstance(inputs, abc.Iterable):
raise TypeError('inputs must be an iterable object')
if not isinstance(dst_type, type):
raise TypeError('"dst_type" must be a valid type')
out_iterable = map(dst_type, inputs)
if return_type is None:
return out_iterable
else:
return return_type(out_iterable)
def list_cast(inputs, dst_type):
"""Cast elements of an iterable object into a list of some type.
A partial method of :func:`iter_cast`.
"""
return iter_cast(inputs, dst_type, return_type=list)
def tuple_cast(inputs, dst_type):
"""Cast elements of an iterable object into a tuple of some type.
A partial method of :func:`iter_cast`.
"""
return iter_cast(inputs, dst_type, return_type=tuple)
def is_seq_of(seq, expected_type, seq_type=None):
"""Check whether it is a sequence of some type.
Args:
seq (Sequence): The sequence to be checked.
expected_type (type): Expected type of sequence items.
seq_type (type, optional): Expected sequence type.
Returns:
bool: Whether the sequence is valid.
"""
if seq_type is None:
exp_seq_type = abc.Sequence
else:
assert isinstance(seq_type, type)
exp_seq_type = seq_type
if not isinstance(seq, exp_seq_type):
return False
for item in seq:
if not isinstance(item, expected_type):
return False
return True
def is_list_of(seq, expected_type):
"""Check whether it is a list of some type.
A partial method of :func:`is_seq_of`.
"""
return is_seq_of(seq, expected_type, seq_type=list)
def is_tuple_of(seq, expected_type):
"""Check whether it is a tuple of some type.
A partial method of :func:`is_seq_of`.
"""
return is_seq_of(seq, expected_type, seq_type=tuple)
def slice_list(in_list, lens):
"""Slice a list into several sub lists by a list of given length.
Args:
in_list (list): The list to be sliced.
lens(int or list): The expected length of each out list.
Returns:
list: A list of sliced list.
"""
if isinstance(lens, int):
assert len(in_list) % lens == 0
lens = [lens] * int(len(in_list) / lens)
if not isinstance(lens, list):
raise TypeError('"indices" must be an integer or a list of integers')
elif sum(lens) != len(in_list):
raise ValueError('sum of lens and list length does not '
f'match: {sum(lens)} != {len(in_list)}')
out_list = []
idx = 0
for i in range(len(lens)):
out_list.append(in_list[idx:idx + lens[i]])
idx += lens[i]
return out_list
def concat_list(in_list):
"""Concatenate a list of list into a single list.
Args:
in_list (list): The list of list to be merged.
Returns:
list: The concatenated flat list.
"""
return list(itertools.chain(*in_list))
def check_prerequisites(
prerequisites,
checker,
msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
'found, please install them first.'): # yapf: disable
"""A decorator factory to check if prerequisites are satisfied.
Args:
prerequisites (str of list[str]): Prerequisites to be checked.
checker (callable): The checker method that returns True if a
prerequisite is meet, False otherwise.
msg_tmpl (str): The message template with two variables.
Returns:
decorator: A specific decorator.
"""
def wrap(func):
@functools.wraps(func)
def wrapped_func(*args, **kwargs):
requirements = [prerequisites] if isinstance(
prerequisites, str) else prerequisites
missing = []
for item in requirements:
if not checker(item):
missing.append(item)
if missing:
print(msg_tmpl.format(', '.join(missing), func.__name__))
raise RuntimeError('Prerequisites not meet.')
else:
return func(*args, **kwargs)
return wrapped_func
return wrap
def _check_py_package(package):
try:
import_module(package)
except ImportError:
return False
else:
return True
def _check_executable(cmd):
if subprocess.call(f'which {cmd}', shell=True) != 0:
return False
else:
return True
def requires_package(prerequisites):
"""A decorator to check if some python packages are installed.
Example:
>>> @requires_package('numpy')
>>> func(arg1, args):
>>> return numpy.zeros(1)
array([0.])
>>> @requires_package(['numpy', 'non_package'])
>>> func(arg1, args):
>>> return numpy.zeros(1)
ImportError
"""
return check_prerequisites(prerequisites, checker=_check_py_package)
def requires_executable(prerequisites):
"""A decorator to check if some executable files are installed.
Example:
>>> @requires_executable('ffmpeg')
>>> func(arg1, args):
>>> print(1)
1
"""
return check_prerequisites(prerequisites, checker=_check_executable)
def deprecated_api_warning(name_dict, cls_name=None):
"""A decorator to check if some arguments are deprecate and try to replace
deprecate src_arg_name to dst_arg_name.
Args:
name_dict(dict):
key (str): Deprecate argument names.
val (str): Expected argument names.
Returns:
func: New function.
"""
def api_warning_wrapper(old_func):
@functools.wraps(old_func)
def new_func(*args, **kwargs):
# get the arg spec of the decorated method
args_info = getfullargspec(old_func)
# get name of the function
func_name = old_func.__name__
if cls_name is not None:
func_name = f'{cls_name}.{func_name}'
if args:
arg_names = args_info.args[:len(args)]
for src_arg_name, dst_arg_name in name_dict.items():
if src_arg_name in arg_names:
warnings.warn(
f'"{src_arg_name}" is deprecated in '
f'`{func_name}`, please use "{dst_arg_name}" '
'instead', DeprecationWarning)
arg_names[arg_names.index(src_arg_name)] = dst_arg_name
if kwargs:
for src_arg_name, dst_arg_name in name_dict.items():
if src_arg_name in kwargs:
assert dst_arg_name not in kwargs, (
f'The expected behavior is to replace '
f'the deprecated key `{src_arg_name}` to '
f'new key `{dst_arg_name}`, but got them '
f'in the arguments at the same time, which '
f'is confusing. `{src_arg_name} will be '
f'deprecated in the future, please '
f'use `{dst_arg_name}` instead.')
warnings.warn(
f'"{src_arg_name}" is deprecated in '
f'`{func_name}`, please use "{dst_arg_name}" '
'instead', DeprecationWarning)
kwargs[dst_arg_name] = kwargs.pop(src_arg_name)
# apply converted arguments to the decorated method
output = old_func(*args, **kwargs)
return output
return new_func
return api_warning_wrapper
def is_method_overridden(method, base_class, derived_class):
"""Check if a method of base class is overridden in derived class.
Args:
method (str): the method name to check.
base_class (type): the class of the base class.
derived_class (type | Any): the class or instance of the derived class.
"""
assert isinstance(base_class, type), \
"base_class doesn't accept instance, Please pass class instead."
if not isinstance(derived_class, type):
derived_class = derived_class.__class__
base_method = getattr(base_class, method)
derived_method = getattr(derived_class, method)
return derived_method != base_method
def has_method(obj: object, method: str) -> bool:
"""Check whether the object has a method.
Args:
method (str): The method name to check.
obj (object): The object to check.
Returns:
bool: True if the object has the method else False.
"""
return hasattr(obj, method) and callable(getattr(obj, method))

101
mmengine/utils/path.py Normal file
View File

@ -0,0 +1,101 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
from pathlib import Path
from .misc import is_str
def is_filepath(x):
return is_str(x) or isinstance(x, Path)
def fopen(filepath, *args, **kwargs):
if is_str(filepath):
return open(filepath, *args, **kwargs)
elif isinstance(filepath, Path):
return filepath.open(*args, **kwargs)
raise ValueError('`filepath` should be a string or a Path')
def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
if not osp.isfile(filename):
raise FileNotFoundError(msg_tmpl.format(filename))
def mkdir_or_exist(dir_name, mode=0o777):
if dir_name == '':
return
dir_name = osp.expanduser(dir_name)
os.makedirs(dir_name, mode=mode, exist_ok=True)
def symlink(src, dst, overwrite=True, **kwargs):
if os.path.lexists(dst) and overwrite:
os.remove(dst)
os.symlink(src, dst, **kwargs)
def scandir(dir_path, suffix=None, recursive=False, case_sensitive=True):
"""Scan a directory to find the interested files.
Args:
dir_path (str | :obj:`Path`): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
case_sensitive (bool, optional) : If set to False, ignore the case of
suffix. Default: True.
Returns:
A generator for all the interested files with relative paths.
"""
if isinstance(dir_path, (str, Path)):
dir_path = str(dir_path)
else:
raise TypeError('"dir_path" must be a string or Path object')
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
if suffix is not None and not case_sensitive:
suffix = suffix.lower() if isinstance(suffix, str) else tuple(
item.lower() for item in suffix)
root = dir_path
def _scandir(dir_path, suffix, recursive, case_sensitive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
rel_path = osp.relpath(entry.path, root)
_rel_path = rel_path if case_sensitive else rel_path.lower()
if suffix is None or _rel_path.endswith(suffix):
yield rel_path
elif recursive and os.path.isdir(entry.path):
# scan recursively if entry.path is a directory
yield from _scandir(entry.path, suffix, recursive,
case_sensitive)
return _scandir(dir_path, suffix, recursive, case_sensitive)
def find_vcs_root(path, markers=('.git', )):
"""Finds the root directory (including itself) of specified markers.
Args:
path (str): Path of directory or file.
markers (list[str], optional): List of file or directory names.
Returns:
The directory contained one of the markers or None if not found.
"""
if osp.isfile(path):
path = osp.dirname(path)
prev, cur = None, osp.abspath(osp.expanduser(path))
while cur != prev:
if any(osp.exists(osp.join(cur, marker)) for marker in markers):
return cur
prev, cur = cur, osp.split(cur)[0]
return None