mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
fix type hint in hooks (#106)
This commit is contained in:
parent
9f0d1a9628
commit
ed8dcb4c61
@ -9,6 +9,8 @@ from mmengine.fileio import FileClient
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class CheckpointHook(Hook):
|
||||
@ -65,7 +67,7 @@ class CheckpointHook(Hook):
|
||||
self.sync_buffer = sync_buffer
|
||||
self.file_client_args = file_client_args
|
||||
|
||||
def before_run(self, runner: object) -> None:
|
||||
def before_run(self, runner) -> None:
|
||||
"""Finish all operations, related to checkpoint.
|
||||
|
||||
This function will get the appropriate file client, and the directory
|
||||
@ -75,7 +77,7 @@ class CheckpointHook(Hook):
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if not self.out_dir:
|
||||
self.out_dir = runner.work_dir # type: ignore
|
||||
self.out_dir = runner.work_dir
|
||||
|
||||
self.file_client = FileClient.infer_client(self.file_client_args,
|
||||
self.out_dir)
|
||||
@ -84,16 +86,12 @@ class CheckpointHook(Hook):
|
||||
# `self.out_dir` is set so the final `self.out_dir` is the
|
||||
# concatenation of `self.out_dir` and the last level directory of
|
||||
# `runner.work_dir`
|
||||
if self.out_dir != runner.work_dir: # type: ignore
|
||||
basename = osp.basename(
|
||||
runner.work_dir.rstrip( # type: ignore
|
||||
osp.sep))
|
||||
if self.out_dir != runner.work_dir:
|
||||
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
||||
self.out_dir = self.file_client.join_path(
|
||||
self.out_dir, # type: ignore
|
||||
basename)
|
||||
self.out_dir, basename) # type: ignore # noqa: E501
|
||||
|
||||
runner.logger.info(( # type: ignore
|
||||
f'Checkpoints will be saved to {self.out_dir} by '
|
||||
runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
|
||||
f'{self.file_client.name}.'))
|
||||
|
||||
# disable the create_symlink option because some file backends do not
|
||||
@ -109,7 +107,7 @@ class CheckpointHook(Hook):
|
||||
else:
|
||||
self.args['create_symlink'] = self.file_client.allow_symlink
|
||||
|
||||
def after_train_epoch(self, runner: object) -> None:
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
"""Save the checkpoint and synchronize buffers after each epoch.
|
||||
|
||||
Args:
|
||||
@ -124,46 +122,40 @@ class CheckpointHook(Hook):
|
||||
if self.every_n_epochs(
|
||||
runner, self.interval) or (self.save_last
|
||||
and self.is_last_epoch(runner)):
|
||||
runner.logger.info( # type: ignore
|
||||
f'Saving checkpoint at \
|
||||
{runner.epoch + 1} epochs') # type: ignore
|
||||
runner.logger.info(f'Saving checkpoint at \
|
||||
{runner.epoch + 1} epochs')
|
||||
if self.sync_buffer:
|
||||
pass
|
||||
# TODO
|
||||
self._save_checkpoint(runner)
|
||||
|
||||
# TODO Add master_only decorator
|
||||
def _save_checkpoint(self, runner: object) -> None:
|
||||
def _save_checkpoint(self, runner) -> None:
|
||||
"""Save the current checkpoint and delete outdated checkpoint.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
runner.save_checkpoint( # type: ignore
|
||||
self.out_dir,
|
||||
save_optimizer=self.save_optimizer,
|
||||
**self.args)
|
||||
if runner.meta is not None: # type: ignore
|
||||
runner.save_checkpoint(
|
||||
self.out_dir, save_optimizer=self.save_optimizer, **self.args)
|
||||
if runner.meta is not None:
|
||||
if self.by_epoch:
|
||||
cur_ckpt_filename = self.args.get(
|
||||
'filename_tmpl',
|
||||
'epoch_{}.pth').format(runner.epoch + 1) # type: ignore
|
||||
'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
|
||||
else:
|
||||
cur_ckpt_filename = self.args.get(
|
||||
'filename_tmpl',
|
||||
'iter_{}.pth').format(runner.iter + 1) # type: ignore
|
||||
runner.meta.setdefault('hook_msgs', dict()) # type: ignore
|
||||
runner.meta['hook_msgs'][ # type: ignore
|
||||
'last_ckpt'] = self.file_client.join_path(
|
||||
'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
|
||||
runner.meta.setdefault('hook_msgs', dict())
|
||||
runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
|
||||
self.out_dir, cur_ckpt_filename) # type: ignore
|
||||
# remove other checkpoints
|
||||
if self.max_keep_ckpts > 0:
|
||||
if self.by_epoch:
|
||||
name = 'epoch_{}.pth'
|
||||
current_ckpt = runner.epoch + 1 # type: ignore
|
||||
current_ckpt = runner.epoch + 1
|
||||
else:
|
||||
name = 'iter_{}.pth'
|
||||
current_ckpt = runner.iter + 1 # type: ignore
|
||||
current_ckpt = runner.iter + 1
|
||||
redundant_ckpts = range(
|
||||
current_ckpt - self.max_keep_ckpts * self.interval, 0,
|
||||
-self.interval)
|
||||
@ -178,8 +170,8 @@ class CheckpointHook(Hook):
|
||||
|
||||
def after_train_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Save the checkpoint and synchronize buffers after each iteration.
|
||||
|
||||
@ -199,9 +191,8 @@ class CheckpointHook(Hook):
|
||||
if self.every_n_iters(
|
||||
runner, self.interval) or (self.save_last
|
||||
and self.is_last_iter(runner)):
|
||||
runner.logger.info( # type: ignore
|
||||
f'Saving checkpoint at \
|
||||
{runner.iter + 1} iterations') # type: ignore
|
||||
runner.logger.info(f'Saving checkpoint at \
|
||||
{runner.iter + 1} iterations')
|
||||
if self.sync_buffer:
|
||||
pass
|
||||
# TODO
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -7,6 +7,8 @@ from mmengine.data import BaseDataSample
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class EmptyCacheHook(Hook):
|
||||
@ -33,35 +35,35 @@ class EmptyCacheHook(Hook):
|
||||
self._after_iter = after_iter
|
||||
|
||||
def after_iter(self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Empty cache after an iteration.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||
Defaults to None.
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample]): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
if self._after_iter:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def before_epoch(self, runner: object) -> None:
|
||||
def before_epoch(self, runner) -> None:
|
||||
"""Empty cache before an epoch.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self._before_epoch:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def after_epoch(self, runner: object) -> None:
|
||||
def after_epoch(self, runner) -> None:
|
||||
"""Empty cache after an epoch.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self._after_epoch:
|
||||
torch.cuda.empty_cache()
|
||||
|
@ -3,6 +3,8 @@ from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
||||
|
||||
|
||||
class Hook:
|
||||
"""Base hook class.
|
||||
@ -12,7 +14,7 @@ class Hook:
|
||||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def before_run(self, runner: object) -> None:
|
||||
def before_run(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before the training process.
|
||||
|
||||
@ -21,7 +23,7 @@ class Hook:
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_run(self, runner: object) -> None:
|
||||
def after_run(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after the training process.
|
||||
|
||||
@ -30,7 +32,7 @@ class Hook:
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_epoch(self, runner: object) -> None:
|
||||
def before_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each epoch.
|
||||
|
||||
@ -39,7 +41,7 @@ class Hook:
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_epoch(self, runner: object) -> None:
|
||||
def after_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each epoch.
|
||||
|
||||
@ -48,11 +50,7 @@ class Hook:
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any,
|
||||
BaseDataSample]]] = None) -> None:
|
||||
def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each iter.
|
||||
|
||||
@ -64,9 +62,8 @@ class Hook:
|
||||
pass
|
||||
|
||||
def after_iter(self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[
|
||||
Any, BaseDataSample]]] = None,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each epoch.
|
||||
@ -80,7 +77,7 @@ class Hook:
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_save_checkpoint(self, runner: object, checkpoint: dict) -> None:
|
||||
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before saving the checkpoint.
|
||||
|
||||
@ -90,7 +87,7 @@ class Hook:
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_load_checkpoint(self, runner: object, checkpoint: dict) -> None:
|
||||
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after loading the checkpoint.
|
||||
|
||||
@ -100,7 +97,7 @@ class Hook:
|
||||
"""
|
||||
pass
|
||||
|
||||
def before_train_epoch(self, runner: object) -> None:
|
||||
def before_train_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each training epoch.
|
||||
|
||||
@ -109,7 +106,7 @@ class Hook:
|
||||
"""
|
||||
self.before_epoch(runner)
|
||||
|
||||
def before_val_epoch(self, runner: object) -> None:
|
||||
def before_val_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each validation epoch.
|
||||
|
||||
@ -118,7 +115,7 @@ class Hook:
|
||||
"""
|
||||
self.before_epoch(runner)
|
||||
|
||||
def before_test_epoch(self, runner: object) -> None:
|
||||
def before_test_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each test epoch.
|
||||
|
||||
@ -127,7 +124,7 @@ class Hook:
|
||||
"""
|
||||
self.before_epoch(runner)
|
||||
|
||||
def after_train_epoch(self, runner: object) -> None:
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each training epoch.
|
||||
|
||||
@ -136,7 +133,7 @@ class Hook:
|
||||
"""
|
||||
self.after_epoch(runner)
|
||||
|
||||
def after_val_epoch(self, runner: object) -> None:
|
||||
def after_val_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each validation epoch.
|
||||
|
||||
@ -145,7 +142,7 @@ class Hook:
|
||||
"""
|
||||
self.after_epoch(runner)
|
||||
|
||||
def after_test_epoch(self, runner: object) -> None:
|
||||
def after_test_epoch(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each test epoch.
|
||||
|
||||
@ -154,11 +151,7 @@ class Hook:
|
||||
"""
|
||||
self.after_epoch(runner)
|
||||
|
||||
def before_train_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any,
|
||||
BaseDataSample]]] = None) -> None:
|
||||
def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each training iteration.
|
||||
|
||||
@ -169,11 +162,7 @@ class Hook:
|
||||
"""
|
||||
self.before_iter(runner, data_batch=None)
|
||||
|
||||
def before_val_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any,
|
||||
BaseDataSample]]] = None) -> None:
|
||||
def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each validation iteration.
|
||||
|
||||
@ -184,11 +173,7 @@ class Hook:
|
||||
"""
|
||||
self.before_iter(runner, data_batch=None)
|
||||
|
||||
def before_test_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any,
|
||||
BaseDataSample]]] = None) -> None:
|
||||
def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each test iteration.
|
||||
|
||||
@ -201,8 +186,8 @@ class Hook:
|
||||
|
||||
def after_train_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each training iteration.
|
||||
@ -218,8 +203,8 @@ class Hook:
|
||||
|
||||
def after_val_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each validation iteration.
|
||||
@ -235,8 +220,8 @@ class Hook:
|
||||
|
||||
def after_test_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each test iteration.
|
||||
@ -250,7 +235,7 @@ class Hook:
|
||||
"""
|
||||
self.after_iter(runner, data_batch=None, outputs=None)
|
||||
|
||||
def every_n_epochs(self, runner: object, n: int) -> bool:
|
||||
def every_n_epochs(self, runner, n: int) -> bool:
|
||||
"""Test whether or not current epoch can be evenly divided by n.
|
||||
|
||||
Args:
|
||||
@ -260,9 +245,9 @@ class Hook:
|
||||
Returns:
|
||||
bool: whether or not current epoch can be evenly divided by n.
|
||||
"""
|
||||
return (runner.epoch + 1) % n == 0 if n > 0 else False # type: ignore
|
||||
return (runner.epoch + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def every_n_inner_iters(self, runner: object, n: int) -> bool:
|
||||
def every_n_inner_iters(self, runner, n: int) -> bool:
|
||||
"""Test whether or not current inner iteration can be evenly divided by
|
||||
n.
|
||||
|
||||
@ -275,10 +260,9 @@ class Hook:
|
||||
bool: whether or not current inner iteration can be evenly
|
||||
divided by n.
|
||||
"""
|
||||
return (runner.inner_iter + # type: ignore
|
||||
1) % n == 0 if n > 0 else False
|
||||
return (runner.inner_iter + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def every_n_iters(self, runner: object, n: int) -> bool:
|
||||
def every_n_iters(self, runner, n: int) -> bool:
|
||||
"""Test whether or not current iteration can be evenly divided by n.
|
||||
|
||||
Args:
|
||||
@ -290,9 +274,9 @@ class Hook:
|
||||
bool: Return True if the current iteration can be evenly divided
|
||||
by n, otherwise False.
|
||||
"""
|
||||
return (runner.iter + 1) % n == 0 if n > 0 else False # type: ignore
|
||||
return (runner.iter + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def end_of_epoch(self, runner: object) -> bool:
|
||||
def end_of_epoch(self, runner) -> bool:
|
||||
"""Check whether the current epoch reaches the `max_epochs` or not.
|
||||
|
||||
Args:
|
||||
@ -301,9 +285,9 @@ class Hook:
|
||||
Returns:
|
||||
bool: whether the end of current epoch or not.
|
||||
"""
|
||||
return runner.inner_iter + 1 == len(runner.data_loader) # type: ignore
|
||||
return runner.inner_iter + 1 == len(runner.data_loader)
|
||||
|
||||
def is_last_epoch(self, runner: object) -> bool:
|
||||
def is_last_epoch(self, runner) -> bool:
|
||||
"""Test whether or not current epoch is the last epoch.
|
||||
|
||||
Args:
|
||||
@ -313,9 +297,9 @@ class Hook:
|
||||
bool: bool: Return True if the current epoch reaches the
|
||||
`max_epochs`, otherwise False.
|
||||
"""
|
||||
return runner.epoch + 1 == runner._max_epochs # type: ignore
|
||||
return runner.epoch + 1 == runner._max_epochs
|
||||
|
||||
def is_last_iter(self, runner: object) -> bool:
|
||||
def is_last_iter(self, runner) -> bool:
|
||||
"""Test whether or not current epoch is the last iteration.
|
||||
|
||||
Args:
|
||||
@ -324,4 +308,4 @@ class Hook:
|
||||
Returns:
|
||||
bool: whether or not current iteration is the last iteration.
|
||||
"""
|
||||
return runner.iter + 1 == runner._max_iters # type: ignore
|
||||
return runner.iter + 1 == runner._max_iters
|
||||
|
@ -1,11 +1,13 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import time
|
||||
from typing import Optional, Sequence
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class IterTimerHook(Hook):
|
||||
@ -16,45 +18,38 @@ class IterTimerHook(Hook):
|
||||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def before_epoch(self, runner: object) -> None:
|
||||
def before_epoch(self, runner) -> None:
|
||||
"""Record time flag before start a epoch.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
self.t = time.time()
|
||||
|
||||
def before_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
|
||||
"""Logging time for loading data and update the time flag.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||
Defaults to None.
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
"""
|
||||
# TODO: update for new logging system
|
||||
runner.log_buffer.update({ # type: ignore
|
||||
'data_time': time.time() - self.t
|
||||
})
|
||||
runner.log_buffer.update({'data_time': time.time() - self.t})
|
||||
|
||||
def after_iter(self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Logging time for a iteration and update the time flag.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||
Defaults to None.
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample]): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
# TODO: update for new logging system
|
||||
runner.log_buffer.update({ # type: ignore
|
||||
'time': time.time() - self.t
|
||||
})
|
||||
runner.log_buffer.update({'time': time.time() - self.t})
|
||||
self.t = time.time()
|
||||
|
@ -10,6 +10,8 @@ from mmengine.data import BaseDataSample
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class OptimizerHook(Hook):
|
||||
@ -56,8 +58,8 @@ class OptimizerHook(Hook):
|
||||
|
||||
def after_train_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""All operations need to be finished after each training iteration.
|
||||
|
||||
@ -82,32 +84,27 @@ class OptimizerHook(Hook):
|
||||
In order to keep this interface consistent with other hooks,
|
||||
we keep ``outputs`` here. Defaults to None.
|
||||
"""
|
||||
runner.optimizer.zero_grad() # type: ignore
|
||||
runner.optimizer.zero_grad()
|
||||
if self.detect_anomalous_params:
|
||||
self.detect_anomalous_parameters(
|
||||
runner.outputs['loss'], # type: ignore
|
||||
runner)
|
||||
runner.outputs['loss'].backward() # type: ignore
|
||||
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
|
||||
runner.outputs['loss'].backward()
|
||||
|
||||
if self.grad_clip is not None:
|
||||
grad_norm = self.clip_grads(
|
||||
runner.model.parameters()) # type: ignore
|
||||
grad_norm = self.clip_grads(runner.model.parameters())
|
||||
if grad_norm is not None:
|
||||
# Add grad norm to the logger
|
||||
runner.log_buffer.update( # type: ignore
|
||||
{'grad_norm': float(grad_norm)},
|
||||
runner.outputs['num_samples']) # type: ignore
|
||||
runner.optimizer.step() # type: ignore
|
||||
runner.log_buffer.update({'grad_norm': float(grad_norm)},
|
||||
runner.outputs['num_samples'])
|
||||
runner.optimizer.step()
|
||||
|
||||
def detect_anomalous_parameters(self, loss: torch.Tensor,
|
||||
runner: object) -> None:
|
||||
def detect_anomalous_parameters(self, loss: torch.Tensor, runner) -> None:
|
||||
"""Detect anomalous parameters that are not included in the graph.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): The loss of current iteration.
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
logger = runner.logger # type: ignore
|
||||
logger = runner.logger
|
||||
parameters_in_graph = set()
|
||||
visited = set()
|
||||
|
||||
@ -125,7 +122,7 @@ class OptimizerHook(Hook):
|
||||
traverse(grad_fn)
|
||||
|
||||
traverse(loss.grad_fn)
|
||||
for n, p in runner.model.named_parameters(): # type: ignore
|
||||
for n, p in runner.model.named_parameters():
|
||||
if p not in parameters_in_graph and p.requires_grad:
|
||||
logger.log(
|
||||
level=logging.ERROR,
|
||||
|
@ -5,6 +5,8 @@ from mmengine.data import BaseDataSample
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class ParamSchedulerHook(Hook):
|
||||
@ -15,8 +17,8 @@ class ParamSchedulerHook(Hook):
|
||||
|
||||
def after_train_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Call step function for each scheduler after each iteration.
|
||||
|
||||
@ -30,16 +32,16 @@ class ParamSchedulerHook(Hook):
|
||||
In order to keep this interface consistent with other hooks, we
|
||||
keep ``data_batch`` here. Defaults to None.
|
||||
"""
|
||||
for scheduler in runner.schedulers: # type: ignore
|
||||
for scheduler in runner.schedulers:
|
||||
if not scheduler.by_epoch:
|
||||
scheduler.step()
|
||||
|
||||
def after_train_epoch(self, runner: object) -> None:
|
||||
def after_train_epoch(self, runner) -> None:
|
||||
"""Call step function for each scheduler after each epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
for scheduler in runner.schedulers: # type: ignore
|
||||
for scheduler in runner.schedulers:
|
||||
if scheduler.by_epoch:
|
||||
scheduler.step()
|
||||
|
@ -14,18 +14,15 @@ class DistSamplerSeedHook(Hook):
|
||||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def before_epoch(self, runner: object) -> None:
|
||||
def before_epoch(self, runner) -> None:
|
||||
"""Set the seed for sampler and batch_sampler.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if hasattr(runner.data_loader.sampler, 'set_epoch'): # type: ignore
|
||||
if hasattr(runner.data_loader.sampler, 'set_epoch'):
|
||||
# in case the data loader uses `SequentialSampler` in Pytorch
|
||||
runner.data_loader.sampler.set_epoch(runner.epoch) # type: ignore
|
||||
elif hasattr(
|
||||
runner.data_loader.batch_sampler.sampler, # type: ignore
|
||||
'set_epoch'):
|
||||
runner.data_loader.sampler.set_epoch(runner.epoch)
|
||||
elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
|
||||
# batch sampler in pytorch warps the sampler as its attributes.
|
||||
runner.data_loader.batch_sampler.sampler.set_epoch( # type: ignore
|
||||
runner.epoch) # type: ignore
|
||||
runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
|
||||
|
@ -89,11 +89,11 @@ class SyncBuffersHook(Hook):
|
||||
def __init__(self) -> None:
|
||||
self.distributed = dist.IS_DIST
|
||||
|
||||
def after_epoch(self, runner: object) -> None:
|
||||
def after_epoch(self, runner) -> None:
|
||||
"""All-reduce model buffers at the end of each epoch.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self.distributed:
|
||||
allreduce_params(runner.model.buffers()) # type: ignore
|
||||
allreduce_params(runner.model.buffers())
|
||||
|
Loading…
x
Reference in New Issue
Block a user