[Fix]: fix missing check of directory in scandir (#1110)

* Missing check for dir in the 'else' clause

Fixing issue when recursively scanning directories with filenames starting with '.'  Without this fix, the `if not entry.name.startswith('.') and entry.is_file()` logic falls through to the `else` clause which in the current code base will error out as it encounters '.' files (e.g. .DS_Store)

* Updated code per comments

* fixing indentation

* fix indenterror and add comment

* remove .DS_Store and add .file

Co-authored-by: zhouzaida <zhouzaida@163.com>
This commit is contained in:
achaiah 2021-06-29 08:31:00 -05:00 committed by GitHub
parent 76d9bf1efb
commit 21845db455
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 14 deletions

View File

@ -63,16 +63,12 @@ def scandir(dir_path, suffix=None, recursive=False):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
rel_path = osp.relpath(entry.path, root)
if suffix is None:
if suffix is None or rel_path.endswith(suffix):
yield rel_path
elif rel_path.endswith(suffix):
yield rel_path
else:
if recursive:
yield from _scandir(
entry.path, suffix=suffix, recursive=recursive)
else:
continue
elif recursive and os.path.isdir(entry.path):
# scan recursively if entry.path is a directory
yield from _scandir(
entry.path, suffix=suffix, recursive=recursive)
return _scandir(dir_path, suffix=suffix, recursive=recursive)

View File

View File

@ -40,12 +40,13 @@ def test_scandir():
filenames_recursive = [
'a.bin', '1.txt', '2.txt', '1.json', '2.json', 'sub/1.json',
'sub/1.txt'
'sub/1.txt', '.file'
]
assert set(mmcv.scandir(folder,
recursive=True)) == set(filenames_recursive)
assert set(mmcv.scandir(Path(folder),
recursive=True)) == set(filenames_recursive)
# .file starts with '.' and is a file so it will not be scanned
assert set(mmcv.scandir(folder, recursive=True)) == set(
[filename for filename in filenames_recursive if filename != '.file'])
assert set(mmcv.scandir(Path(folder), recursive=True)) == set(
[filename for filename in filenames_recursive if filename != '.file'])
assert set(mmcv.scandir(folder, '.txt', recursive=True)) == set([
filename for filename in filenames_recursive
if filename.endswith('.txt')