[Fix] Fix bug when loading class name form file in custom dataset (#923)

* [Fix] #916 expection string type classes

* add unittests for string path classes

* fix double quote string in test_dataset.py

* move the import to the top of the file

* fix isort lint error

fix isort lint error when move the import to the top of the file
pull/945/head
Shouping Shan 2021-10-08 01:06:18 +08:00 committed by GitHub
parent 8d49dd31e4
commit 796d5edebe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 2 deletions

View File

@ -319,7 +319,7 @@ class CustomDataset(Dataset):
raise ValueError(f'Unsupported type {type(classes)} of classes.')
if self.CLASSES:
if not set(classes).issubset(self.CLASSES):
if not set(class_names).issubset(self.CLASSES):
raise ValueError('classes is not a subset of CLASSES.')
# dictionary, its keys are the old label ids and its values
@ -330,7 +330,7 @@ class CustomDataset(Dataset):
if c not in class_names:
self.label_map[i] = -1
else:
self.label_map[i] = classes.index(c)
self.label_map[i] = class_names.index(c)
palette = self.get_palette_for_custom_classes(class_names, palette)

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
import shutil
import tempfile
from typing import Generator
from unittest.mock import MagicMock, patch
@ -26,6 +28,37 @@ def test_classes():
get_classes('unsupported')
def test_classes_file_path():
tmp_file = tempfile.NamedTemporaryFile()
classes_path = f'{tmp_file.name}.txt'
train_pipeline = [dict(type='LoadImageFromFile')]
kwargs = dict(pipeline=train_pipeline, img_dir='./', classes=classes_path)
# classes.txt with full categories
categories = get_classes('cityscapes')
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
assert list(CityscapesDataset(**kwargs).CLASSES) == categories
# classes.txt with sub categories
categories = ['road', 'sidewalk', 'building']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
assert list(CityscapesDataset(**kwargs).CLASSES) == categories
# classes.txt with unknown categories
categories = ['road', 'sidewalk', 'unknown']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
with pytest.raises(ValueError):
CityscapesDataset(**kwargs)
tmp_file.close()
os.remove(classes_path)
assert not osp.exists(classes_path)
def test_palette():
assert CityscapesDataset.PALETTE == get_palette('cityscapes')
assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(