Add global_pool (--gp) arg changes to allow passing 'fast' easily for train/validate to avoid channels_last issue with AdaptiveAvgPool

This commit is contained in:
Ross Wightman 2020-09-02 16:13:47 -07:00
parent 9c297ec67d
commit 751b0bba98
3 changed files with 9 additions and 9 deletions

View File

@ -39,11 +39,6 @@ def create_model(
kwargs.pop('bn_momentum', None)
kwargs.pop('bn_eps', None)
# Parameters that aren't supported by all models should default to None in command line args,
# remove them if they are present and not set so that non-supporting models don't break.
if kwargs.get('drop_block_rate', None) is None:
kwargs.pop('drop_block_rate', None)
# handle backwards compat with drop_connect -> drop_path change
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
@ -51,8 +46,10 @@ def create_model(
" Setting drop_path to %f." % drop_connect_rate)
kwargs['drop_path_rate'] = drop_connect_rate
if kwargs.get('drop_path_rate', None) is None:
kwargs.pop('drop_path_rate', None)
# Parameters that aren't supported by all models or are intended to only override model defaults if set
# should default to None in command line args/cfg. Remove them if they are present and not set so that
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
if is_model(model_name):

View File

@ -74,8 +74,8 @@ parser.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
help='number of label classes (default: 1000)')
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=None, type=float,

View File

@ -64,6 +64,8 @@ parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--log-freq', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -127,6 +129,7 @@ def validate(args):
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=3,
global_pool=args.gp,
scriptable=args.torchscript)
if args.checkpoint: