diff --git a/timm/models/factory.py b/timm/models/factory.py index 03d8cc1f..70209c96 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -39,11 +39,6 @@ def create_model( kwargs.pop('bn_momentum', None) kwargs.pop('bn_eps', None) - # Parameters that aren't supported by all models should default to None in command line args, - # remove them if they are present and not set so that non-supporting models don't break. - if kwargs.get('drop_block_rate', None) is None: - kwargs.pop('drop_block_rate', None) - # handle backwards compat with drop_connect -> drop_path change drop_connect_rate = kwargs.pop('drop_connect_rate', None) if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: @@ -51,8 +46,10 @@ def create_model( " Setting drop_path to %f." % drop_connect_rate) kwargs['drop_path_rate'] = drop_connect_rate - if kwargs.get('drop_path_rate', None) is None: - kwargs.pop('drop_path_rate', None) + # Parameters that aren't supported by all models or are intended to only override model defaults if set + # should default to None in command line args/cfg. Remove them if they are present and not set so that + # non-supporting models don't break and default args remain in effect. + kwargs = {k: v for k, v in kwargs.items() if v is not None} with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): if is_model(model_name): diff --git a/train.py b/train.py index ff421d53..260de18b 100755 --- a/train.py +++ b/train.py @@ -74,8 +74,8 @@ parser.add_argument('--no-resume-opt', action='store_true', default=False, help='prevent resume of optimizer state when resuming model') parser.add_argument('--num-classes', type=int, default=1000, metavar='N', help='number of label classes (default: 1000)') -parser.add_argument('--gp', default='avg', type=str, metavar='POOL', - help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")') +parser.add_argument('--gp', default=None, type=str, metavar='POOL', + help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') parser.add_argument('--img-size', type=int, default=None, metavar='N', help='Image patch size (default: None => model default)') parser.add_argument('--crop-pct', default=None, type=float, diff --git a/validate.py b/validate.py index 587ac6a6..ceca7014 100755 --- a/validate.py +++ b/validate.py @@ -64,6 +64,8 @@ parser.add_argument('--num-classes', type=int, default=1000, help='Number classes in dataset') parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', help='path to class to idx mapping file (default: "")') +parser.add_argument('--gp', default=None, type=str, metavar='POOL', + help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') parser.add_argument('--log-freq', default=10, type=int, metavar='N', help='batch logging frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', @@ -127,6 +129,7 @@ def validate(args): pretrained=args.pretrained, num_classes=args.num_classes, in_chans=3, + global_pool=args.gp, scriptable=args.torchscript) if args.checkpoint: