72 lines
2.3 KiB
Python
72 lines
2.3 KiB
Python
from mmcv.runner import Hook
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
|
|
from openselfsup.utils import nondist_forward_collect, dist_forward_collect
|
|
from .registry import HOOKS
|
|
|
|
|
|
@HOOKS.register_module
|
|
class ValidateHook(Hook):
|
|
|
|
def __init__(self,
|
|
dataset,
|
|
dist_mode=True,
|
|
initial=True,
|
|
interval=1,
|
|
**eval_kwargs):
|
|
from openselfsup import datasets
|
|
if isinstance(dataset, Dataset):
|
|
self.dataset = dataset
|
|
elif isinstance(dataset, dict):
|
|
self.dataset = datasets.build_dataset(dataset)
|
|
else:
|
|
raise TypeError(
|
|
'dataset must be a Dataset object or a dict, not {}'.format(
|
|
type(dataset)))
|
|
self.data_loader = datasets.build_dataloader(
|
|
self.dataset,
|
|
eval_kwargs['imgs_per_gpu'],
|
|
eval_kwargs['workers_per_gpu'],
|
|
dist=dist_mode,
|
|
shuffle=False)
|
|
self.dist_mode = dist_mode
|
|
self.initial = initial
|
|
self.interval = interval
|
|
self.eval_kwargs = eval_kwargs
|
|
|
|
def before_run(self, runner):
|
|
if self.initial:
|
|
self._run_validate(runner)
|
|
|
|
def after_train_epoch(self, runner):
|
|
if not self.every_n_epochs(runner, self.interval):
|
|
return
|
|
self._run_validate(runner)
|
|
|
|
def _run_validate(self, runner):
|
|
runner.model.eval()
|
|
func = lambda **x: runner.model(mode='test', **x)
|
|
if self.dist_mode:
|
|
results = dist_forward_collect(
|
|
func, self.data_loader, runner.rank,
|
|
len(self.dataset)) # dict{key: np.ndarray}
|
|
else:
|
|
results = nondist_forward_collect(func, self.data_loader,
|
|
len(self.dataset))
|
|
if runner.rank == 0:
|
|
for name, val in results.items():
|
|
self._evaluate(runner, torch.from_numpy(val), name)
|
|
runner.model.train()
|
|
|
|
def _evaluate(self, runner, results, keyword):
|
|
eval_res = self.dataset.evaluate(
|
|
results,
|
|
keyword=keyword,
|
|
logger=runner.logger,
|
|
**self.eval_kwargs['eval_param'])
|
|
for name, val in eval_res.items():
|
|
runner.log_buffer.output[name] = val
|
|
runner.log_buffer.ready = True
|