Merge branch 'zhengmiao/tests_bp' into 'refactor_dev'
[Refactory] Clean UTs See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!2pull/1801/head
commit
b2abe15787
|
@ -1,73 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import shutil
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, Dataset, dataloader
|
||||
|
||||
from mmseg.apis import single_gpu_test
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.tensor([1]), img_metas=dict())
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleModel, self).__init__()
|
||||
self.test_cfg = None
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, img, img_metas, return_loss=False, **kwargs):
|
||||
return img
|
||||
|
||||
|
||||
def test_single_gpu():
|
||||
test_dataset = ExampleDataset()
|
||||
data_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
sampler=None,
|
||||
num_workers=0,
|
||||
shuffle=False,
|
||||
)
|
||||
model = ExampleModel()
|
||||
|
||||
# Test efficient test compatibility (will be deprecated)
|
||||
results = single_gpu_test(model, data_loader, efficient_test=True)
|
||||
assert len(results) == 1
|
||||
pred = np.load(results[0])
|
||||
assert isinstance(pred, np.ndarray)
|
||||
assert pred.shape == (1, )
|
||||
assert pred[0] == 1
|
||||
|
||||
shutil.rmtree('.efficient_test')
|
||||
|
||||
# Test pre_eval
|
||||
test_dataset.pre_eval = MagicMock(return_value=['success'])
|
||||
results = single_gpu_test(model, data_loader, pre_eval=True)
|
||||
assert results == ['success']
|
||||
|
||||
# Test format_only
|
||||
test_dataset.format_results = MagicMock(return_value=['success'])
|
||||
results = single_gpu_test(model, data_loader, format_only=True)
|
||||
assert results == ['success']
|
||||
|
||||
# efficient_test, pre_eval and format_only are mutually exclusive
|
||||
with pytest.raises(AssertionError):
|
||||
single_gpu_test(
|
||||
model,
|
||||
dataloader,
|
||||
efficient_test=True,
|
||||
format_only=True,
|
||||
pre_eval=True)
|
|
@ -1,851 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.core.evaluation import get_classes, get_palette
|
||||
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
|
||||
COCOStuffDataset, ConcatDataset, CustomDataset,
|
||||
ISPRSDataset, LoveDADataset, MultiImageMixDataset,
|
||||
PascalVOCDataset, PotsdamDataset, RepeatDataset,
|
||||
build_dataset, iSAIDDataset)
|
||||
|
||||
|
||||
def test_classes():
|
||||
assert list(CityscapesDataset.CLASSES) == get_classes('cityscapes')
|
||||
assert list(PascalVOCDataset.CLASSES) == get_classes('voc') == get_classes(
|
||||
'pascal_voc')
|
||||
assert list(
|
||||
ADE20KDataset.CLASSES) == get_classes('ade') == get_classes('ade20k')
|
||||
assert list(COCOStuffDataset.CLASSES) == get_classes('cocostuff')
|
||||
assert list(LoveDADataset.CLASSES) == get_classes('loveda')
|
||||
assert list(PotsdamDataset.CLASSES) == get_classes('potsdam')
|
||||
assert list(ISPRSDataset.CLASSES) == get_classes('vaihingen')
|
||||
assert list(iSAIDDataset.CLASSES) == get_classes('isaid')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
get_classes('unsupported')
|
||||
|
||||
|
||||
def test_classes_file_path():
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
classes_path = f'{tmp_file.name}.txt'
|
||||
train_pipeline = [dict(type='LoadImageFromFile')]
|
||||
kwargs = dict(pipeline=train_pipeline, img_dir='./', classes=classes_path)
|
||||
|
||||
# classes.txt with full categories
|
||||
categories = get_classes('cityscapes')
|
||||
with open(classes_path, 'w') as f:
|
||||
f.write('\n'.join(categories))
|
||||
assert list(CityscapesDataset(**kwargs).CLASSES) == categories
|
||||
|
||||
# classes.txt with sub categories
|
||||
categories = ['road', 'sidewalk', 'building']
|
||||
with open(classes_path, 'w') as f:
|
||||
f.write('\n'.join(categories))
|
||||
assert list(CityscapesDataset(**kwargs).CLASSES) == categories
|
||||
|
||||
# classes.txt with unknown categories
|
||||
categories = ['road', 'sidewalk', 'unknown']
|
||||
with open(classes_path, 'w') as f:
|
||||
f.write('\n'.join(categories))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
CityscapesDataset(**kwargs)
|
||||
|
||||
tmp_file.close()
|
||||
os.remove(classes_path)
|
||||
assert not osp.exists(classes_path)
|
||||
|
||||
|
||||
def test_palette():
|
||||
assert CityscapesDataset.PALETTE == get_palette('cityscapes')
|
||||
assert PascalVOCDataset.PALETTE == get_palette('voc') == get_palette(
|
||||
'pascal_voc')
|
||||
assert ADE20KDataset.PALETTE == get_palette('ade') == get_palette('ade20k')
|
||||
assert LoveDADataset.PALETTE == get_palette('loveda')
|
||||
assert PotsdamDataset.PALETTE == get_palette('potsdam')
|
||||
assert COCOStuffDataset.PALETTE == get_palette('cocostuff')
|
||||
assert iSAIDDataset.PALETTE == get_palette('isaid')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
get_palette('unsupported')
|
||||
|
||||
|
||||
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
||||
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
||||
MagicMock(side_effect=lambda idx: idx))
|
||||
def test_dataset_wrapper():
|
||||
# CustomDataset.load_annotations = MagicMock()
|
||||
# CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
|
||||
dataset_a = CustomDataset(img_dir=MagicMock(), pipeline=[])
|
||||
len_a = 10
|
||||
dataset_a.img_infos = MagicMock()
|
||||
dataset_a.img_infos.__len__.return_value = len_a
|
||||
dataset_b = CustomDataset(img_dir=MagicMock(), pipeline=[])
|
||||
len_b = 20
|
||||
dataset_b.img_infos = MagicMock()
|
||||
dataset_b.img_infos.__len__.return_value = len_b
|
||||
|
||||
concat_dataset = ConcatDataset([dataset_a, dataset_b])
|
||||
assert concat_dataset[5] == 5
|
||||
assert concat_dataset[25] == 15
|
||||
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
|
||||
|
||||
repeat_dataset = RepeatDataset(dataset_a, 10)
|
||||
assert repeat_dataset[5] == 5
|
||||
assert repeat_dataset[15] == 5
|
||||
assert repeat_dataset[27] == 7
|
||||
assert len(repeat_dataset) == 10 * len(dataset_a)
|
||||
|
||||
img_scale = (60, 60)
|
||||
pipeline = [
|
||||
dict(type='RandomMosaic', prob=1, img_scale=img_scale),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='Resize', img_scale=img_scale, keep_ratio=False),
|
||||
]
|
||||
|
||||
CustomDataset.load_annotations = MagicMock()
|
||||
results = []
|
||||
for _ in range(2):
|
||||
height = np.random.randint(10, 30)
|
||||
weight = np.random.randint(10, 30)
|
||||
img = np.ones((height, weight, 3))
|
||||
gt_semantic_seg = np.random.randint(5, size=(height, weight))
|
||||
results.append(dict(gt_semantic_seg=gt_semantic_seg, img=img))
|
||||
|
||||
classes = ['0', '1', '2', '3', '4']
|
||||
palette = [(0, 0, 0), (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)]
|
||||
CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: results[idx])
|
||||
dataset_a = CustomDataset(
|
||||
img_dir=MagicMock(),
|
||||
pipeline=[],
|
||||
test_mode=True,
|
||||
classes=classes,
|
||||
palette=palette)
|
||||
len_a = 2
|
||||
dataset_a.img_infos = MagicMock()
|
||||
dataset_a.img_infos.__len__.return_value = len_a
|
||||
|
||||
multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
|
||||
assert len(multi_image_mix_dataset) == len(dataset_a)
|
||||
|
||||
for idx in range(len_a):
|
||||
results_ = multi_image_mix_dataset[idx]
|
||||
|
||||
# test skip_type_keys
|
||||
multi_image_mix_dataset = MultiImageMixDataset(
|
||||
dataset_a, pipeline, skip_type_keys=('RandomFlip'))
|
||||
for idx in range(len_a):
|
||||
results_ = multi_image_mix_dataset[idx]
|
||||
assert results_['img'].shape == (img_scale[0], img_scale[1], 3)
|
||||
|
||||
skip_type_keys = ('RandomFlip', 'Resize')
|
||||
multi_image_mix_dataset.update_skip_type_keys(skip_type_keys)
|
||||
for idx in range(len_a):
|
||||
results_ = multi_image_mix_dataset[idx]
|
||||
assert results_['img'].shape[:2] != img_scale
|
||||
|
||||
# test pipeline
|
||||
with pytest.raises(TypeError):
|
||||
pipeline = [['Resize']]
|
||||
multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
|
||||
|
||||
|
||||
def test_custom_dataset():
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True)
|
||||
crop_size = (512, 1024)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(128, 256),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
|
||||
# with img_dir and ann_dir
|
||||
train_dataset = CustomDataset(
|
||||
train_pipeline,
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
||||
img_dir='imgs/',
|
||||
ann_dir='gts/',
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png')
|
||||
assert len(train_dataset) == 5
|
||||
|
||||
# with img_dir, ann_dir, split
|
||||
train_dataset = CustomDataset(
|
||||
train_pipeline,
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
||||
img_dir='imgs/',
|
||||
ann_dir='gts/',
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png',
|
||||
split='splits/train.txt')
|
||||
assert len(train_dataset) == 4
|
||||
|
||||
# no data_root
|
||||
train_dataset = CustomDataset(
|
||||
train_pipeline,
|
||||
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
|
||||
ann_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts'),
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png')
|
||||
assert len(train_dataset) == 5
|
||||
|
||||
# with data_root but img_dir/ann_dir are abs path
|
||||
train_dataset = CustomDataset(
|
||||
train_pipeline,
|
||||
data_root=osp.join(osp.dirname(__file__), '../data/pseudo_dataset'),
|
||||
img_dir=osp.abspath(
|
||||
osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs')),
|
||||
ann_dir=osp.abspath(
|
||||
osp.join(osp.dirname(__file__), '../data/pseudo_dataset/gts')),
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png')
|
||||
assert len(train_dataset) == 5
|
||||
|
||||
# test_mode=True
|
||||
test_dataset = CustomDataset(
|
||||
test_pipeline,
|
||||
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'),
|
||||
img_suffix='img.jpg',
|
||||
test_mode=True,
|
||||
classes=('pseudo_class', ))
|
||||
assert len(test_dataset) == 5
|
||||
|
||||
# training data get
|
||||
train_data = train_dataset[0]
|
||||
assert isinstance(train_data, dict)
|
||||
|
||||
# test data get
|
||||
test_data = test_dataset[0]
|
||||
assert isinstance(test_data, dict)
|
||||
|
||||
# get gt seg map
|
||||
gt_seg_maps = train_dataset.get_gt_seg_maps(efficient_test=True)
|
||||
assert isinstance(gt_seg_maps, Generator)
|
||||
gt_seg_maps = list(gt_seg_maps)
|
||||
assert len(gt_seg_maps) == 5
|
||||
|
||||
# format_results not implemented
|
||||
with pytest.raises(NotImplementedError):
|
||||
test_dataset.format_results([], '')
|
||||
|
||||
pseudo_results = []
|
||||
for gt_seg_map in gt_seg_maps:
|
||||
h, w = gt_seg_map.shape
|
||||
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
|
||||
|
||||
# test past evaluation without CLASSES
|
||||
with pytest.raises(TypeError):
|
||||
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
eval_results = train_dataset.evaluate(
|
||||
pseudo_results, metric=['mDice', 'mIoU'])
|
||||
|
||||
# test past evaluation with CLASSES
|
||||
train_dataset.CLASSES = tuple(['a'] * 7)
|
||||
eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU')
|
||||
assert isinstance(eval_results, dict)
|
||||
assert 'mIoU' in eval_results
|
||||
assert 'mAcc' in eval_results
|
||||
assert 'aAcc' in eval_results
|
||||
|
||||
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
|
||||
assert isinstance(eval_results, dict)
|
||||
assert 'mDice' in eval_results
|
||||
assert 'mAcc' in eval_results
|
||||
assert 'aAcc' in eval_results
|
||||
|
||||
eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
|
||||
assert isinstance(eval_results, dict)
|
||||
assert 'mRecall' in eval_results
|
||||
assert 'mPrecision' in eval_results
|
||||
assert 'mFscore' in eval_results
|
||||
assert 'aAcc' in eval_results
|
||||
|
||||
eval_results = train_dataset.evaluate(
|
||||
pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
|
||||
assert isinstance(eval_results, dict)
|
||||
assert 'mIoU' in eval_results
|
||||
assert 'mDice' in eval_results
|
||||
assert 'mAcc' in eval_results
|
||||
assert 'aAcc' in eval_results
|
||||
assert 'mFscore' in eval_results
|
||||
assert 'mPrecision' in eval_results
|
||||
assert 'mRecall' in eval_results
|
||||
|
||||
assert not np.isnan(eval_results['mIoU'])
|
||||
assert not np.isnan(eval_results['mDice'])
|
||||
assert not np.isnan(eval_results['mAcc'])
|
||||
assert not np.isnan(eval_results['aAcc'])
|
||||
assert not np.isnan(eval_results['mFscore'])
|
||||
assert not np.isnan(eval_results['mPrecision'])
|
||||
assert not np.isnan(eval_results['mRecall'])
|
||||
|
||||
# test evaluation with pre-eval and the dataset.CLASSES is necessary
|
||||
train_dataset.CLASSES = tuple(['a'] * 7)
|
||||
pseudo_results = []
|
||||
for idx in range(len(train_dataset)):
|
||||
h, w = gt_seg_maps[idx].shape
|
||||
pseudo_result = np.random.randint(low=0, high=7, size=(h, w))
|
||||
pseudo_results.extend(train_dataset.pre_eval(pseudo_result, idx))
|
||||
eval_results = train_dataset.evaluate(pseudo_results, metric=['mIoU'])
|
||||
assert isinstance(eval_results, dict)
|
||||
assert 'mIoU' in eval_results
|
||||
assert 'mAcc' in eval_results
|
||||
assert 'aAcc' in eval_results
|
||||
|
||||
eval_results = train_dataset.evaluate(pseudo_results, metric='mDice')
|
||||
assert isinstance(eval_results, dict)
|
||||
assert 'mDice' in eval_results
|
||||
assert 'mAcc' in eval_results
|
||||
assert 'aAcc' in eval_results
|
||||
|
||||
eval_results = train_dataset.evaluate(pseudo_results, metric='mFscore')
|
||||
assert isinstance(eval_results, dict)
|
||||
assert 'mRecall' in eval_results
|
||||
assert 'mPrecision' in eval_results
|
||||
assert 'mFscore' in eval_results
|
||||
assert 'aAcc' in eval_results
|
||||
|
||||
eval_results = train_dataset.evaluate(
|
||||
pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
|
||||
assert isinstance(eval_results, dict)
|
||||
assert 'mIoU' in eval_results
|
||||
assert 'mDice' in eval_results
|
||||
assert 'mAcc' in eval_results
|
||||
assert 'aAcc' in eval_results
|
||||
assert 'mFscore' in eval_results
|
||||
assert 'mPrecision' in eval_results
|
||||
assert 'mRecall' in eval_results
|
||||
|
||||
assert not np.isnan(eval_results['mIoU'])
|
||||
assert not np.isnan(eval_results['mDice'])
|
||||
assert not np.isnan(eval_results['mAcc'])
|
||||
assert not np.isnan(eval_results['aAcc'])
|
||||
assert not np.isnan(eval_results['mFscore'])
|
||||
assert not np.isnan(eval_results['mPrecision'])
|
||||
assert not np.isnan(eval_results['mRecall'])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('separate_eval', [True, False])
|
||||
def test_eval_concat_custom_dataset(separate_eval):
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True)
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(128, 256),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
|
||||
img_dir = 'imgs/'
|
||||
ann_dir = 'gts/'
|
||||
|
||||
cfg1 = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=test_pipeline,
|
||||
data_root=data_root,
|
||||
img_dir=img_dir,
|
||||
ann_dir=ann_dir,
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png',
|
||||
classes=tuple(['a'] * 7))
|
||||
dataset1 = build_dataset(cfg1)
|
||||
assert len(dataset1) == 5
|
||||
# get gt seg map
|
||||
gt_seg_maps = dataset1.get_gt_seg_maps(efficient_test=True)
|
||||
assert isinstance(gt_seg_maps, Generator)
|
||||
gt_seg_maps = list(gt_seg_maps)
|
||||
assert len(gt_seg_maps) == 5
|
||||
|
||||
# test past evaluation
|
||||
pseudo_results = []
|
||||
for gt_seg_map in gt_seg_maps:
|
||||
h, w = gt_seg_map.shape
|
||||
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
|
||||
eval_results1 = dataset1.evaluate(
|
||||
pseudo_results, metric=['mIoU', 'mDice', 'mFscore'])
|
||||
|
||||
# We use same dir twice for simplicity
|
||||
# with ann_dir
|
||||
cfg2 = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=test_pipeline,
|
||||
data_root=data_root,
|
||||
img_dir=[img_dir, img_dir],
|
||||
ann_dir=[ann_dir, ann_dir],
|
||||
img_suffix='img.jpg',
|
||||
seg_map_suffix='gt.png',
|
||||
classes=tuple(['a'] * 7),
|
||||
separate_eval=separate_eval)
|
||||
dataset2 = build_dataset(cfg2)
|
||||
assert isinstance(dataset2, ConcatDataset)
|
||||
assert len(dataset2) == 10
|
||||
|
||||
eval_results2 = dataset2.evaluate(
|
||||
pseudo_results * 2, metric=['mIoU', 'mDice', 'mFscore'])
|
||||
|
||||
if separate_eval:
|
||||
assert eval_results1['mIoU'] == eval_results2[
|
||||
'0_mIoU'] == eval_results2['1_mIoU']
|
||||
assert eval_results1['mDice'] == eval_results2[
|
||||
'0_mDice'] == eval_results2['1_mDice']
|
||||
assert eval_results1['mAcc'] == eval_results2[
|
||||
'0_mAcc'] == eval_results2['1_mAcc']
|
||||
assert eval_results1['aAcc'] == eval_results2[
|
||||
'0_aAcc'] == eval_results2['1_aAcc']
|
||||
assert eval_results1['mFscore'] == eval_results2[
|
||||
'0_mFscore'] == eval_results2['1_mFscore']
|
||||
assert eval_results1['mPrecision'] == eval_results2[
|
||||
'0_mPrecision'] == eval_results2['1_mPrecision']
|
||||
assert eval_results1['mRecall'] == eval_results2[
|
||||
'0_mRecall'] == eval_results2['1_mRecall']
|
||||
else:
|
||||
assert eval_results1['mIoU'] == eval_results2['mIoU']
|
||||
assert eval_results1['mDice'] == eval_results2['mDice']
|
||||
assert eval_results1['mAcc'] == eval_results2['mAcc']
|
||||
assert eval_results1['aAcc'] == eval_results2['aAcc']
|
||||
assert eval_results1['mFscore'] == eval_results2['mFscore']
|
||||
assert eval_results1['mPrecision'] == eval_results2['mPrecision']
|
||||
assert eval_results1['mRecall'] == eval_results2['mRecall']
|
||||
|
||||
# test get dataset_idx and sample_idx from ConcateDataset
|
||||
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(3)
|
||||
assert dataset_idx == 0
|
||||
assert sample_idx == 3
|
||||
|
||||
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(7)
|
||||
assert dataset_idx == 1
|
||||
assert sample_idx == 2
|
||||
|
||||
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-7)
|
||||
assert dataset_idx == 0
|
||||
assert sample_idx == 3
|
||||
|
||||
# test negative indice exceed length of dataset
|
||||
with pytest.raises(ValueError):
|
||||
dataset_idx, sample_idx = dataset2.get_dataset_idx_and_sample_idx(-11)
|
||||
|
||||
# test negative indice value
|
||||
indice = -6
|
||||
dataset_idx1, sample_idx1 = dataset2.get_dataset_idx_and_sample_idx(indice)
|
||||
dataset_idx2, sample_idx2 = dataset2.get_dataset_idx_and_sample_idx(
|
||||
len(dataset2) + indice)
|
||||
assert dataset_idx1 == dataset_idx2
|
||||
assert sample_idx1 == sample_idx2
|
||||
|
||||
# test evaluation with pre-eval and the dataset.CLASSES is necessary
|
||||
pseudo_results = []
|
||||
eval_results1 = []
|
||||
for idx in range(len(dataset1)):
|
||||
h, w = gt_seg_maps[idx].shape
|
||||
pseudo_result = np.random.randint(low=0, high=7, size=(h, w))
|
||||
pseudo_results.append(pseudo_result)
|
||||
eval_results1.extend(dataset1.pre_eval(pseudo_result, idx))
|
||||
|
||||
assert len(eval_results1) == len(dataset1)
|
||||
assert isinstance(eval_results1[0], tuple)
|
||||
assert len(eval_results1[0]) == 4
|
||||
assert isinstance(eval_results1[0][0], torch.Tensor)
|
||||
|
||||
eval_results1 = dataset1.evaluate(
|
||||
eval_results1, metric=['mIoU', 'mDice', 'mFscore'])
|
||||
|
||||
pseudo_results = pseudo_results * 2
|
||||
eval_results2 = []
|
||||
for idx in range(len(dataset2)):
|
||||
eval_results2.extend(dataset2.pre_eval(pseudo_results[idx], idx))
|
||||
|
||||
assert len(eval_results2) == len(dataset2)
|
||||
assert isinstance(eval_results2[0], tuple)
|
||||
assert len(eval_results2[0]) == 4
|
||||
assert isinstance(eval_results2[0][0], torch.Tensor)
|
||||
|
||||
eval_results2 = dataset2.evaluate(
|
||||
eval_results2, metric=['mIoU', 'mDice', 'mFscore'])
|
||||
|
||||
if separate_eval:
|
||||
assert eval_results1['mIoU'] == eval_results2[
|
||||
'0_mIoU'] == eval_results2['1_mIoU']
|
||||
assert eval_results1['mDice'] == eval_results2[
|
||||
'0_mDice'] == eval_results2['1_mDice']
|
||||
assert eval_results1['mAcc'] == eval_results2[
|
||||
'0_mAcc'] == eval_results2['1_mAcc']
|
||||
assert eval_results1['aAcc'] == eval_results2[
|
||||
'0_aAcc'] == eval_results2['1_aAcc']
|
||||
assert eval_results1['mFscore'] == eval_results2[
|
||||
'0_mFscore'] == eval_results2['1_mFscore']
|
||||
assert eval_results1['mPrecision'] == eval_results2[
|
||||
'0_mPrecision'] == eval_results2['1_mPrecision']
|
||||
assert eval_results1['mRecall'] == eval_results2[
|
||||
'0_mRecall'] == eval_results2['1_mRecall']
|
||||
else:
|
||||
assert eval_results1['mIoU'] == eval_results2['mIoU']
|
||||
assert eval_results1['mDice'] == eval_results2['mDice']
|
||||
assert eval_results1['mAcc'] == eval_results2['mAcc']
|
||||
assert eval_results1['aAcc'] == eval_results2['aAcc']
|
||||
assert eval_results1['mFscore'] == eval_results2['mFscore']
|
||||
assert eval_results1['mPrecision'] == eval_results2['mPrecision']
|
||||
assert eval_results1['mRecall'] == eval_results2['mRecall']
|
||||
|
||||
# test batch_indices for pre eval
|
||||
eval_results2 = dataset2.pre_eval(pseudo_results,
|
||||
list(range(len(pseudo_results))))
|
||||
|
||||
assert len(eval_results2) == len(dataset2)
|
||||
assert isinstance(eval_results2[0], tuple)
|
||||
assert len(eval_results2[0]) == 4
|
||||
assert isinstance(eval_results2[0][0], torch.Tensor)
|
||||
|
||||
eval_results2 = dataset2.evaluate(
|
||||
eval_results2, metric=['mIoU', 'mDice', 'mFscore'])
|
||||
|
||||
if separate_eval:
|
||||
assert eval_results1['mIoU'] == eval_results2[
|
||||
'0_mIoU'] == eval_results2['1_mIoU']
|
||||
assert eval_results1['mDice'] == eval_results2[
|
||||
'0_mDice'] == eval_results2['1_mDice']
|
||||
assert eval_results1['mAcc'] == eval_results2[
|
||||
'0_mAcc'] == eval_results2['1_mAcc']
|
||||
assert eval_results1['aAcc'] == eval_results2[
|
||||
'0_aAcc'] == eval_results2['1_aAcc']
|
||||
assert eval_results1['mFscore'] == eval_results2[
|
||||
'0_mFscore'] == eval_results2['1_mFscore']
|
||||
assert eval_results1['mPrecision'] == eval_results2[
|
||||
'0_mPrecision'] == eval_results2['1_mPrecision']
|
||||
assert eval_results1['mRecall'] == eval_results2[
|
||||
'0_mRecall'] == eval_results2['1_mRecall']
|
||||
else:
|
||||
assert eval_results1['mIoU'] == eval_results2['mIoU']
|
||||
assert eval_results1['mDice'] == eval_results2['mDice']
|
||||
assert eval_results1['mAcc'] == eval_results2['mAcc']
|
||||
assert eval_results1['aAcc'] == eval_results2['aAcc']
|
||||
assert eval_results1['mFscore'] == eval_results2['mFscore']
|
||||
assert eval_results1['mPrecision'] == eval_results2['mPrecision']
|
||||
assert eval_results1['mRecall'] == eval_results2['mRecall']
|
||||
|
||||
|
||||
def test_ade():
|
||||
test_dataset = ADE20KDataset(
|
||||
pipeline=[],
|
||||
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
|
||||
assert len(test_dataset) == 5
|
||||
|
||||
# Test format_results
|
||||
pseudo_results = []
|
||||
for _ in range(len(test_dataset)):
|
||||
h, w = (2, 2)
|
||||
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
|
||||
|
||||
file_paths = test_dataset.format_results(pseudo_results, '.format_ade')
|
||||
assert len(file_paths) == len(test_dataset)
|
||||
temp = np.array(Image.open(file_paths[0]))
|
||||
assert np.allclose(temp, pseudo_results[0] + 1)
|
||||
|
||||
shutil.rmtree('.format_ade')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('separate_eval', [True, False])
|
||||
def test_concat_ade(separate_eval):
|
||||
test_dataset = ADE20KDataset(
|
||||
pipeline=[],
|
||||
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
|
||||
assert len(test_dataset) == 5
|
||||
|
||||
concat_dataset = ConcatDataset([test_dataset, test_dataset],
|
||||
separate_eval=separate_eval)
|
||||
assert len(concat_dataset) == 10
|
||||
# Test format_results
|
||||
pseudo_results = []
|
||||
for _ in range(len(concat_dataset)):
|
||||
h, w = (2, 2)
|
||||
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
|
||||
|
||||
# test format per image
|
||||
file_paths = []
|
||||
for i in range(len(pseudo_results)):
|
||||
file_paths.extend(
|
||||
concat_dataset.format_results([pseudo_results[i]],
|
||||
'.format_ade',
|
||||
indices=[i]))
|
||||
assert len(file_paths) == len(concat_dataset)
|
||||
temp = np.array(Image.open(file_paths[0]))
|
||||
assert np.allclose(temp, pseudo_results[0] + 1)
|
||||
|
||||
shutil.rmtree('.format_ade')
|
||||
|
||||
# test default argument
|
||||
file_paths = concat_dataset.format_results(pseudo_results, '.format_ade')
|
||||
assert len(file_paths) == len(concat_dataset)
|
||||
temp = np.array(Image.open(file_paths[0]))
|
||||
assert np.allclose(temp, pseudo_results[0] + 1)
|
||||
|
||||
shutil.rmtree('.format_ade')
|
||||
|
||||
|
||||
def test_cityscapes():
|
||||
test_dataset = CityscapesDataset(
|
||||
pipeline=[],
|
||||
img_dir=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_cityscapes_dataset/leftImg8bit'),
|
||||
ann_dir=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine'))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
gt_seg_maps = list(test_dataset.get_gt_seg_maps())
|
||||
|
||||
# Test format_results
|
||||
pseudo_results = []
|
||||
for idx in range(len(test_dataset)):
|
||||
h, w = gt_seg_maps[idx].shape
|
||||
pseudo_results.append(np.random.randint(low=0, high=19, size=(h, w)))
|
||||
|
||||
file_paths = test_dataset.format_results(pseudo_results, '.format_city')
|
||||
assert len(file_paths) == len(test_dataset)
|
||||
temp = np.array(Image.open(file_paths[0]))
|
||||
assert np.allclose(temp,
|
||||
test_dataset._convert_to_label_id(pseudo_results[0]))
|
||||
|
||||
# Test cityscapes evaluate
|
||||
|
||||
test_dataset.evaluate(
|
||||
pseudo_results, metric='cityscapes', imgfile_prefix='.format_city')
|
||||
|
||||
shutil.rmtree('.format_city')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('separate_eval', [True, False])
|
||||
def test_concat_cityscapes(separate_eval):
|
||||
cityscape_dataset = CityscapesDataset(
|
||||
pipeline=[],
|
||||
img_dir=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_cityscapes_dataset/leftImg8bit'),
|
||||
ann_dir=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_cityscapes_dataset/gtFine'))
|
||||
assert len(cityscape_dataset) == 1
|
||||
with pytest.raises(NotImplementedError):
|
||||
_ = ConcatDataset([cityscape_dataset, cityscape_dataset],
|
||||
separate_eval=separate_eval)
|
||||
ade_dataset = ADE20KDataset(
|
||||
pipeline=[],
|
||||
img_dir=osp.join(osp.dirname(__file__), '../data/pseudo_dataset/imgs'))
|
||||
assert len(ade_dataset) == 5
|
||||
with pytest.raises(NotImplementedError):
|
||||
_ = ConcatDataset([cityscape_dataset, ade_dataset],
|
||||
separate_eval=separate_eval)
|
||||
|
||||
|
||||
def test_loveda():
|
||||
test_dataset = LoveDADataset(
|
||||
pipeline=[],
|
||||
img_dir=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_loveda_dataset/img_dir'),
|
||||
ann_dir=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_loveda_dataset/ann_dir'))
|
||||
assert len(test_dataset) == 3
|
||||
|
||||
gt_seg_maps = list(test_dataset.get_gt_seg_maps())
|
||||
|
||||
# Test format_results
|
||||
pseudo_results = []
|
||||
for idx in range(len(test_dataset)):
|
||||
h, w = gt_seg_maps[idx].shape
|
||||
pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w)))
|
||||
file_paths = test_dataset.format_results(pseudo_results, '.format_loveda')
|
||||
assert len(file_paths) == len(test_dataset)
|
||||
# Test loveda evaluate
|
||||
|
||||
test_dataset.evaluate(
|
||||
pseudo_results, metric='mIoU', imgfile_prefix='.format_loveda')
|
||||
|
||||
shutil.rmtree('.format_loveda')
|
||||
|
||||
|
||||
def test_potsdam():
|
||||
test_dataset = PotsdamDataset(
|
||||
pipeline=[],
|
||||
img_dir=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_potsdam_dataset/img_dir'),
|
||||
ann_dir=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_potsdam_dataset/ann_dir'))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_vaihingen():
|
||||
test_dataset = ISPRSDataset(
|
||||
pipeline=[],
|
||||
img_dir=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_vaihingen_dataset/img_dir'),
|
||||
ann_dir=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_vaihingen_dataset/ann_dir'))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_isaid():
|
||||
test_dataset = iSAIDDataset(
|
||||
pipeline=[],
|
||||
img_dir=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
|
||||
ann_dir=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_isaid_dataset/ann_dir'))
|
||||
assert len(test_dataset) == 2
|
||||
isaid_info = test_dataset.load_annotations(
|
||||
img_dir=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_isaid_dataset/img_dir'),
|
||||
img_suffix='.png',
|
||||
ann_dir=osp.join(
|
||||
osp.dirname(__file__), '../data/pseudo_isaid_dataset/ann_dir'),
|
||||
seg_map_suffix='.png',
|
||||
split=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_isaid_dataset/splits/train.txt'))
|
||||
assert len(isaid_info) == 1
|
||||
|
||||
|
||||
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
||||
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
||||
MagicMock(side_effect=lambda idx: idx))
|
||||
@pytest.mark.parametrize('dataset, classes', [
|
||||
('ADE20KDataset', ('wall', 'building')),
|
||||
('CityscapesDataset', ('road', 'sidewalk')),
|
||||
('CustomDataset', ('bus', 'car')),
|
||||
('PascalVOCDataset', ('aeroplane', 'bicycle')),
|
||||
])
|
||||
def test_custom_classes_override_default(dataset, classes):
|
||||
|
||||
dataset_class = DATASETS.get(dataset)
|
||||
|
||||
original_classes = dataset_class.CLASSES
|
||||
|
||||
# Test setting classes as a tuple
|
||||
custom_dataset = dataset_class(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=classes,
|
||||
test_mode=True)
|
||||
|
||||
assert custom_dataset.CLASSES != original_classes
|
||||
assert custom_dataset.CLASSES == classes
|
||||
|
||||
# Test setting classes as a list
|
||||
custom_dataset = dataset_class(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=list(classes),
|
||||
test_mode=True)
|
||||
|
||||
assert custom_dataset.CLASSES != original_classes
|
||||
assert custom_dataset.CLASSES == list(classes)
|
||||
|
||||
# Test overriding not a subset
|
||||
custom_dataset = dataset_class(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=[classes[0]],
|
||||
test_mode=True)
|
||||
|
||||
assert custom_dataset.CLASSES != original_classes
|
||||
assert custom_dataset.CLASSES == [classes[0]]
|
||||
|
||||
# Test default behavior
|
||||
if dataset_class is CustomDataset:
|
||||
with pytest.raises(AssertionError):
|
||||
custom_dataset = dataset_class(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=None,
|
||||
test_mode=True)
|
||||
else:
|
||||
custom_dataset = dataset_class(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=None,
|
||||
test_mode=True)
|
||||
|
||||
assert custom_dataset.CLASSES == original_classes
|
||||
|
||||
|
||||
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
||||
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
||||
MagicMock(side_effect=lambda idx: idx))
|
||||
def test_custom_dataset_random_palette_is_generated():
|
||||
dataset = CustomDataset(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=('bus', 'car'),
|
||||
test_mode=True)
|
||||
assert len(dataset.PALETTE) == 2
|
||||
for class_color in dataset.PALETTE:
|
||||
assert len(class_color) == 3
|
||||
assert all(x >= 0 and x <= 255 for x in class_color)
|
||||
|
||||
|
||||
@patch('mmseg.datasets.CustomDataset.load_annotations', MagicMock)
|
||||
@patch('mmseg.datasets.CustomDataset.__getitem__',
|
||||
MagicMock(side_effect=lambda idx: idx))
|
||||
def test_custom_dataset_custom_palette():
|
||||
dataset = CustomDataset(
|
||||
pipeline=[],
|
||||
img_dir=MagicMock(),
|
||||
split=MagicMock(),
|
||||
classes=('bus', 'car'),
|
||||
palette=[[100, 100, 100], [200, 200, 200]],
|
||||
test_mode=True)
|
||||
assert tuple(dataset.PALETTE) == tuple([[100, 100, 100], [200, 200, 200]])
|
|
@ -1,200 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import math
|
||||
import os.path as osp
|
||||
|
||||
import pytest
|
||||
from torch.utils.data import (DistributedSampler, RandomSampler,
|
||||
SequentialSampler)
|
||||
|
||||
from mmseg.datasets import (DATASETS, ConcatDataset, MultiImageMixDataset,
|
||||
build_dataloader, build_dataset)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ToyDataset(object):
|
||||
|
||||
def __init__(self, cnt=0):
|
||||
self.cnt = cnt
|
||||
|
||||
def __item__(self, idx):
|
||||
return idx
|
||||
|
||||
def __len__(self):
|
||||
return 100
|
||||
|
||||
|
||||
def test_build_dataset():
|
||||
cfg = dict(type='ToyDataset')
|
||||
dataset = build_dataset(cfg)
|
||||
assert isinstance(dataset, ToyDataset)
|
||||
assert dataset.cnt == 0
|
||||
dataset = build_dataset(cfg, default_args=dict(cnt=1))
|
||||
assert isinstance(dataset, ToyDataset)
|
||||
assert dataset.cnt == 1
|
||||
|
||||
data_root = osp.join(osp.dirname(__file__), '../data/pseudo_dataset')
|
||||
img_dir = 'imgs/'
|
||||
ann_dir = 'gts/'
|
||||
|
||||
# We use same dir twice for simplicity
|
||||
# with ann_dir
|
||||
cfg = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
img_dir=[img_dir, img_dir],
|
||||
ann_dir=[ann_dir, ann_dir])
|
||||
dataset = build_dataset(cfg)
|
||||
assert isinstance(dataset, ConcatDataset)
|
||||
assert len(dataset) == 10
|
||||
|
||||
cfg = dict(type='MultiImageMixDataset', dataset=cfg, pipeline=[])
|
||||
dataset = build_dataset(cfg)
|
||||
assert isinstance(dataset, MultiImageMixDataset)
|
||||
assert len(dataset) == 10
|
||||
|
||||
# with ann_dir, split
|
||||
cfg = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
img_dir=img_dir,
|
||||
ann_dir=ann_dir,
|
||||
split=['splits/train.txt', 'splits/val.txt'])
|
||||
dataset = build_dataset(cfg)
|
||||
assert isinstance(dataset, ConcatDataset)
|
||||
assert len(dataset) == 5
|
||||
|
||||
# with ann_dir, split
|
||||
cfg = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
img_dir=img_dir,
|
||||
ann_dir=[ann_dir, ann_dir],
|
||||
split=['splits/train.txt', 'splits/val.txt'])
|
||||
dataset = build_dataset(cfg)
|
||||
assert isinstance(dataset, ConcatDataset)
|
||||
assert len(dataset) == 5
|
||||
|
||||
# test mode
|
||||
cfg = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
img_dir=[img_dir, img_dir],
|
||||
test_mode=True,
|
||||
classes=('pseudo_class', ))
|
||||
dataset = build_dataset(cfg)
|
||||
assert isinstance(dataset, ConcatDataset)
|
||||
assert len(dataset) == 10
|
||||
|
||||
# test mode with splits
|
||||
cfg = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
img_dir=[img_dir, img_dir],
|
||||
split=['splits/val.txt', 'splits/val.txt'],
|
||||
test_mode=True,
|
||||
classes=('pseudo_class', ))
|
||||
dataset = build_dataset(cfg)
|
||||
assert isinstance(dataset, ConcatDataset)
|
||||
assert len(dataset) == 2
|
||||
|
||||
# len(ann_dir) should be zero or len(img_dir) when len(img_dir) > 1
|
||||
with pytest.raises(AssertionError):
|
||||
cfg = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
img_dir=[img_dir, img_dir],
|
||||
ann_dir=[ann_dir, ann_dir, ann_dir])
|
||||
build_dataset(cfg)
|
||||
|
||||
# len(splits) should be zero or len(img_dir) when len(img_dir) > 1
|
||||
with pytest.raises(AssertionError):
|
||||
cfg = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
img_dir=[img_dir, img_dir],
|
||||
split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt'])
|
||||
build_dataset(cfg)
|
||||
|
||||
# len(splits) == len(ann_dir) when only len(img_dir) == 1 and len(
|
||||
# ann_dir) > 1
|
||||
with pytest.raises(AssertionError):
|
||||
cfg = dict(
|
||||
type='CustomDataset',
|
||||
pipeline=[],
|
||||
data_root=data_root,
|
||||
img_dir=img_dir,
|
||||
ann_dir=[ann_dir, ann_dir],
|
||||
split=['splits/val.txt', 'splits/val.txt', 'splits/val.txt'])
|
||||
build_dataset(cfg)
|
||||
|
||||
|
||||
def test_build_dataloader():
|
||||
dataset = ToyDataset()
|
||||
samples_per_gpu = 3
|
||||
# dist=True, shuffle=True, 1GPU
|
||||
dataloader = build_dataloader(
|
||||
dataset, samples_per_gpu=samples_per_gpu, workers_per_gpu=2)
|
||||
assert dataloader.batch_size == samples_per_gpu
|
||||
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu))
|
||||
assert isinstance(dataloader.sampler, DistributedSampler)
|
||||
assert dataloader.sampler.shuffle
|
||||
|
||||
# dist=True, shuffle=False, 1GPU
|
||||
dataloader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=samples_per_gpu,
|
||||
workers_per_gpu=2,
|
||||
shuffle=False)
|
||||
assert dataloader.batch_size == samples_per_gpu
|
||||
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu))
|
||||
assert isinstance(dataloader.sampler, DistributedSampler)
|
||||
assert not dataloader.sampler.shuffle
|
||||
|
||||
# dist=True, shuffle=True, 8GPU
|
||||
dataloader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=samples_per_gpu,
|
||||
workers_per_gpu=2,
|
||||
num_gpus=8)
|
||||
assert dataloader.batch_size == samples_per_gpu
|
||||
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu))
|
||||
assert dataloader.num_workers == 2
|
||||
|
||||
# dist=False, shuffle=True, 1GPU
|
||||
dataloader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=samples_per_gpu,
|
||||
workers_per_gpu=2,
|
||||
dist=False)
|
||||
assert dataloader.batch_size == samples_per_gpu
|
||||
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu))
|
||||
assert isinstance(dataloader.sampler, RandomSampler)
|
||||
assert dataloader.num_workers == 2
|
||||
|
||||
# dist=False, shuffle=False, 1GPU
|
||||
dataloader = build_dataloader(
|
||||
dataset,
|
||||
samples_per_gpu=3,
|
||||
workers_per_gpu=2,
|
||||
shuffle=False,
|
||||
dist=False)
|
||||
assert dataloader.batch_size == samples_per_gpu
|
||||
assert len(dataloader) == int(math.ceil(len(dataset) / samples_per_gpu))
|
||||
assert isinstance(dataloader.sampler, SequentialSampler)
|
||||
assert dataloader.num_workers == 2
|
||||
|
||||
# dist=False, shuffle=True, 8GPU
|
||||
dataloader = build_dataloader(
|
||||
dataset, samples_per_gpu=3, workers_per_gpu=2, num_gpus=8, dist=False)
|
||||
assert dataloader.batch_size == samples_per_gpu * 8
|
||||
assert len(dataloader) == int(
|
||||
math.ceil(len(dataset) / samples_per_gpu / 8))
|
||||
assert isinstance(dataloader.sampler, RandomSampler)
|
||||
assert dataloader.num_workers == 16
|
|
@ -1,199 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
|
||||
from mmseg.datasets.pipelines import LoadAnnotations, LoadImageFromFile
|
||||
|
||||
|
||||
class TestLoading(object):
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.data_prefix = osp.join(osp.dirname(__file__), '../data')
|
||||
|
||||
def test_load_img(self):
|
||||
results = dict(
|
||||
img_prefix=self.data_prefix, img_info=dict(filename='color.jpg'))
|
||||
transform = LoadImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['filename'] == osp.join(self.data_prefix, 'color.jpg')
|
||||
assert results['ori_filename'] == 'color.jpg'
|
||||
assert results['img'].shape == (288, 512, 3)
|
||||
assert results['img'].dtype == np.uint8
|
||||
assert results['img_shape'] == (288, 512, 3)
|
||||
assert results['ori_shape'] == (288, 512, 3)
|
||||
assert results['pad_shape'] == (288, 512, 3)
|
||||
assert results['scale_factor'] == 1.0
|
||||
np.testing.assert_equal(results['img_norm_cfg']['mean'],
|
||||
np.zeros(3, dtype=np.float32))
|
||||
assert repr(transform) == transform.__class__.__name__ + \
|
||||
"(to_float32=False,color_type='color',imdecode_backend='cv2')"
|
||||
|
||||
# no img_prefix
|
||||
results = dict(
|
||||
img_prefix=None, img_info=dict(filename='tests/data/color.jpg'))
|
||||
transform = LoadImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['filename'] == 'tests/data/color.jpg'
|
||||
assert results['ori_filename'] == 'tests/data/color.jpg'
|
||||
assert results['img'].shape == (288, 512, 3)
|
||||
|
||||
# to_float32
|
||||
transform = LoadImageFromFile(to_float32=True)
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img'].dtype == np.float32
|
||||
|
||||
# gray image
|
||||
results = dict(
|
||||
img_prefix=self.data_prefix, img_info=dict(filename='gray.jpg'))
|
||||
transform = LoadImageFromFile()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img'].shape == (288, 512, 3)
|
||||
assert results['img'].dtype == np.uint8
|
||||
|
||||
transform = LoadImageFromFile(color_type='unchanged')
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['img'].shape == (288, 512)
|
||||
assert results['img'].dtype == np.uint8
|
||||
np.testing.assert_equal(results['img_norm_cfg']['mean'],
|
||||
np.zeros(1, dtype=np.float32))
|
||||
|
||||
def test_load_seg(self):
|
||||
results = dict(
|
||||
seg_prefix=self.data_prefix,
|
||||
ann_info=dict(seg_map='seg.png'),
|
||||
seg_fields=[])
|
||||
transform = LoadAnnotations()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['seg_fields'] == ['gt_semantic_seg']
|
||||
assert results['gt_semantic_seg'].shape == (288, 512)
|
||||
assert results['gt_semantic_seg'].dtype == np.uint8
|
||||
assert repr(transform) == transform.__class__.__name__ + \
|
||||
"(reduce_zero_label=False,imdecode_backend='pillow')"
|
||||
|
||||
# no img_prefix
|
||||
results = dict(
|
||||
seg_prefix=None,
|
||||
ann_info=dict(seg_map='tests/data/seg.png'),
|
||||
seg_fields=[])
|
||||
transform = LoadAnnotations()
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['gt_semantic_seg'].shape == (288, 512)
|
||||
assert results['gt_semantic_seg'].dtype == np.uint8
|
||||
|
||||
# reduce_zero_label
|
||||
transform = LoadAnnotations(reduce_zero_label=True)
|
||||
results = transform(copy.deepcopy(results))
|
||||
assert results['gt_semantic_seg'].shape == (288, 512)
|
||||
assert results['gt_semantic_seg'].dtype == np.uint8
|
||||
|
||||
# mmcv backend
|
||||
results = dict(
|
||||
seg_prefix=self.data_prefix,
|
||||
ann_info=dict(seg_map='seg.png'),
|
||||
seg_fields=[])
|
||||
transform = LoadAnnotations(imdecode_backend='pillow')
|
||||
results = transform(copy.deepcopy(results))
|
||||
# this image is saved by PIL
|
||||
assert results['gt_semantic_seg'].shape == (288, 512)
|
||||
assert results['gt_semantic_seg'].dtype == np.uint8
|
||||
|
||||
def test_load_seg_custom_classes(self):
|
||||
|
||||
test_img = np.random.rand(10, 10)
|
||||
test_gt = np.zeros_like(test_img)
|
||||
test_gt[2:4, 2:4] = 1
|
||||
test_gt[2:4, 6:8] = 2
|
||||
test_gt[6:8, 2:4] = 3
|
||||
test_gt[6:8, 6:8] = 4
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory()
|
||||
img_path = osp.join(tmp_dir.name, 'img.jpg')
|
||||
gt_path = osp.join(tmp_dir.name, 'gt.png')
|
||||
|
||||
mmcv.imwrite(test_img, img_path)
|
||||
mmcv.imwrite(test_gt, gt_path)
|
||||
|
||||
# test only train with label with id 3
|
||||
results = dict(
|
||||
img_info=dict(filename=img_path),
|
||||
ann_info=dict(seg_map=gt_path),
|
||||
label_map={
|
||||
0: 0,
|
||||
1: 0,
|
||||
2: 0,
|
||||
3: 1,
|
||||
4: 0
|
||||
},
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_semantic_seg']
|
||||
|
||||
true_mask = np.zeros_like(gt_array)
|
||||
true_mask[6:8, 2:4] = 1
|
||||
|
||||
assert results['seg_fields'] == ['gt_semantic_seg']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, true_mask)
|
||||
|
||||
# test only train with label with id 4 and 3
|
||||
results = dict(
|
||||
img_info=dict(filename=img_path),
|
||||
ann_info=dict(seg_map=gt_path),
|
||||
label_map={
|
||||
0: 0,
|
||||
1: 0,
|
||||
2: 0,
|
||||
3: 2,
|
||||
4: 1
|
||||
},
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_semantic_seg']
|
||||
|
||||
true_mask = np.zeros_like(gt_array)
|
||||
true_mask[6:8, 2:4] = 2
|
||||
true_mask[6:8, 6:8] = 1
|
||||
|
||||
assert results['seg_fields'] == ['gt_semantic_seg']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, true_mask)
|
||||
|
||||
# test no custom classes
|
||||
results = dict(
|
||||
img_info=dict(filename=img_path),
|
||||
ann_info=dict(seg_map=gt_path),
|
||||
seg_fields=[])
|
||||
|
||||
load_imgs = LoadImageFromFile()
|
||||
results = load_imgs(copy.deepcopy(results))
|
||||
|
||||
load_anns = LoadAnnotations()
|
||||
results = load_anns(copy.deepcopy(results))
|
||||
|
||||
gt_array = results['gt_semantic_seg']
|
||||
|
||||
assert results['seg_fields'] == ['gt_semantic_seg']
|
||||
assert gt_array.shape == (10, 10)
|
||||
assert gt_array.dtype == np.uint8
|
||||
np.testing.assert_array_equal(gt_array, test_gt)
|
||||
|
||||
tmp_dir.cleanup()
|
|
@ -1,690 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mmcv.utils import build_from_cfg
|
||||
from PIL import Image
|
||||
|
||||
from mmseg.datasets.builder import PIPELINES
|
||||
|
||||
|
||||
def test_resize_to_multiple():
|
||||
transform = dict(type='ResizeToMultiple', size_divisor=32)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
|
||||
img = np.random.randn(213, 232, 3)
|
||||
seg = np.random.randint(0, 19, (213, 232))
|
||||
results = dict()
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['pad_shape'] = img.shape
|
||||
|
||||
results = transform(results)
|
||||
assert results['img'].shape == (224, 256, 3)
|
||||
assert results['gt_semantic_seg'].shape == (224, 256)
|
||||
assert results['img_shape'] == (224, 256, 3)
|
||||
assert results['pad_shape'] == (224, 256, 3)
|
||||
|
||||
|
||||
def test_resize():
|
||||
# test assertion if img_scale is a list
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Resize', img_scale=[1333, 800], keep_ratio=True)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion if len(img_scale) while ratio_range is not None
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='Resize',
|
||||
img_scale=[(1333, 800), (1333, 600)],
|
||||
ratio_range=(0.9, 1.1),
|
||||
keep_ratio=True)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid multiscale_mode
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='Resize',
|
||||
img_scale=[(1333, 800), (1333, 600)],
|
||||
keep_ratio=True,
|
||||
multiscale_mode='2333')
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
transform = dict(type='Resize', img_scale=(1333, 800), keep_ratio=True)
|
||||
resize_module = build_from_cfg(transform, PIPELINES)
|
||||
|
||||
results = dict()
|
||||
# (288, 512, 3)
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'] == (750, 1333, 3)
|
||||
|
||||
# test keep_ratio=False
|
||||
transform = dict(
|
||||
type='Resize',
|
||||
img_scale=(1280, 800),
|
||||
multiscale_mode='value',
|
||||
keep_ratio=False)
|
||||
resize_module = build_from_cfg(transform, PIPELINES)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'] == (800, 1280, 3)
|
||||
|
||||
# test multiscale_mode='range'
|
||||
transform = dict(
|
||||
type='Resize',
|
||||
img_scale=[(1333, 400), (1333, 1200)],
|
||||
multiscale_mode='range',
|
||||
keep_ratio=True)
|
||||
resize_module = build_from_cfg(transform, PIPELINES)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert max(resized_results['img_shape'][:2]) <= 1333
|
||||
assert min(resized_results['img_shape'][:2]) >= 400
|
||||
assert min(resized_results['img_shape'][:2]) <= 1200
|
||||
|
||||
# test multiscale_mode='value'
|
||||
transform = dict(
|
||||
type='Resize',
|
||||
img_scale=[(1333, 800), (1333, 400)],
|
||||
multiscale_mode='value',
|
||||
keep_ratio=True)
|
||||
resize_module = build_from_cfg(transform, PIPELINES)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'] in [(750, 1333, 3), (400, 711, 3)]
|
||||
|
||||
# test multiscale_mode='range'
|
||||
transform = dict(
|
||||
type='Resize',
|
||||
img_scale=(1333, 800),
|
||||
ratio_range=(0.9, 1.1),
|
||||
keep_ratio=True)
|
||||
resize_module = build_from_cfg(transform, PIPELINES)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert max(resized_results['img_shape'][:2]) <= 1333 * 1.1
|
||||
|
||||
# test img_scale=None and ratio_range is tuple.
|
||||
# img shape: (288, 512, 3)
|
||||
transform = dict(
|
||||
type='Resize', img_scale=None, ratio_range=(0.5, 2.0), keep_ratio=True)
|
||||
resize_module = build_from_cfg(transform, PIPELINES)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert int(288 * 0.5) <= resized_results['img_shape'][0] <= 288 * 2.0
|
||||
assert int(512 * 0.5) <= resized_results['img_shape'][1] <= 512 * 2.0
|
||||
|
||||
# test min_size=640
|
||||
transform = dict(type='Resize', img_scale=(2560, 640), min_size=640)
|
||||
resize_module = build_from_cfg(transform, PIPELINES)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'] == (640, 1138, 3)
|
||||
|
||||
# test min_size=640 and img_scale=(512, 640)
|
||||
transform = dict(type='Resize', img_scale=(512, 640), min_size=640)
|
||||
resize_module = build_from_cfg(transform, PIPELINES)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'] == (640, 1138, 3)
|
||||
|
||||
# test h > w
|
||||
img = np.random.randn(512, 288, 3)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
transform = dict(type='Resize', img_scale=(2560, 640), min_size=640)
|
||||
resize_module = build_from_cfg(transform, PIPELINES)
|
||||
resized_results = resize_module(results.copy())
|
||||
assert resized_results['img_shape'] == (1138, 640, 3)
|
||||
|
||||
|
||||
def test_flip():
|
||||
# test assertion for invalid prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomFlip', prob=1.5)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion for invalid direction
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomFlip', prob=1, direction='horizonta')
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
transform = dict(type='RandomFlip', prob=1)
|
||||
flip_module = build_from_cfg(transform, PIPELINES)
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
original_seg = copy.deepcopy(seg)
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = flip_module(results)
|
||||
|
||||
flip_module = build_from_cfg(transform, PIPELINES)
|
||||
results = flip_module(results)
|
||||
assert np.equal(original_img, results['img']).all()
|
||||
assert np.equal(original_seg, results['gt_semantic_seg']).all()
|
||||
|
||||
|
||||
def test_random_crop():
|
||||
# test assertion for invalid random crop
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomCrop', crop_size=(-1, 0))
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
h, w, _ = img.shape
|
||||
transform = dict(type='RandomCrop', crop_size=(h - 20, w - 20))
|
||||
crop_module = build_from_cfg(transform, PIPELINES)
|
||||
results = crop_module(results)
|
||||
assert results['img'].shape[:2] == (h - 20, w - 20)
|
||||
assert results['img_shape'][:2] == (h - 20, w - 20)
|
||||
assert results['gt_semantic_seg'].shape[:2] == (h - 20, w - 20)
|
||||
|
||||
|
||||
def test_pad():
|
||||
# test assertion if both size_divisor and size is None
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Pad')
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
transform = dict(type='Pad', size_divisor=32)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
# original img already divisible by 32
|
||||
assert np.equal(results['img'], original_img).all()
|
||||
img_shape = results['img'].shape
|
||||
assert img_shape[0] % 32 == 0
|
||||
assert img_shape[1] % 32 == 0
|
||||
|
||||
resize_transform = dict(
|
||||
type='Resize', img_scale=(1333, 800), keep_ratio=True)
|
||||
resize_module = build_from_cfg(resize_transform, PIPELINES)
|
||||
results = resize_module(results)
|
||||
results = transform(results)
|
||||
img_shape = results['img'].shape
|
||||
assert img_shape[0] % 32 == 0
|
||||
assert img_shape[1] % 32 == 0
|
||||
|
||||
|
||||
def test_rotate():
|
||||
# test assertion degree should be tuple[float] or float
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomRotate', prob=0.5, degree=-10)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
# test assertion degree should be tuple[float] or float
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomRotate', prob=0.5, degree=(10., 20., 30.))
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
transform = dict(type='RandomRotate', degree=10., prob=1.)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
|
||||
assert str(transform) == f'RandomRotate(' \
|
||||
f'prob={1.}, ' \
|
||||
f'degree=({-10.}, {10.}), ' \
|
||||
f'pad_val={0}, ' \
|
||||
f'seg_pad_val={255}, ' \
|
||||
f'center={None}, ' \
|
||||
f'auto_bound={False})'
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
h, w, _ = img.shape
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
assert results['img'].shape[:2] == (h, w)
|
||||
assert results['gt_semantic_seg'].shape[:2] == (h, w)
|
||||
|
||||
|
||||
def test_normalize():
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
to_rgb=True)
|
||||
transform = dict(type='Normalize', **img_norm_cfg)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
|
||||
mean = np.array(img_norm_cfg['mean'])
|
||||
std = np.array(img_norm_cfg['std'])
|
||||
converted_img = (original_img[..., ::-1] - mean) / std
|
||||
assert np.allclose(results['img'], converted_img)
|
||||
|
||||
|
||||
def test_rgb2gray():
|
||||
# test assertion out_channels should be greater than 0
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RGB2Gray', out_channels=-1)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
# test assertion weights should be tuple[float]
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RGB2Gray', out_channels=1, weights=1.1)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test out_channels is None
|
||||
transform = dict(type='RGB2Gray')
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
|
||||
assert str(transform) == f'RGB2Gray(' \
|
||||
f'out_channels={None}, ' \
|
||||
f'weights={(0.299, 0.587, 0.114)})'
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
h, w, c = img.shape
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
assert results['img'].shape == (h, w, c)
|
||||
assert results['img_shape'] == (h, w, c)
|
||||
assert results['ori_shape'] == (h, w, c)
|
||||
|
||||
# test out_channels = 2
|
||||
transform = dict(type='RGB2Gray', out_channels=2)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
|
||||
assert str(transform) == f'RGB2Gray(' \
|
||||
f'out_channels={2}, ' \
|
||||
f'weights={(0.299, 0.587, 0.114)})'
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
h, w, c = img.shape
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
assert results['img'].shape == (h, w, 2)
|
||||
assert results['img_shape'] == (h, w, 2)
|
||||
assert results['ori_shape'] == (h, w, c)
|
||||
|
||||
|
||||
def test_adjust_gamma():
|
||||
# test assertion if gamma <= 0
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='AdjustGamma', gamma=0)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion if gamma is list
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='AdjustGamma', gamma=[1.2])
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test with gamma = 1.2
|
||||
transform = dict(type='AdjustGamma', gamma=1.2)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
|
||||
inv_gamma = 1.0 / 1.2
|
||||
table = np.array([((i / 255.0)**inv_gamma) * 255
|
||||
for i in np.arange(0, 256)]).astype('uint8')
|
||||
converted_img = mmcv.lut_transform(
|
||||
np.array(original_img, dtype=np.uint8), table)
|
||||
assert np.allclose(results['img'], converted_img)
|
||||
assert str(transform) == f'AdjustGamma(gamma={1.2})'
|
||||
|
||||
|
||||
def test_rerange():
|
||||
# test assertion if min_value or max_value is illegal
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Rerange', min_value=[0], max_value=[255])
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion if min_value >= max_value
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Rerange', min_value=1, max_value=1)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion if img_min_value == img_max_value
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='Rerange', min_value=0, max_value=1)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
results = dict()
|
||||
results['img'] = np.array([[1, 1], [1, 1]])
|
||||
transform(results)
|
||||
|
||||
img_rerange_cfg = dict()
|
||||
transform = dict(type='Rerange', **img_rerange_cfg)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
|
||||
min_value = np.min(original_img)
|
||||
max_value = np.max(original_img)
|
||||
converted_img = (original_img - min_value) / (max_value - min_value) * 255
|
||||
|
||||
assert np.allclose(results['img'], converted_img)
|
||||
assert str(transform) == f'Rerange(min_value={0}, max_value={255})'
|
||||
|
||||
|
||||
def test_CLAHE():
|
||||
# test assertion if clip_limit is None
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='CLAHE', clip_limit=None)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion if tile_grid_size is illegal
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='CLAHE', tile_grid_size=(8.0, 8.0))
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
# test assertion if tile_grid_size is illegal
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='CLAHE', tile_grid_size=(9, 9, 9))
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
transform = dict(type='CLAHE', clip_limit=2)
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
original_img = copy.deepcopy(img)
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
results = transform(results)
|
||||
|
||||
converted_img = np.empty(original_img.shape)
|
||||
for i in range(original_img.shape[2]):
|
||||
converted_img[:, :, i] = mmcv.clahe(
|
||||
np.array(original_img[:, :, i], dtype=np.uint8), 2, (8, 8))
|
||||
|
||||
assert np.allclose(results['img'], converted_img)
|
||||
assert str(transform) == f'CLAHE(clip_limit={2}, tile_grid_size={(8, 8)})'
|
||||
|
||||
|
||||
def test_seg_rescale():
|
||||
results = dict()
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
h, w = seg.shape
|
||||
|
||||
transform = dict(type='SegRescale', scale_factor=1. / 2)
|
||||
rescale_module = build_from_cfg(transform, PIPELINES)
|
||||
rescale_results = rescale_module(results.copy())
|
||||
assert rescale_results['gt_semantic_seg'].shape == (h // 2, w // 2)
|
||||
|
||||
transform = dict(type='SegRescale', scale_factor=1)
|
||||
rescale_module = build_from_cfg(transform, PIPELINES)
|
||||
rescale_results = rescale_module(results.copy())
|
||||
assert rescale_results['gt_semantic_seg'].shape == (h, w)
|
||||
|
||||
|
||||
def test_cutout():
|
||||
# test prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomCutOut', prob=1.5, n_holes=1)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
# test n_holes
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut', prob=0.5, n_holes=(5, 3), cutout_shape=(8, 8))
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut',
|
||||
prob=0.5,
|
||||
n_holes=(3, 4, 5),
|
||||
cutout_shape=(8, 8))
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
# test cutout_shape and cutout_ratio
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut', prob=0.5, n_holes=1, cutout_shape=8)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut', prob=0.5, n_holes=1, cutout_ratio=0.2)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
# either of cutout_shape and cutout_ratio should be given
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomCutOut', prob=0.5, n_holes=1)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut',
|
||||
prob=0.5,
|
||||
n_holes=1,
|
||||
cutout_shape=(2, 2),
|
||||
cutout_ratio=(0.4, 0.4))
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
# test seg_fill_in
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut',
|
||||
prob=0.5,
|
||||
n_holes=1,
|
||||
cutout_shape=(8, 8),
|
||||
seg_fill_in='a')
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(
|
||||
type='RandomCutOut',
|
||||
prob=0.5,
|
||||
n_holes=1,
|
||||
cutout_shape=(8, 8),
|
||||
seg_fill_in=256)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
results['pad_shape'] = img.shape
|
||||
results['img_fields'] = ['img']
|
||||
|
||||
transform = dict(
|
||||
type='RandomCutOut', prob=1, n_holes=1, cutout_shape=(10, 10))
|
||||
cutout_module = build_from_cfg(transform, PIPELINES)
|
||||
assert 'cutout_shape' in repr(cutout_module)
|
||||
cutout_result = cutout_module(copy.deepcopy(results))
|
||||
assert cutout_result['img'].sum() < img.sum()
|
||||
|
||||
transform = dict(
|
||||
type='RandomCutOut', prob=1, n_holes=1, cutout_ratio=(0.8, 0.8))
|
||||
cutout_module = build_from_cfg(transform, PIPELINES)
|
||||
assert 'cutout_ratio' in repr(cutout_module)
|
||||
cutout_result = cutout_module(copy.deepcopy(results))
|
||||
assert cutout_result['img'].sum() < img.sum()
|
||||
|
||||
transform = dict(
|
||||
type='RandomCutOut', prob=0, n_holes=1, cutout_ratio=(0.8, 0.8))
|
||||
cutout_module = build_from_cfg(transform, PIPELINES)
|
||||
cutout_result = cutout_module(copy.deepcopy(results))
|
||||
assert cutout_result['img'].sum() == img.sum()
|
||||
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
|
||||
|
||||
transform = dict(
|
||||
type='RandomCutOut',
|
||||
prob=1,
|
||||
n_holes=(2, 4),
|
||||
cutout_shape=[(10, 10), (15, 15)],
|
||||
fill_in=(255, 255, 255),
|
||||
seg_fill_in=None)
|
||||
cutout_module = build_from_cfg(transform, PIPELINES)
|
||||
cutout_result = cutout_module(copy.deepcopy(results))
|
||||
assert cutout_result['img'].sum() > img.sum()
|
||||
assert cutout_result['gt_semantic_seg'].sum() == seg.sum()
|
||||
|
||||
transform = dict(
|
||||
type='RandomCutOut',
|
||||
prob=1,
|
||||
n_holes=1,
|
||||
cutout_ratio=(0.8, 0.8),
|
||||
fill_in=(255, 255, 255),
|
||||
seg_fill_in=255)
|
||||
cutout_module = build_from_cfg(transform, PIPELINES)
|
||||
cutout_result = cutout_module(copy.deepcopy(results))
|
||||
assert cutout_result['img'].sum() > img.sum()
|
||||
assert cutout_result['gt_semantic_seg'].sum() > seg.sum()
|
||||
|
||||
|
||||
def test_mosaic():
|
||||
# test prob
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomMosaic', prob=1.5)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
# test assertion for invalid img_scale
|
||||
with pytest.raises(AssertionError):
|
||||
transform = dict(type='RandomMosaic', prob=1, img_scale=640)
|
||||
build_from_cfg(transform, PIPELINES)
|
||||
|
||||
results = dict()
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
seg = np.array(
|
||||
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
|
||||
|
||||
results['img'] = img
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
|
||||
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
|
||||
mosaic_module = build_from_cfg(transform, PIPELINES)
|
||||
assert 'Mosaic' in repr(mosaic_module)
|
||||
|
||||
# test assertion for invalid mix_results
|
||||
with pytest.raises(AssertionError):
|
||||
mosaic_module(results)
|
||||
|
||||
results['mix_results'] = [copy.deepcopy(results)] * 3
|
||||
results = mosaic_module(results)
|
||||
assert results['img'].shape[:2] == (20, 24)
|
||||
|
||||
results = dict()
|
||||
results['img'] = img[:, :, 0]
|
||||
results['gt_semantic_seg'] = seg
|
||||
results['seg_fields'] = ['gt_semantic_seg']
|
||||
|
||||
transform = dict(type='RandomMosaic', prob=0, img_scale=(10, 12))
|
||||
mosaic_module = build_from_cfg(transform, PIPELINES)
|
||||
results['mix_results'] = [copy.deepcopy(results)] * 3
|
||||
results = mosaic_module(results)
|
||||
assert results['img'].shape[:2] == img.shape[:2]
|
||||
|
||||
transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
|
||||
mosaic_module = build_from_cfg(transform, PIPELINES)
|
||||
results = mosaic_module(results)
|
||||
assert results['img'].shape[:2] == (20, 24)
|
|
@ -1,151 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import pytest
|
||||
from mmcv.utils import build_from_cfg
|
||||
|
||||
from mmseg.datasets.builder import PIPELINES
|
||||
|
||||
|
||||
def test_multi_scale_flip_aug():
|
||||
# test assertion if img_scale=None, img_ratios=1 (not float).
|
||||
with pytest.raises(AssertionError):
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=None,
|
||||
img_ratios=1,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
build_from_cfg(tta_transform, PIPELINES)
|
||||
|
||||
# test assertion if img_scale=None, img_ratios=None.
|
||||
with pytest.raises(AssertionError):
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=None,
|
||||
img_ratios=None,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
build_from_cfg(tta_transform, PIPELINES)
|
||||
|
||||
# test assertion if img_scale=(512, 512), img_ratios=1 (not float).
|
||||
with pytest.raises(AssertionError):
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(512, 512),
|
||||
img_ratios=1,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
build_from_cfg(tta_transform, PIPELINES)
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(512, 512),
|
||||
img_ratios=[0.5, 1.0, 2.0],
|
||||
flip=False,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||
|
||||
results = dict()
|
||||
# (288, 512, 3)
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape
|
||||
results['ori_shape'] = img.shape
|
||||
# Set initial values for default meta_keys
|
||||
results['pad_shape'] = img.shape
|
||||
results['scale_factor'] = 1.0
|
||||
|
||||
tta_results = tta_module(results.copy())
|
||||
assert tta_results['scale'] == [(256, 256), (512, 512), (1024, 1024)]
|
||||
assert tta_results['flip'] == [False, False, False]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(512, 512),
|
||||
img_ratios=[0.5, 1.0, 2.0],
|
||||
flip=True,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
|
||||
(512, 512), (1024, 1024), (1024, 1024)]
|
||||
assert tta_results['flip'] == [False, True, False, True, False, True]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(512, 512),
|
||||
img_ratios=1.0,
|
||||
flip=False,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert tta_results['scale'] == [(512, 512)]
|
||||
assert tta_results['flip'] == [False]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(512, 512),
|
||||
img_ratios=1.0,
|
||||
flip=True,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert tta_results['scale'] == [(512, 512), (512, 512)]
|
||||
assert tta_results['flip'] == [False, True]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=None,
|
||||
img_ratios=[0.5, 1.0, 2.0],
|
||||
flip=False,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert tta_results['scale'] == [(256, 144), (512, 288), (1024, 576)]
|
||||
assert tta_results['flip'] == [False, False, False]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=None,
|
||||
img_ratios=[0.5, 1.0, 2.0],
|
||||
flip=True,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert tta_results['scale'] == [(256, 144), (256, 144), (512, 288),
|
||||
(512, 288), (1024, 576), (1024, 576)]
|
||||
assert tta_results['flip'] == [False, True, False, True, False, True]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=[(256, 256), (512, 512), (1024, 1024)],
|
||||
img_ratios=None,
|
||||
flip=False,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert tta_results['scale'] == [(256, 256), (512, 512), (1024, 1024)]
|
||||
assert tta_results['flip'] == [False, False, False]
|
||||
|
||||
tta_transform = dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=[(256, 256), (512, 512), (1024, 1024)],
|
||||
img_ratios=None,
|
||||
flip=True,
|
||||
transforms=[dict(type='Resize', keep_ratio=False)],
|
||||
)
|
||||
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||
tta_results = tta_module(results.copy())
|
||||
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
|
||||
(512, 512), (1024, 1024), (1024, 1024)]
|
||||
assert tta_results['flip'] == [False, True, False, True, False, True]
|
|
@ -1,204 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import mmcv.runner
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import obj_from_dict
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmseg.apis import single_gpu_test
|
||||
from mmseg.core import DistEvalHook, EvalHook
|
||||
|
||||
|
||||
class ExampleDataset(Dataset):
|
||||
|
||||
def __getitem__(self, idx):
|
||||
results = dict(img=torch.tensor([1]), img_metas=dict())
|
||||
return results
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
|
||||
class ExampleModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleModel, self).__init__()
|
||||
self.test_cfg = None
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, img, img_metas, test_mode=False, **kwargs):
|
||||
return img
|
||||
|
||||
def train_step(self, data_batch, optimizer):
|
||||
loss = self.forward(**data_batch)
|
||||
return dict(loss=loss)
|
||||
|
||||
|
||||
def test_iter_eval_hook():
|
||||
with pytest.raises(TypeError):
|
||||
test_dataset = ExampleModel()
|
||||
data_loader = [
|
||||
DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
sampler=None,
|
||||
num_worker=0,
|
||||
shuffle=False)
|
||||
]
|
||||
EvalHook(data_loader)
|
||||
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])])
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
loader = DataLoader(test_dataset, batch_size=1)
|
||||
model = ExampleModel()
|
||||
data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||
dict(params=model.parameters()))
|
||||
|
||||
# test EvalHook
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
eval_hook = EvalHook(data_loader, by_epoch=False, efficient_test=True)
|
||||
runner = mmcv.runner.IterBasedRunner(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logging.getLogger())
|
||||
runner.register_hook(eval_hook)
|
||||
runner.run([loader], [('train', 1)], 1)
|
||||
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
|
||||
logger=runner.logger)
|
||||
|
||||
|
||||
def test_epoch_eval_hook():
|
||||
with pytest.raises(TypeError):
|
||||
test_dataset = ExampleModel()
|
||||
data_loader = [
|
||||
DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
sampler=None,
|
||||
num_worker=0,
|
||||
shuffle=False)
|
||||
]
|
||||
EvalHook(data_loader, by_epoch=True)
|
||||
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])])
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
loader = DataLoader(test_dataset, batch_size=1)
|
||||
model = ExampleModel()
|
||||
data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||
dict(params=model.parameters()))
|
||||
|
||||
# test EvalHook with interval
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
eval_hook = EvalHook(data_loader, by_epoch=True, interval=2)
|
||||
runner = mmcv.runner.EpochBasedRunner(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logging.getLogger())
|
||||
runner.register_hook(eval_hook)
|
||||
runner.run([loader], [('train', 1)], 2)
|
||||
test_dataset.evaluate.assert_called_once_with([torch.tensor([1])],
|
||||
logger=runner.logger)
|
||||
|
||||
|
||||
def multi_gpu_test(model,
|
||||
data_loader,
|
||||
tmpdir=None,
|
||||
gpu_collect=False,
|
||||
pre_eval=False):
|
||||
# Pre eval is set by default when training.
|
||||
results = single_gpu_test(model, data_loader, pre_eval=True)
|
||||
return results
|
||||
|
||||
|
||||
@patch('mmseg.apis.multi_gpu_test', multi_gpu_test)
|
||||
def test_dist_eval_hook():
|
||||
with pytest.raises(TypeError):
|
||||
test_dataset = ExampleModel()
|
||||
data_loader = [
|
||||
DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
sampler=None,
|
||||
num_worker=0,
|
||||
shuffle=False)
|
||||
]
|
||||
DistEvalHook(data_loader)
|
||||
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])])
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
loader = DataLoader(test_dataset, batch_size=1)
|
||||
model = ExampleModel()
|
||||
data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||
dict(params=model.parameters()))
|
||||
|
||||
# test DistEvalHook
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
eval_hook = DistEvalHook(
|
||||
data_loader, by_epoch=False, efficient_test=True)
|
||||
runner = mmcv.runner.IterBasedRunner(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logging.getLogger())
|
||||
runner.register_hook(eval_hook)
|
||||
runner.run([loader], [('train', 1)], 1)
|
||||
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
|
||||
logger=runner.logger)
|
||||
|
||||
|
||||
@patch('mmseg.apis.multi_gpu_test', multi_gpu_test)
|
||||
def test_dist_eval_hook_epoch():
|
||||
with pytest.raises(TypeError):
|
||||
test_dataset = ExampleModel()
|
||||
data_loader = [
|
||||
DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
sampler=None,
|
||||
num_worker=0,
|
||||
shuffle=False)
|
||||
]
|
||||
DistEvalHook(data_loader)
|
||||
|
||||
test_dataset = ExampleDataset()
|
||||
test_dataset.pre_eval = MagicMock(return_value=[torch.tensor([1])])
|
||||
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||
loader = DataLoader(test_dataset, batch_size=1)
|
||||
model = ExampleModel()
|
||||
data_loader = DataLoader(
|
||||
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||
dict(params=model.parameters()))
|
||||
|
||||
# test DistEvalHook
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
eval_hook = DistEvalHook(data_loader, by_epoch=True, interval=2)
|
||||
runner = mmcv.runner.EpochBasedRunner(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
work_dir=tmpdir,
|
||||
logger=logging.getLogger())
|
||||
runner.register_hook(eval_hook)
|
||||
runner.run([loader], [('train', 1)], 2)
|
||||
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
|
||||
logger=runner.logger)
|
|
@ -1,30 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
|
||||
from mmseg.apis import inference_segmentor, init_segmentor
|
||||
|
||||
|
||||
def test_test_time_augmentation_on_cpu():
|
||||
config_file = 'configs/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes.py'
|
||||
config = mmcv.Config.fromfile(config_file)
|
||||
|
||||
# Remove pretrain model download for testing
|
||||
config.model.pretrained = None
|
||||
# Replace SyncBN with BN to inference on CPU
|
||||
norm_cfg = dict(type='BN', requires_grad=True)
|
||||
config.model.backbone.norm_cfg = norm_cfg
|
||||
config.model.decode_head.norm_cfg = norm_cfg
|
||||
config.model.auxiliary_head.norm_cfg = norm_cfg
|
||||
|
||||
# Enable test time augmentation
|
||||
config.data.test.pipeline[1].flip = True
|
||||
|
||||
checkpoint_file = None
|
||||
model = init_segmentor(config, checkpoint_file, device='cpu')
|
||||
|
||||
img = mmcv.imread(
|
||||
osp.join(osp.dirname(__file__), 'data/color.jpg'), 'color')
|
||||
result = inference_segmentor(model, img)
|
||||
assert result[0].shape == (288, 512)
|
|
@ -1,351 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
|
||||
from mmseg.core.evaluation import (eval_metrics, mean_dice, mean_fscore,
|
||||
mean_iou)
|
||||
from mmseg.core.evaluation.metrics import f_score
|
||||
|
||||
|
||||
def get_confusion_matrix(pred_label, label, num_classes, ignore_index):
|
||||
"""Intersection over Union
|
||||
Args:
|
||||
pred_label (np.ndarray): 2D predict map
|
||||
label (np.ndarray): label 2D label map
|
||||
num_classes (int): number of categories
|
||||
ignore_index (int): index ignore in evaluation
|
||||
"""
|
||||
|
||||
mask = (label != ignore_index)
|
||||
pred_label = pred_label[mask]
|
||||
label = label[mask]
|
||||
|
||||
n = num_classes
|
||||
inds = n * label + pred_label
|
||||
|
||||
mat = np.bincount(inds, minlength=n**2).reshape(n, n)
|
||||
|
||||
return mat
|
||||
|
||||
|
||||
# This func is deprecated since it's not memory efficient
|
||||
def legacy_mean_iou(results, gt_seg_maps, num_classes, ignore_index):
|
||||
num_imgs = len(results)
|
||||
assert len(gt_seg_maps) == num_imgs
|
||||
total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
|
||||
for i in range(num_imgs):
|
||||
mat = get_confusion_matrix(
|
||||
results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
|
||||
total_mat += mat
|
||||
all_acc = np.diag(total_mat).sum() / total_mat.sum()
|
||||
acc = np.diag(total_mat) / total_mat.sum(axis=1)
|
||||
iou = np.diag(total_mat) / (
|
||||
total_mat.sum(axis=1) + total_mat.sum(axis=0) - np.diag(total_mat))
|
||||
|
||||
return all_acc, acc, iou
|
||||
|
||||
|
||||
# This func is deprecated since it's not memory efficient
|
||||
def legacy_mean_dice(results, gt_seg_maps, num_classes, ignore_index):
|
||||
num_imgs = len(results)
|
||||
assert len(gt_seg_maps) == num_imgs
|
||||
total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
|
||||
for i in range(num_imgs):
|
||||
mat = get_confusion_matrix(
|
||||
results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
|
||||
total_mat += mat
|
||||
all_acc = np.diag(total_mat).sum() / total_mat.sum()
|
||||
acc = np.diag(total_mat) / total_mat.sum(axis=1)
|
||||
dice = 2 * np.diag(total_mat) / (
|
||||
total_mat.sum(axis=1) + total_mat.sum(axis=0))
|
||||
|
||||
return all_acc, acc, dice
|
||||
|
||||
|
||||
# This func is deprecated since it's not memory efficient
|
||||
def legacy_mean_fscore(results,
|
||||
gt_seg_maps,
|
||||
num_classes,
|
||||
ignore_index,
|
||||
beta=1):
|
||||
num_imgs = len(results)
|
||||
assert len(gt_seg_maps) == num_imgs
|
||||
total_mat = np.zeros((num_classes, num_classes), dtype=np.float)
|
||||
for i in range(num_imgs):
|
||||
mat = get_confusion_matrix(
|
||||
results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index)
|
||||
total_mat += mat
|
||||
all_acc = np.diag(total_mat).sum() / total_mat.sum()
|
||||
recall = np.diag(total_mat) / total_mat.sum(axis=1)
|
||||
precision = np.diag(total_mat) / total_mat.sum(axis=0)
|
||||
fv = np.vectorize(f_score)
|
||||
fscore = fv(precision, recall, beta=beta)
|
||||
|
||||
return all_acc, recall, precision, fscore
|
||||
|
||||
|
||||
def test_metrics():
|
||||
pred_size = (10, 30, 30)
|
||||
num_classes = 19
|
||||
ignore_index = 255
|
||||
results = np.random.randint(0, num_classes, size=pred_size)
|
||||
label = np.random.randint(0, num_classes, size=pred_size)
|
||||
|
||||
# Test the availability of arg: ignore_index.
|
||||
label[:, 2, 5:10] = ignore_index
|
||||
|
||||
# Test the correctness of the implementation of mIoU calculation.
|
||||
ret_metrics = eval_metrics(
|
||||
results, label, num_classes, ignore_index, metrics='mIoU')
|
||||
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||
'IoU']
|
||||
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
|
||||
ignore_index)
|
||||
assert all_acc == all_acc_l
|
||||
assert np.allclose(acc, acc_l)
|
||||
assert np.allclose(iou, iou_l)
|
||||
# Test the correctness of the implementation of mDice calculation.
|
||||
ret_metrics = eval_metrics(
|
||||
results, label, num_classes, ignore_index, metrics='mDice')
|
||||
all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||
'Dice']
|
||||
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
|
||||
ignore_index)
|
||||
assert all_acc == all_acc_l
|
||||
assert np.allclose(acc, acc_l)
|
||||
assert np.allclose(dice, dice_l)
|
||||
# Test the correctness of the implementation of mDice calculation.
|
||||
ret_metrics = eval_metrics(
|
||||
results, label, num_classes, ignore_index, metrics='mFscore')
|
||||
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
|
||||
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
|
||||
all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
|
||||
results, label, num_classes, ignore_index)
|
||||
assert all_acc == all_acc_l
|
||||
assert np.allclose(recall, recall_l)
|
||||
assert np.allclose(precision, precision_l)
|
||||
assert np.allclose(fscore, fscore_l)
|
||||
# Test the correctness of the implementation of joint calculation.
|
||||
ret_metrics = eval_metrics(
|
||||
results,
|
||||
label,
|
||||
num_classes,
|
||||
ignore_index,
|
||||
metrics=['mIoU', 'mDice', 'mFscore'])
|
||||
all_acc, acc, iou, dice, precision, recall, fscore = ret_metrics[
|
||||
'aAcc'], ret_metrics['Acc'], ret_metrics['IoU'], ret_metrics[
|
||||
'Dice'], ret_metrics['Precision'], ret_metrics[
|
||||
'Recall'], ret_metrics['Fscore']
|
||||
assert all_acc == all_acc_l
|
||||
assert np.allclose(acc, acc_l)
|
||||
assert np.allclose(iou, iou_l)
|
||||
assert np.allclose(dice, dice_l)
|
||||
assert np.allclose(precision, precision_l)
|
||||
assert np.allclose(recall, recall_l)
|
||||
assert np.allclose(fscore, fscore_l)
|
||||
|
||||
# Test the correctness of calculation when arg: num_classes is larger
|
||||
# than the maximum value of input maps.
|
||||
results = np.random.randint(0, 5, size=pred_size)
|
||||
label = np.random.randint(0, 4, size=pred_size)
|
||||
ret_metrics = eval_metrics(
|
||||
results,
|
||||
label,
|
||||
num_classes,
|
||||
ignore_index=255,
|
||||
metrics='mIoU',
|
||||
nan_to_num=-1)
|
||||
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||
'IoU']
|
||||
assert acc[-1] == -1
|
||||
assert iou[-1] == -1
|
||||
|
||||
ret_metrics = eval_metrics(
|
||||
results,
|
||||
label,
|
||||
num_classes,
|
||||
ignore_index=255,
|
||||
metrics='mDice',
|
||||
nan_to_num=-1)
|
||||
all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||
'Dice']
|
||||
assert acc[-1] == -1
|
||||
assert dice[-1] == -1
|
||||
|
||||
ret_metrics = eval_metrics(
|
||||
results,
|
||||
label,
|
||||
num_classes,
|
||||
ignore_index=255,
|
||||
metrics='mFscore',
|
||||
nan_to_num=-1)
|
||||
all_acc, precision, recall, fscore = ret_metrics['aAcc'], ret_metrics[
|
||||
'Precision'], ret_metrics['Recall'], ret_metrics['Fscore']
|
||||
assert precision[-1] == -1
|
||||
assert recall[-1] == -1
|
||||
assert fscore[-1] == -1
|
||||
|
||||
ret_metrics = eval_metrics(
|
||||
results,
|
||||
label,
|
||||
num_classes,
|
||||
ignore_index=255,
|
||||
metrics=['mDice', 'mIoU', 'mFscore'],
|
||||
nan_to_num=-1)
|
||||
all_acc, acc, iou, dice, precision, recall, fscore = ret_metrics[
|
||||
'aAcc'], ret_metrics['Acc'], ret_metrics['IoU'], ret_metrics[
|
||||
'Dice'], ret_metrics['Precision'], ret_metrics[
|
||||
'Recall'], ret_metrics['Fscore']
|
||||
assert acc[-1] == -1
|
||||
assert dice[-1] == -1
|
||||
assert iou[-1] == -1
|
||||
assert precision[-1] == -1
|
||||
assert recall[-1] == -1
|
||||
assert fscore[-1] == -1
|
||||
|
||||
# Test the bug which is caused by torch.histc.
|
||||
# torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html
|
||||
# When the arg:bins is set to be same as arg:max,
|
||||
# some channels of mIoU may be nan.
|
||||
results = np.array([np.repeat(31, 59)])
|
||||
label = np.array([np.arange(59)])
|
||||
num_classes = 59
|
||||
ret_metrics = eval_metrics(
|
||||
results, label, num_classes, ignore_index=255, metrics='mIoU')
|
||||
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||
'IoU']
|
||||
assert not np.any(np.isnan(iou))
|
||||
|
||||
|
||||
def test_mean_iou():
|
||||
pred_size = (10, 30, 30)
|
||||
num_classes = 19
|
||||
ignore_index = 255
|
||||
results = np.random.randint(0, num_classes, size=pred_size)
|
||||
label = np.random.randint(0, num_classes, size=pred_size)
|
||||
label[:, 2, 5:10] = ignore_index
|
||||
ret_metrics = mean_iou(results, label, num_classes, ignore_index)
|
||||
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||
'IoU']
|
||||
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
|
||||
ignore_index)
|
||||
assert all_acc == all_acc_l
|
||||
assert np.allclose(acc, acc_l)
|
||||
assert np.allclose(iou, iou_l)
|
||||
|
||||
results = np.random.randint(0, 5, size=pred_size)
|
||||
label = np.random.randint(0, 4, size=pred_size)
|
||||
ret_metrics = mean_iou(
|
||||
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
||||
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||
'IoU']
|
||||
assert acc[-1] == -1
|
||||
assert acc[-1] == -1
|
||||
|
||||
|
||||
def test_mean_dice():
|
||||
pred_size = (10, 30, 30)
|
||||
num_classes = 19
|
||||
ignore_index = 255
|
||||
results = np.random.randint(0, num_classes, size=pred_size)
|
||||
label = np.random.randint(0, num_classes, size=pred_size)
|
||||
label[:, 2, 5:10] = ignore_index
|
||||
ret_metrics = mean_dice(results, label, num_classes, ignore_index)
|
||||
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||
'Dice']
|
||||
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
|
||||
ignore_index)
|
||||
assert all_acc == all_acc_l
|
||||
assert np.allclose(acc, acc_l)
|
||||
assert np.allclose(iou, dice_l)
|
||||
|
||||
results = np.random.randint(0, 5, size=pred_size)
|
||||
label = np.random.randint(0, 4, size=pred_size)
|
||||
ret_metrics = mean_dice(
|
||||
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
||||
all_acc, acc, dice = ret_metrics['aAcc'], ret_metrics['Acc'], ret_metrics[
|
||||
'Dice']
|
||||
assert acc[-1] == -1
|
||||
assert dice[-1] == -1
|
||||
|
||||
|
||||
def test_mean_fscore():
|
||||
pred_size = (10, 30, 30)
|
||||
num_classes = 19
|
||||
ignore_index = 255
|
||||
results = np.random.randint(0, num_classes, size=pred_size)
|
||||
label = np.random.randint(0, num_classes, size=pred_size)
|
||||
label[:, 2, 5:10] = ignore_index
|
||||
ret_metrics = mean_fscore(results, label, num_classes, ignore_index)
|
||||
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
|
||||
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
|
||||
all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
|
||||
results, label, num_classes, ignore_index)
|
||||
assert all_acc == all_acc_l
|
||||
assert np.allclose(recall, recall_l)
|
||||
assert np.allclose(precision, precision_l)
|
||||
assert np.allclose(fscore, fscore_l)
|
||||
|
||||
ret_metrics = mean_fscore(
|
||||
results, label, num_classes, ignore_index, beta=2)
|
||||
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
|
||||
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
|
||||
all_acc_l, recall_l, precision_l, fscore_l = legacy_mean_fscore(
|
||||
results, label, num_classes, ignore_index, beta=2)
|
||||
assert all_acc == all_acc_l
|
||||
assert np.allclose(recall, recall_l)
|
||||
assert np.allclose(precision, precision_l)
|
||||
assert np.allclose(fscore, fscore_l)
|
||||
|
||||
results = np.random.randint(0, 5, size=pred_size)
|
||||
label = np.random.randint(0, 4, size=pred_size)
|
||||
ret_metrics = mean_fscore(
|
||||
results, label, num_classes, ignore_index=255, nan_to_num=-1)
|
||||
all_acc, recall, precision, fscore = ret_metrics['aAcc'], ret_metrics[
|
||||
'Recall'], ret_metrics['Precision'], ret_metrics['Fscore']
|
||||
assert recall[-1] == -1
|
||||
assert precision[-1] == -1
|
||||
assert fscore[-1] == -1
|
||||
|
||||
|
||||
def test_filename_inputs():
|
||||
import tempfile
|
||||
|
||||
import cv2
|
||||
|
||||
def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
|
||||
filenames = []
|
||||
SUFFIX = '.png' if is_image else '.npy'
|
||||
for idx, arr in enumerate(input_arrays):
|
||||
filename = '{}/{}-{}{}'.format(dir, title, idx, SUFFIX)
|
||||
if is_image:
|
||||
cv2.imwrite(filename, arr)
|
||||
else:
|
||||
np.save(filename, arr)
|
||||
filenames.append(filename)
|
||||
return filenames
|
||||
|
||||
pred_size = (10, 30, 30)
|
||||
num_classes = 19
|
||||
ignore_index = 255
|
||||
results = np.random.randint(0, num_classes, size=pred_size)
|
||||
labels = np.random.randint(0, num_classes, size=pred_size)
|
||||
labels[:, 2, 5:10] = ignore_index
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
|
||||
result_files = save_arr(results, 'pred', False, temp_dir)
|
||||
label_files = save_arr(labels, 'label', True, temp_dir)
|
||||
|
||||
ret_metrics = eval_metrics(
|
||||
result_files,
|
||||
label_files,
|
||||
num_classes,
|
||||
ignore_index,
|
||||
metrics='mIoU')
|
||||
all_acc, acc, iou = ret_metrics['aAcc'], ret_metrics[
|
||||
'Acc'], ret_metrics['IoU']
|
||||
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, labels, num_classes,
|
||||
ignore_index)
|
||||
assert all_acc == all_acc_l
|
||||
assert np.allclose(acc, acc_l)
|
||||
assert np.allclose(iou, iou_l)
|
|
@ -1,19 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import DAHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_da_head():
|
||||
|
||||
inputs = [torch.randn(1, 16, 23, 23)]
|
||||
head = DAHead(in_channels=16, channels=8, num_classes=19, pam_channels=8)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert isinstance(outputs, tuple) and len(outputs) == 3
|
||||
for output in outputs:
|
||||
assert output.shape == (1, head.num_classes, 23, 23)
|
||||
test_output = head.forward_test(inputs, None, None)
|
||||
assert test_output.shape == (1, head.num_classes, 23, 23)
|
|
@ -1,165 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
@patch.multiple(BaseDecodeHead, __abstractmethods__=set())
|
||||
def test_decode_head():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# default input_transform doesn't accept multiple inputs
|
||||
BaseDecodeHead([32, 16], 16, num_classes=19)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# default input_transform doesn't accept multiple inputs
|
||||
BaseDecodeHead(32, 16, num_classes=19, in_index=[-1, -2])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# supported mode is resize_concat only
|
||||
BaseDecodeHead(32, 16, num_classes=19, input_transform='concat')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# in_channels should be list|tuple
|
||||
BaseDecodeHead(32, 16, num_classes=19, input_transform='resize_concat')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# in_index should be list|tuple
|
||||
BaseDecodeHead([32],
|
||||
16,
|
||||
in_index=-1,
|
||||
num_classes=19,
|
||||
input_transform='resize_concat')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# len(in_index) should equal len(in_channels)
|
||||
BaseDecodeHead([32, 16],
|
||||
16,
|
||||
num_classes=19,
|
||||
in_index=[-1],
|
||||
input_transform='resize_concat')
|
||||
|
||||
# test default dropout
|
||||
head = BaseDecodeHead(32, 16, num_classes=19)
|
||||
assert hasattr(head, 'dropout') and head.dropout.p == 0.1
|
||||
|
||||
# test set dropout
|
||||
head = BaseDecodeHead(32, 16, num_classes=19, dropout_ratio=0.2)
|
||||
assert hasattr(head, 'dropout') and head.dropout.p == 0.2
|
||||
|
||||
# test no input_transform
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
head = BaseDecodeHead(32, 16, num_classes=19)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.in_channels == 32
|
||||
assert head.input_transform is None
|
||||
transformed_inputs = head._transform_inputs(inputs)
|
||||
assert transformed_inputs.shape == (1, 32, 45, 45)
|
||||
|
||||
# test input_transform = resize_concat
|
||||
inputs = [torch.randn(1, 32, 45, 45), torch.randn(1, 16, 21, 21)]
|
||||
head = BaseDecodeHead([32, 16],
|
||||
16,
|
||||
num_classes=19,
|
||||
in_index=[0, 1],
|
||||
input_transform='resize_concat')
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
assert head.in_channels == 48
|
||||
assert head.input_transform == 'resize_concat'
|
||||
transformed_inputs = head._transform_inputs(inputs)
|
||||
assert transformed_inputs.shape == (1, 48, 45, 45)
|
||||
|
||||
# test multi-loss, loss_decode is dict
|
||||
with pytest.raises(TypeError):
|
||||
# loss_decode must be a dict or sequence of dict.
|
||||
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
|
||||
|
||||
inputs = torch.randn(2, 19, 8, 8).float()
|
||||
target = torch.ones(2, 1, 64, 64).long()
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
head, target = to_cuda(head, target)
|
||||
loss = head.losses(seg_logit=inputs, seg_label=target)
|
||||
assert 'loss_ce' in loss
|
||||
|
||||
# test multi-loss, loss_decode is list of dict
|
||||
inputs = torch.randn(2, 19, 8, 8).float()
|
||||
target = torch.ones(2, 1, 64, 64).long()
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=[
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_1'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_2')
|
||||
])
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
head, target = to_cuda(head, target)
|
||||
loss = head.losses(seg_logit=inputs, seg_label=target)
|
||||
assert 'loss_1' in loss
|
||||
assert 'loss_2' in loss
|
||||
|
||||
# 'loss_decode' must be a dict or sequence of dict
|
||||
with pytest.raises(TypeError):
|
||||
BaseDecodeHead(3, 16, num_classes=19, loss_decode=['CrossEntropyLoss'])
|
||||
with pytest.raises(TypeError):
|
||||
BaseDecodeHead(3, 16, num_classes=19, loss_decode=0)
|
||||
|
||||
# test multi-loss, loss_decode is list of dict
|
||||
inputs = torch.randn(2, 19, 8, 8).float()
|
||||
target = torch.ones(2, 1, 64, 64).long()
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_1'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_2'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_3')))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
head, target = to_cuda(head, target)
|
||||
loss = head.losses(seg_logit=inputs, seg_label=target)
|
||||
assert 'loss_1' in loss
|
||||
assert 'loss_2' in loss
|
||||
assert 'loss_3' in loss
|
||||
|
||||
# test multi-loss, loss_decode is list of dict, names of them are identical
|
||||
inputs = torch.randn(2, 19, 8, 8).float()
|
||||
target = torch.ones(2, 1, 64, 64).long()
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_ce'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_ce')))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
head, target = to_cuda(head, target)
|
||||
loss_3 = head.losses(seg_logit=inputs, seg_label=target)
|
||||
|
||||
head = BaseDecodeHead(
|
||||
3,
|
||||
16,
|
||||
num_classes=19,
|
||||
loss_decode=(dict(type='CrossEntropyLoss', loss_name='loss_ce')))
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
head, target = to_cuda(head, target)
|
||||
loss = head.losses(seg_logit=inputs, seg_label=target)
|
||||
assert 'loss_ce' in loss
|
||||
assert 'loss_ce' in loss_3
|
||||
assert loss_3['loss_ce'] == 3 * loss['loss_ce']
|
|
@ -1,47 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import EncHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_enc_head():
|
||||
# with se_loss, w.o. lateral
|
||||
inputs = [torch.randn(1, 8, 21, 21)]
|
||||
head = EncHead(in_channels=[8], channels=4, num_classes=19, in_index=[-1])
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert isinstance(outputs, tuple) and len(outputs) == 2
|
||||
assert outputs[0].shape == (1, head.num_classes, 21, 21)
|
||||
assert outputs[1].shape == (1, head.num_classes)
|
||||
|
||||
# w.o se_loss, w.o. lateral
|
||||
inputs = [torch.randn(1, 8, 21, 21)]
|
||||
head = EncHead(
|
||||
in_channels=[8],
|
||||
channels=4,
|
||||
use_se_loss=False,
|
||||
num_classes=19,
|
||||
in_index=[-1])
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 21, 21)
|
||||
|
||||
# with se_loss, with lateral
|
||||
inputs = [torch.randn(1, 4, 45, 45), torch.randn(1, 8, 21, 21)]
|
||||
head = EncHead(
|
||||
in_channels=[4, 8],
|
||||
channels=4,
|
||||
add_lateral=True,
|
||||
num_classes=19,
|
||||
in_index=[-2, -1])
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert isinstance(outputs, tuple) and len(outputs) == 2
|
||||
assert outputs[0].shape == (1, head.num_classes, 21, 21)
|
||||
assert outputs[1].shape == (1, head.num_classes)
|
||||
test_output = head.forward_test(inputs, None, None)
|
||||
assert test_output.shape == (1, head.num_classes, 21, 21)
|
|
@ -1,195 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads.knet_head import (IterativeDecodeHead,
|
||||
KernelUpdateHead)
|
||||
from .utils import to_cuda
|
||||
|
||||
num_stages = 3
|
||||
conv_kernel_size = 1
|
||||
|
||||
kernel_updator_cfg = dict(
|
||||
type='KernelUpdator',
|
||||
in_channels=16,
|
||||
feat_channels=16,
|
||||
out_channels=16,
|
||||
gate_norm_act=True,
|
||||
activate_out=True,
|
||||
act_cfg=dict(type='ReLU', inplace=True),
|
||||
norm_cfg=dict(type='LN'))
|
||||
|
||||
|
||||
def test_knet_head():
|
||||
# test init function of kernel update head
|
||||
kernel_update_head = KernelUpdateHead(
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=1,
|
||||
feedforward_channels=128,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
dropout=0.0,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
with_ffn=True,
|
||||
feat_transform_cfg=dict(conv_cfg=dict(type='Conv2d'), act_cfg=None),
|
||||
kernel_init=True,
|
||||
kernel_updator_cfg=kernel_updator_cfg)
|
||||
kernel_update_head.init_weights()
|
||||
|
||||
head = IterativeDecodeHead(
|
||||
num_stages=num_stages,
|
||||
kernel_update_head=[
|
||||
dict(
|
||||
type='KernelUpdateHead',
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=1,
|
||||
feedforward_channels=128,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
dropout=0.0,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
with_ffn=True,
|
||||
feat_transform_cfg=dict(
|
||||
conv_cfg=dict(type='Conv2d'), act_cfg=None),
|
||||
kernel_init=False,
|
||||
kernel_updator_cfg=kernel_updator_cfg)
|
||||
for _ in range(num_stages)
|
||||
],
|
||||
kernel_generate_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
in_index=3,
|
||||
channels=32,
|
||||
num_convs=2,
|
||||
concat_input=True,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
align_corners=False))
|
||||
head.init_weights()
|
||||
inputs = [
|
||||
torch.randn(1, 16, 27, 32),
|
||||
torch.randn(1, 32, 27, 16),
|
||||
torch.randn(1, 64, 27, 16),
|
||||
torch.randn(1, 128, 27, 16)
|
||||
]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs[-1].shape == (1, head.num_classes, 27, 16)
|
||||
|
||||
# test whether only return the prediction of
|
||||
# the last stage during testing
|
||||
with torch.no_grad():
|
||||
head.eval()
|
||||
outputs = head(inputs)
|
||||
assert outputs.shape == (1, head.num_classes, 27, 16)
|
||||
|
||||
# test K-Net without `feat_transform_cfg`
|
||||
head = IterativeDecodeHead(
|
||||
num_stages=num_stages,
|
||||
kernel_update_head=[
|
||||
dict(
|
||||
type='KernelUpdateHead',
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=1,
|
||||
feedforward_channels=128,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
dropout=0.0,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
with_ffn=True,
|
||||
feat_transform_cfg=None,
|
||||
kernel_updator_cfg=kernel_updator_cfg)
|
||||
for _ in range(num_stages)
|
||||
],
|
||||
kernel_generate_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
in_index=3,
|
||||
channels=32,
|
||||
num_convs=2,
|
||||
concat_input=True,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
align_corners=False))
|
||||
head.init_weights()
|
||||
|
||||
inputs = [
|
||||
torch.randn(1, 16, 27, 32),
|
||||
torch.randn(1, 32, 27, 16),
|
||||
torch.randn(1, 64, 27, 16),
|
||||
torch.randn(1, 128, 27, 16)
|
||||
]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs[-1].shape == (1, head.num_classes, 27, 16)
|
||||
|
||||
# test K-Net with
|
||||
# self.mask_transform_stride == 2 and self.feat_gather_stride == 1
|
||||
head = IterativeDecodeHead(
|
||||
num_stages=num_stages,
|
||||
kernel_update_head=[
|
||||
dict(
|
||||
type='KernelUpdateHead',
|
||||
num_classes=150,
|
||||
num_ffn_fcs=2,
|
||||
num_heads=8,
|
||||
num_mask_fcs=1,
|
||||
feedforward_channels=128,
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
dropout=0.0,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
ffn_act_cfg=dict(type='ReLU', inplace=True),
|
||||
with_ffn=True,
|
||||
feat_transform_cfg=dict(
|
||||
conv_cfg=dict(type='Conv2d'), act_cfg=None),
|
||||
kernel_init=False,
|
||||
mask_transform_stride=2,
|
||||
feat_gather_stride=1,
|
||||
kernel_updator_cfg=kernel_updator_cfg)
|
||||
for _ in range(num_stages)
|
||||
],
|
||||
kernel_generate_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
in_index=3,
|
||||
channels=32,
|
||||
num_convs=2,
|
||||
concat_input=True,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=150,
|
||||
align_corners=False))
|
||||
head.init_weights()
|
||||
|
||||
inputs = [
|
||||
torch.randn(1, 16, 27, 32),
|
||||
torch.randn(1, 32, 27, 16),
|
||||
torch.randn(1, 64, 27, 16),
|
||||
torch.randn(1, 128, 27, 16)
|
||||
]
|
||||
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert outputs[-1].shape == (1, head.num_classes, 26, 16)
|
||||
|
||||
# test loss function in K-Net
|
||||
fake_label = torch.ones_like(
|
||||
outputs[-1][:, 0:1, :, :], dtype=torch.int16).long()
|
||||
loss = head.losses(seg_logit=outputs, seg_label=fake_label)
|
||||
assert loss['loss_ce.s0'] != torch.zeros_like(loss['loss_ce.s0'])
|
||||
assert loss['loss_ce.s1'] != torch.zeros_like(loss['loss_ce.s1'])
|
||||
assert loss['loss_ce.s2'] != torch.zeros_like(loss['loss_ce.s2'])
|
||||
assert loss['loss_ce.s3'] != torch.zeros_like(loss['loss_ce.s3'])
|
|
@ -1,61 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.utils import ConfigDict
|
||||
|
||||
from mmseg.models.decode_heads import FCNHead, PointHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_point_head():
|
||||
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
point_head = PointHead(
|
||||
in_channels=[32], in_index=[0], channels=16, num_classes=19)
|
||||
assert len(point_head.fcs) == 3
|
||||
fcn_head = FCNHead(in_channels=32, channels=16, num_classes=19)
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(point_head, inputs)
|
||||
head, inputs = to_cuda(fcn_head, inputs)
|
||||
prev_output = fcn_head(inputs)
|
||||
test_cfg = ConfigDict(
|
||||
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
|
||||
output = point_head.forward_test(inputs, prev_output, None, test_cfg)
|
||||
assert output.shape == (1, point_head.num_classes, 180, 180)
|
||||
|
||||
# test multiple losses case
|
||||
inputs = [torch.randn(1, 32, 45, 45)]
|
||||
point_head_multiple_losses = PointHead(
|
||||
in_channels=[32],
|
||||
in_index=[0],
|
||||
channels=16,
|
||||
num_classes=19,
|
||||
loss_decode=[
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_1'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_2')
|
||||
])
|
||||
assert len(point_head_multiple_losses.fcs) == 3
|
||||
fcn_head_multiple_losses = FCNHead(
|
||||
in_channels=32,
|
||||
channels=16,
|
||||
num_classes=19,
|
||||
loss_decode=[
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_1'),
|
||||
dict(type='CrossEntropyLoss', loss_name='loss_2')
|
||||
])
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(point_head_multiple_losses, inputs)
|
||||
head, inputs = to_cuda(fcn_head_multiple_losses, inputs)
|
||||
prev_output = fcn_head_multiple_losses(inputs)
|
||||
test_cfg = ConfigDict(
|
||||
subdivision_steps=2, subdivision_num_points=8196, scale_factor=2)
|
||||
output = point_head_multiple_losses.forward_test(inputs, prev_output, None,
|
||||
test_cfg)
|
||||
assert output.shape == (1, point_head.num_classes, 180, 180)
|
||||
|
||||
fake_label = torch.ones([1, 180, 180], dtype=torch.long)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
fake_label = fake_label.cuda()
|
||||
loss = point_head_multiple_losses.losses(output, fake_label)
|
||||
assert 'pointloss_1' in loss
|
||||
assert 'pointloss_2' in loss
|
|
@ -1,31 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
from mmseg.models.decode_heads import STDCHead
|
||||
from .utils import to_cuda
|
||||
|
||||
|
||||
def test_stdc_head():
|
||||
inputs = [torch.randn(1, 32, 21, 21)]
|
||||
head = STDCHead(
|
||||
in_channels=32,
|
||||
channels=8,
|
||||
num_convs=1,
|
||||
num_classes=2,
|
||||
in_index=-1,
|
||||
loss_decode=[
|
||||
dict(
|
||||
type='CrossEntropyLoss', loss_name='loss_ce', loss_weight=1.0),
|
||||
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0)
|
||||
])
|
||||
if torch.cuda.is_available():
|
||||
head, inputs = to_cuda(head, inputs)
|
||||
outputs = head(inputs)
|
||||
assert isinstance(outputs, torch.Tensor) and len(outputs) == 1
|
||||
assert outputs.shape == torch.Size([1, head.num_classes, 21, 21])
|
||||
|
||||
fake_label = torch.ones_like(
|
||||
outputs[:, 0:1, :, :], dtype=torch.int16).long()
|
||||
loss = head.losses(seg_logit=outputs, seg_label=fake_label)
|
||||
assert loss['loss_ce'] != torch.zeros_like(loss['loss_ce'])
|
||||
assert loss['loss_dice'] != torch.zeros_like(loss['loss_dice'])
|
|
@ -1 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -1,294 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.losses.cross_entropy_loss import _expand_onehot_labels
|
||||
|
||||
|
||||
@pytest.mark.parametrize('use_sigmoid', [True, False])
|
||||
@pytest.mark.parametrize('reduction', ('mean', 'sum', 'none'))
|
||||
@pytest.mark.parametrize('avg_non_ignore', [True, False])
|
||||
@pytest.mark.parametrize('bce_input_same_dim', [True, False])
|
||||
def test_ce_loss(use_sigmoid, reduction, avg_non_ignore, bce_input_same_dim):
|
||||
from mmseg.models import build_loss
|
||||
|
||||
# use_mask and use_sigmoid cannot be true at the same time
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_mask=True,
|
||||
use_sigmoid=True,
|
||||
loss_weight=1.0)
|
||||
build_loss(loss_cfg)
|
||||
|
||||
# test loss with simple case for ce/bce
|
||||
fake_pred = torch.Tensor([[100, -100]])
|
||||
fake_label = torch.Tensor([1]).long()
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=use_sigmoid,
|
||||
loss_weight=1.0,
|
||||
avg_non_ignore=avg_non_ignore,
|
||||
loss_name='loss_ce')
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
if use_sigmoid:
|
||||
assert torch.allclose(
|
||||
loss_cls(fake_pred, fake_label), torch.tensor(100.))
|
||||
else:
|
||||
assert torch.allclose(
|
||||
loss_cls(fake_pred, fake_label), torch.tensor(200.))
|
||||
|
||||
# test loss with complicated case for ce/bce
|
||||
# when avg_non_ignore is False, `avg_factor` would not be calculated
|
||||
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
|
||||
fake_label = torch.ones(2, 8, 8).long()
|
||||
fake_label[:, 0, 0] = 255
|
||||
fake_weight = None
|
||||
# extra test bce loss when pred.shape == label.shape
|
||||
if use_sigmoid and bce_input_same_dim:
|
||||
fake_pred = torch.randn(2, 10).float()
|
||||
fake_label = torch.rand(2, 10).float()
|
||||
fake_weight = torch.rand(2, 10) # set weight in forward function
|
||||
fake_label[0, [1, 2, 5, 7]] = 255 # set ignore_index
|
||||
fake_label[1, [0, 5, 8, 9]] = 255
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
loss = loss_cls(
|
||||
fake_pred, fake_label, weight=fake_weight, ignore_index=255)
|
||||
if use_sigmoid:
|
||||
if fake_pred.dim() != fake_label.dim():
|
||||
fake_label, weight, valid_mask = _expand_onehot_labels(
|
||||
labels=fake_label,
|
||||
label_weights=None,
|
||||
target_shape=fake_pred.shape,
|
||||
ignore_index=255)
|
||||
else:
|
||||
# should mask out the ignored elements
|
||||
valid_mask = ((fake_label >= 0) & (fake_label != 255)).float()
|
||||
weight = valid_mask
|
||||
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
fake_pred,
|
||||
fake_label.float(),
|
||||
reduction='none',
|
||||
weight=fake_weight)
|
||||
if avg_non_ignore:
|
||||
avg_factor = valid_mask.sum().item()
|
||||
torch_loss = (torch_loss * weight).sum() / avg_factor
|
||||
else:
|
||||
torch_loss = (torch_loss * weight).mean()
|
||||
else:
|
||||
if avg_non_ignore:
|
||||
torch_loss = torch.nn.functional.cross_entropy(
|
||||
fake_pred, fake_label, reduction='mean', ignore_index=255)
|
||||
else:
|
||||
torch_loss = torch.nn.functional.cross_entropy(
|
||||
fake_pred, fake_label, reduction='sum',
|
||||
ignore_index=255) / fake_label.numel()
|
||||
assert torch.allclose(loss, torch_loss)
|
||||
|
||||
if use_sigmoid:
|
||||
# test loss with complicated case for ce/bce
|
||||
# when avg_non_ignore is False, `avg_factor` would not be calculated
|
||||
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
|
||||
fake_label = torch.ones(2, 8, 8).long()
|
||||
fake_label[:, 0, 0] = 255
|
||||
fake_weight = torch.rand(2, 8, 8)
|
||||
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
loss = loss_cls(
|
||||
fake_pred, fake_label, weight=fake_weight, ignore_index=255)
|
||||
if use_sigmoid:
|
||||
fake_label, weight, valid_mask = _expand_onehot_labels(
|
||||
labels=fake_label,
|
||||
label_weights=None,
|
||||
target_shape=fake_pred.shape,
|
||||
ignore_index=255)
|
||||
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
fake_pred,
|
||||
fake_label.float(),
|
||||
reduction='none',
|
||||
weight=fake_weight.unsqueeze(1).expand(fake_pred.shape))
|
||||
if avg_non_ignore:
|
||||
avg_factor = valid_mask.sum().item()
|
||||
torch_loss = (torch_loss * weight).sum() / avg_factor
|
||||
else:
|
||||
torch_loss = (torch_loss * weight).mean()
|
||||
assert torch.allclose(loss, torch_loss)
|
||||
|
||||
# test loss with class weights from file
|
||||
fake_pred = torch.Tensor([[100, -100]])
|
||||
fake_label = torch.Tensor([1]).long()
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
|
||||
mmcv.dump([0.8, 0.2], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_ce')
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
|
||||
|
||||
np.save(f'{tmp_file.name}.npy', np.array([0.8, 0.2])) # from npy file
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=False,
|
||||
class_weight=f'{tmp_file.name}.npy',
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_ce')
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))
|
||||
tmp_file.close()
|
||||
os.remove(f'{tmp_file.name}.pkl')
|
||||
os.remove(f'{tmp_file.name}.npy')
|
||||
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.))
|
||||
|
||||
# test `avg_non_ignore` without ignore index would not affect ce/bce loss
|
||||
# when reduction='sum'/'none'/'mean'
|
||||
loss_cls_cfg1 = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=use_sigmoid,
|
||||
reduction=reduction,
|
||||
loss_weight=1.0,
|
||||
avg_non_ignore=True)
|
||||
loss_cls1 = build_loss(loss_cls_cfg1)
|
||||
loss_cls_cfg2 = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=use_sigmoid,
|
||||
reduction=reduction,
|
||||
loss_weight=1.0,
|
||||
avg_non_ignore=False)
|
||||
loss_cls2 = build_loss(loss_cls_cfg2)
|
||||
assert torch.allclose(
|
||||
loss_cls1(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(),
|
||||
loss_cls2(fake_pred, fake_label, ignore_index=255) / fake_pred.numel(),
|
||||
atol=1e-4)
|
||||
|
||||
# test ce/bce loss with ignore index and class weight
|
||||
# in 5-way classification
|
||||
if use_sigmoid:
|
||||
# test bce loss when pred.shape == or != label.shape
|
||||
if bce_input_same_dim:
|
||||
fake_pred = torch.randn(2, 10).float()
|
||||
fake_label = torch.rand(2, 10).float()
|
||||
class_weight = torch.rand(2, 10)
|
||||
else:
|
||||
fake_pred = torch.full(size=(2, 21, 8, 8), fill_value=0.5)
|
||||
fake_label = torch.ones(2, 8, 8).long()
|
||||
class_weight = torch.randn(2, 21, 8, 8)
|
||||
fake_label, weight, valid_mask = _expand_onehot_labels(
|
||||
labels=fake_label,
|
||||
label_weights=None,
|
||||
target_shape=fake_pred.shape,
|
||||
ignore_index=-100)
|
||||
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
fake_pred,
|
||||
fake_label.float(),
|
||||
reduction='mean',
|
||||
pos_weight=class_weight)
|
||||
else:
|
||||
fake_pred = torch.randn(2, 5, 10).float() # 5-way classification
|
||||
fake_label = torch.randint(0, 5, (2, 10)).long()
|
||||
class_weight = torch.rand(5)
|
||||
class_weight /= class_weight.sum()
|
||||
torch_loss = torch.nn.functional.cross_entropy(
|
||||
fake_pred, fake_label, reduction='sum',
|
||||
weight=class_weight) / fake_label.numel()
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=use_sigmoid,
|
||||
reduction='mean',
|
||||
class_weight=class_weight,
|
||||
loss_weight=1.0,
|
||||
avg_non_ignore=avg_non_ignore)
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
|
||||
# test cross entropy loss has name `loss_ce`
|
||||
assert loss_cls.loss_name == 'loss_ce'
|
||||
# test avg_non_ignore is in extra_repr
|
||||
assert loss_cls.extra_repr() == f'avg_non_ignore={avg_non_ignore}'
|
||||
|
||||
loss = loss_cls(fake_pred, fake_label)
|
||||
assert torch.allclose(loss, torch_loss)
|
||||
|
||||
fake_label[0, [1, 2, 5, 7]] = 10 # set ignore_index
|
||||
fake_label[1, [0, 5, 8, 9]] = 10
|
||||
loss = loss_cls(fake_pred, fake_label, ignore_index=10)
|
||||
if use_sigmoid:
|
||||
if avg_non_ignore:
|
||||
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
fake_pred[fake_label != 10],
|
||||
fake_label[fake_label != 10].float(),
|
||||
pos_weight=class_weight[fake_label != 10],
|
||||
reduction='mean')
|
||||
else:
|
||||
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
fake_pred[fake_label != 10],
|
||||
fake_label[fake_label != 10].float(),
|
||||
pos_weight=class_weight[fake_label != 10],
|
||||
reduction='sum') / fake_label.numel()
|
||||
else:
|
||||
if avg_non_ignore:
|
||||
torch_loss = torch.nn.functional.cross_entropy(
|
||||
fake_pred,
|
||||
fake_label,
|
||||
ignore_index=10,
|
||||
reduction='sum',
|
||||
weight=class_weight) / fake_label[fake_label != 10].numel()
|
||||
else:
|
||||
torch_loss = torch.nn.functional.cross_entropy(
|
||||
fake_pred,
|
||||
fake_label,
|
||||
ignore_index=10,
|
||||
reduction='sum',
|
||||
weight=class_weight) / fake_label.numel()
|
||||
assert torch.allclose(loss, torch_loss)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('avg_non_ignore', [True, False])
|
||||
@pytest.mark.parametrize('with_weight', [True, False])
|
||||
def test_binary_class_ce_loss(avg_non_ignore, with_weight):
|
||||
from mmseg.models import build_loss
|
||||
|
||||
fake_pred = torch.rand(3, 1, 10, 10)
|
||||
fake_label = torch.randint(0, 2, (3, 10, 10))
|
||||
fake_weight = torch.rand(3, 10, 10)
|
||||
valid_mask = ((fake_label >= 0) & (fake_label != 255)).float()
|
||||
weight = valid_mask
|
||||
|
||||
torch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
|
||||
fake_pred,
|
||||
fake_label.unsqueeze(1).float(),
|
||||
reduction='none',
|
||||
weight=fake_weight.unsqueeze(1).float() if with_weight else None)
|
||||
if avg_non_ignore:
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
avg_factor = valid_mask.sum().item()
|
||||
torch_loss = (torch_loss * weight.unsqueeze(1)).sum() / (
|
||||
avg_factor + eps)
|
||||
else:
|
||||
torch_loss = (torch_loss * weight.unsqueeze(1)).mean()
|
||||
|
||||
loss_cls_cfg = dict(
|
||||
type='CrossEntropyLoss',
|
||||
use_sigmoid=True,
|
||||
loss_weight=1.0,
|
||||
avg_non_ignore=avg_non_ignore,
|
||||
reduction='mean',
|
||||
loss_name='loss_ce')
|
||||
loss_cls = build_loss(loss_cls_cfg)
|
||||
loss = loss_cls(
|
||||
fake_pred,
|
||||
fake_label,
|
||||
weight=fake_weight if with_weight else None,
|
||||
ignore_index=255)
|
||||
assert torch.allclose(loss, torch_loss)
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
|
||||
|
||||
def test_dice_lose():
|
||||
from mmseg.models import build_loss
|
||||
|
||||
# test dice loss with loss_type = 'multi_class'
|
||||
loss_cfg = dict(
|
||||
type='DiceLoss',
|
||||
reduction='none',
|
||||
class_weight=[1.0, 2.0, 3.0],
|
||||
loss_weight=1.0,
|
||||
ignore_index=1,
|
||||
loss_name='loss_dice')
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(8, 3, 4, 4)
|
||||
labels = (torch.rand(8, 4, 4) * 3).long()
|
||||
dice_loss(logits, labels)
|
||||
|
||||
# test loss with class weights from file
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
|
||||
mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
|
||||
loss_cfg = dict(
|
||||
type='DiceLoss',
|
||||
reduction='none',
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0,
|
||||
ignore_index=1,
|
||||
loss_name='loss_dice')
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
dice_loss(logits, labels, ignore_index=None)
|
||||
|
||||
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
|
||||
loss_cfg = dict(
|
||||
type='DiceLoss',
|
||||
reduction='none',
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0,
|
||||
ignore_index=1,
|
||||
loss_name='loss_dice')
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
dice_loss(logits, labels, ignore_index=None)
|
||||
tmp_file.close()
|
||||
os.remove(f'{tmp_file.name}.pkl')
|
||||
os.remove(f'{tmp_file.name}.npy')
|
||||
|
||||
# test dice loss with loss_type = 'binary'
|
||||
loss_cfg = dict(
|
||||
type='DiceLoss',
|
||||
smooth=2,
|
||||
exponent=3,
|
||||
reduction='sum',
|
||||
loss_weight=1.0,
|
||||
ignore_index=0,
|
||||
loss_name='loss_dice')
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(8, 2, 4, 4)
|
||||
labels = (torch.rand(8, 4, 4) * 2).long()
|
||||
dice_loss(logits, labels)
|
||||
|
||||
# test dice loss has name `loss_dice`
|
||||
loss_cfg = dict(
|
||||
type='DiceLoss',
|
||||
smooth=2,
|
||||
exponent=3,
|
||||
reduction='sum',
|
||||
loss_weight=1.0,
|
||||
ignore_index=0,
|
||||
loss_name='loss_dice')
|
||||
dice_loss = build_loss(loss_cfg)
|
||||
assert dice_loss.loss_name == 'loss_dice'
|
|
@ -1,216 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmseg.models import build_loss
|
||||
|
||||
|
||||
# test focal loss with use_sigmoid=False
|
||||
def test_use_sigmoid():
|
||||
# can't init with use_sigmoid=True
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', use_sigmoid=False)
|
||||
build_loss(loss_cfg)
|
||||
|
||||
# can't forward with use_sigmoid=True
|
||||
with pytest.raises(NotImplementedError):
|
||||
loss_cfg = dict(type='FocalLoss', use_sigmoid=True)
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
focal_loss.use_sigmoid = False
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
focal_loss(fake_pred, fake_target)
|
||||
|
||||
|
||||
# reduction type must be 'none', 'mean' or 'sum'
|
||||
def test_wrong_reduction_type():
|
||||
# can't init with wrong reduction
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', reduction='test')
|
||||
build_loss(loss_cfg)
|
||||
|
||||
# can't forward with wrong reduction override
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
focal_loss(fake_pred, fake_target, reduction_override='test')
|
||||
|
||||
|
||||
# test focal loss can handle input parameters with
|
||||
# unacceptable types
|
||||
def test_unacceptable_parameters():
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', gamma='test')
|
||||
build_loss(loss_cfg)
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', alpha='test')
|
||||
build_loss(loss_cfg)
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', class_weight='test')
|
||||
build_loss(loss_cfg)
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', loss_weight='test')
|
||||
build_loss(loss_cfg)
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='FocalLoss', loss_name=123)
|
||||
build_loss(loss_cfg)
|
||||
|
||||
|
||||
# test if focal loss can be correctly initialize
|
||||
def test_init_focal_loss():
|
||||
loss_cfg = dict(
|
||||
type='FocalLoss',
|
||||
use_sigmoid=True,
|
||||
gamma=3.0,
|
||||
alpha=3.0,
|
||||
class_weight=[1, 2, 3, 4],
|
||||
reduction='sum')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
assert focal_loss.use_sigmoid is True
|
||||
assert focal_loss.gamma == 3.0
|
||||
assert focal_loss.alpha == 3.0
|
||||
assert focal_loss.reduction == 'sum'
|
||||
assert focal_loss.class_weight == [1, 2, 3, 4]
|
||||
assert focal_loss.loss_weight == 1.0
|
||||
assert focal_loss.loss_name == 'loss_focal'
|
||||
|
||||
|
||||
# test reduction override
|
||||
def test_reduction_override():
|
||||
loss_cfg = dict(type='FocalLoss', reduction='mean')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
loss = focal_loss(fake_pred, fake_target, reduction_override='none')
|
||||
assert loss.shape == fake_pred.shape
|
||||
|
||||
|
||||
# test wrong pred and target shape
|
||||
def test_wrong_pred_and_target_shape():
|
||||
loss_cfg = dict(type='FocalLoss')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 2, 2))
|
||||
fake_target = F.one_hot(fake_target, num_classes=4)
|
||||
fake_target = fake_target.permute(0, 3, 1, 2)
|
||||
with pytest.raises(AssertionError):
|
||||
focal_loss(fake_pred, fake_target)
|
||||
|
||||
|
||||
# test forward with different shape of target
|
||||
def test_forward_with_different_shape_of_target():
|
||||
loss_cfg = dict(type='FocalLoss')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
loss1 = focal_loss(fake_pred, fake_target)
|
||||
|
||||
fake_target = F.one_hot(fake_target, num_classes=4)
|
||||
fake_target = fake_target.permute(0, 3, 1, 2)
|
||||
loss2 = focal_loss(fake_pred, fake_target)
|
||||
assert loss1 == loss2
|
||||
|
||||
|
||||
# test forward with weight
|
||||
def test_forward_with_weight():
|
||||
loss_cfg = dict(type='FocalLoss')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
weight = torch.rand(3 * 5 * 6, 1)
|
||||
loss1 = focal_loss(fake_pred, fake_target, weight=weight)
|
||||
|
||||
weight2 = weight.view(-1)
|
||||
loss2 = focal_loss(fake_pred, fake_target, weight=weight2)
|
||||
|
||||
weight3 = weight.expand(3 * 5 * 6, 4)
|
||||
loss3 = focal_loss(fake_pred, fake_target, weight=weight3)
|
||||
assert loss1 == loss2 == loss3
|
||||
|
||||
|
||||
# test none reduction type
|
||||
def test_none_reduction_type():
|
||||
loss_cfg = dict(type='FocalLoss', reduction='none')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
loss = focal_loss(fake_pred, fake_target)
|
||||
assert loss.shape == fake_pred.shape
|
||||
|
||||
|
||||
# test the usage of class weight
|
||||
def test_class_weight():
|
||||
loss_cfg_cw = dict(
|
||||
type='FocalLoss', reduction='none', class_weight=[1.0, 2.0, 3.0, 4.0])
|
||||
loss_cfg = dict(type='FocalLoss', reduction='none')
|
||||
focal_loss_cw = build_loss(loss_cfg_cw)
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
loss_cw = focal_loss_cw(fake_pred, fake_target)
|
||||
loss = focal_loss(fake_pred, fake_target)
|
||||
weight = torch.tensor([1, 2, 3, 4]).view(1, 4, 1, 1)
|
||||
assert (loss * weight == loss_cw).all()
|
||||
|
||||
|
||||
# test ignore index
|
||||
def test_ignore_index():
|
||||
loss_cfg = dict(type='FocalLoss', reduction='none')
|
||||
# ignore_index within C classes
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
fake_pred = torch.rand(3, 5, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
dim1 = torch.randint(0, 3, (4, ))
|
||||
dim2 = torch.randint(0, 5, (4, ))
|
||||
dim3 = torch.randint(0, 6, (4, ))
|
||||
fake_target[dim1, dim2, dim3] = 4
|
||||
loss1 = focal_loss(fake_pred, fake_target, ignore_index=4)
|
||||
one_hot_target = F.one_hot(fake_target, num_classes=5)
|
||||
one_hot_target = one_hot_target.permute(0, 3, 1, 2)
|
||||
loss2 = focal_loss(fake_pred, one_hot_target, ignore_index=4)
|
||||
assert (loss1 == loss2).all()
|
||||
assert (loss1[dim1, :, dim2, dim3] == 0).all()
|
||||
assert (loss2[dim1, :, dim2, dim3] == 0).all()
|
||||
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
loss1 = focal_loss(fake_pred, fake_target, ignore_index=2)
|
||||
one_hot_target = F.one_hot(fake_target, num_classes=4)
|
||||
one_hot_target = one_hot_target.permute(0, 3, 1, 2)
|
||||
loss2 = focal_loss(fake_pred, one_hot_target, ignore_index=2)
|
||||
ignore_mask = one_hot_target == 2
|
||||
assert (loss1 == loss2).all()
|
||||
assert torch.sum(loss1 * ignore_mask) == 0
|
||||
assert torch.sum(loss2 * ignore_mask) == 0
|
||||
|
||||
# ignore index is not in prediction's classes
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
dim1 = torch.randint(0, 3, (4, ))
|
||||
dim2 = torch.randint(0, 5, (4, ))
|
||||
dim3 = torch.randint(0, 6, (4, ))
|
||||
fake_target[dim1, dim2, dim3] = 255
|
||||
loss1 = focal_loss(fake_pred, fake_target, ignore_index=255)
|
||||
assert (loss1[dim1, :, dim2, dim3] == 0).all()
|
||||
|
||||
|
||||
# test list alpha
|
||||
def test_alpha():
|
||||
loss_cfg = dict(type='FocalLoss')
|
||||
focal_loss = build_loss(loss_cfg)
|
||||
alpha_float = 0.4
|
||||
alpha = [0.4, 0.4, 0.4, 0.4]
|
||||
alpha2 = [0.1, 0.3, 0.2, 0.1]
|
||||
fake_pred = torch.rand(3, 4, 5, 6)
|
||||
fake_target = torch.randint(0, 4, (3, 5, 6))
|
||||
focal_loss.alpha = alpha_float
|
||||
loss1 = focal_loss(fake_pred, fake_target)
|
||||
focal_loss.alpha = alpha
|
||||
loss2 = focal_loss(fake_pred, fake_target)
|
||||
assert loss1 == loss2
|
||||
focal_loss.alpha = alpha2
|
||||
focal_loss(fake_pred, fake_target)
|
|
@ -1,118 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
def test_lovasz_loss():
|
||||
from mmseg.models import build_loss
|
||||
|
||||
# loss_type should be 'binary' or 'multi_class'
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
loss_type='Binary',
|
||||
reduction='none',
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
build_loss(loss_cfg)
|
||||
|
||||
# reduction should be 'none' when per_image is False.
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
loss_type='multi_class',
|
||||
loss_name='loss_lovasz')
|
||||
build_loss(loss_cfg)
|
||||
|
||||
# test lovasz loss with loss_type = 'multi_class' and per_image = False
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
reduction='none',
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(1, 3, 4, 4)
|
||||
labels = (torch.rand(1, 4, 4) * 2).long()
|
||||
lovasz_loss(logits, labels)
|
||||
|
||||
# test lovasz loss with loss_type = 'multi_class' and per_image = True
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
class_weight=[1.0, 2.0, 3.0],
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(1, 3, 4, 4)
|
||||
labels = (torch.rand(1, 4, 4) * 2).long()
|
||||
lovasz_loss(logits, labels, ignore_index=None)
|
||||
|
||||
# test loss with class weights from file
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
tmp_file = tempfile.NamedTemporaryFile()
|
||||
|
||||
mmcv.dump([1.0, 2.0, 3.0], f'{tmp_file.name}.pkl', 'pkl') # from pkl file
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
class_weight=f'{tmp_file.name}.pkl',
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
lovasz_loss(logits, labels, ignore_index=None)
|
||||
|
||||
np.save(f'{tmp_file.name}.npy', np.array([1.0, 2.0, 3.0])) # from npy file
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
class_weight=f'{tmp_file.name}.npy',
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
lovasz_loss(logits, labels, ignore_index=None)
|
||||
tmp_file.close()
|
||||
os.remove(f'{tmp_file.name}.pkl')
|
||||
os.remove(f'{tmp_file.name}.npy')
|
||||
|
||||
# test lovasz loss with loss_type = 'binary' and per_image = False
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
loss_type='binary',
|
||||
reduction='none',
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(2, 4, 4)
|
||||
labels = (torch.rand(2, 4, 4)).long()
|
||||
lovasz_loss(logits, labels)
|
||||
|
||||
# test lovasz loss with loss_type = 'binary' and per_image = True
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
loss_type='binary',
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(2, 4, 4)
|
||||
labels = (torch.rand(2, 4, 4)).long()
|
||||
lovasz_loss(logits, labels, ignore_index=None)
|
||||
|
||||
# test lovasz loss has name `loss_lovasz`
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
loss_type='binary',
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0,
|
||||
loss_name='loss_lovasz')
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
assert lovasz_loss.loss_name == 'loss_lovasz'
|
|
@ -1,129 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.losses import Accuracy, reduce_loss, weight_reduce_loss
|
||||
|
||||
|
||||
def test_weight_reduce_loss():
|
||||
loss = torch.rand(1, 3, 4, 4)
|
||||
weight = torch.zeros(1, 3, 4, 4)
|
||||
weight[:, :, :2, :2] = 1
|
||||
|
||||
# test reduce_loss()
|
||||
reduced = reduce_loss(loss, 'none')
|
||||
assert reduced is loss
|
||||
|
||||
reduced = reduce_loss(loss, 'mean')
|
||||
np.testing.assert_almost_equal(reduced.numpy(), loss.mean())
|
||||
|
||||
reduced = reduce_loss(loss, 'sum')
|
||||
np.testing.assert_almost_equal(reduced.numpy(), loss.sum())
|
||||
|
||||
# test weight_reduce_loss()
|
||||
reduced = weight_reduce_loss(loss, weight=None, reduction='none')
|
||||
assert reduced is loss
|
||||
|
||||
reduced = weight_reduce_loss(loss, weight=weight, reduction='mean')
|
||||
target = (loss * weight).mean()
|
||||
np.testing.assert_almost_equal(reduced.numpy(), target)
|
||||
|
||||
reduced = weight_reduce_loss(loss, weight=weight, reduction='sum')
|
||||
np.testing.assert_almost_equal(reduced.numpy(), (loss * weight).sum())
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
weight_wrong = weight[0, 0, ...]
|
||||
weight_reduce_loss(loss, weight=weight_wrong, reduction='mean')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
weight_wrong = weight[:, 0:2, ...]
|
||||
weight_reduce_loss(loss, weight=weight_wrong, reduction='mean')
|
||||
|
||||
|
||||
def test_accuracy():
|
||||
# test for empty pred
|
||||
pred = torch.empty(0, 4)
|
||||
label = torch.empty(0)
|
||||
accuracy = Accuracy(topk=1)
|
||||
acc = accuracy(pred, label)
|
||||
assert acc.item() == 0
|
||||
|
||||
pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6],
|
||||
[0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1],
|
||||
[0.0, 0.0, 0.99, 0]])
|
||||
# test for ignore_index
|
||||
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
|
||||
accuracy = Accuracy(topk=1, ignore_index=None)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert torch.allclose(acc, torch.tensor(100.0))
|
||||
|
||||
# test for ignore_index with a wrong prediction of that index
|
||||
true_label = torch.Tensor([2, 3, 1, 1, 2]).long()
|
||||
accuracy = Accuracy(topk=1, ignore_index=1)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert torch.allclose(acc, torch.tensor(100.0))
|
||||
|
||||
# test for ignore_index 1 with a wrong prediction of other index
|
||||
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
|
||||
accuracy = Accuracy(topk=1, ignore_index=1)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert torch.allclose(acc, torch.tensor(75.0))
|
||||
|
||||
# test for ignore_index 4 with a wrong prediction of other index
|
||||
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
|
||||
accuracy = Accuracy(topk=1, ignore_index=4)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert torch.allclose(acc, torch.tensor(80.0))
|
||||
|
||||
# test for ignoring all the pixels
|
||||
true_label = torch.Tensor([2, 2, 2, 2, 2]).long()
|
||||
accuracy = Accuracy(topk=1, ignore_index=2)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert torch.allclose(acc, torch.tensor(100.0))
|
||||
|
||||
# test for top1
|
||||
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
|
||||
accuracy = Accuracy(topk=1)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert torch.allclose(acc, torch.tensor(100.0))
|
||||
|
||||
# test for top1 with score thresh=0.8
|
||||
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
|
||||
accuracy = Accuracy(topk=1, thresh=0.8)
|
||||
acc = accuracy(pred, true_label)
|
||||
assert torch.allclose(acc, torch.tensor(40.0))
|
||||
|
||||
# test for top2
|
||||
accuracy = Accuracy(topk=2)
|
||||
label = torch.Tensor([3, 2, 0, 0, 2]).long()
|
||||
acc = accuracy(pred, label)
|
||||
assert torch.allclose(acc, torch.tensor(100.0))
|
||||
|
||||
# test for both top1 and top2
|
||||
accuracy = Accuracy(topk=(1, 2))
|
||||
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
|
||||
acc = accuracy(pred, true_label)
|
||||
for a in acc:
|
||||
assert torch.allclose(a, torch.tensor(100.0))
|
||||
|
||||
# topk is larger than pred class number
|
||||
with pytest.raises(AssertionError):
|
||||
accuracy = Accuracy(topk=5)
|
||||
accuracy(pred, true_label)
|
||||
|
||||
# wrong topk type
|
||||
with pytest.raises(AssertionError):
|
||||
accuracy = Accuracy(topk='wrong type')
|
||||
accuracy(pred, true_label)
|
||||
|
||||
# label size is larger than required
|
||||
with pytest.raises(AssertionError):
|
||||
label = torch.Tensor([2, 3, 0, 1, 2, 0]).long() # size mismatch
|
||||
accuracy = Accuracy()
|
||||
accuracy(pred, label)
|
||||
|
||||
# wrong pred dimension
|
||||
with pytest.raises(AssertionError):
|
||||
accuracy = Accuracy()
|
||||
accuracy(pred[:, :, None], true_label)
|
|
@ -1 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -1,57 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv import ConfigDict
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
from .utils import _segmentor_forward_train_test
|
||||
|
||||
|
||||
def test_cascade_encoder_decoder():
|
||||
|
||||
# test 1 decode head, w.o. aux head
|
||||
cfg = ConfigDict(
|
||||
type='CascadeEncoderDecoder',
|
||||
num_stages=2,
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=[
|
||||
dict(type='ExampleDecodeHead'),
|
||||
dict(type='ExampleCascadeDecodeHead')
|
||||
])
|
||||
cfg.test_cfg = ConfigDict(mode='whole')
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test slide mode
|
||||
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test 1 decode head, 1 aux head
|
||||
cfg = ConfigDict(
|
||||
type='CascadeEncoderDecoder',
|
||||
num_stages=2,
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=[
|
||||
dict(type='ExampleDecodeHead'),
|
||||
dict(type='ExampleCascadeDecodeHead')
|
||||
],
|
||||
auxiliary_head=dict(type='ExampleDecodeHead'))
|
||||
cfg.test_cfg = ConfigDict(mode='whole')
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test 1 decode head, 2 aux head
|
||||
cfg = ConfigDict(
|
||||
type='CascadeEncoderDecoder',
|
||||
num_stages=2,
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=[
|
||||
dict(type='ExampleDecodeHead'),
|
||||
dict(type='ExampleCascadeDecodeHead')
|
||||
],
|
||||
auxiliary_head=[
|
||||
dict(type='ExampleDecodeHead'),
|
||||
dict(type='ExampleDecodeHead')
|
||||
])
|
||||
cfg.test_cfg = ConfigDict(mode='whole')
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
|
@ -1,47 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmcv import ConfigDict
|
||||
|
||||
from mmseg.models import build_segmentor
|
||||
from .utils import _segmentor_forward_train_test
|
||||
|
||||
|
||||
def test_encoder_decoder():
|
||||
|
||||
# test 1 decode head, w.o. aux head
|
||||
|
||||
cfg = ConfigDict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=dict(type='ExampleDecodeHead'),
|
||||
train_cfg=None,
|
||||
test_cfg=dict(mode='whole'))
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test slide mode
|
||||
cfg.test_cfg = ConfigDict(mode='slide', crop_size=(3, 3), stride=(2, 2))
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test 1 decode head, 1 aux head
|
||||
cfg = ConfigDict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=dict(type='ExampleDecodeHead'),
|
||||
auxiliary_head=dict(type='ExampleDecodeHead'))
|
||||
cfg.test_cfg = ConfigDict(mode='whole')
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
||||
|
||||
# test 1 decode head, 2 aux head
|
||||
cfg = ConfigDict(
|
||||
type='EncoderDecoder',
|
||||
backbone=dict(type='ExampleBackbone'),
|
||||
decode_head=dict(type='ExampleDecodeHead'),
|
||||
auxiliary_head=[
|
||||
dict(type='ExampleDecodeHead'),
|
||||
dict(type='ExampleDecodeHead')
|
||||
])
|
||||
cfg.test_cfg = ConfigDict(mode='whole')
|
||||
segmentor = build_segmentor(cfg)
|
||||
_segmentor_forward_train_test(segmentor)
|
|
@ -1,140 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from mmseg.models import BACKBONES, HEADS
|
||||
from mmseg.models.decode_heads.cascade_decode_head import BaseCascadeDecodeHead
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
|
||||
|
||||
def _demo_mm_inputs(input_shape=(1, 3, 8, 16), num_classes=10):
|
||||
"""Create a superset of inputs needed to run test or train batches.
|
||||
|
||||
Args:
|
||||
input_shape (tuple):
|
||||
input batch dimensions
|
||||
|
||||
num_classes (int):
|
||||
number of semantic classes
|
||||
"""
|
||||
(N, C, H, W) = input_shape
|
||||
|
||||
rng = np.random.RandomState(0)
|
||||
|
||||
imgs = rng.rand(*input_shape)
|
||||
segs = rng.randint(
|
||||
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
|
||||
|
||||
img_metas = [{
|
||||
'img_shape': (H, W, C),
|
||||
'ori_shape': (H, W, C),
|
||||
'pad_shape': (H, W, C),
|
||||
'filename': '<demo>.png',
|
||||
'scale_factor': 1.0,
|
||||
'flip': False,
|
||||
'flip_direction': 'horizontal'
|
||||
} for _ in range(N)]
|
||||
|
||||
mm_inputs = {
|
||||
'imgs': torch.FloatTensor(imgs),
|
||||
'img_metas': img_metas,
|
||||
'gt_semantic_seg': torch.LongTensor(segs)
|
||||
}
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ExampleBackbone(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleBackbone, self).__init__()
|
||||
self.conv = nn.Conv2d(3, 3, 3)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
pass
|
||||
|
||||
def forward(self, x):
|
||||
return [self.conv(x)]
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class ExampleDecodeHead(BaseDecodeHead):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleDecodeHead, self).__init__(3, 3, num_classes=19)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.cls_seg(inputs[0])
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class ExampleCascadeDecodeHead(BaseCascadeDecodeHead):
|
||||
|
||||
def __init__(self):
|
||||
super(ExampleCascadeDecodeHead, self).__init__(3, 3, num_classes=19)
|
||||
|
||||
def forward(self, inputs, prev_out):
|
||||
return self.cls_seg(inputs[0])
|
||||
|
||||
|
||||
def _segmentor_forward_train_test(segmentor):
|
||||
if isinstance(segmentor.decode_head, nn.ModuleList):
|
||||
num_classes = segmentor.decode_head[-1].num_classes
|
||||
else:
|
||||
num_classes = segmentor.decode_head.num_classes
|
||||
# batch_size=2 for BatchNorm
|
||||
mm_inputs = _demo_mm_inputs(num_classes=num_classes)
|
||||
|
||||
imgs = mm_inputs.pop('imgs')
|
||||
img_metas = mm_inputs.pop('img_metas')
|
||||
gt_semantic_seg = mm_inputs['gt_semantic_seg']
|
||||
|
||||
# convert to cuda Tensor if applicable
|
||||
if torch.cuda.is_available():
|
||||
segmentor = segmentor.cuda()
|
||||
imgs = imgs.cuda()
|
||||
gt_semantic_seg = gt_semantic_seg.cuda()
|
||||
|
||||
# Test forward train
|
||||
losses = segmentor.forward(
|
||||
imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True)
|
||||
assert isinstance(losses, dict)
|
||||
|
||||
# Test train_step
|
||||
data_batch = dict(
|
||||
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
|
||||
outputs = segmentor.train_step(data_batch, None)
|
||||
assert isinstance(outputs, dict)
|
||||
assert 'loss' in outputs
|
||||
assert 'log_vars' in outputs
|
||||
assert 'num_samples' in outputs
|
||||
|
||||
# Test val_step
|
||||
with torch.no_grad():
|
||||
segmentor.eval()
|
||||
data_batch = dict(
|
||||
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
|
||||
outputs = segmentor.val_step(data_batch, None)
|
||||
assert isinstance(outputs, dict)
|
||||
assert 'loss' in outputs
|
||||
assert 'log_vars' in outputs
|
||||
assert 'num_samples' in outputs
|
||||
|
||||
# Test forward simple test
|
||||
with torch.no_grad():
|
||||
segmentor.eval()
|
||||
# pack into lists
|
||||
img_list = [img[None, :] for img in imgs]
|
||||
img_meta_list = [[img_meta] for img_meta in img_metas]
|
||||
segmentor.forward(img_list, img_meta_list, return_loss=False)
|
||||
|
||||
# Test forward aug test
|
||||
with torch.no_grad():
|
||||
segmentor.eval()
|
||||
# pack into lists
|
||||
img_list = [img[None, :] for img in imgs]
|
||||
img_list = img_list + img_list
|
||||
img_meta_list = [[img_meta] for img_meta in img_metas]
|
||||
img_meta_list = img_meta_list + img_meta_list
|
||||
segmentor.forward(img_list, img_meta_list, return_loss=False)
|
Loading…
Reference in New Issue