mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix output argument of after_iter, train_after_ter and val_after_iter (#115)
* Fix hook * Fix * Fix docs * FIx * Fix * Fix as comment
This commit is contained in:
parent
3bdd27c4e2
commit
ec3034b765
@ -168,18 +168,17 @@ class CheckpointHook(Hook):
|
||||
else:
|
||||
break
|
||||
|
||||
def after_train_iter(
|
||||
self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs=Optional[dict]) -> None:
|
||||
"""Save the checkpoint and synchronize buffers after each iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
if self.by_epoch:
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -37,14 +37,16 @@ class EmptyCacheHook(Hook):
|
||||
def after_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
outputs:
|
||||
Optional[Union[dict, Sequence[BaseDataSample]]] = None)\
|
||||
-> None:
|
||||
"""Empty cache after an iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample]): Outputs from model.
|
||||
outputs (dict or sequence, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
if self._after_iter:
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
|
||||
@ -19,7 +19,8 @@ class Hook:
|
||||
operations before the training process.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
runner (Runner): The runner of the training/validation/testing
|
||||
process.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -27,11 +28,66 @@ class Hook:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after the training process.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training/validation/testing
|
||||
process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_train(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before train.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_train(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after train.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_val(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before val.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_val(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after val.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_test(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before test.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_test(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after test.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each epoch.
|
||||
@ -64,7 +120,9 @@ class Hook:
|
||||
def after_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
outputs:
|
||||
Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
|
||||
-> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each epoch.
|
||||
|
||||
@ -72,8 +130,8 @@ class Hook:
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from model. Defaults
|
||||
to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -184,11 +242,10 @@ class Hook:
|
||||
"""
|
||||
self.before_iter(runner, data_batch=None)
|
||||
|
||||
def after_train_iter(
|
||||
self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[dict] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each training iteration.
|
||||
|
||||
@ -196,16 +253,16 @@ class Hook:
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
self.after_iter(runner, data_batch=None, outputs=None)
|
||||
|
||||
def after_val_iter(
|
||||
self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
def after_val_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) \
|
||||
-> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each validation iteration.
|
||||
|
||||
@ -213,7 +270,7 @@ class Hook:
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from
|
||||
outputs (dict or sequence, optional): Outputs from
|
||||
model. Defaults to None.
|
||||
"""
|
||||
self.after_iter(runner, data_batch=None, outputs=None)
|
||||
@ -230,7 +287,7 @@ class Hook:
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
self.after_iter(runner, data_batch=None, outputs=None)
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import time
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.registry import HOOKS
|
||||
@ -40,15 +40,17 @@ class IterTimerHook(Hook):
|
||||
def after_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
outputs:
|
||||
Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
|
||||
-> None:
|
||||
"""Logging time for a iteration and update the time flag.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample]): Outputs from model.
|
||||
Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from model. Defaults
|
||||
to None.
|
||||
"""
|
||||
# TODO: update for new logging system
|
||||
runner.log_buffer.update({'time': time.time() - self.t})
|
||||
|
@ -171,18 +171,17 @@ class LoggerHook(Hook):
|
||||
if runner.meta is not None:
|
||||
runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
|
||||
|
||||
def after_train_iter(
|
||||
self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[dict] = None) -> None:
|
||||
"""Record training logs.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample], optional): Data from
|
||||
dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
if runner.meta is not None and 'exp_name' in runner.meta:
|
||||
|
@ -56,11 +56,10 @@ class OptimizerHook(Hook):
|
||||
return clip_grad.clip_grad_norm_(params, **self.grad_clip)
|
||||
return None
|
||||
|
||||
def after_train_iter(
|
||||
self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[dict] = None) -> None:
|
||||
"""All operations need to be finished after each training iteration.
|
||||
|
||||
This function will finish following 3 operations:
|
||||
@ -80,7 +79,7 @@ class OptimizerHook(Hook):
|
||||
from dataloader. In order to keep this interface consistent
|
||||
with other hooks, we keep ``data_batch`` here.
|
||||
Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
In order to keep this interface consistent with other hooks,
|
||||
we keep ``outputs`` here. Defaults to None.
|
||||
"""
|
||||
|
@ -15,11 +15,10 @@ class ParamSchedulerHook(Hook):
|
||||
|
||||
priority = 'LOW'
|
||||
|
||||
def after_train_iter(
|
||||
self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
def after_train_iter(self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[dict] = None) -> None:
|
||||
"""Call step function for each scheduler after each iteration.
|
||||
|
||||
Args:
|
||||
@ -28,7 +27,7 @@ class ParamSchedulerHook(Hook):
|
||||
from dataloader. In order to keep this interface consistent
|
||||
with other hooks, we keep ``data_batch`` here.
|
||||
Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
In order to keep this interface consistent with other hooks, we
|
||||
keep ``data_batch`` here. Defaults to None.
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user