[Enhance] metainfo of dataset can be a generic dict-like Mapping (#1378)

pull/1292/head^2
hiyyg 2023-10-08 01:25:26 -05:00 committed by GitHub
parent 9cbe0665d9
commit eb5834fa66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 10 deletions

View File

@ -4,11 +4,13 @@ import functools
import gc
import logging
import pickle
from collections.abc import Mapping
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
from torch.utils.data import Dataset
from mmengine.config import Config
from mmengine.fileio import join_path, list_from_file, load
from mmengine.logging import print_log
from mmengine.registry import TRANSFORMS
@ -155,8 +157,8 @@ class BaseDataset(Dataset):
Args:
ann_file (str, optional): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
metainfo (Mapping or Config, optional): Meta information for
dataset, such as class information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (dict): Prefix for training data. Defaults to
@ -213,7 +215,7 @@ class BaseDataset(Dataset):
def __init__(self,
ann_file: Optional[str] = '',
metainfo: Optional[dict] = None,
metainfo: Union[Mapping, Config, None] = None,
data_root: Optional[str] = '',
data_prefix: dict = dict(img_path=''),
filter_cfg: Optional[dict] = None,
@ -472,13 +474,14 @@ class BaseDataset(Dataset):
return data_list
@classmethod
def _load_metainfo(cls, metainfo: dict = None) -> dict:
def _load_metainfo(cls,
metainfo: Union[Mapping, Config, None] = None) -> dict:
"""Collect meta information from the dictionary of meta.
Args:
metainfo (dict): Meta information dict. If ``metainfo``
contains existed filename, it will be parsed by
``list_from_file``.
metainfo (Mapping or Config, optional): Meta information dict.
If ``metainfo`` contains existed filename, it will be
parsed by ``list_from_file``.
Returns:
dict: Parsed meta information.
@ -487,9 +490,9 @@ class BaseDataset(Dataset):
cls_metainfo = copy.deepcopy(cls.METAINFO)
if metainfo is None:
return cls_metainfo
if not isinstance(metainfo, dict):
raise TypeError(
f'metainfo should be a dict, but got {type(metainfo)}')
if not isinstance(metainfo, (Mapping, Config)):
raise TypeError('metainfo should be a Mapping or Config, '
f'but got {type(metainfo)}')
for k, v in metainfo.items():
if isinstance(v, str):

View File

@ -6,6 +6,7 @@ from unittest.mock import MagicMock
import pytest
import torch
from mmengine.config import Config, ConfigDict
from mmengine.dataset import (BaseDataset, ClassBalancedDataset, Compose,
ConcatDataset, RepeatDataset, force_full_init)
from mmengine.registry import DATASETS, TRANSFORMS
@ -202,6 +203,39 @@ class TestBaseDataset:
task_name='new_task',
classes=('dog', ),
empty_list=[])
# test dataset.metainfo with passing metainfo as Config into
# self.base_dataset
metainfo = Config(dict(classes=('dog', ), task_name='new_task'))
dataset = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img_path='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=metainfo)
assert BaseDataset.METAINFO == dict(
dataset_type=dataset_type, classes=('dog', 'cat'))
assert dataset.metainfo == dict(
dataset_type=dataset_type,
task_name='new_task',
classes=('dog', ),
empty_list=[])
# test dataset.metainfo with passing metainfo as ConfigDict (Mapping)
# into self.base_dataset
metainfo = ConfigDict(dict(classes=('dog', ), task_name='new_task'))
dataset = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img_path='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=metainfo)
assert BaseDataset.METAINFO == dict(
dataset_type=dataset_type, classes=('dog', 'cat'))
assert dataset.metainfo == dict(
dataset_type=dataset_type,
task_name='new_task',
classes=('dog', ),
empty_list=[])
# reset `base_dataset.METAINFO`, the `dataset.metainfo` should not
# change
BaseDataset.METAINFO['classes'] = ('dog', 'cat', 'fish')