diff --git a/timm/utils.py b/timm/utils.py index 4739064f..65255b53 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -326,3 +326,11 @@ def setup_default_logging(default_level=logging.INFO, log_path=''): file_formatter = logging.Formatter("%(asctime)s - %(name)20s: [%(levelname)8s] - %(message)s") file_handler.setFormatter(file_formatter) logging.root.addHandler(file_handler) + + +def add_bool_arg(parser, name, default=False, help=''): + dest_name = name.replace('-', '_') + group = parser.add_mutually_exclusive_group(required=False) + group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) + group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) + parser.set_defaults(**{dest_name: default})