mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix output argument of after_iter, train_after_ter and val_after_iter (#115)
* Fix hook * Fix * Fix docs * FIx * Fix * Fix as comment
This commit is contained in:
parent
3bdd27c4e2
commit
ec3034b765
@ -168,18 +168,17 @@ class CheckpointHook(Hook):
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
def after_train_iter(
|
def after_train_iter(self,
|
||||||
self,
|
runner,
|
||||||
runner,
|
data_batch: DATA_BATCH = None,
|
||||||
data_batch: DATA_BATCH = None,
|
outputs=Optional[dict]) -> None:
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
|
||||||
"""Save the checkpoint and synchronize buffers after each iteration.
|
"""Save the checkpoint and synchronize buffers after each iteration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
from dataloader. Defaults to None.
|
from dataloader. Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
outputs (dict, optional): Outputs from model.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Any, Optional, Sequence, Tuple
|
from typing import Any, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -37,14 +37,16 @@ class EmptyCacheHook(Hook):
|
|||||||
def after_iter(self,
|
def after_iter(self,
|
||||||
runner,
|
runner,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs:
|
||||||
|
Optional[Union[dict, Sequence[BaseDataSample]]] = None)\
|
||||||
|
-> None:
|
||||||
"""Empty cache after an iteration.
|
"""Empty cache after an iteration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
from dataloader. Defaults to None.
|
from dataloader. Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample]): Outputs from model.
|
outputs (dict or sequence, optional): Outputs from model.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
if self._after_iter:
|
if self._after_iter:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Any, Optional, Sequence, Tuple
|
from typing import Any, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from mmengine.data import BaseDataSample
|
from mmengine.data import BaseDataSample
|
||||||
|
|
||||||
@ -19,7 +19,8 @@ class Hook:
|
|||||||
operations before the training process.
|
operations before the training process.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training/validation/testing
|
||||||
|
process.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -27,11 +28,66 @@ class Hook:
|
|||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after the training process.
|
operations after the training process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training/validation/testing
|
||||||
|
process.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def before_train(self, runner) -> None:
|
||||||
|
"""All subclasses should override this method, if they need any
|
||||||
|
operations before train.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def after_train(self, runner) -> None:
|
||||||
|
"""All subclasses should override this method, if they need any
|
||||||
|
operations after train.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def before_val(self, runner) -> None:
|
||||||
|
"""All subclasses should override this method, if they need any
|
||||||
|
operations before val.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the validation process.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def after_val(self, runner) -> None:
|
||||||
|
"""All subclasses should override this method, if they need any
|
||||||
|
operations after val.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the validation process.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def before_test(self, runner) -> None:
|
||||||
|
"""All subclasses should override this method, if they need any
|
||||||
|
operations before test.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the testing process.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def after_test(self, runner) -> None:
|
||||||
|
"""All subclasses should override this method, if they need any
|
||||||
|
operations after test.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the testing process.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def before_epoch(self, runner) -> None:
|
def before_epoch(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before each epoch.
|
operations before each epoch.
|
||||||
@ -64,7 +120,9 @@ class Hook:
|
|||||||
def after_iter(self,
|
def after_iter(self,
|
||||||
runner,
|
runner,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs:
|
||||||
|
Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
|
||||||
|
-> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after each epoch.
|
operations after each epoch.
|
||||||
|
|
||||||
@ -72,8 +130,8 @@ class Hook:
|
|||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||||
Data from dataloader. Defaults to None.
|
Data from dataloader. Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
outputs (dict or sequence, optional): Outputs from model. Defaults
|
||||||
Defaults to None.
|
to None.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -184,11 +242,10 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
self.before_iter(runner, data_batch=None)
|
self.before_iter(runner, data_batch=None)
|
||||||
|
|
||||||
def after_train_iter(
|
def after_train_iter(self,
|
||||||
self,
|
runner,
|
||||||
runner,
|
data_batch: DATA_BATCH = None,
|
||||||
data_batch: DATA_BATCH = None,
|
outputs: Optional[dict] = None) -> None:
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after each training iteration.
|
operations after each training iteration.
|
||||||
|
|
||||||
@ -196,16 +253,16 @@ class Hook:
|
|||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||||
Data from dataloader. Defaults to None.
|
Data from dataloader. Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
outputs (dict, optional): Outputs from model.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.after_iter(runner, data_batch=None, outputs=None)
|
self.after_iter(runner, data_batch=None, outputs=None)
|
||||||
|
|
||||||
def after_val_iter(
|
def after_val_iter(self,
|
||||||
self,
|
runner,
|
||||||
runner,
|
data_batch: DATA_BATCH = None,
|
||||||
data_batch: DATA_BATCH = None,
|
outputs: Optional[Sequence[BaseDataSample]] = None) \
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
-> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after each validation iteration.
|
operations after each validation iteration.
|
||||||
|
|
||||||
@ -213,7 +270,7 @@ class Hook:
|
|||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||||
Data from dataloader. Defaults to None.
|
Data from dataloader. Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample], optional): Outputs from
|
outputs (dict or sequence, optional): Outputs from
|
||||||
model. Defaults to None.
|
model. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.after_iter(runner, data_batch=None, outputs=None)
|
self.after_iter(runner, data_batch=None, outputs=None)
|
||||||
@ -230,7 +287,7 @@ class Hook:
|
|||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||||
Data from dataloader. Defaults to None.
|
Data from dataloader. Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
outputs (dict, optional): Outputs from model.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.after_iter(runner, data_batch=None, outputs=None)
|
self.after_iter(runner, data_batch=None, outputs=None)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import time
|
import time
|
||||||
from typing import Any, Optional, Sequence, Tuple
|
from typing import Any, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from mmengine.data import BaseDataSample
|
from mmengine.data import BaseDataSample
|
||||||
from mmengine.registry import HOOKS
|
from mmengine.registry import HOOKS
|
||||||
@ -40,15 +40,17 @@ class IterTimerHook(Hook):
|
|||||||
def after_iter(self,
|
def after_iter(self,
|
||||||
runner,
|
runner,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs:
|
||||||
|
Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
|
||||||
|
-> None:
|
||||||
"""Logging time for a iteration and update the time flag.
|
"""Logging time for a iteration and update the time flag.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
from dataloader. Defaults to None.
|
from dataloader. Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample]): Outputs from model.
|
outputs (dict or sequence, optional): Outputs from model. Defaults
|
||||||
Defaults to None.
|
to None.
|
||||||
"""
|
"""
|
||||||
# TODO: update for new logging system
|
# TODO: update for new logging system
|
||||||
runner.log_buffer.update({'time': time.time() - self.t})
|
runner.log_buffer.update({'time': time.time() - self.t})
|
||||||
|
@ -171,18 +171,17 @@ class LoggerHook(Hook):
|
|||||||
if runner.meta is not None:
|
if runner.meta is not None:
|
||||||
runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
|
runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
|
||||||
|
|
||||||
def after_train_iter(
|
def after_train_iter(self,
|
||||||
self,
|
runner,
|
||||||
runner,
|
data_batch: DATA_BATCH = None,
|
||||||
data_batch: DATA_BATCH = None,
|
outputs: Optional[dict] = None) -> None:
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
|
||||||
"""Record training logs.
|
"""Record training logs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
data_batch (Sequence[BaseDataSample], optional): Data from
|
data_batch (Sequence[BaseDataSample], optional): Data from
|
||||||
dataloader. Defaults to None.
|
dataloader. Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
outputs (dict, optional): Outputs from model.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
if runner.meta is not None and 'exp_name' in runner.meta:
|
if runner.meta is not None and 'exp_name' in runner.meta:
|
||||||
|
@ -56,11 +56,10 @@ class OptimizerHook(Hook):
|
|||||||
return clip_grad.clip_grad_norm_(params, **self.grad_clip)
|
return clip_grad.clip_grad_norm_(params, **self.grad_clip)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def after_train_iter(
|
def after_train_iter(self,
|
||||||
self,
|
runner,
|
||||||
runner,
|
data_batch: DATA_BATCH = None,
|
||||||
data_batch: DATA_BATCH = None,
|
outputs: Optional[dict] = None) -> None:
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
|
||||||
"""All operations need to be finished after each training iteration.
|
"""All operations need to be finished after each training iteration.
|
||||||
|
|
||||||
This function will finish following 3 operations:
|
This function will finish following 3 operations:
|
||||||
@ -80,7 +79,7 @@ class OptimizerHook(Hook):
|
|||||||
from dataloader. In order to keep this interface consistent
|
from dataloader. In order to keep this interface consistent
|
||||||
with other hooks, we keep ``data_batch`` here.
|
with other hooks, we keep ``data_batch`` here.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
outputs (dict, optional): Outputs from model.
|
||||||
In order to keep this interface consistent with other hooks,
|
In order to keep this interface consistent with other hooks,
|
||||||
we keep ``outputs`` here. Defaults to None.
|
we keep ``outputs`` here. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
@ -15,11 +15,10 @@ class ParamSchedulerHook(Hook):
|
|||||||
|
|
||||||
priority = 'LOW'
|
priority = 'LOW'
|
||||||
|
|
||||||
def after_train_iter(
|
def after_train_iter(self,
|
||||||
self,
|
runner,
|
||||||
runner,
|
data_batch: DATA_BATCH = None,
|
||||||
data_batch: DATA_BATCH = None,
|
outputs: Optional[dict] = None) -> None:
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
|
||||||
"""Call step function for each scheduler after each iteration.
|
"""Call step function for each scheduler after each iteration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -28,7 +27,7 @@ class ParamSchedulerHook(Hook):
|
|||||||
from dataloader. In order to keep this interface consistent
|
from dataloader. In order to keep this interface consistent
|
||||||
with other hooks, we keep ``data_batch`` here.
|
with other hooks, we keep ``data_batch`` here.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
outputs (dict, optional): Outputs from model.
|
||||||
In order to keep this interface consistent with other hooks, we
|
In order to keep this interface consistent with other hooks, we
|
||||||
keep ``data_batch`` here. Defaults to None.
|
keep ``data_batch`` here. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user