fix type hint in hooks (#106)

This commit is contained in:
Zaida Zhou 2022-03-07 19:35:37 +08:00 committed by GitHub
parent 9f0d1a9628
commit ed8dcb4c61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 122 additions and 154 deletions

View File

@ -9,6 +9,8 @@ from mmengine.fileio import FileClient
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
from .hook import Hook from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
@HOOKS.register_module() @HOOKS.register_module()
class CheckpointHook(Hook): class CheckpointHook(Hook):
@ -65,7 +67,7 @@ class CheckpointHook(Hook):
self.sync_buffer = sync_buffer self.sync_buffer = sync_buffer
self.file_client_args = file_client_args self.file_client_args = file_client_args
def before_run(self, runner: object) -> None: def before_run(self, runner) -> None:
"""Finish all operations, related to checkpoint. """Finish all operations, related to checkpoint.
This function will get the appropriate file client, and the directory This function will get the appropriate file client, and the directory
@ -75,7 +77,7 @@ class CheckpointHook(Hook):
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
if not self.out_dir: if not self.out_dir:
self.out_dir = runner.work_dir # type: ignore self.out_dir = runner.work_dir
self.file_client = FileClient.infer_client(self.file_client_args, self.file_client = FileClient.infer_client(self.file_client_args,
self.out_dir) self.out_dir)
@ -84,16 +86,12 @@ class CheckpointHook(Hook):
# `self.out_dir` is set so the final `self.out_dir` is the # `self.out_dir` is set so the final `self.out_dir` is the
# concatenation of `self.out_dir` and the last level directory of # concatenation of `self.out_dir` and the last level directory of
# `runner.work_dir` # `runner.work_dir`
if self.out_dir != runner.work_dir: # type: ignore if self.out_dir != runner.work_dir:
basename = osp.basename( basename = osp.basename(runner.work_dir.rstrip(osp.sep))
runner.work_dir.rstrip( # type: ignore
osp.sep))
self.out_dir = self.file_client.join_path( self.out_dir = self.file_client.join_path(
self.out_dir, # type: ignore self.out_dir, basename) # type: ignore # noqa: E501
basename)
runner.logger.info(( # type: ignore runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by '
f'Checkpoints will be saved to {self.out_dir} by '
f'{self.file_client.name}.')) f'{self.file_client.name}.'))
# disable the create_symlink option because some file backends do not # disable the create_symlink option because some file backends do not
@ -109,7 +107,7 @@ class CheckpointHook(Hook):
else: else:
self.args['create_symlink'] = self.file_client.allow_symlink self.args['create_symlink'] = self.file_client.allow_symlink
def after_train_epoch(self, runner: object) -> None: def after_train_epoch(self, runner) -> None:
"""Save the checkpoint and synchronize buffers after each epoch. """Save the checkpoint and synchronize buffers after each epoch.
Args: Args:
@ -124,46 +122,40 @@ class CheckpointHook(Hook):
if self.every_n_epochs( if self.every_n_epochs(
runner, self.interval) or (self.save_last runner, self.interval) or (self.save_last
and self.is_last_epoch(runner)): and self.is_last_epoch(runner)):
runner.logger.info( # type: ignore runner.logger.info(f'Saving checkpoint at \
f'Saving checkpoint at \ {runner.epoch + 1} epochs')
{runner.epoch + 1} epochs') # type: ignore
if self.sync_buffer: if self.sync_buffer:
pass pass
# TODO # TODO
self._save_checkpoint(runner) self._save_checkpoint(runner)
# TODO Add master_only decorator # TODO Add master_only decorator
def _save_checkpoint(self, runner: object) -> None: def _save_checkpoint(self, runner) -> None:
"""Save the current checkpoint and delete outdated checkpoint. """Save the current checkpoint and delete outdated checkpoint.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
runner.save_checkpoint( # type: ignore runner.save_checkpoint(
self.out_dir, self.out_dir, save_optimizer=self.save_optimizer, **self.args)
save_optimizer=self.save_optimizer, if runner.meta is not None:
**self.args)
if runner.meta is not None: # type: ignore
if self.by_epoch: if self.by_epoch:
cur_ckpt_filename = self.args.get( cur_ckpt_filename = self.args.get(
'filename_tmpl', 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1)
'epoch_{}.pth').format(runner.epoch + 1) # type: ignore
else: else:
cur_ckpt_filename = self.args.get( cur_ckpt_filename = self.args.get(
'filename_tmpl', 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1)
'iter_{}.pth').format(runner.iter + 1) # type: ignore runner.meta.setdefault('hook_msgs', dict())
runner.meta.setdefault('hook_msgs', dict()) # type: ignore runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path(
runner.meta['hook_msgs'][ # type: ignore
'last_ckpt'] = self.file_client.join_path(
self.out_dir, cur_ckpt_filename) # type: ignore self.out_dir, cur_ckpt_filename) # type: ignore
# remove other checkpoints # remove other checkpoints
if self.max_keep_ckpts > 0: if self.max_keep_ckpts > 0:
if self.by_epoch: if self.by_epoch:
name = 'epoch_{}.pth' name = 'epoch_{}.pth'
current_ckpt = runner.epoch + 1 # type: ignore current_ckpt = runner.epoch + 1
else: else:
name = 'iter_{}.pth' name = 'iter_{}.pth'
current_ckpt = runner.iter + 1 # type: ignore current_ckpt = runner.iter + 1
redundant_ckpts = range( redundant_ckpts = range(
current_ckpt - self.max_keep_ckpts * self.interval, 0, current_ckpt - self.max_keep_ckpts * self.interval, 0,
-self.interval) -self.interval)
@ -178,8 +170,8 @@ class CheckpointHook(Hook):
def after_train_iter( def after_train_iter(
self, self,
runner: object, runner,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Save the checkpoint and synchronize buffers after each iteration. """Save the checkpoint and synchronize buffers after each iteration.
@ -199,9 +191,8 @@ class CheckpointHook(Hook):
if self.every_n_iters( if self.every_n_iters(
runner, self.interval) or (self.save_last runner, self.interval) or (self.save_last
and self.is_last_iter(runner)): and self.is_last_iter(runner)):
runner.logger.info( # type: ignore runner.logger.info(f'Saving checkpoint at \
f'Saving checkpoint at \ {runner.iter + 1} iterations')
{runner.iter + 1} iterations') # type: ignore
if self.sync_buffer: if self.sync_buffer:
pass pass
# TODO # TODO

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence from typing import Any, Optional, Sequence, Tuple
import torch import torch
@ -7,6 +7,8 @@ from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
from .hook import Hook from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
@HOOKS.register_module() @HOOKS.register_module()
class EmptyCacheHook(Hook): class EmptyCacheHook(Hook):
@ -33,35 +35,35 @@ class EmptyCacheHook(Hook):
self._after_iter = after_iter self._after_iter = after_iter
def after_iter(self, def after_iter(self,
runner: object, runner,
data_batch: Optional[Sequence[BaseDataSample]] = None, data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Empty cache after an iteration. """Empty cache after an iteration.
Args: Args:
runner (object): The runner of the training process. runner (Runner): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
Defaults to None. from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model. outputs (Sequence[BaseDataSample]): Outputs from model.
Defaults to None. Defaults to None.
""" """
if self._after_iter: if self._after_iter:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def before_epoch(self, runner: object) -> None: def before_epoch(self, runner) -> None:
"""Empty cache before an epoch. """Empty cache before an epoch.
Args: Args:
runner (object): The runner of the training process. runner (Runner): The runner of the training process.
""" """
if self._before_epoch: if self._before_epoch:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def after_epoch(self, runner: object) -> None: def after_epoch(self, runner) -> None:
"""Empty cache after an epoch. """Empty cache after an epoch.
Args: Args:
runner (object): The runner of the training process. runner (Runner): The runner of the training process.
""" """
if self._after_epoch: if self._after_epoch:
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -3,6 +3,8 @@ from typing import Any, Optional, Sequence, Tuple
from mmengine.data import BaseDataSample from mmengine.data import BaseDataSample
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
class Hook: class Hook:
"""Base hook class. """Base hook class.
@ -12,7 +14,7 @@ class Hook:
priority = 'NORMAL' priority = 'NORMAL'
def before_run(self, runner: object) -> None: def before_run(self, runner) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before the training process. operations before the training process.
@ -21,7 +23,7 @@ class Hook:
""" """
pass pass
def after_run(self, runner: object) -> None: def after_run(self, runner) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after the training process. operations after the training process.
@ -30,7 +32,7 @@ class Hook:
""" """
pass pass
def before_epoch(self, runner: object) -> None: def before_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each epoch. operations before each epoch.
@ -39,7 +41,7 @@ class Hook:
""" """
pass pass
def after_epoch(self, runner: object) -> None: def after_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each epoch. operations after each epoch.
@ -48,11 +50,7 @@ class Hook:
""" """
pass pass
def before_iter( def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each iter. operations before each iter.
@ -64,9 +62,8 @@ class Hook:
pass pass
def after_iter(self, def after_iter(self,
runner: object, runner,
data_batch: Optional[Sequence[Tuple[ data_batch: DATA_BATCH = None,
Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each epoch. operations after each epoch.
@ -80,7 +77,7 @@ class Hook:
""" """
pass pass
def before_save_checkpoint(self, runner: object, checkpoint: dict) -> None: def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before saving the checkpoint. operations before saving the checkpoint.
@ -90,7 +87,7 @@ class Hook:
""" """
pass pass
def after_load_checkpoint(self, runner: object, checkpoint: dict) -> None: def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after loading the checkpoint. operations after loading the checkpoint.
@ -100,7 +97,7 @@ class Hook:
""" """
pass pass
def before_train_epoch(self, runner: object) -> None: def before_train_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each training epoch. operations before each training epoch.
@ -109,7 +106,7 @@ class Hook:
""" """
self.before_epoch(runner) self.before_epoch(runner)
def before_val_epoch(self, runner: object) -> None: def before_val_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each validation epoch. operations before each validation epoch.
@ -118,7 +115,7 @@ class Hook:
""" """
self.before_epoch(runner) self.before_epoch(runner)
def before_test_epoch(self, runner: object) -> None: def before_test_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each test epoch. operations before each test epoch.
@ -127,7 +124,7 @@ class Hook:
""" """
self.before_epoch(runner) self.before_epoch(runner)
def after_train_epoch(self, runner: object) -> None: def after_train_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each training epoch. operations after each training epoch.
@ -136,7 +133,7 @@ class Hook:
""" """
self.after_epoch(runner) self.after_epoch(runner)
def after_val_epoch(self, runner: object) -> None: def after_val_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each validation epoch. operations after each validation epoch.
@ -145,7 +142,7 @@ class Hook:
""" """
self.after_epoch(runner) self.after_epoch(runner)
def after_test_epoch(self, runner: object) -> None: def after_test_epoch(self, runner) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each test epoch. operations after each test epoch.
@ -154,11 +151,7 @@ class Hook:
""" """
self.after_epoch(runner) self.after_epoch(runner)
def before_train_iter( def before_train_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each training iteration. operations before each training iteration.
@ -169,11 +162,7 @@ class Hook:
""" """
self.before_iter(runner, data_batch=None) self.before_iter(runner, data_batch=None)
def before_val_iter( def before_val_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each validation iteration. operations before each validation iteration.
@ -184,11 +173,7 @@ class Hook:
""" """
self.before_iter(runner, data_batch=None) self.before_iter(runner, data_batch=None)
def before_test_iter( def before_test_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each test iteration. operations before each test iteration.
@ -201,8 +186,8 @@ class Hook:
def after_train_iter( def after_train_iter(
self, self,
runner: object, runner,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each training iteration. operations after each training iteration.
@ -218,8 +203,8 @@ class Hook:
def after_val_iter( def after_val_iter(
self, self,
runner: object, runner,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each validation iteration. operations after each validation iteration.
@ -235,8 +220,8 @@ class Hook:
def after_test_iter( def after_test_iter(
self, self,
runner: object, runner,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each test iteration. operations after each test iteration.
@ -250,7 +235,7 @@ class Hook:
""" """
self.after_iter(runner, data_batch=None, outputs=None) self.after_iter(runner, data_batch=None, outputs=None)
def every_n_epochs(self, runner: object, n: int) -> bool: def every_n_epochs(self, runner, n: int) -> bool:
"""Test whether or not current epoch can be evenly divided by n. """Test whether or not current epoch can be evenly divided by n.
Args: Args:
@ -260,9 +245,9 @@ class Hook:
Returns: Returns:
bool: whether or not current epoch can be evenly divided by n. bool: whether or not current epoch can be evenly divided by n.
""" """
return (runner.epoch + 1) % n == 0 if n > 0 else False # type: ignore return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner: object, n: int) -> bool: def every_n_inner_iters(self, runner, n: int) -> bool:
"""Test whether or not current inner iteration can be evenly divided by """Test whether or not current inner iteration can be evenly divided by
n. n.
@ -275,10 +260,9 @@ class Hook:
bool: whether or not current inner iteration can be evenly bool: whether or not current inner iteration can be evenly
divided by n. divided by n.
""" """
return (runner.inner_iter + # type: ignore return (runner.inner_iter + 1) % n == 0 if n > 0 else False
1) % n == 0 if n > 0 else False
def every_n_iters(self, runner: object, n: int) -> bool: def every_n_iters(self, runner, n: int) -> bool:
"""Test whether or not current iteration can be evenly divided by n. """Test whether or not current iteration can be evenly divided by n.
Args: Args:
@ -290,9 +274,9 @@ class Hook:
bool: Return True if the current iteration can be evenly divided bool: Return True if the current iteration can be evenly divided
by n, otherwise False. by n, otherwise False.
""" """
return (runner.iter + 1) % n == 0 if n > 0 else False # type: ignore return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner: object) -> bool: def end_of_epoch(self, runner) -> bool:
"""Check whether the current epoch reaches the `max_epochs` or not. """Check whether the current epoch reaches the `max_epochs` or not.
Args: Args:
@ -301,9 +285,9 @@ class Hook:
Returns: Returns:
bool: whether the end of current epoch or not. bool: whether the end of current epoch or not.
""" """
return runner.inner_iter + 1 == len(runner.data_loader) # type: ignore return runner.inner_iter + 1 == len(runner.data_loader)
def is_last_epoch(self, runner: object) -> bool: def is_last_epoch(self, runner) -> bool:
"""Test whether or not current epoch is the last epoch. """Test whether or not current epoch is the last epoch.
Args: Args:
@ -313,9 +297,9 @@ class Hook:
bool: bool: Return True if the current epoch reaches the bool: bool: Return True if the current epoch reaches the
`max_epochs`, otherwise False. `max_epochs`, otherwise False.
""" """
return runner.epoch + 1 == runner._max_epochs # type: ignore return runner.epoch + 1 == runner._max_epochs
def is_last_iter(self, runner: object) -> bool: def is_last_iter(self, runner) -> bool:
"""Test whether or not current epoch is the last iteration. """Test whether or not current epoch is the last iteration.
Args: Args:
@ -324,4 +308,4 @@ class Hook:
Returns: Returns:
bool: whether or not current iteration is the last iteration. bool: whether or not current iteration is the last iteration.
""" """
return runner.iter + 1 == runner._max_iters # type: ignore return runner.iter + 1 == runner._max_iters

View File

@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import time import time
from typing import Optional, Sequence from typing import Any, Optional, Sequence, Tuple
from mmengine.data import BaseDataSample from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
from .hook import Hook from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
@HOOKS.register_module() @HOOKS.register_module()
class IterTimerHook(Hook): class IterTimerHook(Hook):
@ -16,45 +18,38 @@ class IterTimerHook(Hook):
priority = 'NORMAL' priority = 'NORMAL'
def before_epoch(self, runner: object) -> None: def before_epoch(self, runner) -> None:
"""Record time flag before start a epoch. """Record time flag before start a epoch.
Args: Args:
runner (object): The runner of the training process. runner (Runner): The runner of the training process.
""" """
self.t = time.time() self.t = time.time()
def before_iter( def before_iter(self, runner, data_batch: DATA_BATCH = None) -> None:
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Logging time for loading data and update the time flag. """Logging time for loading data and update the time flag.
Args: Args:
runner (object): The runner of the training process. runner (Runner): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
Defaults to None. from dataloader. Defaults to None.
""" """
# TODO: update for new logging system # TODO: update for new logging system
runner.log_buffer.update({ # type: ignore runner.log_buffer.update({'data_time': time.time() - self.t})
'data_time': time.time() - self.t
})
def after_iter(self, def after_iter(self,
runner: object, runner,
data_batch: Optional[Sequence[BaseDataSample]] = None, data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Logging time for a iteration and update the time flag. """Logging time for a iteration and update the time flag.
Args: Args:
runner (object): The runner of the training process. runner (Runner): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
Defaults to None. from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model. outputs (Sequence[BaseDataSample]): Outputs from model.
Defaults to None. Defaults to None.
""" """
# TODO: update for new logging system # TODO: update for new logging system
runner.log_buffer.update({ # type: ignore runner.log_buffer.update({'time': time.time() - self.t})
'time': time.time() - self.t
})
self.t = time.time() self.t = time.time()

View File

@ -10,6 +10,8 @@ from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
from .hook import Hook from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
@HOOKS.register_module() @HOOKS.register_module()
class OptimizerHook(Hook): class OptimizerHook(Hook):
@ -56,8 +58,8 @@ class OptimizerHook(Hook):
def after_train_iter( def after_train_iter(
self, self,
runner: object, runner,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All operations need to be finished after each training iteration. """All operations need to be finished after each training iteration.
@ -82,32 +84,27 @@ class OptimizerHook(Hook):
In order to keep this interface consistent with other hooks, In order to keep this interface consistent with other hooks,
we keep ``outputs`` here. Defaults to None. we keep ``outputs`` here. Defaults to None.
""" """
runner.optimizer.zero_grad() # type: ignore runner.optimizer.zero_grad()
if self.detect_anomalous_params: if self.detect_anomalous_params:
self.detect_anomalous_parameters( self.detect_anomalous_parameters(runner.outputs['loss'], runner)
runner.outputs['loss'], # type: ignore runner.outputs['loss'].backward()
runner)
runner.outputs['loss'].backward() # type: ignore
if self.grad_clip is not None: if self.grad_clip is not None:
grad_norm = self.clip_grads( grad_norm = self.clip_grads(runner.model.parameters())
runner.model.parameters()) # type: ignore
if grad_norm is not None: if grad_norm is not None:
# Add grad norm to the logger # Add grad norm to the logger
runner.log_buffer.update( # type: ignore runner.log_buffer.update({'grad_norm': float(grad_norm)},
{'grad_norm': float(grad_norm)}, runner.outputs['num_samples'])
runner.outputs['num_samples']) # type: ignore runner.optimizer.step()
runner.optimizer.step() # type: ignore
def detect_anomalous_parameters(self, loss: torch.Tensor, def detect_anomalous_parameters(self, loss: torch.Tensor, runner) -> None:
runner: object) -> None:
"""Detect anomalous parameters that are not included in the graph. """Detect anomalous parameters that are not included in the graph.
Args: Args:
loss (torch.Tensor): The loss of current iteration. loss (torch.Tensor): The loss of current iteration.
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
logger = runner.logger # type: ignore logger = runner.logger
parameters_in_graph = set() parameters_in_graph = set()
visited = set() visited = set()
@ -125,7 +122,7 @@ class OptimizerHook(Hook):
traverse(grad_fn) traverse(grad_fn)
traverse(loss.grad_fn) traverse(loss.grad_fn)
for n, p in runner.model.named_parameters(): # type: ignore for n, p in runner.model.named_parameters():
if p not in parameters_in_graph and p.requires_grad: if p not in parameters_in_graph and p.requires_grad:
logger.log( logger.log(
level=logging.ERROR, level=logging.ERROR,

View File

@ -5,6 +5,8 @@ from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
from .hook import Hook from .hook import Hook
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
@HOOKS.register_module() @HOOKS.register_module()
class ParamSchedulerHook(Hook): class ParamSchedulerHook(Hook):
@ -15,8 +17,8 @@ class ParamSchedulerHook(Hook):
def after_train_iter( def after_train_iter(
self, self,
runner: object, runner,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Call step function for each scheduler after each iteration. """Call step function for each scheduler after each iteration.
@ -30,16 +32,16 @@ class ParamSchedulerHook(Hook):
In order to keep this interface consistent with other hooks, we In order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None. keep ``data_batch`` here. Defaults to None.
""" """
for scheduler in runner.schedulers: # type: ignore for scheduler in runner.schedulers:
if not scheduler.by_epoch: if not scheduler.by_epoch:
scheduler.step() scheduler.step()
def after_train_epoch(self, runner: object) -> None: def after_train_epoch(self, runner) -> None:
"""Call step function for each scheduler after each epoch. """Call step function for each scheduler after each epoch.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
for scheduler in runner.schedulers: # type: ignore for scheduler in runner.schedulers:
if scheduler.by_epoch: if scheduler.by_epoch:
scheduler.step() scheduler.step()

View File

@ -14,18 +14,15 @@ class DistSamplerSeedHook(Hook):
priority = 'NORMAL' priority = 'NORMAL'
def before_epoch(self, runner: object) -> None: def before_epoch(self, runner) -> None:
"""Set the seed for sampler and batch_sampler. """Set the seed for sampler and batch_sampler.
Args: Args:
runner (Runner): The runner of the training process. runner (Runner): The runner of the training process.
""" """
if hasattr(runner.data_loader.sampler, 'set_epoch'): # type: ignore if hasattr(runner.data_loader.sampler, 'set_epoch'):
# in case the data loader uses `SequentialSampler` in Pytorch # in case the data loader uses `SequentialSampler` in Pytorch
runner.data_loader.sampler.set_epoch(runner.epoch) # type: ignore runner.data_loader.sampler.set_epoch(runner.epoch)
elif hasattr( elif hasattr(runner.data_loader.batch_sampler.sampler, 'set_epoch'):
runner.data_loader.batch_sampler.sampler, # type: ignore
'set_epoch'):
# batch sampler in pytorch warps the sampler as its attributes. # batch sampler in pytorch warps the sampler as its attributes.
runner.data_loader.batch_sampler.sampler.set_epoch( # type: ignore runner.data_loader.batch_sampler.sampler.set_epoch(runner.epoch)
runner.epoch) # type: ignore

View File

@ -89,11 +89,11 @@ class SyncBuffersHook(Hook):
def __init__(self) -> None: def __init__(self) -> None:
self.distributed = dist.IS_DIST self.distributed = dist.IS_DIST
def after_epoch(self, runner: object) -> None: def after_epoch(self, runner) -> None:
"""All-reduce model buffers at the end of each epoch. """All-reduce model buffers at the end of each epoch.
Args: Args:
runner (object): The runner of the training process. runner (Runner): The runner of the training process.
""" """
if self.distributed: if self.distributed:
allreduce_params(runner.model.buffers()) # type: ignore allreduce_params(runner.model.buffers())