From f2fdd97e9f859285363c05988820c9350b737e59 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 22 Dec 2023 11:18:25 -0800 Subject: [PATCH] Add parsable json results output for train.py, tweak --pretrained-path to force head adaptation --- timm/models/_builder.py | 9 +++++++-- train.py | 16 +++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/timm/models/_builder.py b/timm/models/_builder.py index c11a84a3..e6150b9a 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -261,7 +261,7 @@ def _filter_kwargs(kwargs, names): kwargs.pop(n, None) -def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter): +def _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter): """ Update the default_cfg and kwargs before passing to model Args: @@ -288,6 +288,11 @@ def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter): if input_size is not None: assert len(input_size) == 3 kwargs.setdefault(n, input_size[0]) + elif n == 'num_classes': + default_val = pretrained_cfg.get(n, None) + # if default is < 0, don't pass through to model + if default_val is not None and default_val >= 0: + kwargs.setdefault(n, pretrained_cfg[n]) else: default_val = pretrained_cfg.get(n, None) if default_val is not None: @@ -379,7 +384,7 @@ def build_model_with_cfg( # FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model pretrained_cfg = pretrained_cfg.to_dict() - _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter) + _update_default_model_kwargs(pretrained_cfg, kwargs, kwargs_filter) # Setup for feature extraction wrapper done at end of this fn if kwargs.pop('features_only', False): diff --git a/train.py b/train.py index 05fff7fa..9f4ff4ec 100755 --- a/train.py +++ b/train.py @@ -15,6 +15,7 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ import argparse +import json import logging import os import time @@ -425,7 +426,10 @@ def main(): factory_kwargs = {} if args.pretrained_path: # merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'. - factory_kwargs['pretrained_cfg_overlay'] = dict(file=args.pretrained_path) + factory_kwargs['pretrained_cfg_overlay'] = dict( + file=args.pretrained_path, + num_classes=-1, # force head adaptation + ) model = create_model( args.model, @@ -770,6 +774,7 @@ def main(): _logger.info( f'Scheduled epochs: {num_epochs}. LR stepped per {"epoch" if lr_scheduler.t_in_epochs else "update"}.') + results = [] try: for epoch in range(start_epoch, num_epochs): if hasattr(dataset_train, 'set_epoch'): @@ -841,11 +846,20 @@ def main(): # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) + results.append({ + 'epoch': epoch, + 'train': train_metrics, + 'validation': eval_metrics, + }) + except KeyboardInterrupt: pass + results = {'all': results} if best_metric is not None: + results['best'] = results['all'][best_epoch] _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) + print(f'--result\n{json.dumps(results, indent=4)}') def train_one_epoch(