fix type hint and format (#88)

pull/91/head^2
Zaida Zhou 2022-03-05 17:44:31 +08:00 committed by GitHub
parent 11b38b12d6
commit fd85156412
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 35 additions and 14 deletions

View File

@ -406,7 +406,7 @@ class BaseDataElement:
# Tensor-like methods
def numpy(self) -> 'BaseDataElement':
"""Convert all tensor to np.narray in metainfo and data."""
"""Convert all tensor to np.narray in metainfo and data."""
new_data = self.new()
for k, v in self.data_items():
if isinstance(v, torch.Tensor):

View File

@ -500,7 +500,7 @@ class BaseDataSample:
# Tensor-like methods
def numpy(self) -> 'BaseDataSample':
"""Convert all tensor to np.narray in metainfo and data."""
"""Convert all tensor to np.narray in metainfo and data."""
new_data = self.new()
for k, v in self.data_items():
if isinstance(v, (torch.Tensor, BaseDataElement)):

View File

@ -10,6 +10,7 @@ from typing import Any, List, Optional, Union
import torch
import torch.distributed as dist
from mmengine.data import BaseDataSample
from mmengine.utils import mkdir_or_exist
@ -45,13 +46,13 @@ class BaseEvaluator(metaclass=ABCMeta):
self._dataset_meta = dataset_meta
@abstractmethod
def process(self, data_samples: dict, predictions: dict) -> None:
def process(self, data_samples: BaseDataSample, predictions: dict) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.
Args:
data_samples (dict): The data samples from the dataset.
data_samples (BaseDataSample): The data samples from the dataset.
predictions (dict): The output of the model.
"""
@ -61,6 +62,7 @@ class BaseEvaluator(metaclass=ABCMeta):
Args:
results (list): The processed results of each batch.
Returns:
dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
@ -78,9 +80,8 @@ class BaseEvaluator(metaclass=ABCMeta):
this size.
Returns:
metrics (dict): Evaluation metrics dict on the val dataset. The
keys are the names of the metrics, and the values are
corresponding results.
dict: Evaluation metrics dict on the val dataset. The keys are the
names of the metrics, and the values are corresponding results.
"""
if len(self.results) == 0:
warnings.warn(

View File

@ -1,9 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
from ..registry import EVALUATORS
from .base import BaseEvaluator
from .composed_evaluator import ComposedEvaluator
def build_evaluator(cfg: dict) -> object:
def build_evaluator(
cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
"""Build function of evaluator.
When the evaluator config is a list, it will automatically build composed

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Union
from mmengine.data import BaseDataSample
from .base import BaseEvaluator
@ -31,11 +32,11 @@ class ComposedEvaluator:
for evaluator in self.evaluators:
evaluator.dataset_meta = dataset_meta
def process(self, data_samples: dict, predictions: dict):
def process(self, data_samples: BaseDataSample, predictions: dict):
"""Invoke process method of each wrapped evaluator.
Args:
data_samples (dict): The data samples from the dataset.
data_samples (BaseDataSample): The data samples from the dataset.
predictions (dict): The output of the model.
"""
@ -54,9 +55,8 @@ class ComposedEvaluator:
this size.
Returns:
metrics (dict): Evaluation metrics of all wrapped evaluators. The
keys are the names of the metrics, and the values are
corresponding results.
dict: Evaluation metrics of all wrapped evaluators. The keys are
the names of the metrics, and the values are corresponding results.
"""
metrics = {}
for evaluator in self.evaluators:

View File

@ -43,6 +43,8 @@ class CheckpointHook(Hook):
Default: None.
"""
priority = 'VERY_LOW'
def __init__(self,
interval: int = -1,
by_epoch: bool = True,

View File

@ -22,6 +22,8 @@ class EmptyCacheHook(Hook):
Defaults to False.
"""
priority = 'NORMAL'
def __init__(self,
before_epoch: bool = False,
after_epoch: bool = True,

View File

@ -10,6 +10,8 @@ class Hook:
All hooks should inherit from this class.
"""
priority = 'NORMAL'
def before_run(self, runner: object) -> None:
"""All subclasses should override this method, if they need any
operations before the training process.

View File

@ -14,6 +14,8 @@ class IterTimerHook(Hook):
Eg. ``data_time`` for loading data and ``time`` for a model train step.
"""
priority = 'NORMAL'
def before_epoch(self, runner: object) -> None:
"""Record time flag before start a epoch.

View File

@ -30,6 +30,8 @@ class OptimizerHook(Hook):
Defaults to False.
"""
priority = 'HIGH'
def __init__(self,
grad_clip: Optional[dict] = None,
detect_anomalous_params: bool = False) -> None:

View File

@ -11,6 +11,8 @@ class ParamSchedulerHook(Hook):
"""A hook to update some hyper-parameters in optimizer, e.g learning rate
and momentum."""
priority = 'LOW'
def after_iter(self,
runner: object,
data_batch: Optional[Sequence[BaseDataSample]] = None,

View File

@ -12,6 +12,8 @@ class DistSamplerSeedHook(Hook):
purpose with :obj:`IterLoader`.
"""
priority = 'NORMAL'
def before_epoch(self, runner: object) -> None:
"""Set the seed for sampler and batch_sampler.

View File

@ -84,6 +84,8 @@ class SyncBuffersHook(Hook):
"""Synchronize model buffers such as running_mean and running_var in BN at
the end of each epoch."""
priority = 'NORMAL'
def __init__(self) -> None:
self.distributed = dist.IS_DIST

View File

@ -8,7 +8,7 @@ def is_model_wrapper(model):
The following 4 model in MMEngine (and their subclasses) are regarded as
model wrappers: DataParallel, DistributedDataParallel,
MMDataParallel, MMDistributedDataParallel. You may add you own
model wrapper by registering it to mmengine.registry.MODEL_WRAPPERS.
model wrapper by registering it to ``mmengine.registry.MODEL_WRAPPERS``.
Args:
model (nn.Module): The model to be checked.