mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Return loop results. (#392)
This commit is contained in:
parent
39e7efb04d
commit
cfee85ff16
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Dict, Union
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
@ -33,5 +33,5 @@ class BaseLoop(metaclass=ABCMeta):
|
||||
return self._runner
|
||||
|
||||
@abstractmethod
|
||||
def run(self) -> None:
|
||||
def run(self) -> Any:
|
||||
"""Execute loop."""
|
||||
|
@ -81,7 +81,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
"""int: Current iteration."""
|
||||
return self._iter
|
||||
|
||||
def run(self) -> None:
|
||||
def run(self) -> torch.nn.Module:
|
||||
"""Launch training."""
|
||||
self.runner.call_hook('before_train')
|
||||
|
||||
@ -95,6 +95,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
||||
self.runner.val_loop.run()
|
||||
|
||||
self.runner.call_hook('after_train')
|
||||
return self.runner.model
|
||||
|
||||
def run_epoch(self) -> None:
|
||||
"""Iterate one epoch."""
|
||||
@ -266,6 +267,7 @@ class IterBasedTrainLoop(BaseLoop):
|
||||
|
||||
self.runner.call_hook('after_train_epoch')
|
||||
self.runner.call_hook('after_train')
|
||||
return self.runner.model
|
||||
|
||||
def run_iter(self, data_batch: Sequence[dict]) -> None:
|
||||
"""Iterate one mini-batch.
|
||||
@ -330,7 +332,7 @@ class ValLoop(BaseLoop):
|
||||
'visualizer will be None.')
|
||||
self.fp16 = fp16
|
||||
|
||||
def run(self):
|
||||
def run(self) -> dict:
|
||||
"""Launch validation."""
|
||||
self.runner.call_hook('before_val')
|
||||
self.runner.call_hook('before_val_epoch')
|
||||
@ -342,6 +344,7 @@ class ValLoop(BaseLoop):
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
self.runner.call_hook('after_val_epoch', metrics=metrics)
|
||||
self.runner.call_hook('after_val')
|
||||
return metrics
|
||||
|
||||
@torch.no_grad()
|
||||
def run_iter(self, idx, data_batch: Sequence[dict]):
|
||||
@ -399,7 +402,7 @@ class TestLoop(BaseLoop):
|
||||
'visualizer will be None.')
|
||||
self.fp16 = fp16
|
||||
|
||||
def run(self) -> None:
|
||||
def run(self) -> dict:
|
||||
"""Launch test."""
|
||||
self.runner.call_hook('before_test')
|
||||
self.runner.call_hook('before_test_epoch')
|
||||
@ -411,6 +414,7 @@ class TestLoop(BaseLoop):
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
self.runner.call_hook('after_test_epoch', metrics=metrics)
|
||||
self.runner.call_hook('after_test')
|
||||
return metrics
|
||||
|
||||
@torch.no_grad()
|
||||
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
|
||||
|
@ -1595,8 +1595,12 @@ class Runner:
|
||||
self.load_checkpoint(self._load_from)
|
||||
self._has_loaded = True
|
||||
|
||||
def train(self) -> None:
|
||||
"""Launch training."""
|
||||
def train(self) -> nn.Module:
|
||||
"""Launch training.
|
||||
|
||||
Returns:
|
||||
nn.Module: The model after training.
|
||||
"""
|
||||
if self._train_loop is None:
|
||||
raise RuntimeError(
|
||||
'`self._train_loop` should not be None when calling train '
|
||||
@ -1634,11 +1638,16 @@ class Runner:
|
||||
self._train_loop.iter, # type: ignore
|
||||
self._train_loop.max_iters) # type: ignore
|
||||
|
||||
self.train_loop.run() # type: ignore
|
||||
model = self.train_loop.run() # type: ignore
|
||||
self.call_hook('after_run')
|
||||
return model
|
||||
|
||||
def val(self) -> None:
|
||||
"""Launch validation."""
|
||||
def val(self) -> dict:
|
||||
"""Launch validation.
|
||||
|
||||
Returns:
|
||||
dict: A dict of metrics on validation set.
|
||||
"""
|
||||
if self._val_loop is None:
|
||||
raise RuntimeError(
|
||||
'`self._val_loop` should not be None when calling val method.'
|
||||
@ -1652,11 +1661,16 @@ class Runner:
|
||||
# make sure checkpoint-related hooks are triggered after `before_run`
|
||||
self.load_or_resume()
|
||||
|
||||
self.val_loop.run() # type: ignore
|
||||
metrics = self.val_loop.run() # type: ignore
|
||||
self.call_hook('after_run')
|
||||
return metrics
|
||||
|
||||
def test(self) -> None:
|
||||
"""Launch test."""
|
||||
def test(self) -> dict:
|
||||
"""Launch test.
|
||||
|
||||
Returns:
|
||||
dict: A dict of metrics on testing set.
|
||||
"""
|
||||
if self._test_loop is None:
|
||||
raise RuntimeError(
|
||||
'`self._test_loop` should not be None when calling test '
|
||||
@ -1670,8 +1684,9 @@ class Runner:
|
||||
# make sure checkpoint-related hooks are triggered after `before_run`
|
||||
self.load_or_resume()
|
||||
|
||||
self.test_loop.run() # type: ignore
|
||||
metrics = self.test_loop.run() # type: ignore
|
||||
self.call_hook('after_run')
|
||||
return metrics
|
||||
|
||||
def call_hook(self, fn_name: str, **kwargs) -> None:
|
||||
"""Call all hooks.
|
||||
|
Loading…
x
Reference in New Issue
Block a user