mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
* Support progressive test with fewer memory cost. * Temp code * Using processor to refactor evaluation workflow. * refactor eval hook. * Fix process bar. * Fix middle save argument. * Modify some variable name of dataset evaluate api. * Modify some viriable name of eval hook. * Fix some priority bugs of eval hook. * Depreciated efficient_test. * Fix training progress blocked by eval hook. * Depreciated old test api. * Fix test api error. * Modify outer api. * Build a sampler test api. * TODO: Refactor format_results. * Modify variable names. * Fix num_classes bug. * Fix sampler index bug. * Fix grammaly bug. * Support batch sampler. * More readable test api. * Remove some command arg and fix eval hook bug. * Support format-only arg. * Modify format_results of datasets. * Modify tool which use test apis. * support cityscapes eval * fixed cityscapes * 1. Add comments for batch_sampler; 2. Keep eval hook api same and add deprecated warning; 3. Add doc string for dataset.pre_eval; * Add efficient_test doc string. * Modify test tool to compat old version. * Modify eval hook to compat with old version. * Modify test api to compat old version api. * Sampler explanation. * update warning * Modify deploy_test.py * compatible with old output, add efficient test back * clear logic of exclusive * Warning about efficient_test. * Modify format_results save folder. * Fix bugs of format_results. * Modify deploy_test.py. * Update doc * Fix deploy test bugs. * Fix custom dataset unit tests. * Fix dataset unit tests. * Fix eval hook unit tests. * Fix some imcompatible. * Add pre_eval argument for eval hooks. * Update eval hook doc string. * Make pre_eval false in default. * Add unit tests for dataset format_results. * Fix some comments and bc-breaking bug. * Fix pre_eval set cfg field. * Remove redundant codes. Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
129 lines
4.6 KiB
Python
129 lines
4.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import warnings
|
|
|
|
import torch.distributed as dist
|
|
from mmcv.runner import DistEvalHook as _DistEvalHook
|
|
from mmcv.runner import EvalHook as _EvalHook
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
|
|
|
|
class EvalHook(_EvalHook):
|
|
"""Single GPU EvalHook, with efficient test support.
|
|
|
|
Args:
|
|
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
|
|
If set to True, it will perform by epoch. Otherwise, by iteration.
|
|
Default: False.
|
|
efficient_test (bool): Whether save the results as local numpy files to
|
|
save CPU memory during evaluation. Default: False.
|
|
pre_eval (bool): Whether to use progressive mode to evaluate model.
|
|
Default: False.
|
|
Returns:
|
|
list: The prediction results.
|
|
"""
|
|
|
|
greater_keys = ['mIoU', 'mAcc', 'aAcc']
|
|
|
|
def __init__(self,
|
|
*args,
|
|
by_epoch=False,
|
|
efficient_test=False,
|
|
pre_eval=False,
|
|
**kwargs):
|
|
super().__init__(*args, by_epoch=by_epoch, **kwargs)
|
|
self.pre_eval = pre_eval
|
|
if efficient_test:
|
|
warnings.warn(
|
|
'DeprecationWarning: ``efficient_test`` for evaluation hook '
|
|
'is deprecated, the evaluation hook is CPU memory friendly '
|
|
'with ``pre_eval=True`` as argument for ``single_gpu_test()`` '
|
|
'function')
|
|
|
|
def _do_evaluate(self, runner):
|
|
"""perform evaluation and save ckpt."""
|
|
if not self._should_evaluate(runner):
|
|
return
|
|
|
|
from mmseg.apis import single_gpu_test
|
|
results = single_gpu_test(
|
|
runner.model, self.dataloader, show=False, pre_eval=self.pre_eval)
|
|
runner.log_buffer.clear()
|
|
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
|
key_score = self.evaluate(runner, results)
|
|
if self.save_best:
|
|
self._save_ckpt(runner, key_score)
|
|
|
|
|
|
class DistEvalHook(_DistEvalHook):
|
|
"""Distributed EvalHook, with efficient test support.
|
|
|
|
Args:
|
|
by_epoch (bool): Determine perform evaluation by epoch or by iteration.
|
|
If set to True, it will perform by epoch. Otherwise, by iteration.
|
|
Default: False.
|
|
efficient_test (bool): Whether save the results as local numpy files to
|
|
save CPU memory during evaluation. Default: False.
|
|
pre_eval (bool): Whether to use progressive mode to evaluate model.
|
|
Default: False.
|
|
Returns:
|
|
list: The prediction results.
|
|
"""
|
|
|
|
greater_keys = ['mIoU', 'mAcc', 'aAcc']
|
|
|
|
def __init__(self,
|
|
*args,
|
|
by_epoch=False,
|
|
efficient_test=False,
|
|
pre_eval=False,
|
|
**kwargs):
|
|
super().__init__(*args, by_epoch=by_epoch, **kwargs)
|
|
self.pre_eval = pre_eval
|
|
if efficient_test:
|
|
warnings.warn(
|
|
'DeprecationWarning: ``efficient_test`` for evaluation hook '
|
|
'is deprecated, the evaluation hook is CPU memory friendly '
|
|
'with ``pre_eval=True`` as argument for ``multi_gpu_test()`` '
|
|
'function')
|
|
|
|
def _do_evaluate(self, runner):
|
|
"""perform evaluation and save ckpt."""
|
|
# Synchronization of BatchNorm's buffer (running_mean
|
|
# and running_var) is not supported in the DDP of pytorch,
|
|
# which may cause the inconsistent performance of models in
|
|
# different ranks, so we broadcast BatchNorm's buffers
|
|
# of rank 0 to other ranks to avoid this.
|
|
if self.broadcast_bn_buffer:
|
|
model = runner.model
|
|
for name, module in model.named_modules():
|
|
if isinstance(module,
|
|
_BatchNorm) and module.track_running_stats:
|
|
dist.broadcast(module.running_var, 0)
|
|
dist.broadcast(module.running_mean, 0)
|
|
|
|
if not self._should_evaluate(runner):
|
|
return
|
|
|
|
tmpdir = self.tmpdir
|
|
if tmpdir is None:
|
|
tmpdir = osp.join(runner.work_dir, '.eval_hook')
|
|
|
|
from mmseg.apis import multi_gpu_test
|
|
results = multi_gpu_test(
|
|
runner.model,
|
|
self.dataloader,
|
|
tmpdir=tmpdir,
|
|
gpu_collect=self.gpu_collect,
|
|
pre_eval=self.pre_eval)
|
|
|
|
runner.log_buffer.clear()
|
|
|
|
if runner.rank == 0:
|
|
print('\n')
|
|
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
|
key_score = self.evaluate(runner, results)
|
|
|
|
if self.save_best:
|
|
self._save_ckpt(runner, key_score)
|