mmselfsup/openselfsup/hooks/validate_hook.py

72 lines
2.3 KiB
Python

from mmcv.runner import Hook
import torch
from torch.utils.data import Dataset
from openselfsup.utils import nondist_forward_collect, dist_forward_collect
from .registry import HOOKS
@HOOKS.register_module
class ValidateHook(Hook):
def __init__(self,
dataset,
dist_mode=True,
initial=True,
interval=1,
**eval_kwargs):
from openselfsup import datasets
if isinstance(dataset, Dataset):
self.dataset = dataset
elif isinstance(dataset, dict):
self.dataset = datasets.build_dataset(dataset)
else:
raise TypeError(
'dataset must be a Dataset object or a dict, not {}'.format(
type(dataset)))
self.data_loader = datasets.build_dataloader(
self.dataset,
eval_kwargs['imgs_per_gpu'],
eval_kwargs['workers_per_gpu'],
dist=dist_mode,
shuffle=False)
self.dist_mode = dist_mode
self.initial = initial
self.interval = interval
self.eval_kwargs = eval_kwargs
def before_run(self, runner):
if self.initial:
self._run_validate(runner)
def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
return
self._run_validate(runner)
def _run_validate(self, runner):
runner.model.eval()
func = lambda **x: runner.model(mode='test', **x)
if self.dist_mode:
results = dist_forward_collect(
func, self.data_loader, runner.rank,
len(self.dataset)) # dict{key: np.ndarray}
else:
results = nondist_forward_collect(func, self.data_loader,
len(self.dataset))
if runner.rank == 0:
for name, val in results.items():
self._evaluate(runner, torch.from_numpy(val), name)
runner.model.train()
def _evaluate(self, runner, results, keyword):
eval_res = self.dataset.evaluate(
results,
keyword=keyword,
logger=runner.logger,
**self.eval_kwargs['eval_param'])
for name, val in eval_res.items():
runner.log_buffer.output[name] = val
runner.log_buffer.ready = True