mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor data flow to make the interface more natural (#468)
* [Refactor]: modify interface of Visualizer.add_datasample (#365) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader` (#323) * acollate data in dataloader * fix docstring * refine comment * fix as comment * refactor default collate and psedo collate * foramt test file * fix docstring * fix as comment * rename elem to data_item * minor fix * fix as comment * [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (#360) * refine evaluator and metric * compatible with new default collate * replace default collate with pseudo * Handle data_batch in metric * fix unit test * fix unit test * fix unit test * minor refine * make data_batch optional make data_batch optional * rename outputs to predictions * fix ut * rename predictions to outputs * fix docstring * fix docstring * fix unit test * make outputs and data_batch to kwargs * fix unit test * keep signature of metric * fix ut * rename pred_sample arguments to data_sample(Visualizer) * fix loop and ut * [refactor]: Refactor model dataflow (#398) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * refactor model data flow tmp_commt tmp commit * make val_cfg and test_cfg optional * roll back runner * pass test mmdet * fix as comment fix as comment fix ci in DataPreprocessor * fix ut * fix ut * fix rebase main * [Fix]: Fix test val ddp (#462) * [Fix] Fix docstring and type hint of data flow (#463) * Fix docstring of data flow * change signature of hook * fix unit test * resolve conflicts * fix lint
This commit is contained in:
parent
7e1d7af2d9
commit
8770c6c7fc
@ -228,10 +228,8 @@ class CheckInvalidLossHook(Hook):
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
if self.every_n_train_iters(runner, self.interval):
|
||||
assert torch.isfinite(outputs['loss']),\
|
||||
|
@ -2,10 +2,11 @@
|
||||
from .base_dataset import BaseDataset, Compose, force_full_init
|
||||
from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset
|
||||
from .sampler import DefaultSampler, InfiniteSampler
|
||||
from .utils import pseudo_collate, worker_init_fn
|
||||
from .utils import (COLLATE_FUNCTIONS, default_collate, pseudo_collate,
|
||||
worker_init_fn)
|
||||
|
||||
__all__ = [
|
||||
'BaseDataset', 'Compose', 'force_full_init', 'ClassBalancedDataset',
|
||||
'ConcatDataset', 'RepeatDataset', 'DefaultSampler', 'InfiniteSampler',
|
||||
'worker_init_fn', 'pseudo_collate'
|
||||
'worker_init_fn', 'pseudo_collate', 'COLLATE_FUNCTIONS', 'default_collate'
|
||||
]
|
||||
|
@ -1,11 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import random
|
||||
from typing import Sequence
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data._utils.collate import \
|
||||
default_collate as torch_default_collate
|
||||
|
||||
DATA_BATCH = Sequence[dict]
|
||||
from mmengine.registry import Registry
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
COLLATE_FUNCTIONS = Registry('Collate Functions')
|
||||
|
||||
|
||||
def worker_init_fn(worker_id: int, num_workers: int, rank: int,
|
||||
@ -28,16 +33,124 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int,
|
||||
torch.manual_seed(worker_seed)
|
||||
|
||||
|
||||
def pseudo_collate(data_batch: DATA_BATCH) -> DATA_BATCH:
|
||||
"""The default behavior of dataloader is to merge a list of samples to form
|
||||
a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate`` does
|
||||
nothing just returns ``data_batch``.
|
||||
@COLLATE_FUNCTIONS.register_module()
|
||||
def pseudo_collate(data_batch: Sequence) -> Any:
|
||||
"""Convert list of data sampled from dataset into a batch of data, of which
|
||||
type consistent with the type of each data_itement in ``data_batch``.
|
||||
|
||||
The default behavior of dataloader is to merge a list of samples to form
|
||||
a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate``
|
||||
will not stack tensors to batch tensors, and convert int, float, ndarray to
|
||||
tensors.
|
||||
|
||||
This code is referenced from:
|
||||
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_. # noqa: E501
|
||||
Args:
|
||||
data_batch (Sequence[dict]): Batch of data from
|
||||
dataloader.
|
||||
data_batch (Sequence): Batch of data from dataloader.
|
||||
|
||||
Returns:
|
||||
Sequence[dict]: Return input ``data_batch``.
|
||||
Any: Transversed Data in the same format as the data_itement of
|
||||
``data_batch``.
|
||||
"""
|
||||
return data_batch
|
||||
data_item = data_batch[0]
|
||||
data_item_type = type(data_item)
|
||||
if isinstance(data_item, (str, bytes)):
|
||||
return data_batch
|
||||
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'):
|
||||
# named tuple
|
||||
return data_item_type(*(pseudo_collate(samples)
|
||||
for samples in zip(*data_batch)))
|
||||
elif isinstance(data_item, Sequence):
|
||||
# check to make sure that the data_itements in batch have
|
||||
# consistent size
|
||||
it = iter(data_batch)
|
||||
data_item_size = len(next(it))
|
||||
if not all(len(data_item) == data_item_size for data_item in it):
|
||||
raise RuntimeError(
|
||||
'each data_itement in list of batch should be of equal size')
|
||||
transposed = list(zip(*data_batch))
|
||||
|
||||
if isinstance(data_item, tuple):
|
||||
return [pseudo_collate(samples)
|
||||
for samples in transposed] # Compat with Pytorch.
|
||||
else:
|
||||
try:
|
||||
return data_item_type(
|
||||
[pseudo_collate(samples) for samples in transposed])
|
||||
except TypeError:
|
||||
# The sequence type may not support `__init__(iterable)`
|
||||
# (e.g., `range`).
|
||||
return [pseudo_collate(samples) for samples in transposed]
|
||||
elif isinstance(data_item, Mapping):
|
||||
return data_item_type({
|
||||
key: pseudo_collate([d[key] for d in data_batch])
|
||||
for key in data_item
|
||||
})
|
||||
else:
|
||||
return data_batch
|
||||
|
||||
|
||||
@COLLATE_FUNCTIONS.register_module()
|
||||
def default_collate(data_batch: Sequence) -> Any:
|
||||
"""Convert list of data sampled from dataset into a batch of data, of which
|
||||
type consistent with the type of each data_itement in ``data_batch``.
|
||||
|
||||
Different from :func:`pseudo_collate`, ``default_collate`` will stack
|
||||
tensor contained in ``data_batch`` into a batched tensor with the
|
||||
first dimension batch size, and then move input tensor to the target
|
||||
device.
|
||||
|
||||
Different from ``default_collate`` in pytorch, ``default_collate`` will
|
||||
not process ``BaseDataElement``.
|
||||
|
||||
This code is referenced from:
|
||||
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_. # noqa: E501
|
||||
|
||||
Note:
|
||||
``default_collate`` only accept input tensor with the same shape.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence): Data sampled from dataset.
|
||||
|
||||
Returns:
|
||||
Any: Data in the same format as the data_itement of ``data_batch``, of which
|
||||
tensors have been stacked, and ndarray, int, float have been
|
||||
converted to tensors.
|
||||
"""
|
||||
data_item = data_batch[0]
|
||||
data_item_type = type(data_item)
|
||||
|
||||
if isinstance(data_item, (BaseDataElement, str, bytes)):
|
||||
return data_batch
|
||||
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'):
|
||||
# named_tuple
|
||||
return data_item_type(*(default_collate(samples)
|
||||
for samples in zip(*data_batch)))
|
||||
elif isinstance(data_item, Sequence):
|
||||
# check to make sure that the data_itements in batch have
|
||||
# consistent size
|
||||
it = iter(data_batch)
|
||||
data_item_size = len(next(it))
|
||||
if not all(len(data_item) == data_item_size for data_item in it):
|
||||
raise RuntimeError(
|
||||
'each data_itement in list of batch should be of equal size')
|
||||
transposed = list(zip(*data_batch))
|
||||
|
||||
if isinstance(data_item, tuple):
|
||||
return [default_collate(samples)
|
||||
for samples in transposed] # Compat with Pytorch.
|
||||
else:
|
||||
try:
|
||||
return data_item_type(
|
||||
[default_collate(samples) for samples in transposed])
|
||||
except TypeError:
|
||||
# The sequence type may not support `__init__(iterable)`
|
||||
# (e.g., `range`).
|
||||
return [default_collate(samples) for samples in transposed]
|
||||
elif isinstance(data_item, Mapping):
|
||||
return data_item_type({
|
||||
key: default_collate([d[key] for d in data_batch])
|
||||
for key in data_item
|
||||
})
|
||||
else:
|
||||
return torch_default_collate(data_batch)
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Iterator, List, Optional, Sequence, Union
|
||||
from typing import Any, Iterator, List, Optional, Sequence, Union
|
||||
|
||||
from mmengine.dataset import pseudo_collate
|
||||
from mmengine.registry import EVALUATOR, METRICS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from .metric import BaseMetric
|
||||
@ -37,34 +38,26 @@ class Evaluator:
|
||||
for metric in self.metrics:
|
||||
metric.dataset_meta = dataset_meta
|
||||
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
predictions: Sequence[BaseDataElement]):
|
||||
def process(self,
|
||||
data_samples: Sequence[BaseDataElement],
|
||||
data_batch: Optional[Any] = None):
|
||||
"""Convert ``BaseDataSample`` to dict and invoke process method of each
|
||||
metric.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[BaseDataElement]): A batch of outputs from
|
||||
the model.
|
||||
data_samples (Sequence[BaseDataElement]): predictions of the model,
|
||||
and the ground truth of the validation set.
|
||||
data_batch (Any, optional): A batch of data from the dataloader.
|
||||
"""
|
||||
_data_batch = []
|
||||
for data in data_batch:
|
||||
if isinstance(data['data_sample'], BaseDataElement):
|
||||
_data_batch.append(
|
||||
dict(
|
||||
inputs=data['inputs'],
|
||||
data_sample=data['data_sample'].to_dict()))
|
||||
_data_samples = []
|
||||
for data_sample in data_samples:
|
||||
if isinstance(data_sample, BaseDataElement):
|
||||
_data_samples.append(data_sample.to_dict())
|
||||
else:
|
||||
_data_batch.append(data)
|
||||
_predictions = []
|
||||
for pred in predictions:
|
||||
if isinstance(pred, BaseDataElement):
|
||||
_predictions.append(pred.to_dict())
|
||||
else:
|
||||
_predictions.append(pred)
|
||||
_data_samples.append(data_sample)
|
||||
|
||||
for metric in self.metrics:
|
||||
metric.process(_data_batch, _predictions)
|
||||
metric.process(data_batch, _data_samples)
|
||||
|
||||
def evaluate(self, size: int) -> dict:
|
||||
"""Invoke ``evaluate`` method of each metric and collect the metrics
|
||||
@ -97,20 +90,26 @@ class Evaluator:
|
||||
return metrics
|
||||
|
||||
def offline_evaluate(self,
|
||||
data: Sequence,
|
||||
predictions: Sequence,
|
||||
data_samples: Sequence,
|
||||
data: Optional[Sequence] = None,
|
||||
chunk_size: int = 1):
|
||||
"""Offline evaluate the dumped predictions on the given data .
|
||||
|
||||
Args:
|
||||
data (Sequence): All data of the validation set.
|
||||
predictions (Sequence): All predictions of the model on the
|
||||
validation set.
|
||||
data_samples (Sequence): All predictions and ground truth of the
|
||||
model and the validation set.
|
||||
data (Sequence, optional): All data of the validation set.
|
||||
chunk_size (int): The number of data samples and predictions to be
|
||||
processed in a batch.
|
||||
"""
|
||||
|
||||
# support chunking iterable objects
|
||||
if data is not None:
|
||||
assert len(data_samples) == len(data), (
|
||||
'outputs and data should have the same length, but got '
|
||||
f'outputs length: {len(data_samples)} '
|
||||
f'data length: {len(data)}')
|
||||
|
||||
def get_chunks(seq: Iterator, chunk_size=1):
|
||||
stop = False
|
||||
while not stop:
|
||||
@ -125,9 +124,11 @@ class Evaluator:
|
||||
yield chunk
|
||||
|
||||
size = 0
|
||||
for data_chunk, pred_chunk in zip(
|
||||
get_chunks(iter(data), chunk_size),
|
||||
get_chunks(iter(predictions), chunk_size)):
|
||||
size += len(data_chunk)
|
||||
self.process(data_chunk, pred_chunk)
|
||||
for output_chunk in get_chunks(iter(data_samples), chunk_size):
|
||||
if data is not None:
|
||||
data_chunk = pseudo_collate(data[size:size + chunk_size])
|
||||
else:
|
||||
data_chunk = None
|
||||
size += len(output_chunk)
|
||||
self.process(output_chunk, data_chunk)
|
||||
return self.evaluate(size)
|
||||
|
@ -58,15 +58,14 @@ class BaseMetric(metaclass=ABCMeta):
|
||||
self._dataset_meta = dataset_meta
|
||||
|
||||
@abstractmethod
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
predictions: Sequence[dict]) -> None:
|
||||
def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None:
|
||||
"""Process one batch of data samples and predictions. The processed
|
||||
results should be stored in ``self.results``, which will be used to
|
||||
compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[dict]): A batch of data from the dataloader.
|
||||
predictions (Sequence[dict]): A batch of outputs from
|
||||
data_batch (Any): A batch of data from the dataloader.
|
||||
data_samples (Sequence[dict]): A batch of outputs from
|
||||
the model.
|
||||
"""
|
||||
|
||||
@ -146,8 +145,7 @@ class DumpResults(BaseMetric):
|
||||
raise ValueError('The output file must be a pkl file.')
|
||||
self.out_file_path = out_file_path
|
||||
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
predictions: Sequence[dict]) -> None:
|
||||
def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
|
||||
"""transfer tensors in predictions to CPU."""
|
||||
self.results.extend(_to_cpu(predictions))
|
||||
|
||||
|
@ -12,7 +12,7 @@ from mmengine.registry import HOOKS
|
||||
from mmengine.utils import is_list_of, is_seq_of
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
@ -470,9 +470,8 @@ class CheckpointHook(Hook):
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict, optional): Outputs from model. Defaults to None.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
"""
|
||||
if self.by_epoch:
|
||||
return
|
||||
|
@ -4,10 +4,9 @@ from typing import Optional, Sequence, Union
|
||||
import torch
|
||||
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
@ -38,18 +37,15 @@ class EmptyCacheHook(Hook):
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Union[dict,
|
||||
Sequence[BaseDataElement]]] = None,
|
||||
outputs: Optional[Union[dict, Sequence]] = None,
|
||||
mode: str = 'train') -> None:
|
||||
"""Empty cache after an iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
outputs (dict or sequence, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
if self._do_after_iter:
|
||||
|
@ -1,9 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
from mmengine.structures import BaseDataElement
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
||||
|
||||
|
||||
class Hook:
|
||||
@ -184,8 +182,7 @@ class Hook:
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
"""
|
||||
self._before_iter(
|
||||
runner, batch_idx=batch_idx, data_batch=data_batch, mode='train')
|
||||
@ -200,7 +197,7 @@ class Hook:
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
data_batch (dict, optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
"""
|
||||
self._before_iter(
|
||||
@ -216,7 +213,7 @@ class Hook:
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
batch_idx (int): The index of the current batch in the test loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
"""
|
||||
self._before_iter(
|
||||
@ -233,10 +230,8 @@ class Hook:
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
data_batch (dict tuple or list, optional): Data from dataloader.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
self._after_iter(
|
||||
runner,
|
||||
@ -249,18 +244,15 @@ class Hook:
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataElement]] = None) \
|
||||
-> None:
|
||||
outputs: Optional[Sequence] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each validation iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the val loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from
|
||||
model. Defaults to None.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
outputs (Sequence, optional): Outputs from model.
|
||||
"""
|
||||
self._after_iter(
|
||||
runner,
|
||||
@ -269,22 +261,19 @@ class Hook:
|
||||
outputs=outputs,
|
||||
mode='val')
|
||||
|
||||
def after_test_iter(
|
||||
self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
|
||||
def after_test_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence] = None) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each test iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the test loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
outputs (Sequence, optional): Outputs from model.
|
||||
"""
|
||||
self._after_iter(
|
||||
runner,
|
||||
@ -327,8 +316,7 @@ class Hook:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
pass
|
||||
@ -337,8 +325,7 @@ class Hook:
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Union[Sequence[BaseDataElement],
|
||||
dict]] = None,
|
||||
outputs: Optional[Union[Sequence, dict]] = None,
|
||||
mode: str = 'train') -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
operations after each epoch.
|
||||
@ -347,10 +334,8 @@ class Hook:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (Sequence[BaseDataElement], optional): Outputs from model.
|
||||
Defaults to None.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
outputs (dict or Sequence, optional): Outputs from model.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
pass
|
||||
|
@ -3,10 +3,9 @@ import time
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
@ -54,8 +53,8 @@ class IterTimerHook(Hook):
|
||||
runner (Runner): The runner of the training, validation and
|
||||
testing process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
data_batch (dict or tuple or list, optional): Data from
|
||||
dataloader.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
# Update data loading time in `runner.message_hub`.
|
||||
@ -66,8 +65,7 @@ class IterTimerHook(Hook):
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Union[dict,
|
||||
Sequence[BaseDataElement]]] = None,
|
||||
outputs: Optional[Union[dict, Sequence]] = None,
|
||||
mode: str = 'train') -> None:
|
||||
"""Calculating time for an iteration and updating "time"
|
||||
``HistoryBuffer`` of ``runner.message_hub``.
|
||||
@ -76,10 +74,8 @@ class IterTimerHook(Hook):
|
||||
runner (Runner): The runner of the training validation and
|
||||
testing process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict or sequence, optional): Outputs from model. Defaults
|
||||
to None.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
outputs (dict or sequence, optional): Outputs from model.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
# Update iteration time in `runner.message_hub`.
|
||||
|
@ -7,10 +7,9 @@ from typing import Dict, Optional, Sequence, Union
|
||||
from mmengine.fileio import FileClient, dump
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from mmengine.utils import is_tuple_of, scandir
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
||||
SUFFIX_TYPE = Union[Sequence[str], str]
|
||||
|
||||
|
||||
@ -125,9 +124,8 @@ class LoggerHook(Hook):
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict, optional): Outputs from model. Defaults to None.
|
||||
data_batch (dict tuple or list, optional): Data from dataloader.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
"""
|
||||
# Print experiment name every n iterations.
|
||||
if self.every_n_train_iters(
|
||||
@ -152,41 +150,38 @@ class LoggerHook(Hook):
|
||||
runner.visualizer.add_scalars(
|
||||
tag, step=runner.iter + 1, file_path=self.json_log_path)
|
||||
|
||||
def after_val_iter(
|
||||
self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
|
||||
def after_val_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence] = None) -> None:
|
||||
"""Record logs after validation iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
batch_idx (int): The index of the current batch in the validation
|
||||
loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (sequence, optional): Outputs from model. Defaults to None.
|
||||
outputs (sequence, optional): Outputs from model.
|
||||
"""
|
||||
if self.every_n_inner_iters(batch_idx, self.interval):
|
||||
_, log_str = runner.log_processor.get_log_after_iter(
|
||||
runner, batch_idx, 'val')
|
||||
runner.logger.info(log_str)
|
||||
|
||||
def after_test_iter(
|
||||
self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
|
||||
def after_test_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence] = None) -> None:
|
||||
"""Record logs after testing iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
batch_idx (int): The index of the current batch in the test loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (sequence, optional): Outputs from model. Defaults to None.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
outputs (sequence, optional): Outputs from model.
|
||||
"""
|
||||
if self.every_n_inner_iters(batch_idx, self.interval):
|
||||
_, log_str = runner.log_processor.get_log_after_iter(
|
||||
|
@ -1,15 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import Optional, Sequence, Tuple
|
||||
from typing import Optional, Sequence, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from mmengine.utils.dl_utils import tensor2imgs
|
||||
|
||||
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
||||
|
||||
|
||||
# TODO: Due to interface changes, the current class
|
||||
# functions incorrectly
|
||||
@ -48,21 +49,18 @@ class NaiveVisualizationHook(Hook):
|
||||
unpad_image = input[:unpad_height, :unpad_width]
|
||||
return unpad_image
|
||||
|
||||
def after_test_iter(
|
||||
self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: Optional[Sequence[dict]] = None,
|
||||
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
|
||||
def after_test_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence] = None) -> None:
|
||||
"""Show or Write the predicted results.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the test loop.
|
||||
data_batch (Sequence[dict], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataElement], optional): Outputs from model.
|
||||
Defaults to None.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
outputs (Sequence, optional): Outputs from model.
|
||||
"""
|
||||
if self.every_n_inner_iters(batch_idx, self._interval):
|
||||
for data, output in zip(data_batch, outputs): # type: ignore
|
||||
|
@ -1,10 +1,10 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from mmengine.registry import HOOKS
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
@ -24,12 +24,12 @@ class ParamSchedulerHook(Hook):
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
data_batch (dict or tuple or list, optional): Data from dataloader.
|
||||
In order to keep this interface consistent with other hooks,
|
||||
we keep ``data_batch`` here. Defaults to None.
|
||||
we keep ``data_batch`` here.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
In order to keep this interface consistent with other hooks, we
|
||||
keep ``data_batch`` here. Defaults to None.
|
||||
keep ``data_batch`` here.
|
||||
"""
|
||||
|
||||
def step(param_schedulers):
|
||||
|
@ -1,12 +1,12 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Sequence
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.utils import get_git_hash
|
||||
from mmengine.version import __version__
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -1,9 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base_model import BaseModel
|
||||
from .data_preprocessor import (BaseDataElement, BaseDataPreprocessor,
|
||||
ImgDataPreprocessor)
|
||||
from .data_preprocessor import BaseDataPreprocessor, ImgDataPreprocessor
|
||||
|
||||
__all__ = [
|
||||
'BaseModel', 'BaseDataElement', 'ImgDataPreprocessor',
|
||||
'BaseDataPreprocessor'
|
||||
]
|
||||
__all__ = ['BaseModel', 'ImgDataPreprocessor', 'BaseDataPreprocessor']
|
||||
|
@ -1,21 +1,17 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import abstractmethod
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.registry import MODELS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from mmengine.utils import is_list_of
|
||||
from ..base_module import BaseModule
|
||||
from .data_preprocessor import BaseDataPreprocessor
|
||||
|
||||
ForwardResults = Union[Dict[str, torch.Tensor], List[BaseDataElement],
|
||||
Tuple[torch.Tensor], torch.Tensor]
|
||||
|
||||
|
||||
class BaseModel(BaseModule):
|
||||
"""Base class for all algorithmic models.
|
||||
@ -85,7 +81,7 @@ class BaseModel(BaseModule):
|
||||
f'`nn.Module` instance, but got '
|
||||
f'{type(data_preprocessor)}')
|
||||
|
||||
def train_step(self, data: List[dict],
|
||||
def train_step(self, data: Union[dict, tuple, list],
|
||||
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
|
||||
"""Implements the default model training process including
|
||||
preprocessing, model forward propagation, loss calculation,
|
||||
@ -96,7 +92,7 @@ class BaseModel(BaseModule):
|
||||
:class:`IterBasedTrainLoop` will call this method to update model
|
||||
parameters. The default parameter update process is as follows:
|
||||
|
||||
1. Calls ``self.data_processor(data, training=False) to collext
|
||||
1. Calls ``self.data_processor(data, training=False) to collect
|
||||
batch_inputs and corresponding data_samples(labels).
|
||||
2. Calls ``self(batch_inputs, data_samples, mode='loss')`` to get raw
|
||||
loss
|
||||
@ -105,7 +101,7 @@ class BaseModel(BaseModule):
|
||||
4. Calls ``optim_wrapper.update_params(loss)`` to update model.
|
||||
|
||||
Args:
|
||||
data (List[dict]): Data sampled from dataloader.
|
||||
data (dict or tuple or list): Data sampled from dataset.
|
||||
optim_wrapper (OptimWrapper): OptimWrapper instance
|
||||
used to update model parameters.
|
||||
|
||||
@ -114,13 +110,13 @@ class BaseModel(BaseModule):
|
||||
"""
|
||||
# Enable automatic mixed precision training context.
|
||||
with optim_wrapper.optim_context(self):
|
||||
batch_inputs, data_samples = self.data_preprocessor(data, True)
|
||||
losses = self(batch_inputs, data_samples, mode='loss')
|
||||
parsed_losses, log_vars = self.parse_losses(losses)
|
||||
data = self.data_preprocessor(data, True)
|
||||
losses = self._run_forward(data, mode='loss') # type: ignore
|
||||
parsed_losses, log_vars = self.parse_losses(losses) # type: ignore
|
||||
optim_wrapper.update_params(parsed_losses)
|
||||
return log_vars
|
||||
|
||||
def val_step(self, data: List[dict]) -> List[BaseDataElement]:
|
||||
def val_step(self, data: Union[tuple, dict, list]) -> list:
|
||||
"""Gets the predictions of given data.
|
||||
|
||||
Calls ``self.data_preprocessor(data, False)`` and
|
||||
@ -128,25 +124,25 @@ class BaseModel(BaseModule):
|
||||
predictions which will be passed to evaluator.
|
||||
|
||||
Args:
|
||||
data (List[dict]): Data sampled from dataloader.
|
||||
data (dict or tuple or list): Data sampled from dataset.
|
||||
|
||||
Returns:
|
||||
List[BaseDataElement]: The predictions of given data.
|
||||
list: The predictions of given data.
|
||||
"""
|
||||
inputs, data_sample = self.data_preprocessor(data, False)
|
||||
return self(inputs, data_sample, mode='predict')
|
||||
data = self.data_preprocessor(data, False)
|
||||
return self._run_forward(data, mode='predict') # type: ignore
|
||||
|
||||
def test_step(self, data: List[dict]) -> List[BaseDataElement]:
|
||||
def test_step(self, data: Union[dict, tuple, list]) -> list:
|
||||
"""``BaseModel`` implements ``test_step`` the same as ``val_step``.
|
||||
|
||||
Args:
|
||||
data (List[dict]): Data sampled from dataloader.
|
||||
data (dict or tuple or list): Data sampled from dataset.
|
||||
|
||||
Returns:
|
||||
List[BaseDataElement]: The predictions of given data.
|
||||
list: The predictions of given data.
|
||||
"""
|
||||
inputs, data_sample = self.data_preprocessor(data, False)
|
||||
return self(inputs, data_sample, mode='predict')
|
||||
data = self.data_preprocessor(data, False)
|
||||
return self._run_forward(data, mode='predict') # type: ignore
|
||||
|
||||
def parse_losses(
|
||||
self, losses: Dict[str, torch.Tensor]
|
||||
@ -239,16 +235,16 @@ class BaseModel(BaseModule):
|
||||
|
||||
@abstractmethod
|
||||
def forward(self,
|
||||
batch_inputs: torch.Tensor,
|
||||
data_samples: Optional[List[BaseDataElement]] = None,
|
||||
mode: str = 'tensor') -> ForwardResults:
|
||||
inputs: torch.Tensor,
|
||||
data_samples: Optional[list] = None,
|
||||
mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]:
|
||||
"""Returns losses or predictions of training, validation, testing, and
|
||||
simple inference process.
|
||||
|
||||
``forward`` method of BaseModel is an abstract method, its subclasses
|
||||
must implement this method.
|
||||
|
||||
Accepts ``batch_inputs`` and ``data_samples`` processed by
|
||||
Accepts ``batch_inputs`` and ``data_sample`` processed by
|
||||
:attr:`data_preprocessor`, and returns results according to mode
|
||||
arguments.
|
||||
|
||||
@ -263,9 +259,9 @@ class BaseModel(BaseModule):
|
||||
loss.
|
||||
|
||||
Args:
|
||||
batch_inputs (torch.Tensor): batch input tensor collated by
|
||||
inputs (torch.Tensor): batch input tensor collated by
|
||||
:attr:`data_preprocessor`.
|
||||
data_samples (List[BaseDataElement], optional):
|
||||
data_samples (list, optional):
|
||||
data samples collated by :attr:`data_preprocessor`.
|
||||
mode (str): mode should be one of ``loss``, ``predict`` and
|
||||
``tensor``
|
||||
@ -273,19 +269,36 @@ class BaseModel(BaseModule):
|
||||
- ``loss``: Called by ``train_step`` and return loss ``dict``
|
||||
used for logging
|
||||
- ``predict``: Called by ``val_step`` and ``test_step``
|
||||
and return list of ``BaseDataElement`` results used for
|
||||
computing metric.
|
||||
and return list of `results used for computing metric.
|
||||
- ``tensor``: Called by custom use to get ``Tensor`` type
|
||||
results.
|
||||
|
||||
Returns:
|
||||
ForwardResults:
|
||||
|
||||
dict or list:
|
||||
- If ``mode == loss``, return a ``dict`` of loss tensor used
|
||||
for backward and logging.
|
||||
- If ``mode == predict``, return a ``list`` of
|
||||
:obj:`BaseDataElement` for computing metric
|
||||
and getting inference result.
|
||||
- If ``mode == predict``, return a ``list`` of inference
|
||||
results.
|
||||
- If ``mode == tensor``, return a tensor or ``tuple`` of tensor
|
||||
or ``dict of tensor for custom use.
|
||||
"""
|
||||
|
||||
def _run_forward(self, data: Union[dict, tuple, list],
|
||||
mode: str) -> Union[Dict[str, torch.Tensor], list]:
|
||||
"""Unpacks data for :meth:`forward`
|
||||
|
||||
Args:
|
||||
data (dict or tuple or list): Data sampled from dataset.
|
||||
mode (str): Mode of forward.
|
||||
|
||||
Returns:
|
||||
dict or list: Results of training or testing mode.
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
results = self(**data, mode=mode)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
results = self(*data, mode=mode)
|
||||
else:
|
||||
raise TypeError('Output of `data_preprocessor` should be '
|
||||
f'list, tuple or dict, but got {type(data)}')
|
||||
return results
|
||||
|
@ -1,94 +1,74 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
import math
|
||||
from typing import Mapping, Optional, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmengine.registry import MODELS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from mmengine.utils import is_list_of
|
||||
from ..utils import stack_batch
|
||||
|
||||
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list]
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class BaseDataPreprocessor(nn.Module):
|
||||
"""Base data pre-processor used for collating and copying data to the
|
||||
target device.
|
||||
|
||||
``BaseDataPreprocessor`` performs data pre-processing according to the
|
||||
following steps:
|
||||
|
||||
- Collates the data sampled from dataloader.
|
||||
- Copies data to the target device.
|
||||
- Stacks the input tensor at the first dimension.
|
||||
"""Base data pre-processor used for copying data to the target device.
|
||||
|
||||
Subclasses inherit from ``BaseDataPreprocessor`` could override the
|
||||
forward method to implement custom data pre-processing, such as
|
||||
batch-resize, MixUp, or CutMix.
|
||||
|
||||
Warnings:
|
||||
Each item of data sampled from dataloader must be a dict and at least
|
||||
contain the ``inputs`` key. Furthermore, the value of ``inputs``
|
||||
must be a ``Tensor`` with the same shape.
|
||||
Note:
|
||||
Data dictionary returned by dataloader must be a dict and at least
|
||||
contain the ``inputs`` key.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._device = torch.device('cpu')
|
||||
|
||||
def collate_data(
|
||||
self,
|
||||
data: Sequence[dict]) -> Tuple[List[torch.Tensor], Optional[list]]:
|
||||
"""Collating and copying data to the target device.
|
||||
|
||||
Collates the data sampled from dataloader into a list of tensor and
|
||||
list of labels, and then copies tensor to the target device.
|
||||
|
||||
Subclasses could override it to be compatible with the custom format
|
||||
data sampled from custom dataloader.
|
||||
def cast_data(self, data: CastData) -> CastData:
|
||||
"""Copying data to the target device.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): Data sampled from dataloader.
|
||||
data (dict): Data returned by ``DataLoader``.
|
||||
|
||||
Returns:
|
||||
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
|
||||
tensor and list of labels at target device.
|
||||
CollatedResult: Inputs and data sample at target device.
|
||||
"""
|
||||
inputs = [_data['inputs'].to(self._device).float() for _data in data]
|
||||
batch_data_samples: List[BaseDataElement] = []
|
||||
# Model can get predictions without any data samples.
|
||||
for _data in data:
|
||||
if 'data_sample' in _data:
|
||||
batch_data_samples.append(_data['data_sample'])
|
||||
# Move data from CPU to corresponding device.
|
||||
batch_data_samples = [
|
||||
data_sample.to(self._device) for data_sample in batch_data_samples
|
||||
]
|
||||
if isinstance(data, Mapping):
|
||||
return {key: self.cast_data(data[key]) for key in data}
|
||||
elif isinstance(data, tuple) and hasattr(data, '_fields'):
|
||||
# namedtuple
|
||||
return type(data)(*(self.cast_data(sample)for sample in data)) # type: ignore # noqa: E501 # yapf:disable
|
||||
elif isinstance(data, Sequence):
|
||||
return [self.cast_data(sample) for sample in data]
|
||||
elif isinstance(data, torch.Tensor):
|
||||
return data.to(self.device)
|
||||
elif isinstance(data, BaseDataElement):
|
||||
return data.to(self.device)
|
||||
else:
|
||||
return data
|
||||
|
||||
if not batch_data_samples:
|
||||
batch_data_samples = None # type: ignore
|
||||
|
||||
return inputs, batch_data_samples
|
||||
|
||||
def forward(self,
|
||||
data: Sequence[dict],
|
||||
training: bool = False) -> Tuple[torch.Tensor, Optional[list]]:
|
||||
def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
|
||||
"""Preprocesses the data into the model input format.
|
||||
|
||||
After the data pre-processing of :meth:`collate_data`, ``forward``
|
||||
After the data pre-processing of :meth:`cast_data`, ``forward``
|
||||
will stack the input tensor list to a batch tensor at the first
|
||||
dimension.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): data sampled from dataloader.
|
||||
data (dict): Data returned by dataloader
|
||||
training (bool): Whether to enable training time augmentation.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||
model input.
|
||||
dict or list: Data in the same format as the model input.
|
||||
"""
|
||||
inputs, batch_data_samples = self.collate_data(data)
|
||||
batch_inputs = torch.stack(inputs, dim=0)
|
||||
return batch_inputs, batch_data_samples
|
||||
return self.cast_data(data) # type: ignore
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
@ -203,42 +183,79 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
|
||||
torch.tensor(std).view(-1, 1, 1), False)
|
||||
else:
|
||||
self._enable_normalize = False
|
||||
self.channel_conversion = rgb_to_bgr or bgr_to_rgb
|
||||
self._channel_conversion = rgb_to_bgr or bgr_to_rgb
|
||||
self.pad_size_divisor = pad_size_divisor
|
||||
self.pad_value = pad_value
|
||||
|
||||
def forward(self,
|
||||
data: Sequence[dict],
|
||||
training: bool = False) -> Tuple[torch.Tensor, Optional[list]]:
|
||||
def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
|
||||
"""Performs normalization、padding and bgr2rgb conversion based on
|
||||
``BaseDataPreprocessor``.
|
||||
|
||||
Args:
|
||||
data (Sequence[dict]): data sampled from dataloader.
|
||||
data (dict): Data sampled from dataset. If the collate
|
||||
function of DataLoader is :obj:`pseudo_collate`, data will be a
|
||||
list of dict. If collate function is :obj:`default_collate`,
|
||||
data will be a tuple with batch input tensor and list of data
|
||||
samples.
|
||||
training (bool): Whether to enable training time augmentation. If
|
||||
subclasses override this method, they can perform different
|
||||
preprocessing strategies for training and testing based on the
|
||||
value of ``training``.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
|
||||
model input.
|
||||
dict or list: Data in the same format as the model input.
|
||||
"""
|
||||
inputs, batch_data_samples = self.collate_data(data)
|
||||
for idx, _input in enumerate(inputs):
|
||||
# channel transform
|
||||
if self.channel_conversion:
|
||||
_input = _input[[2, 1, 0], ...]
|
||||
# Normalization.
|
||||
data = self.cast_data(data) # type: ignore
|
||||
_batch_inputs = data['inputs']
|
||||
# Process data with `pseudo_collate`.
|
||||
if is_list_of(_batch_inputs, torch.Tensor):
|
||||
batch_inputs = []
|
||||
for _batch_input in _batch_inputs:
|
||||
# channel transform
|
||||
if self._channel_conversion:
|
||||
_batch_input = _batch_input[[2, 1, 0], ...]
|
||||
# Convert to float after channel conversion to ensure
|
||||
# efficiency
|
||||
_batch_input = _batch_input.float()
|
||||
# Normalization.
|
||||
if self._enable_normalize:
|
||||
if self.mean.shape[0] == 3:
|
||||
assert _batch_input.dim(
|
||||
) == 3 and _batch_input.shape[0] == 3, (
|
||||
'If the mean has 3 values, the input tensor '
|
||||
'should in shape of (3, H, W), but got the tensor '
|
||||
f'with shape {_batch_input.shape}')
|
||||
_batch_input = (_batch_input - self.mean) / self.std
|
||||
batch_inputs.append(_batch_input)
|
||||
# Pad and stack Tensor.
|
||||
batch_inputs = stack_batch(batch_inputs, self.pad_size_divisor,
|
||||
self.pad_value)
|
||||
# Process data with `default_collate`.
|
||||
elif isinstance(_batch_inputs, torch.Tensor):
|
||||
assert _batch_inputs.dim() == 4, (
|
||||
'The input of `ImgDataPreprocessor` should be a NCHW tensor '
|
||||
'or a list of tensor, but got a tensor with shape: '
|
||||
f'{_batch_inputs.shape}')
|
||||
if self._channel_conversion:
|
||||
_batch_inputs = _batch_inputs[:, [2, 1, 0], ...]
|
||||
# Convert to float after channel conversion to ensure
|
||||
# efficiency
|
||||
_batch_inputs = _batch_inputs.float()
|
||||
if self._enable_normalize:
|
||||
if self.mean.shape[0] == 3:
|
||||
assert _input.dim() == 3 and _input.shape[0] == 3, (
|
||||
'If the mean has 3 values, the input tensor should in '
|
||||
'shape of (3, H, W), but got the tensor with shape '
|
||||
f'{_input.shape}')
|
||||
_input = (_input - self.mean) / self.std
|
||||
inputs[idx] = _input
|
||||
# Pad and stack Tensor.
|
||||
batch_inputs = stack_batch(inputs, self.pad_size_divisor,
|
||||
self.pad_value)
|
||||
return batch_inputs, batch_data_samples
|
||||
_batch_inputs = (_batch_inputs - self.mean) / self.std
|
||||
h, w = _batch_inputs.shape[2:]
|
||||
target_h = math.ceil(
|
||||
h / self.pad_size_divisor) * self.pad_size_divisor
|
||||
target_w = math.ceil(
|
||||
w / self.pad_size_divisor) * self.pad_size_divisor
|
||||
pad_h = target_h - h
|
||||
pad_w = target_w - w
|
||||
batch_inputs = F.pad(_batch_inputs, (0, pad_w, 0, pad_h),
|
||||
'constant', self.pad_value)
|
||||
else:
|
||||
raise TypeError('Output of `cast_data` should be a list of dict '
|
||||
'or a tuple with inputs and data_samples, but got'
|
||||
f'{type(data)}: {data}')
|
||||
data['inputs'] = batch_inputs
|
||||
data.setdefault('data_samples', None)
|
||||
return data
|
||||
|
@ -1,12 +1,11 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, List
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
from mmengine.optim import OptimWrapper
|
||||
from mmengine.registry import MODEL_WRAPPERS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from ..utils import detect_anomalous_params
|
||||
|
||||
|
||||
@ -90,7 +89,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
|
||||
super().__init__(module=module, **kwargs)
|
||||
self.detect_anomalous_params = detect_anomalous_params
|
||||
|
||||
def train_step(self, data: List[dict],
|
||||
def train_step(self, data: Union[dict, tuple, list],
|
||||
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
|
||||
"""Interface for model forward, backward and parameters updating during
|
||||
training process.
|
||||
@ -105,7 +104,7 @@ class MMDistributedDataParallel(DistributedDataParallel):
|
||||
- Return log messages of losses.
|
||||
|
||||
Args:
|
||||
data (List[dict]): Data sampled by dataloader.
|
||||
data (dict or tuple or list): Data sampled from dataset.
|
||||
optim_wrapper (OptimWrapper): A wrapper of optimizer to
|
||||
update parameters.
|
||||
|
||||
@ -114,33 +113,53 @@ class MMDistributedDataParallel(DistributedDataParallel):
|
||||
"""
|
||||
# Enable automatic mixed precision training context.
|
||||
with optim_wrapper.optim_context(self):
|
||||
batch_inputs, data_samples = self.module.data_preprocessor(
|
||||
data, training=True)
|
||||
losses = self(batch_inputs, data_samples, mode='loss')
|
||||
data = self.module.data_preprocessor(data, training=True)
|
||||
losses = self._run_forward(data, mode='loss')
|
||||
if self.detect_anomalous_params:
|
||||
detect_anomalous_params(losses, model=self)
|
||||
parsed_loss, log_vars = self.module.parse_losses(losses)
|
||||
optim_wrapper.update_params(parsed_loss)
|
||||
return log_vars
|
||||
|
||||
def val_step(self, data: List[dict]) -> List[BaseDataElement]:
|
||||
def val_step(self, data: Union[dict, tuple, list]) -> list:
|
||||
"""Gets the prediction of module during validation process.
|
||||
|
||||
Args:
|
||||
data (List[dict]): Data sampled by dataloader.
|
||||
data (dict or tuple or list): Data sampled from dataset.
|
||||
|
||||
Returns:
|
||||
List[BaseDataElement] or dict: The predictions of given data.
|
||||
list: The predictions of given data.
|
||||
"""
|
||||
return self.module.val_step(data)
|
||||
data = self.module.data_preprocessor(data, training=False)
|
||||
return self._run_forward(data, mode='predict')
|
||||
|
||||
def test_step(self, data: List[dict]) -> List[BaseDataElement]:
|
||||
def test_step(self, data: Union[dict, tuple, list]) -> list:
|
||||
"""Gets the predictions of module during testing process.
|
||||
|
||||
Args:
|
||||
data: Data sampled by dataloader.
|
||||
data (dict or tuple or list): Data sampled from dataset.
|
||||
|
||||
Returns:
|
||||
List[BaseDataElement]: The predictions of given data.
|
||||
list: The predictions of given data.
|
||||
"""
|
||||
return self.module.test_step(data)
|
||||
data = self.module.data_preprocessor(data, training=False)
|
||||
return self._run_forward(data, mode='predict')
|
||||
|
||||
def _run_forward(self, data: Union[dict, tuple, list], mode: str) -> Any:
|
||||
"""Unpacks data for :meth:`forward`
|
||||
|
||||
Args:
|
||||
data (dict or tuple or list): Data sampled from dataset.
|
||||
mode (str): Mode of forward.
|
||||
|
||||
Returns:
|
||||
dict or list: Results of training or testing mode.
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
results = self(**data, mode=mode)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
results = self(*data, mode=mode)
|
||||
else:
|
||||
raise TypeError('Output of `data_preprocessor` should be '
|
||||
f'list, tuple or dict, but got {type(data)}')
|
||||
return results
|
||||
|
@ -153,7 +153,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
super().__init__(module, process_group, cpu_offload,
|
||||
fsdp_auto_wrap_policy, backward_prefetch)
|
||||
|
||||
def train_step(self, data: List[dict],
|
||||
def train_step(self, data: dict,
|
||||
optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
|
||||
"""Interface for model forward, backward and parameters updating during
|
||||
training process.
|
||||
@ -168,7 +168,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
- Return log messages of losses.
|
||||
|
||||
Args:
|
||||
data (List[dict]): Data sampled by dataloader.
|
||||
data (dict): Data sampled by dataloader.
|
||||
optim_wrapper (OptimWrapper): A wrapper of optimizer to
|
||||
update parameters.
|
||||
|
||||
@ -177,18 +177,23 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
"""
|
||||
# enable automatic mixed precision training context.
|
||||
with optim_wrapper.optim_context(self):
|
||||
batch_inputs, data_samples = self.module.data_preprocessor(
|
||||
data, training=True)
|
||||
losses = self(batch_inputs, data_samples, mode='loss')
|
||||
data = self.module.data_preprocessor(data, training=True)
|
||||
if isinstance(data, dict):
|
||||
losses = self(**data, mode='loss')
|
||||
elif isinstance(data, (list, tuple)):
|
||||
losses = self(*data, mode='loss')
|
||||
else:
|
||||
raise TypeError('Output of `data_preprocessor` should be '
|
||||
f'list tuple or dict, but got {type(data)}')
|
||||
parsed_loss, log_vars = self.module.parse_losses(losses)
|
||||
optim_wrapper.update_params(parsed_loss)
|
||||
return log_vars
|
||||
|
||||
def val_step(self, data: List[dict]) -> List[BaseDataElement]:
|
||||
def val_step(self, data: dict) -> List[BaseDataElement]:
|
||||
"""Gets the prediction of module during validation process.
|
||||
|
||||
Args:
|
||||
data (List[dict]): Data sampled by dataloader.
|
||||
data (dict): Data sampled by dataloader.
|
||||
|
||||
Returns:
|
||||
List[BaseDataElement] or dict: The predictions of given data.
|
||||
@ -196,11 +201,11 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel):
|
||||
inputs, data_sample = self.module.data_preprocessor(data, False)
|
||||
return self(inputs, data_sample, mode='predict')
|
||||
|
||||
def test_step(self, data: List[dict]) -> List[BaseDataElement]:
|
||||
def test_step(self, data: dict) -> List[BaseDataElement]:
|
||||
"""Gets the predictions of module during testing process.
|
||||
|
||||
Args:
|
||||
data: Data sampled by dataloader.
|
||||
data (dict): Data sampled by dataloader.
|
||||
|
||||
Returns:
|
||||
List[BaseDataElement]: The predictions of given data.
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from typing import Dict, List
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -9,7 +9,6 @@ from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from mmengine.device import get_device
|
||||
from mmengine.optim import OptimWrapperDict
|
||||
from mmengine.registry import MODEL_WRAPPERS
|
||||
from mmengine.structures import BaseDataElement
|
||||
from .distributed import MMDistributedDataParallel
|
||||
|
||||
|
||||
@ -86,13 +85,13 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel):
|
||||
**kwargs)
|
||||
module._modules[name] = sub_module
|
||||
|
||||
def train_step(self, data: List[dict],
|
||||
def train_step(self, data: Union[dict, tuple, list],
|
||||
optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]:
|
||||
"""Interface for model forward, backward and parameters updating during
|
||||
training process.
|
||||
|
||||
Args:
|
||||
data: Data sampled by dataloader.
|
||||
data (dict or tuple or list): Data sampled from dataset.
|
||||
optim_wrapper (OptimWrapperDict): A wrapper of optimizer to
|
||||
update parameters.
|
||||
|
||||
@ -101,25 +100,25 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel):
|
||||
"""
|
||||
return self.module.train_step(data, optim_wrapper)
|
||||
|
||||
def val_step(self, data) -> List[BaseDataElement]:
|
||||
def val_step(self, data: Union[dict, tuple, list]) -> list:
|
||||
"""Gets the prediction of module during validation process.
|
||||
|
||||
Args:
|
||||
data (List[dict]): Data sampled by dataloader.
|
||||
data (dict or tuple or list): Data sampled from dataset.
|
||||
|
||||
Returns:
|
||||
List[BaseDataElement]: The predictions of given data.
|
||||
list: The predictions of given data.
|
||||
"""
|
||||
return self.module.val_step(data)
|
||||
|
||||
def test_step(self, data: List[dict]) -> List[BaseDataElement]:
|
||||
def test_step(self, data: Union[dict, tuple, list]) -> list:
|
||||
"""Gets the predictions of module during testing process.
|
||||
|
||||
Args:
|
||||
data: Data sampled by dataloader.
|
||||
data (dict or tuple or list): Data sampled from dataset.
|
||||
|
||||
Returns:
|
||||
ForwardResults: The predictions of given data.
|
||||
list: The predictions of given data.
|
||||
"""
|
||||
return self.module.test_step(data)
|
||||
|
||||
|
@ -361,7 +361,7 @@ class ValLoop(BaseLoop):
|
||||
# outputs should be sequence of BaseDataElement
|
||||
with autocast(enabled=self.fp16):
|
||||
outputs = self.runner.model.val_step(data_batch)
|
||||
self.evaluator.process(data_batch, outputs)
|
||||
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
|
||||
self.runner.call_hook(
|
||||
'after_val_iter',
|
||||
batch_idx=idx,
|
||||
@ -429,10 +429,10 @@ class TestLoop(BaseLoop):
|
||||
'before_test_iter', batch_idx=idx, data_batch=data_batch)
|
||||
# predictions should be sequence of BaseDataElement
|
||||
with autocast(enabled=self.fp16):
|
||||
predictions = self.runner.model.test_step(data_batch)
|
||||
self.evaluator.process(data_batch, predictions)
|
||||
outputs = self.runner.model.test_step(data_batch)
|
||||
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
|
||||
self.runner.call_hook(
|
||||
'after_test_iter',
|
||||
batch_idx=idx,
|
||||
data_batch=data_batch,
|
||||
outputs=predictions)
|
||||
outputs=outputs)
|
||||
|
@ -20,7 +20,7 @@ from torch.utils.data import DataLoader
|
||||
|
||||
import mmengine
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.dataset import pseudo_collate, worker_init_fn
|
||||
from mmengine.dataset import COLLATE_FUNCTIONS, worker_init_fn
|
||||
from mmengine.device import get_device
|
||||
from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
|
||||
is_distributed, master_only, sync_random_seed)
|
||||
@ -1393,14 +1393,19 @@ class Runner:
|
||||
|
||||
# The default behavior of `collat_fn` in dataloader is to
|
||||
# merge a list of samples to form a mini-batch of Tensor(s).
|
||||
# However, to make this more flexible, collate_fn in MMengine does
|
||||
# nothing. The action to merge a list of samples will be handled
|
||||
# in model.
|
||||
# However, in mmengine, if `collate_fn` is not defined in
|
||||
# dataloader_cfg, `pseudo_collate` will only convert the list of
|
||||
# samples into a dict without stacking the batch tensor.
|
||||
collate_fn_cfg = dataloader_cfg.pop('collate_fn',
|
||||
dict(type='pseudo_collate'))
|
||||
collate_fn_type = collate_fn_cfg.pop('type')
|
||||
collate_fn = COLLATE_FUNCTIONS.get(collate_fn_type)
|
||||
collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore
|
||||
data_loader = DataLoader(
|
||||
dataset=dataset,
|
||||
sampler=sampler if batch_sampler is None else None,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=pseudo_collate,
|
||||
collate_fn=collate_fn,
|
||||
worker_init_fn=init_fn,
|
||||
**dataloader_cfg)
|
||||
return data_loader
|
||||
|
@ -1080,8 +1080,7 @@ class Visualizer(ManagerMixin):
|
||||
def add_datasample(self,
|
||||
name,
|
||||
image: np.ndarray,
|
||||
gt_sample: Optional['BaseDataElement'] = None,
|
||||
pred_sample: Optional['BaseDataElement'] = None,
|
||||
data_sample: Optional['BaseDataElement'] = None,
|
||||
draw_gt: bool = True,
|
||||
draw_pred: bool = True,
|
||||
show: bool = False,
|
||||
|
155
tests/test_data/test_data_utils.py
Normal file
155
tests/test_data/test_data_utils.py
Normal file
@ -0,0 +1,155 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmengine.dataset import default_collate, pseudo_collate
|
||||
from mmengine.structures import BaseDataElement
|
||||
from mmengine.utils import is_list_of
|
||||
|
||||
|
||||
class TestDataUtils(TestCase):
|
||||
|
||||
def test_pseudo_collate(self):
|
||||
# Test with list of dict tensor inputs.
|
||||
input1 = torch.randn(1, 3, 5)
|
||||
input2 = torch.randn(1, 3, 5)
|
||||
label1 = torch.randn(1)
|
||||
label2 = torch.randn(1)
|
||||
|
||||
data_batch = [
|
||||
dict(inputs=input1, data_sample=label1),
|
||||
dict(inputs=input2, data_sample=label2)
|
||||
]
|
||||
data_batch = pseudo_collate(data_batch)
|
||||
self.assertTrue(torch.allclose(input1, data_batch['inputs'][0]))
|
||||
self.assertTrue(torch.allclose(input2, data_batch['inputs'][1]))
|
||||
self.assertTrue(torch.allclose(label1, data_batch['data_sample'][0]))
|
||||
self.assertTrue(torch.allclose(label2, data_batch['data_sample'][1]))
|
||||
|
||||
# Test with list of dict, and each element contains `data_sample`
|
||||
# inputs
|
||||
data_sample1 = BaseDataElement(label=torch.tensor(1))
|
||||
data_sample2 = BaseDataElement(label=torch.tensor(1))
|
||||
data = [
|
||||
dict(inputs=input1, data_sample=data_sample1),
|
||||
dict(inputs=input2, data_sample=data_sample2),
|
||||
]
|
||||
data_batch = pseudo_collate(data)
|
||||
batch_inputs, batch_data_sample = (data_batch['inputs'],
|
||||
data_batch['data_sample'])
|
||||
# check batch_inputs
|
||||
self.assertTrue(is_list_of(batch_inputs, torch.Tensor))
|
||||
self.assertIs(input1, batch_inputs[0])
|
||||
self.assertIs(input2, batch_inputs[1])
|
||||
|
||||
# check data_sample
|
||||
self.assertIs(batch_data_sample[0], data_sample1)
|
||||
self.assertIs(batch_data_sample[1], data_sample2)
|
||||
|
||||
# Test with list of tuple, each tuple is a nested dict instance
|
||||
data_batch = [(dict(
|
||||
inputs=input1,
|
||||
data_sample=data_sample1,
|
||||
value=1,
|
||||
name='1',
|
||||
nested=dict(data_sample=data_sample1)),
|
||||
dict(
|
||||
inputs=input2,
|
||||
data_sample=data_sample2,
|
||||
value=2,
|
||||
name='2',
|
||||
nested=dict(data_sample=data_sample2))),
|
||||
(dict(
|
||||
inputs=input1,
|
||||
data_sample=data_sample1,
|
||||
value=1,
|
||||
name='1',
|
||||
nested=dict(data_sample=data_sample1)),
|
||||
dict(
|
||||
inputs=input2,
|
||||
data_sample=data_sample2,
|
||||
value=2,
|
||||
name='2',
|
||||
nested=dict(data_sample=data_sample2)))]
|
||||
data_batch = pseudo_collate(data_batch)
|
||||
batch_inputs_0 = data_batch[0]['inputs']
|
||||
batch_inputs_1 = data_batch[1]['inputs']
|
||||
batch_data_sample_0 = data_batch[0]['data_sample']
|
||||
batch_data_sample_1 = data_batch[1]['data_sample']
|
||||
batch_value_0 = data_batch[0]['value']
|
||||
batch_value_1 = data_batch[1]['value']
|
||||
batch_name_0 = data_batch[0]['name']
|
||||
batch_name_1 = data_batch[1]['name']
|
||||
batch_nested_0 = data_batch[0]['nested']
|
||||
batch_nested_1 = data_batch[1]['nested']
|
||||
|
||||
self.assertTrue(is_list_of(batch_inputs_0, torch.Tensor))
|
||||
self.assertTrue(is_list_of(batch_inputs_1, torch.Tensor))
|
||||
self.assertIs(batch_inputs_0[0], input1)
|
||||
self.assertIs(batch_inputs_0[1], input1)
|
||||
self.assertIs(batch_inputs_1[0], input2)
|
||||
self.assertIs(batch_inputs_1[1], input2)
|
||||
|
||||
self.assertIs(batch_data_sample_0[0], data_sample1)
|
||||
self.assertIs(batch_data_sample_0[1], data_sample1)
|
||||
self.assertIs(batch_data_sample_1[0], data_sample2)
|
||||
self.assertIs(batch_data_sample_1[1], data_sample2)
|
||||
|
||||
self.assertEqual(batch_value_0, [1, 1])
|
||||
self.assertEqual(batch_value_1, [2, 2])
|
||||
|
||||
self.assertEqual(batch_name_0, ['1', '1'])
|
||||
self.assertEqual(batch_name_1, ['2', '2'])
|
||||
|
||||
self.assertIs(batch_nested_0['data_sample'][0], data_sample1)
|
||||
self.assertIs(batch_nested_0['data_sample'][1], data_sample1)
|
||||
self.assertIs(batch_nested_1['data_sample'][0], data_sample2)
|
||||
self.assertIs(batch_nested_1['data_sample'][1], data_sample2)
|
||||
|
||||
def test_default_collate(self):
|
||||
# `default_collate` has comment logic with `pseudo_collate`, therefore
|
||||
# only test it cam stack batch tensor, convert int or float to tensor.
|
||||
input1 = torch.randn(1, 3, 5)
|
||||
input2 = torch.randn(1, 3, 5)
|
||||
data_batch = [(
|
||||
dict(inputs=input1, value=1, array=np.array(1)),
|
||||
dict(inputs=input2, value=2, array=np.array(2)),
|
||||
),
|
||||
(
|
||||
dict(inputs=input1, value=1, array=np.array(1)),
|
||||
dict(inputs=input2, value=2, array=np.array(2)),
|
||||
)]
|
||||
data_batch = default_collate(data_batch)
|
||||
batch_inputs_0 = data_batch[0]['inputs']
|
||||
batch_inputs_1 = data_batch[1]['inputs']
|
||||
batch_value_0 = data_batch[0]['value']
|
||||
batch_value_1 = data_batch[1]['value']
|
||||
batch_array_0 = data_batch[0]['array']
|
||||
batch_array_1 = data_batch[1]['array']
|
||||
|
||||
self.assertEqual(tuple(batch_inputs_0.shape), (2, 1, 3, 5))
|
||||
self.assertEqual(tuple(batch_inputs_1.shape), (2, 1, 3, 5))
|
||||
self.assertTrue(
|
||||
torch.allclose(batch_inputs_0, torch.stack([input1, input1])))
|
||||
self.assertTrue(
|
||||
torch.allclose(batch_inputs_1, torch.stack([input2, input2])))
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(batch_value_0,
|
||||
torch.stack([torch.tensor(1),
|
||||
torch.tensor(1)])))
|
||||
self.assertTrue(
|
||||
torch.allclose(batch_value_1,
|
||||
torch.stack([torch.tensor(2),
|
||||
torch.tensor(2)])))
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(batch_array_0,
|
||||
torch.stack([torch.tensor(1),
|
||||
torch.tensor(1)])))
|
||||
self.assertTrue(
|
||||
torch.allclose(batch_array_1,
|
||||
torch.stack([torch.tensor(2),
|
||||
torch.tensor(2)])))
|
@ -41,9 +41,9 @@ class ToyMetric(BaseMetric):
|
||||
|
||||
def process(self, data_batch, predictions):
|
||||
results = [{
|
||||
'pred': pred.get('pred'),
|
||||
'label': data['data_sample'].get('label')
|
||||
} for pred, data in zip(predictions, data_batch)]
|
||||
'pred': prediction['label'],
|
||||
'label': prediction['label']
|
||||
} for prediction in predictions]
|
||||
self.results.extend(results)
|
||||
|
||||
def compute_metrics(self, results: List):
|
||||
@ -68,8 +68,7 @@ class NonPrefixedMetric(BaseMetric):
|
||||
"""Evaluator with unassigned `default_prefix` to test the warning
|
||||
information."""
|
||||
|
||||
def process(self, data_batch: Sequence[dict],
|
||||
predictions: Sequence[dict]) -> None:
|
||||
def process(self, data_batch, predictions: Sequence[dict]) -> None:
|
||||
pass
|
||||
|
||||
def compute_metrics(self, results: list) -> dict:
|
||||
@ -81,12 +80,13 @@ def generate_test_results(size, batch_size, pred, label):
|
||||
bs_residual = size % batch_size
|
||||
for i in range(num_batch):
|
||||
bs = bs_residual if i == num_batch - 1 else batch_size
|
||||
data_batch = [
|
||||
dict(
|
||||
inputs=np.zeros((3, 10, 10)),
|
||||
data_sample=BaseDataElement(label=label)) for _ in range(bs)
|
||||
data_batch = {
|
||||
'inputs': [np.zeros((3, 10, 10)) for _ in range(bs)],
|
||||
'data_sample': [BaseDataElement(label=label) for _ in range(bs)]
|
||||
}
|
||||
predictions = [
|
||||
BaseDataElement(pred=pred, label=label) for _ in range(bs)
|
||||
]
|
||||
predictions = [BaseDataElement(pred=pred) for _ in range(bs)]
|
||||
yield (data_batch, predictions)
|
||||
|
||||
|
||||
@ -99,9 +99,9 @@ class TestEvaluator(TestCase):
|
||||
size = 10
|
||||
batch_size = 4
|
||||
|
||||
for data_samples, predictions in generate_test_results(
|
||||
for data_samples, outputs in generate_test_results(
|
||||
size, batch_size, pred=1, label=1):
|
||||
evaluator.process(data_samples, predictions)
|
||||
evaluator.process(data_samples=outputs, data_batch=data_samples)
|
||||
|
||||
metrics = evaluator.evaluate(size=size)
|
||||
self.assertAlmostEqual(metrics['Toy/accuracy'], 1.0)
|
||||
@ -124,9 +124,9 @@ class TestEvaluator(TestCase):
|
||||
size = 10
|
||||
batch_size = 4
|
||||
|
||||
for data_samples, predictions in generate_test_results(
|
||||
for data_samples, outputs in generate_test_results(
|
||||
size, batch_size, pred=1, label=1):
|
||||
evaluator.process(data_samples, predictions)
|
||||
evaluator.process(data_samples=outputs, data_batch=data_samples)
|
||||
|
||||
metrics = evaluator.evaluate(size=size)
|
||||
|
||||
@ -145,9 +145,9 @@ class TestEvaluator(TestCase):
|
||||
size = 10
|
||||
batch_size = 4
|
||||
|
||||
for data_samples, predictions in generate_test_results(
|
||||
for data_samples, outputs in generate_test_results(
|
||||
size, batch_size, pred=1, label=1):
|
||||
evaluator.process(data_samples, predictions)
|
||||
evaluator.process(data_samples=outputs, data_batch=data_samples)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
@ -233,13 +233,22 @@ class TestEvaluator(TestCase):
|
||||
|
||||
size = 10
|
||||
|
||||
all_data = [
|
||||
dict(
|
||||
inputs=np.zeros((3, 10, 10)),
|
||||
data_sample=BaseDataElement(label=1)) for _ in range(size)
|
||||
all_data = [dict() for _ in range(10)]
|
||||
all_predictions = [
|
||||
BaseDataElement(pred=0, label=1) for _ in range(size)
|
||||
]
|
||||
all_predictions = [BaseDataElement(pred=0) for _ in range(size)]
|
||||
evaluator.offline_evaluate(all_data, all_predictions)
|
||||
evaluator.offline_evaluate(all_predictions, all_data)
|
||||
|
||||
# Test with None data
|
||||
all_data = None
|
||||
evaluator.offline_evaluate(all_predictions, all_data)
|
||||
|
||||
# Different length of data and predictions will raise an error.
|
||||
all_data = [dict() for _ in range(9)]
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
'outputs and data should have the same length'):
|
||||
evaluator.offline_evaluate(all_predictions, all_data)
|
||||
|
||||
@unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu')
|
||||
def test_evaluate_cast_cpu(self):
|
||||
@ -256,11 +265,12 @@ class TestEvaluator(TestCase):
|
||||
for _ in range(size)
|
||||
]
|
||||
all_predictions = [
|
||||
BaseDataElement(pred=torch.zeros((1, ), device='cuda'))
|
||||
for _ in range(size)
|
||||
BaseDataElement(
|
||||
pred=torch.zeros((1, ), device='cuda'),
|
||||
label=torch.ones((1, ), device='cuda')) for _ in range(size)
|
||||
]
|
||||
for data, pred in zip(all_data, all_predictions):
|
||||
evaluator.process([data], [pred])
|
||||
evaluator.process([pred], [data])
|
||||
|
||||
def test_results_device(results: List):
|
||||
for result in results:
|
||||
|
@ -19,8 +19,8 @@ class TestDumpResults(TestCase):
|
||||
|
||||
def test_process(self):
|
||||
metric = DumpResults(out_file_path='./results.pkl')
|
||||
predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
|
||||
metric.process(None, predictions)
|
||||
data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
|
||||
metric.process(None, data_samples)
|
||||
self.assertEqual(len(metric.results), 1)
|
||||
self.assertEqual(metric.results[0]['data'][0].device,
|
||||
torch.device('cpu'))
|
||||
@ -29,8 +29,8 @@ class TestDumpResults(TestCase):
|
||||
temp_dir = tempfile.TemporaryDirectory()
|
||||
path = osp.join(temp_dir.name, 'results.pkl')
|
||||
metric = DumpResults(out_file_path=path)
|
||||
predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
|
||||
metric.process(None, predictions)
|
||||
data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))]
|
||||
metric.process(None, data_samples)
|
||||
metric.compute_metrics(metric.results)
|
||||
self.assertTrue(osp.isfile(path))
|
||||
|
||||
|
@ -22,9 +22,10 @@ class ToyModel(nn.Module):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, batch_inputs, labels, mode='tensor'):
|
||||
labels = torch.stack(labels)
|
||||
outputs = self.linear(batch_inputs)
|
||||
def forward(self, inputs, data_sample, mode='tensor'):
|
||||
labels = torch.stack(data_sample)
|
||||
inputs = torch.stack(inputs)
|
||||
outputs = self.linear(inputs)
|
||||
if mode == 'tensor':
|
||||
return outputs
|
||||
elif mode == 'loss':
|
||||
|
@ -28,15 +28,15 @@ class ToyModel(BaseModel):
|
||||
super().__init__(data_preprocessor=data_preprocessor, init_cfg=None)
|
||||
self.conv = nn.Conv2d(3, 1, 1)
|
||||
|
||||
def forward(self, batch_inputs, data_samples=None, mode='tensor'):
|
||||
def forward(self, inputs, data_sample=None, mode='tensor'):
|
||||
if mode == 'loss':
|
||||
out = self.conv(batch_inputs)
|
||||
out = self.conv(inputs)
|
||||
return dict(loss=out)
|
||||
elif mode == 'predict':
|
||||
out = self.conv(batch_inputs)
|
||||
out = self.conv(inputs)
|
||||
return out
|
||||
elif mode == 'tensor':
|
||||
out = self.conv(batch_inputs)
|
||||
out = self.conv(inputs)
|
||||
return out
|
||||
|
||||
|
||||
@ -98,34 +98,34 @@ class TestBaseModel(TestCase):
|
||||
model = ToyModel()
|
||||
optimizer = SGD(model.parameters(), lr=0.1)
|
||||
optim_wrapper = OptimWrapper(optimizer)
|
||||
inputs = torch.randn(3, 1, 1)
|
||||
data = dict(inputs=inputs)
|
||||
inputs = torch.randn(1, 3, 1, 1)
|
||||
data = dict(inputs=inputs, data_sample=None)
|
||||
# initiate grad.
|
||||
# model.conv.weight.grad = torch.randn(1, 3, 1, 1)
|
||||
log_vars = model.train_step([data], optim_wrapper)
|
||||
log_vars = model.train_step(data, optim_wrapper)
|
||||
self.assertIsNotNone(model.conv.weight.grad)
|
||||
self.assertIsInstance(log_vars['loss'], torch.Tensor)
|
||||
|
||||
def test_val_step(self):
|
||||
inputs = torch.randn(3, 1, 1)
|
||||
data = dict(inputs=inputs)
|
||||
inputs = torch.randn(1, 3, 1, 1)
|
||||
data = dict(inputs=inputs, data_sample=None)
|
||||
model = ToyModel()
|
||||
out = model.val_step([data])
|
||||
out = model.val_step(data)
|
||||
self.assertIsInstance(out, torch.Tensor)
|
||||
|
||||
def test_test_step(self):
|
||||
inputs = torch.randn(3, 1, 1)
|
||||
data = dict(inputs=inputs)
|
||||
inputs = torch.randn(1, 3, 1, 1)
|
||||
data = dict(inputs=inputs, data_sample=None)
|
||||
model = ToyModel()
|
||||
out = model.val_step([data])
|
||||
out = model.val_step(data)
|
||||
self.assertIsInstance(out, torch.Tensor)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available')
|
||||
def test_cuda(self):
|
||||
inputs = torch.randn(3, 1, 1).cuda()
|
||||
data = dict(inputs=inputs)
|
||||
inputs = torch.randn(1, 3, 1, 1).cuda()
|
||||
data = dict(inputs=inputs, data_sample=None)
|
||||
model = ToyModel().cuda()
|
||||
out = model.val_step([data])
|
||||
out = model.val_step(data)
|
||||
self.assertEqual(out.device.type, 'cuda')
|
||||
|
||||
model = NestedModel()
|
||||
@ -139,10 +139,10 @@ class TestBaseModel(TestCase):
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available')
|
||||
def test_to(self):
|
||||
inputs = torch.randn(3, 1, 1).to('cuda:0')
|
||||
data = dict(inputs=inputs)
|
||||
inputs = torch.randn(1, 3, 1, 1).to('cuda:0')
|
||||
data = dict(inputs=inputs, data_sample=None)
|
||||
model = ToyModel().to(torch.cuda.current_device())
|
||||
out = model.val_step([data])
|
||||
out = model.val_step(data)
|
||||
self.assertEqual(out.device.type, 'cuda')
|
||||
|
||||
model = NestedModel()
|
||||
|
@ -16,54 +16,73 @@ class TestBaseDataPreprocessor(TestCase):
|
||||
self.assertEqual(base_data_preprocessor._device.type, 'cpu')
|
||||
|
||||
def test_forward(self):
|
||||
# Test cpu forward with list of data samples.
|
||||
base_data_preprocessor = BaseDataPreprocessor()
|
||||
input1 = torch.randn(1, 3, 5)
|
||||
input2 = torch.randn(1, 3, 5)
|
||||
label1 = torch.randn(1)
|
||||
label2 = torch.randn(1)
|
||||
|
||||
data = [
|
||||
dict(inputs=input1, data_sample=label1),
|
||||
dict(inputs=input2, data_sample=label2)
|
||||
]
|
||||
data = dict(inputs=[input1, input2], data_sample=[label1, label2])
|
||||
|
||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||
self.assertEqual(batch_inputs.shape, (2, 1, 3, 5))
|
||||
output = base_data_preprocessor(data)
|
||||
batch_inputs, batch_labels = output['inputs'], output['data_sample']
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
||||
self.assertEqual(batch_inputs[0].shape, (1, 3, 5))
|
||||
|
||||
assert_allclose(input1, batch_inputs[0])
|
||||
assert_allclose(input2, batch_inputs[1])
|
||||
assert_allclose(label1, batch_labels[0])
|
||||
assert_allclose(label2, batch_labels[1])
|
||||
|
||||
# Test with tuple of batch inputs and batch data samples
|
||||
data = dict(
|
||||
inputs=torch.stack([input1, input2]), data_sample=[label1, label2])
|
||||
output = base_data_preprocessor(data)['inputs']
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
||||
|
||||
# Test cuda forward
|
||||
if torch.cuda.is_available():
|
||||
# Test with list of data samples.
|
||||
base_data_preprocessor = base_data_preprocessor.cuda()
|
||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||
output = base_data_preprocessor(data)
|
||||
batch_inputs, batch_labels = output['inputs'], output[
|
||||
'data_sample']
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||
|
||||
base_data_preprocessor = base_data_preprocessor.cpu()
|
||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||
output = base_data_preprocessor(data)
|
||||
batch_inputs, batch_labels = output['inputs'], output[
|
||||
'data_sample']
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||
self.assertEqual(batch_inputs.device.type, 'cpu')
|
||||
|
||||
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
|
||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||
output = base_data_preprocessor(data)
|
||||
batch_inputs, batch_labels = output['inputs'], output[
|
||||
'data_sample']
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||
|
||||
# device of `base_data_preprocessor` is cuda, output should be
|
||||
# cuda tensor.
|
||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||
self.assertEqual(batch_labels[0].device.type, 'cuda')
|
||||
|
||||
class TestImgataPreprocessor(TestBaseDataPreprocessor):
|
||||
|
||||
class TestImgDataPreprocessor(TestBaseDataPreprocessor):
|
||||
|
||||
def test_init(self):
|
||||
# initiate model without `data_preprocessor`
|
||||
# Initiate processor without arguments
|
||||
data_processor = ImgDataPreprocessor()
|
||||
self.assertFalse(data_processor.channel_conversion)
|
||||
self.assertFalse(data_processor._channel_conversion)
|
||||
self.assertFalse(hasattr(data_processor, 'mean'))
|
||||
self.assertFalse(hasattr(data_processor, 'std'))
|
||||
self.assertEqual(data_processor.pad_size_divisor, 1)
|
||||
assert_allclose(data_processor.pad_value, torch.tensor(0))
|
||||
# initiate model with data_preprocessor` and feat keys
|
||||
|
||||
# Initiate model with bgr2rgb, mean, std .etc..
|
||||
data_processor = ImgDataPreprocessor(
|
||||
bgr_to_rgb=True,
|
||||
mean=[0, 0, 0],
|
||||
@ -71,7 +90,7 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor):
|
||||
pad_size_divisor=16,
|
||||
pad_value=10)
|
||||
self.assertTrue(data_processor._enable_normalize)
|
||||
self.assertTrue(data_processor.channel_conversion, True)
|
||||
self.assertTrue(data_processor._channel_conversion, True)
|
||||
assert_allclose(data_processor.mean,
|
||||
torch.tensor([0, 0, 0]).view(-1, 1, 1))
|
||||
assert_allclose(data_processor.std,
|
||||
@ -113,10 +132,11 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor):
|
||||
inputs2 = torch.randn(3, 15, 15)
|
||||
data_sample1 = InstanceData(bboxes=torch.randn(5, 4))
|
||||
data_sample2 = InstanceData(bboxes=torch.randn(5, 4))
|
||||
data = [
|
||||
dict(inputs=inputs1.clone(), data_sample=data_sample1.clone()),
|
||||
dict(inputs=inputs2.clone(), data_sample=data_sample2.clone())
|
||||
]
|
||||
|
||||
data = dict(
|
||||
inputs=[inputs1.clone(), inputs2.clone()],
|
||||
data_sample=[data_sample1.clone(),
|
||||
data_sample2.clone()])
|
||||
|
||||
std = torch.tensor([1, 2, 3]).view(-1, 1, 1)
|
||||
target_inputs1 = (inputs1.clone()[[2, 1, 0], ...] - 127.5) / std
|
||||
@ -126,7 +146,8 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor):
|
||||
target_inputs2 = F.pad(target_inputs2, (0, 1, 0, 1), value=10)
|
||||
|
||||
target_inputs = [target_inputs1, target_inputs2]
|
||||
inputs, data_samples = data_preprocessor(data, True)
|
||||
output = data_preprocessor(data, True)
|
||||
inputs, data_samples = output['inputs'], output['data_sample']
|
||||
self.assertTrue(torch.is_floating_point(inputs))
|
||||
|
||||
target_data_samples = [data_sample1, data_sample2]
|
||||
@ -147,7 +168,8 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor):
|
||||
target_inputs2 = F.pad(target_inputs2, (0, 1, 0, 1), value=10)
|
||||
|
||||
target_inputs = [target_inputs1, target_inputs2]
|
||||
inputs, data_samples = data_preprocessor(data, True)
|
||||
output = data_preprocessor(data, True)
|
||||
inputs, data_samples = output['inputs'], output['data_sample']
|
||||
self.assertTrue(torch.is_floating_point(inputs))
|
||||
|
||||
target_data_samples = [data_sample1, data_sample2]
|
||||
@ -159,22 +181,53 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor):
|
||||
# Test gray image with 3 dim mean will raise error
|
||||
data_preprocessor = ImgDataPreprocessor(
|
||||
mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5))
|
||||
data = [
|
||||
dict(inputs=torch.ones(10, 10)),
|
||||
dict(inputs=torch.ones(10, 10))
|
||||
]
|
||||
data = dict(
|
||||
inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None)
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'If the mean has 3 values'):
|
||||
data_preprocessor(data)
|
||||
|
||||
data = [
|
||||
dict(inputs=torch.ones(1, 10, 10)),
|
||||
dict(inputs=torch.ones(1, 10, 10))
|
||||
]
|
||||
data = dict(
|
||||
inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None)
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
'If the mean has 3 values'):
|
||||
data_preprocessor(data)
|
||||
|
||||
# Test stacked batch inputs and batch data samples
|
||||
data_preprocessor = ImgDataPreprocessor(
|
||||
mean=(127.5, 127.5, 127.5),
|
||||
std=(127.5, 127.5, 127.5),
|
||||
rgb_to_bgr=True,
|
||||
pad_size_divisor=16)
|
||||
_batch_inputs = torch.randn(2, 3, 10, 10)
|
||||
_batch_labels = [torch.randn(1), torch.randn(1)]
|
||||
data = dict(inputs=_batch_inputs, data_sample=_batch_labels)
|
||||
output = data_preprocessor(data)
|
||||
inputs, data_samples = output['inputs'], output['data_sample']
|
||||
target_batch_inputs = _batch_inputs[:, [2, 1, 0], ...]
|
||||
target_batch_inputs = (target_batch_inputs - 127.5) / 127.5
|
||||
target_batch_inputs = F.pad(target_batch_inputs, (0, 6, 0, 6), value=0)
|
||||
self.assertEqual(inputs.shape, torch.Size([2, 3, 16, 16]))
|
||||
self.assertTrue(torch.is_floating_point(inputs))
|
||||
assert_allclose(target_batch_inputs, inputs)
|
||||
|
||||
# Test batch inputs without convert channel order and pad
|
||||
data_preprocessor = ImgDataPreprocessor(
|
||||
mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5))
|
||||
_batch_inputs = torch.randn(2, 3, 10, 10)
|
||||
_batch_labels = [torch.randn(1), torch.randn(1)]
|
||||
data = dict(inputs=_batch_inputs, data_sample=_batch_labels)
|
||||
output = data_preprocessor(data)
|
||||
inputs, data_samples = output['inputs'], output['data_sample']
|
||||
target_batch_inputs = (_batch_inputs - 127.5) / 127.5
|
||||
self.assertEqual(inputs.shape, torch.Size([2, 3, 10, 10]))
|
||||
self.assertTrue(torch.is_floating_point(inputs))
|
||||
assert_allclose(target_batch_inputs, inputs)
|
||||
|
||||
# Test empty `data_sample`
|
||||
data = [dict(inputs=inputs1.clone()), dict(inputs=inputs2.clone())]
|
||||
data_preprocessor(data, True)
|
||||
data = dict(
|
||||
inputs=[inputs1.clone(), inputs2.clone()], data_sample=None)
|
||||
output = data_preprocessor(data, True)
|
||||
inputs, data_samples = output['inputs'], output['data_sample']
|
||||
self.assertIsNone(data_samples)
|
||||
self.assertTrue(torch.is_floating_point(inputs))
|
||||
|
@ -9,9 +9,10 @@ import torch.nn as nn
|
||||
from torch.optim import SGD
|
||||
|
||||
from mmengine.dist import all_gather
|
||||
from mmengine.model import (BaseModel, MMDistributedDataParallel,
|
||||
from mmengine.model import (BaseDataPreprocessor, BaseModel,
|
||||
ExponentialMovingAverage,
|
||||
MMDistributedDataParallel,
|
||||
MMSeparateDistributedDataParallel)
|
||||
from mmengine.model.averaged_model import ExponentialMovingAverage
|
||||
from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
|
||||
from mmengine.testing import assert_allclose
|
||||
from mmengine.testing._internal import MultiProcessTestCase
|
||||
@ -22,15 +23,22 @@ if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
|
||||
from mmengine.model import MMFullyShardedDataParallel # noqa: F401
|
||||
|
||||
|
||||
class ToyDataPreprocessor(BaseDataPreprocessor):
|
||||
|
||||
def forward(self, data: dict, training: bool = False):
|
||||
self.called = True
|
||||
return super().forward(data, training)
|
||||
|
||||
|
||||
class ToyModel(BaseModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
super().__init__(data_preprocessor=ToyDataPreprocessor())
|
||||
self.conv1 = nn.Conv2d(3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(1, 1, 1)
|
||||
|
||||
def forward(self, x, data_samples=None, mode='tensor'):
|
||||
x = self.conv1(x)
|
||||
def forward(self, inputs, data_sample=None, mode='tensor'):
|
||||
x = self.conv1(inputs)
|
||||
x = self.conv2(x)
|
||||
if mode == 'loss':
|
||||
return dict(loss=x)
|
||||
@ -48,10 +56,10 @@ class ComplexModel(BaseModel):
|
||||
self.conv2 = nn.Conv2d(3, 1, 1)
|
||||
|
||||
def train_step(self, data, optim_wrapper):
|
||||
batch_inputs, _ = self.data_preprocessor(data)
|
||||
loss1 = self.conv1(batch_inputs)
|
||||
inputs = self.data_preprocessor(data)['inputs']
|
||||
loss1 = self.conv1(inputs)
|
||||
optim_wrapper['optim_wrapper1'].update_params(loss1)
|
||||
loss2 = self.conv2(batch_inputs)
|
||||
loss2 = self.conv2(inputs)
|
||||
optim_wrapper['optim_wrapper2'].update_params(loss2)
|
||||
return dict(loss1=loss1, loss2=loss2)
|
||||
|
||||
@ -82,9 +90,9 @@ class TestDistributedDataParallel(MultiProcessTestCase):
|
||||
optimizer = SGD(ddp_model.parameters(), lr=0)
|
||||
optim_wrapper = AmpOptimWrapper(
|
||||
optimizer=optimizer, accumulative_counts=3)
|
||||
inputs = torch.randn(3, 1, 1).cuda() * self.rank * 255
|
||||
data = dict(inputs=inputs, data_sample=MagicMock())
|
||||
res = ddp_model.train_step([data], optim_wrapper=optim_wrapper)['loss']
|
||||
inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255
|
||||
data = dict(inputs=inputs, data_sample=None)
|
||||
res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss']
|
||||
self.assertIs(res.dtype, torch.float16)
|
||||
grad = ddp_model.module.conv1.weight.grad
|
||||
all_grads = all_gather(grad)
|
||||
@ -92,10 +100,10 @@ class TestDistributedDataParallel(MultiProcessTestCase):
|
||||
assert_allclose(all_grads[0], all_grads[1])
|
||||
|
||||
# Gradient accumulation
|
||||
ddp_model.train_step([data], optim_wrapper=optim_wrapper)
|
||||
ddp_model.train_step(data, optim_wrapper=optim_wrapper)
|
||||
|
||||
# Test update params and clean grads.
|
||||
ddp_model.train_step([data], optim_wrapper=optim_wrapper)
|
||||
ddp_model.train_step(data, optim_wrapper=optim_wrapper)
|
||||
grad = ddp_model.module.conv1.weight.grad
|
||||
all_grads = all_gather(grad)
|
||||
assert_allclose(all_grads[0], torch.zeros_like(all_grads[0]))
|
||||
@ -105,20 +113,22 @@ class TestDistributedDataParallel(MultiProcessTestCase):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
model = ToyModel()
|
||||
ddp_model = MMDistributedDataParallel(module=model)
|
||||
inputs = torch.randn(3, 1, 1) * self.rank * 255
|
||||
data = dict(inputs=inputs, data_sample=MagicMock())
|
||||
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
|
||||
data = dict(inputs=inputs, data_sample=None)
|
||||
# Test get predictions.
|
||||
predictions = ddp_model.val_step([data])
|
||||
predictions = ddp_model.val_step(data)
|
||||
self.assertIsInstance(predictions, torch.Tensor)
|
||||
self.assertTrue(model.data_preprocessor.called)
|
||||
|
||||
def test_test_step(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
model = ToyModel()
|
||||
ddp_model = MMDistributedDataParallel(module=model)
|
||||
inputs = torch.randn(3, 1, 1) * self.rank * 255
|
||||
data = dict(inputs=inputs, data_sample=MagicMock())
|
||||
predictions = ddp_model.test_step([data])
|
||||
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
|
||||
data = dict(inputs=inputs, data_sample=None)
|
||||
predictions = ddp_model.test_step(data)
|
||||
self.assertIsInstance(predictions, torch.Tensor)
|
||||
self.assertTrue(model.data_preprocessor.called)
|
||||
|
||||
def _init_dist_env(self, rank, world_size):
|
||||
"""Initialize the distributed environment."""
|
||||
@ -157,30 +167,30 @@ class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel):
|
||||
optim_wrapper2 = OptimWrapper(optimizer2, 1)
|
||||
optim_wrapper_dict = OptimWrapperDict(
|
||||
optim_wrapper1=optim_wrapper1, optim_wrapper2=optim_wrapper2)
|
||||
inputs = torch.randn(3, 1, 1).cuda() * self.rank * 255
|
||||
data = dict(inputs=inputs)
|
||||
inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255
|
||||
data = dict(inputs=inputs, data_sample=None)
|
||||
# Automatically sync grads of `optim_wrapper1` since
|
||||
# `cumulative_iters` = 1
|
||||
ddp_model.train()
|
||||
self.assertTrue(ddp_model.training)
|
||||
ddp_model.train_step([data], optim_wrapper=optim_wrapper_dict)
|
||||
ddp_model.train_step(data, optim_wrapper=optim_wrapper_dict)
|
||||
|
||||
def test_val_step(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
model = ComplexModel()
|
||||
ddp_model = MMSeparateDistributedDataParallel(model)
|
||||
data = torch.randn(3, 1, 1)
|
||||
data = torch.randn(1, 3, 1, 1)
|
||||
# Test get predictions.
|
||||
ddp_model.eval()
|
||||
self.assertFalse(ddp_model.training)
|
||||
predictions = ddp_model.val_step([data])
|
||||
predictions = ddp_model.val_step(data)
|
||||
self.assertEqual(predictions, 1)
|
||||
|
||||
def test_test_step(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
model = ComplexModel()
|
||||
ddp_model = MMSeparateDistributedDataParallel(model)
|
||||
data = torch.randn(3, 1, 1)
|
||||
data = torch.randn(1, 3, 1, 1)
|
||||
# Test get predictions.
|
||||
ddp_model.eval()
|
||||
self.assertFalse(ddp_model.training)
|
||||
@ -225,27 +235,27 @@ class TestMMFullyShardedDataParallel(MultiProcessTestCase):
|
||||
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
|
||||
optimizer = SGD(fsdp_model.parameters(), lr=0)
|
||||
optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
|
||||
inputs = torch.randn(3, 1, 1) * self.rank * 255
|
||||
data = dict(inputs=inputs, data_sample=MagicMock())
|
||||
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
|
||||
data = dict(inputs=[inputs], data_sample=MagicMock())
|
||||
fsdp_model.train()
|
||||
self.assertTrue(fsdp_model.training)
|
||||
fsdp_model.train_step([data], optim_wrapper=optim_wrapper)
|
||||
fsdp_model.train_step(data, optim_wrapper=optim_wrapper)
|
||||
|
||||
def test_val_step(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
model = ToyModel()
|
||||
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
|
||||
inputs = torch.randn(3, 1, 1) * self.rank * 255
|
||||
data = dict(inputs=inputs, data_sample=MagicMock())
|
||||
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
|
||||
data = dict(inputs=[inputs], data_sample=MagicMock())
|
||||
# Test get predictions.
|
||||
predictions = fsdp_model.val_step([data])
|
||||
predictions = fsdp_model.val_step(data)
|
||||
self.assertIsInstance(predictions, torch.Tensor)
|
||||
|
||||
def test_test_step(self):
|
||||
self._init_dist_env(self.rank, self.world_size)
|
||||
model = ToyModel()
|
||||
fsdp_model = MMFullyShardedDataParallel(module=model.cuda())
|
||||
inputs = torch.randn(3, 1, 1) * self.rank * 255
|
||||
inputs = torch.randn(1, 3, 1, 1) * self.rank * 255
|
||||
data = dict(inputs=inputs, data_sample=MagicMock())
|
||||
predictions = fsdp_model.test_step([data])
|
||||
predictions = fsdp_model.test_step(data)
|
||||
self.assertIsInstance(predictions, torch.Tensor)
|
||||
|
@ -14,7 +14,7 @@ from torch.optim import SGD, Adam
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmengine.config import Config
|
||||
from mmengine.dataset import DefaultSampler
|
||||
from mmengine.dataset import COLLATE_FUNCTIONS, DefaultSampler, pseudo_collate
|
||||
from mmengine.evaluator import BaseMetric, Evaluator
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
|
||||
IterTimerHook, LoggerHook, ParamSchedulerHook,
|
||||
@ -44,15 +44,18 @@ class ToyModel(BaseModel):
|
||||
self.linear1 = nn.Linear(2, 2)
|
||||
self.linear2 = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, batch_inputs, labels, mode='tensor'):
|
||||
labels = torch.stack(labels)
|
||||
outputs = self.linear1(batch_inputs)
|
||||
def forward(self, inputs, data_sample, mode='tensor'):
|
||||
if isinstance(inputs, list):
|
||||
inputs = torch.stack(inputs)
|
||||
if isinstance(data_sample, list):
|
||||
data_sample = torch.stack(data_sample)
|
||||
outputs = self.linear1(inputs)
|
||||
outputs = self.linear2(outputs)
|
||||
|
||||
if mode == 'tensor':
|
||||
return outputs
|
||||
elif mode == 'loss':
|
||||
loss = (labels - outputs).sum()
|
||||
loss = (data_sample - outputs).sum()
|
||||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
elif mode == 'predict':
|
||||
@ -74,15 +77,16 @@ class ToySyncBNModel(BaseModel):
|
||||
self.conv = nn.Conv2d(3, 8, 2)
|
||||
self.bn = nn.SyncBatchNorm(8)
|
||||
|
||||
def forward(self, batch_inputs, labels, mode='tensor'):
|
||||
labels = torch.stack(labels)
|
||||
outputs = self.conv(batch_inputs)
|
||||
def forward(self, inputs, data_sample, mode='tensor'):
|
||||
data_sample = torch.stack(data_sample)
|
||||
inputs = torch.stack(inputs)
|
||||
outputs = self.conv(inputs)
|
||||
outputs = self.bn(outputs)
|
||||
|
||||
if mode == 'tensor':
|
||||
return outputs
|
||||
elif mode == 'loss':
|
||||
loss = (labels - outputs).sum()
|
||||
loss = (data_sample - outputs).sum()
|
||||
outputs = dict(loss=loss)
|
||||
return outputs
|
||||
elif mode == 'predict':
|
||||
@ -98,24 +102,25 @@ class ToyGANModel(BaseModel):
|
||||
self.linear1 = nn.Linear(2, 1)
|
||||
self.linear2 = nn.Linear(2, 1)
|
||||
|
||||
def forward(self, batch_inputs, labels, mode='tensor'):
|
||||
labels = torch.stack(labels)
|
||||
output1 = self.linear1(batch_inputs)
|
||||
output2 = self.linear2(batch_inputs)
|
||||
def forward(self, inputs, data_sample, mode='tensor'):
|
||||
data_sample = torch.stack(data_sample)
|
||||
inputs = torch.stack(inputs)
|
||||
output1 = self.linear1(inputs)
|
||||
output2 = self.linear2(inputs)
|
||||
|
||||
if mode == 'tensor':
|
||||
return output1, output2
|
||||
elif mode == 'loss':
|
||||
loss1 = (labels - output1).sum()
|
||||
loss2 = (labels - output2).sum()
|
||||
loss1 = (data_sample - output1).sum()
|
||||
loss2 = (data_sample - output2).sum()
|
||||
outputs = dict(linear1=loss1, linear2=loss2)
|
||||
return outputs
|
||||
elif mode == 'predict':
|
||||
return output1, output2
|
||||
|
||||
def train_step(self, data, optim_wrapper):
|
||||
batch_inputs, batch_labels = self.data_preprocessor(data)
|
||||
loss = self(batch_inputs, batch_labels, mode='loss')
|
||||
data = self.data_preprocessor(data)
|
||||
loss = self(**data, mode='loss')
|
||||
optim_wrapper['linear1'].update_params(loss['linear1'])
|
||||
optim_wrapper['linear2'].update_params(loss['linear2'])
|
||||
return loss
|
||||
@ -193,7 +198,7 @@ class ToyMetric1(BaseMetric):
|
||||
super().__init__(collect_device=collect_device)
|
||||
self.dummy_metrics = dummy_metrics
|
||||
|
||||
def process(self, data_samples, predictions):
|
||||
def process(self, data_batch, predictions):
|
||||
result = {'acc': 1}
|
||||
self.results.append(result)
|
||||
|
||||
@ -208,7 +213,7 @@ class ToyMetric2(BaseMetric):
|
||||
super().__init__(collect_device=collect_device)
|
||||
self.dummy_metrics = dummy_metrics
|
||||
|
||||
def process(self, data_samples, predictions):
|
||||
def process(self, data_batch, predictions):
|
||||
result = {'acc': 1}
|
||||
self.results.append(result)
|
||||
|
||||
@ -335,7 +340,12 @@ class ToyEvaluator(Evaluator):
|
||||
|
||||
|
||||
def collate_fn(data_batch):
|
||||
return data_batch
|
||||
return pseudo_collate(data_batch)
|
||||
|
||||
|
||||
@COLLATE_FUNCTIONS.register_module()
|
||||
def custom_collate(data_batch, pad_value):
|
||||
return pseudo_collate(data_batch)
|
||||
|
||||
|
||||
class TestRunner(TestCase):
|
||||
@ -1428,7 +1438,7 @@ class TestRunner(TestCase):
|
||||
val_batch_idx_targets):
|
||||
self.assertEqual(result, target)
|
||||
|
||||
# 5. test dynamic interval in IterBasedTrainLoop
|
||||
# 5.1 test dynamic interval in IterBasedTrainLoop
|
||||
max_iters = 12
|
||||
interval = 5
|
||||
dynamic_intervals = [(11, 2)]
|
||||
@ -1465,7 +1475,7 @@ class TestRunner(TestCase):
|
||||
for result, target, in zip(val_interval_results, val_interval_targets):
|
||||
self.assertEqual(result, target)
|
||||
|
||||
# 6. test dynamic interval in EpochBasedTrainLoop
|
||||
# 5.2 test dynamic interval in EpochBasedTrainLoop
|
||||
max_epochs = 12
|
||||
interval = 5
|
||||
dynamic_intervals = [(11, 2)]
|
||||
@ -1579,6 +1589,30 @@ class TestRunner(TestCase):
|
||||
runner = runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
# 10.1 Test build dataloader with default collate function
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_train10.1'
|
||||
cfg.train_dataloader.update(collate_fn=dict(type='default_collate'))
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
# 10.2 Test build dataloader with custom collate function
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_train10.2'
|
||||
cfg.train_dataloader.update(
|
||||
collate_fn=dict(type='custom_collate', pad_value=100))
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
# 11 test build dataloader without default arguments of collate
|
||||
# function.
|
||||
with self.assertRaises(TypeError):
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_train11'
|
||||
cfg.train_dataloader.update(collate_fn=dict(type='custom_collate'))
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
def test_val(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_val1'
|
||||
@ -1890,8 +1924,6 @@ class TestRunner(TestCase):
|
||||
ckpt = torch.load(path)
|
||||
self.assertEqual(ckpt['meta']['epoch'], 3)
|
||||
self.assertEqual(ckpt['meta']['iter'], 12)
|
||||
self.assertEqual(ckpt['meta']['dataset_meta'],
|
||||
runner.train_dataloader.dataset.metainfo)
|
||||
self.assertEqual(ckpt['meta']['experiment_name'],
|
||||
runner.experiment_name)
|
||||
self.assertEqual(ckpt['meta']['seed'], runner.seed)
|
||||
|
Loading…
x
Reference in New Issue
Block a user