Add depth estimation code (#184)

Add depth estimation code + demo notebook
pull/185/head
Patrick Labatut 2023-08-31 14:57:50 +02:00 committed by GitHub
parent 3a7bf1ca4b
commit d5c376b5b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1770 additions and 2 deletions

2
.gitignore vendored
View File

@ -6,8 +6,6 @@ dist/
**/.ipynb_checkpoints
**/.ipynb_checkpoints/**
**/notebooks
*.swp
.vscode/

View File

@ -0,0 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .backbones import * # noqa: F403
from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss
from .decode_heads import * # noqa: F403
from .depther import * # noqa: F403
from .losses import * # noqa: F403

View File

@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .vision_transformer import DinoVisionTransformer

View File

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from mmcv.runner import BaseModule
from ..builder import BACKBONES
@BACKBONES.register_module()
class DinoVisionTransformer(BaseModule):
"""Vision Transformer."""
def __init__(self, *args, **kwargs):
super().__init__()

View File

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import warnings
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
from mmcv.utils import Registry
MODELS = Registry("models", parent=MMCV_MODELS)
ATTENTION = Registry("attention", parent=MMCV_ATTENTION)
BACKBONES = MODELS
NECKS = MODELS
HEADS = MODELS
LOSSES = MODELS
DEPTHER = MODELS
def build_backbone(cfg):
"""Build backbone."""
return BACKBONES.build(cfg)
def build_neck(cfg):
"""Build neck."""
return NECKS.build(cfg)
def build_head(cfg):
"""Build head."""
return HEADS.build(cfg)
def build_loss(cfg):
"""Build loss."""
return LOSSES.build(cfg)
def build_depther(cfg, train_cfg=None, test_cfg=None):
"""Build depther."""
if train_cfg is not None or test_cfg is not None:
warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning)
assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field "
assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field "
return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))

View File

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .dpt_head import DPTHead
from .linear_head import BNHead

View File

@ -0,0 +1,225 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import copy
from abc import ABCMeta, abstractmethod
import mmcv
import numpy as np
import torch
import torch.nn as nn
from mmcv.runner import BaseModule, auto_fp16, force_fp32
from ...ops import resize
from ..builder import build_loss
class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta):
"""Base class for BaseDecodeHead.
Args:
in_channels (List): Input channels.
channels (int): Channels after modules, before conv_depth.
conv_cfg (dict|None): Config of conv layers. Default: None.
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU')
loss_decode (dict): Config of decode loss.
Default: dict(type='SigLoss').
sampler (dict|None): The config of depth map sampler.
Default: None.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
min_depth (int): Min depth in dataset setting.
Default: 1e-3.
max_depth (int): Max depth in dataset setting.
Default: None.
norm_cfg (dict|None): Config of norm layers.
Default: None.
classify (bool): Whether predict depth in a cls.-reg. manner.
Default: False.
n_bins (int): The number of bins used in cls. step.
Default: 256.
bins_strategy (str): The discrete strategy used in cls. step.
Default: 'UD'.
norm_strategy (str): The norm strategy on cls. probability
distribution. Default: 'linear'
scale_up (str): Whether predict depth in a scale-up manner.
Default: False.
"""
def __init__(
self,
in_channels,
channels=96,
conv_cfg=None,
act_cfg=dict(type="ReLU"),
loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10),
sampler=None,
align_corners=False,
min_depth=1e-3,
max_depth=None,
norm_cfg=None,
classify=False,
n_bins=256,
bins_strategy="UD",
norm_strategy="linear",
scale_up=False,
):
super(DepthBaseDecodeHead, self).__init__()
self.in_channels = in_channels
self.channels = channels
self.conv_cfg = conv_cfg
self.act_cfg = act_cfg
if isinstance(loss_decode, dict):
self.loss_decode = build_loss(loss_decode)
elif isinstance(loss_decode, (list, tuple)):
self.loss_decode = nn.ModuleList()
for loss in loss_decode:
self.loss_decode.append(build_loss(loss))
self.align_corners = align_corners
self.min_depth = min_depth
self.max_depth = max_depth
self.norm_cfg = norm_cfg
self.classify = classify
self.n_bins = n_bins
self.scale_up = scale_up
if self.classify:
assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
self.bins_strategy = bins_strategy
self.norm_strategy = norm_strategy
self.softmax = nn.Softmax(dim=1)
self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
else:
self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
self.fp16_enabled = False
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def extra_repr(self):
"""Extra repr."""
s = f"align_corners={self.align_corners}"
return s
@auto_fp16()
@abstractmethod
def forward(self, inputs, img_metas):
"""Placeholder of forward function."""
pass
def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg):
"""Forward function for training.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
depth_gt (Tensor): GT depth
train_cfg (dict): The training config.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
depth_pred = self.forward(inputs, img_metas)
losses = self.losses(depth_pred, depth_gt)
log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
losses.update(**log_imgs)
return losses
def forward_test(self, inputs, img_metas, test_cfg):
"""Forward function for testing.
Args:
inputs (list[Tensor]): List of multi-level img features.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
test_cfg (dict): The testing config.
Returns:
Tensor: Output depth map.
"""
return self.forward(inputs, img_metas)
def depth_pred(self, feat):
"""Prediction each pixel."""
if self.classify:
logit = self.conv_depth(feat)
if self.bins_strategy == "UD":
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
elif self.bins_strategy == "SID":
bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
# following Adabins, default linear
if self.norm_strategy == "linear":
logit = torch.relu(logit)
eps = 0.1
logit = logit + eps
logit = logit / logit.sum(dim=1, keepdim=True)
elif self.norm_strategy == "softmax":
logit = torch.softmax(logit, dim=1)
elif self.norm_strategy == "sigmoid":
logit = torch.sigmoid(logit)
logit = logit / logit.sum(dim=1, keepdim=True)
output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
else:
if self.scale_up:
output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
else:
output = self.relu(self.conv_depth(feat)) + self.min_depth
return output
@force_fp32(apply_to=("depth_pred",))
def losses(self, depth_pred, depth_gt):
"""Compute depth loss."""
loss = dict()
depth_pred = resize(
input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
)
if not isinstance(self.loss_decode, nn.ModuleList):
losses_decode = [self.loss_decode]
else:
losses_decode = self.loss_decode
for loss_decode in losses_decode:
if loss_decode.loss_name not in loss:
loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
else:
loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
return loss
def log_images(self, img_path, depth_pred, depth_gt, img_meta):
show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
show_img = show_img.numpy().astype(np.float32)
show_img = mmcv.imdenormalize(
show_img,
img_meta["img_norm_cfg"]["mean"],
img_meta["img_norm_cfg"]["std"],
img_meta["img_norm_cfg"]["to_rgb"],
)
show_img = np.clip(show_img, 0, 255)
show_img = show_img.astype(np.uint8)
show_img = show_img[:, :, ::-1]
show_img = show_img.transpose(0, 2, 1)
show_img = show_img.transpose(1, 0, 2)
depth_pred = depth_pred / torch.max(depth_pred)
depth_gt = depth_gt / torch.max(depth_gt)
depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}

View File

@ -0,0 +1,270 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, Linear, build_activation_layer
from mmcv.runner import BaseModule
from ...ops import resize
from ..builder import HEADS
from .decode_head import DepthBaseDecodeHead
class Interpolate(nn.Module):
def __init__(self, scale_factor, mode, align_corners=False):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
return x
class HeadDepth(nn.Module):
def __init__(self, features):
super(HeadDepth, self).__init__()
self.head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
)
def forward(self, x):
x = self.head(x)
return x
class ReassembleBlocks(BaseModule):
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
rearrange the feature vector to feature map.
Args:
in_channels (int): ViT feature channels. Default: 768.
out_channels (List): output channels of each stage.
Default: [96, 192, 384, 768].
readout_type (str): Type of readout operation. Default: 'ignore'.
patch_size (int): The patch size. Default: 16.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
def __init__(
self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16, init_cfg=None
):
super(ReassembleBlocks, self).__init__(init_cfg)
assert readout_type in ["ignore", "add", "project"]
self.readout_type = readout_type
self.patch_size = patch_size
self.projects = nn.ModuleList(
[
ConvModule(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
act_cfg=None,
)
for out_channel in out_channels
]
)
self.resize_layers = nn.ModuleList(
[
nn.ConvTranspose2d(
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
),
nn.ConvTranspose2d(
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
),
]
)
if self.readout_type == "project":
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(
nn.Sequential(Linear(2 * in_channels, in_channels), build_activation_layer(dict(type="GELU")))
)
def forward(self, inputs):
assert isinstance(inputs, list)
out = []
for i, x in enumerate(inputs):
assert len(x) == 2
x, cls_token = x[0], x[1]
feature_shape = x.shape
if self.readout_type == "project":
x = x.flatten(2).permute((0, 2, 1))
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
x = x.permute(0, 2, 1).reshape(feature_shape)
elif self.readout_type == "add":
x = x.flatten(2) + cls_token.unsqueeze(-1)
x = x.reshape(feature_shape)
else:
pass
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
return out
class PreActResidualConvUnit(BaseModule):
"""ResidualConvUnit, pre-activate residual unit.
Args:
in_channels (int): number of channels in the input feature map.
act_cfg (dict): dictionary to construct and config activation layer.
norm_cfg (dict): dictionary to construct and config norm layer.
stride (int): stride of the first block. Default: 1
dilation (int): dilation rate for convs layers. Default: 1.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None):
super(PreActResidualConvUnit, self).__init__(init_cfg)
self.conv1 = ConvModule(
in_channels,
in_channels,
3,
stride=stride,
padding=dilation,
dilation=dilation,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
bias=False,
order=("act", "conv", "norm"),
)
self.conv2 = ConvModule(
in_channels,
in_channels,
3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
bias=False,
order=("act", "conv", "norm"),
)
def forward(self, inputs):
inputs_ = inputs.clone()
x = self.conv1(inputs)
x = self.conv2(x)
return x + inputs_
class FeatureFusionBlock(BaseModule):
"""FeatureFusionBlock, merge feature map from different stages.
Args:
in_channels (int): Input channels.
act_cfg (dict): The activation config for ResidualConvUnit.
norm_cfg (dict): Config dict for normalization layer.
expand (bool): Whether expand the channels in post process block.
Default: False.
align_corners (bool): align_corner setting for bilinear upsample.
Default: True.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None):
super(FeatureFusionBlock, self).__init__(init_cfg)
self.in_channels = in_channels
self.expand = expand
self.align_corners = align_corners
self.out_channels = in_channels
if self.expand:
self.out_channels = in_channels // 2
self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True)
self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
def forward(self, *inputs):
x = inputs[0]
if len(inputs) == 2:
if x.shape != inputs[1].shape:
res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
else:
res = inputs[1]
x = x + self.res_conv_unit1(res)
x = self.res_conv_unit2(x)
x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
x = self.project(x)
return x
@HEADS.register_module()
class DPTHead(DepthBaseDecodeHead):
"""Vision Transformers for Dense Prediction.
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
Args:
embed_dims (int): The embed dimension of the ViT backbone.
Default: 768.
post_process_channels (List): Out channels of post process conv
layers. Default: [96, 192, 384, 768].
readout_type (str): Type of readout operation. Default: 'ignore'.
patch_size (int): The patch size. Default: 16.
expand_channels (bool): Whether expand the channels in post process
block. Default: False.
"""
def __init__(
self,
embed_dims=768,
post_process_channels=[96, 192, 384, 768],
readout_type="ignore",
patch_size=16,
expand_channels=False,
**kwargs
):
super(DPTHead, self).__init__(**kwargs)
self.in_channels = self.in_channels
self.expand_channels = expand_channels
self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
self.post_process_channels = [
channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
]
self.convs = nn.ModuleList()
for channel in self.post_process_channels:
self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False))
self.fusion_blocks = nn.ModuleList()
for _ in range(len(self.convs)):
self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg))
self.fusion_blocks[0].res_conv_unit1 = None
self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg)
self.num_fusion_blocks = len(self.fusion_blocks)
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
self.num_post_process_channels = len(self.post_process_channels)
assert self.num_fusion_blocks == self.num_reassemble_blocks
assert self.num_reassemble_blocks == self.num_post_process_channels
self.conv_depth = HeadDepth(self.channels)
def forward(self, inputs, img_metas):
assert len(inputs) == self.num_reassemble_blocks
x = [inp for inp in inputs]
x = self.reassemble_blocks(x)
x = [self.convs[i](feature) for i, feature in enumerate(x)]
out = self.fusion_blocks[0](x[-1])
for i in range(1, len(self.fusion_blocks)):
out = self.fusion_blocks[i](out, x[-(i + 1)])
out = self.project(out)
out = self.depth_pred(out)
return out

View File

@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from ...ops import resize
from ..builder import HEADS
from .decode_head import DepthBaseDecodeHead
@HEADS.register_module()
class BNHead(DepthBaseDecodeHead):
"""Just a batchnorm."""
def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
super().__init__(**kwargs)
self.input_transform = input_transform
self.in_index = in_index
self.upsample = upsample
# self.bn = nn.SyncBatchNorm(self.in_channels)
if self.classify:
self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
else:
self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if "concat" in self.input_transform:
inputs = [inputs[i] for i in self.in_index]
if "resize" in self.input_transform:
inputs = [
resize(
input=x,
size=[s * self.upsample for s in inputs[0].shape[2:]],
mode="bilinear",
align_corners=self.align_corners,
)
for x in inputs
]
inputs = torch.cat(inputs, dim=1)
elif self.input_transform == "multiple_select":
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
def _forward_feature(self, inputs, img_metas=None, **kwargs):
"""Forward function for feature maps before classifying each pixel with
``self.cls_seg`` fc.
Args:
inputs (list[Tensor]): List of multi-level img features.
Returns:
feats (Tensor): A tensor of shape (batch_size, self.channels,
H, W) which is feature map for last layer of decoder head.
"""
# accept lists (for cls token)
inputs = list(inputs)
for i, x in enumerate(inputs):
if len(x) == 2:
x, cls_token = x[0], x[1]
if len(x.shape) == 2:
x = x[:, :, None, None]
cls_token = cls_token[:, :, None, None].expand_as(x)
inputs[i] = torch.cat((x, cls_token), 1)
else:
x = x[0]
if len(x.shape) == 2:
x = x[:, :, None, None]
inputs[i] = x
x = self._transform_inputs(inputs)
# feats = self.bn(x)
return x
def forward(self, inputs, img_metas=None, **kwargs):
"""Forward function."""
output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
output = self.depth_pred(output)
return output

View File

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .base import BaseDepther
from .encoder_decoder import DepthEncoderDecoder

View File

@ -0,0 +1,194 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
import torch
import torch.distributed as dist
from mmcv.runner import BaseModule, auto_fp16
class BaseDepther(BaseModule, metaclass=ABCMeta):
"""Base class for depther."""
def __init__(self, init_cfg=None):
super(BaseDepther, self).__init__(init_cfg)
self.fp16_enabled = False
@property
def with_neck(self):
"""bool: whether the depther has neck"""
return hasattr(self, "neck") and self.neck is not None
@property
def with_auxiliary_head(self):
"""bool: whether the depther has auxiliary head"""
return hasattr(self, "auxiliary_head") and self.auxiliary_head is not None
@property
def with_decode_head(self):
"""bool: whether the depther has decode head"""
return hasattr(self, "decode_head") and self.decode_head is not None
@abstractmethod
def extract_feat(self, imgs):
"""Placeholder for extract features from images."""
pass
@abstractmethod
def encode_decode(self, img, img_metas):
"""Placeholder for encode images with backbone and decode into a
semantic depth map of the same size as input."""
pass
@abstractmethod
def forward_train(self, imgs, img_metas, **kwargs):
"""Placeholder for Forward function for training."""
pass
@abstractmethod
def simple_test(self, img, img_meta, **kwargs):
"""Placeholder for single image test."""
pass
@abstractmethod
def aug_test(self, imgs, img_metas, **kwargs):
"""Placeholder for augmentation test."""
pass
def forward_test(self, imgs, img_metas, **kwargs):
"""
Args:
imgs (List[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_metas (List[List[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch.
"""
for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
if not isinstance(var, list):
raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
# all images in the same aug batch all of the same ori_shape and pad
# shape
for img_meta in img_metas:
ori_shapes = [_["ori_shape"] for _ in img_meta]
assert all(shape == ori_shapes[0] for shape in ori_shapes)
img_shapes = [_["img_shape"] for _ in img_meta]
assert all(shape == img_shapes[0] for shape in img_shapes)
pad_shapes = [_["pad_shape"] for _ in img_meta]
assert all(shape == pad_shapes[0] for shape in pad_shapes)
if num_augs == 1:
return self.simple_test(imgs[0], img_metas[0], **kwargs)
else:
return self.aug_test(imgs, img_metas, **kwargs)
@auto_fp16(apply_to=("img",))
def forward(self, img, img_metas, return_loss=True, **kwargs):
"""Calls either :func:`forward_train` or :func:`forward_test` depending
on whether ``return_loss`` is ``True``.
Note this setting will change the expected inputs. When
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
and List[dict]), and when ``resturn_loss=False``, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
"""
if return_loss:
return self.forward_train(img, img_metas, **kwargs)
else:
return self.forward_test(img, img_metas, **kwargs)
def train_step(self, data_batch, optimizer, **kwargs):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
``num_samples``.
``loss`` is a tensor for back propagation, which can be a
weighted sum of multiple losses.
``log_vars`` contains all the variables to be sent to the
logger.
``num_samples`` indicates the batch size (when the model is
DDP, it means the batch size on each GPU), which is used for
averaging the logs.
"""
losses = self(**data_batch)
# split losses and images
real_losses = {}
log_imgs = {}
for k, v in losses.items():
if "img" in k:
log_imgs[k] = v
else:
real_losses[k] = v
loss, log_vars = self._parse_losses(real_losses)
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
return outputs
def val_step(self, data_batch, **kwargs):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
output = self(**data_batch, **kwargs)
return output
@staticmethod
def _parse_losses(losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
which may be a weighted sum of all losses, log_vars contains
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(f"{loss_name} is not a tensor or list of tensors")
loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
log_vars["loss"] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars

View File

@ -0,0 +1,236 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from ...models import builder
from ...models.builder import DEPTHER
from ...ops import resize
from .base import BaseDepther
def add_prefix(inputs, prefix):
"""Add prefix for dict.
Args:
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
dict: The dict with keys updated with ``prefix``.
"""
outputs = dict()
for name, value in inputs.items():
outputs[f"{prefix}.{name}"] = value
return outputs
@DEPTHER.register_module()
class DepthEncoderDecoder(BaseDepther):
"""Encoder Decoder depther.
EncoderDecoder typically consists of backbone, (neck) and decode_head.
"""
def __init__(self, backbone, decode_head, neck=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None):
super(DepthEncoderDecoder, self).__init__(init_cfg)
if pretrained is not None:
assert backbone.get("pretrained") is None, "both backbone and depther set pretrained weight"
backbone.pretrained = pretrained
self.backbone = builder.build_backbone(backbone)
self._init_decode_head(decode_head)
if neck is not None:
self.neck = builder.build_neck(neck)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
assert self.with_decode_head
def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``"""
self.decode_head = builder.build_head(decode_head)
self.align_corners = self.decode_head.align_corners
def extract_feat(self, img):
"""Extract features from images."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def encode_decode(self, img, img_metas, rescale=True, size=None):
"""Encode images with backbone and decode into a depth estimation
map of the same size as input."""
x = self.extract_feat(img)
out = self._decode_head_forward_test(x, img_metas)
# crop the pred depth to the certain range.
out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
if rescale:
if size is None:
if img_metas is not None:
size = img_metas[0]["ori_shape"][:2]
else:
size = img.shape[2:]
out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
return out
def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, self.train_cfg, **kwargs)
losses.update(add_prefix(loss_decode, "decode"))
return losses
def _decode_head_forward_test(self, x, img_metas):
"""Run forward function and calculate loss for decode head in
inference."""
depth_pred = self.decode_head.forward_test(x, img_metas, self.test_cfg)
return depth_pred
def forward_dummy(self, img):
"""Dummy forward function."""
depth = self.encode_decode(img, None)
return depth
def forward_train(self, img, img_metas, depth_gt, **kwargs):
"""Forward function for training.
Args:
img (Tensor): Input images.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
depth_gt (Tensor): Depth gt
used if the architecture supports depth estimation task.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(img)
losses = dict()
# the last of x saves the info from neck
loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
losses.update(loss_decode)
return losses
def whole_inference(self, img, img_meta, rescale, size=None):
"""Inference with full image."""
depth_pred = self.encode_decode(img, img_meta, rescale, size=size)
return depth_pred
def slide_inference(self, img, img_meta, rescale):
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
"""
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = img.size()
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = img.new_zeros((batch_size, 1, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
depth_pred = self.encode_decode(crop_img, img_meta, rescale)
preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
if torch.onnx.is_in_onnx_export():
# cast count_mat to constant while exporting to ONNX
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
preds = preds / count_mat
return preds
def inference(self, img, img_meta, rescale, size=None):
"""Inference with slide/whole style.
Args:
img (Tensor): The input image of shape (N, 3, H, W).
img_meta (dict): Image info dict where each dict has: 'img_shape',
'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`depth/datasets/pipelines/formatting.py:Collect`.
rescale (bool): Whether rescale back to original shape.
Returns:
Tensor: The output depth map.
"""
assert self.test_cfg.mode in ["slide", "whole"]
ori_shape = img_meta[0]["ori_shape"]
assert all(_["ori_shape"] == ori_shape for _ in img_meta)
if self.test_cfg.mode == "slide":
depth_pred = self.slide_inference(img, img_meta, rescale)
else:
depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
output = depth_pred
flip = img_meta[0]["flip"]
if flip:
flip_direction = img_meta[0]["flip_direction"]
assert flip_direction in ["horizontal", "vertical"]
if flip_direction == "horizontal":
output = output.flip(dims=(3,))
elif flip_direction == "vertical":
output = output.flip(dims=(2,))
return output
def simple_test(self, img, img_meta, rescale=True):
"""Simple test with single image."""
depth_pred = self.inference(img, img_meta, rescale)
if torch.onnx.is_in_onnx_export():
# our inference backend only support 4D output
depth_pred = depth_pred.unsqueeze(0)
return depth_pred
depth_pred = depth_pred.cpu().numpy()
# unravel batch dim
depth_pred = list(depth_pred)
return depth_pred
def aug_test(self, imgs, img_metas, rescale=True):
"""Test with augmentations.
Only rescale=True is supported.
"""
# aug_test rescale all imgs back to ori_shape for now
assert rescale
# to save memory, we get augmented depth logit inplace
depth_pred = self.inference(imgs[0], img_metas[0], rescale)
for i in range(1, len(imgs)):
cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
depth_pred += cur_depth_pred
depth_pred /= len(imgs)
depth_pred = depth_pred.cpu().numpy()
# unravel batch dim
depth_pred = list(depth_pred)
return depth_pred

View File

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .gradientloss import GradientLoss
from .sigloss import SigLoss

View File

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from ...models.builder import LOSSES
@LOSSES.register_module()
class GradientLoss(nn.Module):
"""GradientLoss.
Adapted from https://www.cs.cornell.edu/projects/megadepth/
Args:
valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
loss_weight (float): Weight of the loss. Default: 1.0.
max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
"""
def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"):
super(GradientLoss, self).__init__()
self.valid_mask = valid_mask
self.loss_weight = loss_weight
self.max_depth = max_depth
self.loss_name = loss_name
self.eps = 0.001 # avoid grad explode
def gradientloss(self, input, target):
input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)]
target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)]
gradient_loss = 0
for input, target in zip(input_downscaled, target_downscaled):
if self.valid_mask:
mask = target > 0
if self.max_depth is not None:
mask = torch.logical_and(target > 0, target <= self.max_depth)
N = torch.sum(mask)
else:
mask = torch.ones_like(target)
N = input.numel()
input_log = torch.log(input + self.eps)
target_log = torch.log(target + self.eps)
log_d_diff = input_log - target_log
log_d_diff = torch.mul(log_d_diff, mask)
v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :])
v_mask = torch.mul(mask[0:-2, :], mask[2:, :])
v_gradient = torch.mul(v_gradient, v_mask)
h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:])
h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:])
h_gradient = torch.mul(h_gradient, h_mask)
gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N
return gradient_loss
def forward(self, depth_pred, depth_gt):
"""Forward function."""
gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt)
return gradient_loss

View File

@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from ...models.builder import LOSSES
@LOSSES.register_module()
class SigLoss(nn.Module):
"""SigLoss.
This follows `AdaBins <https://arxiv.org/abs/2011.14141>`_.
Args:
valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
loss_weight (float): Weight of the loss. Default: 1.0.
max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
warm_up (bool): A simple warm up stage to help convergence. Default: False.
warm_iter (int): The number of warm up stage. Default: 100.
"""
def __init__(
self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss"
):
super(SigLoss, self).__init__()
self.valid_mask = valid_mask
self.loss_weight = loss_weight
self.max_depth = max_depth
self.loss_name = loss_name
self.eps = 0.001 # avoid grad explode
# HACK: a hack implementation for warmup sigloss
self.warm_up = warm_up
self.warm_iter = warm_iter
self.warm_up_counter = 0
def sigloss(self, input, target):
if self.valid_mask:
valid_mask = target > 0
if self.max_depth is not None:
valid_mask = torch.logical_and(target > 0, target <= self.max_depth)
input = input[valid_mask]
target = target[valid_mask]
if self.warm_up:
if self.warm_up_counter < self.warm_iter:
g = torch.log(input + self.eps) - torch.log(target + self.eps)
g = 0.15 * torch.pow(torch.mean(g), 2)
self.warm_up_counter += 1
return torch.sqrt(g)
g = torch.log(input + self.eps) - torch.log(target + self.eps)
Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2)
return torch.sqrt(Dg)
def forward(self, depth_pred, depth_gt):
"""Forward function."""
loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt)
return loss_depth

View File

@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .wrappers import resize

View File

@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import warnings
import torch.nn.functional as F
def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
if warning:
if size is not None and align_corners:
input_h, input_w = tuple(int(x) for x in input.shape[2:])
output_h, output_w = tuple(int(x) for x in size)
if output_h > input_h or output_w > output_h:
if (
(output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
and (output_h - 1) % (input_h - 1)
and (output_w - 1) % (input_w - 1)
):
warnings.warn(
f"When align_corners={align_corners}, "
"the output would more aligned if "
f"input size {(input_h, input_w)} is `x+1` and "
f"out size {(output_h, output_w)} is `nx+1`"
)
return F.interpolate(input, size, scale_factor, mode, align_corners)

File diff suppressed because one or more lines are too long