Expose DPT depth models via torch.hub.load() (#238)
Add streamlined model versions w/o the mmcv dependency to directly load them via torch.hub.load().pull/247/head
parent
82185b17a8
commit
e7df9fc95d
|
@ -3,5 +3,5 @@
|
|||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .decode_heads import BNHead
|
||||
from .decode_heads import BNHead, DPTHead
|
||||
from .encoder_decoder import DepthEncoderDecoder
|
||||
|
|
|
@ -4,6 +4,9 @@
|
|||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -29,6 +32,8 @@ class DepthBaseDecodeHead(nn.Module):
|
|||
Args:
|
||||
in_channels (List): Input channels.
|
||||
channels (int): Channels after modules, before conv_depth.
|
||||
conv_layer (nn.Module): Conv layers. Default: None.
|
||||
act_layer (nn.Module): Activation layers. Default: nn.ReLU.
|
||||
loss_decode (dict): Config of decode loss.
|
||||
Default: ().
|
||||
sampler (dict|None): The config of depth map sampler.
|
||||
|
@ -39,7 +44,7 @@ class DepthBaseDecodeHead(nn.Module):
|
|||
Default: 1e-3.
|
||||
max_depth (int): Max depth in dataset setting.
|
||||
Default: None.
|
||||
norm_cfg (dict|None): Config of norm layers.
|
||||
norm_layer (dict|None): Norm layers.
|
||||
Default: None.
|
||||
classify (bool): Whether predict depth in a cls.-reg. manner.
|
||||
Default: False.
|
||||
|
@ -56,13 +61,15 @@ class DepthBaseDecodeHead(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
conv_layer=None,
|
||||
act_layer=nn.ReLU,
|
||||
channels=96,
|
||||
loss_decode=(),
|
||||
sampler=None,
|
||||
align_corners=False,
|
||||
min_depth=1e-3,
|
||||
max_depth=None,
|
||||
norm_cfg=None,
|
||||
norm_layer=None,
|
||||
classify=False,
|
||||
n_bins=256,
|
||||
bins_strategy="UD",
|
||||
|
@ -73,11 +80,13 @@ class DepthBaseDecodeHead(nn.Module):
|
|||
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.conf_layer = conv_layer
|
||||
self.act_layer = act_layer
|
||||
self.loss_decode = loss_decode
|
||||
self.align_corners = align_corners
|
||||
self.min_depth = min_depth
|
||||
self.max_depth = max_depth
|
||||
self.norm_cfg = norm_cfg
|
||||
self.norm_layer = norm_layer
|
||||
self.classify = classify
|
||||
self.n_bins = n_bins
|
||||
self.scale_up = scale_up
|
||||
|
@ -285,3 +294,454 @@ class BNHead(DepthBaseDecodeHead):
|
|||
output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
|
||||
output = self.depth_pred(output)
|
||||
return output
|
||||
|
||||
|
||||
class ConvModule(nn.Module):
|
||||
"""A conv block that bundles conv/norm/activation layers.
|
||||
|
||||
This block simplifies the usage of convolution layers, which are commonly
|
||||
used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
||||
It is based upon three build methods: `build_conv_layer()`,
|
||||
`build_norm_layer()` and `build_activation_layer()`.
|
||||
|
||||
Besides, we add some additional features in this module.
|
||||
1. Automatically set `bias` of the conv layer.
|
||||
2. Spectral norm is supported.
|
||||
3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
|
||||
supports zero and circular padding, and we add "reflect" padding mode.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input feature map.
|
||||
Same as that in ``nn._ConvNd``.
|
||||
out_channels (int): Number of channels produced by the convolution.
|
||||
Same as that in ``nn._ConvNd``.
|
||||
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
||||
Same as that in ``nn._ConvNd``.
|
||||
stride (int | tuple[int]): Stride of the convolution.
|
||||
Same as that in ``nn._ConvNd``.
|
||||
padding (int | tuple[int]): Zero-padding added to both sides of
|
||||
the input. Same as that in ``nn._ConvNd``.
|
||||
dilation (int | tuple[int]): Spacing between kernel elements.
|
||||
Same as that in ``nn._ConvNd``.
|
||||
groups (int): Number of blocked connections from input channels to
|
||||
output channels. Same as that in ``nn._ConvNd``.
|
||||
bias (bool | str): If specified as `auto`, it will be decided by the
|
||||
norm_layer. Bias will be set as True if `norm_layer` is None, otherwise
|
||||
False. Default: "auto".
|
||||
conv_layer (nn.Module): Convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_layer (nn.Module): Normalization layer. Default: None.
|
||||
act_layer (nn.Module): Activation layer. Default: nn.ReLU.
|
||||
inplace (bool): Whether to use inplace mode for activation.
|
||||
Default: True.
|
||||
with_spectral_norm (bool): Whether use spectral norm in conv module.
|
||||
Default: False.
|
||||
padding_mode (str): If the `padding_mode` has not been supported by
|
||||
current `Conv2d` in PyTorch, we will use our own padding layer
|
||||
instead. Currently, we support ['zeros', 'circular'] with official
|
||||
implementation and ['reflect'] with our own implementation.
|
||||
Default: 'zeros'.
|
||||
order (tuple[str]): The order of conv/norm/activation layers. It is a
|
||||
sequence of "conv", "norm" and "act". Common examples are
|
||||
("conv", "norm", "act") and ("act", "conv", "norm").
|
||||
Default: ('conv', 'norm', 'act').
|
||||
"""
|
||||
|
||||
_abbr_ = "conv_block"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias="auto",
|
||||
conv_layer=nn.Conv2d,
|
||||
norm_layer=None,
|
||||
act_layer=nn.ReLU,
|
||||
inplace=True,
|
||||
with_spectral_norm=False,
|
||||
padding_mode="zeros",
|
||||
order=("conv", "norm", "act"),
|
||||
):
|
||||
super(ConvModule, self).__init__()
|
||||
official_padding_mode = ["zeros", "circular"]
|
||||
self.conv_layer = conv_layer
|
||||
self.norm_layer = norm_layer
|
||||
self.act_layer = act_layer
|
||||
self.inplace = inplace
|
||||
self.with_spectral_norm = with_spectral_norm
|
||||
self.with_explicit_padding = padding_mode not in official_padding_mode
|
||||
self.order = order
|
||||
assert isinstance(self.order, tuple) and len(self.order) == 3
|
||||
assert set(order) == set(["conv", "norm", "act"])
|
||||
|
||||
self.with_norm = norm_layer is not None
|
||||
self.with_activation = act_layer is not None
|
||||
# if the conv layer is before a norm layer, bias is unnecessary.
|
||||
if bias == "auto":
|
||||
bias = not self.with_norm
|
||||
self.with_bias = bias
|
||||
|
||||
if self.with_explicit_padding:
|
||||
if padding_mode == "zeros":
|
||||
padding_layer = nn.ZeroPad2d
|
||||
else:
|
||||
raise AssertionError(f"Unsupported padding mode: {padding_mode}")
|
||||
self.pad = padding_layer(padding)
|
||||
|
||||
# reset padding to 0 for conv module
|
||||
conv_padding = 0 if self.with_explicit_padding else padding
|
||||
# build convolution layer
|
||||
self.conv = self.conv_layer(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=conv_padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
# export the attributes of self.conv to a higher level for convenience
|
||||
self.in_channels = self.conv.in_channels
|
||||
self.out_channels = self.conv.out_channels
|
||||
self.kernel_size = self.conv.kernel_size
|
||||
self.stride = self.conv.stride
|
||||
self.padding = padding
|
||||
self.dilation = self.conv.dilation
|
||||
self.transposed = self.conv.transposed
|
||||
self.output_padding = self.conv.output_padding
|
||||
self.groups = self.conv.groups
|
||||
|
||||
if self.with_spectral_norm:
|
||||
self.conv = nn.utils.spectral_norm(self.conv)
|
||||
|
||||
# build normalization layers
|
||||
if self.with_norm:
|
||||
# norm layer is after conv layer
|
||||
if order.index("norm") > order.index("conv"):
|
||||
norm_channels = out_channels
|
||||
else:
|
||||
norm_channels = in_channels
|
||||
norm = partial(norm_layer, num_features=norm_channels)
|
||||
self.add_module("norm", norm)
|
||||
if self.with_bias:
|
||||
from torch.nnModules.batchnorm import _BatchNorm
|
||||
from torch.nnModules.instancenorm import _InstanceNorm
|
||||
|
||||
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
|
||||
warnings.warn("Unnecessary conv bias before batch/instance norm")
|
||||
else:
|
||||
self.norm_name = None
|
||||
|
||||
# build activation layer
|
||||
if self.with_activation:
|
||||
# nn.Tanh has no 'inplace' argument
|
||||
# (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.HSigmoid, nn.Swish, nn.GELU)
|
||||
if not isinstance(act_layer, (nn.Tanh, nn.PReLU, nn.Sigmoid, nn.GELU)):
|
||||
act_layer = partial(act_layer, inplace=inplace)
|
||||
self.activate = act_layer()
|
||||
|
||||
# Use msra init by default
|
||||
self.init_weights()
|
||||
|
||||
@property
|
||||
def norm(self):
|
||||
if self.norm_name:
|
||||
return getattr(self, self.norm_name)
|
||||
else:
|
||||
return None
|
||||
|
||||
def init_weights(self):
|
||||
# 1. It is mainly for customized conv layers with their own
|
||||
# initialization manners by calling their own ``init_weights()``,
|
||||
# and we do not want ConvModule to override the initialization.
|
||||
# 2. For customized conv layers without their own initialization
|
||||
# manners (that is, they don't have their own ``init_weights()``)
|
||||
# and PyTorch's conv layers, they will be initialized by
|
||||
# this method with default ``kaiming_init``.
|
||||
# Note: For PyTorch's conv layers, they will be overwritten by our
|
||||
# initialization implementation using default ``kaiming_init``.
|
||||
if not hasattr(self.conv, "init_weights"):
|
||||
if self.with_activation and isinstance(self.act_layer, nn.LeakyReLU):
|
||||
nonlinearity = "leaky_relu"
|
||||
a = 0.01 # XXX: default negative_slope
|
||||
else:
|
||||
nonlinearity = "relu"
|
||||
a = 0
|
||||
if hasattr(self.conv, "weight") and self.conv.weight is not None:
|
||||
nn.init.kaiming_normal_(self.conv.weight, a=a, mode="fan_out", nonlinearity=nonlinearity)
|
||||
if hasattr(self.conv, "bias") and self.conv.bias is not None:
|
||||
nn.init.constant_(self.conv.bias, 0)
|
||||
if self.with_norm:
|
||||
if hasattr(self.norm, "weight") and self.norm.weight is not None:
|
||||
nn.init.constant_(self.norm.weight, 1)
|
||||
if hasattr(self.norm, "bias") and self.norm.bias is not None:
|
||||
nn.init.constant_(self.norm.bias, 0)
|
||||
|
||||
def forward(self, x, activate=True, norm=True):
|
||||
for layer in self.order:
|
||||
if layer == "conv":
|
||||
if self.with_explicit_padding:
|
||||
x = self.pad(x)
|
||||
x = self.conv(x)
|
||||
elif layer == "norm" and norm and self.with_norm:
|
||||
x = self.norm(x)
|
||||
elif layer == "act" and activate and self.with_activation:
|
||||
x = self.activate(x)
|
||||
return x
|
||||
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
def __init__(self, scale_factor, mode, align_corners=False):
|
||||
super(Interpolate, self).__init__()
|
||||
self.interp = nn.functional.interpolate
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
|
||||
return x
|
||||
|
||||
|
||||
class HeadDepth(nn.Module):
|
||||
def __init__(self, features):
|
||||
super(HeadDepth, self).__init__()
|
||||
self.head = nn.Sequential(
|
||||
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
|
||||
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
||||
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
class ReassembleBlocks(nn.Module):
|
||||
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
||||
rearrange the feature vector to feature map.
|
||||
Args:
|
||||
in_channels (int): ViT feature channels. Default: 768.
|
||||
out_channels (List): output channels of each stage.
|
||||
Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16):
|
||||
super(ReassembleBlocks, self).__init__()
|
||||
|
||||
assert readout_type in ["ignore", "add", "project"]
|
||||
self.readout_type = readout_type
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.projects = nn.ModuleList(
|
||||
[
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channel,
|
||||
kernel_size=1,
|
||||
act_layer=None,
|
||||
)
|
||||
for out_channel in out_channels
|
||||
]
|
||||
)
|
||||
|
||||
self.resize_layers = nn.ModuleList(
|
||||
[
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
|
||||
),
|
||||
nn.ConvTranspose2d(
|
||||
in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
|
||||
),
|
||||
nn.Identity(),
|
||||
nn.Conv2d(
|
||||
in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
|
||||
),
|
||||
]
|
||||
)
|
||||
if self.readout_type == "project":
|
||||
self.readout_projects = nn.ModuleList()
|
||||
for _ in range(len(self.projects)):
|
||||
self.readout_projects.append(nn.Sequential(nn.Linear(2 * in_channels, in_channels), nn.GELU()))
|
||||
|
||||
def forward(self, inputs):
|
||||
assert isinstance(inputs, list)
|
||||
out = []
|
||||
for i, x in enumerate(inputs):
|
||||
assert len(x) == 2
|
||||
x, cls_token = x[0], x[1]
|
||||
feature_shape = x.shape
|
||||
if self.readout_type == "project":
|
||||
x = x.flatten(2).permute((0, 2, 1))
|
||||
readout = cls_token.unsqueeze(1).expand_as(x)
|
||||
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
||||
x = x.permute(0, 2, 1).reshape(feature_shape)
|
||||
elif self.readout_type == "add":
|
||||
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
||||
x = x.reshape(feature_shape)
|
||||
else:
|
||||
pass
|
||||
x = self.projects[i](x)
|
||||
x = self.resize_layers[i](x)
|
||||
out.append(x)
|
||||
return out
|
||||
|
||||
|
||||
class PreActResidualConvUnit(nn.Module):
|
||||
"""ResidualConvUnit, pre-activate residual unit.
|
||||
Args:
|
||||
in_channels (int): number of channels in the input feature map.
|
||||
act_layer (nn.Module): activation layer.
|
||||
norm_layer (nn.Module): norm layer.
|
||||
stride (int): stride of the first block. Default: 1
|
||||
dilation (int): dilation rate for convs layers. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, act_layer, norm_layer, stride=1, dilation=1):
|
||||
super(PreActResidualConvUnit, self).__init__()
|
||||
|
||||
self.conv1 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
dilation=dilation,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
bias=False,
|
||||
order=("act", "conv", "norm"),
|
||||
)
|
||||
|
||||
self.conv2 = ConvModule(
|
||||
in_channels,
|
||||
in_channels,
|
||||
3,
|
||||
padding=1,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
bias=False,
|
||||
order=("act", "conv", "norm"),
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs_ = inputs.clone()
|
||||
x = self.conv1(inputs)
|
||||
x = self.conv2(x)
|
||||
return x + inputs_
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
"""FeatureFusionBlock, merge feature map from different stages.
|
||||
Args:
|
||||
in_channels (int): Input channels.
|
||||
act_layer (nn.Module): activation layer for ResidualConvUnit.
|
||||
norm_layer (nn.Module): normalization layer.
|
||||
expand (bool): Whether expand the channels in post process block.
|
||||
Default: False.
|
||||
align_corners (bool): align_corner setting for bilinear upsample.
|
||||
Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, act_layer, norm_layer, expand=False, align_corners=True):
|
||||
super(FeatureFusionBlock, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.expand = expand
|
||||
self.align_corners = align_corners
|
||||
|
||||
self.out_channels = in_channels
|
||||
if self.expand:
|
||||
self.out_channels = in_channels // 2
|
||||
|
||||
self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_layer=None, bias=True)
|
||||
|
||||
self.res_conv_unit1 = PreActResidualConvUnit(
|
||||
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
|
||||
)
|
||||
self.res_conv_unit2 = PreActResidualConvUnit(
|
||||
in_channels=self.in_channels, act_layer=act_layer, norm_layer=norm_layer
|
||||
)
|
||||
|
||||
def forward(self, *inputs):
|
||||
x = inputs[0]
|
||||
if len(inputs) == 2:
|
||||
if x.shape != inputs[1].shape:
|
||||
res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
|
||||
else:
|
||||
res = inputs[1]
|
||||
x = x + self.res_conv_unit1(res)
|
||||
x = self.res_conv_unit2(x)
|
||||
x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
|
||||
x = self.project(x)
|
||||
return x
|
||||
|
||||
|
||||
class DPTHead(DepthBaseDecodeHead):
|
||||
"""Vision Transformers for Dense Prediction.
|
||||
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
||||
Args:
|
||||
embed_dims (int): The embed dimension of the ViT backbone.
|
||||
Default: 768.
|
||||
post_process_channels (List): Out channels of post process conv
|
||||
layers. Default: [96, 192, 384, 768].
|
||||
readout_type (str): Type of readout operation. Default: 'ignore'.
|
||||
patch_size (int): The patch size. Default: 16.
|
||||
expand_channels (bool): Whether expand the channels in post process
|
||||
block. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims=768,
|
||||
post_process_channels=[96, 192, 384, 768],
|
||||
readout_type="ignore",
|
||||
patch_size=16,
|
||||
expand_channels=False,
|
||||
**kwargs,
|
||||
):
|
||||
super(DPTHead, self).__init__(**kwargs)
|
||||
|
||||
self.in_channels = self.in_channels
|
||||
self.expand_channels = expand_channels
|
||||
self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
|
||||
|
||||
self.post_process_channels = [
|
||||
channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
|
||||
]
|
||||
self.convs = nn.ModuleList()
|
||||
for channel in self.post_process_channels:
|
||||
self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_layer=None, bias=False))
|
||||
self.fusion_blocks = nn.ModuleList()
|
||||
for _ in range(len(self.convs)):
|
||||
self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_layer, self.norm_layer))
|
||||
self.fusion_blocks[0].res_conv_unit1 = None
|
||||
self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_layer=self.norm_layer)
|
||||
self.num_fusion_blocks = len(self.fusion_blocks)
|
||||
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
||||
self.num_post_process_channels = len(self.post_process_channels)
|
||||
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
||||
assert self.num_reassemble_blocks == self.num_post_process_channels
|
||||
self.conv_depth = HeadDepth(self.channels)
|
||||
|
||||
def forward(self, inputs, img_metas):
|
||||
assert len(inputs) == self.num_reassemble_blocks
|
||||
x = [inp for inp in inputs]
|
||||
x = self.reassemble_blocks(x)
|
||||
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
||||
out = self.fusion_blocks[0](x[-1])
|
||||
for i in range(1, len(self.fusion_blocks)):
|
||||
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
||||
out = self.project(out)
|
||||
out = self.depth_pred(out)
|
||||
return out
|
||||
|
|
|
@ -5,12 +5,12 @@
|
|||
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .backbones import _make_dinov2_model
|
||||
from .depth import BNHead, DepthEncoderDecoder
|
||||
from .depth import BNHead, DepthEncoderDecoder, DPTHead
|
||||
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name, CenterPadding
|
||||
|
||||
|
||||
|
@ -19,10 +19,26 @@ class Weights(Enum):
|
|||
KITTI = "KITTI"
|
||||
|
||||
|
||||
def _get_depth_range(pretrained: bool, weights: Weights = Weights.NYU) -> Tuple[float, float]:
|
||||
if not pretrained: # Default
|
||||
return (0.001, 10.0)
|
||||
|
||||
# Pretrained, set according to the training dataset for the provided weights
|
||||
if weights == Weights.KITTI:
|
||||
return (0.001, 80.0)
|
||||
|
||||
if weights == Weights.NYU:
|
||||
return (0.001, 10.0)
|
||||
|
||||
return (0.001, 10.0)
|
||||
|
||||
|
||||
def _make_dinov2_linear_depth_head(
|
||||
*,
|
||||
embed_dim: int = 1024,
|
||||
layers: int = 4,
|
||||
embed_dim: int,
|
||||
layers: int,
|
||||
min_depth: float,
|
||||
max_depth: float,
|
||||
**kwargs,
|
||||
):
|
||||
if layers not in (1, 4):
|
||||
|
@ -46,7 +62,7 @@ def _make_dinov2_linear_depth_head(
|
|||
channels=embed_dim * len(in_index) * 2,
|
||||
align_corners=False,
|
||||
min_depth=0.001,
|
||||
max_depth=10,
|
||||
max_depth=80,
|
||||
loss_decode=(),
|
||||
)
|
||||
|
||||
|
@ -57,6 +73,7 @@ def _make_dinov2_linear_depther(
|
|||
layers: int = 4,
|
||||
pretrained: bool = True,
|
||||
weights: Union[Weights, str] = Weights.NYU,
|
||||
depth_range: Optional[Tuple[float, float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if layers not in (1, 4):
|
||||
|
@ -67,15 +84,20 @@ def _make_dinov2_linear_depther(
|
|||
except KeyError:
|
||||
raise AssertionError(f"Unsupported weights: {weights}")
|
||||
|
||||
if depth_range is None:
|
||||
depth_range = _get_depth_range(pretrained, weights)
|
||||
min_depth, max_depth = depth_range
|
||||
|
||||
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
|
||||
|
||||
embed_dim = backbone.embed_dim
|
||||
patch_size = backbone.patch_size
|
||||
model_name = _make_dinov2_model_name(arch_name, patch_size)
|
||||
linear_depth_head = _make_dinov2_linear_depth_head(
|
||||
arch_name=arch_name,
|
||||
embed_dim=embed_dim,
|
||||
layers=layers,
|
||||
min_depth=min_depth,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
||||
layer_count = {
|
||||
|
@ -140,3 +162,85 @@ def dinov2_vitg14_ld(*, layers: int = 4, pretrained: bool = True, weights: Union
|
|||
return _make_dinov2_linear_depther(
|
||||
arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def _make_dinov2_dpt_depth_head(*, embed_dim: int, min_depth: float, max_depth: float):
|
||||
return DPTHead(
|
||||
in_channels=[embed_dim] * 4,
|
||||
channels=256,
|
||||
embed_dims=embed_dim,
|
||||
post_process_channels=[embed_dim // 2 ** (3 - i) for i in range(4)],
|
||||
readout_type="project",
|
||||
min_depth=min_depth,
|
||||
max_depth=max_depth,
|
||||
loss_decode=(),
|
||||
)
|
||||
|
||||
|
||||
def _make_dinov2_dpt_depther(
|
||||
*,
|
||||
arch_name: str = "vit_large",
|
||||
pretrained: bool = True,
|
||||
weights: Union[Weights, str] = Weights.NYU,
|
||||
depth_range: Optional[Tuple[float, float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if isinstance(weights, str):
|
||||
try:
|
||||
weights = Weights[weights]
|
||||
except KeyError:
|
||||
raise AssertionError(f"Unsupported weights: {weights}")
|
||||
|
||||
if depth_range is None:
|
||||
depth_range = _get_depth_range(pretrained, weights)
|
||||
min_depth, max_depth = depth_range
|
||||
|
||||
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs)
|
||||
|
||||
model_name = _make_dinov2_model_name(arch_name, backbone.patch_size)
|
||||
dpt_depth_head = _make_dinov2_dpt_depth_head(embed_dim=backbone.embed_dim, min_depth=min_depth, max_depth=max_depth)
|
||||
|
||||
out_index = {
|
||||
"vit_small": [2, 5, 8, 11],
|
||||
"vit_base": [2, 5, 8, 11],
|
||||
"vit_large": [4, 11, 17, 23],
|
||||
"vit_giant2": [9, 19, 29, 39],
|
||||
}[arch_name]
|
||||
|
||||
model = DepthEncoderDecoder(backbone=backbone, decode_head=dpt_depth_head)
|
||||
model.backbone.forward = partial(
|
||||
backbone.get_intermediate_layers,
|
||||
n=out_index,
|
||||
reshape=True,
|
||||
return_class_token=True,
|
||||
norm=False,
|
||||
)
|
||||
model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone.patch_size)(x[0]))
|
||||
|
||||
if pretrained:
|
||||
weights_str = weights.value.lower()
|
||||
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_{weights_str}_dpt_head.pth"
|
||||
checkpoint = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
||||
if "state_dict" in checkpoint:
|
||||
state_dict = checkpoint["state_dict"]
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def dinov2_vits14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
||||
return _make_dinov2_dpt_depther(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
||||
|
||||
|
||||
def dinov2_vitb14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
||||
return _make_dinov2_dpt_depther(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
||||
|
||||
|
||||
def dinov2_vitl14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
||||
return _make_dinov2_dpt_depther(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
||||
|
||||
|
||||
def dinov2_vitg14_dd(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.NYU, **kwargs):
|
||||
return _make_dinov2_dpt_depther(
|
||||
arch_name="vit_giant2", ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs
|
||||
)
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
from dinov2.hub.backbones import dinov2_vitb14, dinov2_vitg14, dinov2_vitl14, dinov2_vits14
|
||||
from dinov2.hub.classifiers import dinov2_vitb14_lc, dinov2_vitg14_lc, dinov2_vitl14_lc, dinov2_vits14_lc
|
||||
from dinov2.hub.depthers import dinov2_vitb14_ld, dinov2_vitg14_ld, dinov2_vitl14_ld, dinov2_vits14_ld
|
||||
from dinov2.hub.depthers import dinov2_vitb14_dd, dinov2_vitg14_dd, dinov2_vitl14_dd, dinov2_vits14_dd
|
||||
|
||||
|
||||
dependencies = ["torch"]
|
||||
|
|
Loading…
Reference in New Issue