[Fix] Fix searching path of config when invoking train, test and gridsearch command (#75)

* Add `followlinks` argument in `recursively_find`.

* When searching config, `pkg/.mim/configs`  is the default path to be searched if it exists.
This commit is contained in:
Ma Zerun 2021-07-27 04:02:24 -04:00 committed by GitHub
parent eb35cceb84
commit d70f9cf14b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 5 deletions

View File

@ -222,7 +222,16 @@ def gridsearch(
pkg_root = get_installed_path(package) pkg_root = get_installed_path(package)
if not osp.exists(config): if not osp.exists(config):
files = recursively_find(pkg_root, osp.basename(config)) # configs is put in pkg/.mim in PR #68
config_root = osp.join(pkg_root, '.mim', 'configs')
if not osp.exists(config_root):
# If not pkg/.mim/config, try to search the whole pkg root.
config_root = pkg_root
# pkg/.mim/configs is a symbolic link to the real config folder,
# so we need to follow links.
files = recursively_find(
pkg_root, osp.basename(config), followlinks=True)
if len(files) == 0: if len(files) == 0:
msg = (f"The path {config} doesn't exist and we can not " msg = (f"The path {config} doesn't exist and we can not "

View File

@ -211,7 +211,16 @@ def test(
pkg_root = get_installed_path(package) pkg_root = get_installed_path(package)
if not osp.exists(config): if not osp.exists(config):
files = recursively_find(pkg_root, osp.basename(config)) # configs is put in pkg/.mim in PR #68
config_root = osp.join(pkg_root, '.mim', 'configs')
if not osp.exists(config_root):
# If not pkg/.mim/config, try to search the whole pkg root.
config_root = pkg_root
# pkg/.mim/configs is a symbolic link to the real config folder,
# so we need to follow links.
files = recursively_find(
pkg_root, osp.basename(config), followlinks=True)
if len(files) == 0: if len(files) == 0:
msg = (f"The path {config} doesn't exist and we can not find " msg = (f"The path {config} doesn't exist and we can not find "

View File

@ -188,7 +188,16 @@ def train(
pkg_root = get_installed_path(package) pkg_root = get_installed_path(package)
if not osp.exists(config): if not osp.exists(config):
files = recursively_find(pkg_root, osp.basename(config)) # configs is put in pkg/.mim in PR #68
config_root = osp.join(pkg_root, '.mim', 'configs')
if not osp.exists(config_root):
# If not pkg/.mim/config, try to search the whole pkg root.
config_root = pkg_root
# pkg/.mim/configs is a symbolic link to the real config folder,
# so we need to follow links.
files = recursively_find(
pkg_root, osp.basename(config), followlinks=True)
if len(files) == 0: if len(files) == 0:
msg = (f"The path {config} doesn't exist and we can not find " msg = (f"The path {config} doesn't exist and we can not find "

View File

@ -369,18 +369,19 @@ def cast2lowercase(input: Union[list, tuple, str]) -> Any:
return outputs return outputs
def recursively_find(root: str, base_name: str) -> list: def recursively_find(root: str, base_name: str, followlinks=False) -> list:
"""Recursive list a directory, return all files with a given base_name. """Recursive list a directory, return all files with a given base_name.
Args: Args:
root (str): The root directory to list. root (str): The root directory to list.
base_name (str): The base_name. base_name (str): The base_name.
followlinks (bool): Follow symbolic links. Defaults to False.
Return: Return:
Files with given base_name. Files with given base_name.
""" """
files = [] files = []
for _root, _, _files in os.walk(root): for _root, _, _files in os.walk(root, followlinks=followlinks):
if base_name in _files: if base_name in _files:
files.append(osp.join(_root, base_name)) files.append(osp.join(_root, base_name))