fix type hint and format (#88)
parent
11b38b12d6
commit
fd85156412
|
@ -406,7 +406,7 @@ class BaseDataElement:
|
||||||
|
|
||||||
# Tensor-like methods
|
# Tensor-like methods
|
||||||
def numpy(self) -> 'BaseDataElement':
|
def numpy(self) -> 'BaseDataElement':
|
||||||
"""Convert all tensor to np.narray in metainfo and data."""
|
"""Convert all tensor to np.narray in metainfo and data."""
|
||||||
new_data = self.new()
|
new_data = self.new()
|
||||||
for k, v in self.data_items():
|
for k, v in self.data_items():
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
|
|
|
@ -500,7 +500,7 @@ class BaseDataSample:
|
||||||
|
|
||||||
# Tensor-like methods
|
# Tensor-like methods
|
||||||
def numpy(self) -> 'BaseDataSample':
|
def numpy(self) -> 'BaseDataSample':
|
||||||
"""Convert all tensor to np.narray in metainfo and data."""
|
"""Convert all tensor to np.narray in metainfo and data."""
|
||||||
new_data = self.new()
|
new_data = self.new()
|
||||||
for k, v in self.data_items():
|
for k, v in self.data_items():
|
||||||
if isinstance(v, (torch.Tensor, BaseDataElement)):
|
if isinstance(v, (torch.Tensor, BaseDataElement)):
|
||||||
|
|
|
@ -10,6 +10,7 @@ from typing import Any, List, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from mmengine.data import BaseDataSample
|
||||||
from mmengine.utils import mkdir_or_exist
|
from mmengine.utils import mkdir_or_exist
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,13 +46,13 @@ class BaseEvaluator(metaclass=ABCMeta):
|
||||||
self._dataset_meta = dataset_meta
|
self._dataset_meta = dataset_meta
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def process(self, data_samples: dict, predictions: dict) -> None:
|
def process(self, data_samples: BaseDataSample, predictions: dict) -> None:
|
||||||
"""Process one batch of data samples and predictions. The processed
|
"""Process one batch of data samples and predictions. The processed
|
||||||
results should be stored in ``self.results``, which will be used to
|
results should be stored in ``self.results``, which will be used to
|
||||||
compute the metrics when all batches have been processed.
|
compute the metrics when all batches have been processed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_samples (dict): The data samples from the dataset.
|
data_samples (BaseDataSample): The data samples from the dataset.
|
||||||
predictions (dict): The output of the model.
|
predictions (dict): The output of the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -61,6 +62,7 @@ class BaseEvaluator(metaclass=ABCMeta):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
results (list): The processed results of each batch.
|
results (list): The processed results of each batch.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: The computed metrics. The keys are the names of the metrics,
|
dict: The computed metrics. The keys are the names of the metrics,
|
||||||
and the values are corresponding results.
|
and the values are corresponding results.
|
||||||
|
@ -78,9 +80,8 @@ class BaseEvaluator(metaclass=ABCMeta):
|
||||||
this size.
|
this size.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
metrics (dict): Evaluation metrics dict on the val dataset. The
|
dict: Evaluation metrics dict on the val dataset. The keys are the
|
||||||
keys are the names of the metrics, and the values are
|
names of the metrics, and the values are corresponding results.
|
||||||
corresponding results.
|
|
||||||
"""
|
"""
|
||||||
if len(self.results) == 0:
|
if len(self.results) == 0:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -1,9 +1,13 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from ..registry import EVALUATORS
|
from ..registry import EVALUATORS
|
||||||
|
from .base import BaseEvaluator
|
||||||
from .composed_evaluator import ComposedEvaluator
|
from .composed_evaluator import ComposedEvaluator
|
||||||
|
|
||||||
|
|
||||||
def build_evaluator(cfg: dict) -> object:
|
def build_evaluator(
|
||||||
|
cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
|
||||||
"""Build function of evaluator.
|
"""Build function of evaluator.
|
||||||
|
|
||||||
When the evaluator config is a list, it will automatically build composed
|
When the evaluator config is a list, it will automatically build composed
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from typing import Optional, Sequence, Union
|
from typing import Optional, Sequence, Union
|
||||||
|
|
||||||
|
from mmengine.data import BaseDataSample
|
||||||
from .base import BaseEvaluator
|
from .base import BaseEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,11 +32,11 @@ class ComposedEvaluator:
|
||||||
for evaluator in self.evaluators:
|
for evaluator in self.evaluators:
|
||||||
evaluator.dataset_meta = dataset_meta
|
evaluator.dataset_meta = dataset_meta
|
||||||
|
|
||||||
def process(self, data_samples: dict, predictions: dict):
|
def process(self, data_samples: BaseDataSample, predictions: dict):
|
||||||
"""Invoke process method of each wrapped evaluator.
|
"""Invoke process method of each wrapped evaluator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_samples (dict): The data samples from the dataset.
|
data_samples (BaseDataSample): The data samples from the dataset.
|
||||||
predictions (dict): The output of the model.
|
predictions (dict): The output of the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -54,9 +55,8 @@ class ComposedEvaluator:
|
||||||
this size.
|
this size.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
metrics (dict): Evaluation metrics of all wrapped evaluators. The
|
dict: Evaluation metrics of all wrapped evaluators. The keys are
|
||||||
keys are the names of the metrics, and the values are
|
the names of the metrics, and the values are corresponding results.
|
||||||
corresponding results.
|
|
||||||
"""
|
"""
|
||||||
metrics = {}
|
metrics = {}
|
||||||
for evaluator in self.evaluators:
|
for evaluator in self.evaluators:
|
||||||
|
|
|
@ -43,6 +43,8 @@ class CheckpointHook(Hook):
|
||||||
Default: None.
|
Default: None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
priority = 'VERY_LOW'
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
interval: int = -1,
|
interval: int = -1,
|
||||||
by_epoch: bool = True,
|
by_epoch: bool = True,
|
||||||
|
|
|
@ -22,6 +22,8 @@ class EmptyCacheHook(Hook):
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
priority = 'NORMAL'
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
before_epoch: bool = False,
|
before_epoch: bool = False,
|
||||||
after_epoch: bool = True,
|
after_epoch: bool = True,
|
||||||
|
|
|
@ -10,6 +10,8 @@ class Hook:
|
||||||
All hooks should inherit from this class.
|
All hooks should inherit from this class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
priority = 'NORMAL'
|
||||||
|
|
||||||
def before_run(self, runner: object) -> None:
|
def before_run(self, runner: object) -> None:
|
||||||
"""All subclasses should override this method, if they need any
|
"""All subclasses should override this method, if they need any
|
||||||
operations before the training process.
|
operations before the training process.
|
||||||
|
|
|
@ -14,6 +14,8 @@ class IterTimerHook(Hook):
|
||||||
Eg. ``data_time`` for loading data and ``time`` for a model train step.
|
Eg. ``data_time`` for loading data and ``time`` for a model train step.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
priority = 'NORMAL'
|
||||||
|
|
||||||
def before_epoch(self, runner: object) -> None:
|
def before_epoch(self, runner: object) -> None:
|
||||||
"""Record time flag before start a epoch.
|
"""Record time flag before start a epoch.
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,8 @@ class OptimizerHook(Hook):
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
priority = 'HIGH'
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
grad_clip: Optional[dict] = None,
|
grad_clip: Optional[dict] = None,
|
||||||
detect_anomalous_params: bool = False) -> None:
|
detect_anomalous_params: bool = False) -> None:
|
||||||
|
|
|
@ -11,6 +11,8 @@ class ParamSchedulerHook(Hook):
|
||||||
"""A hook to update some hyper-parameters in optimizer, e.g learning rate
|
"""A hook to update some hyper-parameters in optimizer, e.g learning rate
|
||||||
and momentum."""
|
and momentum."""
|
||||||
|
|
||||||
|
priority = 'LOW'
|
||||||
|
|
||||||
def after_iter(self,
|
def after_iter(self,
|
||||||
runner: object,
|
runner: object,
|
||||||
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
data_batch: Optional[Sequence[BaseDataSample]] = None,
|
||||||
|
|
|
@ -12,6 +12,8 @@ class DistSamplerSeedHook(Hook):
|
||||||
purpose with :obj:`IterLoader`.
|
purpose with :obj:`IterLoader`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
priority = 'NORMAL'
|
||||||
|
|
||||||
def before_epoch(self, runner: object) -> None:
|
def before_epoch(self, runner: object) -> None:
|
||||||
"""Set the seed for sampler and batch_sampler.
|
"""Set the seed for sampler and batch_sampler.
|
||||||
|
|
||||||
|
|
|
@ -84,6 +84,8 @@ class SyncBuffersHook(Hook):
|
||||||
"""Synchronize model buffers such as running_mean and running_var in BN at
|
"""Synchronize model buffers such as running_mean and running_var in BN at
|
||||||
the end of each epoch."""
|
the end of each epoch."""
|
||||||
|
|
||||||
|
priority = 'NORMAL'
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.distributed = dist.IS_DIST
|
self.distributed = dist.IS_DIST
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ def is_model_wrapper(model):
|
||||||
The following 4 model in MMEngine (and their subclasses) are regarded as
|
The following 4 model in MMEngine (and their subclasses) are regarded as
|
||||||
model wrappers: DataParallel, DistributedDataParallel,
|
model wrappers: DataParallel, DistributedDataParallel,
|
||||||
MMDataParallel, MMDistributedDataParallel. You may add you own
|
MMDataParallel, MMDistributedDataParallel. You may add you own
|
||||||
model wrapper by registering it to mmengine.registry.MODEL_WRAPPERS.
|
model wrapper by registering it to ``mmengine.registry.MODEL_WRAPPERS``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The model to be checked.
|
model (nn.Module): The model to be checked.
|
||||||
|
|
Loading…
Reference in New Issue