[Fix] Fix BaseDataset: join prefix in parse_data_info (#226)

* implement parse_data_info

* add unit test

* fix join prefix of ann_file

* fix docstring
pull/236/head
Mashiro 2022-05-17 20:53:13 +08:00 committed by GitHub
parent f5867f8442
commit cc8a6b86e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 4 deletions

View File

@ -155,7 +155,7 @@ class BaseDataset(Dataset):
data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None.
data_prefix (dict, optional): Prefix for training data. Defaults to
dict(img=None, ann=None).
dict(img_path=None, seg_path=None).
filter_cfg (dict, optional): Config for filter data. Defaults to None.
indices (int or Sequence[int], optional): Support using first few
data in annotation file to facilitate training/testing on a smaller
@ -211,7 +211,7 @@ class BaseDataset(Dataset):
ann_file: str = '',
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: dict = dict(img=None, ann=None),
data_prefix: dict = dict(img_path=None, seg_path=None),
filter_cfg: Optional[dict] = None,
indices: Optional[Union[int, Sequence[int]]] = None,
serialize_data: bool = True,
@ -330,6 +330,12 @@ class BaseDataset(Dataset):
Returns:
list or list[dict]: Parsed annotation.
"""
for prefix_key, prefix in self.data_prefix.items():
assert prefix_key in raw_data_info, (
f'raw_data_info: {raw_data_info} dose not contain prefix key'
f'{prefix_key}, please check your data_prefix.')
raw_data_info[prefix_key] = osp.join(prefix,
raw_data_info[prefix_key])
return raw_data_info
def filter_data(self) -> List[dict]:
@ -520,7 +526,7 @@ class BaseDataset(Dataset):
"""
# Automatically join annotation file path with `self.root` if
# `self.ann_file` is not an absolute path.
if not osp.isabs(self.ann_file):
if not osp.isabs(self.ann_file) and self.ann_file:
self.ann_file = osp.join(self.data_root, self.ann_file)
# Automatically join data directory with `self.root` if path value in
# `self.data_prefix` is not an absolute path.

View File

@ -39,11 +39,13 @@ class TestBaseDataset:
filename='test_img.jpg', height=604, width=640, sample_idx=0)
self.imgs = torch.rand((2, 3, 32, 32))
self.ori_meta = BaseDataset.METAINFO
self.ori_parse_data_info = BaseDataset.parse_data_info
BaseDataset.parse_data_info = MagicMock(return_value=self.data_info)
self.pipeline = MagicMock(return_value=dict(imgs=self.imgs))
def teardown(self):
BaseDataset.METAINFO = self.ori_meta
BaseDataset.parse_data_info = self.ori_parse_data_info
def test_init(self):
# test the instantiation of self.base_dataset
@ -83,7 +85,6 @@ class TestBaseDataset:
lazy_init=True)
assert not dataset._fully_initialized
assert not dataset.data_list
# test the instantiation of self.base_dataset if ann_file is not
# existed.
with pytest.raises(FileNotFoundError):
@ -147,6 +148,15 @@ class TestBaseDataset:
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json')
# test the instantiation of self.base_dataset without `ann_file`
BaseDataset.parse_data_info = self.ori_parse_data_info
dataset = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='',
serialize_data=False,
lazy_init=True)
assert not dataset.ann_file
def test_meta(self):
# test dataset.metainfo with setting the metainfo from annotation file
@ -369,6 +379,16 @@ class TestBaseDataset:
assert dataset.get_data_info(0) == self.data_info
assert dataset._fully_initialized
assert hasattr(dataset, 'data_list')
# Test parse_data_info with `data_prefix`
BaseDataset.parse_data_info = self.ori_parse_data_info
data_root = osp.join(osp.dirname(__file__), '../data/')
dataset = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img_path='imgs'),
ann_file='annotations/dummy_annotation.json')
data_info = dataset.get_data_info(0)
assert data_info['img_path'] == osp.join(data_root, 'imgs',
'test_img.jpg')
def test_force_full_init(self):
with pytest.raises(AttributeError):