mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
fix type hint in hooks (#106)
This commit is contained in:
parent
9f0d1a9628
commit
ed8dcb4c61
@ -9,6 +9,8 @@ from mmengine.fileio import FileClient
|
|||||||
from mmengine.registry import HOOKS
|
from mmengine.registry import HOOKS
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
|
|
||||||
|
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
class CheckpointHook(Hook):
|
class CheckpointHook(Hook):
|
||||||
@ -65,7 +67,7 @@ class CheckpointHook(Hook):
|
|||||||
self.sync_buffer = sync_buffer
|
self.sync_buffer = sync_buffer
|
||||||
self.file_client_args = file_client_args
|
self.file_client_args = file_client_args
|
||||||
|
|
||||||
def before_run(self, runner: object) -> None:
|
def before_run(self, runner) -> None:
|
||||||
"""Finish all operations, related to checkpoint.
|
"""Finish all operations, related to checkpoint.
|
||||||
|
|
||||||
This function will get the appropriate file client, and the directory
|
This function will get the appropriate file client, and the directory
|
||||||
@ -75,7 +77,7 @@ class CheckpointHook(Hook):
|
|||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
if not self.out_dir:
|
if not self.out_dir:
|
||||||
self.out_dir = runner.work_dir # type: ignore
|
self.out_dir = runner.work_dir
|
||||||
|
|
||||||
self.file_client = FileClient.infer_client(self.file_client_args,
|
self.file_client = FileClient.infer_client(self.file_client_args,
|
||||||
self.out_dir)
|
self.out_dir)
|
||||||
@ -84,16 +86,12 @@ class CheckpointHook(Hook):
|
|||||||
# `self.out_dir` is set so the final `self.out_dir` is the
|
# `self.out_dir` is set so the final `self.out_dir` is the
|
||||||
# concatenation of `self.out_dir` and the last level directory of
|
# concatenation of `self.out_dir` and the last level directory of
|
||||||
# `runner.work_dir`
|
# `runner.work_dir`
|
||||||
if self.out_dir != runner.work_dir: # type: ignore
|
if self.out_dir != runner.work_dir:
|
||||||
basename = osp.basename(
|
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
||||||
runner.work_dir.rstrip( # type: ignore
|
|
||||||
osp.sep))
|
|
||||||
self.out_dir = self.file_client.join_path(
|
self.out_dir = self.file_client.join_path(
|
||||||
self.out_dir, # type: ignore
|
self.out_dir, basename) # type: ignore # noqa: E501
|
||||||
basename)
|
|
||||||
|
|
||||||
runner.logger.info(( # type: ignore
|
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
|
||||||
f'Checkpoints will be saved to {self.out_dir} by '
|
|
||||||
f'{self.file_client.name}.'))
|
f'{self.file_client.name}.'))
|
||||||
|
|
||||||
# disable the create_symlink option because some file backends do not
|
# disable the create_symlink option because some file backends do not
|
||||||
@ -109,7 +107,7 @@ class CheckpointHook(Hook):
|
|||||||
else:
|
else:
|
||||||
self.args['create_symlink'] = self.file_client.allow_symlink
|
self.args['create_symlink'] = self.file_client.allow_symlink
|
||||||
|
|
||||||
def after_train_epoch(self, runner: object) -> None:
|
def after_train_epoch(self, runner) -> None:
|
||||||
"""Save the checkpoint and synchronize buffers after each epoch.
|
"""Save the checkpoint and synchronize buffers after each epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -124,46 +122,40 @@ class CheckpointHook(Hook):
|
|||||||
if self.every_n_epochs(
|
if self.every_n_epochs(
|
||||||
runner, self.interval) or (self.save_last
|
runner, self.interval) or (self.save_last
|
||||||
and self.is_last_epoch(runner)):
|
and self.is_last_epoch(runner)):
|
||||||
runner.logger.info( # type: ignore
|
runner.logger.info(f'Saving checkpoint at \
|
||||||
f'Saving checkpoint at \
|
{runner.epoch + 1} epochs')
|
||||||
{runner.epoch + 1} epochs') # type: ignore
|
|
||||||
if self.sync_buffer:
|
if self.sync_buffer:
|
||||||
pass
|
pass
|
||||||
# TODO
|
# TODO
|
||||||
self._save_checkpoint(runner)
|
self._save_checkpoint(runner)
|
||||||
|
|
||||||
# TODO Add master_only decorator
|
# TODO Add master_only decorator
|
||||||
def _save_checkpoint(self, runner: object) -> None:
|
def _save_checkpoint(self, runner) -> None:
|
||||||
"""Save the current checkpoint and delete outdated checkpoint.
|
"""Save the current checkpoint and delete outdated checkpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
runner.save_checkpoint( # type: ignore
|
runner.save_checkpoint(
|
||||||
self.out_dir,
|
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
|
||||||
save_optimizer=self.save_optimizer,
|
if runner.meta is not None:
|
||||||
**self.args)
|
|
||||||
if runner.meta is not None: # type: ignore
|
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
cur_ckpt_filename = self.args.get(
|
cur_ckpt_filename = self.args.get(
|
||||||
'filename_tmpl',
|
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
|
||||||
'epoch_{}.pth').format(runner.epoch + 1) # type: ignore
|
|
||||||
else:
|
else:
|
||||||
cur_ckpt_filename = self.args.get(
|
cur_ckpt_filename = self.args.get(
|
||||||
'filename_tmpl',
|
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
|
||||||
'iter_{}.pth').format(runner.iter + 1) # type: ignore
|
runner.meta.setdefault('hook_msgs', dict())
|
||||||
runner.meta.setdefault('hook_msgs', dict()) # type: ignore
|
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
|
||||||
runner.meta['hook_msgs'][ # type: ignore
|
|
||||||
'last_ckpt'] = self.file_client.join_path(
|
|
||||||
self.out_dir, cur_ckpt_filename) # type: ignore
|
self.out_dir, cur_ckpt_filename) # type: ignore
|
||||||
# remove other checkpoints
|
# remove other checkpoints
|
||||||
if self.max_keep_ckpts > 0:
|
if self.max_keep_ckpts > 0:
|
||||||
if self.by_epoch:
|
if self.by_epoch:
|
||||||
name = 'epoch_{}.pth'
|
name = 'epoch_{}.pth'
|
||||||
current_ckpt = runner.epoch + 1 # type: ignore
|
current_ckpt = runner.epoch + 1
|
||||||
else:
|
else:
|
||||||
name = 'iter_{}.pth'
|
name = 'iter_{}.pth'
|
||||||
current_ckpt = runner.iter + 1 # type: ignore
|
current_ckpt = runner.iter + 1
|
||||||
redundant_ckpts = range(
|
redundant_ckpts = range(
|
||||||
current_ckpt - self.max_keep_ckpts * self.interval, 0,
|
current_ckpt - self.max_keep_ckpts * self.interval, 0,
|
||||||
-self.interval)
|
-self.interval)
|
||||||
@ -178,8 +170,8 @@ class CheckpointHook(Hook):
|
|||||||
|
|
||||||
def after_train_iter(
|
def after_train_iter(
|
||||||
self,
|
self,
|
||||||
runner: object,
|
runner,
|
||||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
"""Save the checkpoint and synchronize buffers after each iteration.
|
"""Save the checkpoint and synchronize buffers after each iteration.
|
||||||
|
|
||||||
@ -199,9 +191,8 @@ class CheckpointHook(Hook):
|
|||||||
if self.every_n_iters(
|
if self.every_n_iters(
|
||||||
runner, self.interval) or (self.save_last
|
runner, self.interval) or (self.save_last
|
||||||
and self.is_last_iter(runner)):
|
and self.is_last_iter(runner)):
|
||||||
runner.logger.info( # type: ignore
|
runner.logger.info(f'Saving checkpoint at \
|
||||||
f'Saving checkpoint at \
|
{runner.iter + 1} iterations')
|
||||||
{runner.iter + 1} iterations') # type: ignore
|
|
||||||
if self.sync_buffer:
|
if self.sync_buffer:
|
||||||
pass
|
pass
|
||||||
# TODO
|
# TODO
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Optional, Sequence
|
from typing import Any, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -7,6 +7,8 @@ from mmengine.data import BaseDataSample
|
|||||||
from mmengine.registry import HOOKS
|
from mmengine.registry import HOOKS
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
|
|
||||||
|
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
class EmptyCacheHook(Hook):
|
class EmptyCacheHook(Hook):
|
||||||
@ -33,35 +35,35 @@ class EmptyCacheHook(Hook):
|
|||||||
self._after_iter = after_iter
|
self._after_iter = after_iter
|
||||||
|
|
||||||
def after_iter(self,
|
def after_iter(self,
|
||||||
runner: object,
|
runner,
|
||||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
"""Empty cache after an iteration.
|
"""Empty cache after an iteration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (object): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
Defaults to None.
|
from dataloader. Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample]): Outputs from model.
|
outputs (Sequence[BaseDataSample]): Outputs from model.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
if self._after_iter:
|
if self._after_iter:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def before_epoch(self, runner: object) -> None:
|
def before_epoch(self, runner) -> None:
|
||||||
"""Empty cache before an epoch.
|
"""Empty cache before an epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (object): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
if self._before_epoch:
|
if self._before_epoch:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def after_epoch(self, runner: object) -> None:
|
def after_epoch(self, runner) -> None:
|
||||||
"""Empty cache after an epoch.
|
"""Empty cache after an epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (object): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
if self._after_epoch:
|
if self._after_epoch:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
@ -3,6 +3,8 @@ from typing import Any, Optional, Sequence, Tuple
|
|||||||
|
|
||||||
from mmengine.data import BaseDataSample
|
from mmengine.data import BaseDataSample
|
||||||
|
|
||||||
|
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
||||||
|
|
||||||
|
|
||||||
class Hook:
|
class Hook:
|
||||||
"""Base hook class.
|
"""Base hook class.
|
||||||
@ -12,7 +14,7 @@ class Hook:
|
|||||||
|
|
||||||
priority = 'NORMAL'
|
priority = 'NORMAL'
|
||||||
|
|
||||||
def before_run(self, runner: object) -> None:
|
def before_run(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before the training process.
|
operations before the training process.
|
||||||
|
|
||||||
@ -21,7 +23,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def after_run(self, runner: object) -> None:
|
def after_run(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after the training process.
|
operations after the training process.
|
||||||
|
|
||||||
@ -30,7 +32,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def before_epoch(self, runner: object) -> None:
|
def before_epoch(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before each epoch.
|
operations before each epoch.
|
||||||
|
|
||||||
@ -39,7 +41,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def after_epoch(self, runner: object) -> None:
|
def after_epoch(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after each epoch.
|
operations after each epoch.
|
||||||
|
|
||||||
@ -48,11 +50,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def before_iter(
|
def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||||
self,
|
|
||||||
runner: object,
|
|
||||||
data_batch: Optional[Sequence[Tuple[Any,
|
|
||||||
BaseDataSample]]] = None) -> None:
|
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before each iter.
|
operations before each iter.
|
||||||
|
|
||||||
@ -64,9 +62,8 @@ class Hook:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def after_iter(self,
|
def after_iter(self,
|
||||||
runner: object,
|
runner,
|
||||||
data_batch: Optional[Sequence[Tuple[
|
data_batch: DATA_BATCH = None,
|
||||||
Any, BaseDataSample]]] = None,
|
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after each epoch.
|
operations after each epoch.
|
||||||
@ -80,7 +77,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def before_save_checkpoint(self, runner: object, checkpoint: dict) -> None:
|
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before saving the checkpoint.
|
operations before saving the checkpoint.
|
||||||
|
|
||||||
@ -90,7 +87,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def after_load_checkpoint(self, runner: object, checkpoint: dict) -> None:
|
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after loading the checkpoint.
|
operations after loading the checkpoint.
|
||||||
|
|
||||||
@ -100,7 +97,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def before_train_epoch(self, runner: object) -> None:
|
def before_train_epoch(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before each training epoch.
|
operations before each training epoch.
|
||||||
|
|
||||||
@ -109,7 +106,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
self.before_epoch(runner)
|
self.before_epoch(runner)
|
||||||
|
|
||||||
def before_val_epoch(self, runner: object) -> None:
|
def before_val_epoch(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before each validation epoch.
|
operations before each validation epoch.
|
||||||
|
|
||||||
@ -118,7 +115,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
self.before_epoch(runner)
|
self.before_epoch(runner)
|
||||||
|
|
||||||
def before_test_epoch(self, runner: object) -> None:
|
def before_test_epoch(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before each test epoch.
|
operations before each test epoch.
|
||||||
|
|
||||||
@ -127,7 +124,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
self.before_epoch(runner)
|
self.before_epoch(runner)
|
||||||
|
|
||||||
def after_train_epoch(self, runner: object) -> None:
|
def after_train_epoch(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after each training epoch.
|
operations after each training epoch.
|
||||||
|
|
||||||
@ -136,7 +133,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
self.after_epoch(runner)
|
self.after_epoch(runner)
|
||||||
|
|
||||||
def after_val_epoch(self, runner: object) -> None:
|
def after_val_epoch(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after each validation epoch.
|
operations after each validation epoch.
|
||||||
|
|
||||||
@ -145,7 +142,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
self.after_epoch(runner)
|
self.after_epoch(runner)
|
||||||
|
|
||||||
def after_test_epoch(self, runner: object) -> None:
|
def after_test_epoch(self, runner) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after each test epoch.
|
operations after each test epoch.
|
||||||
|
|
||||||
@ -154,11 +151,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
self.after_epoch(runner)
|
self.after_epoch(runner)
|
||||||
|
|
||||||
def before_train_iter(
|
def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||||
self,
|
|
||||||
runner: object,
|
|
||||||
data_batch: Optional[Sequence[Tuple[Any,
|
|
||||||
BaseDataSample]]] = None) -> None:
|
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before each training iteration.
|
operations before each training iteration.
|
||||||
|
|
||||||
@ -169,11 +162,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
self.before_iter(runner, data_batch=None)
|
self.before_iter(runner, data_batch=None)
|
||||||
|
|
||||||
def before_val_iter(
|
def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||||
self,
|
|
||||||
runner: object,
|
|
||||||
data_batch: Optional[Sequence[Tuple[Any,
|
|
||||||
BaseDataSample]]] = None) -> None:
|
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before each validation iteration.
|
operations before each validation iteration.
|
||||||
|
|
||||||
@ -184,11 +173,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
self.before_iter(runner, data_batch=None)
|
self.before_iter(runner, data_batch=None)
|
||||||
|
|
||||||
def before_test_iter(
|
def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||||
self,
|
|
||||||
runner: object,
|
|
||||||
data_batch: Optional[Sequence[Tuple[Any,
|
|
||||||
BaseDataSample]]] = None) -> None:
|
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before each test iteration.
|
operations before each test iteration.
|
||||||
|
|
||||||
@ -201,8 +186,8 @@ class Hook:
|
|||||||
|
|
||||||
def after_train_iter(
|
def after_train_iter(
|
||||||
self,
|
self,
|
||||||
runner: object,
|
runner,
|
||||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after each training iteration.
|
operations after each training iteration.
|
||||||
@ -218,8 +203,8 @@ class Hook:
|
|||||||
|
|
||||||
def after_val_iter(
|
def after_val_iter(
|
||||||
self,
|
self,
|
||||||
runner: object,
|
runner,
|
||||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after each validation iteration.
|
operations after each validation iteration.
|
||||||
@ -235,8 +220,8 @@ class Hook:
|
|||||||
|
|
||||||
def after_test_iter(
|
def after_test_iter(
|
||||||
self,
|
self,
|
||||||
runner: object,
|
runner,
|
||||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations after each test iteration.
|
operations after each test iteration.
|
||||||
@ -250,7 +235,7 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
self.after_iter(runner, data_batch=None, outputs=None)
|
self.after_iter(runner, data_batch=None, outputs=None)
|
||||||
|
|
||||||
def every_n_epochs(self, runner: object, n: int) -> bool:
|
def every_n_epochs(self, runner, n: int) -> bool:
|
||||||
"""Test whether or not current epoch can be evenly divided by n.
|
"""Test whether or not current epoch can be evenly divided by n.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -260,9 +245,9 @@ class Hook:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: whether or not current epoch can be evenly divided by n.
|
bool: whether or not current epoch can be evenly divided by n.
|
||||||
"""
|
"""
|
||||||
return (runner.epoch + 1) % n == 0 if n > 0 else False # type: ignore
|
return (runner.epoch + 1) % n == 0 if n > 0 else False
|
||||||
|
|
||||||
def every_n_inner_iters(self, runner: object, n: int) -> bool:
|
def every_n_inner_iters(self, runner, n: int) -> bool:
|
||||||
"""Test whether or not current inner iteration can be evenly divided by
|
"""Test whether or not current inner iteration can be evenly divided by
|
||||||
n.
|
n.
|
||||||
|
|
||||||
@ -275,10 +260,9 @@ class Hook:
|
|||||||
bool: whether or not current inner iteration can be evenly
|
bool: whether or not current inner iteration can be evenly
|
||||||
divided by n.
|
divided by n.
|
||||||
"""
|
"""
|
||||||
return (runner.inner_iter + # type: ignore
|
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
|
||||||
1) % n == 0 if n > 0 else False
|
|
||||||
|
|
||||||
def every_n_iters(self, runner: object, n: int) -> bool:
|
def every_n_iters(self, runner, n: int) -> bool:
|
||||||
"""Test whether or not current iteration can be evenly divided by n.
|
"""Test whether or not current iteration can be evenly divided by n.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -290,9 +274,9 @@ class Hook:
|
|||||||
bool: Return True if the current iteration can be evenly divided
|
bool: Return True if the current iteration can be evenly divided
|
||||||
by n, otherwise False.
|
by n, otherwise False.
|
||||||
"""
|
"""
|
||||||
return (runner.iter + 1) % n == 0 if n > 0 else False # type: ignore
|
return (runner.iter + 1) % n == 0 if n > 0 else False
|
||||||
|
|
||||||
def end_of_epoch(self, runner: object) -> bool:
|
def end_of_epoch(self, runner) -> bool:
|
||||||
"""Check whether the current epoch reaches the `max_epochs` or not.
|
"""Check whether the current epoch reaches the `max_epochs` or not.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -301,9 +285,9 @@ class Hook:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: whether the end of current epoch or not.
|
bool: whether the end of current epoch or not.
|
||||||
"""
|
"""
|
||||||
return runner.inner_iter + 1 == len(runner.data_loader) # type: ignore
|
return runner.inner_iter + 1 == len(runner.data_loader)
|
||||||
|
|
||||||
def is_last_epoch(self, runner: object) -> bool:
|
def is_last_epoch(self, runner) -> bool:
|
||||||
"""Test whether or not current epoch is the last epoch.
|
"""Test whether or not current epoch is the last epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -313,9 +297,9 @@ class Hook:
|
|||||||
bool: bool: Return True if the current epoch reaches the
|
bool: bool: Return True if the current epoch reaches the
|
||||||
`max_epochs`, otherwise False.
|
`max_epochs`, otherwise False.
|
||||||
"""
|
"""
|
||||||
return runner.epoch + 1 == runner._max_epochs # type: ignore
|
return runner.epoch + 1 == runner._max_epochs
|
||||||
|
|
||||||
def is_last_iter(self, runner: object) -> bool:
|
def is_last_iter(self, runner) -> bool:
|
||||||
"""Test whether or not current epoch is the last iteration.
|
"""Test whether or not current epoch is the last iteration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -324,4 +308,4 @@ class Hook:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: whether or not current iteration is the last iteration.
|
bool: whether or not current iteration is the last iteration.
|
||||||
"""
|
"""
|
||||||
return runner.iter + 1 == runner._max_iters # type: ignore
|
return runner.iter + 1 == runner._max_iters
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Sequence
|
from typing import Any, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from mmengine.data import BaseDataSample
|
from mmengine.data import BaseDataSample
|
||||||
from mmengine.registry import HOOKS
|
from mmengine.registry import HOOKS
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
|
|
||||||
|
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
class IterTimerHook(Hook):
|
class IterTimerHook(Hook):
|
||||||
@ -16,45 +18,38 @@ class IterTimerHook(Hook):
|
|||||||
|
|
||||||
priority = 'NORMAL'
|
priority = 'NORMAL'
|
||||||
|
|
||||||
def before_epoch(self, runner: object) -> None:
|
def before_epoch(self, runner) -> None:
|
||||||
"""Record time flag before start a epoch.
|
"""Record time flag before start a epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (object): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
self.t = time.time()
|
self.t = time.time()
|
||||||
|
|
||||||
def before_iter(
|
def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||||
self,
|
|
||||||
runner: object,
|
|
||||||
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
|
|
||||||
"""Logging time for loading data and update the time flag.
|
"""Logging time for loading data and update the time flag.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (object): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
Defaults to None.
|
from dataloader. Defaults to None.
|
||||||
"""
|
"""
|
||||||
# TODO: update for new logging system
|
# TODO: update for new logging system
|
||||||
runner.log_buffer.update({ # type: ignore
|
runner.log_buffer.update({'data_time': time.time() - self.t})
|
||||||
'data_time': time.time() - self.t
|
|
||||||
})
|
|
||||||
|
|
||||||
def after_iter(self,
|
def after_iter(self,
|
||||||
runner: object,
|
runner,
|
||||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
"""Logging time for a iteration and update the time flag.
|
"""Logging time for a iteration and update the time flag.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (object): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
Defaults to None.
|
from dataloader. Defaults to None.
|
||||||
outputs (Sequence[BaseDataSample]): Outputs from model.
|
outputs (Sequence[BaseDataSample]): Outputs from model.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
# TODO: update for new logging system
|
# TODO: update for new logging system
|
||||||
runner.log_buffer.update({ # type: ignore
|
runner.log_buffer.update({'time': time.time() - self.t})
|
||||||
'time': time.time() - self.t
|
|
||||||
})
|
|
||||||
self.t = time.time()
|
self.t = time.time()
|
||||||
|
@ -10,6 +10,8 @@ from mmengine.data import BaseDataSample
|
|||||||
from mmengine.registry import HOOKS
|
from mmengine.registry import HOOKS
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
|
|
||||||
|
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
class OptimizerHook(Hook):
|
class OptimizerHook(Hook):
|
||||||
@ -56,8 +58,8 @@ class OptimizerHook(Hook):
|
|||||||
|
|
||||||
def after_train_iter(
|
def after_train_iter(
|
||||||
self,
|
self,
|
||||||
runner: object,
|
runner,
|
||||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
"""All operations need to be finished after each training iteration.
|
"""All operations need to be finished after each training iteration.
|
||||||
|
|
||||||
@ -82,32 +84,27 @@ class OptimizerHook(Hook):
|
|||||||
In order to keep this interface consistent with other hooks,
|
In order to keep this interface consistent with other hooks,
|
||||||
we keep ``outputs`` here. Defaults to None.
|
we keep ``outputs`` here. Defaults to None.
|
||||||
"""
|
"""
|
||||||
runner.optimizer.zero_grad() # type: ignore
|
runner.optimizer.zero_grad()
|
||||||
if self.detect_anomalous_params:
|
if self.detect_anomalous_params:
|
||||||
self.detect_anomalous_parameters(
|
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
|
||||||
runner.outputs['loss'], # type: ignore
|
runner.outputs['loss'].backward()
|
||||||
runner)
|
|
||||||
runner.outputs['loss'].backward() # type: ignore
|
|
||||||
|
|
||||||
if self.grad_clip is not None:
|
if self.grad_clip is not None:
|
||||||
grad_norm = self.clip_grads(
|
grad_norm = self.clip_grads(runner.model.parameters())
|
||||||
runner.model.parameters()) # type: ignore
|
|
||||||
if grad_norm is not None:
|
if grad_norm is not None:
|
||||||
# Add grad norm to the logger
|
# Add grad norm to the logger
|
||||||
runner.log_buffer.update( # type: ignore
|
runner.log_buffer.update({'grad_norm': float(grad_norm)},
|
||||||
{'grad_norm': float(grad_norm)},
|
runner.outputs['num_samples'])
|
||||||
runner.outputs['num_samples']) # type: ignore
|
runner.optimizer.step()
|
||||||
runner.optimizer.step() # type: ignore
|
|
||||||
|
|
||||||
def detect_anomalous_parameters(self, loss: torch.Tensor,
|
def detect_anomalous_parameters(self, loss: torch.Tensor, runner) -> None:
|
||||||
runner: object) -> None:
|
|
||||||
"""Detect anomalous parameters that are not included in the graph.
|
"""Detect anomalous parameters that are not included in the graph.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
loss (torch.Tensor): The loss of current iteration.
|
loss (torch.Tensor): The loss of current iteration.
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
logger = runner.logger # type: ignore
|
logger = runner.logger
|
||||||
parameters_in_graph = set()
|
parameters_in_graph = set()
|
||||||
visited = set()
|
visited = set()
|
||||||
|
|
||||||
@ -125,7 +122,7 @@ class OptimizerHook(Hook):
|
|||||||
traverse(grad_fn)
|
traverse(grad_fn)
|
||||||
|
|
||||||
traverse(loss.grad_fn)
|
traverse(loss.grad_fn)
|
||||||
for n, p in runner.model.named_parameters(): # type: ignore
|
for n, p in runner.model.named_parameters():
|
||||||
if p not in parameters_in_graph and p.requires_grad:
|
if p not in parameters_in_graph and p.requires_grad:
|
||||||
logger.log(
|
logger.log(
|
||||||
level=logging.ERROR,
|
level=logging.ERROR,
|
||||||
|
@ -5,6 +5,8 @@ from mmengine.data import BaseDataSample
|
|||||||
from mmengine.registry import HOOKS
|
from mmengine.registry import HOOKS
|
||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
|
|
||||||
|
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
class ParamSchedulerHook(Hook):
|
class ParamSchedulerHook(Hook):
|
||||||
@ -15,8 +17,8 @@ class ParamSchedulerHook(Hook):
|
|||||||
|
|
||||||
def after_train_iter(
|
def after_train_iter(
|
||||||
self,
|
self,
|
||||||
runner: object,
|
runner,
|
||||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
"""Call step function for each scheduler after each iteration.
|
"""Call step function for each scheduler after each iteration.
|
||||||
|
|
||||||
@ -30,16 +32,16 @@ class ParamSchedulerHook(Hook):
|
|||||||
In order to keep this interface consistent with other hooks, we
|
In order to keep this interface consistent with other hooks, we
|
||||||
keep ``data_batch`` here. Defaults to None.
|
keep ``data_batch`` here. Defaults to None.
|
||||||
"""
|
"""
|
||||||
for scheduler in runner.schedulers: # type: ignore
|
for scheduler in runner.schedulers:
|
||||||
if not scheduler.by_epoch:
|
if not scheduler.by_epoch:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
def after_train_epoch(self, runner: object) -> None:
|
def after_train_epoch(self, runner) -> None:
|
||||||
"""Call step function for each scheduler after each epoch.
|
"""Call step function for each scheduler after each epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
for scheduler in runner.schedulers: # type: ignore
|
for scheduler in runner.schedulers:
|
||||||
if scheduler.by_epoch:
|
if scheduler.by_epoch:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
@ -14,18 +14,15 @@ class DistSamplerSeedHook(Hook):
|
|||||||
|
|
||||||
priority = 'NORMAL'
|
priority = 'NORMAL'
|
||||||
|
|
||||||
def before_epoch(self, runner: object) -> None:
|
def before_epoch(self, runner) -> None:
|
||||||
"""Set the seed for sampler and batch_sampler.
|
"""Set the seed for sampler and batch_sampler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
if hasattr(runner.data_loader.sampler, 'set_epoch'): # type: ignore
|
if hasattr(runner.data_loader.sampler, 'set_epoch'):
|
||||||
# in case the data loader uses `SequentialSampler` in Pytorch
|
# in case the data loader uses `SequentialSampler` in Pytorch
|
||||||
runner.data_loader.sampler.set_epoch(runner.epoch) # type: ignore
|
runner.data_loader.sampler.set_epoch(runner.epoch)
|
||||||
elif hasattr(
|
elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
|
||||||
runner.data_loader.batch_sampler.sampler, # type: ignore
|
|
||||||
'set_epoch'):
|
|
||||||
# batch sampler in pytorch warps the sampler as its attributes.
|
# batch sampler in pytorch warps the sampler as its attributes.
|
||||||
runner.data_loader.batch_sampler.sampler.set_epoch( # type: ignore
|
runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
|
||||||
runner.epoch) # type: ignore
|
|
||||||
|
@ -89,11 +89,11 @@ class SyncBuffersHook(Hook):
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.distributed = dist.IS_DIST
|
self.distributed = dist.IS_DIST
|
||||||
|
|
||||||
def after_epoch(self, runner: object) -> None:
|
def after_epoch(self, runner) -> None:
|
||||||
"""All-reduce model buffers at the end of each epoch.
|
"""All-reduce model buffers at the end of each epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (object): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
if self.distributed:
|
if self.distributed:
|
||||||
allreduce_params(runner.model.buffers()) # type: ignore
|
allreduce_params(runner.model.buffers())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user