# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp from collections import OrderedDict import torch from mmcv.runner import Hook from torch.utils.data import DataLoader if torch.cuda.is_available(): from easycv.datasets.shared.dali_tfrecord_imagenet import DaliLoaderWrapper class EvalHook(Hook): """Evaluation hook. Attributes: dataloader (DataLoader): A PyTorch dataloader. interval (int): Evaluation interval (by epochs). Default: 1. mode (str): model forward mode flush_buffer (bool): flush log buffer """ def __init__(self, dataloader, initial=False, interval=1, mode='test', flush_buffer=True, **eval_kwargs): if torch.cuda.is_available(): if not isinstance(dataloader, DataLoader) and not isinstance( dataloader, DaliLoaderWrapper): raise TypeError( 'dataloader must be a pytorch DataLoader, but got' f' {type(dataloader)}') else: if not isinstance(dataloader, DataLoader): raise TypeError( 'dataloader must be a pytorch DataLoader, but got' f' {type(dataloader)}') self.dataloader = dataloader self.interval = interval self.initial = initial self.mode = mode self.eval_kwargs = eval_kwargs # hook.evaluate runs every interval epoch or iter, popped at init self.vis_config = self.eval_kwargs.pop('visualization_config', {}) self.flush_buffer = flush_buffer def before_run(self, runner): if self.initial: self.after_train_epoch(runner) return def after_train_epoch(self, runner): if not self.every_n_epochs(runner, self.interval): return from easycv.apis import single_gpu_test if runner.rank == 0: if hasattr(runner, 'ema'): results = single_gpu_test( runner.ema.model, self.dataloader, mode=self.mode, show=False) else: results = single_gpu_test( runner.model, self.dataloader, mode=self.mode, show=False) self.evaluate(runner, results) def add_visualization_info(self, runner, results): if runner.visualization_buffer.output.get('eval_results', None) is None: runner.visualization_buffer.output['eval_results'] = OrderedDict() if isinstance(self.dataloader, DataLoader): dataset_obj = self.dataloader.dataset else: dataset_obj = self.dataloader if hasattr(dataset_obj, 'visualize'): runner.visualization_buffer.output['eval_results'].update( dataset_obj.visualize(results, **self.vis_config)) def evaluate(self, runner, results): self.add_visualization_info(runner, results) if isinstance(self.dataloader, DataLoader): eval_res = self.dataloader.dataset.evaluate( results, logger=runner.logger, **self.eval_kwargs) else: eval_res = self.dataloader.evaluate( results, logger=runner.logger, **self.eval_kwargs) for name, val in eval_res.items(): runner.log_buffer.output[name] = val if self.flush_buffer: runner.log_buffer.ready = True # regist eval_res to do save best hook if getattr(runner, 'eval_res', None) is None: runner.eval_res = {} for k in eval_res.keys(): tmp_res = {} tmp_res[k] = eval_res[k] tmp_res['runner_epoch'] = runner.epoch if k not in runner.eval_res.keys(): runner.eval_res[k] = [] runner.eval_res[k].append(tmp_res) class DistEvalHook(EvalHook): """Distributed evaluation hook. Attributes: dataloader (DataLoader): A PyTorch dataloader. interval (int): Evaluation interval (by epochs). Default: 1. mode (str): model forward mode tmpdir (str | None): Temporary directory to save the results of all processes. Default: None. gpu_collect (bool): Whether to use gpu or cpu to collect results. Default: False. """ def __init__(self, dataloader, interval=1, mode='test', initial=False, gpu_collect=False, flush_buffer=True, **eval_kwargs): if not isinstance(dataloader, DataLoader) and not isinstance( dataloader, DaliLoaderWrapper): raise TypeError('dataloader must be a pytorch DataLoader, but got' f' {type(dataloader)}') self.dataloader = dataloader self.interval = interval self.mode = mode self.eval_kwargs = eval_kwargs self.initial = initial self.flush_buffer = flush_buffer # hook.evaluate runs every interval epoch or iter, popped at init self.vis_config = self.eval_kwargs.pop('visualization_config', {}) self.gpu_collect = self.eval_kwargs.pop('gpu_collect', gpu_collect) def before_run(self, runner): if self.initial: self.after_train_epoch(runner) def after_train_epoch(self, runner): if not self.every_n_epochs(runner, self.interval): return from easycv.apis import multi_gpu_test results = multi_gpu_test( runner.model, self.dataloader, mode=self.mode, tmpdir=osp.join(runner.work_dir, '.eval_hook'), gpu_collect=self.gpu_collect) if runner.rank == 0: self.evaluate(runner, results)