From 8770c6c7fc21b50e90e7aca2a4faaaeee444da87 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Wed, 24 Aug 2022 22:04:55 +0800 Subject: [PATCH] [Refactor] Refactor data flow to make the interface more natural (#468) * [Refactor]: modify interface of Visualizer.add_datasample (#365) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader` (#323) * acollate data in dataloader * fix docstring * refine comment * fix as comment * refactor default collate and psedo collate * foramt test file * fix docstring * fix as comment * rename elem to data_item * minor fix * fix as comment * [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (#360) * refine evaluator and metric * compatible with new default collate * replace default collate with pseudo * Handle data_batch in metric * fix unit test * fix unit test * fix unit test * minor refine * make data_batch optional make data_batch optional * rename outputs to predictions * fix ut * rename predictions to outputs * fix docstring * fix docstring * fix unit test * make outputs and data_batch to kwargs * fix unit test * keep signature of metric * fix ut * rename pred_sample arguments to data_sample(Visualizer) * fix loop and ut * [refactor]: Refactor model dataflow (#398) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * refactor model data flow tmp_commt tmp commit * make val_cfg and test_cfg optional * roll back runner * pass test mmdet * fix as comment fix as comment fix ci in DataPreprocessor * fix ut * fix ut * fix rebase main * [Fix]: Fix test val ddp (#462) * [Fix] Fix docstring and type hint of data flow (#463) * Fix docstring of data flow * change signature of hook * fix unit test * resolve conflicts * fix lint --- docs/zh_cn/tutorials/hook.md | 4 +- mmengine/dataset/__init__.py | 5 +- mmengine/dataset/utils.py | 133 ++++++++++++-- mmengine/evaluator/evaluator.py | 63 +++---- mmengine/evaluator/metric.py | 10 +- mmengine/hooks/checkpoint_hook.py | 7 +- mmengine/hooks/empty_cache_hook.py | 10 +- mmengine/hooks/hook.py | 53 ++---- mmengine/hooks/iter_timer_hook.py | 16 +- mmengine/hooks/logger_hook.py | 39 ++-- mmengine/hooks/naive_visualization_hook.py | 22 ++- mmengine/hooks/param_scheduler_hook.py | 10 +- mmengine/hooks/runtime_info_hook.py | 4 +- mmengine/model/base_model/__init__.py | 8 +- mmengine/model/base_model/base_model.py | 81 +++++---- .../model/base_model/data_preprocessor.py | 169 ++++++++++-------- mmengine/model/wrappers/distributed.py | 49 +++-- .../wrappers/fully_sharded_distributed.py | 23 ++- .../model/wrappers/seperate_distributed.py | 19 +- mmengine/runner/loops.py | 8 +- mmengine/runner/runner.py | 15 +- mmengine/visualization/visualizer.py | 3 +- tests/test_data/test_data_utils.py | 155 ++++++++++++++++ tests/test_evaluator/test_evaluator.py | 60 ++++--- tests/test_evaluator/test_metric.py | 8 +- tests/test_hooks/test_ema_hook.py | 7 +- .../test_base_model/test_base_model.py | 38 ++-- .../test_base_model/test_data_preprocessor.py | 115 ++++++++---- .../test_wrappers/test_model_wrapper.py | 76 ++++---- tests/test_runner/test_runner.py | 80 ++++++--- 30 files changed, 842 insertions(+), 448 deletions(-) create mode 100644 tests/test_data/test_data_utils.py diff --git a/docs/zh_cn/tutorials/hook.md b/docs/zh_cn/tutorials/hook.md index 1a919edc..e3b3d498 100644 --- a/docs/zh_cn/tutorials/hook.md +++ b/docs/zh_cn/tutorials/hook.md @@ -228,10 +228,8 @@ class CheckInvalidLossHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. + data_batch (dict or tuple or list, optional): Data from dataloader. outputs (dict, optional): Outputs from model. - Defaults to None. """ if self.every_n_train_iters(runner, self.interval): assert torch.isfinite(outputs['loss']),\ diff --git a/mmengine/dataset/__init__.py b/mmengine/dataset/__init__.py index 33345831..c58ef983 100644 --- a/mmengine/dataset/__init__.py +++ b/mmengine/dataset/__init__.py @@ -2,10 +2,11 @@ from .base_dataset import BaseDataset, Compose, force_full_init from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset from .sampler import DefaultSampler, InfiniteSampler -from .utils import pseudo_collate, worker_init_fn +from .utils import (COLLATE_FUNCTIONS, default_collate, pseudo_collate, + worker_init_fn) __all__ = [ 'BaseDataset', 'Compose', 'force_full_init', 'ClassBalancedDataset', 'ConcatDataset', 'RepeatDataset', 'DefaultSampler', 'InfiniteSampler', - 'worker_init_fn', 'pseudo_collate' + 'worker_init_fn', 'pseudo_collate', 'COLLATE_FUNCTIONS', 'default_collate' ] diff --git a/mmengine/dataset/utils.py b/mmengine/dataset/utils.py index c284a336..0b93f842 100644 --- a/mmengine/dataset/utils.py +++ b/mmengine/dataset/utils.py @@ -1,11 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import random -from typing import Sequence +from typing import Any, Mapping, Sequence import numpy as np import torch +from torch.utils.data._utils.collate import \ + default_collate as torch_default_collate -DATA_BATCH = Sequence[dict] +from mmengine.registry import Registry +from mmengine.structures import BaseDataElement + +COLLATE_FUNCTIONS = Registry('Collate Functions') def worker_init_fn(worker_id: int, num_workers: int, rank: int, @@ -28,16 +33,124 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, torch.manual_seed(worker_seed) -def pseudo_collate(data_batch: DATA_BATCH) -> DATA_BATCH: - """The default behavior of dataloader is to merge a list of samples to form - a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate`` does - nothing just returns ``data_batch``. +@COLLATE_FUNCTIONS.register_module() +def pseudo_collate(data_batch: Sequence) -> Any: + """Convert list of data sampled from dataset into a batch of data, of which + type consistent with the type of each data_itement in ``data_batch``. + The default behavior of dataloader is to merge a list of samples to form + a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate`` + will not stack tensors to batch tensors, and convert int, float, ndarray to + tensors. + + This code is referenced from: + `Pytorch default_collate `_. # noqa: E501 Args: - data_batch (Sequence[dict]): Batch of data from - dataloader. + data_batch (Sequence): Batch of data from dataloader. Returns: - Sequence[dict]: Return input ``data_batch``. + Any: Transversed Data in the same format as the data_itement of + ``data_batch``. """ - return data_batch + data_item = data_batch[0] + data_item_type = type(data_item) + if isinstance(data_item, (str, bytes)): + return data_batch + elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'): + # named tuple + return data_item_type(*(pseudo_collate(samples) + for samples in zip(*data_batch))) + elif isinstance(data_item, Sequence): + # check to make sure that the data_itements in batch have + # consistent size + it = iter(data_batch) + data_item_size = len(next(it)) + if not all(len(data_item) == data_item_size for data_item in it): + raise RuntimeError( + 'each data_itement in list of batch should be of equal size') + transposed = list(zip(*data_batch)) + + if isinstance(data_item, tuple): + return [pseudo_collate(samples) + for samples in transposed] # Compat with Pytorch. + else: + try: + return data_item_type( + [pseudo_collate(samples) for samples in transposed]) + except TypeError: + # The sequence type may not support `__init__(iterable)` + # (e.g., `range`). + return [pseudo_collate(samples) for samples in transposed] + elif isinstance(data_item, Mapping): + return data_item_type({ + key: pseudo_collate([d[key] for d in data_batch]) + for key in data_item + }) + else: + return data_batch + + +@COLLATE_FUNCTIONS.register_module() +def default_collate(data_batch: Sequence) -> Any: + """Convert list of data sampled from dataset into a batch of data, of which + type consistent with the type of each data_itement in ``data_batch``. + + Different from :func:`pseudo_collate`, ``default_collate`` will stack + tensor contained in ``data_batch`` into a batched tensor with the + first dimension batch size, and then move input tensor to the target + device. + + Different from ``default_collate`` in pytorch, ``default_collate`` will + not process ``BaseDataElement``. + + This code is referenced from: + `Pytorch default_collate `_. # noqa: E501 + + Note: + ``default_collate`` only accept input tensor with the same shape. + + Args: + data_batch (Sequence): Data sampled from dataset. + + Returns: + Any: Data in the same format as the data_itement of ``data_batch``, of which + tensors have been stacked, and ndarray, int, float have been + converted to tensors. + """ + data_item = data_batch[0] + data_item_type = type(data_item) + + if isinstance(data_item, (BaseDataElement, str, bytes)): + return data_batch + elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'): + # named_tuple + return data_item_type(*(default_collate(samples) + for samples in zip(*data_batch))) + elif isinstance(data_item, Sequence): + # check to make sure that the data_itements in batch have + # consistent size + it = iter(data_batch) + data_item_size = len(next(it)) + if not all(len(data_item) == data_item_size for data_item in it): + raise RuntimeError( + 'each data_itement in list of batch should be of equal size') + transposed = list(zip(*data_batch)) + + if isinstance(data_item, tuple): + return [default_collate(samples) + for samples in transposed] # Compat with Pytorch. + else: + try: + return data_item_type( + [default_collate(samples) for samples in transposed]) + except TypeError: + # The sequence type may not support `__init__(iterable)` + # (e.g., `range`). + return [default_collate(samples) for samples in transposed] + elif isinstance(data_item, Mapping): + return data_item_type({ + key: default_collate([d[key] for d in data_batch]) + for key in data_item + }) + else: + return torch_default_collate(data_batch) diff --git a/mmengine/evaluator/evaluator.py b/mmengine/evaluator/evaluator.py index fb74dbed..a435a4ce 100644 --- a/mmengine/evaluator/evaluator.py +++ b/mmengine/evaluator/evaluator.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Iterator, List, Optional, Sequence, Union +from typing import Any, Iterator, List, Optional, Sequence, Union +from mmengine.dataset import pseudo_collate from mmengine.registry import EVALUATOR, METRICS from mmengine.structures import BaseDataElement from .metric import BaseMetric @@ -37,34 +38,26 @@ class Evaluator: for metric in self.metrics: metric.dataset_meta = dataset_meta - def process(self, data_batch: Sequence[dict], - predictions: Sequence[BaseDataElement]): + def process(self, + data_samples: Sequence[BaseDataElement], + data_batch: Optional[Any] = None): """Convert ``BaseDataSample`` to dict and invoke process method of each metric. Args: - data_batch (Sequence[dict]): A batch of data from the dataloader. - predictions (Sequence[BaseDataElement]): A batch of outputs from - the model. + data_samples (Sequence[BaseDataElement]): predictions of the model, + and the ground truth of the validation set. + data_batch (Any, optional): A batch of data from the dataloader. """ - _data_batch = [] - for data in data_batch: - if isinstance(data['data_sample'], BaseDataElement): - _data_batch.append( - dict( - inputs=data['inputs'], - data_sample=data['data_sample'].to_dict())) + _data_samples = [] + for data_sample in data_samples: + if isinstance(data_sample, BaseDataElement): + _data_samples.append(data_sample.to_dict()) else: - _data_batch.append(data) - _predictions = [] - for pred in predictions: - if isinstance(pred, BaseDataElement): - _predictions.append(pred.to_dict()) - else: - _predictions.append(pred) + _data_samples.append(data_sample) for metric in self.metrics: - metric.process(_data_batch, _predictions) + metric.process(data_batch, _data_samples) def evaluate(self, size: int) -> dict: """Invoke ``evaluate`` method of each metric and collect the metrics @@ -97,20 +90,26 @@ class Evaluator: return metrics def offline_evaluate(self, - data: Sequence, - predictions: Sequence, + data_samples: Sequence, + data: Optional[Sequence] = None, chunk_size: int = 1): """Offline evaluate the dumped predictions on the given data . Args: - data (Sequence): All data of the validation set. - predictions (Sequence): All predictions of the model on the - validation set. + data_samples (Sequence): All predictions and ground truth of the + model and the validation set. + data (Sequence, optional): All data of the validation set. chunk_size (int): The number of data samples and predictions to be processed in a batch. """ # support chunking iterable objects + if data is not None: + assert len(data_samples) == len(data), ( + 'outputs and data should have the same length, but got ' + f'outputs length: {len(data_samples)} ' + f'data length: {len(data)}') + def get_chunks(seq: Iterator, chunk_size=1): stop = False while not stop: @@ -125,9 +124,11 @@ class Evaluator: yield chunk size = 0 - for data_chunk, pred_chunk in zip( - get_chunks(iter(data), chunk_size), - get_chunks(iter(predictions), chunk_size)): - size += len(data_chunk) - self.process(data_chunk, pred_chunk) + for output_chunk in get_chunks(iter(data_samples), chunk_size): + if data is not None: + data_chunk = pseudo_collate(data[size:size + chunk_size]) + else: + data_chunk = None + size += len(output_chunk) + self.process(output_chunk, data_chunk) return self.evaluate(size) diff --git a/mmengine/evaluator/metric.py b/mmengine/evaluator/metric.py index 5d7be857..f7f7df5b 100644 --- a/mmengine/evaluator/metric.py +++ b/mmengine/evaluator/metric.py @@ -58,15 +58,14 @@ class BaseMetric(metaclass=ABCMeta): self._dataset_meta = dataset_meta @abstractmethod - def process(self, data_batch: Sequence[dict], - predictions: Sequence[dict]) -> None: + def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None: """Process one batch of data samples and predictions. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: - data_batch (Sequence[dict]): A batch of data from the dataloader. - predictions (Sequence[dict]): A batch of outputs from + data_batch (Any): A batch of data from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from the model. """ @@ -146,8 +145,7 @@ class DumpResults(BaseMetric): raise ValueError('The output file must be a pkl file.') self.out_file_path = out_file_path - def process(self, data_batch: Sequence[dict], - predictions: Sequence[dict]) -> None: + def process(self, data_batch: Any, predictions: Sequence[dict]) -> None: """transfer tensors in predictions to CPU.""" self.results.extend(_to_cpu(predictions)) diff --git a/mmengine/hooks/checkpoint_hook.py b/mmengine/hooks/checkpoint_hook.py index 37e34210..448d2545 100644 --- a/mmengine/hooks/checkpoint_hook.py +++ b/mmengine/hooks/checkpoint_hook.py @@ -12,7 +12,7 @@ from mmengine.registry import HOOKS from mmengine.utils import is_list_of, is_seq_of from .hook import Hook -DATA_BATCH = Optional[Sequence[dict]] +DATA_BATCH = Optional[Union[dict, tuple, list]] @HOOKS.register_module() @@ -470,9 +470,8 @@ class CheckpointHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. - outputs (dict, optional): Outputs from model. Defaults to None. + data_batch (dict or tuple or list, optional): Data from dataloader. + outputs (dict, optional): Outputs from model. """ if self.by_epoch: return diff --git a/mmengine/hooks/empty_cache_hook.py b/mmengine/hooks/empty_cache_hook.py index 60edb890..b9b5eba0 100644 --- a/mmengine/hooks/empty_cache_hook.py +++ b/mmengine/hooks/empty_cache_hook.py @@ -4,10 +4,9 @@ from typing import Optional, Sequence, Union import torch from mmengine.registry import HOOKS -from mmengine.structures import BaseDataElement from .hook import Hook -DATA_BATCH = Optional[Sequence[dict]] +DATA_BATCH = Optional[Union[dict, tuple, list]] @HOOKS.register_module() @@ -38,18 +37,15 @@ class EmptyCacheHook(Hook): runner, batch_idx: int, data_batch: DATA_BATCH = None, - outputs: Optional[Union[dict, - Sequence[BaseDataElement]]] = None, + outputs: Optional[Union[dict, Sequence]] = None, mode: str = 'train') -> None: """Empty cache after an iteration. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. + data_batch (dict or tuple or list, optional): Data from dataloader. outputs (dict or sequence, optional): Outputs from model. - Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. """ if self._do_after_iter: diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 1a5d88cf..67f6bc23 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -1,9 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, Optional, Sequence, Union -from mmengine.structures import BaseDataElement - -DATA_BATCH = Optional[Sequence[dict]] +DATA_BATCH = Optional[Union[dict, tuple, list]] class Hook: @@ -184,8 +182,7 @@ class Hook: Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. + data_batch (dict or tuple or list, optional): Data from dataloader. """ self._before_iter( runner, batch_idx=batch_idx, data_batch=data_batch, mode='train') @@ -200,7 +197,7 @@ class Hook: Args: runner (Runner): The runner of the validation process. batch_idx (int): The index of the current batch in the val loop. - data_batch (Sequence[dict], optional): Data from dataloader. + data_batch (dict, optional): Data from dataloader. Defaults to None. """ self._before_iter( @@ -216,7 +213,7 @@ class Hook: Args: runner (Runner): The runner of the testing process. batch_idx (int): The index of the current batch in the test loop. - data_batch (Sequence[dict], optional): Data from dataloader. + data_batch (dict or tuple or list, optional): Data from dataloader. Defaults to None. """ self._before_iter( @@ -233,10 +230,8 @@ class Hook: Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. + data_batch (dict tuple or list, optional): Data from dataloader. outputs (dict, optional): Outputs from model. - Defaults to None. """ self._after_iter( runner, @@ -249,18 +244,15 @@ class Hook: runner, batch_idx: int, data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataElement]] = None) \ - -> None: + outputs: Optional[Sequence] = None) -> None: """All subclasses should override this method, if they need any operations after each validation iteration. Args: runner (Runner): The runner of the validation process. batch_idx (int): The index of the current batch in the val loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. - outputs (dict or sequence, optional): Outputs from - model. Defaults to None. + data_batch (dict or tuple or list, optional): Data from dataloader. + outputs (Sequence, optional): Outputs from model. """ self._after_iter( runner, @@ -269,22 +261,19 @@ class Hook: outputs=outputs, mode='val') - def after_test_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataElement]] = None) -> None: + def after_test_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[Sequence] = None) -> None: """All subclasses should override this method, if they need any operations after each test iteration. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the test loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. - outputs (dict, optional): Outputs from model. - Defaults to None. + data_batch (dict or tuple or list, optional): Data from dataloader. + outputs (Sequence, optional): Outputs from model. """ self._after_iter( runner, @@ -327,8 +316,7 @@ class Hook: runner (Runner): The runner of the training, validation or testing process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. + data_batch (dict or tuple or list, optional): Data from dataloader. mode (str): Current mode of runner. Defaults to 'train'. """ pass @@ -337,8 +325,7 @@ class Hook: runner, batch_idx: int, data_batch: DATA_BATCH = None, - outputs: Optional[Union[Sequence[BaseDataElement], - dict]] = None, + outputs: Optional[Union[Sequence, dict]] = None, mode: str = 'train') -> None: """All subclasses should override this method, if they need any operations after each epoch. @@ -347,10 +334,8 @@ class Hook: runner (Runner): The runner of the training, validation or testing process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. - outputs (Sequence[BaseDataElement], optional): Outputs from model. - Defaults to None. + data_batch (dict or tuple or list, optional): Data from dataloader. + outputs (dict or Sequence, optional): Outputs from model. mode (str): Current mode of runner. Defaults to 'train'. """ pass diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index d5578371..5166c063 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -3,10 +3,9 @@ import time from typing import Optional, Sequence, Union from mmengine.registry import HOOKS -from mmengine.structures import BaseDataElement from .hook import Hook -DATA_BATCH = Optional[Sequence[dict]] +DATA_BATCH = Optional[Union[dict, tuple, list]] @HOOKS.register_module() @@ -54,8 +53,8 @@ class IterTimerHook(Hook): runner (Runner): The runner of the training, validation and testing process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. + data_batch (dict or tuple or list, optional): Data from + dataloader. mode (str): Current mode of runner. Defaults to 'train'. """ # Update data loading time in `runner.message_hub`. @@ -66,8 +65,7 @@ class IterTimerHook(Hook): runner, batch_idx: int, data_batch: DATA_BATCH = None, - outputs: Optional[Union[dict, - Sequence[BaseDataElement]]] = None, + outputs: Optional[Union[dict, Sequence]] = None, mode: str = 'train') -> None: """Calculating time for an iteration and updating "time" ``HistoryBuffer`` of ``runner.message_hub``. @@ -76,10 +74,8 @@ class IterTimerHook(Hook): runner (Runner): The runner of the training validation and testing process. batch_idx (int): The index of the current batch in the loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. - outputs (dict or sequence, optional): Outputs from model. Defaults - to None. + data_batch (dict or tuple or list, optional): Data from dataloader. + outputs (dict or sequence, optional): Outputs from model. mode (str): Current mode of runner. Defaults to 'train'. """ # Update iteration time in `runner.message_hub`. diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index f6f8b719..c189c2d1 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -7,10 +7,9 @@ from typing import Dict, Optional, Sequence, Union from mmengine.fileio import FileClient, dump from mmengine.hooks import Hook from mmengine.registry import HOOKS -from mmengine.structures import BaseDataElement from mmengine.utils import is_tuple_of, scandir -DATA_BATCH = Optional[Sequence[dict]] +DATA_BATCH = Optional[Union[dict, tuple, list]] SUFFIX_TYPE = Union[Sequence[str], str] @@ -125,9 +124,8 @@ class LoggerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. - outputs (dict, optional): Outputs from model. Defaults to None. + data_batch (dict tuple or list, optional): Data from dataloader. + outputs (dict, optional): Outputs from model. """ # Print experiment name every n iterations. if self.every_n_train_iters( @@ -152,41 +150,38 @@ class LoggerHook(Hook): runner.visualizer.add_scalars( tag, step=runner.iter + 1, file_path=self.json_log_path) - def after_val_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataElement]] = None) -> None: + def after_val_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[Sequence] = None) -> None: """Record logs after validation iteration. Args: runner (Runner): The runner of the validation process. batch_idx (int): The index of the current batch in the validation loop. - data_batch (Sequence[dict], optional): Data from dataloader. + data_batch (dict or tuple or list, optional): Data from dataloader. Defaults to None. - outputs (sequence, optional): Outputs from model. Defaults to None. + outputs (sequence, optional): Outputs from model. """ if self.every_n_inner_iters(batch_idx, self.interval): _, log_str = runner.log_processor.get_log_after_iter( runner, batch_idx, 'val') runner.logger.info(log_str) - def after_test_iter( - self, - runner, - batch_idx: int, - data_batch: DATA_BATCH = None, - outputs: Optional[Sequence[BaseDataElement]] = None) -> None: + def after_test_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[Sequence] = None) -> None: """Record logs after testing iteration. Args: runner (Runner): The runner of the testing process. batch_idx (int): The index of the current batch in the test loop. - data_batch (Sequence[dict], optional): Data from dataloader. - Defaults to None. - outputs (sequence, optional): Outputs from model. Defaults to None. + data_batch (dict or tuple or list, optional): Data from dataloader. + outputs (sequence, optional): Outputs from model. """ if self.every_n_inner_iters(batch_idx, self.interval): _, log_str = runner.log_processor.get_log_after_iter( diff --git a/mmengine/hooks/naive_visualization_hook.py b/mmengine/hooks/naive_visualization_hook.py index 671471c3..6a6c3f38 100644 --- a/mmengine/hooks/naive_visualization_hook.py +++ b/mmengine/hooks/naive_visualization_hook.py @@ -1,15 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple, Union import cv2 import numpy as np from mmengine.hooks import Hook from mmengine.registry import HOOKS -from mmengine.structures import BaseDataElement from mmengine.utils.dl_utils import tensor2imgs +DATA_BATCH = Optional[Union[dict, tuple, list]] + # TODO: Due to interface changes, the current class # functions incorrectly @@ -48,21 +49,18 @@ class NaiveVisualizationHook(Hook): unpad_image = input[:unpad_height, :unpad_width] return unpad_image - def after_test_iter( - self, - runner, - batch_idx: int, - data_batch: Optional[Sequence[dict]] = None, - outputs: Optional[Sequence[BaseDataElement]] = None) -> None: + def after_test_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[Sequence] = None) -> None: """Show or Write the predicted results. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the test loop. - data_batch (Sequence[dict], optional): Data - from dataloader. Defaults to None. - outputs (Sequence[BaseDataElement], optional): Outputs from model. - Defaults to None. + data_batch (dict or tuple or list, optional): Data from dataloader. + outputs (Sequence, optional): Outputs from model. """ if self.every_n_inner_iters(batch_idx, self._interval): for data, output in zip(data_batch, outputs): # type: ignore diff --git a/mmengine/hooks/param_scheduler_hook.py b/mmengine/hooks/param_scheduler_hook.py index b5a52dbc..cb033ce4 100644 --- a/mmengine/hooks/param_scheduler_hook.py +++ b/mmengine/hooks/param_scheduler_hook.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Optional, Sequence +from typing import Optional, Union from mmengine.registry import HOOKS from .hook import Hook -DATA_BATCH = Optional[Sequence[dict]] +DATA_BATCH = Optional[Union[dict, tuple, list]] @HOOKS.register_module() @@ -24,12 +24,12 @@ class ParamSchedulerHook(Hook): Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[dict], optional): Data from dataloader. + data_batch (dict or tuple or list, optional): Data from dataloader. In order to keep this interface consistent with other hooks, - we keep ``data_batch`` here. Defaults to None. + we keep ``data_batch`` here. outputs (dict, optional): Outputs from model. In order to keep this interface consistent with other hooks, we - keep ``data_batch`` here. Defaults to None. + keep ``data_batch`` here. """ def step(param_schedulers): diff --git a/mmengine/hooks/runtime_info_hook.py b/mmengine/hooks/runtime_info_hook.py index 0d37cdae..64ecafa9 100644 --- a/mmengine/hooks/runtime_info_hook.py +++ b/mmengine/hooks/runtime_info_hook.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Optional, Sequence +from typing import Dict, Optional, Union from mmengine.registry import HOOKS from mmengine.utils import get_git_hash from mmengine.version import __version__ from .hook import Hook -DATA_BATCH = Optional[Sequence[dict]] +DATA_BATCH = Optional[Union[dict, tuple, list]] @HOOKS.register_module() diff --git a/mmengine/model/base_model/__init__.py b/mmengine/model/base_model/__init__.py index 696c83ad..66a3cb89 100644 --- a/mmengine/model/base_model/__init__.py +++ b/mmengine/model/base_model/__init__.py @@ -1,9 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base_model import BaseModel -from .data_preprocessor import (BaseDataElement, BaseDataPreprocessor, - ImgDataPreprocessor) +from .data_preprocessor import BaseDataPreprocessor, ImgDataPreprocessor -__all__ = [ - 'BaseModel', 'BaseDataElement', 'ImgDataPreprocessor', - 'BaseDataPreprocessor' -] +__all__ = ['BaseModel', 'ImgDataPreprocessor', 'BaseDataPreprocessor'] diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index dbc3ad1e..19fd7845 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -1,21 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import abstractmethod from collections import OrderedDict -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn from mmengine.optim import OptimWrapper from mmengine.registry import MODELS -from mmengine.structures import BaseDataElement from mmengine.utils import is_list_of from ..base_module import BaseModule from .data_preprocessor import BaseDataPreprocessor -ForwardResults = Union[Dict[str, torch.Tensor], List[BaseDataElement], - Tuple[torch.Tensor], torch.Tensor] - class BaseModel(BaseModule): """Base class for all algorithmic models. @@ -85,7 +81,7 @@ class BaseModel(BaseModule): f'`nn.Module` instance, but got ' f'{type(data_preprocessor)}') - def train_step(self, data: List[dict], + def train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: """Implements the default model training process including preprocessing, model forward propagation, loss calculation, @@ -96,7 +92,7 @@ class BaseModel(BaseModule): :class:`IterBasedTrainLoop` will call this method to update model parameters. The default parameter update process is as follows: - 1. Calls ``self.data_processor(data, training=False) to collext + 1. Calls ``self.data_processor(data, training=False) to collect batch_inputs and corresponding data_samples(labels). 2. Calls ``self(batch_inputs, data_samples, mode='loss')`` to get raw loss @@ -105,7 +101,7 @@ class BaseModel(BaseModule): 4. Calls ``optim_wrapper.update_params(loss)`` to update model. Args: - data (List[dict]): Data sampled from dataloader. + data (dict or tuple or list): Data sampled from dataset. optim_wrapper (OptimWrapper): OptimWrapper instance used to update model parameters. @@ -114,13 +110,13 @@ class BaseModel(BaseModule): """ # Enable automatic mixed precision training context. with optim_wrapper.optim_context(self): - batch_inputs, data_samples = self.data_preprocessor(data, True) - losses = self(batch_inputs, data_samples, mode='loss') - parsed_losses, log_vars = self.parse_losses(losses) + data = self.data_preprocessor(data, True) + losses = self._run_forward(data, mode='loss') # type: ignore + parsed_losses, log_vars = self.parse_losses(losses) # type: ignore optim_wrapper.update_params(parsed_losses) return log_vars - def val_step(self, data: List[dict]) -> List[BaseDataElement]: + def val_step(self, data: Union[tuple, dict, list]) -> list: """Gets the predictions of given data. Calls ``self.data_preprocessor(data, False)`` and @@ -128,25 +124,25 @@ class BaseModel(BaseModule): predictions which will be passed to evaluator. Args: - data (List[dict]): Data sampled from dataloader. + data (dict or tuple or list): Data sampled from dataset. Returns: - List[BaseDataElement]: The predictions of given data. + list: The predictions of given data. """ - inputs, data_sample = self.data_preprocessor(data, False) - return self(inputs, data_sample, mode='predict') + data = self.data_preprocessor(data, False) + return self._run_forward(data, mode='predict') # type: ignore - def test_step(self, data: List[dict]) -> List[BaseDataElement]: + def test_step(self, data: Union[dict, tuple, list]) -> list: """``BaseModel`` implements ``test_step`` the same as ``val_step``. Args: - data (List[dict]): Data sampled from dataloader. + data (dict or tuple or list): Data sampled from dataset. Returns: - List[BaseDataElement]: The predictions of given data. + list: The predictions of given data. """ - inputs, data_sample = self.data_preprocessor(data, False) - return self(inputs, data_sample, mode='predict') + data = self.data_preprocessor(data, False) + return self._run_forward(data, mode='predict') # type: ignore def parse_losses( self, losses: Dict[str, torch.Tensor] @@ -239,16 +235,16 @@ class BaseModel(BaseModule): @abstractmethod def forward(self, - batch_inputs: torch.Tensor, - data_samples: Optional[List[BaseDataElement]] = None, - mode: str = 'tensor') -> ForwardResults: + inputs: torch.Tensor, + data_samples: Optional[list] = None, + mode: str = 'tensor') -> Union[Dict[str, torch.Tensor], list]: """Returns losses or predictions of training, validation, testing, and simple inference process. ``forward`` method of BaseModel is an abstract method, its subclasses must implement this method. - Accepts ``batch_inputs`` and ``data_samples`` processed by + Accepts ``batch_inputs`` and ``data_sample`` processed by :attr:`data_preprocessor`, and returns results according to mode arguments. @@ -263,9 +259,9 @@ class BaseModel(BaseModule): loss. Args: - batch_inputs (torch.Tensor): batch input tensor collated by + inputs (torch.Tensor): batch input tensor collated by :attr:`data_preprocessor`. - data_samples (List[BaseDataElement], optional): + data_samples (list, optional): data samples collated by :attr:`data_preprocessor`. mode (str): mode should be one of ``loss``, ``predict`` and ``tensor`` @@ -273,19 +269,36 @@ class BaseModel(BaseModule): - ``loss``: Called by ``train_step`` and return loss ``dict`` used for logging - ``predict``: Called by ``val_step`` and ``test_step`` - and return list of ``BaseDataElement`` results used for - computing metric. + and return list of `results used for computing metric. - ``tensor``: Called by custom use to get ``Tensor`` type results. Returns: - ForwardResults: - + dict or list: - If ``mode == loss``, return a ``dict`` of loss tensor used for backward and logging. - - If ``mode == predict``, return a ``list`` of - :obj:`BaseDataElement` for computing metric - and getting inference result. + - If ``mode == predict``, return a ``list`` of inference + results. - If ``mode == tensor``, return a tensor or ``tuple`` of tensor or ``dict of tensor for custom use. """ + + def _run_forward(self, data: Union[dict, tuple, list], + mode: str) -> Union[Dict[str, torch.Tensor], list]: + """Unpacks data for :meth:`forward` + + Args: + data (dict or tuple or list): Data sampled from dataset. + mode (str): Mode of forward. + + Returns: + dict or list: Results of training or testing mode. + """ + if isinstance(data, dict): + results = self(**data, mode=mode) + elif isinstance(data, (list, tuple)): + results = self(*data, mode=mode) + else: + raise TypeError('Output of `data_preprocessor` should be ' + f'list, tuple or dict, but got {type(data)}') + return results diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 8b4d3e87..92fae1dc 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -1,94 +1,74 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Optional, Sequence, Tuple, Union +import math +from typing import Mapping, Optional, Sequence, Union import torch import torch.nn as nn +import torch.nn.functional as F from mmengine.registry import MODELS from mmengine.structures import BaseDataElement +from mmengine.utils import is_list_of from ..utils import stack_batch +CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list] + @MODELS.register_module() class BaseDataPreprocessor(nn.Module): - """Base data pre-processor used for collating and copying data to the - target device. - - ``BaseDataPreprocessor`` performs data pre-processing according to the - following steps: - - - Collates the data sampled from dataloader. - - Copies data to the target device. - - Stacks the input tensor at the first dimension. + """Base data pre-processor used for copying data to the target device. Subclasses inherit from ``BaseDataPreprocessor`` could override the forward method to implement custom data pre-processing, such as batch-resize, MixUp, or CutMix. - Warnings: - Each item of data sampled from dataloader must be a dict and at least - contain the ``inputs`` key. Furthermore, the value of ``inputs`` - must be a ``Tensor`` with the same shape. + Note: + Data dictionary returned by dataloader must be a dict and at least + contain the ``inputs`` key. """ def __init__(self): super().__init__() self._device = torch.device('cpu') - def collate_data( - self, - data: Sequence[dict]) -> Tuple[List[torch.Tensor], Optional[list]]: - """Collating and copying data to the target device. - - Collates the data sampled from dataloader into a list of tensor and - list of labels, and then copies tensor to the target device. - - Subclasses could override it to be compatible with the custom format - data sampled from custom dataloader. + def cast_data(self, data: CastData) -> CastData: + """Copying data to the target device. Args: - data (Sequence[dict]): Data sampled from dataloader. + data (dict): Data returned by ``DataLoader``. Returns: - Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input - tensor and list of labels at target device. + CollatedResult: Inputs and data sample at target device. """ - inputs = [_data['inputs'].to(self._device).float() for _data in data] - batch_data_samples: List[BaseDataElement] = [] - # Model can get predictions without any data samples. - for _data in data: - if 'data_sample' in _data: - batch_data_samples.append(_data['data_sample']) - # Move data from CPU to corresponding device. - batch_data_samples = [ - data_sample.to(self._device) for data_sample in batch_data_samples - ] + if isinstance(data, Mapping): + return {key: self.cast_data(data[key]) for key in data} + elif isinstance(data, tuple) and hasattr(data, '_fields'): + # namedtuple + return type(data)(*(self.cast_data(sample)for sample in data)) # type: ignore # noqa: E501 # yapf:disable + elif isinstance(data, Sequence): + return [self.cast_data(sample) for sample in data] + elif isinstance(data, torch.Tensor): + return data.to(self.device) + elif isinstance(data, BaseDataElement): + return data.to(self.device) + else: + return data - if not batch_data_samples: - batch_data_samples = None # type: ignore - - return inputs, batch_data_samples - - def forward(self, - data: Sequence[dict], - training: bool = False) -> Tuple[torch.Tensor, Optional[list]]: + def forward(self, data: dict, training: bool = False) -> Union[dict, list]: """Preprocesses the data into the model input format. - After the data pre-processing of :meth:`collate_data`, ``forward`` + After the data pre-processing of :meth:`cast_data`, ``forward`` will stack the input tensor list to a batch tensor at the first dimension. Args: - data (Sequence[dict]): data sampled from dataloader. + data (dict): Data returned by dataloader training (bool): Whether to enable training time augmentation. Returns: - Tuple[torch.Tensor, Optional[list]]: Data in the same format as the - model input. + dict or list: Data in the same format as the model input. """ - inputs, batch_data_samples = self.collate_data(data) - batch_inputs = torch.stack(inputs, dim=0) - return batch_inputs, batch_data_samples + return self.cast_data(data) # type: ignore @property def device(self): @@ -203,42 +183,79 @@ class ImgDataPreprocessor(BaseDataPreprocessor): torch.tensor(std).view(-1, 1, 1), False) else: self._enable_normalize = False - self.channel_conversion = rgb_to_bgr or bgr_to_rgb + self._channel_conversion = rgb_to_bgr or bgr_to_rgb self.pad_size_divisor = pad_size_divisor self.pad_value = pad_value - def forward(self, - data: Sequence[dict], - training: bool = False) -> Tuple[torch.Tensor, Optional[list]]: + def forward(self, data: dict, training: bool = False) -> Union[dict, list]: """Performs normalization、padding and bgr2rgb conversion based on ``BaseDataPreprocessor``. Args: - data (Sequence[dict]): data sampled from dataloader. + data (dict): Data sampled from dataset. If the collate + function of DataLoader is :obj:`pseudo_collate`, data will be a + list of dict. If collate function is :obj:`default_collate`, + data will be a tuple with batch input tensor and list of data + samples. training (bool): Whether to enable training time augmentation. If subclasses override this method, they can perform different preprocessing strategies for training and testing based on the value of ``training``. Returns: - Tuple[torch.Tensor, Optional[list]]: Data in the same format as the - model input. + dict or list: Data in the same format as the model input. """ - inputs, batch_data_samples = self.collate_data(data) - for idx, _input in enumerate(inputs): - # channel transform - if self.channel_conversion: - _input = _input[[2, 1, 0], ...] - # Normalization. + data = self.cast_data(data) # type: ignore + _batch_inputs = data['inputs'] + # Process data with `pseudo_collate`. + if is_list_of(_batch_inputs, torch.Tensor): + batch_inputs = [] + for _batch_input in _batch_inputs: + # channel transform + if self._channel_conversion: + _batch_input = _batch_input[[2, 1, 0], ...] + # Convert to float after channel conversion to ensure + # efficiency + _batch_input = _batch_input.float() + # Normalization. + if self._enable_normalize: + if self.mean.shape[0] == 3: + assert _batch_input.dim( + ) == 3 and _batch_input.shape[0] == 3, ( + 'If the mean has 3 values, the input tensor ' + 'should in shape of (3, H, W), but got the tensor ' + f'with shape {_batch_input.shape}') + _batch_input = (_batch_input - self.mean) / self.std + batch_inputs.append(_batch_input) + # Pad and stack Tensor. + batch_inputs = stack_batch(batch_inputs, self.pad_size_divisor, + self.pad_value) + # Process data with `default_collate`. + elif isinstance(_batch_inputs, torch.Tensor): + assert _batch_inputs.dim() == 4, ( + 'The input of `ImgDataPreprocessor` should be a NCHW tensor ' + 'or a list of tensor, but got a tensor with shape: ' + f'{_batch_inputs.shape}') + if self._channel_conversion: + _batch_inputs = _batch_inputs[:, [2, 1, 0], ...] + # Convert to float after channel conversion to ensure + # efficiency + _batch_inputs = _batch_inputs.float() if self._enable_normalize: - if self.mean.shape[0] == 3: - assert _input.dim() == 3 and _input.shape[0] == 3, ( - 'If the mean has 3 values, the input tensor should in ' - 'shape of (3, H, W), but got the tensor with shape ' - f'{_input.shape}') - _input = (_input - self.mean) / self.std - inputs[idx] = _input - # Pad and stack Tensor. - batch_inputs = stack_batch(inputs, self.pad_size_divisor, - self.pad_value) - return batch_inputs, batch_data_samples + _batch_inputs = (_batch_inputs - self.mean) / self.std + h, w = _batch_inputs.shape[2:] + target_h = math.ceil( + h / self.pad_size_divisor) * self.pad_size_divisor + target_w = math.ceil( + w / self.pad_size_divisor) * self.pad_size_divisor + pad_h = target_h - h + pad_w = target_w - w + batch_inputs = F.pad(_batch_inputs, (0, pad_w, 0, pad_h), + 'constant', self.pad_value) + else: + raise TypeError('Output of `cast_data` should be a list of dict ' + 'or a tuple with inputs and data_samples, but got' + f'{type(data)}: {data}') + data['inputs'] = batch_inputs + data.setdefault('data_samples', None) + return data diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index 376796a2..b07b8210 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -1,12 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List +from typing import Any, Dict, Union import torch from torch.nn.parallel.distributed import DistributedDataParallel from mmengine.optim import OptimWrapper from mmengine.registry import MODEL_WRAPPERS -from mmengine.structures import BaseDataElement from ..utils import detect_anomalous_params @@ -90,7 +89,7 @@ class MMDistributedDataParallel(DistributedDataParallel): super().__init__(module=module, **kwargs) self.detect_anomalous_params = detect_anomalous_params - def train_step(self, data: List[dict], + def train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: """Interface for model forward, backward and parameters updating during training process. @@ -105,7 +104,7 @@ class MMDistributedDataParallel(DistributedDataParallel): - Return log messages of losses. Args: - data (List[dict]): Data sampled by dataloader. + data (dict or tuple or list): Data sampled from dataset. optim_wrapper (OptimWrapper): A wrapper of optimizer to update parameters. @@ -114,33 +113,53 @@ class MMDistributedDataParallel(DistributedDataParallel): """ # Enable automatic mixed precision training context. with optim_wrapper.optim_context(self): - batch_inputs, data_samples = self.module.data_preprocessor( - data, training=True) - losses = self(batch_inputs, data_samples, mode='loss') + data = self.module.data_preprocessor(data, training=True) + losses = self._run_forward(data, mode='loss') if self.detect_anomalous_params: detect_anomalous_params(losses, model=self) parsed_loss, log_vars = self.module.parse_losses(losses) optim_wrapper.update_params(parsed_loss) return log_vars - def val_step(self, data: List[dict]) -> List[BaseDataElement]: + def val_step(self, data: Union[dict, tuple, list]) -> list: """Gets the prediction of module during validation process. Args: - data (List[dict]): Data sampled by dataloader. + data (dict or tuple or list): Data sampled from dataset. Returns: - List[BaseDataElement] or dict: The predictions of given data. + list: The predictions of given data. """ - return self.module.val_step(data) + data = self.module.data_preprocessor(data, training=False) + return self._run_forward(data, mode='predict') - def test_step(self, data: List[dict]) -> List[BaseDataElement]: + def test_step(self, data: Union[dict, tuple, list]) -> list: """Gets the predictions of module during testing process. Args: - data: Data sampled by dataloader. + data (dict or tuple or list): Data sampled from dataset. Returns: - List[BaseDataElement]: The predictions of given data. + list: The predictions of given data. """ - return self.module.test_step(data) + data = self.module.data_preprocessor(data, training=False) + return self._run_forward(data, mode='predict') + + def _run_forward(self, data: Union[dict, tuple, list], mode: str) -> Any: + """Unpacks data for :meth:`forward` + + Args: + data (dict or tuple or list): Data sampled from dataset. + mode (str): Mode of forward. + + Returns: + dict or list: Results of training or testing mode. + """ + if isinstance(data, dict): + results = self(**data, mode=mode) + elif isinstance(data, (list, tuple)): + results = self(*data, mode=mode) + else: + raise TypeError('Output of `data_preprocessor` should be ' + f'list, tuple or dict, but got {type(data)}') + return results diff --git a/mmengine/model/wrappers/fully_sharded_distributed.py b/mmengine/model/wrappers/fully_sharded_distributed.py index 5e430a9b..1dbb4263 100644 --- a/mmengine/model/wrappers/fully_sharded_distributed.py +++ b/mmengine/model/wrappers/fully_sharded_distributed.py @@ -153,7 +153,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): super().__init__(module, process_group, cpu_offload, fsdp_auto_wrap_policy, backward_prefetch) - def train_step(self, data: List[dict], + def train_step(self, data: dict, optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]: """Interface for model forward, backward and parameters updating during training process. @@ -168,7 +168,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): - Return log messages of losses. Args: - data (List[dict]): Data sampled by dataloader. + data (dict): Data sampled by dataloader. optim_wrapper (OptimWrapper): A wrapper of optimizer to update parameters. @@ -177,18 +177,23 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): """ # enable automatic mixed precision training context. with optim_wrapper.optim_context(self): - batch_inputs, data_samples = self.module.data_preprocessor( - data, training=True) - losses = self(batch_inputs, data_samples, mode='loss') + data = self.module.data_preprocessor(data, training=True) + if isinstance(data, dict): + losses = self(**data, mode='loss') + elif isinstance(data, (list, tuple)): + losses = self(*data, mode='loss') + else: + raise TypeError('Output of `data_preprocessor` should be ' + f'list tuple or dict, but got {type(data)}') parsed_loss, log_vars = self.module.parse_losses(losses) optim_wrapper.update_params(parsed_loss) return log_vars - def val_step(self, data: List[dict]) -> List[BaseDataElement]: + def val_step(self, data: dict) -> List[BaseDataElement]: """Gets the prediction of module during validation process. Args: - data (List[dict]): Data sampled by dataloader. + data (dict): Data sampled by dataloader. Returns: List[BaseDataElement] or dict: The predictions of given data. @@ -196,11 +201,11 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): inputs, data_sample = self.module.data_preprocessor(data, False) return self(inputs, data_sample, mode='predict') - def test_step(self, data: List[dict]) -> List[BaseDataElement]: + def test_step(self, data: dict) -> List[BaseDataElement]: """Gets the predictions of module during testing process. Args: - data: Data sampled by dataloader. + data (dict): Data sampled by dataloader. Returns: List[BaseDataElement]: The predictions of given data. diff --git a/mmengine/model/wrappers/seperate_distributed.py b/mmengine/model/wrappers/seperate_distributed.py index 9b1b260a..b7306326 100644 --- a/mmengine/model/wrappers/seperate_distributed.py +++ b/mmengine/model/wrappers/seperate_distributed.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import ExitStack, contextmanager -from typing import Dict, List +from typing import Dict, Union import torch import torch.nn as nn @@ -9,7 +9,6 @@ from torch.nn.parallel.distributed import DistributedDataParallel from mmengine.device import get_device from mmengine.optim import OptimWrapperDict from mmengine.registry import MODEL_WRAPPERS -from mmengine.structures import BaseDataElement from .distributed import MMDistributedDataParallel @@ -86,13 +85,13 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel): **kwargs) module._modules[name] = sub_module - def train_step(self, data: List[dict], + def train_step(self, data: Union[dict, tuple, list], optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: """Interface for model forward, backward and parameters updating during training process. Args: - data: Data sampled by dataloader. + data (dict or tuple or list): Data sampled from dataset. optim_wrapper (OptimWrapperDict): A wrapper of optimizer to update parameters. @@ -101,25 +100,25 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel): """ return self.module.train_step(data, optim_wrapper) - def val_step(self, data) -> List[BaseDataElement]: + def val_step(self, data: Union[dict, tuple, list]) -> list: """Gets the prediction of module during validation process. Args: - data (List[dict]): Data sampled by dataloader. + data (dict or tuple or list): Data sampled from dataset. Returns: - List[BaseDataElement]: The predictions of given data. + list: The predictions of given data. """ return self.module.val_step(data) - def test_step(self, data: List[dict]) -> List[BaseDataElement]: + def test_step(self, data: Union[dict, tuple, list]) -> list: """Gets the predictions of module during testing process. Args: - data: Data sampled by dataloader. + data (dict or tuple or list): Data sampled from dataset. Returns: - ForwardResults: The predictions of given data. + list: The predictions of given data. """ return self.module.test_step(data) diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index c8bab682..baa1230c 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -361,7 +361,7 @@ class ValLoop(BaseLoop): # outputs should be sequence of BaseDataElement with autocast(enabled=self.fp16): outputs = self.runner.model.val_step(data_batch) - self.evaluator.process(data_batch, outputs) + self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( 'after_val_iter', batch_idx=idx, @@ -429,10 +429,10 @@ class TestLoop(BaseLoop): 'before_test_iter', batch_idx=idx, data_batch=data_batch) # predictions should be sequence of BaseDataElement with autocast(enabled=self.fp16): - predictions = self.runner.model.test_step(data_batch) - self.evaluator.process(data_batch, predictions) + outputs = self.runner.model.test_step(data_batch) + self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( 'after_test_iter', batch_idx=idx, data_batch=data_batch, - outputs=predictions) + outputs=outputs) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index cc59e866..4eae559d 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -20,7 +20,7 @@ from torch.utils.data import DataLoader import mmengine from mmengine.config import Config, ConfigDict -from mmengine.dataset import pseudo_collate, worker_init_fn +from mmengine.dataset import COLLATE_FUNCTIONS, worker_init_fn from mmengine.device import get_device from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist, is_distributed, master_only, sync_random_seed) @@ -1393,14 +1393,19 @@ class Runner: # The default behavior of `collat_fn` in dataloader is to # merge a list of samples to form a mini-batch of Tensor(s). - # However, to make this more flexible, collate_fn in MMengine does - # nothing. The action to merge a list of samples will be handled - # in model. + # However, in mmengine, if `collate_fn` is not defined in + # dataloader_cfg, `pseudo_collate` will only convert the list of + # samples into a dict without stacking the batch tensor. + collate_fn_cfg = dataloader_cfg.pop('collate_fn', + dict(type='pseudo_collate')) + collate_fn_type = collate_fn_cfg.pop('type') + collate_fn = COLLATE_FUNCTIONS.get(collate_fn_type) + collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore data_loader = DataLoader( dataset=dataset, sampler=sampler if batch_sampler is None else None, batch_sampler=batch_sampler, - collate_fn=pseudo_collate, + collate_fn=collate_fn, worker_init_fn=init_fn, **dataloader_cfg) return data_loader diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py index e41287e6..79a05882 100644 --- a/mmengine/visualization/visualizer.py +++ b/mmengine/visualization/visualizer.py @@ -1080,8 +1080,7 @@ class Visualizer(ManagerMixin): def add_datasample(self, name, image: np.ndarray, - gt_sample: Optional['BaseDataElement'] = None, - pred_sample: Optional['BaseDataElement'] = None, + data_sample: Optional['BaseDataElement'] = None, draw_gt: bool = True, draw_pred: bool = True, show: bool = False, diff --git a/tests/test_data/test_data_utils.py b/tests/test_data/test_data_utils.py new file mode 100644 index 00000000..186bfad2 --- /dev/null +++ b/tests/test_data/test_data_utils.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import numpy as np +import torch + +from mmengine.dataset import default_collate, pseudo_collate +from mmengine.structures import BaseDataElement +from mmengine.utils import is_list_of + + +class TestDataUtils(TestCase): + + def test_pseudo_collate(self): + # Test with list of dict tensor inputs. + input1 = torch.randn(1, 3, 5) + input2 = torch.randn(1, 3, 5) + label1 = torch.randn(1) + label2 = torch.randn(1) + + data_batch = [ + dict(inputs=input1, data_sample=label1), + dict(inputs=input2, data_sample=label2) + ] + data_batch = pseudo_collate(data_batch) + self.assertTrue(torch.allclose(input1, data_batch['inputs'][0])) + self.assertTrue(torch.allclose(input2, data_batch['inputs'][1])) + self.assertTrue(torch.allclose(label1, data_batch['data_sample'][0])) + self.assertTrue(torch.allclose(label2, data_batch['data_sample'][1])) + + # Test with list of dict, and each element contains `data_sample` + # inputs + data_sample1 = BaseDataElement(label=torch.tensor(1)) + data_sample2 = BaseDataElement(label=torch.tensor(1)) + data = [ + dict(inputs=input1, data_sample=data_sample1), + dict(inputs=input2, data_sample=data_sample2), + ] + data_batch = pseudo_collate(data) + batch_inputs, batch_data_sample = (data_batch['inputs'], + data_batch['data_sample']) + # check batch_inputs + self.assertTrue(is_list_of(batch_inputs, torch.Tensor)) + self.assertIs(input1, batch_inputs[0]) + self.assertIs(input2, batch_inputs[1]) + + # check data_sample + self.assertIs(batch_data_sample[0], data_sample1) + self.assertIs(batch_data_sample[1], data_sample2) + + # Test with list of tuple, each tuple is a nested dict instance + data_batch = [(dict( + inputs=input1, + data_sample=data_sample1, + value=1, + name='1', + nested=dict(data_sample=data_sample1)), + dict( + inputs=input2, + data_sample=data_sample2, + value=2, + name='2', + nested=dict(data_sample=data_sample2))), + (dict( + inputs=input1, + data_sample=data_sample1, + value=1, + name='1', + nested=dict(data_sample=data_sample1)), + dict( + inputs=input2, + data_sample=data_sample2, + value=2, + name='2', + nested=dict(data_sample=data_sample2)))] + data_batch = pseudo_collate(data_batch) + batch_inputs_0 = data_batch[0]['inputs'] + batch_inputs_1 = data_batch[1]['inputs'] + batch_data_sample_0 = data_batch[0]['data_sample'] + batch_data_sample_1 = data_batch[1]['data_sample'] + batch_value_0 = data_batch[0]['value'] + batch_value_1 = data_batch[1]['value'] + batch_name_0 = data_batch[0]['name'] + batch_name_1 = data_batch[1]['name'] + batch_nested_0 = data_batch[0]['nested'] + batch_nested_1 = data_batch[1]['nested'] + + self.assertTrue(is_list_of(batch_inputs_0, torch.Tensor)) + self.assertTrue(is_list_of(batch_inputs_1, torch.Tensor)) + self.assertIs(batch_inputs_0[0], input1) + self.assertIs(batch_inputs_0[1], input1) + self.assertIs(batch_inputs_1[0], input2) + self.assertIs(batch_inputs_1[1], input2) + + self.assertIs(batch_data_sample_0[0], data_sample1) + self.assertIs(batch_data_sample_0[1], data_sample1) + self.assertIs(batch_data_sample_1[0], data_sample2) + self.assertIs(batch_data_sample_1[1], data_sample2) + + self.assertEqual(batch_value_0, [1, 1]) + self.assertEqual(batch_value_1, [2, 2]) + + self.assertEqual(batch_name_0, ['1', '1']) + self.assertEqual(batch_name_1, ['2', '2']) + + self.assertIs(batch_nested_0['data_sample'][0], data_sample1) + self.assertIs(batch_nested_0['data_sample'][1], data_sample1) + self.assertIs(batch_nested_1['data_sample'][0], data_sample2) + self.assertIs(batch_nested_1['data_sample'][1], data_sample2) + + def test_default_collate(self): + # `default_collate` has comment logic with `pseudo_collate`, therefore + # only test it cam stack batch tensor, convert int or float to tensor. + input1 = torch.randn(1, 3, 5) + input2 = torch.randn(1, 3, 5) + data_batch = [( + dict(inputs=input1, value=1, array=np.array(1)), + dict(inputs=input2, value=2, array=np.array(2)), + ), + ( + dict(inputs=input1, value=1, array=np.array(1)), + dict(inputs=input2, value=2, array=np.array(2)), + )] + data_batch = default_collate(data_batch) + batch_inputs_0 = data_batch[0]['inputs'] + batch_inputs_1 = data_batch[1]['inputs'] + batch_value_0 = data_batch[0]['value'] + batch_value_1 = data_batch[1]['value'] + batch_array_0 = data_batch[0]['array'] + batch_array_1 = data_batch[1]['array'] + + self.assertEqual(tuple(batch_inputs_0.shape), (2, 1, 3, 5)) + self.assertEqual(tuple(batch_inputs_1.shape), (2, 1, 3, 5)) + self.assertTrue( + torch.allclose(batch_inputs_0, torch.stack([input1, input1]))) + self.assertTrue( + torch.allclose(batch_inputs_1, torch.stack([input2, input2]))) + + self.assertTrue( + torch.allclose(batch_value_0, + torch.stack([torch.tensor(1), + torch.tensor(1)]))) + self.assertTrue( + torch.allclose(batch_value_1, + torch.stack([torch.tensor(2), + torch.tensor(2)]))) + + self.assertTrue( + torch.allclose(batch_array_0, + torch.stack([torch.tensor(1), + torch.tensor(1)]))) + self.assertTrue( + torch.allclose(batch_array_1, + torch.stack([torch.tensor(2), + torch.tensor(2)]))) diff --git a/tests/test_evaluator/test_evaluator.py b/tests/test_evaluator/test_evaluator.py index ef4ebbaf..fb02167e 100644 --- a/tests/test_evaluator/test_evaluator.py +++ b/tests/test_evaluator/test_evaluator.py @@ -41,9 +41,9 @@ class ToyMetric(BaseMetric): def process(self, data_batch, predictions): results = [{ - 'pred': pred.get('pred'), - 'label': data['data_sample'].get('label') - } for pred, data in zip(predictions, data_batch)] + 'pred': prediction['label'], + 'label': prediction['label'] + } for prediction in predictions] self.results.extend(results) def compute_metrics(self, results: List): @@ -68,8 +68,7 @@ class NonPrefixedMetric(BaseMetric): """Evaluator with unassigned `default_prefix` to test the warning information.""" - def process(self, data_batch: Sequence[dict], - predictions: Sequence[dict]) -> None: + def process(self, data_batch, predictions: Sequence[dict]) -> None: pass def compute_metrics(self, results: list) -> dict: @@ -81,12 +80,13 @@ def generate_test_results(size, batch_size, pred, label): bs_residual = size % batch_size for i in range(num_batch): bs = bs_residual if i == num_batch - 1 else batch_size - data_batch = [ - dict( - inputs=np.zeros((3, 10, 10)), - data_sample=BaseDataElement(label=label)) for _ in range(bs) + data_batch = { + 'inputs': [np.zeros((3, 10, 10)) for _ in range(bs)], + 'data_sample': [BaseDataElement(label=label) for _ in range(bs)] + } + predictions = [ + BaseDataElement(pred=pred, label=label) for _ in range(bs) ] - predictions = [BaseDataElement(pred=pred) for _ in range(bs)] yield (data_batch, predictions) @@ -99,9 +99,9 @@ class TestEvaluator(TestCase): size = 10 batch_size = 4 - for data_samples, predictions in generate_test_results( + for data_samples, outputs in generate_test_results( size, batch_size, pred=1, label=1): - evaluator.process(data_samples, predictions) + evaluator.process(data_samples=outputs, data_batch=data_samples) metrics = evaluator.evaluate(size=size) self.assertAlmostEqual(metrics['Toy/accuracy'], 1.0) @@ -124,9 +124,9 @@ class TestEvaluator(TestCase): size = 10 batch_size = 4 - for data_samples, predictions in generate_test_results( + for data_samples, outputs in generate_test_results( size, batch_size, pred=1, label=1): - evaluator.process(data_samples, predictions) + evaluator.process(data_samples=outputs, data_batch=data_samples) metrics = evaluator.evaluate(size=size) @@ -145,9 +145,9 @@ class TestEvaluator(TestCase): size = 10 batch_size = 4 - for data_samples, predictions in generate_test_results( + for data_samples, outputs in generate_test_results( size, batch_size, pred=1, label=1): - evaluator.process(data_samples, predictions) + evaluator.process(data_samples=outputs, data_batch=data_samples) with self.assertRaisesRegex( ValueError, @@ -233,13 +233,22 @@ class TestEvaluator(TestCase): size = 10 - all_data = [ - dict( - inputs=np.zeros((3, 10, 10)), - data_sample=BaseDataElement(label=1)) for _ in range(size) + all_data = [dict() for _ in range(10)] + all_predictions = [ + BaseDataElement(pred=0, label=1) for _ in range(size) ] - all_predictions = [BaseDataElement(pred=0) for _ in range(size)] - evaluator.offline_evaluate(all_data, all_predictions) + evaluator.offline_evaluate(all_predictions, all_data) + + # Test with None data + all_data = None + evaluator.offline_evaluate(all_predictions, all_data) + + # Different length of data and predictions will raise an error. + all_data = [dict() for _ in range(9)] + with self.assertRaisesRegex( + AssertionError, + 'outputs and data should have the same length'): + evaluator.offline_evaluate(all_predictions, all_data) @unittest.skipUnless(torch.cuda.is_available(), 'can only run with gpu') def test_evaluate_cast_cpu(self): @@ -256,11 +265,12 @@ class TestEvaluator(TestCase): for _ in range(size) ] all_predictions = [ - BaseDataElement(pred=torch.zeros((1, ), device='cuda')) - for _ in range(size) + BaseDataElement( + pred=torch.zeros((1, ), device='cuda'), + label=torch.ones((1, ), device='cuda')) for _ in range(size) ] for data, pred in zip(all_data, all_predictions): - evaluator.process([data], [pred]) + evaluator.process([pred], [data]) def test_results_device(results: List): for result in results: diff --git a/tests/test_evaluator/test_metric.py b/tests/test_evaluator/test_metric.py index 91d66ea1..fb7a181d 100644 --- a/tests/test_evaluator/test_metric.py +++ b/tests/test_evaluator/test_metric.py @@ -19,8 +19,8 @@ class TestDumpResults(TestCase): def test_process(self): metric = DumpResults(out_file_path='./results.pkl') - predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))] - metric.process(None, predictions) + data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))] + metric.process(None, data_samples) self.assertEqual(len(metric.results), 1) self.assertEqual(metric.results[0]['data'][0].device, torch.device('cpu')) @@ -29,8 +29,8 @@ class TestDumpResults(TestCase): temp_dir = tempfile.TemporaryDirectory() path = osp.join(temp_dir.name, 'results.pkl') metric = DumpResults(out_file_path=path) - predictions = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))] - metric.process(None, predictions) + data_samples = [dict(data=(Tensor([1, 2, 3]), Tensor([4, 5, 6])))] + metric.process(None, data_samples) metric.compute_metrics(metric.results) self.assertTrue(osp.isfile(path)) diff --git a/tests/test_hooks/test_ema_hook.py b/tests/test_hooks/test_ema_hook.py index f1aeea56..4fcced7d 100644 --- a/tests/test_hooks/test_ema_hook.py +++ b/tests/test_hooks/test_ema_hook.py @@ -22,9 +22,10 @@ class ToyModel(nn.Module): super().__init__() self.linear = nn.Linear(2, 1) - def forward(self, batch_inputs, labels, mode='tensor'): - labels = torch.stack(labels) - outputs = self.linear(batch_inputs) + def forward(self, inputs, data_sample, mode='tensor'): + labels = torch.stack(data_sample) + inputs = torch.stack(inputs) + outputs = self.linear(inputs) if mode == 'tensor': return outputs elif mode == 'loss': diff --git a/tests/test_model/test_base_model/test_base_model.py b/tests/test_model/test_base_model/test_base_model.py index 95ba3d46..22df50ec 100644 --- a/tests/test_model/test_base_model/test_base_model.py +++ b/tests/test_model/test_base_model/test_base_model.py @@ -28,15 +28,15 @@ class ToyModel(BaseModel): super().__init__(data_preprocessor=data_preprocessor, init_cfg=None) self.conv = nn.Conv2d(3, 1, 1) - def forward(self, batch_inputs, data_samples=None, mode='tensor'): + def forward(self, inputs, data_sample=None, mode='tensor'): if mode == 'loss': - out = self.conv(batch_inputs) + out = self.conv(inputs) return dict(loss=out) elif mode == 'predict': - out = self.conv(batch_inputs) + out = self.conv(inputs) return out elif mode == 'tensor': - out = self.conv(batch_inputs) + out = self.conv(inputs) return out @@ -98,34 +98,34 @@ class TestBaseModel(TestCase): model = ToyModel() optimizer = SGD(model.parameters(), lr=0.1) optim_wrapper = OptimWrapper(optimizer) - inputs = torch.randn(3, 1, 1) - data = dict(inputs=inputs) + inputs = torch.randn(1, 3, 1, 1) + data = dict(inputs=inputs, data_sample=None) # initiate grad. # model.conv.weight.grad = torch.randn(1, 3, 1, 1) - log_vars = model.train_step([data], optim_wrapper) + log_vars = model.train_step(data, optim_wrapper) self.assertIsNotNone(model.conv.weight.grad) self.assertIsInstance(log_vars['loss'], torch.Tensor) def test_val_step(self): - inputs = torch.randn(3, 1, 1) - data = dict(inputs=inputs) + inputs = torch.randn(1, 3, 1, 1) + data = dict(inputs=inputs, data_sample=None) model = ToyModel() - out = model.val_step([data]) + out = model.val_step(data) self.assertIsInstance(out, torch.Tensor) def test_test_step(self): - inputs = torch.randn(3, 1, 1) - data = dict(inputs=inputs) + inputs = torch.randn(1, 3, 1, 1) + data = dict(inputs=inputs, data_sample=None) model = ToyModel() - out = model.val_step([data]) + out = model.val_step(data) self.assertIsInstance(out, torch.Tensor) @unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available') def test_cuda(self): - inputs = torch.randn(3, 1, 1).cuda() - data = dict(inputs=inputs) + inputs = torch.randn(1, 3, 1, 1).cuda() + data = dict(inputs=inputs, data_sample=None) model = ToyModel().cuda() - out = model.val_step([data]) + out = model.val_step(data) self.assertEqual(out.device.type, 'cuda') model = NestedModel() @@ -139,10 +139,10 @@ class TestBaseModel(TestCase): @unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available') def test_to(self): - inputs = torch.randn(3, 1, 1).to('cuda:0') - data = dict(inputs=inputs) + inputs = torch.randn(1, 3, 1, 1).to('cuda:0') + data = dict(inputs=inputs, data_sample=None) model = ToyModel().to(torch.cuda.current_device()) - out = model.val_step([data]) + out = model.val_step(data) self.assertEqual(out.device.type, 'cuda') model = NestedModel() diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py index 18b773bd..15ba57d3 100644 --- a/tests/test_model/test_base_model/test_data_preprocessor.py +++ b/tests/test_model/test_base_model/test_data_preprocessor.py @@ -16,54 +16,73 @@ class TestBaseDataPreprocessor(TestCase): self.assertEqual(base_data_preprocessor._device.type, 'cpu') def test_forward(self): + # Test cpu forward with list of data samples. base_data_preprocessor = BaseDataPreprocessor() input1 = torch.randn(1, 3, 5) input2 = torch.randn(1, 3, 5) label1 = torch.randn(1) label2 = torch.randn(1) - data = [ - dict(inputs=input1, data_sample=label1), - dict(inputs=input2, data_sample=label2) - ] + data = dict(inputs=[input1, input2], data_sample=[label1, label2]) - batch_inputs, batch_labels = base_data_preprocessor(data) - self.assertTrue(torch.is_floating_point(batch_inputs)) - self.assertEqual(batch_inputs.shape, (2, 1, 3, 5)) + output = base_data_preprocessor(data) + batch_inputs, batch_labels = output['inputs'], output['data_sample'] + self.assertTrue(torch.is_floating_point(batch_inputs[0])) + self.assertEqual(batch_inputs[0].shape, (1, 3, 5)) assert_allclose(input1, batch_inputs[0]) assert_allclose(input2, batch_inputs[1]) assert_allclose(label1, batch_labels[0]) assert_allclose(label2, batch_labels[1]) + # Test with tuple of batch inputs and batch data samples + data = dict( + inputs=torch.stack([input1, input2]), data_sample=[label1, label2]) + output = base_data_preprocessor(data)['inputs'] + self.assertTrue(torch.is_floating_point(batch_inputs[0])) + + # Test cuda forward if torch.cuda.is_available(): + # Test with list of data samples. base_data_preprocessor = base_data_preprocessor.cuda() - batch_inputs, batch_labels = base_data_preprocessor(data) + output = base_data_preprocessor(data) + batch_inputs, batch_labels = output['inputs'], output[ + 'data_sample'] self.assertTrue(torch.is_floating_point(batch_inputs)) self.assertEqual(batch_inputs.device.type, 'cuda') base_data_preprocessor = base_data_preprocessor.cpu() - batch_inputs, batch_labels = base_data_preprocessor(data) + output = base_data_preprocessor(data) + batch_inputs, batch_labels = output['inputs'], output[ + 'data_sample'] self.assertTrue(torch.is_floating_point(batch_inputs)) self.assertEqual(batch_inputs.device.type, 'cpu') base_data_preprocessor = base_data_preprocessor.to('cuda:0') - batch_inputs, batch_labels = base_data_preprocessor(data) + output = base_data_preprocessor(data) + batch_inputs, batch_labels = output['inputs'], output[ + 'data_sample'] self.assertTrue(torch.is_floating_point(batch_inputs)) self.assertEqual(batch_inputs.device.type, 'cuda') + # device of `base_data_preprocessor` is cuda, output should be + # cuda tensor. + self.assertEqual(batch_inputs.device.type, 'cuda') + self.assertEqual(batch_labels[0].device.type, 'cuda') -class TestImgataPreprocessor(TestBaseDataPreprocessor): + +class TestImgDataPreprocessor(TestBaseDataPreprocessor): def test_init(self): - # initiate model without `data_preprocessor` + # Initiate processor without arguments data_processor = ImgDataPreprocessor() - self.assertFalse(data_processor.channel_conversion) + self.assertFalse(data_processor._channel_conversion) self.assertFalse(hasattr(data_processor, 'mean')) self.assertFalse(hasattr(data_processor, 'std')) self.assertEqual(data_processor.pad_size_divisor, 1) assert_allclose(data_processor.pad_value, torch.tensor(0)) - # initiate model with data_preprocessor` and feat keys + + # Initiate model with bgr2rgb, mean, std .etc.. data_processor = ImgDataPreprocessor( bgr_to_rgb=True, mean=[0, 0, 0], @@ -71,7 +90,7 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor): pad_size_divisor=16, pad_value=10) self.assertTrue(data_processor._enable_normalize) - self.assertTrue(data_processor.channel_conversion, True) + self.assertTrue(data_processor._channel_conversion, True) assert_allclose(data_processor.mean, torch.tensor([0, 0, 0]).view(-1, 1, 1)) assert_allclose(data_processor.std, @@ -113,10 +132,11 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor): inputs2 = torch.randn(3, 15, 15) data_sample1 = InstanceData(bboxes=torch.randn(5, 4)) data_sample2 = InstanceData(bboxes=torch.randn(5, 4)) - data = [ - dict(inputs=inputs1.clone(), data_sample=data_sample1.clone()), - dict(inputs=inputs2.clone(), data_sample=data_sample2.clone()) - ] + + data = dict( + inputs=[inputs1.clone(), inputs2.clone()], + data_sample=[data_sample1.clone(), + data_sample2.clone()]) std = torch.tensor([1, 2, 3]).view(-1, 1, 1) target_inputs1 = (inputs1.clone()[[2, 1, 0], ...] - 127.5) / std @@ -126,7 +146,8 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor): target_inputs2 = F.pad(target_inputs2, (0, 1, 0, 1), value=10) target_inputs = [target_inputs1, target_inputs2] - inputs, data_samples = data_preprocessor(data, True) + output = data_preprocessor(data, True) + inputs, data_samples = output['inputs'], output['data_sample'] self.assertTrue(torch.is_floating_point(inputs)) target_data_samples = [data_sample1, data_sample2] @@ -147,7 +168,8 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor): target_inputs2 = F.pad(target_inputs2, (0, 1, 0, 1), value=10) target_inputs = [target_inputs1, target_inputs2] - inputs, data_samples = data_preprocessor(data, True) + output = data_preprocessor(data, True) + inputs, data_samples = output['inputs'], output['data_sample'] self.assertTrue(torch.is_floating_point(inputs)) target_data_samples = [data_sample1, data_sample2] @@ -159,22 +181,53 @@ class TestImgataPreprocessor(TestBaseDataPreprocessor): # Test gray image with 3 dim mean will raise error data_preprocessor = ImgDataPreprocessor( mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5)) - data = [ - dict(inputs=torch.ones(10, 10)), - dict(inputs=torch.ones(10, 10)) - ] + data = dict( + inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None) with self.assertRaisesRegex(AssertionError, 'If the mean has 3 values'): data_preprocessor(data) - data = [ - dict(inputs=torch.ones(1, 10, 10)), - dict(inputs=torch.ones(1, 10, 10)) - ] + data = dict( + inputs=[torch.ones(10, 10), torch.ones(10, 10)], data_sample=None) with self.assertRaisesRegex(AssertionError, 'If the mean has 3 values'): data_preprocessor(data) + # Test stacked batch inputs and batch data samples + data_preprocessor = ImgDataPreprocessor( + mean=(127.5, 127.5, 127.5), + std=(127.5, 127.5, 127.5), + rgb_to_bgr=True, + pad_size_divisor=16) + _batch_inputs = torch.randn(2, 3, 10, 10) + _batch_labels = [torch.randn(1), torch.randn(1)] + data = dict(inputs=_batch_inputs, data_sample=_batch_labels) + output = data_preprocessor(data) + inputs, data_samples = output['inputs'], output['data_sample'] + target_batch_inputs = _batch_inputs[:, [2, 1, 0], ...] + target_batch_inputs = (target_batch_inputs - 127.5) / 127.5 + target_batch_inputs = F.pad(target_batch_inputs, (0, 6, 0, 6), value=0) + self.assertEqual(inputs.shape, torch.Size([2, 3, 16, 16])) + self.assertTrue(torch.is_floating_point(inputs)) + assert_allclose(target_batch_inputs, inputs) + + # Test batch inputs without convert channel order and pad + data_preprocessor = ImgDataPreprocessor( + mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5)) + _batch_inputs = torch.randn(2, 3, 10, 10) + _batch_labels = [torch.randn(1), torch.randn(1)] + data = dict(inputs=_batch_inputs, data_sample=_batch_labels) + output = data_preprocessor(data) + inputs, data_samples = output['inputs'], output['data_sample'] + target_batch_inputs = (_batch_inputs - 127.5) / 127.5 + self.assertEqual(inputs.shape, torch.Size([2, 3, 10, 10])) + self.assertTrue(torch.is_floating_point(inputs)) + assert_allclose(target_batch_inputs, inputs) + # Test empty `data_sample` - data = [dict(inputs=inputs1.clone()), dict(inputs=inputs2.clone())] - data_preprocessor(data, True) + data = dict( + inputs=[inputs1.clone(), inputs2.clone()], data_sample=None) + output = data_preprocessor(data, True) + inputs, data_samples = output['inputs'], output['data_sample'] + self.assertIsNone(data_samples) + self.assertTrue(torch.is_floating_point(inputs)) diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py index 14aaa221..4884b826 100644 --- a/tests/test_model/test_wrappers/test_model_wrapper.py +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -9,9 +9,10 @@ import torch.nn as nn from torch.optim import SGD from mmengine.dist import all_gather -from mmengine.model import (BaseModel, MMDistributedDataParallel, +from mmengine.model import (BaseDataPreprocessor, BaseModel, + ExponentialMovingAverage, + MMDistributedDataParallel, MMSeparateDistributedDataParallel) -from mmengine.model.averaged_model import ExponentialMovingAverage from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict from mmengine.testing import assert_allclose from mmengine.testing._internal import MultiProcessTestCase @@ -22,15 +23,22 @@ if digit_version(TORCH_VERSION) >= digit_version('1.11.0'): from mmengine.model import MMFullyShardedDataParallel # noqa: F401 +class ToyDataPreprocessor(BaseDataPreprocessor): + + def forward(self, data: dict, training: bool = False): + self.called = True + return super().forward(data, training) + + class ToyModel(BaseModel): def __init__(self): - super().__init__() + super().__init__(data_preprocessor=ToyDataPreprocessor()) self.conv1 = nn.Conv2d(3, 1, 1) self.conv2 = nn.Conv2d(1, 1, 1) - def forward(self, x, data_samples=None, mode='tensor'): - x = self.conv1(x) + def forward(self, inputs, data_sample=None, mode='tensor'): + x = self.conv1(inputs) x = self.conv2(x) if mode == 'loss': return dict(loss=x) @@ -48,10 +56,10 @@ class ComplexModel(BaseModel): self.conv2 = nn.Conv2d(3, 1, 1) def train_step(self, data, optim_wrapper): - batch_inputs, _ = self.data_preprocessor(data) - loss1 = self.conv1(batch_inputs) + inputs = self.data_preprocessor(data)['inputs'] + loss1 = self.conv1(inputs) optim_wrapper['optim_wrapper1'].update_params(loss1) - loss2 = self.conv2(batch_inputs) + loss2 = self.conv2(inputs) optim_wrapper['optim_wrapper2'].update_params(loss2) return dict(loss1=loss1, loss2=loss2) @@ -82,9 +90,9 @@ class TestDistributedDataParallel(MultiProcessTestCase): optimizer = SGD(ddp_model.parameters(), lr=0) optim_wrapper = AmpOptimWrapper( optimizer=optimizer, accumulative_counts=3) - inputs = torch.randn(3, 1, 1).cuda() * self.rank * 255 - data = dict(inputs=inputs, data_sample=MagicMock()) - res = ddp_model.train_step([data], optim_wrapper=optim_wrapper)['loss'] + inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255 + data = dict(inputs=inputs, data_sample=None) + res = ddp_model.train_step(data, optim_wrapper=optim_wrapper)['loss'] self.assertIs(res.dtype, torch.float16) grad = ddp_model.module.conv1.weight.grad all_grads = all_gather(grad) @@ -92,10 +100,10 @@ class TestDistributedDataParallel(MultiProcessTestCase): assert_allclose(all_grads[0], all_grads[1]) # Gradient accumulation - ddp_model.train_step([data], optim_wrapper=optim_wrapper) + ddp_model.train_step(data, optim_wrapper=optim_wrapper) # Test update params and clean grads. - ddp_model.train_step([data], optim_wrapper=optim_wrapper) + ddp_model.train_step(data, optim_wrapper=optim_wrapper) grad = ddp_model.module.conv1.weight.grad all_grads = all_gather(grad) assert_allclose(all_grads[0], torch.zeros_like(all_grads[0])) @@ -105,20 +113,22 @@ class TestDistributedDataParallel(MultiProcessTestCase): self._init_dist_env(self.rank, self.world_size) model = ToyModel() ddp_model = MMDistributedDataParallel(module=model) - inputs = torch.randn(3, 1, 1) * self.rank * 255 - data = dict(inputs=inputs, data_sample=MagicMock()) + inputs = torch.randn(1, 3, 1, 1) * self.rank * 255 + data = dict(inputs=inputs, data_sample=None) # Test get predictions. - predictions = ddp_model.val_step([data]) + predictions = ddp_model.val_step(data) self.assertIsInstance(predictions, torch.Tensor) + self.assertTrue(model.data_preprocessor.called) def test_test_step(self): self._init_dist_env(self.rank, self.world_size) model = ToyModel() ddp_model = MMDistributedDataParallel(module=model) - inputs = torch.randn(3, 1, 1) * self.rank * 255 - data = dict(inputs=inputs, data_sample=MagicMock()) - predictions = ddp_model.test_step([data]) + inputs = torch.randn(1, 3, 1, 1) * self.rank * 255 + data = dict(inputs=inputs, data_sample=None) + predictions = ddp_model.test_step(data) self.assertIsInstance(predictions, torch.Tensor) + self.assertTrue(model.data_preprocessor.called) def _init_dist_env(self, rank, world_size): """Initialize the distributed environment.""" @@ -157,30 +167,30 @@ class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel): optim_wrapper2 = OptimWrapper(optimizer2, 1) optim_wrapper_dict = OptimWrapperDict( optim_wrapper1=optim_wrapper1, optim_wrapper2=optim_wrapper2) - inputs = torch.randn(3, 1, 1).cuda() * self.rank * 255 - data = dict(inputs=inputs) + inputs = torch.randn(1, 3, 1, 1).cuda() * self.rank * 255 + data = dict(inputs=inputs, data_sample=None) # Automatically sync grads of `optim_wrapper1` since # `cumulative_iters` = 1 ddp_model.train() self.assertTrue(ddp_model.training) - ddp_model.train_step([data], optim_wrapper=optim_wrapper_dict) + ddp_model.train_step(data, optim_wrapper=optim_wrapper_dict) def test_val_step(self): self._init_dist_env(self.rank, self.world_size) model = ComplexModel() ddp_model = MMSeparateDistributedDataParallel(model) - data = torch.randn(3, 1, 1) + data = torch.randn(1, 3, 1, 1) # Test get predictions. ddp_model.eval() self.assertFalse(ddp_model.training) - predictions = ddp_model.val_step([data]) + predictions = ddp_model.val_step(data) self.assertEqual(predictions, 1) def test_test_step(self): self._init_dist_env(self.rank, self.world_size) model = ComplexModel() ddp_model = MMSeparateDistributedDataParallel(model) - data = torch.randn(3, 1, 1) + data = torch.randn(1, 3, 1, 1) # Test get predictions. ddp_model.eval() self.assertFalse(ddp_model.training) @@ -225,27 +235,27 @@ class TestMMFullyShardedDataParallel(MultiProcessTestCase): fsdp_model = MMFullyShardedDataParallel(module=model.cuda()) optimizer = SGD(fsdp_model.parameters(), lr=0) optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1) - inputs = torch.randn(3, 1, 1) * self.rank * 255 - data = dict(inputs=inputs, data_sample=MagicMock()) + inputs = torch.randn(1, 3, 1, 1) * self.rank * 255 + data = dict(inputs=[inputs], data_sample=MagicMock()) fsdp_model.train() self.assertTrue(fsdp_model.training) - fsdp_model.train_step([data], optim_wrapper=optim_wrapper) + fsdp_model.train_step(data, optim_wrapper=optim_wrapper) def test_val_step(self): self._init_dist_env(self.rank, self.world_size) model = ToyModel() fsdp_model = MMFullyShardedDataParallel(module=model.cuda()) - inputs = torch.randn(3, 1, 1) * self.rank * 255 - data = dict(inputs=inputs, data_sample=MagicMock()) + inputs = torch.randn(1, 3, 1, 1) * self.rank * 255 + data = dict(inputs=[inputs], data_sample=MagicMock()) # Test get predictions. - predictions = fsdp_model.val_step([data]) + predictions = fsdp_model.val_step(data) self.assertIsInstance(predictions, torch.Tensor) def test_test_step(self): self._init_dist_env(self.rank, self.world_size) model = ToyModel() fsdp_model = MMFullyShardedDataParallel(module=model.cuda()) - inputs = torch.randn(3, 1, 1) * self.rank * 255 + inputs = torch.randn(1, 3, 1, 1) * self.rank * 255 data = dict(inputs=inputs, data_sample=MagicMock()) - predictions = fsdp_model.test_step([data]) + predictions = fsdp_model.test_step(data) self.assertIsInstance(predictions, torch.Tensor) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 8fbdc967..b0f293fd 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -14,7 +14,7 @@ from torch.optim import SGD, Adam from torch.utils.data import DataLoader, Dataset from mmengine.config import Config -from mmengine.dataset import DefaultSampler +from mmengine.dataset import COLLATE_FUNCTIONS, DefaultSampler, pseudo_collate from mmengine.evaluator import BaseMetric, Evaluator from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook, IterTimerHook, LoggerHook, ParamSchedulerHook, @@ -44,15 +44,18 @@ class ToyModel(BaseModel): self.linear1 = nn.Linear(2, 2) self.linear2 = nn.Linear(2, 1) - def forward(self, batch_inputs, labels, mode='tensor'): - labels = torch.stack(labels) - outputs = self.linear1(batch_inputs) + def forward(self, inputs, data_sample, mode='tensor'): + if isinstance(inputs, list): + inputs = torch.stack(inputs) + if isinstance(data_sample, list): + data_sample = torch.stack(data_sample) + outputs = self.linear1(inputs) outputs = self.linear2(outputs) if mode == 'tensor': return outputs elif mode == 'loss': - loss = (labels - outputs).sum() + loss = (data_sample - outputs).sum() outputs = dict(loss=loss) return outputs elif mode == 'predict': @@ -74,15 +77,16 @@ class ToySyncBNModel(BaseModel): self.conv = nn.Conv2d(3, 8, 2) self.bn = nn.SyncBatchNorm(8) - def forward(self, batch_inputs, labels, mode='tensor'): - labels = torch.stack(labels) - outputs = self.conv(batch_inputs) + def forward(self, inputs, data_sample, mode='tensor'): + data_sample = torch.stack(data_sample) + inputs = torch.stack(inputs) + outputs = self.conv(inputs) outputs = self.bn(outputs) if mode == 'tensor': return outputs elif mode == 'loss': - loss = (labels - outputs).sum() + loss = (data_sample - outputs).sum() outputs = dict(loss=loss) return outputs elif mode == 'predict': @@ -98,24 +102,25 @@ class ToyGANModel(BaseModel): self.linear1 = nn.Linear(2, 1) self.linear2 = nn.Linear(2, 1) - def forward(self, batch_inputs, labels, mode='tensor'): - labels = torch.stack(labels) - output1 = self.linear1(batch_inputs) - output2 = self.linear2(batch_inputs) + def forward(self, inputs, data_sample, mode='tensor'): + data_sample = torch.stack(data_sample) + inputs = torch.stack(inputs) + output1 = self.linear1(inputs) + output2 = self.linear2(inputs) if mode == 'tensor': return output1, output2 elif mode == 'loss': - loss1 = (labels - output1).sum() - loss2 = (labels - output2).sum() + loss1 = (data_sample - output1).sum() + loss2 = (data_sample - output2).sum() outputs = dict(linear1=loss1, linear2=loss2) return outputs elif mode == 'predict': return output1, output2 def train_step(self, data, optim_wrapper): - batch_inputs, batch_labels = self.data_preprocessor(data) - loss = self(batch_inputs, batch_labels, mode='loss') + data = self.data_preprocessor(data) + loss = self(**data, mode='loss') optim_wrapper['linear1'].update_params(loss['linear1']) optim_wrapper['linear2'].update_params(loss['linear2']) return loss @@ -193,7 +198,7 @@ class ToyMetric1(BaseMetric): super().__init__(collect_device=collect_device) self.dummy_metrics = dummy_metrics - def process(self, data_samples, predictions): + def process(self, data_batch, predictions): result = {'acc': 1} self.results.append(result) @@ -208,7 +213,7 @@ class ToyMetric2(BaseMetric): super().__init__(collect_device=collect_device) self.dummy_metrics = dummy_metrics - def process(self, data_samples, predictions): + def process(self, data_batch, predictions): result = {'acc': 1} self.results.append(result) @@ -335,7 +340,12 @@ class ToyEvaluator(Evaluator): def collate_fn(data_batch): - return data_batch + return pseudo_collate(data_batch) + + +@COLLATE_FUNCTIONS.register_module() +def custom_collate(data_batch, pad_value): + return pseudo_collate(data_batch) class TestRunner(TestCase): @@ -1428,7 +1438,7 @@ class TestRunner(TestCase): val_batch_idx_targets): self.assertEqual(result, target) - # 5. test dynamic interval in IterBasedTrainLoop + # 5.1 test dynamic interval in IterBasedTrainLoop max_iters = 12 interval = 5 dynamic_intervals = [(11, 2)] @@ -1465,7 +1475,7 @@ class TestRunner(TestCase): for result, target, in zip(val_interval_results, val_interval_targets): self.assertEqual(result, target) - # 6. test dynamic interval in EpochBasedTrainLoop + # 5.2 test dynamic interval in EpochBasedTrainLoop max_epochs = 12 interval = 5 dynamic_intervals = [(11, 2)] @@ -1579,6 +1589,30 @@ class TestRunner(TestCase): runner = runner.from_cfg(cfg) runner.train() + # 10.1 Test build dataloader with default collate function + cfg = copy.deepcopy(self.iter_based_cfg) + cfg.experiment_name = 'test_train10.1' + cfg.train_dataloader.update(collate_fn=dict(type='default_collate')) + runner = Runner.from_cfg(cfg) + runner.train() + + # 10.2 Test build dataloader with custom collate function + cfg = copy.deepcopy(self.iter_based_cfg) + cfg.experiment_name = 'test_train10.2' + cfg.train_dataloader.update( + collate_fn=dict(type='custom_collate', pad_value=100)) + runner = Runner.from_cfg(cfg) + runner.train() + + # 11 test build dataloader without default arguments of collate + # function. + with self.assertRaises(TypeError): + cfg = copy.deepcopy(self.iter_based_cfg) + cfg.experiment_name = 'test_train11' + cfg.train_dataloader.update(collate_fn=dict(type='custom_collate')) + runner = Runner.from_cfg(cfg) + runner.train() + def test_val(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_val1' @@ -1890,8 +1924,6 @@ class TestRunner(TestCase): ckpt = torch.load(path) self.assertEqual(ckpt['meta']['epoch'], 3) self.assertEqual(ckpt['meta']['iter'], 12) - self.assertEqual(ckpt['meta']['dataset_meta'], - runner.train_dataloader.dataset.metainfo) self.assertEqual(ckpt['meta']['experiment_name'], runner.experiment_name) self.assertEqual(ckpt['meta']['seed'], runner.seed)