diff --git a/mmengine/dataset/base_dataset.py b/mmengine/dataset/base_dataset.py index f866b826..4622f146 100644 --- a/mmengine/dataset/base_dataset.py +++ b/mmengine/dataset/base_dataset.py @@ -4,11 +4,13 @@ import functools import gc import logging import pickle +from collections.abc import Mapping from typing import Any, Callable, List, Optional, Sequence, Tuple, Union import numpy as np from torch.utils.data import Dataset +from mmengine.config import Config from mmengine.fileio import join_path, list_from_file, load from mmengine.logging import print_log from mmengine.registry import TRANSFORMS @@ -155,8 +157,8 @@ class BaseDataset(Dataset): Args: ann_file (str, optional): Annotation file path. Defaults to ''. - metainfo (dict, optional): Meta information for dataset, such as class - information. Defaults to None. + metainfo (Mapping or Config, optional): Meta information for + dataset, such as class information. Defaults to None. data_root (str, optional): The root directory for ``data_prefix`` and ``ann_file``. Defaults to ''. data_prefix (dict): Prefix for training data. Defaults to @@ -213,7 +215,7 @@ class BaseDataset(Dataset): def __init__(self, ann_file: Optional[str] = '', - metainfo: Optional[dict] = None, + metainfo: Union[Mapping, Config, None] = None, data_root: Optional[str] = '', data_prefix: dict = dict(img_path=''), filter_cfg: Optional[dict] = None, @@ -472,13 +474,14 @@ class BaseDataset(Dataset): return data_list @classmethod - def _load_metainfo(cls, metainfo: dict = None) -> dict: + def _load_metainfo(cls, + metainfo: Union[Mapping, Config, None] = None) -> dict: """Collect meta information from the dictionary of meta. Args: - metainfo (dict): Meta information dict. If ``metainfo`` - contains existed filename, it will be parsed by - ``list_from_file``. + metainfo (Mapping or Config, optional): Meta information dict. + If ``metainfo`` contains existed filename, it will be + parsed by ``list_from_file``. Returns: dict: Parsed meta information. @@ -487,9 +490,9 @@ class BaseDataset(Dataset): cls_metainfo = copy.deepcopy(cls.METAINFO) if metainfo is None: return cls_metainfo - if not isinstance(metainfo, dict): - raise TypeError( - f'metainfo should be a dict, but got {type(metainfo)}') + if not isinstance(metainfo, (Mapping, Config)): + raise TypeError('metainfo should be a Mapping or Config, ' + f'but got {type(metainfo)}') for k, v in metainfo.items(): if isinstance(v, str): diff --git a/tests/test_dataset/test_base_dataset.py b/tests/test_dataset/test_base_dataset.py index 0c364baa..f4ec815e 100644 --- a/tests/test_dataset/test_base_dataset.py +++ b/tests/test_dataset/test_base_dataset.py @@ -6,6 +6,7 @@ from unittest.mock import MagicMock import pytest import torch +from mmengine.config import Config, ConfigDict from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose, ConcatDataset, RepeatDataset, force_full_init) from mmengine.registry import DATASETS, TRANSFORMS @@ -202,6 +203,39 @@ class TestBaseDataset: task_name='new_task', classes=('dog', ), empty_list=[]) + + # test dataset.metainfo with passing metainfo as Config into + # self.base_dataset + metainfo = Config(dict(classes=('dog', ), task_name='new_task')) + dataset = BaseDataset( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo) + assert BaseDataset.METAINFO == dict( + dataset_type=dataset_type, classes=('dog', 'cat')) + assert dataset.metainfo == dict( + dataset_type=dataset_type, + task_name='new_task', + classes=('dog', ), + empty_list=[]) + + # test dataset.metainfo with passing metainfo as ConfigDict (Mapping) + # into self.base_dataset + metainfo = ConfigDict(dict(classes=('dog', ), task_name='new_task')) + dataset = BaseDataset( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img_path='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=metainfo) + assert BaseDataset.METAINFO == dict( + dataset_type=dataset_type, classes=('dog', 'cat')) + assert dataset.metainfo == dict( + dataset_type=dataset_type, + task_name='new_task', + classes=('dog', ), + empty_list=[]) + # reset `base_dataset.METAINFO`, the `dataset.metainfo` should not # change BaseDataset.METAINFO['classes'] = ('dog', 'cat', 'fish')