[feature]: Able to use save_best option (#575)
* Add save_best option in eval_hook. * Update meta to fix best model can not test bug * refactor with _do_evaluate * remove redundent * add meta Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>pull/1801/head
parent
725d5aa002
commit
02b5d768aa
|
@ -1,7 +1,9 @@
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
from mmcv.runner import DistEvalHook as _DistEvalHook
|
from mmcv.runner import DistEvalHook as _DistEvalHook
|
||||||
from mmcv.runner import EvalHook as _EvalHook
|
from mmcv.runner import EvalHook as _EvalHook
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
|
||||||
class EvalHook(_EvalHook):
|
class EvalHook(_EvalHook):
|
||||||
|
@ -23,33 +25,17 @@ class EvalHook(_EvalHook):
|
||||||
super().__init__(*args, by_epoch=by_epoch, **kwargs)
|
super().__init__(*args, by_epoch=by_epoch, **kwargs)
|
||||||
self.efficient_test = efficient_test
|
self.efficient_test = efficient_test
|
||||||
|
|
||||||
def after_train_iter(self, runner):
|
def _do_evaluate(self, runner):
|
||||||
"""After train epoch hook.
|
"""perform evaluation and save ckpt."""
|
||||||
|
if not self._should_evaluate(runner):
|
||||||
Override default ``single_gpu_test``.
|
|
||||||
"""
|
|
||||||
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
|
||||||
return
|
return
|
||||||
from mmseg.apis import single_gpu_test
|
|
||||||
runner.log_buffer.clear()
|
|
||||||
results = single_gpu_test(
|
|
||||||
runner.model,
|
|
||||||
self.dataloader,
|
|
||||||
show=False,
|
|
||||||
efficient_test=self.efficient_test)
|
|
||||||
self.evaluate(runner, results)
|
|
||||||
|
|
||||||
def after_train_epoch(self, runner):
|
|
||||||
"""After train epoch hook.
|
|
||||||
|
|
||||||
Override default ``single_gpu_test``.
|
|
||||||
"""
|
|
||||||
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
|
||||||
return
|
|
||||||
from mmseg.apis import single_gpu_test
|
from mmseg.apis import single_gpu_test
|
||||||
runner.log_buffer.clear()
|
|
||||||
results = single_gpu_test(runner.model, self.dataloader, show=False)
|
results = single_gpu_test(runner.model, self.dataloader, show=False)
|
||||||
self.evaluate(runner, results)
|
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
||||||
|
key_score = self.evaluate(runner, results)
|
||||||
|
if self.save_best:
|
||||||
|
self._save_ckpt(runner, key_score)
|
||||||
|
|
||||||
|
|
||||||
class DistEvalHook(_DistEvalHook):
|
class DistEvalHook(_DistEvalHook):
|
||||||
|
@ -71,39 +57,38 @@ class DistEvalHook(_DistEvalHook):
|
||||||
super().__init__(*args, by_epoch=by_epoch, **kwargs)
|
super().__init__(*args, by_epoch=by_epoch, **kwargs)
|
||||||
self.efficient_test = efficient_test
|
self.efficient_test = efficient_test
|
||||||
|
|
||||||
def after_train_iter(self, runner):
|
def _do_evaluate(self, runner):
|
||||||
"""After train epoch hook.
|
"""perform evaluation and save ckpt."""
|
||||||
|
# Synchronization of BatchNorm's buffer (running_mean
|
||||||
|
# and running_var) is not supported in the DDP of pytorch,
|
||||||
|
# which may cause the inconsistent performance of models in
|
||||||
|
# different ranks, so we broadcast BatchNorm's buffers
|
||||||
|
# of rank 0 to other ranks to avoid this.
|
||||||
|
if self.broadcast_bn_buffer:
|
||||||
|
model = runner.model
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module,
|
||||||
|
_BatchNorm) and module.track_running_stats:
|
||||||
|
dist.broadcast(module.running_var, 0)
|
||||||
|
dist.broadcast(module.running_mean, 0)
|
||||||
|
|
||||||
Override default ``multi_gpu_test``.
|
if not self._should_evaluate(runner):
|
||||||
"""
|
|
||||||
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
tmpdir = self.tmpdir
|
||||||
|
if tmpdir is None:
|
||||||
|
tmpdir = osp.join(runner.work_dir, '.eval_hook')
|
||||||
|
|
||||||
from mmseg.apis import multi_gpu_test
|
from mmseg.apis import multi_gpu_test
|
||||||
runner.log_buffer.clear()
|
|
||||||
results = multi_gpu_test(
|
results = multi_gpu_test(
|
||||||
runner.model,
|
runner.model,
|
||||||
self.dataloader,
|
self.dataloader,
|
||||||
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
|
tmpdir=tmpdir,
|
||||||
gpu_collect=self.gpu_collect,
|
|
||||||
efficient_test=self.efficient_test)
|
|
||||||
if runner.rank == 0:
|
|
||||||
print('\n')
|
|
||||||
self.evaluate(runner, results)
|
|
||||||
|
|
||||||
def after_train_epoch(self, runner):
|
|
||||||
"""After train epoch hook.
|
|
||||||
|
|
||||||
Override default ``multi_gpu_test``.
|
|
||||||
"""
|
|
||||||
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
|
||||||
return
|
|
||||||
from mmseg.apis import multi_gpu_test
|
|
||||||
runner.log_buffer.clear()
|
|
||||||
results = multi_gpu_test(
|
|
||||||
runner.model,
|
|
||||||
self.dataloader,
|
|
||||||
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
|
|
||||||
gpu_collect=self.gpu_collect)
|
gpu_collect=self.gpu_collect)
|
||||||
if runner.rank == 0:
|
if runner.rank == 0:
|
||||||
print('\n')
|
print('\n')
|
||||||
self.evaluate(runner, results)
|
runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
|
||||||
|
key_score = self.evaluate(runner, results)
|
||||||
|
|
||||||
|
if self.save_best:
|
||||||
|
self._save_ckpt(runner, key_score)
|
||||||
|
|
|
@ -122,8 +122,16 @@ def main():
|
||||||
if fp16_cfg is not None:
|
if fp16_cfg is not None:
|
||||||
wrap_fp16_model(model)
|
wrap_fp16_model(model)
|
||||||
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
|
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
|
||||||
|
if 'CLASSES' in checkpoint.get('meta', {}):
|
||||||
model.CLASSES = checkpoint['meta']['CLASSES']
|
model.CLASSES = checkpoint['meta']['CLASSES']
|
||||||
|
else:
|
||||||
|
print('"CLASSES" not found in meta, use dataset.CLASSES instead')
|
||||||
|
model.CLASSES = dataset.CLASSES
|
||||||
|
if 'PALETTE' in checkpoint.get('meta', {}):
|
||||||
model.PALETTE = checkpoint['meta']['PALETTE']
|
model.PALETTE = checkpoint['meta']['PALETTE']
|
||||||
|
else:
|
||||||
|
print('"PALETTE" not found in meta, use dataset.PALETTE instead')
|
||||||
|
model.PALETTE = dataset.PALETTE
|
||||||
|
|
||||||
efficient_test = False
|
efficient_test = False
|
||||||
if args.eval_options is not None:
|
if args.eval_options is not None:
|
||||||
|
|
|
@ -149,6 +149,8 @@ def main():
|
||||||
PALETTE=datasets[0].PALETTE)
|
PALETTE=datasets[0].PALETTE)
|
||||||
# add an attribute for visualization convenience
|
# add an attribute for visualization convenience
|
||||||
model.CLASSES = datasets[0].CLASSES
|
model.CLASSES = datasets[0].CLASSES
|
||||||
|
# passing checkpoint meta for saving best checkpoint
|
||||||
|
meta.update(cfg.checkpoint_config.meta)
|
||||||
train_segmentor(
|
train_segmentor(
|
||||||
model,
|
model,
|
||||||
datasets,
|
datasets,
|
||||||
|
|
Loading…
Reference in New Issue