[Enhance] metainfo of dataset can be a generic dict-like Mapping (#1378)
parent
9cbe0665d9
commit
eb5834fa66
|
@ -4,11 +4,13 @@ import functools
|
|||
import gc
|
||||
import logging
|
||||
import pickle
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from mmengine.config import Config
|
||||
from mmengine.fileio import join_path, list_from_file, load
|
||||
from mmengine.logging import print_log
|
||||
from mmengine.registry import TRANSFORMS
|
||||
|
@ -155,8 +157,8 @@ class BaseDataset(Dataset):
|
|||
|
||||
Args:
|
||||
ann_file (str, optional): Annotation file path. Defaults to ''.
|
||||
metainfo (dict, optional): Meta information for dataset, such as class
|
||||
information. Defaults to None.
|
||||
metainfo (Mapping or Config, optional): Meta information for
|
||||
dataset, such as class information. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to ''.
|
||||
data_prefix (dict): Prefix for training data. Defaults to
|
||||
|
@ -213,7 +215,7 @@ class BaseDataset(Dataset):
|
|||
|
||||
def __init__(self,
|
||||
ann_file: Optional[str] = '',
|
||||
metainfo: Optional[dict] = None,
|
||||
metainfo: Union[Mapping, Config, None] = None,
|
||||
data_root: Optional[str] = '',
|
||||
data_prefix: dict = dict(img_path=''),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
|
@ -472,13 +474,14 @@ class BaseDataset(Dataset):
|
|||
return data_list
|
||||
|
||||
@classmethod
|
||||
def _load_metainfo(cls, metainfo: dict = None) -> dict:
|
||||
def _load_metainfo(cls,
|
||||
metainfo: Union[Mapping, Config, None] = None) -> dict:
|
||||
"""Collect meta information from the dictionary of meta.
|
||||
|
||||
Args:
|
||||
metainfo (dict): Meta information dict. If ``metainfo``
|
||||
contains existed filename, it will be parsed by
|
||||
``list_from_file``.
|
||||
metainfo (Mapping or Config, optional): Meta information dict.
|
||||
If ``metainfo`` contains existed filename, it will be
|
||||
parsed by ``list_from_file``.
|
||||
|
||||
Returns:
|
||||
dict: Parsed meta information.
|
||||
|
@ -487,9 +490,9 @@ class BaseDataset(Dataset):
|
|||
cls_metainfo = copy.deepcopy(cls.METAINFO)
|
||||
if metainfo is None:
|
||||
return cls_metainfo
|
||||
if not isinstance(metainfo, dict):
|
||||
raise TypeError(
|
||||
f'metainfo should be a dict, but got {type(metainfo)}')
|
||||
if not isinstance(metainfo, (Mapping, Config)):
|
||||
raise TypeError('metainfo should be a Mapping or Config, '
|
||||
f'but got {type(metainfo)}')
|
||||
|
||||
for k, v in metainfo.items():
|
||||
if isinstance(v, str):
|
||||
|
|
|
@ -6,6 +6,7 @@ from unittest.mock import MagicMock
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.config import Config, ConfigDict
|
||||
from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
|
||||
ConcatDataset, RepeatDataset, force_full_init)
|
||||
from mmengine.registry import DATASETS, TRANSFORMS
|
||||
|
@ -202,6 +203,39 @@ class TestBaseDataset:
|
|||
task_name='new_task',
|
||||
classes=('dog', ),
|
||||
empty_list=[])
|
||||
|
||||
# test dataset.metainfo with passing metainfo as Config into
|
||||
# self.base_dataset
|
||||
metainfo = Config(dict(classes=('dog', ), task_name='new_task'))
|
||||
dataset = BaseDataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img_path='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
metainfo=metainfo)
|
||||
assert BaseDataset.METAINFO == dict(
|
||||
dataset_type=dataset_type, classes=('dog', 'cat'))
|
||||
assert dataset.metainfo == dict(
|
||||
dataset_type=dataset_type,
|
||||
task_name='new_task',
|
||||
classes=('dog', ),
|
||||
empty_list=[])
|
||||
|
||||
# test dataset.metainfo with passing metainfo as ConfigDict (Mapping)
|
||||
# into self.base_dataset
|
||||
metainfo = ConfigDict(dict(classes=('dog', ), task_name='new_task'))
|
||||
dataset = BaseDataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img_path='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json',
|
||||
metainfo=metainfo)
|
||||
assert BaseDataset.METAINFO == dict(
|
||||
dataset_type=dataset_type, classes=('dog', 'cat'))
|
||||
assert dataset.metainfo == dict(
|
||||
dataset_type=dataset_type,
|
||||
task_name='new_task',
|
||||
classes=('dog', ),
|
||||
empty_list=[])
|
||||
|
||||
# reset `base_dataset.METAINFO`, the `dataset.metainfo` should not
|
||||
# change
|
||||
BaseDataset.METAINFO['classes'] = ('dog', 'cat', 'fish')
|
||||
|
|
Loading…
Reference in New Issue