[Enhance] Return loop results. (#392)

This commit is contained in:
RangiLyu 2022-07-30 20:22:52 +08:00 committed by GitHub
parent 39e7efb04d
commit cfee85ff16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 14 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Dict, Union
from typing import Any, Dict, Union
from torch.utils.data import DataLoader
@ -33,5 +33,5 @@ class BaseLoop(metaclass=ABCMeta):
return self._runner
@abstractmethod
def run(self) -> None:
def run(self) -> Any:
"""Execute loop."""

View File

@ -81,7 +81,7 @@ class EpochBasedTrainLoop(BaseLoop):
"""int: Current iteration."""
return self._iter
def run(self) -> None:
def run(self) -> torch.nn.Module:
"""Launch training."""
self.runner.call_hook('before_train')
@ -95,6 +95,7 @@ class EpochBasedTrainLoop(BaseLoop):
self.runner.val_loop.run()
self.runner.call_hook('after_train')
return self.runner.model
def run_epoch(self) -> None:
"""Iterate one epoch."""
@ -266,6 +267,7 @@ class IterBasedTrainLoop(BaseLoop):
self.runner.call_hook('after_train_epoch')
self.runner.call_hook('after_train')
return self.runner.model
def run_iter(self, data_batch: Sequence[dict]) -> None:
"""Iterate one mini-batch.
@ -330,7 +332,7 @@ class ValLoop(BaseLoop):
'visualizer will be None.')
self.fp16 = fp16
def run(self):
def run(self) -> dict:
"""Launch validation."""
self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch')
@ -342,6 +344,7 @@ class ValLoop(BaseLoop):
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
self.runner.call_hook('after_val_epoch', metrics=metrics)
self.runner.call_hook('after_val')
return metrics
@torch.no_grad()
def run_iter(self, idx, data_batch: Sequence[dict]):
@ -399,7 +402,7 @@ class TestLoop(BaseLoop):
'visualizer will be None.')
self.fp16 = fp16
def run(self) -> None:
def run(self) -> dict:
"""Launch test."""
self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch')
@ -411,6 +414,7 @@ class TestLoop(BaseLoop):
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test')
return metrics
@torch.no_grad()
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:

View File

@ -1595,8 +1595,12 @@ class Runner:
self.load_checkpoint(self._load_from)
self._has_loaded = True
def train(self) -> None:
"""Launch training."""
def train(self) -> nn.Module:
"""Launch training.
Returns:
nn.Module: The model after training.
"""
if self._train_loop is None:
raise RuntimeError(
'`self._train_loop` should not be None when calling train '
@ -1634,11 +1638,16 @@ class Runner:
self._train_loop.iter, # type: ignore
self._train_loop.max_iters) # type: ignore
self.train_loop.run() # type: ignore
model = self.train_loop.run() # type: ignore
self.call_hook('after_run')
return model
def val(self) -> None:
"""Launch validation."""
def val(self) -> dict:
"""Launch validation.
Returns:
dict: A dict of metrics on validation set.
"""
if self._val_loop is None:
raise RuntimeError(
'`self._val_loop` should not be None when calling val method.'
@ -1652,11 +1661,16 @@ class Runner:
# make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume()
self.val_loop.run() # type: ignore
metrics = self.val_loop.run() # type: ignore
self.call_hook('after_run')
return metrics
def test(self) -> None:
"""Launch test."""
def test(self) -> dict:
"""Launch test.
Returns:
dict: A dict of metrics on testing set.
"""
if self._test_loop is None:
raise RuntimeError(
'`self._test_loop` should not be None when calling test '
@ -1670,8 +1684,9 @@ class Runner:
# make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume()
self.test_loop.run() # type: ignore
metrics = self.test_loop.run() # type: ignore
self.call_hook('after_run')
return metrics
def call_hook(self, fn_name: str, **kwargs) -> None:
"""Call all hooks.