mmsegmentation/tests/test_data/test_dataset.py

401 lines
14 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
2020-07-07 20:52:19 +08:00
import os.path as osp
[Refactor] Support progressive test with fewer memory cost (#709) * Support progressive test with fewer memory cost. * Temp code * Using processor to refactor evaluation workflow. * refactor eval hook. * Fix process bar. * Fix middle save argument. * Modify some variable name of dataset evaluate api. * Modify some viriable name of eval hook. * Fix some priority bugs of eval hook. * Depreciated efficient_test. * Fix training progress blocked by eval hook. * Depreciated old test api. * Fix test api error. * Modify outer api. * Build a sampler test api. * TODO: Refactor format_results. * Modify variable names. * Fix num_classes bug. * Fix sampler index bug. * Fix grammaly bug. * Support batch sampler. * More readable test api. * Remove some command arg and fix eval hook bug. * Support format-only arg. * Modify format_results of datasets. * Modify tool which use test apis. * support cityscapes eval * fixed cityscapes * 1. Add comments for batch_sampler; 2. Keep eval hook api same and add deprecated warning; 3. Add doc string for dataset.pre_eval; * Add efficient_test doc string. * Modify test tool to compat old version. * Modify eval hook to compat with old version. * Modify test api to compat old version api. * Sampler explanation. * update warning * Modify deploy_test.py * compatible with old output, add efficient test back * clear logic of exclusive * Warning about efficient_test. * Modify format_results save folder. * Fix bugs of format_results. * Modify deploy_test.py. * Update doc * Fix deploy test bugs. * Fix custom dataset unit tests. * Fix dataset unit tests. * Fix eval hook unit tests. * Fix some imcompatible. * Add pre_eval argument for eval hooks. * Update eval hook doc string. * Make pre_eval false in default. * Add unit tests for dataset format_results. * Fix some comments and bc-breaking bug. * Fix pre_eval set cfg field. * Remove redundant codes. Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
2021-08-20 11:44:58 +08:00
import shutil
from typing import Generator
2020-07-07 20:52:19 +08:00
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
[Refactor] Support progressive test with fewer memory cost (#709) * Support progressive test with fewer memory cost. * Temp code * Using processor to refactor evaluation workflow. * refactor eval hook. * Fix process bar. * Fix middle save argument. * Modify some variable name of dataset evaluate api. * Modify some viriable name of eval hook. * Fix some priority bugs of eval hook. * Depreciated efficient_test. * Fix training progress blocked by eval hook. * Depreciated old test api. * Fix test api error. * Modify outer api. * Build a sampler test api. * TODO: Refactor format_results. * Modify variable names. * Fix num_classes bug. * Fix sampler index bug. * Fix grammaly bug. * Support batch sampler. * More readable test api. * Remove some command arg and fix eval hook bug. * Support format-only arg. * Modify format_results of datasets. * Modify tool which use test apis. * support cityscapes eval * fixed cityscapes * 1. Add comments for batch_sampler; 2. Keep eval hook api same and add deprecated warning; 3. Add doc string for dataset.pre_eval; * Add efficient_test doc string. * Modify test tool to compat old version. * Modify eval hook to compat with old version. * Modify test api to compat old version api. * Sampler explanation. * update warning * Modify deploy_test.py * compatible with old output, add efficient test back * clear logic of exclusive * Warning about efficient_test. * Modify format_results save folder. * Fix bugs of format_results. * Modify deploy_test.py. * Update doc * Fix deploy test bugs. * Fix custom dataset unit tests. * Fix dataset unit tests. * Fix eval hook unit tests. * Fix some imcompatible. * Add pre_eval argument for eval hooks. * Update eval hook doc string. * Make pre_eval false in default. * Add unit tests for dataset format_results. * Fix some comments and bc-breaking bug. * Fix pre_eval set cfg field. * Remove redundant codes. Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
2021-08-20 11:44:58 +08:00
from PIL import Image
2020-07-07 20:52:19 +08:00
from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
ConcatDataset, CustomDataset, PascalVOCDataset,
RepeatDataset)
2020-07-07 20:52:19 +08:00
def test_classes():
assert list(CityscapesDataset.CLASSES) == get_classes('cityscapes')
assert list(PascalVOCDataset.CLASSES) == get_classes('voc') == get_classes(
'pascal_voc')
assert list(
ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k')
with pytest.raises(ValueError):
get_classes('unsupported')
def test_palette():
assert CityscapesDataset.PALETTE == get_palette('cityscapes')
assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(
'pascal_voc')
assert ADE20KDataset.PALETTE == get_palette('ade') == get_palette('ade20k')
with pytest.raises(ValueError):
get_palette('unsupported')
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
def test_dataset_wrapper():
# CustomDataset.load_annotations = MagicMock()
# CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
dataset_a = CustomDataset(img_dir=MagicMock(), pipeline=[])
len_a = 10
dataset_a.img_infos = MagicMock()
dataset_a.img_infos.__len__.return_value = len_a
dataset_b = CustomDataset(img_dir=MagicMock(), pipeline=[])
len_b = 20
dataset_b.img_infos = MagicMock()
dataset_b.img_infos.__len__.return_value = len_b
concat_dataset = ConcatDataset([dataset_a, dataset_b])
assert concat_dataset[5] == 5
assert concat_dataset[25] == 15
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
repeat_dataset = RepeatDataset(dataset_a, 10)
assert repeat_dataset[5] == 5
assert repeat_dataset[15] == 5
assert repeat_dataset[27] == 7
assert len(repeat_dataset) == 10 * len(dataset_a)
def test_custom_dataset():
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
to_rgb=True)
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
2020-07-07 20:52:19 +08:00
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(128, 256),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
# with img_dir and ann_dir
train_dataset = CustomDataset(
train_pipeline,
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
img_dir='imgs/',
ann_dir='gts/',
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# with img_dir, ann_dir, split
train_dataset = CustomDataset(
train_pipeline,
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
img_dir='imgs/',
ann_dir='gts/',
img_suffix='img.jpg',
seg_map_suffix='gt.png',
split='splits/train.txt')
assert len(train_dataset) == 4
# no data_root
train_dataset = CustomDataset(
train_pipeline,
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
ann_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts'),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# with data_root but img_dir/ann_dir are abs path
train_dataset = CustomDataset(
train_pipeline,
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
img_dir=osp.abspath(
osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
ann_dir=osp.abspath(
osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts')),
img_suffix='img.jpg',
seg_map_suffix='gt.png')
assert len(train_dataset) == 5
# test_mode=True
test_dataset = CustomDataset(
test_pipeline,
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
img_suffix='img.jpg',
test_mode=True)
assert len(test_dataset) == 5
# training data get
train_data = train_dataset[0]
assert isinstance(train_data, dict)
# test data get
test_data = test_dataset[0]
assert isinstance(test_data, dict)
# get gt seg map
[Refactor] Support progressive test with fewer memory cost (#709) * Support progressive test with fewer memory cost. * Temp code * Using processor to refactor evaluation workflow. * refactor eval hook. * Fix process bar. * Fix middle save argument. * Modify some variable name of dataset evaluate api. * Modify some viriable name of eval hook. * Fix some priority bugs of eval hook. * Depreciated efficient_test. * Fix training progress blocked by eval hook. * Depreciated old test api. * Fix test api error. * Modify outer api. * Build a sampler test api. * TODO: Refactor format_results. * Modify variable names. * Fix num_classes bug. * Fix sampler index bug. * Fix grammaly bug. * Support batch sampler. * More readable test api. * Remove some command arg and fix eval hook bug. * Support format-only arg. * Modify format_results of datasets. * Modify tool which use test apis. * support cityscapes eval * fixed cityscapes * 1. Add comments for batch_sampler; 2. Keep eval hook api same and add deprecated warning; 3. Add doc string for dataset.pre_eval; * Add efficient_test doc string. * Modify test tool to compat old version. * Modify eval hook to compat with old version. * Modify test api to compat old version api. * Sampler explanation. * update warning * Modify deploy_test.py * compatible with old output, add efficient test back * clear logic of exclusive * Warning about efficient_test. * Modify format_results save folder. * Fix bugs of format_results. * Modify deploy_test.py. * Update doc * Fix deploy test bugs. * Fix custom dataset unit tests. * Fix dataset unit tests. * Fix eval hook unit tests. * Fix some imcompatible. * Add pre_eval argument for eval hooks. * Update eval hook doc string. * Make pre_eval false in default. * Add unit tests for dataset format_results. * Fix some comments and bc-breaking bug. * Fix pre_eval set cfg field. * Remove redundant codes. Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
2021-08-20 11:44:58 +08:00
gt_seg_maps = train_dataset.get_gt_seg_maps(efficient_test=True)
assert isinstance(gt_seg_maps, Generator)
gt_seg_maps = list(gt_seg_maps)
2020-07-07 20:52:19 +08:00
assert len(gt_seg_maps) == 5
[Refactor] Support progressive test with fewer memory cost (#709) * Support progressive test with fewer memory cost. * Temp code * Using processor to refactor evaluation workflow. * refactor eval hook. * Fix process bar. * Fix middle save argument. * Modify some variable name of dataset evaluate api. * Modify some viriable name of eval hook. * Fix some priority bugs of eval hook. * Depreciated efficient_test. * Fix training progress blocked by eval hook. * Depreciated old test api. * Fix test api error. * Modify outer api. * Build a sampler test api. * TODO: Refactor format_results. * Modify variable names. * Fix num_classes bug. * Fix sampler index bug. * Fix grammaly bug. * Support batch sampler. * More readable test api. * Remove some command arg and fix eval hook bug. * Support format-only arg. * Modify format_results of datasets. * Modify tool which use test apis. * support cityscapes eval * fixed cityscapes * 1. Add comments for batch_sampler; 2. Keep eval hook api same and add deprecated warning; 3. Add doc string for dataset.pre_eval; * Add efficient_test doc string. * Modify test tool to compat old version. * Modify eval hook to compat with old version. * Modify test api to compat old version api. * Sampler explanation. * update warning * Modify deploy_test.py * compatible with old output, add efficient test back * clear logic of exclusive * Warning about efficient_test. * Modify format_results save folder. * Fix bugs of format_results. * Modify deploy_test.py. * Update doc * Fix deploy test bugs. * Fix custom dataset unit tests. * Fix dataset unit tests. * Fix eval hook unit tests. * Fix some imcompatible. * Add pre_eval argument for eval hooks. * Update eval hook doc string. * Make pre_eval false in default. * Add unit tests for dataset format_results. * Fix some comments and bc-breaking bug. * Fix pre_eval set cfg field. * Remove redundant codes. Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
2021-08-20 11:44:58 +08:00
# format_results not implemented
with pytest.raises(NotImplementedError):
test_dataset.format_results([], '')
# test past evaluation
2020-07-07 20:52:19 +08:00
pseudo_results = []
for gt_seg_map in gt_seg_maps:
h, w = gt_seg_map.shape
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
2020-07-07 20:52:19 +08:00
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
assert isinstance(eval_results, dict)
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(
pseudo_results, metric=['mDice', 'mIoU'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
[Refactor] Support progressive test with fewer memory cost (#709) * Support progressive test with fewer memory cost. * Temp code * Using processor to refactor evaluation workflow. * refactor eval hook. * Fix process bar. * Fix middle save argument. * Modify some variable name of dataset evaluate api. * Modify some viriable name of eval hook. * Fix some priority bugs of eval hook. * Depreciated efficient_test. * Fix training progress blocked by eval hook. * Depreciated old test api. * Fix test api error. * Modify outer api. * Build a sampler test api. * TODO: Refactor format_results. * Modify variable names. * Fix num_classes bug. * Fix sampler index bug. * Fix grammaly bug. * Support batch sampler. * More readable test api. * Remove some command arg and fix eval hook bug. * Support format-only arg. * Modify format_results of datasets. * Modify tool which use test apis. * support cityscapes eval * fixed cityscapes * 1. Add comments for batch_sampler; 2. Keep eval hook api same and add deprecated warning; 3. Add doc string for dataset.pre_eval; * Add efficient_test doc string. * Modify test tool to compat old version. * Modify eval hook to compat with old version. * Modify test api to compat old version api. * Sampler explanation. * update warning * Modify deploy_test.py * compatible with old output, add efficient test back * clear logic of exclusive * Warning about efficient_test. * Modify format_results save folder. * Fix bugs of format_results. * Modify deploy_test.py. * Update doc * Fix deploy test bugs. * Fix custom dataset unit tests. * Fix dataset unit tests. * Fix eval hook unit tests. * Fix some imcompatible. * Add pre_eval argument for eval hooks. * Update eval hook doc string. * Make pre_eval false in default. * Add unit tests for dataset format_results. * Fix some comments and bc-breaking bug. * Fix pre_eval set cfg field. * Remove redundant codes. Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
2021-08-20 11:44:58 +08:00
# test past evaluation with CLASSES
2020-07-07 20:52:19 +08:00
train_dataset.CLASSES = tuple(['a'] * 7)
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
assert isinstance(eval_results, dict)
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
assert isinstance(eval_results, dict)
assert 'mRecall' in eval_results
assert 'mPrecision' in eval_results
assert 'mFscore' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(
pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
2020-07-07 20:52:19 +08:00
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mDice' in eval_results
2020-07-07 20:52:19 +08:00
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
assert 'mFscore' in eval_results
assert 'mPrecision' in eval_results
assert 'mRecall' in eval_results
[Refactor] Support progressive test with fewer memory cost (#709) * Support progressive test with fewer memory cost. * Temp code * Using processor to refactor evaluation workflow. * refactor eval hook. * Fix process bar. * Fix middle save argument. * Modify some variable name of dataset evaluate api. * Modify some viriable name of eval hook. * Fix some priority bugs of eval hook. * Depreciated efficient_test. * Fix training progress blocked by eval hook. * Depreciated old test api. * Fix test api error. * Modify outer api. * Build a sampler test api. * TODO: Refactor format_results. * Modify variable names. * Fix num_classes bug. * Fix sampler index bug. * Fix grammaly bug. * Support batch sampler. * More readable test api. * Remove some command arg and fix eval hook bug. * Support format-only arg. * Modify format_results of datasets. * Modify tool which use test apis. * support cityscapes eval * fixed cityscapes * 1. Add comments for batch_sampler; 2. Keep eval hook api same and add deprecated warning; 3. Add doc string for dataset.pre_eval; * Add efficient_test doc string. * Modify test tool to compat old version. * Modify eval hook to compat with old version. * Modify test api to compat old version api. * Sampler explanation. * update warning * Modify deploy_test.py * compatible with old output, add efficient test back * clear logic of exclusive * Warning about efficient_test. * Modify format_results save folder. * Fix bugs of format_results. * Modify deploy_test.py. * Update doc * Fix deploy test bugs. * Fix custom dataset unit tests. * Fix dataset unit tests. * Fix eval hook unit tests. * Fix some imcompatible. * Add pre_eval argument for eval hooks. * Update eval hook doc string. * Make pre_eval false in default. * Add unit tests for dataset format_results. * Fix some comments and bc-breaking bug. * Fix pre_eval set cfg field. * Remove redundant codes. Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
2021-08-20 11:44:58 +08:00
# test evaluation with pre-eval and the dataset.CLASSES is necessary
train_dataset.CLASSES = tuple(['a'] * 7)
pseudo_results = []
for idx in range(len(train_dataset)):
h, w = gt_seg_maps[idx].shape
pseudo_result = np.random.randint(low=0, high=7, size=(h, w))
pseudo_results.extend(train_dataset.pre_eval(pseudo_result, idx))
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
assert isinstance(eval_results, dict)
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
assert isinstance(eval_results, dict)
assert 'mRecall' in eval_results
assert 'mPrecision' in eval_results
assert 'mFscore' in eval_results
assert 'aAcc' in eval_results
eval_results = train_dataset.evaluate(
pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
assert isinstance(eval_results, dict)
assert 'mIoU' in eval_results
assert 'mDice' in eval_results
assert 'mAcc' in eval_results
assert 'aAcc' in eval_results
assert 'mFscore' in eval_results
assert 'mPrecision' in eval_results
assert 'mRecall' in eval_results
def test_ade():
test_dataset = ADE20KDataset(
pipeline=[],
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
assert len(test_dataset) == 5
# Test format_results
pseudo_results = []
for _ in range(len(test_dataset)):
h, w = (2, 2)
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
file_paths = test_dataset.format_results(pseudo_results, '.format_ade')
assert len(file_paths) == len(test_dataset)
temp = np.array(Image.open(file_paths[0]))
assert np.allclose(temp, pseudo_results[0] + 1)
shutil.rmtree('.format_ade')
def test_cityscapes():
test_dataset = CityscapesDataset(
pipeline=[],
img_dir=osp.join(
osp.dirname(__file__),
'../data/pseudo_cityscapes_dataset/leftImg8bit'),
ann_dir=osp.join(
osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine'))
assert len(test_dataset) == 1
gt_seg_maps = list(test_dataset.get_gt_seg_maps())
# Test format_results
pseudo_results = []
for idx in range(len(test_dataset)):
h, w = gt_seg_maps[idx].shape
pseudo_results.append(np.random.randint(low=0, high=19, size=(h, w)))
file_paths = test_dataset.format_results(pseudo_results, '.format_city')
assert len(file_paths) == len(test_dataset)
temp = np.array(Image.open(file_paths[0]))
assert np.allclose(temp,
test_dataset._convert_to_label_id(pseudo_results[0]))
# Test cityscapes evaluate
test_dataset.evaluate(
pseudo_results, metric='cityscapes', imgfile_prefix='.format_city')
shutil.rmtree('.format_city')
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
@pytest.mark.parametrize('dataset, classes', [
('ADE20KDataset', ('wall', 'building')),
('CityscapesDataset', ('road', 'sidewalk')),
('CustomDataset', ('bus', 'car')),
('PascalVOCDataset', ('aeroplane', 'bicycle')),
])
def test_custom_classes_override_default(dataset, classes):
dataset_class = DATASETS.get(dataset)
original_classes = dataset_class.CLASSES
# Test setting classes as a tuple
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=classes,
test_mode=True)
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == classes
# Test setting classes as a list
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=list(classes),
test_mode=True)
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == list(classes)
# Test overriding not a subset
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=[classes[0]],
test_mode=True)
assert custom_dataset.CLASSES != original_classes
assert custom_dataset.CLASSES == [classes[0]]
# Test default behavior
custom_dataset = dataset_class(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=None,
test_mode=True)
assert custom_dataset.CLASSES == original_classes
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
def test_custom_dataset_random_palette_is_generated():
dataset = CustomDataset(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=('bus', 'car'),
test_mode=True)
assert len(dataset.PALETTE) == 2
for class_color in dataset.PALETTE:
assert len(class_color) == 3
assert all(x >= 0 and x <= 255 for x in class_color)
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
@patch('mmseg.datasets.CustomDataset.__getitem__',
MagicMock(side_effect=lambda idx: idx))
def test_custom_dataset_custom_palette():
dataset = CustomDataset(
pipeline=[],
img_dir=MagicMock(),
split=MagicMock(),
classes=('bus', 'car'),
palette=[[100, 100, 100], [200, 200, 200]],
test_mode=True)
assert tuple(dataset.PALETTE) == tuple([[100, 100, 100], [200, 200, 200]])