mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Support depth estimation algorithm [VPD](https://github.com/wl-zhao/VPD) ## Modification 1. add VPD backbone 2. add VPD decoder head for depth estimation 3. add a new segmentor `DepthEstimator` based on `EncoderDecoder` for depth estimation 4. add an integrated metric that calculate common metrics in depth estimation 5. add SiLog loss for depth estimation 6. add config for VPD ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 7. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 8. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 9. The documentation has been modified accordingly, like docstring or example tutorials.
176 lines
6.6 KiB
Python
176 lines
6.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import glob
|
|
import os
|
|
from os.path import dirname, exists, isdir, join, relpath
|
|
|
|
import numpy as np
|
|
from mmengine import Config
|
|
from mmengine.dataset import Compose
|
|
from mmengine.registry import init_default_scope
|
|
from torch import nn
|
|
|
|
from mmseg.models import build_segmentor
|
|
|
|
|
|
def _get_config_directory():
|
|
"""Find the predefined segmentor config directory."""
|
|
try:
|
|
# Assume we are running in the source mmsegmentation repo
|
|
repo_dpath = dirname(dirname(__file__))
|
|
except NameError:
|
|
# For IPython development when this __file__ is not defined
|
|
import mmseg
|
|
repo_dpath = dirname(dirname(mmseg.__file__))
|
|
config_dpath = join(repo_dpath, 'configs')
|
|
if not exists(config_dpath):
|
|
raise Exception('Cannot find config path')
|
|
return config_dpath
|
|
|
|
|
|
def test_config_build_segmentor():
|
|
"""Test that all segmentation models defined in the configs can be
|
|
initialized."""
|
|
init_default_scope('mmseg')
|
|
config_dpath = _get_config_directory()
|
|
print(f'Found config_dpath = {config_dpath!r}')
|
|
|
|
config_fpaths = []
|
|
# one config each sub folder
|
|
for sub_folder in os.listdir(config_dpath):
|
|
if isdir(sub_folder):
|
|
config_fpaths.append(
|
|
list(glob.glob(join(config_dpath, sub_folder, '*.py')))[0])
|
|
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
|
config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
|
|
|
print(f'Using {len(config_names)} config files')
|
|
|
|
for config_fname in config_names:
|
|
config_fpath = join(config_dpath, config_fname)
|
|
config_mod = Config.fromfile(config_fpath)
|
|
|
|
config_mod.model
|
|
print(f'Building segmentor, config_fpath = {config_fpath!r}')
|
|
|
|
# Remove pretrained keys to allow for testing in an offline environment
|
|
if 'pretrained' in config_mod.model:
|
|
config_mod.model['pretrained'] = None
|
|
|
|
print(f'building {config_fname}')
|
|
segmentor = build_segmentor(config_mod.model)
|
|
assert segmentor is not None
|
|
|
|
head_config = config_mod.model['decode_head']
|
|
_check_decode_head(head_config, segmentor.decode_head)
|
|
|
|
|
|
def test_config_data_pipeline():
|
|
"""Test whether the data pipeline is valid and can process corner cases.
|
|
|
|
CommandLine:
|
|
xdoctest -m tests/test_config.py test_config_build_data_pipeline
|
|
"""
|
|
|
|
init_default_scope('mmseg')
|
|
config_dpath = _get_config_directory()
|
|
print(f'Found config_dpath = {config_dpath!r}')
|
|
|
|
import glob
|
|
config_fpaths = list(glob.glob(join(config_dpath, '**', '*.py')))
|
|
config_fpaths = [p for p in config_fpaths if p.find('_base_') == -1]
|
|
config_names = [relpath(p, config_dpath) for p in config_fpaths]
|
|
|
|
print(f'Using {len(config_names)} config files')
|
|
|
|
for config_fname in config_names:
|
|
config_fpath = join(config_dpath, config_fname)
|
|
print(f'Building data pipeline, config_fpath = {config_fpath!r}')
|
|
config_mod = Config.fromfile(config_fpath)
|
|
|
|
# remove loading pipeline
|
|
load_img_pipeline = config_mod.train_pipeline.pop(0)
|
|
to_float32 = load_img_pipeline.get('to_float32', False)
|
|
del config_mod.train_pipeline[0]
|
|
del config_mod.test_pipeline[0]
|
|
# remove loading annotation in test pipeline
|
|
load_anno_idx = -1
|
|
for i in range(len(config_mod.test_pipeline)):
|
|
if config_mod.test_pipeline[i].type in ('LoadAnnotations',
|
|
'LoadDepthAnnotation'):
|
|
load_anno_idx = i
|
|
del config_mod.test_pipeline[load_anno_idx]
|
|
|
|
train_pipeline = Compose(config_mod.train_pipeline)
|
|
test_pipeline = Compose(config_mod.test_pipeline)
|
|
|
|
img = np.random.randint(0, 255, size=(1024, 2048, 3), dtype=np.uint8)
|
|
if to_float32:
|
|
img = img.astype(np.float32)
|
|
seg = np.random.randint(0, 255, size=(1024, 2048, 1), dtype=np.uint8)
|
|
depth = np.random.rand(1024, 2048).astype(np.float32)
|
|
|
|
results = dict(
|
|
filename='test_img.png',
|
|
ori_filename='test_img.png',
|
|
img=img,
|
|
img_shape=img.shape,
|
|
ori_shape=img.shape,
|
|
gt_seg_map=seg,
|
|
gt_depth_map=depth)
|
|
results['seg_fields'] = ['gt_seg_map']
|
|
_check_concat_cd_input(config_mod, results)
|
|
print(f'Test training data pipeline: \n{train_pipeline!r}')
|
|
output_results = train_pipeline(results)
|
|
assert output_results is not None
|
|
|
|
_check_concat_cd_input(config_mod, results)
|
|
print(f'Test testing data pipeline: \n{test_pipeline!r}')
|
|
output_results = test_pipeline(results)
|
|
assert output_results is not None
|
|
|
|
|
|
def _check_concat_cd_input(config_mod: Config, results: dict):
|
|
keys = []
|
|
pipeline = config_mod.train_pipeline.copy()
|
|
pipeline.extend(config_mod.test_pipeline)
|
|
for t in pipeline:
|
|
keys.append(t.type)
|
|
if 'ConcatCDInput' in keys:
|
|
results.update({'img2': results['img']})
|
|
|
|
|
|
def _check_decode_head(decode_head_cfg, decode_head):
|
|
if isinstance(decode_head_cfg, list):
|
|
assert isinstance(decode_head, nn.ModuleList)
|
|
assert len(decode_head_cfg) == len(decode_head)
|
|
num_heads = len(decode_head)
|
|
for i in range(num_heads):
|
|
_check_decode_head(decode_head_cfg[i], decode_head[i])
|
|
return
|
|
# check consistency between head_config and roi_head
|
|
assert decode_head_cfg['type'] == decode_head.__class__.__name__
|
|
|
|
assert decode_head_cfg['type'] == decode_head.__class__.__name__
|
|
|
|
in_channels = decode_head_cfg.in_channels
|
|
input_transform = decode_head.input_transform
|
|
assert input_transform in ['resize_concat', 'multiple_select', None]
|
|
if input_transform is not None:
|
|
assert isinstance(in_channels, (list, tuple))
|
|
assert isinstance(decode_head.in_index, (list, tuple))
|
|
assert len(in_channels) == len(decode_head.in_index)
|
|
elif input_transform == 'resize_concat':
|
|
assert sum(in_channels) == decode_head.in_channels
|
|
else:
|
|
assert in_channels == decode_head.in_channels
|
|
|
|
if decode_head_cfg['type'] == 'PointHead':
|
|
assert decode_head_cfg.channels+decode_head_cfg.num_classes == \
|
|
decode_head.fc_seg.in_channels
|
|
assert decode_head.fc_seg.out_channels == decode_head_cfg.num_classes
|
|
elif decode_head_cfg['type'] == 'VPDDepthHead':
|
|
assert decode_head.out_channels == 1
|
|
else:
|
|
assert decode_head_cfg.channels == decode_head.conv_seg.in_channels
|
|
assert decode_head.conv_seg.out_channels == decode_head_cfg.num_classes
|