[Fix] Fix bug when loading class name form file in custom dataset (#923)

* [Fix] #916 expection string type classes

* add unittests for string path classes

* fix double quote string in test_dataset.py

* move the import to the top of the file

* fix isort lint error

fix isort lint error when move the import to the top of the file
pull/1801/head
Shouping Shan 2021-10-08 01:06:18 +08:00 committed by GitHub
parent 1ce4904fe3
commit adb1cd361b
2 changed files with 35 additions and 2 deletions

View File

@ -319,7 +319,7 @@ class CustomDataset(Dataset):
raise ValueError(f'Unsupported type {type(classes)} of classes.') raise ValueError(f'Unsupported type {type(classes)} of classes.')
if self.CLASSES: if self.CLASSES:
if not set(classes).issubset(self.CLASSES): if not set(class_names).issubset(self.CLASSES):
raise ValueError('classes is not a subset of CLASSES.') raise ValueError('classes is not a subset of CLASSES.')
# dictionary, its keys are the old label ids and its values # dictionary, its keys are the old label ids and its values
@ -330,7 +330,7 @@ class CustomDataset(Dataset):
if c not in class_names: if c not in class_names:
self.label_map[i] = -1 self.label_map[i] = -1
else: else:
self.label_map[i] = classes.index(c) self.label_map[i] = class_names.index(c)
palette = self.get_palette_for_custom_classes(class_names, palette) palette = self.get_palette_for_custom_classes(class_names, palette)

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp import os.path as osp
import shutil import shutil
import tempfile
from typing import Generator from typing import Generator
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -26,6 +28,37 @@ def test_classes():
get_classes('unsupported') get_classes('unsupported')
def test_classes_file_path():
tmp_file = tempfile.NamedTemporaryFile()
classes_path = f'{tmp_file.name}.txt'
train_pipeline = [dict(type='LoadImageFromFile')]
kwargs = dict(pipeline=train_pipeline, img_dir='./', classes=classes_path)
# classes.txt with full categories
categories = get_classes('cityscapes')
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
assert list(CityscapesDataset(**kwargs).CLASSES) == categories
# classes.txt with sub categories
categories = ['road', 'sidewalk', 'building']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
assert list(CityscapesDataset(**kwargs).CLASSES) == categories
# classes.txt with unknown categories
categories = ['road', 'sidewalk', 'unknown']
with open(classes_path, 'w') as f:
f.write('\n'.join(categories))
with pytest.raises(ValueError):
CityscapesDataset(**kwargs)
tmp_file.close()
os.remove(classes_path)
assert not osp.exists(classes_path)
def test_palette(): def test_palette():
assert CityscapesDataset.PALETTE == get_palette('cityscapes') assert CityscapesDataset.PALETTE == get_palette('cityscapes')
assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette( assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(