[Fix]: Fix data batch type in base hook (#99)
* [Fix]: Fix data batch type in base hook * [Fix]: Fix the type hint bug in checkpoint, optimizer, param scheduler hooks Co-authored-by: Your <you@example.com>pull/100/head
parent
3adf4ea6b8
commit
15abb061ef
|
@ -2,7 +2,7 @@
|
|||
import os.path as osp
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence, Union
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.fileio import FileClient
|
||||
|
@ -179,14 +179,14 @@ class CheckpointHook(Hook):
|
|||
def after_train_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Save the checkpoint and synchronize buffers after each iteration.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||
Defaults to None.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
|
||||
|
@ -49,31 +49,33 @@ class Hook:
|
|||
pass
|
||||
|
||||
def before_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any,
|
||||
BaseDataSample]]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each iter.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||
Defaults to None.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
"""
|
||||
pass
|
||||
|
||||
def after_iter(self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
data_batch: Optional[Sequence[Tuple[
|
||||
Any, BaseDataSample]]] = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each epoch.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample]): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (Sequence[BaseDataSample]): Outputs from model.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
pass
|
||||
|
@ -153,59 +155,62 @@ class Hook:
|
|||
self.after_epoch(runner)
|
||||
|
||||
def before_train_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any,
|
||||
BaseDataSample]]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each training iteration.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample], optional): Data from
|
||||
dataloader. Defaults to None.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
"""
|
||||
self.before_iter(runner, data_batch=None)
|
||||
|
||||
def before_val_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any,
|
||||
BaseDataSample]]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each validation iteration.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample], optional): Data from
|
||||
dataloader. Defaults to None.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
"""
|
||||
self.before_iter(runner, data_batch=None)
|
||||
|
||||
def before_test_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[Tuple[Any,
|
||||
BaseDataSample]]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before each test iteration.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample], optional): Data from
|
||||
dataloader. Defaults to None.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
"""
|
||||
self.before_iter(runner, data_batch=None)
|
||||
|
||||
def after_train_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each training iteration.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample], optional): Data from
|
||||
dataloader. Defaults to None.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
@ -214,15 +219,15 @@ class Hook:
|
|||
def after_val_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each validation iteration.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample], optional): Data from
|
||||
dataloader. Defaults to None.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from
|
||||
model. Defaults to None.
|
||||
"""
|
||||
|
@ -231,15 +236,15 @@ class Hook:
|
|||
def after_test_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each test iteration.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample], optional): Data from
|
||||
dataloader. Defaults to None.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
from typing import List, Optional, Sequence
|
||||
from typing import Any, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
@ -57,7 +57,7 @@ class OptimizerHook(Hook):
|
|||
def after_train_iter(
|
||||
self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""All operations need to be finished after each training iteration.
|
||||
|
||||
|
@ -74,9 +74,10 @@ class OptimizerHook(Hook):
|
|||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample], optional): Data from
|
||||
dataloader. In order to keep this interface consistent with
|
||||
other hooks, we keep ``data_batch`` here. Defaults to None.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. In order to keep this interface consistent
|
||||
with other hooks, we keep ``data_batch`` here.
|
||||
Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
In order to keep this interface consistent with other hooks,
|
||||
we keep ``outputs`` here. Defaults to None.
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.registry import HOOKS
|
||||
|
@ -15,17 +15,19 @@ class ParamSchedulerHook(Hook):
|
|||
|
||||
def after_iter(self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
data_batch: Optional[Sequence[Tuple[
|
||||
Any, BaseDataSample]]] = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Call step function for each scheduler after each iteration.
|
||||
|
||||
Args:
|
||||
runner (object): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample]): Data from dataloader. In
|
||||
order to keep this interface consistent with other hooks, we
|
||||
keep ``data_batch`` here. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample]): Outputs from model. In
|
||||
order to keep this interface consistent with other hooks, we
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. In order to keep this interface consistent
|
||||
with other hooks, we keep ``data_batch`` here.
|
||||
Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
In order to keep this interface consistent with other hooks, we
|
||||
keep ``data_batch`` here. Defaults to None.
|
||||
"""
|
||||
for scheduler in runner.schedulers: # type: ignore
|
||||
|
|
Loading…
Reference in New Issue