255 lines
8.8 KiB
Python
Raw Normal View History

[Feature] Support VPD Depth Estimator (#3321) Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Support depth estimation algorithm [VPD](https://github.com/wl-zhao/VPD) ## Modification 1. add VPD backbone 2. add VPD decoder head for depth estimation 3. add a new segmentor `DepthEstimator` based on `EncoderDecoder` for depth estimation 4. add an integrated metric that calculate common metrics in depth estimation 5. add SiLog loss for depth estimation 6. add config for VPD ## BC-breaking (Optional) Does the modification introduce changes that break the backward-compatibility of the downstream repos? If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR. ## Use cases (Optional) If this PR introduces a new feature, it is better to list some use cases here, and update the documentation. ## Checklist 1. Pre-commit or other linting tools are used to fix the potential lint issues. 7. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness. 8. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMDet3D. 9. The documentation has been modified accordingly, like docstring or example tutorials.
2023-09-13 15:31:22 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Sequence, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer
from mmengine.model import BaseModule
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.utils import SampleList
from ..builder import build_loss
from ..utils import resize
from .decode_head import BaseDecodeHead
class VPDDepthDecoder(BaseModule):
"""VPD Depth Decoder class.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
num_deconv_layers (int): Number of deconvolution layers.
num_deconv_filters (List[int]): List of output channels for
deconvolution layers.
init_cfg (Optional[Union[Dict, List[Dict]]], optional): Configuration
for weight initialization. Defaults to Normal for Conv2d and
ConvTranspose2d layers.
"""
def __init__(self,
in_channels: int,
out_channels: int,
num_deconv_layers: int,
num_deconv_filters: List[int],
init_cfg: Optional[Union[Dict, List[Dict]]] = dict(
type='Normal',
std=0.001,
layer=['Conv2d', 'ConvTranspose2d'])):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.deconv_layers = self._make_deconv_layer(
num_deconv_layers,
num_deconv_filters,
)
conv_layers = []
conv_layers.append(
build_conv_layer(
dict(type='Conv2d'),
in_channels=num_deconv_filters[-1],
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1))
conv_layers.append(build_norm_layer(dict(type='BN'), out_channels)[1])
conv_layers.append(nn.ReLU(inplace=True))
self.conv_layers = nn.Sequential(*conv_layers)
self.up_sample = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=False)
def forward(self, x):
"""Forward pass through the decoder network."""
out = self.deconv_layers(x)
out = self.conv_layers(out)
out = self.up_sample(out)
out = self.up_sample(out)
return out
def _make_deconv_layer(self, num_layers, num_deconv_filters):
"""Make deconv layers."""
layers = []
in_channels = self.in_channels
for i in range(num_layers):
num_channels = num_deconv_filters[i]
layers.append(
build_upsample_layer(
dict(type='deconv'),
in_channels=in_channels,
out_channels=num_channels,
kernel_size=2,
stride=2,
padding=0,
output_padding=0,
bias=False))
layers.append(nn.BatchNorm2d(num_channels))
layers.append(nn.ReLU(inplace=True))
in_channels = num_channels
return nn.Sequential(*layers)
@MODELS.register_module()
class VPDDepthHead(BaseDecodeHead):
"""Depth Prediction Head for VPD.
.. _`VPD`: https://arxiv.org/abs/2303.02153
Args:
max_depth (float): Maximum depth value. Defaults to 10.0.
in_channels (Sequence[int]): Number of input channels for each
convolutional layer.
embed_dim (int): Dimension of embedding. Defaults to 192.
feature_dim (int): Dimension of aggregated feature. Defaults to 1536.
num_deconv_layers (int): Number of deconvolution layers in the
decoder. Defaults to 3.
num_deconv_filters (Sequence[int]): Number of filters for each deconv
layer. Defaults to (32, 32, 32).
fmap_border (Union[int, Sequence[int]]): Feature map border for
cropping. Defaults to 0.
align_corners (bool): Flag for align_corners in interpolation.
Defaults to False.
loss_decode (dict): Configurations for the loss function. Defaults to
dict(type='SiLogLoss').
init_cfg (dict): Initialization configurations. Defaults to
dict(type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']).
"""
num_classes = 1
out_channels = 1
input_transform = None
def __init__(
self,
max_depth: float = 10.0,
in_channels: Sequence[int] = [320, 640, 1280, 1280],
embed_dim: int = 192,
feature_dim: int = 1536,
num_deconv_layers: int = 3,
num_deconv_filters: Sequence[int] = (32, 32, 32),
fmap_border: Union[int, Sequence[int]] = 0,
align_corners: bool = False,
loss_decode: dict = dict(type='SiLogLoss'),
init_cfg=dict(
type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']),
):
super(BaseDecodeHead, self).__init__(init_cfg=init_cfg)
# initialize parameters
self.in_channels = in_channels
self.max_depth = max_depth
self.align_corners = align_corners
# feature map border
if isinstance(fmap_border, int):
fmap_border = (fmap_border, fmap_border)
self.fmap_border = fmap_border
# define network layers
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1),
nn.GroupNorm(16, in_channels[0]),
nn.ReLU(),
nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1),
)
self.conv2 = nn.Conv2d(
in_channels[1], in_channels[1], 3, stride=2, padding=1)
self.conv_aggregation = nn.Sequential(
nn.Conv2d(sum(in_channels), feature_dim, 1),
nn.GroupNorm(16, feature_dim),
nn.ReLU(),
)
self.decoder = VPDDepthDecoder(
in_channels=embed_dim * 8,
out_channels=embed_dim,
num_deconv_layers=num_deconv_layers,
num_deconv_filters=num_deconv_filters)
self.depth_pred_layer = nn.Sequential(
nn.Conv2d(
embed_dim, embed_dim, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(embed_dim, 1, kernel_size=3, stride=1, padding=1))
# build loss
if isinstance(loss_decode, dict):
self.loss_decode = build_loss(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
else:
raise TypeError(f'loss_decode must be a dict or sequence of dict,\
but got {type(loss_decode)}')
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
gt_depth_maps = [
data_sample.gt_depth_map.data for data_sample in batch_data_samples
]
return torch.stack(gt_depth_maps, dim=0)
def forward(self, x):
x = [
x[0], x[1],
torch.cat([x[2], F.interpolate(x[3], scale_factor=2)], dim=1)
]
x = torch.cat([self.conv1(x[0]), self.conv2(x[1]), x[2]], dim=1)
x = self.conv_aggregation(x)
x = x[:, :, :x.size(2) - self.fmap_border[0], :x.size(3) -
self.fmap_border[1]].contiguous()
x = self.decoder(x)
out = self.depth_pred_layer(x)
depth = torch.sigmoid(out) * self.max_depth
return depth
def loss_by_feat(self, pred_depth_map: Tensor,
batch_data_samples: SampleList) -> dict:
"""Compute depth estimation loss.
Args:
pred_depth_map (Tensor): The output from decode head forward
function.
batch_data_samples (List[:obj:`SegDataSample`]): The seg
data samples. It usually includes information such
as `metainfo` and `gt_dpeth_map`.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
gt_depth_map = self._stack_batch_gt(batch_data_samples)
loss = dict()
pred_depth_map = resize(
input=pred_depth_map,
size=gt_depth_map.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(
pred_depth_map, gt_depth_map)
else:
loss[loss_decode.loss_name] += loss_decode(
pred_depth_map, gt_depth_map)
return loss