[Fix] Fix bug when loading class name form file in custom dataset (#923)
* [Fix] #916 expection string type classes * add unittests for string path classes * fix double quote string in test_dataset.py * move the import to the top of the file * fix isort lint error fix isort lint error when move the import to the top of the filepull/1801/head
parent
1ce4904fe3
commit
adb1cd361b
|
@ -319,7 +319,7 @@ class CustomDataset(Dataset):
|
||||||
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
||||||
|
|
||||||
if self.CLASSES:
|
if self.CLASSES:
|
||||||
if not set(classes).issubset(self.CLASSES):
|
if not set(class_names).issubset(self.CLASSES):
|
||||||
raise ValueError('classes is not a subset of CLASSES.')
|
raise ValueError('classes is not a subset of CLASSES.')
|
||||||
|
|
||||||
# dictionary, its keys are the old label ids and its values
|
# dictionary, its keys are the old label ids and its values
|
||||||
|
@ -330,7 +330,7 @@ class CustomDataset(Dataset):
|
||||||
if c not in class_names:
|
if c not in class_names:
|
||||||
self.label_map[i] = -1
|
self.label_map[i] = -1
|
||||||
else:
|
else:
|
||||||
self.label_map[i] = classes.index(c)
|
self.label_map[i] = class_names.index(c)
|
||||||
|
|
||||||
palette = self.get_palette_for_custom_classes(class_names, palette)
|
palette = self.get_palette_for_custom_classes(class_names, palette)
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import shutil
|
import shutil
|
||||||
|
import tempfile
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
@ -26,6 +28,37 @@ def test_classes():
|
||||||
get_classes('unsupported')
|
get_classes('unsupported')
|
||||||
|
|
||||||
|
|
||||||
|
def test_classes_file_path():
|
||||||
|
tmp_file = tempfile.NamedTemporaryFile()
|
||||||
|
classes_path = f'{tmp_file.name}.txt'
|
||||||
|
train_pipeline = [dict(type='LoadImageFromFile')]
|
||||||
|
kwargs = dict(pipeline=train_pipeline, img_dir='./', classes=classes_path)
|
||||||
|
|
||||||
|
# classes.txt with full categories
|
||||||
|
categories = get_classes('cityscapes')
|
||||||
|
with open(classes_path, 'w') as f:
|
||||||
|
f.write('\n'.join(categories))
|
||||||
|
assert list(CityscapesDataset(**kwargs).CLASSES) == categories
|
||||||
|
|
||||||
|
# classes.txt with sub categories
|
||||||
|
categories = ['road', 'sidewalk', 'building']
|
||||||
|
with open(classes_path, 'w') as f:
|
||||||
|
f.write('\n'.join(categories))
|
||||||
|
assert list(CityscapesDataset(**kwargs).CLASSES) == categories
|
||||||
|
|
||||||
|
# classes.txt with unknown categories
|
||||||
|
categories = ['road', 'sidewalk', 'unknown']
|
||||||
|
with open(classes_path, 'w') as f:
|
||||||
|
f.write('\n'.join(categories))
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
CityscapesDataset(**kwargs)
|
||||||
|
|
||||||
|
tmp_file.close()
|
||||||
|
os.remove(classes_path)
|
||||||
|
assert not osp.exists(classes_path)
|
||||||
|
|
||||||
|
|
||||||
def test_palette():
|
def test_palette():
|
||||||
assert CityscapesDataset.PALETTE == get_palette('cityscapes')
|
assert CityscapesDataset.PALETTE == get_palette('cityscapes')
|
||||||
assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(
|
assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(
|
||||||
|
|
Loading…
Reference in New Issue