[Refactor] Refactor data flow to make the interface more natural (#468)

* [Refactor]: modify interface of Visualizer.add_datasample (#365)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader`  (#323)

* acollate data in dataloader

* fix docstring

* refine comment

* fix as comment

* refactor default collate and psedo collate

* foramt test file

* fix docstring

* fix as comment

* rename elem to data_item

* minor fix

* fix as comment

* [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (#360)

* refine evaluator and metric

* compatible with new default collate

* replace default collate with pseudo

* Handle data_batch in metric

* fix unit test

* fix unit test

* fix unit test

* minor refine

* make data_batch optional

make data_batch optional

* rename outputs to predictions

* fix ut

* rename predictions to outputs

* fix docstring

* fix docstring

* fix unit test

* make outputs and data_batch to kwargs

* fix unit test

* keep signature of metric

* fix ut

* rename pred_sample arguments to data_sample(Visualizer)

* fix loop and ut

* [refactor]: Refactor model dataflow (#398)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* refactor model data flow

tmp_commt

tmp commit

* make val_cfg and test_cfg optional

* roll back runner

* pass test mmdet

* fix as comment

fix as comment

fix ci in DataPreprocessor

* fix ut

* fix ut

* fix rebase main

* [Fix]: Fix test val ddp (#462)

* [Fix] Fix docstring and type hint of data flow (#463)

* Fix docstring of data flow

* change signature of hook

* fix unit test

* resolve conflicts

* fix lint
This commit is contained in:
Mashiro 2022-08-24 22:04:55 +08:00 committed by GitHub
parent 7e1d7af2d9
commit 8770c6c7fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 842 additions and 448 deletions

View File

@ -228,10 +228,8 @@ class CheckInvalidLossHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
if self.every_n_train_iters(runner, self.interval):
assert torch.isfinite(outputs['loss']),\

View File

@ -2,10 +2,11 @@
from .base_dataset import BaseDataset, Compose, force_full_init
from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset
from .sampler import DefaultSampler, InfiniteSampler
from .utils import pseudo_collate, worker_init_fn
from .utils import (COLLATE_FUNCTIONS, default_collate, pseudo_collate,
worker_init_fn)
__all__ = [
'BaseDataset', 'Compose', 'force_full_init', 'ClassBalancedDataset',
'ConcatDataset', 'RepeatDataset', 'DefaultSampler', 'InfiniteSampler',
'worker_init_fn', 'pseudo_collate'
'worker_init_fn', 'pseudo_collate', 'COLLATE_FUNCTIONS', 'default_collate'
]

View File

@ -1,11 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import Sequence
from typing import Any, Mapping, Sequence
import numpy as np
import torch
from torch.utils.data._utils.collate import \
default_collate as torch_default_collate
DATA_BATCH = Sequence[dict]
from mmengine.registry import Registry
from mmengine.structures import BaseDataElement
COLLATE_FUNCTIONS = Registry('Collate Functions')
def worker_init_fn(worker_id: int, num_workers: int, rank: int,
@ -28,16 +33,124 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int,
torch.manual_seed(worker_seed)
def pseudo_collate(data_batch: DATA_BATCH) -> DATA_BATCH:
"""The default behavior of dataloader is to merge a list of samples to form
a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate`` does
nothing just returns ``data_batch``.
@COLLATE_FUNCTIONS.register_module()
def pseudo_collate(data_batch: Sequence) -> Any:
"""Convert list of data sampled from dataset into a batch of data, of which
type consistent with the type of each data_itement in ``data_batch``.
The default behavior of dataloader is to merge a list of samples to form
a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate``
will not stack tensors to batch tensors, and convert int, float, ndarray to
tensors.
This code is referenced from:
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_. # noqa: E501
Args:
data_batch (Sequence[dict]): Batch of data from
dataloader.
data_batch (Sequence): Batch of data from dataloader.
Returns:
Sequence[dict]: Return input ``data_batch``.
Any: Transversed Data in the same format as the data_itement of
``data_batch``.
"""
return data_batch
data_item = data_batch[0]
data_item_type = type(data_item)
if isinstance(data_item, (str, bytes)):
return data_batch
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'):
# named tuple
return data_item_type(*(pseudo_collate(samples)
for samples in zip(*data_batch)))
elif isinstance(data_item, Sequence):
# check to make sure that the data_itements in batch have
# consistent size
it = iter(data_batch)
data_item_size = len(next(it))
if not all(len(data_item) == data_item_size for data_item in it):
raise RuntimeError(
'each data_itement in list of batch should be of equal size')
transposed = list(zip(*data_batch))
if isinstance(data_item, tuple):
return [pseudo_collate(samples)
for samples in transposed] # Compat with Pytorch.
else:
try:
return data_item_type(
[pseudo_collate(samples) for samples in transposed])
except TypeError:
# The sequence type may not support `__init__(iterable)`
# (e.g., `range`).
return [pseudo_collate(samples) for samples in transposed]
elif isinstance(data_item, Mapping):
return data_item_type({
key: pseudo_collate([d[key] for d in data_batch])
for key in data_item
})
else:
return data_batch
@COLLATE_FUNCTIONS.register_module()
def default_collate(data_batch: Sequence) -> Any:
"""Convert list of data sampled from dataset into a batch of data, of which
type consistent with the type of each data_itement in ``data_batch``.
Different from :func:`pseudo_collate`, ``default_collate`` will stack
tensor contained in ``data_batch`` into a batched tensor with the
first dimension batch size, and then move input tensor to the target
device.
Different from ``default_collate`` in pytorch, ``default_collate`` will
not process ``BaseDataElement``.
This code is referenced from:
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_. # noqa: E501
Note:
``default_collate`` only accept input tensor with the same shape.
Args:
data_batch (Sequence): Data sampled from dataset.
Returns:
Any: Data in the same format as the data_itement of ``data_batch``, of which
tensors have been stacked, and ndarray, int, float have been
converted to tensors.
"""
data_item = data_batch[0]
data_item_type = type(data_item)
if isinstance(data_item, (BaseDataElement, str, bytes)):
return data_batch
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'):
# named_tuple
return data_item_type(*(default_collate(samples)
for samples in zip(*data_batch)))
elif isinstance(data_item, Sequence):
# check to make sure that the data_itements in batch have
# consistent size
it = iter(data_batch)
data_item_size = len(next(it))
if not all(len(data_item) == data_item_size for data_item in it):
raise RuntimeError(
'each data_itement in list of batch should be of equal size')
transposed = list(zip(*data_batch))
if isinstance(data_item, tuple):
return [default_collate(samples)
for samples in transposed] # Compat with Pytorch.
else:
try:
return data_item_type(
[default_collate(samples) for samples in transposed])
except TypeError:
# The sequence type may not support `__init__(iterable)`
# (e.g., `range`).
return [default_collate(samples) for samples in transposed]
elif isinstance(data_item, Mapping):
return data_item_type({
key: default_collate([d[key] for d in data_batch])
for key in data_item
})
else:
return torch_default_collate(data_batch)

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Iterator, List, Optional, Sequence, Union
from typing import Any, Iterator, List, Optional, Sequence, Union
from mmengine.dataset import pseudo_collate
from mmengine.registry import EVALUATOR, METRICS
from mmengine.structures import BaseDataElement
from .metric import BaseMetric
@ -37,34 +38,26 @@ class Evaluator:
for metric in self.metrics:
metric.dataset_meta = dataset_meta
def process(self, data_batch: Sequence[dict],
predictions: Sequence[BaseDataElement]):
def process(self,
data_samples: Sequence[BaseDataElement],
data_batch: Optional[Any] = None):
"""Convert ``BaseDataSample`` to dict and invoke process method of each
metric.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[BaseDataElement]): A batch of outputs from
the model.
data_samples (Sequence[BaseDataElement]): predictions of the model,
and the ground truth of the validation set.
data_batch (Any, optional): A batch of data from the dataloader.
"""
_data_batch = []
for data in data_batch:
if isinstance(data['data_sample'], BaseDataElement):
_data_batch.append(
dict(
inputs=data['inputs'],
data_sample=data['data_sample'].to_dict()))
_data_samples = []
for data_sample in data_samples:
if isinstance(data_sample, BaseDataElement):
_data_samples.append(data_sample.to_dict())
else:
_data_batch.append(data)
_predictions = []
for pred in predictions:
if isinstance(pred, BaseDataElement):
_predictions.append(pred.to_dict())
else:
_predictions.append(pred)
_data_samples.append(data_sample)
for metric in self.metrics:
metric.process(_data_batch, _predictions)
metric.process(data_batch, _data_samples)
def evaluate(self, size: int) -> dict:
"""Invoke ``evaluate`` method of each metric and collect the metrics
@ -97,20 +90,26 @@ class Evaluator:
return metrics
def offline_evaluate(self,
data: Sequence,
predictions: Sequence,
data_samples: Sequence,
data: Optional[Sequence] = None,
chunk_size: int = 1):
"""Offline evaluate the dumped predictions on the given data .
Args:
data (Sequence): All data of the validation set.
predictions (Sequence): All predictions of the model on the
validation set.
data_samples (Sequence): All predictions and ground truth of the
model and the validation set.
data (Sequence, optional): All data of the validation set.
chunk_size (int): The number of data samples and predictions to be
processed in a batch.
"""
# support chunking iterable objects
if data is not None:
assert len(data_samples) == len(data), (
'outputs and data should have the same length, but got '
f'outputs length: {len(data_samples)} '
f'data length: {len(data)}')
def get_chunks(seq: Iterator, chunk_size=1):
stop = False
while not stop:
@ -125,9 +124,11 @@ class Evaluator:
yield chunk
size = 0
for data_chunk, pred_chunk in zip(
get_chunks(iter(data), chunk_size),
get_chunks(iter(predictions), chunk_size)):
size += len(data_chunk)
self.process(data_chunk, pred_chunk)
for output_chunk in get_chunks(iter(data_samples), chunk_size):
if data is not None:
data_chunk = pseudo_collate(data[size:size + chunk_size])
else:
data_chunk = None
size += len(output_chunk)
self.process(output_chunk, data_chunk)
return self.evaluate(size)

View File

@ -58,15 +58,14 @@ class BaseMetric(metaclass=ABCMeta):
self._dataset_meta = dataset_meta
@abstractmethod
def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from
data_batch (Any): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from
the model.
"""
@ -146,8 +145,7 @@ class DumpResults(BaseMetric):
raise ValueError('The output file must be a pkl file.')
self.out_file_path = out_file_path
def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
"""transfer tensors in predictions to CPU."""
self.results.extend(_to_cpu(predictions))

View File

@ -12,7 +12,7 @@ from mmengine.registry import HOOKS
from mmengine.utils import is_list_of, is_seq_of
from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]]
DATA_BATCH = Optional[Union[dict, tuple, list]]
@HOOKS.register_module()
@ -470,9 +470,8 @@ class CheckpointHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model. Defaults to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (dict, optional): Outputs from model.
"""
if self.by_epoch:
return

View File

@ -4,10 +4,9 @@ from typing import Optional, Sequence, Union
import torch
from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement
from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]]
DATA_BATCH = Optional[Union[dict, tuple, list]]
@HOOKS.register_module()
@ -38,18 +37,15 @@ class EmptyCacheHook(Hook):
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Union[dict,
Sequence[BaseDataElement]]] = None,
outputs: Optional[Union[dict, Sequence]] = None,
mode: str = 'train') -> None:
"""Empty cache after an iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (dict or sequence, optional): Outputs from model.
Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_after_iter:

View File

@ -1,9 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence, Union
from mmengine.structures import BaseDataElement
DATA_BATCH = Optional[Sequence[dict]]
DATA_BATCH = Optional[Union[dict, tuple, list]]
class Hook:
@ -184,8 +182,7 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
"""
self._before_iter(
runner, batch_idx=batch_idx, data_batch=data_batch, mode='train')
@ -200,7 +197,7 @@ class Hook:
Args:
runner (Runner): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[dict], optional): Data from dataloader.
data_batch (dict, optional): Data from dataloader.
Defaults to None.
"""
self._before_iter(
@ -216,7 +213,7 @@ class Hook:
Args:
runner (Runner): The runner of the testing process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[dict], optional): Data from dataloader.
data_batch (dict or tuple or list, optional): Data from dataloader.
Defaults to None.
"""
self._before_iter(
@ -233,10 +230,8 @@ class Hook:
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
data_batch (dict tuple or list, optional): Data from dataloader.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
self._after_iter(
runner,
@ -249,18 +244,15 @@ class Hook:
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataElement]] = None) \
-> None:
outputs: Optional[Sequence] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each validation iteration.
Args:
runner (Runner): The runner of the validation process.
batch_idx (int): The index of the current batch in the val loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict or sequence, optional): Outputs from
model. Defaults to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (Sequence, optional): Outputs from model.
"""
self._after_iter(
runner,
@ -269,22 +261,19 @@ class Hook:
outputs=outputs,
mode='val')
def after_test_iter(
self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
def after_test_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence] = None) -> None:
"""All subclasses should override this method, if they need any
operations after each test iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (Sequence, optional): Outputs from model.
"""
self._after_iter(
runner,
@ -327,8 +316,7 @@ class Hook:
runner (Runner): The runner of the training, validation or testing
process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass
@ -337,8 +325,7 @@ class Hook:
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Union[Sequence[BaseDataElement],
dict]] = None,
outputs: Optional[Union[Sequence, dict]] = None,
mode: str = 'train') -> None:
"""All subclasses should override this method, if they need any
operations after each epoch.
@ -347,10 +334,8 @@ class Hook:
runner (Runner): The runner of the training, validation or testing
process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (Sequence[BaseDataElement], optional): Outputs from model.
Defaults to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (dict or Sequence, optional): Outputs from model.
mode (str): Current mode of runner. Defaults to 'train'.
"""
pass

View File

@ -3,10 +3,9 @@ import time
from typing import Optional, Sequence, Union
from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement
from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]]
DATA_BATCH = Optional[Union[dict, tuple, list]]
@HOOKS.register_module()
@ -54,8 +53,8 @@ class IterTimerHook(Hook):
runner (Runner): The runner of the training, validation and
testing process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
data_batch (dict or tuple or list, optional): Data from
dataloader.
mode (str): Current mode of runner. Defaults to 'train'.
"""
# Update data loading time in `runner.message_hub`.
@ -66,8 +65,7 @@ class IterTimerHook(Hook):
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Union[dict,
Sequence[BaseDataElement]]] = None,
outputs: Optional[Union[dict, Sequence]] = None,
mode: str = 'train') -> None:
"""Calculating time for an iteration and updating "time"
``HistoryBuffer`` of ``runner.message_hub``.
@ -76,10 +74,8 @@ class IterTimerHook(Hook):
runner (Runner): The runner of the training validation and
testing process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict or sequence, optional): Outputs from model. Defaults
to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (dict or sequence, optional): Outputs from model.
mode (str): Current mode of runner. Defaults to 'train'.
"""
# Update iteration time in `runner.message_hub`.

View File

@ -7,10 +7,9 @@ from typing import Dict, Optional, Sequence, Union
from mmengine.fileio import FileClient, dump
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_tuple_of, scandir
DATA_BATCH = Optional[Sequence[dict]]
DATA_BATCH = Optional[Union[dict, tuple, list]]
SUFFIX_TYPE = Union[Sequence[str], str]
@ -125,9 +124,8 @@ class LoggerHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model. Defaults to None.
data_batch (dict tuple or list, optional): Data from dataloader.
outputs (dict, optional): Outputs from model.
"""
# Print experiment name every n iterations.
if self.every_n_train_iters(
@ -152,41 +150,38 @@ class LoggerHook(Hook):
runner.visualizer.add_scalars(
tag, step=runner.iter + 1, file_path=self.json_log_path)
def after_val_iter(
self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
def after_val_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence] = None) -> None:
"""Record logs after validation iteration.
Args:
runner (Runner): The runner of the validation process.
batch_idx (int): The index of the current batch in the validation
loop.
data_batch (Sequence[dict], optional): Data from dataloader.
data_batch (dict or tuple or list, optional): Data from dataloader.
Defaults to None.
outputs (sequence, optional): Outputs from model. Defaults to None.
outputs (sequence, optional): Outputs from model.
"""
if self.every_n_inner_iters(batch_idx, self.interval):
_, log_str = runner.log_processor.get_log_after_iter(
runner, batch_idx, 'val')
runner.logger.info(log_str)
def after_test_iter(
self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
def after_test_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence] = None) -> None:
"""Record logs after testing iteration.
Args:
runner (Runner): The runner of the testing process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (sequence, optional): Outputs from model. Defaults to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (sequence, optional): Outputs from model.
"""
if self.every_n_inner_iters(batch_idx, self.interval):
_, log_str = runner.log_processor.get_log_after_iter(

View File

@ -1,15 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple, Union
import cv2
import numpy as np
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement
from mmengine.utils.dl_utils import tensor2imgs
DATA_BATCH = Optional[Union[dict, tuple, list]]
# TODO: Due to interface changes, the current class
# functions incorrectly
@ -48,21 +49,18 @@ class NaiveVisualizationHook(Hook):
unpad_image = input[:unpad_height, :unpad_width]
return unpad_image
def after_test_iter(
self,
runner,
batch_idx: int,
data_batch: Optional[Sequence[dict]] = None,
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
def after_test_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence] = None) -> None:
"""Show or Write the predicted results.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (Sequence[dict], optional): Data
from dataloader. Defaults to None.
outputs (Sequence[BaseDataElement], optional): Outputs from model.
Defaults to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (Sequence, optional): Outputs from model.
"""
if self.every_n_inner_iters(batch_idx, self._interval):
for data, output in zip(data_batch, outputs): # type: ignore

View File

@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence
from typing import Optional, Union
from mmengine.registry import HOOKS
from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]]
DATA_BATCH = Optional[Union[dict, tuple, list]]
@HOOKS.register_module()
@ -24,12 +24,12 @@ class ParamSchedulerHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
data_batch (dict or tuple or list, optional): Data from dataloader.
In order to keep this interface consistent with other hooks,
we keep ``data_batch`` here. Defaults to None.
we keep ``data_batch`` here.
outputs (dict, optional): Outputs from model.
In order to keep this interface consistent with other hooks, we
keep ``data_batch`` here. Defaults to None.
keep ``data_batch`` here.
"""
def step(param_schedulers):

View File

@ -1,12 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence
from typing import Dict, Optional, Union
from mmengine.registry import HOOKS
from mmengine.utils import get_git_hash
from mmengine.version import __version__
from .hook import Hook
DATA_BATCH = Optional[Sequence[dict]]
DATA_BATCH = Optional[Union[dict, tuple, list]]
@HOOKS.register_module()

View File

@ -1,9 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_model import BaseModel
from .data_preprocessor import (BaseDataElement, BaseDataPreprocessor,
ImgDataPreprocessor)
from .data_preprocessor import BaseDataPreprocessor, ImgDataPreprocessor
__all__ = [
'BaseModel', 'BaseDataElement', 'ImgDataPreprocessor',
'BaseDataPreprocessor'
]
__all__ = ['BaseModel', 'ImgDataPreprocessor', 'BaseDataPreprocessor']

View File

@ -1,21 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import abstractmethod
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmengine.optim import OptimWrapper
from mmengine.registry import MODELS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from ..base_module import BaseModule
from .data_preprocessor import BaseDataPreprocessor
ForwardResults = Union[Dict[str, torch.Tensor], List[BaseDataElement],
Tuple[torch.Tensor], torch.Tensor]
class BaseModel(BaseModule):
"""Base class for all algorithmic models.
@ -85,7 +81,7 @@ class BaseModel(BaseModule):
f'`nn.Module` instance, but got '
f'{type(data_preprocessor)}')
def train_step(self, data: List[dict],
def train_step(self, data: Union[dict, tuple, list],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""Implements the default model training process including
preprocessing, model forward propagation, loss calculation,
@ -96,7 +92,7 @@ class BaseModel(BaseModule):
:class:`IterBasedTrainLoop` will call this method to update model
parameters. The default parameter update process is as follows:
1. Calls ``self.data_processor(data, training=False) to collext
1. Calls ``self.data_processor(data, training=False) to collect
batch_inputs and corresponding data_samples(labels).
2. Calls ``self(batch_inputs, data_samples, mode='loss')`` to get raw
loss
@ -105,7 +101,7 @@ class BaseModel(BaseModule):
4. Calls ``optim_wrapper.update_params(loss)`` to update model.
Args:
data (List[dict]): Data sampled from dataloader.
data (dict or tuple or list): Data sampled from dataset.
optim_wrapper (OptimWrapper): OptimWrapper instance
used to update model parameters.
@ -114,13 +110,13 @@ class BaseModel(BaseModule):
"""
# Enable automatic mixed precision training context.
with optim_wrapper.optim_context(self):
batch_inputs, data_samples = self.data_preprocessor(data, True)
losses = self(batch_inputs, data_samples, mode='loss')
parsed_losses, log_vars = self.parse_losses(losses)
data = self.data_preprocessor(data, True)
losses = self._run_forward(data, mode='loss') # type: ignore
parsed_losses, log_vars = self.parse_losses(losses) # type: ignore
optim_wrapper.update_params(parsed_losses)
return log_vars
def val_step(self, data: List[dict]) -> List[BaseDataElement]:
def val_step(self, data: Union[tuple, dict, list]) -> list:
"""Gets the predictions of given data.
Calls ``self.data_preprocessor(data, False)`` and
@ -128,25 +124,25 @@ class BaseModel(BaseModule):
predictions which will be passed to evaluator.
Args:
data (List[dict]): Data sampled from dataloader.
data (dict or tuple or list): Data sampled from dataset.
Returns:
List[BaseDataElement]: The predictions of given data.
list: The predictions of given data.
"""
inputs, data_sample = self.data_preprocessor(data, False)
return self(inputs, data_sample, mode='predict')
data = self.data_preprocessor(data, False)
return self._run_forward(data, mode='predict') # type: ignore
def test_step(self, data: List[dict]) -> List[BaseDataElement]:
def test_step(self, data: Union[dict, tuple, list]) -> list:
"""``BaseModel`` implements ``test_step`` the same as ``val_step``.
Args:
data (List[dict]): Data sampled from dataloader.
data (dict or tuple or list): Data sampled from dataset.
Returns:
List[BaseDataElement]: The predictions of given data.
list: The predictions of given data.
"""
inputs, data_sample = self.data_preprocessor(data, False)
return self(inputs, data_sample, mode='predict')
data = self.data_preprocessor(data, False)
return self._run_forward(data, mode='predict') # type: ignore
def parse_losses(
self, losses: Dict[str, torch.Tensor]
@ -239,16 +235,16 @@ class BaseModel(BaseModule):
@abstractmethod
def forward(self,
batch_inputs: torch.Tensor,
data_samples: Optional[List[BaseDataElement]] = None,
mode: str = 'tensor') -> ForwardResults:
inputs: torch.Tensor,
data_samples: Optional[list] = None,
mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]:
"""Returns losses or predictions of training, validation, testing, and
simple inference process.
``forward`` method of BaseModel is an abstract method, its subclasses
must implement this method.
Accepts ``batch_inputs`` and ``data_samples`` processed by
Accepts ``batch_inputs`` and ``data_sample`` processed by
:attr:`data_preprocessor`, and returns results according to mode
arguments.
@ -263,9 +259,9 @@ class BaseModel(BaseModule):
loss.
Args:
batch_inputs (torch.Tensor): batch input tensor collated by
inputs (torch.Tensor): batch input tensor collated by
:attr:`data_preprocessor`.
data_samples (List[BaseDataElement], optional):
data_samples (list, optional):
data samples collated by :attr:`data_preprocessor`.
mode (str): mode should be one of ``loss``, ``predict`` and
``tensor``
@ -273,19 +269,36 @@ class BaseModel(BaseModule):
- ``loss``: Called by ``train_step`` and return loss ``dict``
used for logging
- ``predict``: Called by ``val_step`` and ``test_step``
and return list of ``BaseDataElement`` results used for
computing metric.
and return list of `results used for computing metric.
- ``tensor``: Called by custom use to get ``Tensor`` type
results.
Returns:
ForwardResults:
dict or list:
- If ``mode == loss``, return a ``dict`` of loss tensor used
for backward and logging.
- If ``mode == predict``, return a ``list`` of
:obj:`BaseDataElement` for computing metric
and getting inference result.
- If ``mode == predict``, return a ``list`` of inference
results.
- If ``mode == tensor``, return a tensor or ``tuple`` of tensor
or ``dict of tensor for custom use.
"""
def _run_forward(self, data: Union[dict, tuple, list],
mode: str) -> Union[Dict[str, torch.Tensor], list]:
"""Unpacks data for :meth:`forward`
Args:
data (dict or tuple or list): Data sampled from dataset.
mode (str): Mode of forward.
Returns:
dict or list: Results of training or testing mode.
"""
if isinstance(data, dict):
results = self(**data, mode=mode)
elif isinstance(data, (list, tuple)):
results = self(*data, mode=mode)
else:
raise TypeError('Output of `data_preprocessor` should be '
f'list, tuple or dict, but got {type(data)}')
return results

View File

@ -1,94 +1,74 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Sequence, Tuple, Union
import math
from typing import Mapping, Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.registry import MODELS
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
from ..utils import stack_batch
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list]
@MODELS.register_module()
class BaseDataPreprocessor(nn.Module):
"""Base data pre-processor used for collating and copying data to the
target device.
``BaseDataPreprocessor`` performs data pre-processing according to the
following steps:
- Collates the data sampled from dataloader.
- Copies data to the target device.
- Stacks the input tensor at the first dimension.
"""Base data pre-processor used for copying data to the target device.
Subclasses inherit from ``BaseDataPreprocessor`` could override the
forward method to implement custom data pre-processing, such as
batch-resize, MixUp, or CutMix.
Warnings:
Each item of data sampled from dataloader must be a dict and at least
contain the ``inputs`` key. Furthermore, the value of ``inputs``
must be a ``Tensor`` with the same shape.
Note:
Data dictionary returned by dataloader must be a dict and at least
contain the ``inputs`` key.
"""
def __init__(self):
super().__init__()
self._device = torch.device('cpu')
def collate_data(
self,
data: Sequence[dict]) -> Tuple[List[torch.Tensor], Optional[list]]:
"""Collating and copying data to the target device.
Collates the data sampled from dataloader into a list of tensor and
list of labels, and then copies tensor to the target device.
Subclasses could override it to be compatible with the custom format
data sampled from custom dataloader.
def cast_data(self, data: CastData) -> CastData:
"""Copying data to the target device.
Args:
data (Sequence[dict]): Data sampled from dataloader.
data (dict): Data returned by ``DataLoader``.
Returns:
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
tensor and list of labels at target device.
CollatedResult: Inputs and data sample at target device.
"""
inputs = [_data['inputs'].to(self._device).float() for _data in data]
batch_data_samples: List[BaseDataElement] = []
# Model can get predictions without any data samples.
for _data in data:
if 'data_sample' in _data:
batch_data_samples.append(_data['data_sample'])
# Move data from CPU to corresponding device.
batch_data_samples = [
data_sample.to(self._device) for data_sample in batch_data_samples
]
if isinstance(data, Mapping):
return {key: self.cast_data(data[key]) for key in data}
elif isinstance(data, tuple) and hasattr(data, '_fields'):
# namedtuple
return type(data)(*(self.cast_data(sample)for sample in data)) # type: ignore # noqa: E501 # yapf:disable
elif isinstance(data, Sequence):
return [self.cast_data(sample) for sample in data]
elif isinstance(data, torch.Tensor):
return data.to(self.device)
elif isinstance(data, BaseDataElement):
return data.to(self.device)
else:
return data
if not batch_data_samples:
batch_data_samples = None # type: ignore
return inputs, batch_data_samples
def forward(self,
data: Sequence[dict],
training: bool = False) -> Tuple[torch.Tensor, Optional[list]]:
def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
"""Preprocesses the data into the model input format.
After the data pre-processing of :meth:`collate_data`, ``forward``
After the data pre-processing of :meth:`cast_data`, ``forward``
will stack the input tensor list to a batch tensor at the first
dimension.
Args:
data (Sequence[dict]): data sampled from dataloader.
data (dict): Data returned by dataloader
training (bool): Whether to enable training time augmentation.
Returns:
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
model input.
dict or list: Data in the same format as the model input.
"""
inputs, batch_data_samples = self.collate_data(data)
batch_inputs = torch.stack(inputs, dim=0)
return batch_inputs, batch_data_samples
return self.cast_data(data) # type: ignore
@property
def device(self):
@ -203,42 +183,79 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
torch.tensor(std).view(-1, 1, 1), False)
else:
self._enable_normalize = False
self.channel_conversion = rgb_to_bgr or bgr_to_rgb
self._channel_conversion = rgb_to_bgr or bgr_to_rgb
self.pad_size_divisor = pad_size_divisor
self.pad_value = pad_value
def forward(self,
data: Sequence[dict],
training: bool = False) -> Tuple[torch.Tensor, Optional[list]]:
def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
"""Performs normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (Sequence[dict]): data sampled from dataloader.
data (dict): Data sampled from dataset. If the collate
function of DataLoader is :obj:`pseudo_collate`, data will be a
list of dict. If collate function is :obj:`default_collate`,
data will be a tuple with batch input tensor and list of data
samples.
training (bool): Whether to enable training time augmentation. If
subclasses override this method, they can perform different
preprocessing strategies for training and testing based on the
value of ``training``.
Returns:
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
model input.
dict or list: Data in the same format as the model input.
"""
inputs, batch_data_samples = self.collate_data(data)
for idx, _input in enumerate(inputs):
# channel transform
if self.channel_conversion:
_input = _input[[2, 1, 0], ...]
# Normalization.
data = self.cast_data(data) # type: ignore
_batch_inputs = data['inputs']
# Process data with `pseudo_collate`.
if is_list_of(_batch_inputs, torch.Tensor):
batch_inputs = []
for _batch_input in _batch_inputs:
# channel transform
if self._channel_conversion:
_batch_input = _batch_input[[2, 1, 0], ...]
# Convert to float after channel conversion to ensure
# efficiency
_batch_input = _batch_input.float()
# Normalization.
if self._enable_normalize:
if self.mean.shape[0] == 3:
assert _batch_input.dim(
) == 3 and _batch_input.shape[0] == 3, (
'If the mean has 3 values, the input tensor '
'should in shape of (3, H, W), but got the tensor '
f'with shape {_batch_input.shape}')
_batch_input = (_batch_input - self.mean) / self.std
batch_inputs.append(_batch_input)
# Pad and stack Tensor.
batch_inputs = stack_batch(batch_inputs, self.pad_size_divisor,
self.pad_value)
# Process data with `default_collate`.
elif isinstance(_batch_inputs, torch.Tensor):
assert _batch_inputs.dim() == 4, (
'The input of `ImgDataPreprocessor` should be a NCHW tensor '
'or a list of tensor, but got a tensor with shape: '
f'{_batch_inputs.shape}')
if self._channel_conversion:
_batch_inputs = _batch_inputs[:, [2, 1, 0], ...]
# Convert to float after channel conversion to ensure
# efficiency
_batch_inputs = _batch_inputs.float()
if self._enable_normalize:
if self.mean.shape[0] == 3:
assert _input.dim() == 3 and _input.shape[0] == 3, (
'If the mean has 3 values, the input tensor should in '
'shape of (3, H, W), but got the tensor with shape '
f'{_input.shape}')
_input = (_input - self.mean) / self.std
inputs[idx] = _input
# Pad and stack Tensor.
batch_inputs = stack_batch(inputs, self.pad_size_divisor,
self.pad_value)
return batch_inputs, batch_data_samples
_batch_inputs = (_batch_inputs - self.mean) / self.std
h, w = _batch_inputs.shape[2:]
target_h = math.ceil(
h / self.pad_size_divisor) * self.pad_size_divisor
target_w = math.ceil(
w / self.pad_size_divisor) * self.pad_size_divisor
pad_h = target_h - h
pad_w = target_w - w
batch_inputs = F.pad(_batch_inputs, (0, pad_w, 0, pad_h),
'constant', self.pad_value)
else:
raise TypeError('Output of `cast_data` should be a list of dict '
'or a tuple with inputs and data_samples, but got'
f'{type(data)} {data}')
data['inputs'] = batch_inputs
data.setdefault('data_samples', None)
return data

View File

@ -1,12 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List
from typing import Any, Dict, Union
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
from mmengine.optim import OptimWrapper
from mmengine.registry import MODEL_WRAPPERS
from mmengine.structures import BaseDataElement
from ..utils import detect_anomalous_params
@ -90,7 +89,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
super().__init__(module=module, **kwargs)
self.detect_anomalous_params = detect_anomalous_params
def train_step(self, data: List[dict],
def train_step(self, data: Union[dict, tuple, list],
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""Interface for model forward, backward and parameters updating during
training process.
@ -105,7 +104,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
- Return log messages of losses.
Args:
data (List[dict]): Data sampled by dataloader.
data (dict or tuple or list): Data sampled from dataset.
optim_wrapper (OptimWrapper): A wrapper of optimizer to
update parameters.
@ -114,33 +113,53 @@ class MMDistributedDataParallel(DistributedDataParallel):
"""
# Enable automatic mixed precision training context.
with optim_wrapper.optim_context(self):
batch_inputs, data_samples = self.module.data_preprocessor(
data, training=True)
losses = self(batch_inputs, data_samples, mode='loss')
data = self.module.data_preprocessor(data, training=True)
losses = self._run_forward(data, mode='loss')
if self.detect_anomalous_params:
detect_anomalous_params(losses, model=self)
parsed_loss, log_vars = self.module.parse_losses(losses)
optim_wrapper.update_params(parsed_loss)
return log_vars
def val_step(self, data: List[dict]) -> List[BaseDataElement]:
def val_step(self, data: Union[dict, tuple, list]) -> list:
"""Gets the prediction of module during validation process.
Args:
data (List[dict]): Data sampled by dataloader.
data (dict or tuple or list): Data sampled from dataset.
Returns:
List[BaseDataElement] or dict: The predictions of given data.
list: The predictions of given data.
"""
return self.module.val_step(data)
data = self.module.data_preprocessor(data, training=False)
return self._run_forward(data, mode='predict')
def test_step(self, data: List[dict]) -> List[BaseDataElement]:
def test_step(self, data: Union[dict, tuple, list]) -> list:
"""Gets the predictions of module during testing process.
Args:
data: Data sampled by dataloader.
data (dict or tuple or list): Data sampled from dataset.
Returns:
List[BaseDataElement]: The predictions of given data.
list: The predictions of given data.
"""
return self.module.test_step(data)
data = self.module.data_preprocessor(data, training=False)
return self._run_forward(data, mode='predict')
def _run_forward(self, data: Union[dict, tuple, list], mode: str) -> Any:
"""Unpacks data for :meth:`forward`
Args:
data (dict or tuple or list): Data sampled from dataset.
mode (str): Mode of forward.
Returns:
dict or list: Results of training or testing mode.
"""
if isinstance(data, dict):
results = self(**data, mode=mode)
elif isinstance(data, (list, tuple)):
results = self(*data, mode=mode)
else:
raise TypeError('Output of `data_preprocessor` should be '
f'list, tuple or dict, but got {type(data)}')
return results

View File

@ -153,7 +153,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
super().__init__(module, process_group, cpu_offload,
fsdp_auto_wrap_policy, backward_prefetch)
def train_step(self, data: List[dict],
def train_step(self, data: dict,
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
"""Interface for model forward, backward and parameters updating during
training process.
@ -168,7 +168,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
- Return log messages of losses.
Args:
data (List[dict]): Data sampled by dataloader.
data (dict): Data sampled by dataloader.
optim_wrapper (OptimWrapper): A wrapper of optimizer to
update parameters.
@ -177,18 +177,23 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
"""
# enable automatic mixed precision training context.
with optim_wrapper.optim_context(self):
batch_inputs, data_samples = self.module.data_preprocessor(
data, training=True)
losses = self(batch_inputs, data_samples, mode='loss')
data = self.module.data_preprocessor(data, training=True)
if isinstance(data, dict):
losses = self(**data, mode='loss')
elif isinstance(data, (list, tuple)):
losses = self(*data, mode='loss')
else:
raise TypeError('Output of `data_preprocessor` should be '
f'list tuple or dict, but got {type(data)}')
parsed_loss, log_vars = self.module.parse_losses(losses)
optim_wrapper.update_params(parsed_loss)
return log_vars
def val_step(self, data: List[dict]) -> List[BaseDataElement]:
def val_step(self, data: dict) -> List[BaseDataElement]:
"""Gets the prediction of module during validation process.
Args:
data (List[dict]): Data sampled by dataloader.
data (dict): Data sampled by dataloader.
Returns:
List[BaseDataElement] or dict: The predictions of given data.
@ -196,11 +201,11 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
inputs, data_sample = self.module.data_preprocessor(data, False)
return self(inputs, data_sample, mode='predict')
def test_step(self, data: List[dict]) -> List[BaseDataElement]:
def test_step(self, data: dict) -> List[BaseDataElement]:
"""Gets the predictions of module during testing process.
Args:
data: Data sampled by dataloader.
data (dict): Data sampled by dataloader.
Returns:
List[BaseDataElement]: The predictions of given data.

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import ExitStack, contextmanager
from typing import Dict, List
from typing import Dict, Union
import torch
import torch.nn as nn
@ -9,7 +9,6 @@ from torch.nn.parallel.distributed import DistributedDataParallel
from mmengine.device import get_device
from mmengine.optim import OptimWrapperDict
from mmengine.registry import MODEL_WRAPPERS
from mmengine.structures import BaseDataElement
from .distributed import MMDistributedDataParallel
@ -86,13 +85,13 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel):
**kwargs)
module._modules[name] = sub_module
def train_step(self, data: List[dict],
def train_step(self, data: Union[dict, tuple, list],
optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]:
"""Interface for model forward, backward and parameters updating during
training process.
Args:
data: Data sampled by dataloader.
data (dict or tuple or list): Data sampled from dataset.
optim_wrapper (OptimWrapperDict): A wrapper of optimizer to
update parameters.
@ -101,25 +100,25 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel):
"""
return self.module.train_step(data, optim_wrapper)
def val_step(self, data) -> List[BaseDataElement]:
def val_step(self, data: Union[dict, tuple, list]) -> list:
"""Gets the prediction of module during validation process.
Args:
data (List[dict]): Data sampled by dataloader.
data (dict or tuple or list): Data sampled from dataset.
Returns:
List[BaseDataElement]: The predictions of given data.
list: The predictions of given data.
"""
return self.module.val_step(data)
def test_step(self, data: List[dict]) -> List[BaseDataElement]:
def test_step(self, data: Union[dict, tuple, list]) -> list:
"""Gets the predictions of module during testing process.
Args:
data: Data sampled by dataloader.
data (dict or tuple or list): Data sampled from dataset.
Returns:
ForwardResults: The predictions of given data.
list: The predictions of given data.
"""
return self.module.test_step(data)

View File

@ -361,7 +361,7 @@ class ValLoop(BaseLoop):
# outputs should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
outputs = self.runner.model.val_step(data_batch)
self.evaluator.process(data_batch, outputs)
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_val_iter',
batch_idx=idx,
@ -429,10 +429,10 @@ class TestLoop(BaseLoop):
'before_test_iter', batch_idx=idx, data_batch=data_batch)
# predictions should be sequence of BaseDataElement
with autocast(enabled=self.fp16):
predictions = self.runner.model.test_step(data_batch)
self.evaluator.process(data_batch, predictions)
outputs = self.runner.model.test_step(data_batch)
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
self.runner.call_hook(
'after_test_iter',
batch_idx=idx,
data_batch=data_batch,
outputs=predictions)
outputs=outputs)

View File

@ -20,7 +20,7 @@ from torch.utils.data import DataLoader
import mmengine
from mmengine.config import Config, ConfigDict
from mmengine.dataset import pseudo_collate, worker_init_fn
from mmengine.dataset import COLLATE_FUNCTIONS, worker_init_fn
from mmengine.device import get_device
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
is_distributed, master_only, sync_random_seed)
@ -1393,14 +1393,19 @@ class Runner:
# The default behavior of `collat_fn` in dataloader is to
# merge a list of samples to form a mini-batch of Tensor(s).
# However, to make this more flexible, collate_fn in MMengine does
# nothing. The action to merge a list of samples will be handled
# in model.
# However, in mmengine, if `collate_fn` is not defined in
# dataloader_cfg, `pseudo_collate` will only convert the list of
# samples into a dict without stacking the batch tensor.
collate_fn_cfg = dataloader_cfg.pop('collate_fn',
dict(type='pseudo_collate'))
collate_fn_type = collate_fn_cfg.pop('type')
collate_fn = COLLATE_FUNCTIONS.get(collate_fn_type)
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
data_loader = DataLoader(
dataset=dataset,
sampler=sampler if batch_sampler is None else None,
batch_sampler=batch_sampler,
collate_fn=pseudo_collate,
collate_fn=collate_fn,
worker_init_fn=init_fn,
**dataloader_cfg)
return data_loader

View File

@ -1080,8 +1080,7 @@ class Visualizer(ManagerMixin):
def add_datasample(self,
name,
image: np.ndarray,
gt_sample: Optional['BaseDataElement'] = None,
pred_sample: Optional['BaseDataElement'] = None,
data_sample: Optional['BaseDataElement'] = None,
draw_gt: bool = True,
draw_pred: bool = True,
show: bool = False,

View File

@ -0,0 +1,155 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmengine.dataset import default_collate, pseudo_collate
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
class TestDataUtils(TestCase):
def test_pseudo_collate(self):
# Test with list of dict tensor inputs.
input1 = torch.randn(1, 3, 5)
input2 = torch.randn(1, 3, 5)
label1 = torch.randn(1)
label2 = torch.randn(1)
data_batch = [
dict(inputs=input1, data_sample=label1),
dict(inputs=input2, data_sample=label2)
]
data_batch = pseudo_collate(data_batch)
self.assertTrue(torch.allclose(input1, data_batch['inputs'][0]))
self.assertTrue(torch.allclose(input2, data_batch['inputs'][1]))
self.assertTrue(torch.allclose(label1, data_batch['data_sample'][0]))
self.assertTrue(torch.allclose(label2, data_batch['data_sample'][1]))
# Test with list of dict, and each element contains `data_sample`
# inputs
data_sample1 = BaseDataElement(label=torch.tensor(1))
data_sample2 = BaseDataElement(label=torch.tensor(1))
data = [
dict(inputs=input1, data_sample=data_sample1),
dict(inputs=input2, data_sample=data_sample2),
]
data_batch = pseudo_collate(data)
batch_inputs, batch_data_sample = (data_batch['inputs'],
data_batch['data_sample'])
# check batch_inputs
self.assertTrue(is_list_of(batch_inputs, torch.Tensor))
self.assertIs(input1, batch_inputs[0])
self.assertIs(input2, batch_inputs[1])
# check data_sample
self.assertIs(batch_data_sample[0], data_sample1)
self.assertIs(batch_data_sample[1], data_sample2)
# Test with list of tuple, each tuple is a nested dict instance
data_batch = [(dict(
inputs=input1,
data_sample=data_sample1,
value=1,
name='1',
nested=dict(data_sample=data_sample1)),
dict(
inputs=input2,
data_sample=data_sample2,
value=2,
name='2',
nested=dict(data_sample=data_sample2))),
(dict(
inputs=input1,
data_sample=data_sample1,
value=1,
name='1',
nested=dict(data_sample=data_sample1)),
dict(
inputs=input2,
data_sample=data_sample2,
value=2,
name='2',
nested=dict(data_sample=data_sample2)))]
data_batch = pseudo_collate(data_batch)
batch_inputs_0 = data_batch[0]['inputs']
batch_inputs_1 = data_batch[1]['inputs']
batch_data_sample_0 = data_batch[0]['data_sample']
batch_data_sample_1 = data_batch[1]['data_sample']
batch_value_0 = data_batch[0]['value']
batch_value_1 = data_batch[1]['value']
batch_name_0 = data_batch[0]['name']
batch_name_1 = data_batch[1]['name']
batch_nested_0 = data_batch[0]['nested']
batch_nested_1 = data_batch[1]['nested']
self.assertTrue(is_list_of(batch_inputs_0, torch.Tensor))
self.assertTrue(is_list_of(batch_inputs_1, torch.Tensor))
self.assertIs(batch_inputs_0[0], input1)
self.assertIs(batch_inputs_0[1], input1)
self.assertIs(batch_inputs_1[0], input2)
self.assertIs(batch_inputs_1[1], input2)
self.assertIs(batch_data_sample_0[0], data_sample1)
self.assertIs(batch_data_sample_0[1], data_sample1)
self.assertIs(batch_data_sample_1[0], data_sample2)
self.assertIs(batch_data_sample_1[1], data_sample2)
self.assertEqual(batch_value_0, [1, 1])
self.assertEqual(batch_value_1, [2, 2])
self.assertEqual(batch_name_0, ['1', '1'])
self.assertEqual(batch_name_1, ['2', '2'])
self.assertIs(batch_nested_0['data_sample'][0], data_sample1)
self.assertIs(batch_nested_0['data_sample'][1], data_sample1)
self.assertIs(batch_nested_1['data_sample'][0], data_sample2)
self.assertIs(batch_nested_1['data_sample'][1], data_sample2)
def test_default_collate(self):
# `default_collate` has comment logic with `pseudo_collate`, therefore
# only test it cam stack batch tensor, convert int or float to tensor.
input1 = torch.randn(1, 3, 5)
input2 = torch.randn(1, 3, 5)
data_batch = [(
dict(inputs=input1, value=1, array=np.array(1)),
dict(inputs=input2, value=2, array=np.array(2)),
),
(
dict(inputs=input1, value=1, array=np.array(1)),
dict(inputs=input2, value=2, array=np.array(2)),
)]
data_batch = default_collate(data_batch)
batch_inputs_0 = data_batch[0]['inputs']
batch_inputs_1 = data_batch[1]['inputs']
batch_value_0 = data_batch[0]['value']
batch_value_1 = data_batch[1]['value']
batch_array_0 = data_batch[0]['array']
batch_array_1 = data_batch[1]['array']
self.assertEqual(tuple(batch_inputs_0.shape), (2, 1, 3, 5))
self.assertEqual(tuple(batch_inputs_1.shape), (2, 1, 3, 5))
self.assertTrue(
torch.allclose(batch_inputs_0, torch.stack([input1, input1])))
self.assertTrue(
torch.allclose(batch_inputs_1, torch.stack([input2, input2])))
self.assertTrue(
torch.allclose(batch_value_0,
torch.stack([torch.tensor(1),
torch.tensor(1)])))
self.assertTrue(
torch.allclose(batch_value_1,
torch.stack([torch.tensor(2),
torch.tensor(2)])))
self.assertTrue(
torch.allclose(batch_array_0,
torch.stack([torch.tensor(1),
torch.tensor(1)])))
self.assertTrue(
torch.allclose(batch_array_1,
torch.stack([torch.tensor(2),
torch.tensor(2)])))

View File

@ -41,9 +41,9 @@ class ToyMetric(BaseMetric):
def process(self, data_batch, predictions):
results = [{
'pred': pred.get('pred'),
'label': data['data_sample'].get('label')
} for pred, data in zip(predictions, data_batch)]
'pred': prediction['label'],
'label': prediction['label']
} for prediction in predictions]
self.results.extend(results)
def compute_metrics(self, results: List):
@ -68,8 +68,7 @@ class NonPrefixedMetric(BaseMetric):
"""Evaluator with unassigned `default_prefix` to test the warning
information."""
def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
def process(self, data_batch, predictions: Sequence[dict]) -> None:
pass
def compute_metrics(self, results: list) -> dict:
@ -81,12 +80,13 @@ def generate_test_results(size, batch_size, pred, label):
bs_residual = size % batch_size
for i in range(num_batch):
bs = bs_residual if i == num_batch - 1 else batch_size
data_batch = [
dict(
inputs=np.zeros((3, 10, 10)),
data_sample=BaseDataElement(label=label)) for _ in range(bs)
data_batch = {
'inputs': [np.zeros((3, 10, 10)) for _ in range(bs)],
'data_sample': [BaseDataElement(label=label) for _ in range(bs)]
}
predictions = [
BaseDataElement(pred=pred, label=label) for _ in range(bs)
]
predictions = [BaseDataElement(pred=pred) for _ in range(bs)]
yield (data_batch, predictions)
@ -99,9 +99,9 @@ class TestEvaluator(TestCase):
size = 10
batch_size = 4
for data_samples, predictions in generate_test_results(
for data_samples, outputs in generate_test_results(
size, batch_size, pred=1, label=1):
evaluator.process(data_samples, predictions)
evaluator.process(data_samples=outputs, data_batch=data_samples)
metrics = evaluator.evaluate(size=size)
self.assertAlmostEqual(metrics['Toy/accuracy'], 1.0)
@ -124,9 +124,9 @@ class TestEvaluator(TestCase):
size = 10
batch_size = 4
for data_samples, predictions in generate_test_results(
for data_samples, outputs in generate_test_results(
size, batch_size, pred=1, label=1):
evaluator.process(data_samples, predictions)
evaluator.process(data_samples=outputs, data_batch=data_samples)
metrics = evaluator.evaluate(size=size)
@ -145,9 +145,9 @@ class TestEvaluator(TestCase):
size = 10
batch_size = 4
for data_samples, predictions in generate_test_results(
for data_samples, outputs in generate_test_results(
size, batch_size, pred=1, label=1):
evaluator.process(data_samples, predictions)
evaluator.process(data_samples=outputs, data_batch=data_samples)
with self.assertRaisesRegex(
ValueError,
@ -233,13 +233,22 @@ class TestEvaluator(TestCase):
size = 10
all_data = [
dict(
inputs=np.zeros((3, 10, 10)),
data_sample=BaseDataElement(label=1)) for _ in range(size)
all_data = [dict() for _ in range(10)]
all_predictions = [
BaseDataElement(pred=0, label=1) for _ in range(size)
]
all_predictions = [BaseDataElement(pred=0) for _ in range(size)]
evaluator.offline_evaluate(all_data, all_predictions)
evaluator.offline_evaluate(all_predictions, all_data)
# Test with None data
all_data = None
evaluator.offline_evaluate(all_predictions, all_data)
# Different length of data and predictions will raise an error.
all_data = [dict() for _ in range(9)]
with self.assertRaisesRegex(
AssertionError,
'outputs and data should have the same length'):
evaluator.offline_evaluate(all_predictions, all_data)
@unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu')
def test_evaluate_cast_cpu(self):
@ -256,11 +265,12 @@ class TestEvaluator(TestCase):
for _ in range(size)
]
all_predictions = [
BaseDataElement(pred=torch.zeros((1, ), device='cuda'))
for _ in range(size)
BaseDataElement(
pred=torch.zeros((1, ), device='cuda'),
label=torch.ones((1, ), device='cuda')) for _ in range(size)
]
for data, pred in zip(all_data, all_predictions):
evaluator.process([data], [pred])
evaluator.process([pred], [data])
def test_results_device(results: List):
for result in results:

View File

@ -19,8 +19,8 @@ class TestDumpResults(TestCase):
def test_process(self):
metric = DumpResults(out_file_path='./results.pkl')
predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
metric.process(None, predictions)
data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
metric.process(None, data_samples)
self.assertEqual(len(metric.results), 1)
self.assertEqual(metric.results[0]['data'][0].device,
torch.device('cpu'))
@ -29,8 +29,8 @@ class TestDumpResults(TestCase):
temp_dir = tempfile.TemporaryDirectory()
path = osp.join(temp_dir.name, 'results.pkl')
metric = DumpResults(out_file_path=path)
predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
metric.process(None, predictions)
data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
metric.process(None, data_samples)
metric.compute_metrics(metric.results)
self.assertTrue(osp.isfile(path))

View File

@ -22,9 +22,10 @@ class ToyModel(nn.Module):
super().__init__()
self.linear = nn.Linear(2, 1)
def forward(self, batch_inputs, labels, mode='tensor'):
labels = torch.stack(labels)
outputs = self.linear(batch_inputs)
def forward(self, inputs, data_sample, mode='tensor'):
labels = torch.stack(data_sample)
inputs = torch.stack(inputs)
outputs = self.linear(inputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':

View File

@ -28,15 +28,15 @@ class ToyModel(BaseModel):
super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
self.conv = nn.Conv2d(3, 1, 1)
def forward(self, batch_inputs, data_samples=None, mode='tensor'):
def forward(self, inputs, data_sample=None, mode='tensor'):
if mode == 'loss':
out = self.conv(batch_inputs)
out = self.conv(inputs)
return dict(loss=out)
elif mode == 'predict':
out = self.conv(batch_inputs)
out = self.conv(inputs)
return out
elif mode == 'tensor':
out = self.conv(batch_inputs)
out = self.conv(inputs)
return out
@ -98,34 +98,34 @@ class TestBaseModel(TestCase):
model = ToyModel()
optimizer = SGD(model.parameters(), lr=0.1)
optim_wrapper = OptimWrapper(optimizer)
inputs = torch.randn(3, 1, 1)
data = dict(inputs=inputs)
inputs = torch.randn(1, 3, 1, 1)
data = dict(inputs=inputs, data_sample=None)
# initiate grad.
# model.conv.weight.grad = torch.randn(1, 3, 1, 1)
log_vars = model.train_step([data], optim_wrapper)
log_vars = model.train_step(data, optim_wrapper)
self.assertIsNotNone(model.conv.weight.grad)
self.assertIsInstance(log_vars['loss'], torch.Tensor)
def test_val_step(self):
inputs = torch.randn(3, 1, 1)
data = dict(inputs=inputs)
inputs = torch.randn(1, 3, 1, 1)
data = dict(inputs=inputs, data_sample=None)
model = ToyModel()
out = model.val_step([data])
out = model.val_step(data)
self.assertIsInstance(out, torch.Tensor)
def test_test_step(self):
inputs = torch.randn(3, 1, 1)
data = dict(inputs=inputs)
inputs = torch.randn(1, 3, 1, 1)
data = dict(inputs=inputs, data_sample=None)
model = ToyModel()
out = model.val_step([data])
out = model.val_step(data)
self.assertIsInstance(out, torch.Tensor)
@unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available')
def test_cuda(self):
inputs = torch.randn(3, 1, 1).cuda()
data = dict(inputs=inputs)
inputs = torch.randn(1, 3, 1, 1).cuda()
data = dict(inputs=inputs, data_sample=None)
model = ToyModel().cuda()
out = model.val_step([data])
out = model.val_step(data)
self.assertEqual(out.device.type, 'cuda')
model = NestedModel()
@ -139,10 +139,10 @@ class TestBaseModel(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available')
def test_to(self):
inputs = torch.randn(3, 1, 1).to('cuda:0')
data = dict(inputs=inputs)
inputs = torch.randn(1, 3, 1, 1).to('cuda:0')
data = dict(inputs=inputs, data_sample=None)
model = ToyModel().to(torch.cuda.current_device())
out = model.val_step([data])
out = model.val_step(data)
self.assertEqual(out.device.type, 'cuda')
model = NestedModel()

View File

@ -16,54 +16,73 @@ class TestBaseDataPreprocessor(TestCase):
self.assertEqual(base_data_preprocessor._device.type, 'cpu')
def test_forward(self):
# Test cpu forward with list of data samples.
base_data_preprocessor = BaseDataPreprocessor()
input1 = torch.randn(1, 3, 5)
input2 = torch.randn(1, 3, 5)
label1 = torch.randn(1)
label2 = torch.randn(1)
data = [
dict(inputs=input1, data_sample=label1),
dict(inputs=input2, data_sample=label2)
]
data = dict(inputs=[input1, input2], data_sample=[label1, label2])
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.shape, (2, 1, 3, 5))
output = base_data_preprocessor(data)
batch_inputs, batch_labels = output['inputs'], output['data_sample']
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
self.assertEqual(batch_inputs[0].shape, (1, 3, 5))
assert_allclose(input1, batch_inputs[0])
assert_allclose(input2, batch_inputs[1])
assert_allclose(label1, batch_labels[0])
assert_allclose(label2, batch_labels[1])
# Test with tuple of batch inputs and batch data samples
data = dict(
inputs=torch.stack([input1, input2]), data_sample=[label1, label2])
output = base_data_preprocessor(data)['inputs']
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
# Test cuda forward
if torch.cuda.is_available():
# Test with list of data samples.
base_data_preprocessor = base_data_preprocessor.cuda()
batch_inputs, batch_labels = base_data_preprocessor(data)
output = base_data_preprocessor(data)
batch_inputs, batch_labels = output['inputs'], output[
'data_sample']
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cuda')
base_data_preprocessor = base_data_preprocessor.cpu()
batch_inputs, batch_labels = base_data_preprocessor(data)
output = base_data_preprocessor(data)
batch_inputs, batch_labels = output['inputs'], output[
'data_sample']
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cpu')
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
batch_inputs, batch_labels = base_data_preprocessor(data)
output = base_data_preprocessor(data)
batch_inputs, batch_labels = output['inputs'], output[
'data_sample']
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cuda')
# device of `base_data_preprocessor` is cuda, output should be
# cuda tensor.
self.assertEqual(batch_inputs.device.type, 'cuda')
self.assertEqual(batch_labels[0].device.type, 'cuda')
class TestImgataPreprocessor(TestBaseDataPreprocessor):
class TestImgDataPreprocessor(TestBaseDataPreprocessor):
def test_init(self):
# initiate model without `data_preprocessor`
# Initiate processor without arguments
data_processor = ImgDataPreprocessor()
self.assertFalse(data_processor.channel_conversion)
self.assertFalse(data_processor._channel_conversion)
self.assertFalse(hasattr(data_processor, 'mean'))
self.assertFalse(hasattr(data_processor, 'std'))
self.assertEqual(data_processor.pad_size_divisor, 1)
assert_allclose(data_processor.pad_value, torch.tensor(0))
# initiate model with data_preprocessor` and feat keys
# Initiate model with bgr2rgb, mean, std .etc..
data_processor = ImgDataPreprocessor(
bgr_to_rgb=True,
mean=[0, 0, 0],
@ -71,7 +90,7 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor):
pad_size_divisor=16,
pad_value=10)
self.assertTrue(data_processor._enable_normalize)
self.assertTrue(data_processor.channel_conversion, True)
self.assertTrue(data_processor._channel_conversion, True)
assert_allclose(data_processor.mean,
torch.tensor([0, 0, 0]).view(-1, 1, 1))
assert_allclose(data_processor.std,
@ -113,10 +132,11 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor):
inputs2 = torch.randn(3, 15, 15)
data_sample1 = InstanceData(bboxes=torch.randn(5, 4))
data_sample2 = InstanceData(bboxes=torch.randn(5, 4))
data = [
dict(inputs=inputs1.clone(), data_sample=data_sample1.clone()),
dict(inputs=inputs2.clone(), data_sample=data_sample2.clone())
]
data = dict(
inputs=[inputs1.clone(), inputs2.clone()],
data_sample=[data_sample1.clone(),
data_sample2.clone()])
std = torch.tensor([1, 2, 3]).view(-1, 1, 1)
target_inputs1 = (inputs1.clone()[[2, 1, 0], ...] - 127.5) / std
@ -126,7 +146,8 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor):
target_inputs2 = F.pad(target_inputs2, (0, 1, 0, 1), value=10)
target_inputs = [target_inputs1, target_inputs2]
inputs, data_samples = data_preprocessor(data, True)
output = data_preprocessor(data, True)
inputs, data_samples = output['inputs'], output['data_sample']
self.assertTrue(torch.is_floating_point(inputs))
target_data_samples = [data_sample1, data_sample2]
@ -147,7 +168,8 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor):
target_inputs2 = F.pad(target_inputs2, (0, 1, 0, 1), value=10)
target_inputs = [target_inputs1, target_inputs2]
inputs, data_samples = data_preprocessor(data, True)
output = data_preprocessor(data, True)
inputs, data_samples = output['inputs'], output['data_sample']
self.assertTrue(torch.is_floating_point(inputs))
target_data_samples = [data_sample1, data_sample2]
@ -159,22 +181,53 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor):
# Test gray image with 3 dim mean will raise error
data_preprocessor = ImgDataPreprocessor(
mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5))
data = [
dict(inputs=torch.ones(10, 10)),
dict(inputs=torch.ones(10, 10))
]
data = dict(
inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None)
with self.assertRaisesRegex(AssertionError,
'If the mean has 3 values'):
data_preprocessor(data)
data = [
dict(inputs=torch.ones(1, 10, 10)),
dict(inputs=torch.ones(1, 10, 10))
]
data = dict(
inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None)
with self.assertRaisesRegex(AssertionError,
'If the mean has 3 values'):
data_preprocessor(data)
# Test stacked batch inputs and batch data samples
data_preprocessor = ImgDataPreprocessor(
mean=(127.5, 127.5, 127.5),
std=(127.5, 127.5, 127.5),
rgb_to_bgr=True,
pad_size_divisor=16)
_batch_inputs = torch.randn(2, 3, 10, 10)
_batch_labels = [torch.randn(1), torch.randn(1)]
data = dict(inputs=_batch_inputs, data_sample=_batch_labels)
output = data_preprocessor(data)
inputs, data_samples = output['inputs'], output['data_sample']
target_batch_inputs = _batch_inputs[:, [2, 1, 0], ...]
target_batch_inputs = (target_batch_inputs - 127.5) / 127.5
target_batch_inputs = F.pad(target_batch_inputs, (0, 6, 0, 6), value=0)
self.assertEqual(inputs.shape, torch.Size([2, 3, 16, 16]))
self.assertTrue(torch.is_floating_point(inputs))
assert_allclose(target_batch_inputs, inputs)
# Test batch inputs without convert channel order and pad
data_preprocessor = ImgDataPreprocessor(
mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5))
_batch_inputs = torch.randn(2, 3, 10, 10)
_batch_labels = [torch.randn(1), torch.randn(1)]
data = dict(inputs=_batch_inputs, data_sample=_batch_labels)
output = data_preprocessor(data)
inputs, data_samples = output['inputs'], output['data_sample']
target_batch_inputs = (_batch_inputs - 127.5) / 127.5
self.assertEqual(inputs.shape, torch.Size([2, 3, 10, 10]))
self.assertTrue(torch.is_floating_point(inputs))
assert_allclose(target_batch_inputs, inputs)
# Test empty `data_sample`
data = [dict(inputs=inputs1.clone()), dict(inputs=inputs2.clone())]
data_preprocessor(data, True)
data = dict(
inputs=[inputs1.clone(), inputs2.clone()], data_sample=None)
output = data_preprocessor(data, True)
inputs, data_samples = output['inputs'], output['data_sample']
self.assertIsNone(data_samples)
self.assertTrue(torch.is_floating_point(inputs))

View File

@ -9,9 +9,10 @@ import torch.nn as nn
from torch.optim import SGD
from mmengine.dist import all_gather
from mmengine.model import (BaseModel, MMDistributedDataParallel,
from mmengine.model import (BaseDataPreprocessor, BaseModel,
ExponentialMovingAverage,
MMDistributedDataParallel,
MMSeparateDistributedDataParallel)
from mmengine.model.averaged_model import ExponentialMovingAverage
from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase
@ -22,15 +23,22 @@ if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
from mmengine.model import MMFullyShardedDataParallel # noqa: F401
class ToyDataPreprocessor(BaseDataPreprocessor):
def forward(self, data: dict, training: bool = False):
self.called = True
return super().forward(data, training)
class ToyModel(BaseModel):
def __init__(self):
super().__init__()
super().__init__(data_preprocessor=ToyDataPreprocessor())
self.conv1 = nn.Conv2d(3, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
def forward(self, x, data_samples=None, mode='tensor'):
x = self.conv1(x)
def forward(self, inputs, data_sample=None, mode='tensor'):
x = self.conv1(inputs)
x = self.conv2(x)
if mode == 'loss':
return dict(loss=x)
@ -48,10 +56,10 @@ class ComplexModel(BaseModel):
self.conv2 = nn.Conv2d(3, 1, 1)
def train_step(self, data, optim_wrapper):
batch_inputs, _ = self.data_preprocessor(data)
loss1 = self.conv1(batch_inputs)
inputs = self.data_preprocessor(data)['inputs']
loss1 = self.conv1(inputs)
optim_wrapper['optim_wrapper1'].update_params(loss1)
loss2 = self.conv2(batch_inputs)
loss2 = self.conv2(inputs)
optim_wrapper['optim_wrapper2'].update_params(loss2)
return dict(loss1=loss1, loss2=loss2)
@ -82,9 +90,9 @@ class TestDistributedDataParallel(MultiProcessTestCase):
optimizer = SGD(ddp_model.parameters(), lr=0)
optim_wrapper = AmpOptimWrapper(
optimizer=optimizer, accumulative_counts=3)
inputs = torch.randn(3, 1, 1).cuda() * self.rank * 255
data = dict(inputs=inputs, data_sample=MagicMock())
res = ddp_model.train_step([data], optim_wrapper=optim_wrapper)['loss']
inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255
data = dict(inputs=inputs, data_sample=None)
res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss']
self.assertIs(res.dtype, torch.float16)
grad = ddp_model.module.conv1.weight.grad
all_grads = all_gather(grad)
@ -92,10 +100,10 @@ class TestDistributedDataParallel(MultiProcessTestCase):
assert_allclose(all_grads[0], all_grads[1])
# Gradient accumulation
ddp_model.train_step([data], optim_wrapper=optim_wrapper)
ddp_model.train_step(data, optim_wrapper=optim_wrapper)
# Test update params and clean grads.
ddp_model.train_step([data], optim_wrapper=optim_wrapper)
ddp_model.train_step(data, optim_wrapper=optim_wrapper)
grad = ddp_model.module.conv1.weight.grad
all_grads = all_gather(grad)
assert_allclose(all_grads[0], torch.zeros_like(all_grads[0]))
@ -105,20 +113,22 @@ class TestDistributedDataParallel(MultiProcessTestCase):
self._init_dist_env(self.rank, self.world_size)
model = ToyModel()
ddp_model = MMDistributedDataParallel(module=model)
inputs = torch.randn(3, 1, 1) * self.rank * 255
data = dict(inputs=inputs, data_sample=MagicMock())
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
data = dict(inputs=inputs, data_sample=None)
# Test get predictions.
predictions = ddp_model.val_step([data])
predictions = ddp_model.val_step(data)
self.assertIsInstance(predictions, torch.Tensor)
self.assertTrue(model.data_preprocessor.called)
def test_test_step(self):
self._init_dist_env(self.rank, self.world_size)
model = ToyModel()
ddp_model = MMDistributedDataParallel(module=model)
inputs = torch.randn(3, 1, 1) * self.rank * 255
data = dict(inputs=inputs, data_sample=MagicMock())
predictions = ddp_model.test_step([data])
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
data = dict(inputs=inputs, data_sample=None)
predictions = ddp_model.test_step(data)
self.assertIsInstance(predictions, torch.Tensor)
self.assertTrue(model.data_preprocessor.called)
def _init_dist_env(self, rank, world_size):
"""Initialize the distributed environment."""
@ -157,30 +167,30 @@ class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel):
optim_wrapper2 = OptimWrapper(optimizer2, 1)
optim_wrapper_dict = OptimWrapperDict(
optim_wrapper1=optim_wrapper1, optim_wrapper2=optim_wrapper2)
inputs = torch.randn(3, 1, 1).cuda() * self.rank * 255
data = dict(inputs=inputs)
inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255
data = dict(inputs=inputs, data_sample=None)
# Automatically sync grads of `optim_wrapper1` since
# `cumulative_iters` = 1
ddp_model.train()
self.assertTrue(ddp_model.training)
ddp_model.train_step([data], optim_wrapper=optim_wrapper_dict)
ddp_model.train_step(data, optim_wrapper=optim_wrapper_dict)
def test_val_step(self):
self._init_dist_env(self.rank, self.world_size)
model = ComplexModel()
ddp_model = MMSeparateDistributedDataParallel(model)
data = torch.randn(3, 1, 1)
data = torch.randn(1, 3, 1, 1)
# Test get predictions.
ddp_model.eval()
self.assertFalse(ddp_model.training)
predictions = ddp_model.val_step([data])
predictions = ddp_model.val_step(data)
self.assertEqual(predictions, 1)
def test_test_step(self):
self._init_dist_env(self.rank, self.world_size)
model = ComplexModel()
ddp_model = MMSeparateDistributedDataParallel(model)
data = torch.randn(3, 1, 1)
data = torch.randn(1, 3, 1, 1)
# Test get predictions.
ddp_model.eval()
self.assertFalse(ddp_model.training)
@ -225,27 +235,27 @@ class TestMMFullyShardedDataParallel(MultiProcessTestCase):
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
optimizer = SGD(fsdp_model.parameters(), lr=0)
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
inputs = torch.randn(3, 1, 1) * self.rank * 255
data = dict(inputs=inputs, data_sample=MagicMock())
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
data = dict(inputs=[inputs], data_sample=MagicMock())
fsdp_model.train()
self.assertTrue(fsdp_model.training)
fsdp_model.train_step([data], optim_wrapper=optim_wrapper)
fsdp_model.train_step(data, optim_wrapper=optim_wrapper)
def test_val_step(self):
self._init_dist_env(self.rank, self.world_size)
model = ToyModel()
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
inputs = torch.randn(3, 1, 1) * self.rank * 255
data = dict(inputs=inputs, data_sample=MagicMock())
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
data = dict(inputs=[inputs], data_sample=MagicMock())
# Test get predictions.
predictions = fsdp_model.val_step([data])
predictions = fsdp_model.val_step(data)
self.assertIsInstance(predictions, torch.Tensor)
def test_test_step(self):
self._init_dist_env(self.rank, self.world_size)
model = ToyModel()
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
inputs = torch.randn(3, 1, 1) * self.rank * 255
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
data = dict(inputs=inputs, data_sample=MagicMock())
predictions = fsdp_model.test_step([data])
predictions = fsdp_model.test_step(data)
self.assertIsInstance(predictions, torch.Tensor)

View File

@ -14,7 +14,7 @@ from torch.optim import SGD, Adam
from torch.utils.data import DataLoader, Dataset
from mmengine.config import Config
from mmengine.dataset import DefaultSampler
from mmengine.dataset import COLLATE_FUNCTIONS, DefaultSampler, pseudo_collate
from mmengine.evaluator import BaseMetric, Evaluator
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
IterTimerHook, LoggerHook, ParamSchedulerHook,
@ -44,15 +44,18 @@ class ToyModel(BaseModel):
self.linear1 = nn.Linear(2, 2)
self.linear2 = nn.Linear(2, 1)
def forward(self, batch_inputs, labels, mode='tensor'):
labels = torch.stack(labels)
outputs = self.linear1(batch_inputs)
def forward(self, inputs, data_sample, mode='tensor'):
if isinstance(inputs, list):
inputs = torch.stack(inputs)
if isinstance(data_sample, list):
data_sample = torch.stack(data_sample)
outputs = self.linear1(inputs)
outputs = self.linear2(outputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (labels - outputs).sum()
loss = (data_sample - outputs).sum()
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':
@ -74,15 +77,16 @@ class ToySyncBNModel(BaseModel):
self.conv = nn.Conv2d(3, 8, 2)
self.bn = nn.SyncBatchNorm(8)
def forward(self, batch_inputs, labels, mode='tensor'):
labels = torch.stack(labels)
outputs = self.conv(batch_inputs)
def forward(self, inputs, data_sample, mode='tensor'):
data_sample = torch.stack(data_sample)
inputs = torch.stack(inputs)
outputs = self.conv(inputs)
outputs = self.bn(outputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (labels - outputs).sum()
loss = (data_sample - outputs).sum()
outputs = dict(loss=loss)
return outputs
elif mode == 'predict':
@ -98,24 +102,25 @@ class ToyGANModel(BaseModel):
self.linear1 = nn.Linear(2, 1)
self.linear2 = nn.Linear(2, 1)
def forward(self, batch_inputs, labels, mode='tensor'):
labels = torch.stack(labels)
output1 = self.linear1(batch_inputs)
output2 = self.linear2(batch_inputs)
def forward(self, inputs, data_sample, mode='tensor'):
data_sample = torch.stack(data_sample)
inputs = torch.stack(inputs)
output1 = self.linear1(inputs)
output2 = self.linear2(inputs)
if mode == 'tensor':
return output1, output2
elif mode == 'loss':
loss1 = (labels - output1).sum()
loss2 = (labels - output2).sum()
loss1 = (data_sample - output1).sum()
loss2 = (data_sample - output2).sum()
outputs = dict(linear1=loss1, linear2=loss2)
return outputs
elif mode == 'predict':
return output1, output2
def train_step(self, data, optim_wrapper):
batch_inputs, batch_labels = self.data_preprocessor(data)
loss = self(batch_inputs, batch_labels, mode='loss')
data = self.data_preprocessor(data)
loss = self(**data, mode='loss')
optim_wrapper['linear1'].update_params(loss['linear1'])
optim_wrapper['linear2'].update_params(loss['linear2'])
return loss
@ -193,7 +198,7 @@ class ToyMetric1(BaseMetric):
super().__init__(collect_device=collect_device)
self.dummy_metrics = dummy_metrics
def process(self, data_samples, predictions):
def process(self, data_batch, predictions):
result = {'acc': 1}
self.results.append(result)
@ -208,7 +213,7 @@ class ToyMetric2(BaseMetric):
super().__init__(collect_device=collect_device)
self.dummy_metrics = dummy_metrics
def process(self, data_samples, predictions):
def process(self, data_batch, predictions):
result = {'acc': 1}
self.results.append(result)
@ -335,7 +340,12 @@ class ToyEvaluator(Evaluator):
def collate_fn(data_batch):
return data_batch
return pseudo_collate(data_batch)
@COLLATE_FUNCTIONS.register_module()
def custom_collate(data_batch, pad_value):
return pseudo_collate(data_batch)
class TestRunner(TestCase):
@ -1428,7 +1438,7 @@ class TestRunner(TestCase):
val_batch_idx_targets):
self.assertEqual(result, target)
# 5. test dynamic interval in IterBasedTrainLoop
# 5.1 test dynamic interval in IterBasedTrainLoop
max_iters = 12
interval = 5
dynamic_intervals = [(11, 2)]
@ -1465,7 +1475,7 @@ class TestRunner(TestCase):
for result, target, in zip(val_interval_results, val_interval_targets):
self.assertEqual(result, target)
# 6. test dynamic interval in EpochBasedTrainLoop
# 5.2 test dynamic interval in EpochBasedTrainLoop
max_epochs = 12
interval = 5
dynamic_intervals = [(11, 2)]
@ -1579,6 +1589,30 @@ class TestRunner(TestCase):
runner = runner.from_cfg(cfg)
runner.train()
# 10.1 Test build dataloader with default collate function
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_train10.1'
cfg.train_dataloader.update(collate_fn=dict(type='default_collate'))
runner = Runner.from_cfg(cfg)
runner.train()
# 10.2 Test build dataloader with custom collate function
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_train10.2'
cfg.train_dataloader.update(
collate_fn=dict(type='custom_collate', pad_value=100))
runner = Runner.from_cfg(cfg)
runner.train()
# 11 test build dataloader without default arguments of collate
# function.
with self.assertRaises(TypeError):
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_train11'
cfg.train_dataloader.update(collate_fn=dict(type='custom_collate'))
runner = Runner.from_cfg(cfg)
runner.train()
def test_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_val1'
@ -1890,8 +1924,6 @@ class TestRunner(TestCase):
ckpt = torch.load(path)
self.assertEqual(ckpt['meta']['epoch'], 3)
self.assertEqual(ckpt['meta']['iter'], 12)
self.assertEqual(ckpt['meta']['dataset_meta'],
runner.train_dataloader.dataset.metainfo)
self.assertEqual(ckpt['meta']['experiment_name'],
runner.experiment_name)
self.assertEqual(ckpt['meta']['seed'], runner.seed)