[Fix] Fix output argument of after_iter, train_after_ter and val_after_iter (#115)

* Fix hook

* Fix

* Fix docs

* FIx

* Fix

* Fix as comment
This commit is contained in:
Mashiro 2022-03-09 23:10:19 +08:00 committed by GitHub
parent 3bdd27c4e2
commit ec3034b765
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 106 additions and 49 deletions

View File

@ -168,18 +168,17 @@ class CheckpointHook(Hook):
else: else:
break break
def after_train_iter( def after_train_iter(self,
self, runner,
runner, data_batch: DATA_BATCH = None,
data_batch: DATA_BATCH = None, outputs=Optional[dict]) -> None:
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Save the checkpoint and synchronize buffers after each iteration. """Save the checkpoint and synchronize buffers after each iteration.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None. from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model. outputs (dict, optional): Outputs from model.
Defaults to None. Defaults to None.
""" """
if self.by_epoch: if self.by_epoch:

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Tuple from typing import Any, Optional, Sequence, Tuple, Union
import torch import torch
@ -37,14 +37,16 @@ class EmptyCacheHook(Hook):
def after_iter(self, def after_iter(self,
runner, runner,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs:
Optional[Union[dict, Sequence[BaseDataSample]]] = None)\
-> None:
"""Empty cache after an iteration. """Empty cache after an iteration.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None. from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model. outputs (dict or sequence, optional): Outputs from model.
Defaults to None. Defaults to None.
""" """
if self._after_iter: if self._after_iter:

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Tuple from typing import Any, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataSample from mmengine.data import BaseDataSample
@ -19,7 +19,8 @@ class Hook:
operations before the training process. operations before the training process.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training/validation/testing
process.
""" """
pass pass
@ -27,11 +28,66 @@ class Hook:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after the training process. operations after the training process.
Args:
runner (Runner): The runner of the training/validation/testing
process.
"""
pass
def before_train(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before train.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
pass pass
def after_train(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after train.
Args:
runner (Runner): The runner of the training process.
"""
pass
def before_val(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before val.
Args:
runner (Runner): The runner of the validation process.
"""
pass
def after_val(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after val.
Args:
runner (Runner): The runner of the validation process.
"""
pass
def before_test(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations before test.
Args:
runner (Runner): The runner of the testing process.
"""
pass
def after_test(self, runner) -> None:
"""All subclasses should override this method, if they need any
operations after test.
Args:
runner (Runner): The runner of the testing process.
"""
pass
def before_epoch(self, runner) -> None: def before_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each epoch. operations before each epoch.
@ -64,7 +120,9 @@ class Hook:
def after_iter(self, def after_iter(self,
runner, runner,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs:
Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
-> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each epoch. operations after each epoch.
@ -72,8 +130,8 @@ class Hook:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None. Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model. outputs (dict or sequence, optional): Outputs from model. Defaults
Defaults to None. to None.
""" """
pass pass
@ -184,11 +242,10 @@ class Hook:
""" """
self.before_iter(runner, data_batch=None) self.before_iter(runner, data_batch=None)
def after_train_iter( def after_train_iter(self,
self, runner,
runner, data_batch: DATA_BATCH = None,
data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None:
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each training iteration. operations after each training iteration.
@ -196,16 +253,16 @@ class Hook:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None. Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model. outputs (dict, optional): Outputs from model.
Defaults to None. Defaults to None.
""" """
self.after_iter(runner, data_batch=None, outputs=None) self.after_iter(runner, data_batch=None, outputs=None)
def after_val_iter( def after_val_iter(self,
self, runner,
runner, data_batch: DATA_BATCH = None,
data_batch: DATA_BATCH = None, outputs: Optional[Sequence[BaseDataSample]] = None) \
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each validation iteration. operations after each validation iteration.
@ -213,7 +270,7 @@ class Hook:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None. Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from outputs (dict or sequence, optional): Outputs from
model. Defaults to None. model. Defaults to None.
""" """
self.after_iter(runner, data_batch=None, outputs=None) self.after_iter(runner, data_batch=None, outputs=None)
@ -230,7 +287,7 @@ class Hook:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None. Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model. outputs (dict, optional): Outputs from model.
Defaults to None. Defaults to None.
""" """
self.after_iter(runner, data_batch=None, outputs=None) self.after_iter(runner, data_batch=None, outputs=None)

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import time import time
from typing import Any, Optional, Sequence, Tuple from typing import Any, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataSample from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
@ -40,15 +40,17 @@ class IterTimerHook(Hook):
def after_iter(self, def after_iter(self,
runner, runner,
data_batch: DATA_BATCH = None, data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs:
Optional[Union[dict, Sequence[BaseDataSample]]] = None) \
-> None:
"""Logging time for a iteration and update the time flag. """Logging time for a iteration and update the time flag.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None. from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model. outputs (dict or sequence, optional): Outputs from model. Defaults
Defaults to None. to None.
""" """
# TODO: update for new logging system # TODO: update for new logging system
runner.log_buffer.update({'time': time.time() - self.t}) runner.log_buffer.update({'time': time.time() - self.t})

View File

@ -171,18 +171,17 @@ class LoggerHook(Hook):
if runner.meta is not None: if runner.meta is not None:
runner.writer.add_params(runner.meta, file_path=self.yaml_log_path) runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
def after_train_iter( def after_train_iter(self,
self, runner,
runner, data_batch: DATA_BATCH = None,
data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None:
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Record training logs. """Record training logs.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None. dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model. outputs (dict, optional): Outputs from model.
Defaults to None. Defaults to None.
""" """
if runner.meta is not None and 'exp_name' in runner.meta: if runner.meta is not None and 'exp_name' in runner.meta:

View File

@ -56,11 +56,10 @@ class OptimizerHook(Hook):
return clip_grad.clip_grad_norm_(params, **self.grad_clip) return clip_grad.clip_grad_norm_(params, **self.grad_clip)
return None return None
def after_train_iter( def after_train_iter(self,
self, runner,
runner, data_batch: DATA_BATCH = None,
data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None:
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All operations need to be finished after each training iteration. """All operations need to be finished after each training iteration.
This function will finish following 3 operations: This function will finish following 3 operations:
@ -80,7 +79,7 @@ class OptimizerHook(Hook):
from dataloader. In order to keep this interface consistent from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here. with other hooks, we keep ``data_batch`` here.
Defaults to None. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model. outputs (dict, optional): Outputs from model.
In order to keep this interface consistent with other hooks, In order to keep this interface consistent with other hooks,
we keep ``outputs`` here. Defaults to None. we keep ``outputs`` here. Defaults to None.
""" """

View File

@ -15,11 +15,10 @@ class ParamSchedulerHook(Hook):
priority = 'LOW' priority = 'LOW'
def after_train_iter( def after_train_iter(self,
self, runner,
runner, data_batch: DATA_BATCH = None,
data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None:
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Call step function for each scheduler after each iteration. """Call step function for each scheduler after each iteration.
Args: Args:
@ -28,7 +27,7 @@ class ParamSchedulerHook(Hook):
from dataloader. In order to keep this interface consistent from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here. with other hooks, we keep ``data_batch`` here.
Defaults to None. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model. outputs (dict, optional): Outputs from model.
In order to keep this interface consistent with other hooks, we In order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None. keep ``data_batch`` here. Defaults to None.
""" """