From 9dd3e15f834c1bd22fa19d0001ab7c51cec3e6ce Mon Sep 17 00:00:00 2001 From: Yinhao Li Date: Thu, 3 Jun 2021 07:25:26 +0800 Subject: [PATCH] [feature]: Able to use save_best option (#575) * Add save_best option in eval_hook. * Update meta to fix best model can not test bug * refactor with _do_evaluate * remove redundent * add meta Co-authored-by: Jiarui XU --- mmseg/core/evaluation/eval_hooks.py | 85 ++++++++++++----------------- tools/test.py | 12 +++- tools/train.py | 2 + 3 files changed, 47 insertions(+), 52 deletions(-) diff --git a/mmseg/core/evaluation/eval_hooks.py b/mmseg/core/evaluation/eval_hooks.py index 34c44c7fe..ce5809146 100644 --- a/mmseg/core/evaluation/eval_hooks.py +++ b/mmseg/core/evaluation/eval_hooks.py @@ -1,7 +1,9 @@ import os.path as osp +import torch.distributed as dist from mmcv.runner import DistEvalHook as _DistEvalHook from mmcv.runner import EvalHook as _EvalHook +from torch.nn.modules.batchnorm import _BatchNorm class EvalHook(_EvalHook): @@ -23,33 +25,17 @@ class EvalHook(_EvalHook): super().__init__(*args, by_epoch=by_epoch, **kwargs) self.efficient_test = efficient_test - def after_train_iter(self, runner): - """After train epoch hook. - - Override default ``single_gpu_test``. - """ - if self.by_epoch or not self.every_n_iters(runner, self.interval): + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + if not self._should_evaluate(runner): return - from mmseg.apis import single_gpu_test - runner.log_buffer.clear() - results = single_gpu_test( - runner.model, - self.dataloader, - show=False, - efficient_test=self.efficient_test) - self.evaluate(runner, results) - def after_train_epoch(self, runner): - """After train epoch hook. - - Override default ``single_gpu_test``. - """ - if not self.by_epoch or not self.every_n_epochs(runner, self.interval): - return from mmseg.apis import single_gpu_test - runner.log_buffer.clear() results = single_gpu_test(runner.model, self.dataloader, show=False) - self.evaluate(runner, results) + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + if self.save_best: + self._save_ckpt(runner, key_score) class DistEvalHook(_DistEvalHook): @@ -71,39 +57,38 @@ class DistEvalHook(_DistEvalHook): super().__init__(*args, by_epoch=by_epoch, **kwargs) self.efficient_test = efficient_test - def after_train_iter(self, runner): - """After train epoch hook. + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + # Synchronization of BatchNorm's buffer (running_mean + # and running_var) is not supported in the DDP of pytorch, + # which may cause the inconsistent performance of models in + # different ranks, so we broadcast BatchNorm's buffers + # of rank 0 to other ranks to avoid this. + if self.broadcast_bn_buffer: + model = runner.model + for name, module in model.named_modules(): + if isinstance(module, + _BatchNorm) and module.track_running_stats: + dist.broadcast(module.running_var, 0) + dist.broadcast(module.running_mean, 0) - Override default ``multi_gpu_test``. - """ - if self.by_epoch or not self.every_n_iters(runner, self.interval): + if not self._should_evaluate(runner): return + + tmpdir = self.tmpdir + if tmpdir is None: + tmpdir = osp.join(runner.work_dir, '.eval_hook') + from mmseg.apis import multi_gpu_test - runner.log_buffer.clear() results = multi_gpu_test( runner.model, self.dataloader, - tmpdir=osp.join(runner.work_dir, '.eval_hook'), - gpu_collect=self.gpu_collect, - efficient_test=self.efficient_test) - if runner.rank == 0: - print('\n') - self.evaluate(runner, results) - - def after_train_epoch(self, runner): - """After train epoch hook. - - Override default ``multi_gpu_test``. - """ - if not self.by_epoch or not self.every_n_epochs(runner, self.interval): - return - from mmseg.apis import multi_gpu_test - runner.log_buffer.clear() - results = multi_gpu_test( - runner.model, - self.dataloader, - tmpdir=osp.join(runner.work_dir, '.eval_hook'), + tmpdir=tmpdir, gpu_collect=self.gpu_collect) if runner.rank == 0: print('\n') - self.evaluate(runner, results) + runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) + key_score = self.evaluate(runner, results) + + if self.save_best: + self._save_ckpt(runner, key_score) diff --git a/tools/test.py b/tools/test.py index fd8589c02..ab2bd6017 100644 --- a/tools/test.py +++ b/tools/test.py @@ -122,8 +122,16 @@ def main(): if fp16_cfg is not None: wrap_fp16_model(model) checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') - model.CLASSES = checkpoint['meta']['CLASSES'] - model.PALETTE = checkpoint['meta']['PALETTE'] + if 'CLASSES' in checkpoint.get('meta', {}): + model.CLASSES = checkpoint['meta']['CLASSES'] + else: + print('"CLASSES" not found in meta, use dataset.CLASSES instead') + model.CLASSES = dataset.CLASSES + if 'PALETTE' in checkpoint.get('meta', {}): + model.PALETTE = checkpoint['meta']['PALETTE'] + else: + print('"PALETTE" not found in meta, use dataset.PALETTE instead') + model.PALETTE = dataset.PALETTE efficient_test = False if args.eval_options is not None: diff --git a/tools/train.py b/tools/train.py index 51fe4065d..69ca7335d 100644 --- a/tools/train.py +++ b/tools/train.py @@ -149,6 +149,8 @@ def main(): PALETTE=datasets[0].PALETTE) # add an attribute for visualization convenience model.CLASSES = datasets[0].CLASSES + # passing checkpoint meta for saving best checkpoint + meta.update(cfg.checkpoint_config.meta) train_segmentor( model, datasets,