parent
3a7bf1ca4b
commit
d5c376b5b3
|
@ -6,8 +6,6 @@ dist/
|
|||
**/.ipynb_checkpoints
|
||||
**/.ipynb_checkpoints/**
|
||||
|
||||
**/notebooks
|
||||
|
||||
*.swp
|
||||
|
||||
.vscode/
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .backbones import * # noqa: F403
|
||||
from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss
|
||||
from .decode_heads import * # noqa: F403
|
||||
from .depther import * # noqa: F403
|
||||
from .losses import * # noqa: F403
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .vision_transformer import DinoVisionTransformer
|
|
@ -0,0 +1,16 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from ..builder import BACKBONES
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class DinoVisionTransformer(BaseModule):
|
||||
"""Vision Transformer."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
|
||||
from mmcv.cnn import MODELS as MMCV_MODELS
|
||||
from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
|
||||
from mmcv.utils import Registry
|
||||
|
||||
MODELS = Registry("models", parent=MMCV_MODELS)
|
||||
ATTENTION = Registry("attention", parent=MMCV_ATTENTION)
|
||||
|
||||
|
||||
BACKBONES = MODELS
|
||||
NECKS = MODELS
|
||||
HEADS = MODELS
|
||||
LOSSES = MODELS
|
||||
DEPTHER = MODELS
|
||||
|
||||
|
||||
def build_backbone(cfg):
|
||||
"""Build backbone."""
|
||||
return BACKBONES.build(cfg)
|
||||
|
||||
|
||||
def build_neck(cfg):
|
||||
"""Build neck."""
|
||||
return NECKS.build(cfg)
|
||||
|
||||
|
||||
def build_head(cfg):
|
||||
"""Build head."""
|
||||
return HEADS.build(cfg)
|
||||
|
||||
|
||||
def build_loss(cfg):
|
||||
"""Build loss."""
|
||||
return LOSSES.build(cfg)
|
||||
|
||||
|
||||
def build_depther(cfg, train_cfg=None, test_cfg=None):
|
||||
"""Build depther."""
|
||||
if train_cfg is not None or test_cfg is not None:
|
||||
warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning)
|
||||
assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field "
|
||||
assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field "
|
||||
return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .dpt_head import DPTHead
|
||||
from .linear_head import BNHead
|
|
@ -0,0 +1,225 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.runner import BaseModule, auto_fp16, force_fp32
|
||||
|
||||
from ...ops import resize
|
||||
from ..builder import build_loss
|
||||
|
||||
|
||||
class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for BaseDecodeHead.
|
||||
|
||||
Args:
|
||||
in_channels (List): Input channels.
|
||||
channels (int): Channels after modules, before conv_depth.
|
||||
conv_cfg (dict|None): Config of conv layers. Default: None.
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU')
|
||||
loss_decode (dict): Config of decode loss.
|
||||
Default: dict(type='SigLoss').
|
||||
sampler (dict|None): The config of depth map sampler.
|
||||
Default: None.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
min_depth (int): Min depth in dataset setting.
|
||||
Default: 1e-3.
|
||||
max_depth (int): Max depth in dataset setting.
|
||||
Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
Default: None.
|
||||
classify (bool): Whether predict depth in a cls.-reg. manner.
|
||||
Default: False.
|
||||
n_bins (int): The number of bins used in cls. step.
|
||||
Default: 256.
|
||||
bins_strategy (str): The discrete strategy used in cls. step.
|
||||
Default: 'UD'.
|
||||
norm_strategy (str): The norm strategy on cls. probability
|
||||
distribution. Default: 'linear'
|
||||
scale_up (str): Whether predict depth in a scale-up manner.
|
||||
Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
channels=96,
|
||||
conv_cfg=None,
|
||||
act_cfg=dict(type="ReLU"),
|
||||
loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10),
|
||||
sampler=None,
|
||||
align_corners=False,
|
||||
min_depth=1e-3,
|
||||
max_depth=None,
|
||||
norm_cfg=None,
|
||||
classify=False,
|
||||
n_bins=256,
|
||||
bins_strategy="UD",
|
||||
norm_strategy="linear",
|
||||
scale_up=False,
|
||||
):
|
||||
super(DepthBaseDecodeHead, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conv_cfg = conv_cfg
|
||||
self.act_cfg = act_cfg
|
||||
if isinstance(loss_decode, dict):
|
||||
self.loss_decode = build_loss(loss_decode)
|
||||
elif isinstance(loss_decode, (list, tuple)):
|
||||
self.loss_decode = nn.ModuleList()
|
||||
for loss in loss_decode:
|
||||
self.loss_decode.append(build_loss(loss))
|
||||
self.align_corners = align_corners
|
||||
self.min_depth = min_depth
|
||||
self.max_depth = max_depth
|
||||
self.norm_cfg = norm_cfg
|
||||
self.classify = classify
|
||||
self.n_bins = n_bins
|
||||
self.scale_up = scale_up
|
||||
|
||||
if self.classify:
|
||||
assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
|
||||
assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
|
||||
|
||||
self.bins_strategy = bins_strategy
|
||||
self.norm_strategy = norm_strategy
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
|
||||
else:
|
||||
self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
|
||||
|
||||
self.fp16_enabled = False
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f"align_corners={self.align_corners}"
|
||||
return s
|
||||
|
||||
@auto_fp16()
|
||||
@abstractmethod
|
||||
def forward(self, inputs, img_metas):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg):
|
||||
"""Forward function for training.
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`depth/datasets/pipelines/formatting.py:Collect`.
|
||||
depth_gt (Tensor): GT depth
|
||||
train_cfg (dict): The training config.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
depth_pred = self.forward(inputs, img_metas)
|
||||
losses = self.losses(depth_pred, depth_gt)
|
||||
|
||||
log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
|
||||
losses.update(**log_imgs)
|
||||
|
||||
return losses
|
||||
|
||||
def forward_test(self, inputs, img_metas, test_cfg):
|
||||
"""Forward function for testing.
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`depth/datasets/pipelines/formatting.py:Collect`.
|
||||
test_cfg (dict): The testing config.
|
||||
|
||||
Returns:
|
||||
Tensor: Output depth map.
|
||||
"""
|
||||
return self.forward(inputs, img_metas)
|
||||
|
||||
def depth_pred(self, feat):
|
||||
"""Prediction each pixel."""
|
||||
if self.classify:
|
||||
logit = self.conv_depth(feat)
|
||||
|
||||
if self.bins_strategy == "UD":
|
||||
bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
||||
elif self.bins_strategy == "SID":
|
||||
bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
|
||||
|
||||
# following Adabins, default linear
|
||||
if self.norm_strategy == "linear":
|
||||
logit = torch.relu(logit)
|
||||
eps = 0.1
|
||||
logit = logit + eps
|
||||
logit = logit / logit.sum(dim=1, keepdim=True)
|
||||
elif self.norm_strategy == "softmax":
|
||||
logit = torch.softmax(logit, dim=1)
|
||||
elif self.norm_strategy == "sigmoid":
|
||||
logit = torch.sigmoid(logit)
|
||||
logit = logit / logit.sum(dim=1, keepdim=True)
|
||||
|
||||
output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
|
||||
|
||||
else:
|
||||
if self.scale_up:
|
||||
output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
|
||||
else:
|
||||
output = self.relu(self.conv_depth(feat)) + self.min_depth
|
||||
return output
|
||||
|
||||
@force_fp32(apply_to=("depth_pred",))
|
||||
def losses(self, depth_pred, depth_gt):
|
||||
"""Compute depth loss."""
|
||||
loss = dict()
|
||||
depth_pred = resize(
|
||||
input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
|
||||
)
|
||||
if not isinstance(self.loss_decode, nn.ModuleList):
|
||||
losses_decode = [self.loss_decode]
|
||||
else:
|
||||
losses_decode = self.loss_decode
|
||||
for loss_decode in losses_decode:
|
||||
if loss_decode.loss_name not in loss:
|
||||
loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
|
||||
else:
|
||||
loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
|
||||
return loss
|
||||
|
||||
def log_images(self, img_path, depth_pred, depth_gt, img_meta):
|
||||
show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
|
||||
show_img = show_img.numpy().astype(np.float32)
|
||||
show_img = mmcv.imdenormalize(
|
||||
show_img,
|
||||
img_meta["img_norm_cfg"]["mean"],
|
||||
img_meta["img_norm_cfg"]["std"],
|
||||
img_meta["img_norm_cfg"]["to_rgb"],
|
||||
)
|
||||
show_img = np.clip(show_img, 0, 255)
|
||||
show_img = show_img.astype(np.uint8)
|
||||
show_img = show_img[:, :, ::-1]
|
||||
show_img = show_img.transpose(0, 2, 1)
|
||||
show_img = show_img.transpose(1, 0, 2)
|
||||
|
||||
depth_pred = depth_pred / torch.max(depth_pred)
|
||||
depth_gt = depth_gt / torch.max(depth_gt)
|
||||
|
||||
depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
|
||||
depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
|
||||
|
||||
return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
|
|
@ -0,0 +1,270 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule, Linear, build_activation_layer
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from ...ops import resize
|
||||
from ..builder import HEADS
|
||||
from .decode_head import DepthBaseDecodeHead
|
||||
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
def __init__(self, scale_factor, mode, align_corners=False):
|
||||
super(Interpolate, self).__init__()
|
||||
self.interp = nn.functional.interpolate
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
||||
return x
|
||||
|
||||
|
||||
class HeadDepth(nn.Module):
|
||||
def __init__(self, features):
|
||||
super(HeadDepth, self).__init__()
|
||||
self.head = nn.Sequential(
|
||||
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
||||
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
||||
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
class ReassembleBlocks(BaseModule):
|
||||
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
||||
rearrange the feature vector to feature map.
|
||||
Args:
|
||||
in_channels (int): ViT feature channels. Default: 768.
|
||||
out_channels (List): output channels of each stage.
|
||||
Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16, init_cfg=None
|
||||
):
|
||||
super(ReassembleBlocks, self).__init__(init_cfg)
|
||||
|
||||
assert readout_type in ["ignore", "add", "project"]
|
||||
self.readout_type = readout_type
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.projects = nn.ModuleList(
|
||||
[
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
act_cfg=None,
|
||||
)
|
||||
for out_channel in out_channels
|
||||
]
|
||||
)
|
||||
|
||||
self.resize_layers = nn.ModuleList(
|
||||
[
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
||||
),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
||||
),
|
||||
]
|
||||
)
|
||||
if self.readout_type == "project":
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(
|
||||
nn.Sequential(Linear(2 * in_channels, in_channels), build_activation_layer(dict(type="GELU")))
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
assert isinstance(inputs, list)
|
||||
out = []
|
||||
for i, x in enumerate(inputs):
|
||||
assert len(x) == 2
|
||||
x, cls_token = x[0], x[1]
|
||||
feature_shape = x.shape
|
||||
if self.readout_type == "project":
|
||||
x = x.flatten(2).permute((0, 2, 1))
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
x = x.permute(0, 2, 1).reshape(feature_shape)
|
||||
elif self.readout_type == "add":
|
||||
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
||||
x = x.reshape(feature_shape)
|
||||
else:
|
||||
pass
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
class PreActResidualConvUnit(BaseModule):
|
||||
"""ResidualConvUnit, pre-activate residual unit.
|
||||
Args:
|
||||
in_channels (int): number of channels in the input feature map.
|
||||
act_cfg (dict): dictionary to construct and config activation layer.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
stride (int): stride of the first block. Default: 1
|
||||
dilation (int): dilation rate for convs layers. Default: 1.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None):
|
||||
super(PreActResidualConvUnit, self).__init__(init_cfg)
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
bias=False,
|
||||
order=("act", "conv", "norm"),
|
||||
)
|
||||
|
||||
self.conv2 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
bias=False,
|
||||
order=("act", "conv", "norm"),
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs_ = inputs.clone()
|
||||
x = self.conv1(inputs)
|
||||
x = self.conv2(x)
|
||||
return x + inputs_
|
||||
|
||||
|
||||
class FeatureFusionBlock(BaseModule):
|
||||
"""FeatureFusionBlock, merge feature map from different stages.
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
act_cfg (dict): The activation config for ResidualConvUnit.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
expand (bool): Whether expand the channels in post process block.
|
||||
Default: False.
|
||||
align_corners (bool): align_corner setting for bilinear upsample.
|
||||
Default: True.
|
||||
init_cfg (dict, optional): Initialization config dict. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None):
|
||||
super(FeatureFusionBlock, self).__init__(init_cfg)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.expand = expand
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.out_channels = in_channels
|
||||
if self.expand:
|
||||
self.out_channels = in_channels // 2
|
||||
|
||||
self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True)
|
||||
|
||||
self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||
self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
||||
|
||||
def forward(self, *inputs):
|
||||
x = inputs[0]
|
||||
if len(inputs) == 2:
|
||||
if x.shape != inputs[1].shape:
|
||||
res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
|
||||
else:
|
||||
res = inputs[1]
|
||||
x = x + self.res_conv_unit1(res)
|
||||
x = self.res_conv_unit2(x)
|
||||
x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
|
||||
x = self.project(x)
|
||||
return x
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class DPTHead(DepthBaseDecodeHead):
|
||||
"""Vision Transformers for Dense Prediction.
|
||||
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
||||
Args:
|
||||
embed_dims (int): The embed dimension of the ViT backbone.
|
||||
Default: 768.
|
||||
post_process_channels (List): Out channels of post process conv
|
||||
layers. Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
expand_channels (bool): Whether expand the channels in post process
|
||||
block. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims=768,
|
||||
post_process_channels=[96, 192, 384, 768],
|
||||
readout_type="ignore",
|
||||
patch_size=16,
|
||||
expand_channels=False,
|
||||
**kwargs
|
||||
):
|
||||
super(DPTHead, self).__init__(**kwargs)
|
||||
|
||||
self.in_channels = self.in_channels
|
||||
self.expand_channels = expand_channels
|
||||
self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
|
||||
|
||||
self.post_process_channels = [
|
||||
channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
|
||||
]
|
||||
self.convs = nn.ModuleList()
|
||||
for channel in self.post_process_channels:
|
||||
self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False))
|
||||
self.fusion_blocks = nn.ModuleList()
|
||||
for _ in range(len(self.convs)):
|
||||
self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg))
|
||||
self.fusion_blocks[0].res_conv_unit1 = None
|
||||
self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg)
|
||||
self.num_fusion_blocks = len(self.fusion_blocks)
|
||||
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
||||
self.num_post_process_channels = len(self.post_process_channels)
|
||||
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
||||
assert self.num_reassemble_blocks == self.num_post_process_channels
|
||||
self.conv_depth = HeadDepth(self.channels)
|
||||
|
||||
def forward(self, inputs, img_metas):
|
||||
assert len(inputs) == self.num_reassemble_blocks
|
||||
x = [inp for inp in inputs]
|
||||
x = self.reassemble_blocks(x)
|
||||
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
||||
out = self.fusion_blocks[0](x[-1])
|
||||
for i in range(1, len(self.fusion_blocks)):
|
||||
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
||||
out = self.project(out)
|
||||
out = self.depth_pred(out)
|
||||
return out
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...ops import resize
|
||||
from ..builder import HEADS
|
||||
from .decode_head import DepthBaseDecodeHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class BNHead(DepthBaseDecodeHead):
|
||||
"""Just a batchnorm."""
|
||||
|
||||
def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.input_transform = input_transform
|
||||
self.in_index = in_index
|
||||
self.upsample = upsample
|
||||
# self.bn = nn.SyncBatchNorm(self.in_channels)
|
||||
if self.classify:
|
||||
self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
|
||||
else:
|
||||
self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
|
||||
|
||||
def _transform_inputs(self, inputs):
|
||||
"""Transform inputs for decoder.
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
Returns:
|
||||
Tensor: The transformed inputs
|
||||
"""
|
||||
|
||||
if "concat" in self.input_transform:
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
if "resize" in self.input_transform:
|
||||
inputs = [
|
||||
resize(
|
||||
input=x,
|
||||
size=[s * self.upsample for s in inputs[0].shape[2:]],
|
||||
mode="bilinear",
|
||||
align_corners=self.align_corners,
|
||||
)
|
||||
for x in inputs
|
||||
]
|
||||
inputs = torch.cat(inputs, dim=1)
|
||||
elif self.input_transform == "multiple_select":
|
||||
inputs = [inputs[i] for i in self.in_index]
|
||||
else:
|
||||
inputs = inputs[self.in_index]
|
||||
|
||||
return inputs
|
||||
|
||||
def _forward_feature(self, inputs, img_metas=None, **kwargs):
|
||||
"""Forward function for feature maps before classifying each pixel with
|
||||
``self.cls_seg`` fc.
|
||||
Args:
|
||||
inputs (list[Tensor]): List of multi-level img features.
|
||||
Returns:
|
||||
feats (Tensor): A tensor of shape (batch_size, self.channels,
|
||||
H, W) which is feature map for last layer of decoder head.
|
||||
"""
|
||||
# accept lists (for cls token)
|
||||
inputs = list(inputs)
|
||||
for i, x in enumerate(inputs):
|
||||
if len(x) == 2:
|
||||
x, cls_token = x[0], x[1]
|
||||
if len(x.shape) == 2:
|
||||
x = x[:, :, None, None]
|
||||
cls_token = cls_token[:, :, None, None].expand_as(x)
|
||||
inputs[i] = torch.cat((x, cls_token), 1)
|
||||
else:
|
||||
x = x[0]
|
||||
if len(x.shape) == 2:
|
||||
x = x[:, :, None, None]
|
||||
inputs[i] = x
|
||||
x = self._transform_inputs(inputs)
|
||||
# feats = self.bn(x)
|
||||
return x
|
||||
|
||||
def forward(self, inputs, img_metas=None, **kwargs):
|
||||
"""Forward function."""
|
||||
output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
|
||||
output = self.depth_pred(output)
|
||||
|
||||
return output
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .base import BaseDepther
|
||||
from .encoder_decoder import DepthEncoderDecoder
|
|
@ -0,0 +1,194 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from mmcv.runner import BaseModule, auto_fp16
|
||||
|
||||
|
||||
class BaseDepther(BaseModule, metaclass=ABCMeta):
|
||||
"""Base class for depther."""
|
||||
|
||||
def __init__(self, init_cfg=None):
|
||||
super(BaseDepther, self).__init__(init_cfg)
|
||||
self.fp16_enabled = False
|
||||
|
||||
@property
|
||||
def with_neck(self):
|
||||
"""bool: whether the depther has neck"""
|
||||
return hasattr(self, "neck") and self.neck is not None
|
||||
|
||||
@property
|
||||
def with_auxiliary_head(self):
|
||||
"""bool: whether the depther has auxiliary head"""
|
||||
return hasattr(self, "auxiliary_head") and self.auxiliary_head is not None
|
||||
|
||||
@property
|
||||
def with_decode_head(self):
|
||||
"""bool: whether the depther has decode head"""
|
||||
return hasattr(self, "decode_head") and self.decode_head is not None
|
||||
|
||||
@abstractmethod
|
||||
def extract_feat(self, imgs):
|
||||
"""Placeholder for extract features from images."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def encode_decode(self, img, img_metas):
|
||||
"""Placeholder for encode images with backbone and decode into a
|
||||
semantic depth map of the same size as input."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward_train(self, imgs, img_metas, **kwargs):
|
||||
"""Placeholder for Forward function for training."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def simple_test(self, img, img_meta, **kwargs):
|
||||
"""Placeholder for single image test."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def aug_test(self, imgs, img_metas, **kwargs):
|
||||
"""Placeholder for augmentation test."""
|
||||
pass
|
||||
|
||||
def forward_test(self, imgs, img_metas, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
imgs (List[Tensor]): the outer list indicates test-time
|
||||
augmentations and inner Tensor should have a shape NxCxHxW,
|
||||
which contains all images in the batch.
|
||||
img_metas (List[List[dict]]): the outer list indicates test-time
|
||||
augs (multiscale, flip, etc.) and the inner list indicates
|
||||
images in a batch.
|
||||
"""
|
||||
for var, name in [(imgs, "imgs"), (img_metas, "img_metas")]:
|
||||
if not isinstance(var, list):
|
||||
raise TypeError(f"{name} must be a list, but got " f"{type(var)}")
|
||||
num_augs = len(imgs)
|
||||
if num_augs != len(img_metas):
|
||||
raise ValueError(f"num of augmentations ({len(imgs)}) != " f"num of image meta ({len(img_metas)})")
|
||||
# all images in the same aug batch all of the same ori_shape and pad
|
||||
# shape
|
||||
for img_meta in img_metas:
|
||||
ori_shapes = [_["ori_shape"] for _ in img_meta]
|
||||
assert all(shape == ori_shapes[0] for shape in ori_shapes)
|
||||
img_shapes = [_["img_shape"] for _ in img_meta]
|
||||
assert all(shape == img_shapes[0] for shape in img_shapes)
|
||||
pad_shapes = [_["pad_shape"] for _ in img_meta]
|
||||
assert all(shape == pad_shapes[0] for shape in pad_shapes)
|
||||
|
||||
if num_augs == 1:
|
||||
return self.simple_test(imgs[0], img_metas[0], **kwargs)
|
||||
else:
|
||||
return self.aug_test(imgs, img_metas, **kwargs)
|
||||
|
||||
@auto_fp16(apply_to=("img",))
|
||||
def forward(self, img, img_metas, return_loss=True, **kwargs):
|
||||
"""Calls either :func:`forward_train` or :func:`forward_test` depending
|
||||
on whether ``return_loss`` is ``True``.
|
||||
|
||||
Note this setting will change the expected inputs. When
|
||||
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
|
||||
and List[dict]), and when ``resturn_loss=False``, img and img_meta
|
||||
should be double nested (i.e. List[Tensor], List[List[dict]]), with
|
||||
the outer list indicating test time augmentations.
|
||||
"""
|
||||
if return_loss:
|
||||
return self.forward_train(img, img_metas, **kwargs)
|
||||
else:
|
||||
return self.forward_test(img, img_metas, **kwargs)
|
||||
|
||||
def train_step(self, data_batch, optimizer, **kwargs):
|
||||
"""The iteration step during training.
|
||||
|
||||
This method defines an iteration step during training, except for the
|
||||
back propagation and optimizer updating, which are done in an optimizer
|
||||
hook. Note that in some complicated cases or models, the whole process
|
||||
including back propagation and optimizer updating is also defined in
|
||||
this method, such as GAN.
|
||||
|
||||
Args:
|
||||
data (dict): The output of dataloader.
|
||||
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
|
||||
runner is passed to ``train_step()``. This argument is unused
|
||||
and reserved.
|
||||
|
||||
Returns:
|
||||
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
|
||||
``num_samples``.
|
||||
``loss`` is a tensor for back propagation, which can be a
|
||||
weighted sum of multiple losses.
|
||||
``log_vars`` contains all the variables to be sent to the
|
||||
logger.
|
||||
``num_samples`` indicates the batch size (when the model is
|
||||
DDP, it means the batch size on each GPU), which is used for
|
||||
averaging the logs.
|
||||
"""
|
||||
losses = self(**data_batch)
|
||||
|
||||
# split losses and images
|
||||
real_losses = {}
|
||||
log_imgs = {}
|
||||
for k, v in losses.items():
|
||||
if "img" in k:
|
||||
log_imgs[k] = v
|
||||
else:
|
||||
real_losses[k] = v
|
||||
|
||||
loss, log_vars = self._parse_losses(real_losses)
|
||||
|
||||
outputs = dict(loss=loss, log_vars=log_vars, num_samples=len(data_batch["img_metas"]), log_imgs=log_imgs)
|
||||
|
||||
return outputs
|
||||
|
||||
def val_step(self, data_batch, **kwargs):
|
||||
"""The iteration step during validation.
|
||||
|
||||
This method shares the same signature as :func:`train_step`, but used
|
||||
during val epochs. Note that the evaluation after training epochs is
|
||||
not implemented with this method, but an evaluation hook.
|
||||
"""
|
||||
output = self(**data_batch, **kwargs)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _parse_losses(losses):
|
||||
"""Parse the raw outputs (losses) of the network.
|
||||
|
||||
Args:
|
||||
losses (dict): Raw output of the network, which usually contain
|
||||
losses and other necessary information.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
|
||||
which may be a weighted sum of all losses, log_vars contains
|
||||
all the variables to be sent to the logger.
|
||||
"""
|
||||
log_vars = OrderedDict()
|
||||
for loss_name, loss_value in losses.items():
|
||||
if isinstance(loss_value, torch.Tensor):
|
||||
log_vars[loss_name] = loss_value.mean()
|
||||
elif isinstance(loss_value, list):
|
||||
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
|
||||
else:
|
||||
raise TypeError(f"{loss_name} is not a tensor or list of tensors")
|
||||
|
||||
loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
|
||||
|
||||
log_vars["loss"] = loss
|
||||
for loss_name, loss_value in log_vars.items():
|
||||
# reduce loss when distributed training
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
loss_value = loss_value.data.clone()
|
||||
dist.all_reduce(loss_value.div_(dist.get_world_size()))
|
||||
log_vars[loss_name] = loss_value.item()
|
||||
|
||||
return loss, log_vars
|
|
@ -0,0 +1,236 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...models import builder
|
||||
from ...models.builder import DEPTHER
|
||||
from ...ops import resize
|
||||
from .base import BaseDepther
|
||||
|
||||
|
||||
def add_prefix(inputs, prefix):
|
||||
"""Add prefix for dict.
|
||||
|
||||
Args:
|
||||
inputs (dict): The input dict with str keys.
|
||||
prefix (str): The prefix to add.
|
||||
|
||||
Returns:
|
||||
|
||||
dict: The dict with keys updated with ``prefix``.
|
||||
"""
|
||||
|
||||
outputs = dict()
|
||||
for name, value in inputs.items():
|
||||
outputs[f"{prefix}.{name}"] = value
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@DEPTHER.register_module()
|
||||
class DepthEncoderDecoder(BaseDepther):
|
||||
"""Encoder Decoder depther.
|
||||
|
||||
EncoderDecoder typically consists of backbone, (neck) and decode_head.
|
||||
"""
|
||||
|
||||
def __init__(self, backbone, decode_head, neck=None, train_cfg=None, test_cfg=None, pretrained=None, init_cfg=None):
|
||||
super(DepthEncoderDecoder, self).__init__(init_cfg)
|
||||
if pretrained is not None:
|
||||
assert backbone.get("pretrained") is None, "both backbone and depther set pretrained weight"
|
||||
backbone.pretrained = pretrained
|
||||
self.backbone = builder.build_backbone(backbone)
|
||||
self._init_decode_head(decode_head)
|
||||
|
||||
if neck is not None:
|
||||
self.neck = builder.build_neck(neck)
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
assert self.with_decode_head
|
||||
|
||||
def _init_decode_head(self, decode_head):
|
||||
"""Initialize ``decode_head``"""
|
||||
self.decode_head = builder.build_head(decode_head)
|
||||
self.align_corners = self.decode_head.align_corners
|
||||
|
||||
def extract_feat(self, img):
|
||||
"""Extract features from images."""
|
||||
x = self.backbone(img)
|
||||
if self.with_neck:
|
||||
x = self.neck(x)
|
||||
return x
|
||||
|
||||
def encode_decode(self, img, img_metas, rescale=True, size=None):
|
||||
"""Encode images with backbone and decode into a depth estimation
|
||||
map of the same size as input."""
|
||||
x = self.extract_feat(img)
|
||||
out = self._decode_head_forward_test(x, img_metas)
|
||||
# crop the pred depth to the certain range.
|
||||
out = torch.clamp(out, min=self.decode_head.min_depth, max=self.decode_head.max_depth)
|
||||
if rescale:
|
||||
if size is None:
|
||||
if img_metas is not None:
|
||||
size = img_metas[0]["ori_shape"][:2]
|
||||
else:
|
||||
size = img.shape[2:]
|
||||
out = resize(input=out, size=size, mode="bilinear", align_corners=self.align_corners)
|
||||
return out
|
||||
|
||||
def _decode_head_forward_train(self, img, x, img_metas, depth_gt, **kwargs):
|
||||
"""Run forward function and calculate loss for decode head in
|
||||
training."""
|
||||
losses = dict()
|
||||
loss_decode = self.decode_head.forward_train(img, x, img_metas, depth_gt, self.train_cfg, **kwargs)
|
||||
losses.update(add_prefix(loss_decode, "decode"))
|
||||
return losses
|
||||
|
||||
def _decode_head_forward_test(self, x, img_metas):
|
||||
"""Run forward function and calculate loss for decode head in
|
||||
inference."""
|
||||
depth_pred = self.decode_head.forward_test(x, img_metas, self.test_cfg)
|
||||
return depth_pred
|
||||
|
||||
def forward_dummy(self, img):
|
||||
"""Dummy forward function."""
|
||||
depth = self.encode_decode(img, None)
|
||||
|
||||
return depth
|
||||
|
||||
def forward_train(self, img, img_metas, depth_gt, **kwargs):
|
||||
"""Forward function for training.
|
||||
|
||||
Args:
|
||||
img (Tensor): Input images.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`depth/datasets/pipelines/formatting.py:Collect`.
|
||||
depth_gt (Tensor): Depth gt
|
||||
used if the architecture supports depth estimation task.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
x = self.extract_feat(img)
|
||||
|
||||
losses = dict()
|
||||
|
||||
# the last of x saves the info from neck
|
||||
loss_decode = self._decode_head_forward_train(img, x, img_metas, depth_gt, **kwargs)
|
||||
|
||||
losses.update(loss_decode)
|
||||
|
||||
return losses
|
||||
|
||||
def whole_inference(self, img, img_meta, rescale, size=None):
|
||||
"""Inference with full image."""
|
||||
depth_pred = self.encode_decode(img, img_meta, rescale, size=size)
|
||||
|
||||
return depth_pred
|
||||
|
||||
def slide_inference(self, img, img_meta, rescale):
|
||||
"""Inference by sliding-window with overlap.
|
||||
|
||||
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
||||
decode without padding.
|
||||
"""
|
||||
|
||||
h_stride, w_stride = self.test_cfg.stride
|
||||
h_crop, w_crop = self.test_cfg.crop_size
|
||||
batch_size, _, h_img, w_img = img.size()
|
||||
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
||||
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
||||
preds = img.new_zeros((batch_size, 1, h_img, w_img))
|
||||
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
||||
for h_idx in range(h_grids):
|
||||
for w_idx in range(w_grids):
|
||||
y1 = h_idx * h_stride
|
||||
x1 = w_idx * w_stride
|
||||
y2 = min(y1 + h_crop, h_img)
|
||||
x2 = min(x1 + w_crop, w_img)
|
||||
y1 = max(y2 - h_crop, 0)
|
||||
x1 = max(x2 - w_crop, 0)
|
||||
crop_img = img[:, :, y1:y2, x1:x2]
|
||||
depth_pred = self.encode_decode(crop_img, img_meta, rescale)
|
||||
preds += F.pad(depth_pred, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
|
||||
|
||||
count_mat[:, :, y1:y2, x1:x2] += 1
|
||||
assert (count_mat == 0).sum() == 0
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
# cast count_mat to constant while exporting to ONNX
|
||||
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
|
||||
preds = preds / count_mat
|
||||
return preds
|
||||
|
||||
def inference(self, img, img_meta, rescale, size=None):
|
||||
"""Inference with slide/whole style.
|
||||
|
||||
Args:
|
||||
img (Tensor): The input image of shape (N, 3, H, W).
|
||||
img_meta (dict): Image info dict where each dict has: 'img_shape',
|
||||
'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`depth/datasets/pipelines/formatting.py:Collect`.
|
||||
rescale (bool): Whether rescale back to original shape.
|
||||
|
||||
Returns:
|
||||
Tensor: The output depth map.
|
||||
"""
|
||||
|
||||
assert self.test_cfg.mode in ["slide", "whole"]
|
||||
ori_shape = img_meta[0]["ori_shape"]
|
||||
assert all(_["ori_shape"] == ori_shape for _ in img_meta)
|
||||
if self.test_cfg.mode == "slide":
|
||||
depth_pred = self.slide_inference(img, img_meta, rescale)
|
||||
else:
|
||||
depth_pred = self.whole_inference(img, img_meta, rescale, size=size)
|
||||
output = depth_pred
|
||||
flip = img_meta[0]["flip"]
|
||||
if flip:
|
||||
flip_direction = img_meta[0]["flip_direction"]
|
||||
assert flip_direction in ["horizontal", "vertical"]
|
||||
if flip_direction == "horizontal":
|
||||
output = output.flip(dims=(3,))
|
||||
elif flip_direction == "vertical":
|
||||
output = output.flip(dims=(2,))
|
||||
|
||||
return output
|
||||
|
||||
def simple_test(self, img, img_meta, rescale=True):
|
||||
"""Simple test with single image."""
|
||||
depth_pred = self.inference(img, img_meta, rescale)
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
# our inference backend only support 4D output
|
||||
depth_pred = depth_pred.unsqueeze(0)
|
||||
return depth_pred
|
||||
depth_pred = depth_pred.cpu().numpy()
|
||||
# unravel batch dim
|
||||
depth_pred = list(depth_pred)
|
||||
return depth_pred
|
||||
|
||||
def aug_test(self, imgs, img_metas, rescale=True):
|
||||
"""Test with augmentations.
|
||||
|
||||
Only rescale=True is supported.
|
||||
"""
|
||||
# aug_test rescale all imgs back to ori_shape for now
|
||||
assert rescale
|
||||
# to save memory, we get augmented depth logit inplace
|
||||
depth_pred = self.inference(imgs[0], img_metas[0], rescale)
|
||||
for i in range(1, len(imgs)):
|
||||
cur_depth_pred = self.inference(imgs[i], img_metas[i], rescale, size=depth_pred.shape[-2:])
|
||||
depth_pred += cur_depth_pred
|
||||
depth_pred /= len(imgs)
|
||||
depth_pred = depth_pred.cpu().numpy()
|
||||
# unravel batch dim
|
||||
depth_pred = list(depth_pred)
|
||||
return depth_pred
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .gradientloss import GradientLoss
|
||||
from .sigloss import SigLoss
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...models.builder import LOSSES
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class GradientLoss(nn.Module):
|
||||
"""GradientLoss.
|
||||
|
||||
Adapted from https://www.cs.cornell.edu/projects/megadepth/
|
||||
|
||||
Args:
|
||||
valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
|
||||
loss_weight (float): Weight of the loss. Default: 1.0.
|
||||
max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, valid_mask=True, loss_weight=1.0, max_depth=None, loss_name="loss_grad"):
|
||||
super(GradientLoss, self).__init__()
|
||||
self.valid_mask = valid_mask
|
||||
self.loss_weight = loss_weight
|
||||
self.max_depth = max_depth
|
||||
self.loss_name = loss_name
|
||||
|
||||
self.eps = 0.001 # avoid grad explode
|
||||
|
||||
def gradientloss(self, input, target):
|
||||
input_downscaled = [input] + [input[:: 2 * i, :: 2 * i] for i in range(1, 4)]
|
||||
target_downscaled = [target] + [target[:: 2 * i, :: 2 * i] for i in range(1, 4)]
|
||||
|
||||
gradient_loss = 0
|
||||
for input, target in zip(input_downscaled, target_downscaled):
|
||||
if self.valid_mask:
|
||||
mask = target > 0
|
||||
if self.max_depth is not None:
|
||||
mask = torch.logical_and(target > 0, target <= self.max_depth)
|
||||
N = torch.sum(mask)
|
||||
else:
|
||||
mask = torch.ones_like(target)
|
||||
N = input.numel()
|
||||
input_log = torch.log(input + self.eps)
|
||||
target_log = torch.log(target + self.eps)
|
||||
log_d_diff = input_log - target_log
|
||||
|
||||
log_d_diff = torch.mul(log_d_diff, mask)
|
||||
|
||||
v_gradient = torch.abs(log_d_diff[0:-2, :] - log_d_diff[2:, :])
|
||||
v_mask = torch.mul(mask[0:-2, :], mask[2:, :])
|
||||
v_gradient = torch.mul(v_gradient, v_mask)
|
||||
|
||||
h_gradient = torch.abs(log_d_diff[:, 0:-2] - log_d_diff[:, 2:])
|
||||
h_mask = torch.mul(mask[:, 0:-2], mask[:, 2:])
|
||||
h_gradient = torch.mul(h_gradient, h_mask)
|
||||
|
||||
gradient_loss += (torch.sum(h_gradient) + torch.sum(v_gradient)) / N
|
||||
|
||||
return gradient_loss
|
||||
|
||||
def forward(self, depth_pred, depth_gt):
|
||||
"""Forward function."""
|
||||
|
||||
gradient_loss = self.loss_weight * self.gradientloss(depth_pred, depth_gt)
|
||||
return gradient_loss
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...models.builder import LOSSES
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class SigLoss(nn.Module):
|
||||
"""SigLoss.
|
||||
|
||||
This follows `AdaBins <https://arxiv.org/abs/2011.14141>`_.
|
||||
|
||||
Args:
|
||||
valid_mask (bool): Whether filter invalid gt (gt > 0). Default: True.
|
||||
loss_weight (float): Weight of the loss. Default: 1.0.
|
||||
max_depth (int): When filtering invalid gt, set a max threshold. Default: None.
|
||||
warm_up (bool): A simple warm up stage to help convergence. Default: False.
|
||||
warm_iter (int): The number of warm up stage. Default: 100.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, valid_mask=True, loss_weight=1.0, max_depth=None, warm_up=False, warm_iter=100, loss_name="sigloss"
|
||||
):
|
||||
super(SigLoss, self).__init__()
|
||||
self.valid_mask = valid_mask
|
||||
self.loss_weight = loss_weight
|
||||
self.max_depth = max_depth
|
||||
self.loss_name = loss_name
|
||||
|
||||
self.eps = 0.001 # avoid grad explode
|
||||
|
||||
# HACK: a hack implementation for warmup sigloss
|
||||
self.warm_up = warm_up
|
||||
self.warm_iter = warm_iter
|
||||
self.warm_up_counter = 0
|
||||
|
||||
def sigloss(self, input, target):
|
||||
if self.valid_mask:
|
||||
valid_mask = target > 0
|
||||
if self.max_depth is not None:
|
||||
valid_mask = torch.logical_and(target > 0, target <= self.max_depth)
|
||||
input = input[valid_mask]
|
||||
target = target[valid_mask]
|
||||
|
||||
if self.warm_up:
|
||||
if self.warm_up_counter < self.warm_iter:
|
||||
g = torch.log(input + self.eps) - torch.log(target + self.eps)
|
||||
g = 0.15 * torch.pow(torch.mean(g), 2)
|
||||
self.warm_up_counter += 1
|
||||
return torch.sqrt(g)
|
||||
|
||||
g = torch.log(input + self.eps) - torch.log(target + self.eps)
|
||||
Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2)
|
||||
return torch.sqrt(Dg)
|
||||
|
||||
def forward(self, depth_pred, depth_gt):
|
||||
"""Forward function."""
|
||||
|
||||
loss_depth = self.loss_weight * self.sigloss(depth_pred, depth_gt)
|
||||
return loss_depth
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .wrappers import resize
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def resize(input, size=None, scale_factor=None, mode="nearest", align_corners=None, warning=False):
|
||||
if warning:
|
||||
if size is not None and align_corners:
|
||||
input_h, input_w = tuple(int(x) for x in input.shape[2:])
|
||||
output_h, output_w = tuple(int(x) for x in size)
|
||||
if output_h > input_h or output_w > output_h:
|
||||
if (
|
||||
(output_h > 1 and output_w > 1 and input_h > 1 and input_w > 1)
|
||||
and (output_h - 1) % (input_h - 1)
|
||||
and (output_w - 1) % (input_w - 1)
|
||||
):
|
||||
warnings.warn(
|
||||
f"When align_corners={align_corners}, "
|
||||
"the output would more aligned if "
|
||||
f"input size {(input_h, input_w)} is `x+1` and "
|
||||
f"out size {(output_h, output_w)} is `nx+1`"
|
||||
)
|
||||
return F.interpolate(input, size, scale_factor, mode, align_corners)
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue