diff --git a/mmcv/utils/path.py b/mmcv/utils/path.py index aed078fe9..3a4d03844 100644 --- a/mmcv/utils/path.py +++ b/mmcv/utils/path.py @@ -63,16 +63,12 @@ def scandir(dir_path, suffix=None, recursive=False): for entry in os.scandir(dir_path): if not entry.name.startswith('.') and entry.is_file(): rel_path = osp.relpath(entry.path, root) - if suffix is None: + if suffix is None or rel_path.endswith(suffix): yield rel_path - elif rel_path.endswith(suffix): - yield rel_path - else: - if recursive: - yield from _scandir( - entry.path, suffix=suffix, recursive=recursive) - else: - continue + elif recursive and os.path.isdir(entry.path): + # scan recursively if entry.path is a directory + yield from _scandir( + entry.path, suffix=suffix, recursive=recursive) return _scandir(dir_path, suffix=suffix, recursive=recursive) diff --git a/tests/data/for_scan/.file b/tests/data/for_scan/.file new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_utils/test_path.py b/tests/test_utils/test_path.py index 42f308ef6..aa6537eaf 100644 --- a/tests/test_utils/test_path.py +++ b/tests/test_utils/test_path.py @@ -40,12 +40,13 @@ def test_scandir(): filenames_recursive = [ 'a.bin', '1.txt', '2.txt', '1.json', '2.json', 'sub/1.json', - 'sub/1.txt' + 'sub/1.txt', '.file' ] - assert set(mmcv.scandir(folder, - recursive=True)) == set(filenames_recursive) - assert set(mmcv.scandir(Path(folder), - recursive=True)) == set(filenames_recursive) + # .file starts with '.' and is a file so it will not be scanned + assert set(mmcv.scandir(folder, recursive=True)) == set( + [filename for filename in filenames_recursive if filename != '.file']) + assert set(mmcv.scandir(Path(folder), recursive=True)) == set( + [filename for filename in filenames_recursive if filename != '.file']) assert set(mmcv.scandir(folder, '.txt', recursive=True)) == set([ filename for filename in filenames_recursive if filename.endswith('.txt')