[feature]: Able to use save_best option (#575)
* Add save_best option in eval_hook. * Update meta to fix best model can not test bug * refactor with _do_evaluate * remove redundent * add meta Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>pull/1801/head
parent
725d5aa002
commit
02b5d768aa
|
@ -1,7 +1,9 @@
|
|||
import os.path as osp
|
||||
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import DistEvalHook as _DistEvalHook
|
||||
from mmcv.runner import EvalHook as _EvalHook
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
|
||||
class EvalHook(_EvalHook):
|
||||
|
@ -23,33 +25,17 @@ class EvalHook(_EvalHook):
|
|||
super().__init__(*args, by_epoch=by_epoch, **kwargs)
|
||||
self.efficient_test = efficient_test
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
"""After train epoch hook.
|
||||
|
||||
Override default ``single_gpu_test``.
|
||||
"""
|
||||
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
||||
def _do_evaluate(self, runner):
|
||||
"""perform evaluation and save ckpt."""
|
||||
if not self._should_evaluate(runner):
|
||||
return
|
||||
from mmseg.apis import single_gpu_test
|
||||
runner.log_buffer.clear()
|
||||
results = single_gpu_test(
|
||||
runner.model,
|
||||
self.dataloader,
|
||||
show=False,
|
||||
efficient_test=self.efficient_test)
|
||||
self.evaluate(runner, results)
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
"""After train epoch hook.
|
||||
|
||||
Override default ``single_gpu_test``.
|
||||
"""
|
||||
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
||||
return
|
||||
from mmseg.apis import single_gpu_test
|
||||
runner.log_buffer.clear()
|
||||
results = single_gpu_test(runner.model, self.dataloader, show=False)
|
||||
self.evaluate(runner, results)
|
||||
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
||||
key_score = self.evaluate(runner, results)
|
||||
if self.save_best:
|
||||
self._save_ckpt(runner, key_score)
|
||||
|
||||
|
||||
class DistEvalHook(_DistEvalHook):
|
||||
|
@ -71,39 +57,38 @@ class DistEvalHook(_DistEvalHook):
|
|||
super().__init__(*args, by_epoch=by_epoch, **kwargs)
|
||||
self.efficient_test = efficient_test
|
||||
|
||||
def after_train_iter(self, runner):
|
||||
"""After train epoch hook.
|
||||
def _do_evaluate(self, runner):
|
||||
"""perform evaluation and save ckpt."""
|
||||
# Synchronization of BatchNorm's buffer (running_mean
|
||||
# and running_var) is not supported in the DDP of pytorch,
|
||||
# which may cause the inconsistent performance of models in
|
||||
# different ranks, so we broadcast BatchNorm's buffers
|
||||
# of rank 0 to other ranks to avoid this.
|
||||
if self.broadcast_bn_buffer:
|
||||
model = runner.model
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module,
|
||||
_BatchNorm) and module.track_running_stats:
|
||||
dist.broadcast(module.running_var, 0)
|
||||
dist.broadcast(module.running_mean, 0)
|
||||
|
||||
Override default ``multi_gpu_test``.
|
||||
"""
|
||||
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
||||
if not self._should_evaluate(runner):
|
||||
return
|
||||
|
||||
tmpdir = self.tmpdir
|
||||
if tmpdir is None:
|
||||
tmpdir = osp.join(runner.work_dir, '.eval_hook')
|
||||
|
||||
from mmseg.apis import multi_gpu_test
|
||||
runner.log_buffer.clear()
|
||||
results = multi_gpu_test(
|
||||
runner.model,
|
||||
self.dataloader,
|
||||
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
|
||||
gpu_collect=self.gpu_collect,
|
||||
efficient_test=self.efficient_test)
|
||||
if runner.rank == 0:
|
||||
print('\n')
|
||||
self.evaluate(runner, results)
|
||||
|
||||
def after_train_epoch(self, runner):
|
||||
"""After train epoch hook.
|
||||
|
||||
Override default ``multi_gpu_test``.
|
||||
"""
|
||||
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
||||
return
|
||||
from mmseg.apis import multi_gpu_test
|
||||
runner.log_buffer.clear()
|
||||
results = multi_gpu_test(
|
||||
runner.model,
|
||||
self.dataloader,
|
||||
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
|
||||
tmpdir=tmpdir,
|
||||
gpu_collect=self.gpu_collect)
|
||||
if runner.rank == 0:
|
||||
print('\n')
|
||||
self.evaluate(runner, results)
|
||||
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
||||
key_score = self.evaluate(runner, results)
|
||||
|
||||
if self.save_best:
|
||||
self._save_ckpt(runner, key_score)
|
||||
|
|
|
@ -122,8 +122,16 @@ def main():
|
|||
if fp16_cfg is not None:
|
||||
wrap_fp16_model(model)
|
||||
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||
model.CLASSES = checkpoint['meta']['CLASSES']
|
||||
model.PALETTE = checkpoint['meta']['PALETTE']
|
||||
if 'CLASSES' in checkpoint.get('meta', {}):
|
||||
model.CLASSES = checkpoint['meta']['CLASSES']
|
||||
else:
|
||||
print('"CLASSES" not found in meta, use dataset.CLASSES instead')
|
||||
model.CLASSES = dataset.CLASSES
|
||||
if 'PALETTE' in checkpoint.get('meta', {}):
|
||||
model.PALETTE = checkpoint['meta']['PALETTE']
|
||||
else:
|
||||
print('"PALETTE" not found in meta, use dataset.PALETTE instead')
|
||||
model.PALETTE = dataset.PALETTE
|
||||
|
||||
efficient_test = False
|
||||
if args.eval_options is not None:
|
||||
|
|
|
@ -149,6 +149,8 @@ def main():
|
|||
PALETTE=datasets[0].PALETTE)
|
||||
# add an attribute for visualization convenience
|
||||
model.CLASSES = datasets[0].CLASSES
|
||||
# passing checkpoint meta for saving best checkpoint
|
||||
meta.update(cfg.checkpoint_config.meta)
|
||||
train_segmentor(
|
||||
model,
|
||||
datasets,
|
||||
|
|
Loading…
Reference in New Issue