[Fix] Fix BaseDataset: join prefix in parse_data_info (#226)
* implement parse_data_info * add unit test * fix join prefix of ann_file * fix docstringpull/236/head
parent
f5867f8442
commit
cc8a6b86e1
|
@ -155,7 +155,7 @@ class BaseDataset(Dataset):
|
|||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||
dict(img=None, ann=None).
|
||||
dict(img_path=None, seg_path=None).
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
|
@ -211,7 +211,7 @@ class BaseDataset(Dataset):
|
|||
ann_file: str = '',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: dict = dict(img=None, ann=None),
|
||||
data_prefix: dict = dict(img_path=None, seg_path=None),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
|
@ -330,6 +330,12 @@ class BaseDataset(Dataset):
|
|||
Returns:
|
||||
list or list[dict]: Parsed annotation.
|
||||
"""
|
||||
for prefix_key, prefix in self.data_prefix.items():
|
||||
assert prefix_key in raw_data_info, (
|
||||
f'raw_data_info: {raw_data_info} dose not contain prefix key'
|
||||
f'{prefix_key}, please check your data_prefix.')
|
||||
raw_data_info[prefix_key] = osp.join(prefix,
|
||||
raw_data_info[prefix_key])
|
||||
return raw_data_info
|
||||
|
||||
def filter_data(self) -> List[dict]:
|
||||
|
@ -520,7 +526,7 @@ class BaseDataset(Dataset):
|
|||
"""
|
||||
# Automatically join annotation file path with `self.root` if
|
||||
# `self.ann_file` is not an absolute path.
|
||||
if not osp.isabs(self.ann_file):
|
||||
if not osp.isabs(self.ann_file) and self.ann_file:
|
||||
self.ann_file = osp.join(self.data_root, self.ann_file)
|
||||
# Automatically join data directory with `self.root` if path value in
|
||||
# `self.data_prefix` is not an absolute path.
|
||||
|
|
|
@ -39,11 +39,13 @@ class TestBaseDataset:
|
|||
filename='test_img.jpg', height=604, width=640, sample_idx=0)
|
||||
self.imgs = torch.rand((2, 3, 32, 32))
|
||||
self.ori_meta = BaseDataset.METAINFO
|
||||
self.ori_parse_data_info = BaseDataset.parse_data_info
|
||||
BaseDataset.parse_data_info = MagicMock(return_value=self.data_info)
|
||||
self.pipeline = MagicMock(return_value=dict(imgs=self.imgs))
|
||||
|
||||
def teardown(self):
|
||||
BaseDataset.METAINFO = self.ori_meta
|
||||
BaseDataset.parse_data_info = self.ori_parse_data_info
|
||||
|
||||
def test_init(self):
|
||||
# test the instantiation of self.base_dataset
|
||||
|
@ -83,7 +85,6 @@ class TestBaseDataset:
|
|||
lazy_init=True)
|
||||
assert not dataset._fully_initialized
|
||||
assert not dataset.data_list
|
||||
|
||||
# test the instantiation of self.base_dataset if ann_file is not
|
||||
# existed.
|
||||
with pytest.raises(FileNotFoundError):
|
||||
|
@ -147,6 +148,15 @@ class TestBaseDataset:
|
|||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
# test the instantiation of self.base_dataset without `ann_file`
|
||||
BaseDataset.parse_data_info = self.ori_parse_data_info
|
||||
dataset = BaseDataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img='imgs'),
|
||||
ann_file='',
|
||||
serialize_data=False,
|
||||
lazy_init=True)
|
||||
assert not dataset.ann_file
|
||||
|
||||
def test_meta(self):
|
||||
# test dataset.metainfo with setting the metainfo from annotation file
|
||||
|
@ -369,6 +379,16 @@ class TestBaseDataset:
|
|||
assert dataset.get_data_info(0) == self.data_info
|
||||
assert dataset._fully_initialized
|
||||
assert hasattr(dataset, 'data_list')
|
||||
# Test parse_data_info with `data_prefix`
|
||||
BaseDataset.parse_data_info = self.ori_parse_data_info
|
||||
data_root = osp.join(osp.dirname(__file__), '../data/')
|
||||
dataset = BaseDataset(
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/'),
|
||||
data_prefix=dict(img_path='imgs'),
|
||||
ann_file='annotations/dummy_annotation.json')
|
||||
data_info = dataset.get_data_info(0)
|
||||
assert data_info['img_path'] == osp.join(data_root, 'imgs',
|
||||
'test_img.jpg')
|
||||
|
||||
def test_force_full_init(self):
|
||||
with pytest.raises(AttributeError):
|
||||
|
|
Loading…
Reference in New Issue