diff --git a/mmengine/runner/base_loop.py b/mmengine/runner/base_loop.py index 48a6f69c..5bae459a 100644 --- a/mmengine/runner/base_loop.py +++ b/mmengine/runner/base_loop.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod -from typing import Dict, Union +from typing import Any, Dict, Union from torch.utils.data import DataLoader @@ -33,5 +33,5 @@ class BaseLoop(metaclass=ABCMeta): return self._runner @abstractmethod - def run(self) -> None: + def run(self) -> Any: """Execute loop.""" diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 3e32e549..e155263b 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -81,7 +81,7 @@ class EpochBasedTrainLoop(BaseLoop): """int: Current iteration.""" return self._iter - def run(self) -> None: + def run(self) -> torch.nn.Module: """Launch training.""" self.runner.call_hook('before_train') @@ -95,6 +95,7 @@ class EpochBasedTrainLoop(BaseLoop): self.runner.val_loop.run() self.runner.call_hook('after_train') + return self.runner.model def run_epoch(self) -> None: """Iterate one epoch.""" @@ -266,6 +267,7 @@ class IterBasedTrainLoop(BaseLoop): self.runner.call_hook('after_train_epoch') self.runner.call_hook('after_train') + return self.runner.model def run_iter(self, data_batch: Sequence[dict]) -> None: """Iterate one mini-batch. @@ -330,7 +332,7 @@ class ValLoop(BaseLoop): 'visualizer will be None.') self.fp16 = fp16 - def run(self): + def run(self) -> dict: """Launch validation.""" self.runner.call_hook('before_val') self.runner.call_hook('before_val_epoch') @@ -342,6 +344,7 @@ class ValLoop(BaseLoop): metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) self.runner.call_hook('after_val_epoch', metrics=metrics) self.runner.call_hook('after_val') + return metrics @torch.no_grad() def run_iter(self, idx, data_batch: Sequence[dict]): @@ -399,7 +402,7 @@ class TestLoop(BaseLoop): 'visualizer will be None.') self.fp16 = fp16 - def run(self) -> None: + def run(self) -> dict: """Launch test.""" self.runner.call_hook('before_test') self.runner.call_hook('before_test_epoch') @@ -411,6 +414,7 @@ class TestLoop(BaseLoop): metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) self.runner.call_hook('after_test_epoch', metrics=metrics) self.runner.call_hook('after_test') + return metrics @torch.no_grad() def run_iter(self, idx, data_batch: Sequence[dict]) -> None: diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index f04897b5..35509174 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1595,8 +1595,12 @@ class Runner: self.load_checkpoint(self._load_from) self._has_loaded = True - def train(self) -> None: - """Launch training.""" + def train(self) -> nn.Module: + """Launch training. + + Returns: + nn.Module: The model after training. + """ if self._train_loop is None: raise RuntimeError( '`self._train_loop` should not be None when calling train ' @@ -1634,11 +1638,16 @@ class Runner: self._train_loop.iter, # type: ignore self._train_loop.max_iters) # type: ignore - self.train_loop.run() # type: ignore + model = self.train_loop.run() # type: ignore self.call_hook('after_run') + return model - def val(self) -> None: - """Launch validation.""" + def val(self) -> dict: + """Launch validation. + + Returns: + dict: A dict of metrics on validation set. + """ if self._val_loop is None: raise RuntimeError( '`self._val_loop` should not be None when calling val method.' @@ -1652,11 +1661,16 @@ class Runner: # make sure checkpoint-related hooks are triggered after `before_run` self.load_or_resume() - self.val_loop.run() # type: ignore + metrics = self.val_loop.run() # type: ignore self.call_hook('after_run') + return metrics - def test(self) -> None: - """Launch test.""" + def test(self) -> dict: + """Launch test. + + Returns: + dict: A dict of metrics on testing set. + """ if self._test_loop is None: raise RuntimeError( '`self._test_loop` should not be None when calling test ' @@ -1670,8 +1684,9 @@ class Runner: # make sure checkpoint-related hooks are triggered after `before_run` self.load_or_resume() - self.test_loop.run() # type: ignore + metrics = self.test_loop.run() # type: ignore self.call_hook('after_run') + return metrics def call_hook(self, fn_name: str, **kwargs) -> None: """Call all hooks.