[Fix] Fix bug when loading class name form file in custom dataset (#923)
* [Fix] #916 expection string type classes * add unittests for string path classes * fix double quote string in test_dataset.py * move the import to the top of the file * fix isort lint error fix isort lint error when move the import to the top of the filepull/945/head
parent
8d49dd31e4
commit
796d5edebe
|
@ -319,7 +319,7 @@ class CustomDataset(Dataset):
|
|||
raise ValueError(f'Unsupported type {type(classes)} of classes.')
|
||||
|
||||
if self.CLASSES:
|
||||
if not set(classes).issubset(self.CLASSES):
|
||||
if not set(class_names).issubset(self.CLASSES):
|
||||
raise ValueError('classes is not a subset of CLASSES.')
|
||||
|
||||
# dictionary, its keys are the old label ids and its values
|
||||
|
@ -330,7 +330,7 @@ class CustomDataset(Dataset):
|
|||
if c not in class_names:
|
||||
self.label_map[i] = -1
|
||||
else:
|
||||
self.label_map[i] = classes.index(c)
|
||||
self.label_map[i] = class_names.index(c)
|
||||
|
||||
palette = self.get_palette_for_custom_classes(class_names, palette)
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
@ -26,6 +28,37 @@ def test_classes():
|
|||
get_classes('unsupported')
|
||||
|
||||
|
||||
def test_classes_file_path():
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
classes_path = f'{tmp_file.name}.txt'
|
||||
train_pipeline = [dict(type='LoadImageFromFile')]
|
||||
kwargs = dict(pipeline=train_pipeline, img_dir='./', classes=classes_path)
|
||||
|
||||
# classes.txt with full categories
|
||||
categories = get_classes('cityscapes')
|
||||
with open(classes_path, 'w') as f:
|
||||
f.write('\n'.join(categories))
|
||||
assert list(CityscapesDataset(**kwargs).CLASSES) == categories
|
||||
|
||||
# classes.txt with sub categories
|
||||
categories = ['road', 'sidewalk', 'building']
|
||||
with open(classes_path, 'w') as f:
|
||||
f.write('\n'.join(categories))
|
||||
assert list(CityscapesDataset(**kwargs).CLASSES) == categories
|
||||
|
||||
# classes.txt with unknown categories
|
||||
categories = ['road', 'sidewalk', 'unknown']
|
||||
with open(classes_path, 'w') as f:
|
||||
f.write('\n'.join(categories))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
CityscapesDataset(**kwargs)
|
||||
|
||||
tmp_file.close()
|
||||
os.remove(classes_path)
|
||||
assert not osp.exists(classes_path)
|
||||
|
||||
|
||||
def test_palette():
|
||||
assert CityscapesDataset.PALETTE == get_palette('cityscapes')
|
||||
assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(
|
||||
|
|
Loading…
Reference in New Issue