[Refactor] Add VOC dataset and mutil_label dataset
parent
522ab1fd84
commit
d6fa480915
|
@ -16,8 +16,9 @@ from .voc import VOC
|
|||
|
||||
__all__ = [
|
||||
'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
|
||||
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
|
||||
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
|
||||
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS',
|
||||
'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB', 'CustomDataset'
|
||||
'VOC', 'build_dataloader', 'build_dataset', 'DistributedSampler',
|
||||
'ConcatDataset', 'RepeatDataset', 'ClassBalancedDataset', 'DATASETS',
|
||||
'PIPELINES', 'ImageNet21k', 'SAMPLERS', 'build_sampler',
|
||||
'RepeatAugSampler', 'KFoldDataset', 'CUB', 'CustomDataset',
|
||||
'MultiLabelDataset'
|
||||
]
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# Pre-defined categories names of various datasets.
|
||||
|
||||
VOC2007_CATEGORIES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
|
||||
'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
|
||||
'horse', 'motorbike', 'person', 'pottedplant', 'sheep',
|
||||
'sofa', 'train', 'tvmonitor')
|
||||
|
||||
IMAGENET_CATEGORIES = (
|
||||
'tench, Tinca tinca',
|
||||
'goldfish, Carassius auratus',
|
||||
|
|
|
@ -1,14 +1,80 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mmcls.core import average_performance, mAP
|
||||
from mmcls.registry import DATASETS
|
||||
from .base_dataset import BaseDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MultiLabelDataset(BaseDataset):
|
||||
"""Multi-label Dataset."""
|
||||
"""Multi-label Dataset.
|
||||
|
||||
This dataset support annotation file in `OpenMMLab 2.0 style annotation
|
||||
format`.
|
||||
|
||||
.. _OpenMMLab 2.0 style annotation format:
|
||||
https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md
|
||||
|
||||
The annotation format is shown as follows.
|
||||
|
||||
.. code-block:: none
|
||||
{
|
||||
"metainfo":
|
||||
{
|
||||
"classes":['A', 'B', 'C'....]
|
||||
},
|
||||
"data_list":
|
||||
[
|
||||
{
|
||||
"img_path": "test_img1.jpg",
|
||||
'img_label': [0, 1],
|
||||
},
|
||||
{
|
||||
"img_path": "test_img2.jpg",
|
||||
'img_label': [2],
|
||||
},
|
||||
]
|
||||
....
|
||||
}
|
||||
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path.
|
||||
metainfo (dict, optional): Meta information for dataset, such as class
|
||||
information. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (str | dict, optional): Prefix for training data. Defaults
|
||||
to None.
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to None which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=False``. Defaults to False.
|
||||
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
classes (str | Sequence[str], optional): Specify names of classes.
|
||||
|
||||
- If is string, it should be a file path, and the every line of
|
||||
the file is a name of a class.
|
||||
- If is a sequence of string, every item is a name of class.
|
||||
- If is None, use categories information in ``metainfo`` argument,
|
||||
annotation file or the class attribute ``METAINFO``.
|
||||
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def get_cat_ids(self, idx: int) -> List[int]:
|
||||
"""Get category ids by index.
|
||||
|
@ -19,61 +85,4 @@ class MultiLabelDataset(BaseDataset):
|
|||
Returns:
|
||||
cat_ids (List[int]): Image categories of specified index.
|
||||
"""
|
||||
gt_labels = self.data_infos[idx]['gt_label']
|
||||
cat_ids = np.where(gt_labels == 1)[0].tolist()
|
||||
return cat_ids
|
||||
|
||||
def evaluate(self,
|
||||
results,
|
||||
metric='mAP',
|
||||
metric_options=None,
|
||||
indices=None,
|
||||
logger=None):
|
||||
"""Evaluate the dataset.
|
||||
|
||||
Args:
|
||||
results (list): Testing results of the dataset.
|
||||
metric (str | list[str]): Metrics to be evaluated.
|
||||
Default value is 'mAP'. Options are 'mAP', 'CP', 'CR', 'CF1',
|
||||
'OP', 'OR' and 'OF1'.
|
||||
metric_options (dict, optional): Options for calculating metrics.
|
||||
Allowed keys are 'k' and 'thr'. Defaults to None
|
||||
logger (logging.Logger | str, optional): Logger used for printing
|
||||
related information during evaluation. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: evaluation results
|
||||
"""
|
||||
if metric_options is None or metric_options == {}:
|
||||
metric_options = {'thr': 0.5}
|
||||
|
||||
if isinstance(metric, str):
|
||||
metrics = [metric]
|
||||
else:
|
||||
metrics = metric
|
||||
allowed_metrics = ['mAP', 'CP', 'CR', 'CF1', 'OP', 'OR', 'OF1']
|
||||
eval_results = {}
|
||||
results = np.vstack(results)
|
||||
gt_labels = self.get_gt_labels()
|
||||
if indices is not None:
|
||||
gt_labels = gt_labels[indices]
|
||||
num_imgs = len(results)
|
||||
assert len(gt_labels) == num_imgs, 'dataset testing results should '\
|
||||
'be of the same length as gt_labels.'
|
||||
|
||||
invalid_metrics = set(metrics) - set(allowed_metrics)
|
||||
if len(invalid_metrics) != 0:
|
||||
raise ValueError(f'metric {invalid_metrics} is not supported.')
|
||||
|
||||
if 'mAP' in metrics:
|
||||
mAP_value = mAP(results, gt_labels)
|
||||
eval_results['mAP'] = mAP_value
|
||||
if len(set(metrics) - {'mAP'}) != 0:
|
||||
performance_keys = ['CP', 'CR', 'CF1', 'OP', 'OR', 'OF1']
|
||||
performance_values = average_performance(results, gt_labels,
|
||||
**metric_options)
|
||||
for k, v in zip(performance_keys, performance_values):
|
||||
if k in metrics:
|
||||
eval_results[k] = v
|
||||
|
||||
return eval_results
|
||||
return self.get_data_info(idx)['gt_label']
|
||||
|
|
|
@ -1,69 +1,141 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine import FileClient, list_from_file
|
||||
|
||||
from mmcls.registry import DATASETS
|
||||
from .base_dataset import expanduser
|
||||
from .categories import VOC2007_CATEGORIES
|
||||
from .multi_label import MultiLabelDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class VOC(MultiLabelDataset):
|
||||
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset."""
|
||||
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Dataset.
|
||||
|
||||
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
|
||||
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
|
||||
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
|
||||
'tvmonitor')
|
||||
After decompression, the dataset directory structure is as follows:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(VOC, self).__init__(**kwargs)
|
||||
if 'VOC2007' in self.data_prefix:
|
||||
self.year = 2007
|
||||
VOC dataset directory: ::
|
||||
|
||||
VOC2007 (data_root)/
|
||||
├── JPEGImages (data_prefix['img_path'])
|
||||
│ ├── xxx.jpg
|
||||
│ ├── xxy.jpg
|
||||
│ └── ...
|
||||
├── Annotations (data_prefix['ann_path'])
|
||||
│ ├── xxx.xml
|
||||
│ ├── xxy.xml
|
||||
│ └── ...
|
||||
└── ImageSets (directory contains various imageset file)
|
||||
|
||||
Args:
|
||||
data_root (str): The root directory for VOC dataset.
|
||||
image_set_path (str): The path of image set, The file which
|
||||
lists image ids of the sub dataset, and this path is relative
|
||||
to ``data_root``.
|
||||
data_prefix (dict): Prefix for data and annotation, keyword
|
||||
'img_path' and 'ann_path' can be set. Defaults to be
|
||||
``dict(img_path='JPEGImages', ann_path='Annotations')``.
|
||||
test_mode (bool): ``test_mode=True`` means in test phase.
|
||||
It determines to use the training set or test set.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
categories information. Defaults to None.
|
||||
**kwargs: Other keyword arguments in :class:`BaseDataset`.
|
||||
""" # noqa: E501
|
||||
|
||||
METAINFO = {'classes': VOC2007_CATEGORIES}
|
||||
|
||||
def __init__(self,
|
||||
data_root: str,
|
||||
image_set_path: str,
|
||||
data_prefix: Union[str, dict] = dict(
|
||||
img_path='JPEGImages', ann_path='Annotations'),
|
||||
test_mode: bool = False,
|
||||
metainfo: Optional[dict] = None,
|
||||
**kwargs):
|
||||
if isinstance(data_prefix, str):
|
||||
data_prefix = dict(img_path=expanduser(data_prefix))
|
||||
assert isinstance(data_prefix, dict) and 'img_path' in data_prefix, \
|
||||
'`data_prefix` must be a dict with key img_path'
|
||||
|
||||
if test_mode is False:
|
||||
assert 'ann_path' in data_prefix and data_prefix[
|
||||
'ann_path'] is not None, \
|
||||
'"ann_path" must be set in `data_prefix` if `test_mode` is' \
|
||||
' False.'
|
||||
|
||||
self.data_root = data_root
|
||||
self.file_client = FileClient.infer_client(uri=data_root)
|
||||
self.image_set_path = self.file_client.join_path(
|
||||
data_root, image_set_path)
|
||||
|
||||
super().__init__(
|
||||
ann_file='',
|
||||
metainfo=metainfo,
|
||||
data_root=data_root,
|
||||
data_prefix=data_prefix,
|
||||
test_mode=test_mode,
|
||||
**kwargs)
|
||||
|
||||
@property
|
||||
def ann_prefix(self):
|
||||
"""The prefix of images."""
|
||||
if 'ann_path' in self.data_prefix:
|
||||
return self.data_prefix['ann_path']
|
||||
else:
|
||||
raise ValueError('Cannot infer dataset year from img_prefix.')
|
||||
return None
|
||||
|
||||
def load_annotations(self):
|
||||
"""Load annotations.
|
||||
def _get_labels_from_xml(self, img_id):
|
||||
"""Get gt_labels and labels_difficult from xml file."""
|
||||
xml_path = self.file_client.join_path(self.ann_prefix, f'{img_id}.xml')
|
||||
content = self.file_client.get(xml_path)
|
||||
root = ET.fromstring(content)
|
||||
|
||||
labels, labels_difficult = set(), set()
|
||||
for obj in root.findall('object'):
|
||||
label_name = obj.find('name').text
|
||||
# in case customized dataset has wrong labels
|
||||
# or CLASSES has been override.
|
||||
if label_name not in self.CLASSES:
|
||||
continue
|
||||
label = self.class_to_idx[label_name]
|
||||
difficult = int(obj.find('difficult').text)
|
||||
if difficult:
|
||||
labels_difficult.add(label)
|
||||
else:
|
||||
labels.add(label)
|
||||
|
||||
return list(labels), list(labels_difficult)
|
||||
|
||||
def load_data_list(self):
|
||||
"""Load images and ground truth labels."""
|
||||
data_list = []
|
||||
img_ids = list_from_file(self.image_set_path)
|
||||
|
||||
Returns:
|
||||
list[dict]: Annotation info from XML file.
|
||||
"""
|
||||
data_infos = []
|
||||
img_ids = mmcv.list_from_file(self.ann_file)
|
||||
for img_id in img_ids:
|
||||
filename = f'JPEGImages/{img_id}.jpg'
|
||||
xml_path = osp.join(self.data_prefix, 'Annotations',
|
||||
f'{img_id}.xml')
|
||||
tree = ET.parse(xml_path)
|
||||
root = tree.getroot()
|
||||
labels = []
|
||||
labels_difficult = []
|
||||
for obj in root.findall('object'):
|
||||
label_name = obj.find('name').text
|
||||
# in case customized dataset has wrong labels
|
||||
# or CLASSES has been override.
|
||||
if label_name not in self.CLASSES:
|
||||
continue
|
||||
label = self.class_to_idx[label_name]
|
||||
difficult = int(obj.find('difficult').text)
|
||||
if difficult:
|
||||
labels_difficult.append(label)
|
||||
else:
|
||||
labels.append(label)
|
||||
img_path = self.file_client.join_path(self.img_prefix,
|
||||
f'{img_id}.jpg')
|
||||
|
||||
gt_label = np.zeros(len(self.CLASSES))
|
||||
# The order cannot be swapped for the case where multiple objects
|
||||
# of the same kind exist and some are difficult.
|
||||
gt_label[labels_difficult] = -1
|
||||
gt_label[labels] = 1
|
||||
labels, labels_difficult = None, None
|
||||
if self.ann_prefix is not None:
|
||||
labels, labels_difficult = self._get_labels_from_xml(img_id)
|
||||
|
||||
info = dict(
|
||||
img_prefix=self.data_prefix,
|
||||
img_info=dict(filename=filename),
|
||||
gt_label=gt_label.astype(np.int8))
|
||||
data_infos.append(info)
|
||||
img_path=img_path,
|
||||
gt_label=labels,
|
||||
gt_label_difficult=labels_difficult)
|
||||
data_list.append(info)
|
||||
|
||||
return data_infos
|
||||
return data_list
|
||||
|
||||
def extra_repr(self) -> List[str]:
|
||||
"""The extra repr information of the dataset."""
|
||||
body = [
|
||||
f'Prefix of dataset: \t{self.data_root}',
|
||||
f'Path of image set: \t{self.image_set_path}',
|
||||
f'Prefix of images: \t{self.img_prefix}',
|
||||
f'Prefix of annotations: \t{self.ann_prefix}'
|
||||
]
|
||||
|
||||
return body
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
{
|
||||
"metainfo": {
|
||||
"categories": [
|
||||
{
|
||||
"category_name": "first",
|
||||
"id": 0
|
||||
},
|
||||
{
|
||||
"category_name": "second",
|
||||
"id": 1
|
||||
}
|
||||
]
|
||||
},
|
||||
"data_list": [
|
||||
{
|
||||
"img_path": "a/1.JPG",
|
||||
"gt_label": [0]
|
||||
},
|
||||
{
|
||||
"img_path": "b/2.jpeg",
|
||||
"gt_label": [1]
|
||||
},
|
||||
{
|
||||
"img_path": "b/subb/2.jpeg",
|
||||
"gt_label": [0, 1]
|
||||
}
|
||||
]
|
||||
}
|
|
@ -12,8 +12,6 @@ from mmengine.registry import TRANSFORMS
|
|||
from mmcls.datasets import DATASETS
|
||||
from mmcls.utils import get_root_logger
|
||||
|
||||
# import torch
|
||||
|
||||
mmcls_logger = get_root_logger()
|
||||
ASSETS_ROOT = osp.abspath(
|
||||
osp.join(osp.dirname(__file__), '../../data/dataset'))
|
||||
|
@ -100,72 +98,6 @@ class TestBaseDataset(TestCase):
|
|||
repr(dataset))
|
||||
|
||||
|
||||
"""Temporarily disabled.
|
||||
class TestMultiLabelDataset(TestBaseDataset):
|
||||
DATASET_TYPE = 'MultiLabelDataset'
|
||||
|
||||
def test_get_cat_ids(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
fake_ann = [
|
||||
dict(
|
||||
img_prefix='',
|
||||
img_info=dict(),
|
||||
gt_label=np.array([0, 1, 1, 0], dtype=np.uint8))
|
||||
]
|
||||
|
||||
with patch.object(dataset_class, 'load_annotations') as mock_load:
|
||||
mock_load.return_value = fake_ann
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
|
||||
cat_ids = dataset.get_cat_ids(0)
|
||||
self.assertIsInstance(cat_ids, list)
|
||||
self.assertEqual(len(cat_ids), 2)
|
||||
self.assertIsInstance(cat_ids[0], int)
|
||||
self.assertEqual(cat_ids, [1, 2])
|
||||
|
||||
def test_evaluate(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
fake_ann = [
|
||||
dict(gt_label=np.array([1, 1, 0, -1], dtype=np.int8)),
|
||||
dict(gt_label=np.array([1, 1, 0, -1], dtype=np.int8)),
|
||||
dict(gt_label=np.array([0, -1, 1, -1], dtype=np.int8)),
|
||||
dict(gt_label=np.array([0, 1, 0, -1], dtype=np.int8)),
|
||||
dict(gt_label=np.array([0, 1, 0, -1], dtype=np.int8)),
|
||||
]
|
||||
|
||||
with patch.object(dataset_class, 'load_annotations') as mock_load:
|
||||
mock_load.return_value = fake_ann
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
|
||||
fake_results = np.array([
|
||||
[0.9, 0.8, 0.3, 0.2],
|
||||
[0.1, 0.2, 0.2, 0.1],
|
||||
[0.7, 0.5, 0.9, 0.3],
|
||||
[0.8, 0.1, 0.1, 0.2],
|
||||
[0.8, 0.1, 0.1, 0.2],
|
||||
])
|
||||
|
||||
# the metric must be valid for the dataset
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"{'unknown'} is not supported"):
|
||||
dataset.evaluate(fake_results, metric='unknown')
|
||||
|
||||
# only one metric
|
||||
eval_results = dataset.evaluate(fake_results, metric='mAP')
|
||||
self.assertEqual(eval_results.keys(), {'mAP'})
|
||||
self.assertAlmostEqual(eval_results['mAP'], 67.5, places=4)
|
||||
|
||||
# multiple metrics
|
||||
eval_results = dataset.evaluate(
|
||||
fake_results, metric=['mAP', 'CR', 'OF1'])
|
||||
self.assertEqual(eval_results.keys(), {'mAP', 'CR', 'OF1'})
|
||||
self.assertAlmostEqual(eval_results['mAP'], 67.50, places=2)
|
||||
self.assertAlmostEqual(eval_results['CR'], 43.75, places=2)
|
||||
self.assertAlmostEqual(eval_results['OF1'], 42.86, places=2)
|
||||
"""
|
||||
|
||||
|
||||
class TestCustomDataset(TestBaseDataset):
|
||||
DATASET_TYPE = 'CustomDataset'
|
||||
|
||||
|
@ -529,6 +461,200 @@ class TestCIFAR100(TestCIFAR10):
|
|||
DATASET_TYPE = 'CIFAR100'
|
||||
|
||||
|
||||
class TestMultiLabelDataset(TestBaseDataset):
|
||||
DATASET_TYPE = 'MultiLabelDataset'
|
||||
|
||||
DEFAULT_ARGS = dict(data_root=ASSETS_ROOT, ann_file='multi_label_ann.json')
|
||||
|
||||
def test_get_cat_ids(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
cfg = {**self.DEFAULT_ARGS}
|
||||
dataset = dataset_class(**cfg)
|
||||
|
||||
cat_ids = dataset.get_cat_ids(0)
|
||||
self.assertTrue(cat_ids, [0])
|
||||
|
||||
cat_ids = dataset.get_cat_ids(1)
|
||||
self.assertTrue(cat_ids, [1])
|
||||
|
||||
cat_ids = dataset.get_cat_ids(1)
|
||||
self.assertTrue(cat_ids, [0, 1])
|
||||
|
||||
|
||||
class TestVOC(TestBaseDataset):
|
||||
DATASET_TYPE = 'VOC'
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
super().setUpClass()
|
||||
|
||||
tmpdir = tempfile.TemporaryDirectory()
|
||||
cls.tmpdir = tmpdir
|
||||
data_root = tmpdir.name
|
||||
|
||||
cls.DEFAULT_ARGS = dict(
|
||||
data_root=data_root,
|
||||
image_set_path='ImageSets/train.txt',
|
||||
data_prefix=dict(img_path='JPEGImages', ann_path='Annotations'),
|
||||
pipeline=[],
|
||||
test_mode=False)
|
||||
|
||||
cls.image_folder = osp.join(data_root, 'JPEGImages')
|
||||
cls.ann_folder = osp.join(data_root, 'Annotations')
|
||||
cls.image_set_folder = osp.join(data_root, 'ImageSets')
|
||||
os.mkdir(cls.image_set_folder)
|
||||
os.mkdir(cls.image_folder)
|
||||
os.mkdir(cls.ann_folder)
|
||||
|
||||
cls.fake_img_paths = [f'{i}' for i in range(6)]
|
||||
cls.fake_labels = [[
|
||||
np.random.randint(10) for _ in range(np.random.randint(1, 4))
|
||||
] for _ in range(6)]
|
||||
cls.fake_classes = [f'C_{i}' for i in range(10)]
|
||||
train_list = [i for i in range(0, 4)]
|
||||
test_list = [i for i in range(4, 6)]
|
||||
|
||||
with open(osp.join(cls.image_set_folder, 'train.txt'), 'w') as f:
|
||||
for train_item in train_list:
|
||||
f.write(str(train_item) + '\n')
|
||||
with open(osp.join(cls.image_set_folder, 'test.txt'), 'w') as f:
|
||||
for test_item in test_list:
|
||||
f.write(str(test_item) + '\n')
|
||||
with open(osp.join(cls.image_set_folder, 'full_path_test.txt'),
|
||||
'w') as f:
|
||||
for test_item in test_list:
|
||||
f.write(osp.join(cls.image_folder, str(test_item)) + '\n')
|
||||
|
||||
for train_item in train_list:
|
||||
with open(osp.join(cls.ann_folder, f'{train_item}.xml'), 'w') as f:
|
||||
temple = '<object><name>C_{}</name>{}</object>'
|
||||
ann_data = ''.join([
|
||||
temple.format(label, '<difficult>0</difficult>')
|
||||
for label in cls.fake_labels[train_item]
|
||||
])
|
||||
# add difficult label
|
||||
ann_data += ''.join([
|
||||
temple.format(label, '<difficult>1</difficult>')
|
||||
for label in cls.fake_labels[train_item]
|
||||
])
|
||||
xml_ann_data = f'<annotation>{ann_data}</annotation>'
|
||||
f.write(xml_ann_data + '\n')
|
||||
|
||||
for test_item in test_list:
|
||||
with open(osp.join(cls.ann_folder, f'{test_item}.xml'), 'w') as f:
|
||||
temple = '<object><name>C_{}</name>{}</object>'
|
||||
ann_data = ''.join([
|
||||
temple.format(label, '<difficult>0</difficult>')
|
||||
for label in cls.fake_labels[test_item]
|
||||
])
|
||||
xml_ann_data = f'<annotation>{ann_data}</annotation>'
|
||||
f.write(xml_ann_data + '\n')
|
||||
|
||||
def test_initialize(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
# Test overriding metainfo by `classes` argument
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': ['bus', 'car']}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
|
||||
|
||||
# Test overriding CLASSES by classes file
|
||||
classes_file = osp.join(ASSETS_ROOT, 'classes.txt')
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': classes_file}
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(dataset.CLASSES, ('bus', 'car'))
|
||||
self.assertEqual(dataset.class_to_idx, {'bus': 0, 'car': 1})
|
||||
|
||||
# Test invalid classes
|
||||
cfg = {**self.DEFAULT_ARGS, 'classes': dict(classes=1)}
|
||||
with self.assertRaisesRegex(ValueError, "type <class 'dict'>"):
|
||||
dataset_class(**cfg)
|
||||
|
||||
def test_get_cat_ids(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
cfg = {'classes': self.fake_classes, **self.DEFAULT_ARGS}
|
||||
dataset = dataset_class(**cfg)
|
||||
|
||||
cat_ids = dataset.get_cat_ids(0)
|
||||
self.assertIsInstance(cat_ids, list)
|
||||
self.assertIsInstance(cat_ids[0], int)
|
||||
|
||||
def test_load_data_list(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
|
||||
# Test default behavior
|
||||
dataset = dataset_class(**self.DEFAULT_ARGS)
|
||||
self.assertEqual(len(dataset), 4)
|
||||
self.assertEqual(len(dataset.CLASSES), 20)
|
||||
|
||||
cfg = {
|
||||
'classes': self.fake_classes,
|
||||
'lazy_init': True,
|
||||
**self.DEFAULT_ARGS
|
||||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
|
||||
self.assertIn("Haven't been initialized", repr(dataset))
|
||||
dataset.full_init()
|
||||
self.assertIn(f'Number of samples: \t{len(dataset)}', repr(dataset))
|
||||
|
||||
data_info = dataset[0]
|
||||
fake_img_path = osp.join(self.image_folder, self.fake_img_paths[0])
|
||||
self.assertEqual(data_info['img_path'], f'{fake_img_path}.jpg')
|
||||
self.assertEqual(set(data_info['gt_label']), set(self.fake_labels[0]))
|
||||
|
||||
# Test with test_mode=True
|
||||
cfg['image_set_path'] = 'ImageSets/test.txt'
|
||||
cfg['test_mode'] = True
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 2)
|
||||
|
||||
data_info = dataset[0]
|
||||
fake_img_path = osp.join(self.image_folder, self.fake_img_paths[4])
|
||||
self.assertEqual(data_info['img_path'], f'{fake_img_path}.jpg')
|
||||
self.assertEqual(set(data_info['gt_label']), set(self.fake_labels[4]))
|
||||
|
||||
# Test with test_mode=True and ann_path = None
|
||||
cfg['image_set_path'] = 'ImageSets/test.txt'
|
||||
cfg['test_mode'] = True
|
||||
cfg['data_prefix'] = 'JPEGImages'
|
||||
dataset = dataset_class(**cfg)
|
||||
self.assertEqual(len(dataset), 2)
|
||||
|
||||
data_info = dataset[0]
|
||||
fake_img_path = osp.join(self.image_folder, self.fake_img_paths[4])
|
||||
self.assertEqual(data_info['img_path'], f'{fake_img_path}.jpg')
|
||||
self.assertEqual(data_info['gt_label'], None)
|
||||
|
||||
# Test different backend
|
||||
cfg = {
|
||||
**self.DEFAULT_ARGS, 'lazy_init': True,
|
||||
'data_root': 's3:/openmmlab/voc'
|
||||
}
|
||||
dataset = dataset_class(**cfg)
|
||||
dataset._check_integrity = MagicMock(return_value=False)
|
||||
with self.assertRaisesRegex(FileNotFoundError, 's3:/openmmlab/voc'):
|
||||
dataset.full_init()
|
||||
|
||||
def test_extra_repr(self):
|
||||
dataset_class = DATASETS.get(self.DATASET_TYPE)
|
||||
cfg = {**self.DEFAULT_ARGS}
|
||||
dataset = dataset_class(**cfg)
|
||||
|
||||
self.assertIn(f'Path of image set: \t{dataset.image_set_path}',
|
||||
repr(dataset))
|
||||
self.assertIn(f'Prefix of dataset: \t{dataset.data_root}',
|
||||
repr(dataset))
|
||||
self.assertIn(f'Prefix of annotations: \t{dataset.ann_prefix}',
|
||||
repr(dataset))
|
||||
self.assertIn(f'Prefix of images: \t{dataset.img_prefix}',
|
||||
repr(dataset))
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.tmpdir.cleanup()
|
||||
|
||||
|
||||
"""Temporarily disabled.
|
||||
|
||||
class TestMNIST(TestBaseDataset):
|
||||
|
@ -593,12 +719,6 @@ class TestMNIST(TestBaseDataset):
|
|||
def tearDownClass(cls):
|
||||
cls.tmpdir.cleanup()
|
||||
|
||||
|
||||
class TestVOC(TestMultiLabelDataset):
|
||||
DATASET_TYPE = 'VOC'
|
||||
|
||||
DEFAULT_ARGS = dict(data_prefix='VOC2007', pipeline=[])
|
||||
|
||||
class TestCUB(TestBaseDataset):
|
||||
DATASET_TYPE = 'CUB'
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ def test_pcpvt():
|
|||
assert len(outs) == 1
|
||||
assert outs[-1].shape == (1, 512, H // 32, W // 32)
|
||||
|
||||
# test with mutil outputs
|
||||
# test with multi outputs
|
||||
model = PCPVT('small', out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
outs = model(temp)
|
||||
|
@ -160,7 +160,7 @@ def test_svt():
|
|||
assert len(outs) == 1
|
||||
assert outs[-1].shape == (1, 512, H // 32, W // 32)
|
||||
|
||||
# test with mutil outputs
|
||||
# test with multi outputs
|
||||
model = SVT('small', out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
outs = model(temp)
|
||||
|
|
Loading…
Reference in New Issue