fix type hint and format (#88)

pull/91/head^2
Zaida Zhou 2022-03-05 17:44:31 +08:00 committed by GitHub
parent 11b38b12d6
commit fd85156412
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 35 additions and 14 deletions

View File

@ -406,7 +406,7 @@ class BaseDataElement:
# Tensor-like methods # Tensor-like methods
def numpy(self) -> 'BaseDataElement': def numpy(self) -> 'BaseDataElement':
"""Convert all tensor to np.narray in metainfo and data.""" """Convert all tensor to np.narray in metainfo and data."""
new_data = self.new() new_data = self.new()
for k, v in self.data_items(): for k, v in self.data_items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):

View File

@ -500,7 +500,7 @@ class BaseDataSample:
# Tensor-like methods # Tensor-like methods
def numpy(self) -> 'BaseDataSample': def numpy(self) -> 'BaseDataSample':
"""Convert all tensor to np.narray in metainfo and data.""" """Convert all tensor to np.narray in metainfo and data."""
new_data = self.new() new_data = self.new()
for k, v in self.data_items(): for k, v in self.data_items():
if isinstance(v, (torch.Tensor, BaseDataElement)): if isinstance(v, (torch.Tensor, BaseDataElement)):

View File

@ -10,6 +10,7 @@ from typing import Any, List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from mmengine.data import BaseDataSample
from mmengine.utils import mkdir_or_exist from mmengine.utils import mkdir_or_exist
@ -45,13 +46,13 @@ class BaseEvaluator(metaclass=ABCMeta):
self._dataset_meta = dataset_meta self._dataset_meta = dataset_meta
@abstractmethod @abstractmethod
def process(self, data_samples: dict, predictions: dict) -> None: def process(self, data_samples: BaseDataSample, predictions: dict) -> None:
"""Process one batch of data samples and predictions. The processed """Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed. compute the metrics when all batches have been processed.
Args: Args:
data_samples (dict): The data samples from the dataset. data_samples (BaseDataSample): The data samples from the dataset.
predictions (dict): The output of the model. predictions (dict): The output of the model.
""" """
@ -61,6 +62,7 @@ class BaseEvaluator(metaclass=ABCMeta):
Args: Args:
results (list): The processed results of each batch. results (list): The processed results of each batch.
Returns: Returns:
dict: The computed metrics. The keys are the names of the metrics, dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results. and the values are corresponding results.
@ -78,9 +80,8 @@ class BaseEvaluator(metaclass=ABCMeta):
this size. this size.
Returns: Returns:
metrics (dict): Evaluation metrics dict on the val dataset. The dict: Evaluation metrics dict on the val dataset. The keys are the
keys are the names of the metrics, and the values are names of the metrics, and the values are corresponding results.
corresponding results.
""" """
if len(self.results) == 0: if len(self.results) == 0:
warnings.warn( warnings.warn(

View File

@ -1,9 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
from ..registry import EVALUATORS from ..registry import EVALUATORS
from .base import BaseEvaluator
from .composed_evaluator import ComposedEvaluator from .composed_evaluator import ComposedEvaluator
def build_evaluator(cfg: dict) -> object: def build_evaluator(
cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
"""Build function of evaluator. """Build function of evaluator.
When the evaluator config is a list, it will automatically build composed When the evaluator config is a list, it will automatically build composed

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union from typing import Optional, Sequence, Union
from mmengine.data import BaseDataSample
from .base import BaseEvaluator from .base import BaseEvaluator
@ -31,11 +32,11 @@ class ComposedEvaluator:
for evaluator in self.evaluators: for evaluator in self.evaluators:
evaluator.dataset_meta = dataset_meta evaluator.dataset_meta = dataset_meta
def process(self, data_samples: dict, predictions: dict): def process(self, data_samples: BaseDataSample, predictions: dict):
"""Invoke process method of each wrapped evaluator. """Invoke process method of each wrapped evaluator.
Args: Args:
data_samples (dict): The data samples from the dataset. data_samples (BaseDataSample): The data samples from the dataset.
predictions (dict): The output of the model. predictions (dict): The output of the model.
""" """
@ -54,9 +55,8 @@ class ComposedEvaluator:
this size. this size.
Returns: Returns:
metrics (dict): Evaluation metrics of all wrapped evaluators. The dict: Evaluation metrics of all wrapped evaluators. The keys are
keys are the names of the metrics, and the values are the names of the metrics, and the values are corresponding results.
corresponding results.
""" """
metrics = {} metrics = {}
for evaluator in self.evaluators: for evaluator in self.evaluators:

View File

@ -43,6 +43,8 @@ class CheckpointHook(Hook):
Default: None. Default: None.
""" """
priority = 'VERY_LOW'
def __init__(self, def __init__(self,
interval: int = -1, interval: int = -1,
by_epoch: bool = True, by_epoch: bool = True,

View File

@ -22,6 +22,8 @@ class EmptyCacheHook(Hook):
Defaults to False. Defaults to False.
""" """
priority = 'NORMAL'
def __init__(self, def __init__(self,
before_epoch: bool = False, before_epoch: bool = False,
after_epoch: bool = True, after_epoch: bool = True,

View File

@ -10,6 +10,8 @@ class Hook:
All hooks should inherit from this class. All hooks should inherit from this class.
""" """
priority = 'NORMAL'
def before_run(self, runner: object) -> None: def before_run(self, runner: object) -> None:
"""All subclasses should override this method, if they need any """All subclasses should override this method, if they need any
operations before the training process. operations before the training process.

View File

@ -14,6 +14,8 @@ class IterTimerHook(Hook):
Eg. ``data_time`` for loading data and ``time`` for a model train step. Eg. ``data_time`` for loading data and ``time`` for a model train step.
""" """
priority = 'NORMAL'
def before_epoch(self, runner: object) -> None: def before_epoch(self, runner: object) -> None:
"""Record time flag before start a epoch. """Record time flag before start a epoch.

View File

@ -30,6 +30,8 @@ class OptimizerHook(Hook):
Defaults to False. Defaults to False.
""" """
priority = 'HIGH'
def __init__(self, def __init__(self,
grad_clip: Optional[dict] = None, grad_clip: Optional[dict] = None,
detect_anomalous_params: bool = False) -> None: detect_anomalous_params: bool = False) -> None:

View File

@ -11,6 +11,8 @@ class ParamSchedulerHook(Hook):
"""A hook to update some hyper-parameters in optimizer, e.g learning rate """A hook to update some hyper-parameters in optimizer, e.g learning rate
and momentum.""" and momentum."""
priority = 'LOW'
def after_iter(self, def after_iter(self,
runner: object, runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None, data_batch: Optional[Sequence[BaseDataSample]] = None,

View File

@ -12,6 +12,8 @@ class DistSamplerSeedHook(Hook):
purpose with :obj:`IterLoader`. purpose with :obj:`IterLoader`.
""" """
priority = 'NORMAL'
def before_epoch(self, runner: object) -> None: def before_epoch(self, runner: object) -> None:
"""Set the seed for sampler and batch_sampler. """Set the seed for sampler and batch_sampler.

View File

@ -84,6 +84,8 @@ class SyncBuffersHook(Hook):
"""Synchronize model buffers such as running_mean and running_var in BN at """Synchronize model buffers such as running_mean and running_var in BN at
the end of each epoch.""" the end of each epoch."""
priority = 'NORMAL'
def __init__(self) -> None: def __init__(self) -> None:
self.distributed = dist.IS_DIST self.distributed = dist.IS_DIST

View File

@ -8,7 +8,7 @@ def is_model_wrapper(model):
The following 4 model in MMEngine (and their subclasses) are regarded as The following 4 model in MMEngine (and their subclasses) are regarded as
model wrappers: DataParallel, DistributedDataParallel, model wrappers: DataParallel, DistributedDataParallel,
MMDataParallel, MMDistributedDataParallel. You may add you own MMDataParallel, MMDistributedDataParallel. You may add you own
model wrapper by registering it to mmengine.registry.MODEL_WRAPPERS. model wrapper by registering it to ``mmengine.registry.MODEL_WRAPPERS``.
Args: Args:
model (nn.Module): The model to be checked. model (nn.Module): The model to be checked.