fix type hint and format (#88)
parent
11b38b12d6
commit
fd85156412
|
@ -406,7 +406,7 @@ class BaseDataElement:
|
|||
|
||||
# Tensor-like methods
|
||||
def numpy(self) -> 'BaseDataElement':
|
||||
"""Convert all tensor to np.narray in metainfo and data."""
|
||||
"""Convert all tensor to np.narray in metainfo and data."""
|
||||
new_data = self.new()
|
||||
for k, v in self.data_items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
|
|
|
@ -500,7 +500,7 @@ class BaseDataSample:
|
|||
|
||||
# Tensor-like methods
|
||||
def numpy(self) -> 'BaseDataSample':
|
||||
"""Convert all tensor to np.narray in metainfo and data."""
|
||||
"""Convert all tensor to np.narray in metainfo and data."""
|
||||
new_data = self.new()
|
||||
for k, v in self.data_items():
|
||||
if isinstance(v, (torch.Tensor, BaseDataElement)):
|
||||
|
|
|
@ -10,6 +10,7 @@ from typing import Any, List, Optional, Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
|
||||
|
@ -45,13 +46,13 @@ class BaseEvaluator(metaclass=ABCMeta):
|
|||
self._dataset_meta = dataset_meta
|
||||
|
||||
@abstractmethod
|
||||
def process(self, data_samples: dict, predictions: dict) -> None:
|
||||
def process(self, data_samples: BaseDataSample, predictions: dict) -> None:
|
||||
"""Process one batch of data samples and predictions. The processed
|
||||
results should be stored in ``self.results``, which will be used to
|
||||
compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_samples (dict): The data samples from the dataset.
|
||||
data_samples (BaseDataSample): The data samples from the dataset.
|
||||
predictions (dict): The output of the model.
|
||||
"""
|
||||
|
||||
|
@ -61,6 +62,7 @@ class BaseEvaluator(metaclass=ABCMeta):
|
|||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
|
||||
Returns:
|
||||
dict: The computed metrics. The keys are the names of the metrics,
|
||||
and the values are corresponding results.
|
||||
|
@ -78,9 +80,8 @@ class BaseEvaluator(metaclass=ABCMeta):
|
|||
this size.
|
||||
|
||||
Returns:
|
||||
metrics (dict): Evaluation metrics dict on the val dataset. The
|
||||
keys are the names of the metrics, and the values are
|
||||
corresponding results.
|
||||
dict: Evaluation metrics dict on the val dataset. The keys are the
|
||||
names of the metrics, and the values are corresponding results.
|
||||
"""
|
||||
if len(self.results) == 0:
|
||||
warnings.warn(
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Union
|
||||
|
||||
from ..registry import EVALUATORS
|
||||
from .base import BaseEvaluator
|
||||
from .composed_evaluator import ComposedEvaluator
|
||||
|
||||
|
||||
def build_evaluator(cfg: dict) -> object:
|
||||
def build_evaluator(
|
||||
cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
|
||||
"""Build function of evaluator.
|
||||
|
||||
When the evaluator config is a list, it will automatically build composed
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from .base import BaseEvaluator
|
||||
|
||||
|
||||
|
@ -31,11 +32,11 @@ class ComposedEvaluator:
|
|||
for evaluator in self.evaluators:
|
||||
evaluator.dataset_meta = dataset_meta
|
||||
|
||||
def process(self, data_samples: dict, predictions: dict):
|
||||
def process(self, data_samples: BaseDataSample, predictions: dict):
|
||||
"""Invoke process method of each wrapped evaluator.
|
||||
|
||||
Args:
|
||||
data_samples (dict): The data samples from the dataset.
|
||||
data_samples (BaseDataSample): The data samples from the dataset.
|
||||
predictions (dict): The output of the model.
|
||||
"""
|
||||
|
||||
|
@ -54,9 +55,8 @@ class ComposedEvaluator:
|
|||
this size.
|
||||
|
||||
Returns:
|
||||
metrics (dict): Evaluation metrics of all wrapped evaluators. The
|
||||
keys are the names of the metrics, and the values are
|
||||
corresponding results.
|
||||
dict: Evaluation metrics of all wrapped evaluators. The keys are
|
||||
the names of the metrics, and the values are corresponding results.
|
||||
"""
|
||||
metrics = {}
|
||||
for evaluator in self.evaluators:
|
||||
|
|
|
@ -43,6 +43,8 @@ class CheckpointHook(Hook):
|
|||
Default: None.
|
||||
"""
|
||||
|
||||
priority = 'VERY_LOW'
|
||||
|
||||
def __init__(self,
|
||||
interval: int = -1,
|
||||
by_epoch: bool = True,
|
||||
|
|
|
@ -22,6 +22,8 @@ class EmptyCacheHook(Hook):
|
|||
Defaults to False.
|
||||
"""
|
||||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def __init__(self,
|
||||
before_epoch: bool = False,
|
||||
after_epoch: bool = True,
|
||||
|
|
|
@ -10,6 +10,8 @@ class Hook:
|
|||
All hooks should inherit from this class.
|
||||
"""
|
||||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def before_run(self, runner: object) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations before the training process.
|
||||
|
|
|
@ -14,6 +14,8 @@ class IterTimerHook(Hook):
|
|||
Eg. ``data_time`` for loading data and ``time`` for a model train step.
|
||||
"""
|
||||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def before_epoch(self, runner: object) -> None:
|
||||
"""Record time flag before start a epoch.
|
||||
|
||||
|
|
|
@ -30,6 +30,8 @@ class OptimizerHook(Hook):
|
|||
Defaults to False.
|
||||
"""
|
||||
|
||||
priority = 'HIGH'
|
||||
|
||||
def __init__(self,
|
||||
grad_clip: Optional[dict] = None,
|
||||
detect_anomalous_params: bool = False) -> None:
|
||||
|
|
|
@ -11,6 +11,8 @@ class ParamSchedulerHook(Hook):
|
|||
"""A hook to update some hyper-parameters in optimizer, e.g learning rate
|
||||
and momentum."""
|
||||
|
||||
priority = 'LOW'
|
||||
|
||||
def after_iter(self,
|
||||
runner: object,
|
||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||
|
|
|
@ -12,6 +12,8 @@ class DistSamplerSeedHook(Hook):
|
|||
purpose with :obj:`IterLoader`.
|
||||
"""
|
||||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def before_epoch(self, runner: object) -> None:
|
||||
"""Set the seed for sampler and batch_sampler.
|
||||
|
||||
|
|
|
@ -84,6 +84,8 @@ class SyncBuffersHook(Hook):
|
|||
"""Synchronize model buffers such as running_mean and running_var in BN at
|
||||
the end of each epoch."""
|
||||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.distributed = dist.IS_DIST
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ def is_model_wrapper(model):
|
|||
The following 4 model in MMEngine (and their subclasses) are regarded as
|
||||
model wrappers: DataParallel, DistributedDataParallel,
|
||||
MMDataParallel, MMDistributedDataParallel. You may add you own
|
||||
model wrapper by registering it to mmengine.registry.MODEL_WRAPPERS.
|
||||
model wrapper by registering it to ``mmengine.registry.MODEL_WRAPPERS``.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be checked.
|
||||
|
|
Loading…
Reference in New Issue