mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Return loop results. (#392)
This commit is contained in:
parent
39e7efb04d
commit
cfee85ff16
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from typing import Dict, Union
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -33,5 +33,5 @@ class BaseLoop(metaclass=ABCMeta):
|
|||||||
return self._runner
|
return self._runner
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(self) -> None:
|
def run(self) -> Any:
|
||||||
"""Execute loop."""
|
"""Execute loop."""
|
||||||
|
@ -81,7 +81,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
"""int: Current iteration."""
|
"""int: Current iteration."""
|
||||||
return self._iter
|
return self._iter
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> torch.nn.Module:
|
||||||
"""Launch training."""
|
"""Launch training."""
|
||||||
self.runner.call_hook('before_train')
|
self.runner.call_hook('before_train')
|
||||||
|
|
||||||
@ -95,6 +95,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||||||
self.runner.val_loop.run()
|
self.runner.val_loop.run()
|
||||||
|
|
||||||
self.runner.call_hook('after_train')
|
self.runner.call_hook('after_train')
|
||||||
|
return self.runner.model
|
||||||
|
|
||||||
def run_epoch(self) -> None:
|
def run_epoch(self) -> None:
|
||||||
"""Iterate one epoch."""
|
"""Iterate one epoch."""
|
||||||
@ -266,6 +267,7 @@ class IterBasedTrainLoop(BaseLoop):
|
|||||||
|
|
||||||
self.runner.call_hook('after_train_epoch')
|
self.runner.call_hook('after_train_epoch')
|
||||||
self.runner.call_hook('after_train')
|
self.runner.call_hook('after_train')
|
||||||
|
return self.runner.model
|
||||||
|
|
||||||
def run_iter(self, data_batch: Sequence[dict]) -> None:
|
def run_iter(self, data_batch: Sequence[dict]) -> None:
|
||||||
"""Iterate one mini-batch.
|
"""Iterate one mini-batch.
|
||||||
@ -330,7 +332,7 @@ class ValLoop(BaseLoop):
|
|||||||
'visualizer will be None.')
|
'visualizer will be None.')
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
|
|
||||||
def run(self):
|
def run(self) -> dict:
|
||||||
"""Launch validation."""
|
"""Launch validation."""
|
||||||
self.runner.call_hook('before_val')
|
self.runner.call_hook('before_val')
|
||||||
self.runner.call_hook('before_val_epoch')
|
self.runner.call_hook('before_val_epoch')
|
||||||
@ -342,6 +344,7 @@ class ValLoop(BaseLoop):
|
|||||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||||
self.runner.call_hook('after_val_epoch', metrics=metrics)
|
self.runner.call_hook('after_val_epoch', metrics=metrics)
|
||||||
self.runner.call_hook('after_val')
|
self.runner.call_hook('after_val')
|
||||||
|
return metrics
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def run_iter(self, idx, data_batch: Sequence[dict]):
|
def run_iter(self, idx, data_batch: Sequence[dict]):
|
||||||
@ -399,7 +402,7 @@ class TestLoop(BaseLoop):
|
|||||||
'visualizer will be None.')
|
'visualizer will be None.')
|
||||||
self.fp16 = fp16
|
self.fp16 = fp16
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> dict:
|
||||||
"""Launch test."""
|
"""Launch test."""
|
||||||
self.runner.call_hook('before_test')
|
self.runner.call_hook('before_test')
|
||||||
self.runner.call_hook('before_test_epoch')
|
self.runner.call_hook('before_test_epoch')
|
||||||
@ -411,6 +414,7 @@ class TestLoop(BaseLoop):
|
|||||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||||
self.runner.call_hook('after_test_epoch', metrics=metrics)
|
self.runner.call_hook('after_test_epoch', metrics=metrics)
|
||||||
self.runner.call_hook('after_test')
|
self.runner.call_hook('after_test')
|
||||||
|
return metrics
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
|
def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
|
||||||
|
@ -1595,8 +1595,12 @@ class Runner:
|
|||||||
self.load_checkpoint(self._load_from)
|
self.load_checkpoint(self._load_from)
|
||||||
self._has_loaded = True
|
self._has_loaded = True
|
||||||
|
|
||||||
def train(self) -> None:
|
def train(self) -> nn.Module:
|
||||||
"""Launch training."""
|
"""Launch training.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: The model after training.
|
||||||
|
"""
|
||||||
if self._train_loop is None:
|
if self._train_loop is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'`self._train_loop` should not be None when calling train '
|
'`self._train_loop` should not be None when calling train '
|
||||||
@ -1634,11 +1638,16 @@ class Runner:
|
|||||||
self._train_loop.iter, # type: ignore
|
self._train_loop.iter, # type: ignore
|
||||||
self._train_loop.max_iters) # type: ignore
|
self._train_loop.max_iters) # type: ignore
|
||||||
|
|
||||||
self.train_loop.run() # type: ignore
|
model = self.train_loop.run() # type: ignore
|
||||||
self.call_hook('after_run')
|
self.call_hook('after_run')
|
||||||
|
return model
|
||||||
|
|
||||||
def val(self) -> None:
|
def val(self) -> dict:
|
||||||
"""Launch validation."""
|
"""Launch validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dict of metrics on validation set.
|
||||||
|
"""
|
||||||
if self._val_loop is None:
|
if self._val_loop is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'`self._val_loop` should not be None when calling val method.'
|
'`self._val_loop` should not be None when calling val method.'
|
||||||
@ -1652,11 +1661,16 @@ class Runner:
|
|||||||
# make sure checkpoint-related hooks are triggered after `before_run`
|
# make sure checkpoint-related hooks are triggered after `before_run`
|
||||||
self.load_or_resume()
|
self.load_or_resume()
|
||||||
|
|
||||||
self.val_loop.run() # type: ignore
|
metrics = self.val_loop.run() # type: ignore
|
||||||
self.call_hook('after_run')
|
self.call_hook('after_run')
|
||||||
|
return metrics
|
||||||
|
|
||||||
def test(self) -> None:
|
def test(self) -> dict:
|
||||||
"""Launch test."""
|
"""Launch test.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dict of metrics on testing set.
|
||||||
|
"""
|
||||||
if self._test_loop is None:
|
if self._test_loop is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'`self._test_loop` should not be None when calling test '
|
'`self._test_loop` should not be None when calling test '
|
||||||
@ -1670,8 +1684,9 @@ class Runner:
|
|||||||
# make sure checkpoint-related hooks are triggered after `before_run`
|
# make sure checkpoint-related hooks are triggered after `before_run`
|
||||||
self.load_or_resume()
|
self.load_or_resume()
|
||||||
|
|
||||||
self.test_loop.run() # type: ignore
|
metrics = self.test_loop.run() # type: ignore
|
||||||
self.call_hook('after_run')
|
self.call_hook('after_run')
|
||||||
|
return metrics
|
||||||
|
|
||||||
def call_hook(self, fn_name: str, **kwargs) -> None:
|
def call_hook(self, fn_name: str, **kwargs) -> None:
|
||||||
"""Call all hooks.
|
"""Call all hooks.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user