[Fix]: Fix data batch type in base hook (#99)

* [Fix]: Fix data batch type in base hook

* [Fix]: Fix the type hint bug in checkpoint, optimizer, param scheduler hooks

Co-authored-by: Your <you@example.com>
pull/100/head
Yuan Liu 2022-03-07 13:25:45 +08:00 committed by GitHub
parent 3adf4ea6b8
commit 15abb061ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 50 deletions

View File

@ -2,7 +2,7 @@
import os.path as osp
import warnings
from pathlib import Path
from typing import Optional, Sequence, Union
from typing import Any, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataSample
from mmengine.fileio import FileClient
@ -179,14 +179,14 @@ class CheckpointHook(Hook):
def after_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Save the checkpoint and synchronize buffers after each iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader.
Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None.
"""

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
from typing import Any, Optional, Sequence, Tuple
from mmengine.data import BaseDataSample
@ -49,31 +49,33 @@ class Hook:
pass
def before_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any
operations before each iter.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader.
Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
pass
def after_iter(self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[
Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each epoch.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader.
Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None.
"""
pass
@ -153,59 +155,62 @@ class Hook:
self.after_epoch(runner)
def before_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any
operations before each training iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
self.before_iter(runner, data_batch=None)
def before_val_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any
operations before each validation iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
self.before_iter(runner, data_batch=None)
def before_test_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
self,
runner: object,
data_batch: Optional[Sequence[Tuple[Any,
BaseDataSample]]] = None) -> None:
"""All subclasses should override this method, if they need any
operations before each test iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
"""
self.before_iter(runner, data_batch=None)
def after_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each training iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None.
"""
@ -214,15 +219,15 @@ class Hook:
def after_val_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each validation iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from
model. Defaults to None.
"""
@ -231,15 +236,15 @@ class Hook:
def after_test_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each test iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
Data from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None.
"""

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import List, Optional, Sequence
from typing import Any, List, Optional, Sequence, Tuple
import torch
from torch.nn.parameter import Parameter
@ -57,7 +57,7 @@ class OptimizerHook(Hook):
def after_train_iter(
self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""All operations need to be finished after each training iteration.
@ -74,9 +74,10 @@ class OptimizerHook(Hook):
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. In order to keep this interface consistent with
other hooks, we keep ``data_batch`` here. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here.
Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
In order to keep this interface consistent with other hooks,
we keep ``outputs`` here. Defaults to None.

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
from typing import Any, Optional, Sequence, Tuple
from mmengine.data import BaseDataSample
from mmengine.registry import HOOKS
@ -15,17 +15,19 @@ class ParamSchedulerHook(Hook):
def after_iter(self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,
data_batch: Optional[Sequence[Tuple[
Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Call step function for each scheduler after each iteration.
Args:
runner (object): The runner of the training process.
data_batch (Sequence[BaseDataSample]): Data from dataloader. In
order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None.
outputs (Sequence[BaseDataSample]): Outputs from model. In
order to keep this interface consistent with other hooks, we
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. In order to keep this interface consistent
with other hooks, we keep ``data_batch`` here.
Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
In order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None.
"""
for scheduler in runner.schedulers: # type: ignore