EasyCV/easycv/hooks/eval_hook.py

194 lines
6.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from collections import OrderedDict
import torch
import torch.distributed as dist
from mmcv.runner import Hook
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import DataLoader
from easycv.datasets.loader.loader_wrapper import TorchaccLoaderWrapper
from easycv.framework.errors import TypeError
from easycv.hooks.tensorboard import TensorboardLoggerHookV2
from easycv.hooks.wandb import WandbLoggerHookV2
if torch.cuda.is_available():
from easycv.datasets.shared.dali_tfrecord_imagenet import DaliLoaderWrapper
class EvalHook(Hook):
"""Evaluation hook.
Attributes:
dataloader (DataLoader): A PyTorch dataloader.
interval (int): Evaluation interval (by epochs). Default: 1.
mode (str): model forward mode
flush_buffer (bool): flush log buffer
"""
def __init__(self,
dataloader,
initial=False,
interval=1,
mode='test',
flush_buffer=True,
**eval_kwargs):
if torch.cuda.is_available():
if not isinstance(
dataloader,
(DataLoader, DaliLoaderWrapper, TorchaccLoaderWrapper)):
raise TypeError(
'dataloader must be a pytorch DataLoader, but got'
f' {type(dataloader)}')
else:
if not isinstance(dataloader, DataLoader):
raise TypeError(
'dataloader must be a pytorch DataLoader, but got'
f' {type(dataloader)}')
self.dataloader = dataloader
self.interval = interval
self.initial = initial
self.mode = mode
self.eval_kwargs = eval_kwargs
# hook.evaluate runs every interval epoch or iter, popped at init
self.vis_config = self.eval_kwargs.pop('visualization_config', {})
self.flush_buffer = flush_buffer
def before_run(self, runner):
if self.initial:
self.after_train_epoch(runner)
return
def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
return
from easycv.apis import single_gpu_test
if runner.rank == 0:
if hasattr(runner, 'ema'):
results = single_gpu_test(
runner.ema.model,
self.dataloader,
mode=self.mode,
show=False)
else:
results = single_gpu_test(
runner.model, self.dataloader, mode=self.mode, show=False)
self.evaluate(runner, results)
def add_visualization_info(self, runner, results):
if runner.visualization_buffer.output.get('eval_results',
None) is None:
runner.visualization_buffer.output['eval_results'] = OrderedDict()
if isinstance(self.dataloader, DataLoader):
dataset_obj = self.dataloader.dataset
else:
dataset_obj = self.dataloader
if hasattr(dataset_obj, 'visualize'):
runner.visualization_buffer.output['eval_results'].update(
dataset_obj.visualize(results, **self.vis_config))
def evaluate(self, runner, results):
for _hook in runner.hooks:
# Only apply `add_visualization_info` when visualization hook is specified
if isinstance(_hook, (TensorboardLoggerHookV2, WandbLoggerHookV2)):
self.add_visualization_info(runner, results)
break
if isinstance(self.dataloader, DaliLoaderWrapper):
eval_res = self.dataloader.evaluate(
results, logger=runner.logger, **self.eval_kwargs)
else:
eval_res = self.dataloader.dataset.evaluate(
results, logger=runner.logger, **self.eval_kwargs)
for name, val in eval_res.items():
runner.log_buffer.output[name] = val
if self.flush_buffer:
runner.log_buffer.ready = True
# regist eval_res to do save best hook
if getattr(runner, 'eval_res', None) is None:
runner.eval_res = {}
for k in eval_res.keys():
tmp_res = {}
tmp_res[k] = eval_res[k]
tmp_res['runner_epoch'] = runner.epoch
if k not in runner.eval_res.keys():
runner.eval_res[k] = []
runner.eval_res[k].append(tmp_res)
class DistEvalHook(EvalHook):
"""Distributed evaluation hook.
Attributes:
dataloader (DataLoader): A PyTorch dataloader.
interval (int): Evaluation interval (by epochs). Default: 1.
mode (str): model forward mode
tmpdir (str | None): Temporary directory to save the results of all
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
Default: False.
broadcast_bn_buffer (bool): Whether to broadcast the
buffer(running_mean and running_var) of rank 0 to other rank
before evaluation. Default: True.
"""
def __init__(self,
dataloader,
interval=1,
mode='test',
initial=False,
gpu_collect=False,
flush_buffer=True,
broadcast_bn_buffer=True,
**eval_kwargs):
super(DistEvalHook, self).__init__(
dataloader=dataloader,
initial=initial,
interval=interval,
mode=mode,
flush_buffer=flush_buffer,
**eval_kwargs)
self.broadcast_bn_buffer = broadcast_bn_buffer
self.gpu_collect = self.eval_kwargs.pop('gpu_collect', gpu_collect)
def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
return
# Synchronization of BatchNorm's buffer (running_mean
# and running_var) is not supported in the DDP of pytorch,
# which may cause the inconsistent performance of models in
# different ranks, so we broadcast BatchNorm's buffers
# of rank 0 to other ranks to avoid this.
if self.broadcast_bn_buffer:
model = runner.model
for name, module in model.named_modules():
if isinstance(module,
_BatchNorm) and module.track_running_stats:
dist.broadcast(module.running_var, 0)
dist.broadcast(module.running_mean, 0)
from easycv.apis import multi_gpu_test
results = multi_gpu_test(
runner.model,
self.dataloader,
mode=self.mode,
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
gpu_collect=self.gpu_collect)
if runner.rank == 0:
self.evaluate(runner, results)