[Enhance] Return loop results. (#392)

This commit is contained in:
RangiLyu 2022-07-30 20:22:52 +08:00 committed by GitHub
parent 39e7efb04d
commit cfee85ff16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 14 deletions

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Dict, Union from typing import Any, Dict, Union
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -33,5 +33,5 @@ class BaseLoop(metaclass=ABCMeta):
return self._runner return self._runner
@abstractmethod @abstractmethod
def run(self) -> None: def run(self) -> Any:
"""Execute loop.""" """Execute loop."""

View File

@ -81,7 +81,7 @@ class EpochBasedTrainLoop(BaseLoop):
"""int: Current iteration.""" """int: Current iteration."""
return self._iter return self._iter
def run(self) -> None: def run(self) -> torch.nn.Module:
"""Launch training.""" """Launch training."""
self.runner.call_hook('before_train') self.runner.call_hook('before_train')
@ -95,6 +95,7 @@ class EpochBasedTrainLoop(BaseLoop):
self.runner.val_loop.run() self.runner.val_loop.run()
self.runner.call_hook('after_train') self.runner.call_hook('after_train')
return self.runner.model
def run_epoch(self) -> None: def run_epoch(self) -> None:
"""Iterate one epoch.""" """Iterate one epoch."""
@ -266,6 +267,7 @@ class IterBasedTrainLoop(BaseLoop):
self.runner.call_hook('after_train_epoch') self.runner.call_hook('after_train_epoch')
self.runner.call_hook('after_train') self.runner.call_hook('after_train')
return self.runner.model
def run_iter(self, data_batch: Sequence[dict]) -> None: def run_iter(self, data_batch: Sequence[dict]) -> None:
"""Iterate one mini-batch. """Iterate one mini-batch.
@ -330,7 +332,7 @@ class ValLoop(BaseLoop):
'visualizer will be None.') 'visualizer will be None.')
self.fp16 = fp16 self.fp16 = fp16
def run(self): def run(self) -> dict:
"""Launch validation.""" """Launch validation."""
self.runner.call_hook('before_val') self.runner.call_hook('before_val')
self.runner.call_hook('before_val_epoch') self.runner.call_hook('before_val_epoch')
@ -342,6 +344,7 @@ class ValLoop(BaseLoop):
metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
self.runner.call_hook('after_val_epoch', metrics=metrics) self.runner.call_hook('after_val_epoch', metrics=metrics)
self.runner.call_hook('after_val') self.runner.call_hook('after_val')
return metrics
@torch.no_grad() @torch.no_grad()
def run_iter(self, idx, data_batch: Sequence[dict]): def run_iter(self, idx, data_batch: Sequence[dict]):
@ -399,7 +402,7 @@ class TestLoop(BaseLoop):
'visualizer will be None.') 'visualizer will be None.')
self.fp16 = fp16 self.fp16 = fp16
def run(self) -> None: def run(self) -> dict:
"""Launch test.""" """Launch test."""
self.runner.call_hook('before_test') self.runner.call_hook('before_test')
self.runner.call_hook('before_test_epoch') self.runner.call_hook('before_test_epoch')
@ -411,6 +414,7 @@ class TestLoop(BaseLoop):
metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
self.runner.call_hook('after_test_epoch', metrics=metrics) self.runner.call_hook('after_test_epoch', metrics=metrics)
self.runner.call_hook('after_test') self.runner.call_hook('after_test')
return metrics
@torch.no_grad() @torch.no_grad()
def run_iter(self, idx, data_batch: Sequence[dict]) -> None: def run_iter(self, idx, data_batch: Sequence[dict]) -> None:

View File

@ -1595,8 +1595,12 @@ class Runner:
self.load_checkpoint(self._load_from) self.load_checkpoint(self._load_from)
self._has_loaded = True self._has_loaded = True
def train(self) -> None: def train(self) -> nn.Module:
"""Launch training.""" """Launch training.
Returns:
nn.Module: The model after training.
"""
if self._train_loop is None: if self._train_loop is None:
raise RuntimeError( raise RuntimeError(
'`self._train_loop` should not be None when calling train ' '`self._train_loop` should not be None when calling train '
@ -1634,11 +1638,16 @@ class Runner:
self._train_loop.iter, # type: ignore self._train_loop.iter, # type: ignore
self._train_loop.max_iters) # type: ignore self._train_loop.max_iters) # type: ignore
self.train_loop.run() # type: ignore model = self.train_loop.run() # type: ignore
self.call_hook('after_run') self.call_hook('after_run')
return model
def val(self) -> None: def val(self) -> dict:
"""Launch validation.""" """Launch validation.
Returns:
dict: A dict of metrics on validation set.
"""
if self._val_loop is None: if self._val_loop is None:
raise RuntimeError( raise RuntimeError(
'`self._val_loop` should not be None when calling val method.' '`self._val_loop` should not be None when calling val method.'
@ -1652,11 +1661,16 @@ class Runner:
# make sure checkpoint-related hooks are triggered after `before_run` # make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume() self.load_or_resume()
self.val_loop.run() # type: ignore metrics = self.val_loop.run() # type: ignore
self.call_hook('after_run') self.call_hook('after_run')
return metrics
def test(self) -> None: def test(self) -> dict:
"""Launch test.""" """Launch test.
Returns:
dict: A dict of metrics on testing set.
"""
if self._test_loop is None: if self._test_loop is None:
raise RuntimeError( raise RuntimeError(
'`self._test_loop` should not be None when calling test ' '`self._test_loop` should not be None when calling test '
@ -1670,8 +1684,9 @@ class Runner:
# make sure checkpoint-related hooks are triggered after `before_run` # make sure checkpoint-related hooks are triggered after `before_run`
self.load_or_resume() self.load_or_resume()
self.test_loop.run() # type: ignore metrics = self.test_loop.run() # type: ignore
self.call_hook('after_run') self.call_hook('after_run')
return metrics
def call_hook(self, fn_name: str, **kwargs) -> None: def call_hook(self, fn_name: str, **kwargs) -> None:
"""Call all hooks. """Call all hooks.