mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Revert ml-decoder changes to model factory and train script
This commit is contained in:
parent
72b57163d1
commit
d98aa47d12
@ -29,7 +29,6 @@ def create_model(
|
|||||||
scriptable=None,
|
scriptable=None,
|
||||||
exportable=None,
|
exportable=None,
|
||||||
no_jit=None,
|
no_jit=None,
|
||||||
use_ml_decoder_head=False,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Create a model
|
"""Create a model
|
||||||
|
|
||||||
@ -81,10 +80,6 @@ def create_model(
|
|||||||
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
|
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
|
||||||
model = create_fn(pretrained=pretrained, **kwargs)
|
model = create_fn(pretrained=pretrained, **kwargs)
|
||||||
|
|
||||||
if use_ml_decoder_head:
|
|
||||||
from timm.models.layers.ml_decoder import add_ml_decoder_head
|
|
||||||
model = add_ml_decoder_head(model)
|
|
||||||
|
|
||||||
if checkpoint_path:
|
if checkpoint_path:
|
||||||
load_checkpoint(model, checkpoint_path)
|
load_checkpoint(model, checkpoint_path)
|
||||||
|
|
||||||
|
4
train.py
4
train.py
@ -115,7 +115,6 @@ parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
|
|||||||
help='input batch size for training (default: 128)')
|
help='input batch size for training (default: 128)')
|
||||||
parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
|
parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
|
||||||
help='validation batch size override (default: None)')
|
help='validation batch size override (default: None)')
|
||||||
parser.add_argument('--use-ml-decoder-head', type=int, default=0)
|
|
||||||
|
|
||||||
# Optimizer parameters
|
# Optimizer parameters
|
||||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||||
@ -380,8 +379,7 @@ def main():
|
|||||||
bn_momentum=args.bn_momentum,
|
bn_momentum=args.bn_momentum,
|
||||||
bn_eps=args.bn_eps,
|
bn_eps=args.bn_eps,
|
||||||
scriptable=args.torchscript,
|
scriptable=args.torchscript,
|
||||||
checkpoint_path=args.initial_checkpoint,
|
checkpoint_path=args.initial_checkpoint)
|
||||||
use_ml_decoder_head=args.use_ml_decoder_head)
|
|
||||||
if args.num_classes is None:
|
if args.num_classes is None:
|
||||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||||
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
|
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
|
||||||
|
Loading…
x
Reference in New Issue
Block a user