fast-reid/fastreid/config/config.py

320 lines
11 KiB
Python

# encoding: utf-8
"""
@author: l1aoxingyu
@contact: sherlockliao01@gmail.com
"""
import functools
import inspect
import logging
import os
from typing import Any
import yaml
from yacs.config import CfgNode as _CfgNode
from ..utils.file_io import PathManager
BASE_KEY = "_BASE_"
class CfgNode(_CfgNode):
"""
Our own extended version of :class:`yacs.config.CfgNode`.
It contains the following extra features:
1. The :meth:`merge_from_file` method supports the "_BASE_" key,
which allows the new CfgNode to inherit all the attributes from the
base configuration file.
2. Keys that start with "COMPUTED_" are treated as insertion-only
"computed" attributes. They can be inserted regardless of whether
the CfgNode is frozen or not.
3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
expressions in config. See examples in
https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
Note that this may lead to arbitrary code execution: you must not
load a config file from untrusted sources before manually inspecting
the content of the file.
"""
@staticmethod
def load_yaml_with_base(filename: str, allow_unsafe: bool = False):
"""
Just like `yaml.load(open(filename))`, but inherit attributes from its
`_BASE_`.
Args:
filename (str): the file name of the current config. Will be used to
find the base config file.
allow_unsafe (bool): whether to allow loading the config file with
`yaml.unsafe_load`.
Returns:
(dict): the loaded yaml
"""
with PathManager.open(filename, "r") as f:
try:
cfg = yaml.safe_load(f)
except yaml.constructor.ConstructorError:
if not allow_unsafe:
raise
logger = logging.getLogger(__name__)
logger.warning(
"Loading config {} with yaml.unsafe_load. Your machine may "
"be at risk if the file contains malicious content.".format(
filename
)
)
f.close()
with open(filename, "r") as f:
cfg = yaml.unsafe_load(f)
def merge_a_into_b(a, b):
# merge dict a into dict b. values in a will overwrite b.
for k, v in a.items():
if isinstance(v, dict) and k in b:
assert isinstance(
b[k], dict
), "Cannot inherit key '{}' from base!".format(k)
merge_a_into_b(v, b[k])
else:
b[k] = v
if BASE_KEY in cfg:
base_cfg_file = cfg[BASE_KEY]
if base_cfg_file.startswith("~"):
base_cfg_file = os.path.expanduser(base_cfg_file)
if not any(
map(base_cfg_file.startswith, ["/", "https://", "http://"])
):
# the path to base cfg is relative to the config file itself.
base_cfg_file = os.path.join(
os.path.dirname(filename), base_cfg_file
)
base_cfg = CfgNode.load_yaml_with_base(
base_cfg_file, allow_unsafe=allow_unsafe
)
del cfg[BASE_KEY]
merge_a_into_b(cfg, base_cfg)
return base_cfg
return cfg
def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = False):
"""
Merge configs from a given yaml file.
Args:
cfg_filename: the file name of the yaml config.
allow_unsafe: whether to allow loading the config file with
`yaml.unsafe_load`.
"""
loaded_cfg = CfgNode.load_yaml_with_base(
cfg_filename, allow_unsafe=allow_unsafe
)
loaded_cfg = type(self)(loaded_cfg)
self.merge_from_other_cfg(loaded_cfg)
# Forward the following calls to base, but with a check on the BASE_KEY.
def merge_from_other_cfg(self, cfg_other):
"""
Args:
cfg_other (CfgNode): configs to merge from.
"""
assert (
BASE_KEY not in cfg_other
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
return super().merge_from_other_cfg(cfg_other)
def merge_from_list(self, cfg_list: list):
"""
Args:
cfg_list (list): list of configs to merge from.
"""
keys = set(cfg_list[0::2])
assert (
BASE_KEY not in keys
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
return super().merge_from_list(cfg_list)
def __setattr__(self, name: str, val: Any):
if name.startswith("COMPUTED_"):
if name in self:
old_val = self[name]
if old_val == val:
return
raise KeyError(
"Computed attributed '{}' already exists "
"with a different value! old={}, new={}.".format(
name, old_val, val
)
)
self[name] = val
else:
super().__setattr__(name, val)
global_cfg = CfgNode()
def get_cfg() -> CfgNode:
"""
Get a copy of the default config.
Returns:
a fastreid CfgNode instance.
"""
from .defaults import _C
return _C.clone()
def set_global_cfg(cfg: CfgNode) -> None:
"""
Let the global config point to the given cfg.
Assume that the given "cfg" has the key "KEY", after calling
`set_global_cfg(cfg)`, the key can be accessed by:
::
from detectron2.config import global_cfg
print(global_cfg.KEY)
By using a hacky global config, you can access these configs anywhere,
without having to pass the config object or the values deep into the code.
This is a hacky feature introduced for quick prototyping / research exploration.
"""
global global_cfg
global_cfg.clear()
global_cfg.update(cfg)
def configurable(init_func=None, *, from_config=None):
"""
Decorate a function or a class's __init__ method so that it can be called
with a :class:`CfgNode` object using a :func:`from_config` function that translates
:class:`CfgNode` to arguments.
Examples:
::
# Usage 1: Decorator on __init__:
class A:
@configurable
def __init__(self, a, b=2, c=3):
pass
@classmethod
def from_config(cls, cfg): # 'cfg' must be the first argument
# Returns kwargs to be passed to __init__
return {"a": cfg.A, "b": cfg.B}
a1 = A(a=1, b=2) # regular construction
a2 = A(cfg) # construct with a cfg
a3 = A(cfg, b=3, c=4) # construct with extra overwrite
# Usage 2: Decorator on any function. Needs an extra from_config argument:
@configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B})
def a_func(a, b=2, c=3):
pass
a1 = a_func(a=1, b=2) # regular call
a2 = a_func(cfg) # call with a cfg
a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
Args:
init_func (callable): a class's ``__init__`` method in usage 1. The
class must have a ``from_config`` classmethod which takes `cfg` as
the first argument.
from_config (callable): the from_config function in usage 2. It must take `cfg`
as its first argument.
"""
def check_docstring(func):
if func.__module__.startswith("fastreid."):
assert (
func.__doc__ is not None and "experimental" in func.__doc__.lower()
), f"configurable {func} should be marked experimental"
if init_func is not None:
assert (
inspect.isfunction(init_func)
and from_config is None
and init_func.__name__ == "__init__"
), "Incorrect use of @configurable. Check API documentation for examples."
check_docstring(init_func)
@functools.wraps(init_func)
def wrapped(self, *args, **kwargs):
try:
from_config_func = type(self).from_config
except AttributeError as e:
raise AttributeError(
"Class with @configurable must have a 'from_config' classmethod."
) from e
if not inspect.ismethod(from_config_func):
raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
if _called_with_cfg(*args, **kwargs):
explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
init_func(self, **explicit_args)
else:
init_func(self, *args, **kwargs)
return wrapped
else:
if from_config is None:
return configurable # @configurable() is made equivalent to @configurable
assert inspect.isfunction(
from_config
), "from_config argument of configurable must be a function!"
def wrapper(orig_func):
check_docstring(orig_func)
@functools.wraps(orig_func)
def wrapped(*args, **kwargs):
if _called_with_cfg(*args, **kwargs):
explicit_args = _get_args_from_config(from_config, *args, **kwargs)
return orig_func(**explicit_args)
else:
return orig_func(*args, **kwargs)
return wrapped
return wrapper
def _get_args_from_config(from_config_func, *args, **kwargs):
"""
Use `from_config` to obtain explicit arguments.
Returns:
dict: arguments to be used for cls.__init__
"""
signature = inspect.signature(from_config_func)
if list(signature.parameters.keys())[0] != "cfg":
if inspect.isfunction(from_config_func):
name = from_config_func.__name__
else:
name = f"{from_config_func.__self__}.from_config"
raise TypeError(f"{name} must take 'cfg' as the first argument!")
support_var_arg = any(
param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
for param in signature.parameters.values()
)
if support_var_arg: # forward all arguments to from_config, if from_config accepts them
ret = from_config_func(*args, **kwargs)
else:
# forward supported arguments to from_config
supported_arg_names = set(signature.parameters.keys())
extra_kwargs = {}
for name in list(kwargs.keys()):
if name not in supported_arg_names:
extra_kwargs[name] = kwargs.pop(name)
ret = from_config_func(*args, **kwargs)
# forward the other arguments to __init__
ret.update(extra_kwargs)
return ret
def _called_with_cfg(*args, **kwargs):
"""
Returns:
bool: whether the arguments contain CfgNode and should be considered
forwarded to from_config.
"""
if len(args) and isinstance(args[0], _CfgNode):
return True
if isinstance(kwargs.pop("cfg", None), _CfgNode):
return True
# `from_config`'s first argument is forced to be "cfg".
# So the above check covers all cases.
return False