Add parsable json results output for train.py, tweak --pretrained-path to force head adaptation

This commit is contained in:
Ross Wightman 2023-12-22 11:18:25 -08:00
parent e0079c92da
commit f2fdd97e9f
2 changed files with 22 additions and 3 deletions

View File

@ -261,7 +261,7 @@ def _filter_kwargs(kwargs, names):
kwargs.pop(n, None)
def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter):
""" Update the default_cfg and kwargs before passing to model
Args:
@ -288,6 +288,11 @@ def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[0])
elif n == 'num_classes':
default_val = pretrained_cfg.get(n, None)
# if default is < 0, don't pass through to model
if default_val is not None and default_val >= 0:
kwargs.setdefault(n, pretrained_cfg[n])
else:
default_val = pretrained_cfg.get(n, None)
if default_val is not None:
@ -379,7 +384,7 @@ def build_model_with_cfg(
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
pretrained_cfg = pretrained_cfg.to_dict()
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)
_update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter)
# Setup for feature extraction wrapper done at end of this fn
if kwargs.pop('features_only', False):

View File

@ -15,6 +15,7 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import argparse
import json
import logging
import os
import time
@ -425,7 +426,10 @@ def main():
factory_kwargs = {}
if args.pretrained_path:
# merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
factory_kwargs['pretrained_cfg_overlay'] = dict(file=args.pretrained_path)
factory_kwargs['pretrained_cfg_overlay'] = dict(
file=args.pretrained_path,
num_classes=-1, # force head adaptation
)
model = create_model(
args.model,
@ -770,6 +774,7 @@ def main():
_logger.info(
f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.')
results = []
try:
for epoch in range(start_epoch, num_epochs):
if hasattr(dataset_train, 'set_epoch'):
@ -841,11 +846,20 @@ def main():
# step LR for next epoch
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
results.append({
'epoch': epoch,
'train': train_metrics,
'validation': eval_metrics,
})
except KeyboardInterrupt:
pass
results = {'all': results}
if best_metric is not None:
results['best'] = results['all'][best_epoch]
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
print(f'--result\n{json.dumps(results, indent=4)}')
def train_one_epoch(