mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Pass drop connect arg through to EfficientNet models
This commit is contained in:
parent
31453b039e
commit
7b83e67f77
@ -25,12 +25,13 @@ def create_model(
|
||||
"""
|
||||
margs = dict(pretrained=pretrained, num_classes=num_classes, in_chans=in_chans)
|
||||
|
||||
# Not all models have support for batchnorm params passed as args, only gen_efficientnet variants
|
||||
supports_bn_params = is_model_in_modules(model_name, ['gen_efficientnet'])
|
||||
if not supports_bn_params and any([x in kwargs for x in ['bn_tf', 'bn_momentum', 'bn_eps']]):
|
||||
# Only gen_efficientnet models have support for batchnorm params or drop_connect_rate passed as args
|
||||
is_efficientnet = is_model_in_modules(model_name, ['gen_efficientnet'])
|
||||
if not is_efficientnet:
|
||||
kwargs.pop('bn_tf', None)
|
||||
kwargs.pop('bn_momentum', None)
|
||||
kwargs.pop('bn_eps', None)
|
||||
kwargs.pop('drop_connect_rate', None)
|
||||
|
||||
if is_model(model_name):
|
||||
create_fn = model_entrypoint(model_name)
|
||||
|
5
train.py
5
train.py
@ -65,6 +65,8 @@ parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for training (default: 32)')
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
||||
help='Dropout rate (default: 0.)')
|
||||
parser.add_argument('--drop-connect', type=float, default=0.0, metavar='DROP',
|
||||
help='Drop connect rate (default: 0.)')
|
||||
# Optimizer parameters
|
||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "sgd"')
|
||||
@ -208,6 +210,7 @@ def main():
|
||||
pretrained=args.pretrained,
|
||||
num_classes=args.num_classes,
|
||||
drop_rate=args.drop,
|
||||
drop_connect_rate=args.drop_connect,
|
||||
global_pool=args.gp,
|
||||
bn_tf=args.bn_tf,
|
||||
bn_momentum=args.bn_momentum,
|
||||
@ -253,7 +256,7 @@ def main():
|
||||
if args.local_rank == 0:
|
||||
logging.info('Restoring NVIDIA AMP state from checkpoint')
|
||||
amp.load_state_dict(resume_state['amp'])
|
||||
resume_state = None
|
||||
resume_state = None # clear it
|
||||
|
||||
model_ema = None
|
||||
if args.model_ema:
|
||||
|
Loading…
x
Reference in New Issue
Block a user