mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add global_pool (--gp) arg changes to allow passing 'fast' easily for train/validate to avoid channels_last issue with AdaptiveAvgPool
This commit is contained in:
parent
9c297ec67d
commit
751b0bba98
@ -39,11 +39,6 @@ def create_model(
|
||||
kwargs.pop('bn_momentum', None)
|
||||
kwargs.pop('bn_eps', None)
|
||||
|
||||
# Parameters that aren't supported by all models should default to None in command line args,
|
||||
# remove them if they are present and not set so that non-supporting models don't break.
|
||||
if kwargs.get('drop_block_rate', None) is None:
|
||||
kwargs.pop('drop_block_rate', None)
|
||||
|
||||
# handle backwards compat with drop_connect -> drop_path change
|
||||
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
|
||||
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:
|
||||
@ -51,8 +46,10 @@ def create_model(
|
||||
" Setting drop_path to %f." % drop_connect_rate)
|
||||
kwargs['drop_path_rate'] = drop_connect_rate
|
||||
|
||||
if kwargs.get('drop_path_rate', None) is None:
|
||||
kwargs.pop('drop_path_rate', None)
|
||||
# Parameters that aren't supported by all models or are intended to only override model defaults if set
|
||||
# should default to None in command line args/cfg. Remove them if they are present and not set so that
|
||||
# non-supporting models don't break and default args remain in effect.
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
|
||||
if is_model(model_name):
|
||||
|
4
train.py
4
train.py
@ -74,8 +74,8 @@ parser.add_argument('--no-resume-opt', action='store_true', default=False,
|
||||
help='prevent resume of optimizer state when resuming model')
|
||||
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
|
||||
help='number of label classes (default: 1000)')
|
||||
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
|
||||
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
|
||||
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
||||
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
||||
parser.add_argument('--img-size', type=int, default=None, metavar='N',
|
||||
help='Image patch size (default: None => model default)')
|
||||
parser.add_argument('--crop-pct', default=None, type=float,
|
||||
|
@ -64,6 +64,8 @@ parser.add_argument('--num-classes', type=int, default=1000,
|
||||
help='Number classes in dataset')
|
||||
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
|
||||
help='path to class to idx mapping file (default: "")')
|
||||
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
||||
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
||||
parser.add_argument('--log-freq', default=10, type=int,
|
||||
metavar='N', help='batch logging frequency (default: 10)')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
@ -127,6 +129,7 @@ def validate(args):
|
||||
pretrained=args.pretrained,
|
||||
num_classes=args.num_classes,
|
||||
in_chans=3,
|
||||
global_pool=args.gp,
|
||||
scriptable=args.torchscript)
|
||||
|
||||
if args.checkpoint:
|
||||
|
Loading…
x
Reference in New Issue
Block a user