# encoding: utf-8 """ @author: l1aoxingyu @contact: sherlockliao01@gmail.com """ import functools import inspect import logging import os from typing import Any import yaml from yacs.config import CfgNode as _CfgNode from ..utils.file_io import PathManager BASE_KEY = "_BASE_" class CfgNode(_CfgNode): """ Our own extended version of :class:`yacs.config.CfgNode`. It contains the following extra features: 1. The :meth:`merge_from_file` method supports the "_BASE_" key, which allows the new CfgNode to inherit all the attributes from the base configuration file. 2. Keys that start with "COMPUTED_" are treated as insertion-only "computed" attributes. They can be inserted regardless of whether the CfgNode is frozen or not. 3. With "allow_unsafe=True", it supports pyyaml tags that evaluate expressions in config. See examples in https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types Note that this may lead to arbitrary code execution: you must not load a config file from untrusted sources before manually inspecting the content of the file. """ @staticmethod def load_yaml_with_base(filename: str, allow_unsafe: bool = False): """ Just like `yaml.load(open(filename))`, but inherit attributes from its `_BASE_`. Args: filename (str): the file name of the current config. Will be used to find the base config file. allow_unsafe (bool): whether to allow loading the config file with `yaml.unsafe_load`. Returns: (dict): the loaded yaml """ with PathManager.open(filename, "r") as f: try: cfg = yaml.safe_load(f) except yaml.constructor.ConstructorError: if not allow_unsafe: raise logger = logging.getLogger(__name__) logger.warning( "Loading config {} with yaml.unsafe_load. Your machine may " "be at risk if the file contains malicious content.".format( filename ) ) f.close() with open(filename, "r") as f: cfg = yaml.unsafe_load(f) def merge_a_into_b(a, b): # merge dict a into dict b. values in a will overwrite b. for k, v in a.items(): if isinstance(v, dict) and k in b: assert isinstance( b[k], dict ), "Cannot inherit key '{}' from base!".format(k) merge_a_into_b(v, b[k]) else: b[k] = v if BASE_KEY in cfg: base_cfg_file = cfg[BASE_KEY] if base_cfg_file.startswith("~"): base_cfg_file = os.path.expanduser(base_cfg_file) if not any( map(base_cfg_file.startswith, ["/", "https://", "http://"]) ): # the path to base cfg is relative to the config file itself. base_cfg_file = os.path.join( os.path.dirname(filename), base_cfg_file ) base_cfg = CfgNode.load_yaml_with_base( base_cfg_file, allow_unsafe=allow_unsafe ) del cfg[BASE_KEY] merge_a_into_b(cfg, base_cfg) return base_cfg return cfg def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = False): """ Merge configs from a given yaml file. Args: cfg_filename: the file name of the yaml config. allow_unsafe: whether to allow loading the config file with `yaml.unsafe_load`. """ loaded_cfg = CfgNode.load_yaml_with_base( cfg_filename, allow_unsafe=allow_unsafe ) loaded_cfg = type(self)(loaded_cfg) self.merge_from_other_cfg(loaded_cfg) # Forward the following calls to base, but with a check on the BASE_KEY. def merge_from_other_cfg(self, cfg_other): """ Args: cfg_other (CfgNode): configs to merge from. """ assert ( BASE_KEY not in cfg_other ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) return super().merge_from_other_cfg(cfg_other) def merge_from_list(self, cfg_list: list): """ Args: cfg_list (list): list of configs to merge from. """ keys = set(cfg_list[0::2]) assert ( BASE_KEY not in keys ), "The reserved key '{}' can only be used in files!".format(BASE_KEY) return super().merge_from_list(cfg_list) def __setattr__(self, name: str, val: Any): if name.startswith("COMPUTED_"): if name in self: old_val = self[name] if old_val == val: return raise KeyError( "Computed attributed '{}' already exists " "with a different value! old={}, new={}.".format( name, old_val, val ) ) self[name] = val else: super().__setattr__(name, val) global_cfg = CfgNode() def get_cfg() -> CfgNode: """ Get a copy of the default config. Returns: a fastreid CfgNode instance. """ from .defaults import _C return _C.clone() def set_global_cfg(cfg: CfgNode) -> None: """ Let the global config point to the given cfg. Assume that the given "cfg" has the key "KEY", after calling `set_global_cfg(cfg)`, the key can be accessed by: :: from detectron2.config import global_cfg print(global_cfg.KEY) By using a hacky global config, you can access these configs anywhere, without having to pass the config object or the values deep into the code. This is a hacky feature introduced for quick prototyping / research exploration. """ global global_cfg global_cfg.clear() global_cfg.update(cfg) def configurable(init_func=None, *, from_config=None): """ Decorate a function or a class's __init__ method so that it can be called with a :class:`CfgNode` object using a :func:`from_config` function that translates :class:`CfgNode` to arguments. Examples: :: # Usage 1: Decorator on __init__: class A: @configurable def __init__(self, a, b=2, c=3): pass @classmethod def from_config(cls, cfg): # 'cfg' must be the first argument # Returns kwargs to be passed to __init__ return {"a": cfg.A, "b": cfg.B} a1 = A(a=1, b=2) # regular construction a2 = A(cfg) # construct with a cfg a3 = A(cfg, b=3, c=4) # construct with extra overwrite # Usage 2: Decorator on any function. Needs an extra from_config argument: @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B}) def a_func(a, b=2, c=3): pass a1 = a_func(a=1, b=2) # regular call a2 = a_func(cfg) # call with a cfg a3 = a_func(cfg, b=3, c=4) # call with extra overwrite Args: init_func (callable): a class's ``__init__`` method in usage 1. The class must have a ``from_config`` classmethod which takes `cfg` as the first argument. from_config (callable): the from_config function in usage 2. It must take `cfg` as its first argument. """ def check_docstring(func): if func.__module__.startswith("fastreid."): assert ( func.__doc__ is not None and "experimental" in func.__doc__.lower() ), f"configurable {func} should be marked experimental" if init_func is not None: assert ( inspect.isfunction(init_func) and from_config is None and init_func.__name__ == "__init__" ), "Incorrect use of @configurable. Check API documentation for examples." check_docstring(init_func) @functools.wraps(init_func) def wrapped(self, *args, **kwargs): try: from_config_func = type(self).from_config except AttributeError as e: raise AttributeError( "Class with @configurable must have a 'from_config' classmethod." ) from e if not inspect.ismethod(from_config_func): raise TypeError("Class with @configurable must have a 'from_config' classmethod.") if _called_with_cfg(*args, **kwargs): explicit_args = _get_args_from_config(from_config_func, *args, **kwargs) init_func(self, **explicit_args) else: init_func(self, *args, **kwargs) return wrapped else: if from_config is None: return configurable # @configurable() is made equivalent to @configurable assert inspect.isfunction( from_config ), "from_config argument of configurable must be a function!" def wrapper(orig_func): check_docstring(orig_func) @functools.wraps(orig_func) def wrapped(*args, **kwargs): if _called_with_cfg(*args, **kwargs): explicit_args = _get_args_from_config(from_config, *args, **kwargs) return orig_func(**explicit_args) else: return orig_func(*args, **kwargs) return wrapped return wrapper def _get_args_from_config(from_config_func, *args, **kwargs): """ Use `from_config` to obtain explicit arguments. Returns: dict: arguments to be used for cls.__init__ """ signature = inspect.signature(from_config_func) if list(signature.parameters.keys())[0] != "cfg": if inspect.isfunction(from_config_func): name = from_config_func.__name__ else: name = f"{from_config_func.__self__}.from_config" raise TypeError(f"{name} must take 'cfg' as the first argument!") support_var_arg = any( param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD] for param in signature.parameters.values() ) if support_var_arg: # forward all arguments to from_config, if from_config accepts them ret = from_config_func(*args, **kwargs) else: # forward supported arguments to from_config supported_arg_names = set(signature.parameters.keys()) extra_kwargs = {} for name in list(kwargs.keys()): if name not in supported_arg_names: extra_kwargs[name] = kwargs.pop(name) ret = from_config_func(*args, **kwargs) # forward the other arguments to __init__ ret.update(extra_kwargs) return ret def _called_with_cfg(*args, **kwargs): """ Returns: bool: whether the arguments contain CfgNode and should be considered forwarded to from_config. """ if len(args) and isinstance(args[0], _CfgNode): return True if isinstance(kwargs.pop("cfg", None), _CfgNode): return True # `from_config`'s first argument is forced to be "cfg". # So the above check covers all cases. return False