[Fix]: Fix data batch type in base hook (#99)

* [Fix]: Fix data batch type in base hook

* [Fix]: Fix the type hint bug in checkpoint, optimizer, param scheduler hooks

Co-authored-by: Your <you@example.com>
pull/100/head
Yuan Liu 2022-03-07 13:25:45 +08:00 committed by GitHub
parent 3adf4ea6b8
commit 15abb061ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 50 deletions

View File

@ -2,7 +2,7 @@
import os.path as osp import os.path as osp
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Optional, Sequence, Union from typing import Any, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataSample from mmengine.data import BaseDataSample
from mmengine.fileio import FileClient from mmengine.fileio import FileClient
@ -179,14 +179,14 @@ class CheckpointHook(Hook):
def after_train_iter( def after_train_iter(
self, self,
runner: object, runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None, data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Save the checkpoint and synchronize buffers after each iteration. """Save the checkpoint and synchronize buffers after each iteration.
Args: Args:
runner (object): The runner of the training process. runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
Defaults to None. from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model. outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None. Defaults to None.
""" """

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence from typing import Any, Optional, Sequence, Tuple
from mmengine.data import BaseDataSample from mmengine.data import BaseDataSample
@ -49,31 +49,33 @@ class Hook:
pass pass
def before_iter( def before_iter(
self, self,
runner: object, runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each iter. operations before each iter.
Args: Args:
runner (object): The runner of the training process. runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Defaults to None. Data from dataloader. Defaults to None.
""" """
pass pass
def after_iter(self, def after_iter(self,
runner: object, runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None, data_batch: Optional[Sequence[Tuple[
Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each epoch. operations after each epoch.
Args: Args:
runner (object): The runner of the training process. runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader. data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Defaults to None. Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model. outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None. Defaults to None.
""" """
pass pass
@ -153,59 +155,62 @@ class Hook:
self.after_epoch(runner) self.after_epoch(runner)
def before_train_iter( def before_train_iter(
self, self,
runner: object, runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each training iteration. operations before each training iteration.
Args: Args:
runner (object): The runner of the training process. runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
dataloader. Defaults to None. Data from dataloader. Defaults to None.
""" """
self.before_iter(runner, data_batch=None) self.before_iter(runner, data_batch=None)
def before_val_iter( def before_val_iter(
self, self,
runner: object, runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each validation iteration. operations before each validation iteration.
Args: Args:
runner (object): The runner of the training process. runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
dataloader. Defaults to None. Data from dataloader. Defaults to None.
""" """
self.before_iter(runner, data_batch=None) self.before_iter(runner, data_batch=None)
def before_test_iter( def before_test_iter(
self, self,
runner: object, runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None: data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before each test iteration. operations before each test iteration.
Args: Args:
runner (object): The runner of the training process. runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
dataloader. Defaults to None. Data from dataloader. Defaults to None.
""" """
self.before_iter(runner, data_batch=None) self.before_iter(runner, data_batch=None)
def after_train_iter( def after_train_iter(
self, self,
runner: object, runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None, data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each training iteration. operations after each training iteration.
Args: Args:
runner (object): The runner of the training process. runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
dataloader. Defaults to None. Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model. outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None. Defaults to None.
""" """
@ -214,15 +219,15 @@ class Hook:
def after_val_iter( def after_val_iter(
self, self,
runner: object, runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None, data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each validation iteration. operations after each validation iteration.
Args: Args:
runner (object): The runner of the training process. runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
dataloader. Defaults to None. Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from outputs (Sequence[BaseDataSample], optional): Outputs from
model. Defaults to None. model. Defaults to None.
""" """
@ -231,15 +236,15 @@ class Hook:
def after_test_iter( def after_test_iter(
self, self,
runner: object, runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None, data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations after each test iteration. operations after each test iteration.
Args: Args:
runner (object): The runner of the training process. runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
dataloader. Defaults to None. Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model. outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None. Defaults to None.
""" """

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging import logging
from typing import List, Optional, Sequence from typing import Any, List, Optional, Sequence, Tuple
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@ -57,7 +57,7 @@ class OptimizerHook(Hook):
def after_train_iter( def after_train_iter(
self, self,
runner: object, runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None, data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All operations need to be finished after each training iteration. """All operations need to be finished after each training iteration.
@ -74,9 +74,10 @@ class OptimizerHook(Hook):
Args: Args:
runner (object): The runner of the training process. runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
dataloader. In order to keep this interface consistent with from dataloader. In order to keep this interface consistent
other hooks, we keep ``data_batch`` here. Defaults to None. with other hooks, we keep ``data_batch`` here.
Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model. outputs (Sequence[BaseDataSample], optional): Outputs from model.
In order to keep this interface consistent with other hooks, In order to keep this interface consistent with other hooks,
we keep ``outputs`` here. Defaults to None. we keep ``outputs`` here. Defaults to None.

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence from typing import Any, Optional, Sequence, Tuple
from mmengine.data import BaseDataSample from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS from mmengine.registry import HOOKS
@ -15,17 +15,19 @@ class ParamSchedulerHook(Hook):
def after_iter(self, def after_iter(self,
runner: object, runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None, data_batch: Optional[Sequence[Tuple[
Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None: outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Call step function for each scheduler after each iteration. """Call step function for each scheduler after each iteration.
Args: Args:
runner (object): The runner of the training process. runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader. In data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
order to keep this interface consistent with other hooks, we from dataloader. In order to keep this interface consistent
keep ``data_batch`` here. Defaults to None. with other hooks, we keep ``data_batch`` here.
outputs (Sequence[BaseDataSample]): Outputs from model. In Defaults to None.
order to keep this interface consistent with other hooks, we outputs (Sequence[BaseDataSample], optional): Outputs from model.
In order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None. keep ``data_batch`` here. Defaults to None.
""" """
for scheduler in runner.schedulers: # type: ignore for scheduler in runner.schedulers: # type: ignore