[feature]: Able to use save_best option (#575)

* Add  save_best option in eval_hook.

* Update meta to fix best model can not test bug

* refactor with _do_evaluate

* remove redundent

* add meta

Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
pull/580/head
Yinhao Li 2021-06-03 07:25:26 +08:00 committed by GitHub
parent a95f6d8173
commit 9dd3e15f83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 52 deletions

View File

@ -1,7 +1,9 @@
import os.path as osp
import torch.distributed as dist
from mmcv.runner import DistEvalHook as _DistEvalHook
from mmcv.runner import EvalHook as _EvalHook
from torch.nn.modules.batchnorm import _BatchNorm
class EvalHook(_EvalHook):
@ -23,33 +25,17 @@ class EvalHook(_EvalHook):
super().__init__(*args, by_epoch=by_epoch, **kwargs)
self.efficient_test = efficient_test
def after_train_iter(self, runner):
"""After train epoch hook.
Override default ``single_gpu_test``.
"""
if self.by_epoch or not self.every_n_iters(runner, self.interval):
def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
if not self._should_evaluate(runner):
return
from mmseg.apis import single_gpu_test
runner.log_buffer.clear()
results = single_gpu_test(
runner.model,
self.dataloader,
show=False,
efficient_test=self.efficient_test)
self.evaluate(runner, results)
def after_train_epoch(self, runner):
"""After train epoch hook.
Override default ``single_gpu_test``.
"""
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
from mmseg.apis import single_gpu_test
runner.log_buffer.clear()
results = single_gpu_test(runner.model, self.dataloader, show=False)
self.evaluate(runner, results)
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
self._save_ckpt(runner, key_score)
class DistEvalHook(_DistEvalHook):
@ -71,39 +57,38 @@ class DistEvalHook(_DistEvalHook):
super().__init__(*args, by_epoch=by_epoch, **kwargs)
self.efficient_test = efficient_test
def after_train_iter(self, runner):
"""After train epoch hook.
def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
# Synchronization of BatchNorm's buffer (running_mean
# and running_var) is not supported in the DDP of pytorch,
# which may cause the inconsistent performance of models in
# different ranks, so we broadcast BatchNorm's buffers
# of rank 0 to other ranks to avoid this.
if self.broadcast_bn_buffer:
model = runner.model
for name, module in model.named_modules():
if isinstance(module,
_BatchNorm) and module.track_running_stats:
dist.broadcast(module.running_var, 0)
dist.broadcast(module.running_mean, 0)
Override default ``multi_gpu_test``.
"""
if self.by_epoch or not self.every_n_iters(runner, self.interval):
if not self._should_evaluate(runner):
return
tmpdir = self.tmpdir
if tmpdir is None:
tmpdir = osp.join(runner.work_dir, '.eval_hook')
from mmseg.apis import multi_gpu_test
runner.log_buffer.clear()
results = multi_gpu_test(
runner.model,
self.dataloader,
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
gpu_collect=self.gpu_collect,
efficient_test=self.efficient_test)
if runner.rank == 0:
print('\n')
self.evaluate(runner, results)
def after_train_epoch(self, runner):
"""After train epoch hook.
Override default ``multi_gpu_test``.
"""
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
return
from mmseg.apis import multi_gpu_test
runner.log_buffer.clear()
results = multi_gpu_test(
runner.model,
self.dataloader,
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
tmpdir=tmpdir,
gpu_collect=self.gpu_collect)
if runner.rank == 0:
print('\n')
self.evaluate(runner, results)
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
key_score = self.evaluate(runner, results)
if self.save_best:
self._save_ckpt(runner, key_score)

View File

@ -122,8 +122,16 @@ def main():
if fp16_cfg is not None:
wrap_fp16_model(model)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
model.CLASSES = checkpoint['meta']['CLASSES']
model.PALETTE = checkpoint['meta']['PALETTE']
if 'CLASSES' in checkpoint.get('meta', {}):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
print('"CLASSES" not found in meta, use dataset.CLASSES instead')
model.CLASSES = dataset.CLASSES
if 'PALETTE' in checkpoint.get('meta', {}):
model.PALETTE = checkpoint['meta']['PALETTE']
else:
print('"PALETTE" not found in meta, use dataset.PALETTE instead')
model.PALETTE = dataset.PALETTE
efficient_test = False
if args.eval_options is not None:

View File

@ -149,6 +149,8 @@ def main():
PALETTE=datasets[0].PALETTE)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
# passing checkpoint meta for saving best checkpoint
meta.update(cfg.checkpoint_config.meta)
train_segmentor(
model,
datasets,