mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add parsable json results output for train.py, tweak --pretrained-path to force head adaptation
This commit is contained in:
parent
e0079c92da
commit
f2fdd97e9f
@ -261,7 +261,7 @@ def _filter_kwargs(kwargs, names):
|
||||
kwargs.pop(n, None)
|
||||
|
||||
|
||||
def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
|
||||
def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
|
||||
""" Update the default_cfg and kwargs before passing to model
|
||||
|
||||
Args:
|
||||
@ -288,6 +288,11 @@ def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
|
||||
if input_size is not None:
|
||||
assert len(input_size) == 3
|
||||
kwargs.setdefault(n, input_size[0])
|
||||
elif n == 'num_classes':
|
||||
default_val = pretrained_cfg.get(n, None)
|
||||
# if default is < 0, don't pass through to model
|
||||
if default_val is not None and default_val >= 0:
|
||||
kwargs.setdefault(n, pretrained_cfg[n])
|
||||
else:
|
||||
default_val = pretrained_cfg.get(n, None)
|
||||
if default_val is not None:
|
||||
@ -379,7 +384,7 @@ def build_model_with_cfg(
|
||||
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
|
||||
pretrained_cfg = pretrained_cfg.to_dict()
|
||||
|
||||
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)
|
||||
_update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter)
|
||||
|
||||
# Setup for feature extraction wrapper done at end of this fn
|
||||
if kwargs.pop('features_only', False):
|
||||
|
16
train.py
16
train.py
@ -15,6 +15,7 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
|
||||
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
@ -425,7 +426,10 @@ def main():
|
||||
factory_kwargs = {}
|
||||
if args.pretrained_path:
|
||||
# merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
|
||||
factory_kwargs['pretrained_cfg_overlay'] = dict(file=args.pretrained_path)
|
||||
factory_kwargs['pretrained_cfg_overlay'] = dict(
|
||||
file=args.pretrained_path,
|
||||
num_classes=-1, # force head adaptation
|
||||
)
|
||||
|
||||
model = create_model(
|
||||
args.model,
|
||||
@ -770,6 +774,7 @@ def main():
|
||||
_logger.info(
|
||||
f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
|
||||
|
||||
results = []
|
||||
try:
|
||||
for epoch in range(start_epoch, num_epochs):
|
||||
if hasattr(dataset_train, 'set_epoch'):
|
||||
@ -841,11 +846,20 @@ def main():
|
||||
# step LR for next epoch
|
||||
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
|
||||
|
||||
results.append({
|
||||
'epoch': epoch,
|
||||
'train': train_metrics,
|
||||
'validation': eval_metrics,
|
||||
})
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
results = {'all': results}
|
||||
if best_metric is not None:
|
||||
results['best'] = results['all'][best_epoch]
|
||||
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
|
||||
print(f'--result\n{json.dumps(results, indent=4)}')
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
|
Loading…
x
Reference in New Issue
Block a user