dinov2/notebooks/depth_estimation.ipynb

483 lines
656 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) Meta Platforms, Inc. and affiliates."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2023-09-01 00:43:55 +08:00
"# Depth Estimation <a target=\"_blank\" href=\"https://colab.research.google.com/github/facebookresearch/dinov2/blob/main/notebooks/depth_estimation.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"INSTALL = False # Switch this to install dependencies\n",
"if INSTALL: # Try installing package with extras\n",
" REPO_URL = \"https://github.com/facebookresearch/dinov2\"\n",
" !{sys.executable} -m pip install -e {REPO_URL}'[extras]' --extra-index-url https://download.pytorch.org/whl/cu117 --extra-index-url https://pypi.nvidia.com\n",
"else:\n",
" REPO_PATH = \"<FIXME>\" # Specify a local path to the repository (or use installed package instead)\n",
" sys.path.append(REPO_PATH)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Utilities"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import itertools\n",
"from functools import partial\n",
"\n",
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"from dinov2.eval.depth.models import build_depther\n",
"\n",
"\n",
"class CenterPadding(torch.nn.Module):\n",
" def __init__(self, multiple):\n",
" super().__init__()\n",
" self.multiple = multiple\n",
"\n",
" def _get_pad(self, size):\n",
" new_size = math.ceil(size / self.multiple) * self.multiple\n",
" pad_size = new_size - size\n",
" pad_size_left = pad_size // 2\n",
" pad_size_right = pad_size - pad_size_left\n",
" return pad_size_left, pad_size_right\n",
"\n",
" @torch.inference_mode()\n",
" def forward(self, x):\n",
" pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))\n",
" output = F.pad(x, pads)\n",
" return output\n",
"\n",
"\n",
"def create_depther(cfg, backbone_model, backbone_size, head_type):\n",
" train_cfg = cfg.get(\"train_cfg\")\n",
" test_cfg = cfg.get(\"test_cfg\")\n",
" depther = build_depther(cfg.model, train_cfg=train_cfg, test_cfg=test_cfg)\n",
"\n",
" depther.backbone.forward = partial(\n",
" backbone_model.get_intermediate_layers,\n",
" n=cfg.model.backbone.out_indices,\n",
" reshape=True,\n",
" return_class_token=cfg.model.backbone.output_cls_token,\n",
" norm=cfg.model.backbone.final_norm,\n",
" )\n",
"\n",
" if hasattr(backbone_model, \"patch_size\"):\n",
" depther.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone_model.patch_size)(x[0]))\n",
"\n",
" return depther"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load pretrained backbone"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using cache found in /private/home/plabatut/.cache/torch/hub/facebookresearch_dinov2_main\n"
]
},
{
"data": {
"text/plain": [
"DinoVisionTransformer(\n",
" (patch_embed): PatchEmbed(\n",
" (proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))\n",
" (norm): Identity()\n",
" )\n",
" (blocks): ModuleList(\n",
" (0-11): 12 x NestedTensorBlock(\n",
" (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
" (attn): MemEffAttention(\n",
" (qkv): Linear(in_features=384, out_features=1152, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=384, out_features=384, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ls1): LayerScale()\n",
" (drop_path1): Identity()\n",
" (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=384, out_features=1536, bias=True)\n",
" (act): GELU(approximate='none')\n",
" (fc2): Linear(in_features=1536, out_features=384, bias=True)\n",
" (drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ls2): LayerScale()\n",
" (drop_path2): Identity()\n",
" )\n",
" )\n",
" (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
" (head): Identity()\n",
")"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"BACKBONE_SIZE = \"small\" # in (\"small\", \"base\", \"large\" or \"giant\")\n",
"\n",
"\n",
"backbone_archs = {\n",
" \"small\": \"vits14\",\n",
" \"base\": \"vitb14\",\n",
" \"large\": \"vitl14\",\n",
" \"giant\": \"vitg14\",\n",
"}\n",
"backbone_arch = backbone_archs[BACKBONE_SIZE]\n",
"backbone_name = f\"dinov2_{backbone_arch}\"\n",
"\n",
"backbone_model = torch.hub.load(repo_or_dir=\"facebookresearch/dinov2\", model=backbone_name)\n",
"backbone_model.eval()\n",
"backbone_model.cuda()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load pretrained depth head"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"load checkpoint from http path: https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_dpt_head.pth\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading: \"https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_dpt_head.pth\" to /private/home/plabatut/.cache/torch/hub/checkpoints/dinov2_vits14_nyu_dpt_head.pth\n",
"100%|██████████| 160M/160M [00:06<00:00, 27.2MB/s] \n"
]
},
{
"data": {
"text/plain": [
"DepthEncoderDecoder(\n",
" (backbone): DinoVisionTransformer()\n",
" (decode_head): DPTHead(\n",
" align_corners=False\n",
" (loss_decode): ModuleList(\n",
" (0): SigLoss()\n",
" (1): GradientLoss()\n",
" )\n",
" (conv_depth): HeadDepth(\n",
" (head): Sequential(\n",
" (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): Interpolate()\n",
" (2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (3): ReLU()\n",
" (4): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" )\n",
" (relu): ReLU()\n",
" (sigmoid): Sigmoid()\n",
" (reassemble_blocks): ReassembleBlocks(\n",
" (projects): ModuleList(\n",
" (0): ConvModule(\n",
" (conv): Conv2d(384, 48, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" (1): ConvModule(\n",
" (conv): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" (2): ConvModule(\n",
" (conv): Conv2d(384, 192, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" (3): ConvModule(\n",
" (conv): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" )\n",
" (resize_layers): ModuleList(\n",
" (0): ConvTranspose2d(48, 48, kernel_size=(4, 4), stride=(4, 4))\n",
" (1): ConvTranspose2d(96, 96, kernel_size=(2, 2), stride=(2, 2))\n",
" (2): Identity()\n",
" (3): Conv2d(384, 384, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
" )\n",
" (readout_projects): ModuleList(\n",
" (0-3): 4 x Sequential(\n",
" (0): Linear(in_features=768, out_features=384, bias=True)\n",
" (1): GELU(approximate='none')\n",
" )\n",
" )\n",
" )\n",
" (convs): ModuleList(\n",
" (0): ConvModule(\n",
" (conv): Conv2d(48, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" )\n",
" (1): ConvModule(\n",
" (conv): Conv2d(96, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" )\n",
" (2): ConvModule(\n",
" (conv): Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" )\n",
" (3): ConvModule(\n",
" (conv): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" )\n",
" )\n",
" (fusion_blocks): ModuleList(\n",
" (0): FeatureFusionBlock(\n",
" (project): ConvModule(\n",
" (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" (res_conv_unit1): None\n",
" (res_conv_unit2): PreActResidualConvUnit(\n",
" (conv1): ConvModule(\n",
" (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (activate): ReLU(inplace=True)\n",
" )\n",
" (conv2): ConvModule(\n",
" (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (activate): ReLU(inplace=True)\n",
" )\n",
" )\n",
" )\n",
" (1-3): 3 x FeatureFusionBlock(\n",
" (project): ConvModule(\n",
" (conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))\n",
" )\n",
" (res_conv_unit1): PreActResidualConvUnit(\n",
" (conv1): ConvModule(\n",
" (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (activate): ReLU(inplace=True)\n",
" )\n",
" (conv2): ConvModule(\n",
" (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (activate): ReLU(inplace=True)\n",
" )\n",
" )\n",
" (res_conv_unit2): PreActResidualConvUnit(\n",
" (conv1): ConvModule(\n",
" (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (activate): ReLU(inplace=True)\n",
" )\n",
" (conv2): ConvModule(\n",
" (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (activate): ReLU(inplace=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (project): ConvModule(\n",
" (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (activate): ReLU(inplace=True)\n",
" )\n",
" )\n",
")"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import urllib\n",
"\n",
"import mmcv\n",
"from mmcv.runner import load_checkpoint\n",
"\n",
"\n",
"def load_config_from_url(url: str) -> str:\n",
" with urllib.request.urlopen(url) as f:\n",
" return f.read().decode()\n",
"\n",
"\n",
"HEAD_DATASET = \"nyu\" # in (\"nyu\", \"kitti\")\n",
"HEAD_TYPE = \"dpt\" # in (\"linear\", \"linear4\", \"dpt\")\n",
"\n",
"\n",
"DINOV2_BASE_URL = \"https://dl.fbaipublicfiles.com/dinov2\"\n",
"head_config_url = f\"{DINOV2_BASE_URL}/{backbone_name}/{backbone_name}_{HEAD_DATASET}_{HEAD_TYPE}_config.py\"\n",
"head_checkpoint_url = f\"{DINOV2_BASE_URL}/{backbone_name}/{backbone_name}_{HEAD_DATASET}_{HEAD_TYPE}_head.pth\"\n",
"\n",
"cfg_str = load_config_from_url(head_config_url)\n",
"cfg = mmcv.Config.fromstring(cfg_str, file_format=\".py\")\n",
"\n",
"model = create_depther(\n",
" cfg,\n",
" backbone_model=backbone_model,\n",
" backbone_size=BACKBONE_SIZE,\n",
" head_type=HEAD_TYPE,\n",
")\n",
"\n",
"load_checkpoint(model, head_checkpoint_url, map_location=\"cpu\")\n",
"model.eval()\n",
"model.cuda()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load sample image"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAABamlDQ1BJQ0MgUHJvZmlsZQAAeJx1kL1Lw1AUxU+rUtA6iA4dHDKJQ9TSCnZxaCsURTBUBatTmn4JbXwkKVJxE1cp+B9YwVlwsIhUcHFwEEQHEd2cOim4aHjel1TaIt7H5f04nHO5XMAbUBkr9gIo6ZaRTMSktdS65HuDh55TqmayqKIsCv79u+vz0fXeT4hZTbt2ENlPXJfOLpd2ngJTf/1d1Z/Jmhr939RBjRkW4JGJlW2LCd4lHjFoKeKq4LzLx4LTLp87npVknPiWWNIKaoa4SSynO/R8B5eKZa21g9jen9VXl8Uc6lHMYRMmGIpQUYEEBeF//NOOP44tcldgUC6PAizKREkRE7LE89ChYRIycQhB6pC4c+t+D637yW1t7xWYbXDOL9raQgM4naGT1dvaeAQYGgBu6kw1VEfqofbmcsD7CTCYAobvKLNh5sIhd3t/DOh74fxjDPAdAnaV868jzu0ahZ+BK/0HFylqvLiAv9gAAQAASURBVHicrP1psHVLlhiErSEz997nnDt845vq1avqGrqGrpJoIfVUomXUjSSwDRFYAyAJQ8jwx+AfKMIORzjCxhgThiAc/mEby9gNEcYOhMMggxAt1JJopJ6kVnd1qau7a3j16s3vG+50pr0z11r+kXvvk2efc+/7GpxR9b1998mdw5pz5cqV+IUvflZVzQwAzExE8jMR5YfxfUrJzMAoxph/ijGqKhEhIiKamZmNz6qqqmeL6uzsjNAuLi6++PnPNU2z3qy6bSsiqppUzMwUFCF/nqRNUdu222wTGANQSl1Vh1deeeQDfeYHPmkmwXHqYkoJAJNKHk/uNA/SOdd1XR45IjRNc3Z2MpvNiIh9x+xjp++9+/Tn/tovv/rKYyMmDADgvd+0681m45wjItXEzITmvQeAPLMMECIa/1VNZOOzMnMGDg5lGAbml1CU8f3hT5M/i/cMAIg71ADu4WvE2tAHImIenoGOLY+Iy5+UYy5byB+OL8cZTcY5/GsTEA0TpLG+Y6+q5Ricc7nxZNp13bvvvvv9d578uT/3p1999VUzQ63UkDxGXfmaJBlTDagEgMaQfHCVwopdUlUwrw5VxHGFyKaYadlQAcBMzQRMCBQNyAAN0gFeiunrBBd5LiMQ8kMGo5qNcC5bI3JjgwPSdmAc+WVkIuiHatbjWjN6mDG1navqrk1J7T/4f/3Fv/W3f/HBo4d/+p/4Iw8e3COPH3548Vf+i59HqjtJ83lzOj996fEnTk7OPvWpT3zms59su9Vq0zbNiaFmBul7J2LmEEJKaT6fr1abDz/8sG1bYHLOzWazGr2BMtu22xDDYrGoqxkmapctdCrruF1vnj1/8vTph8+un3z07KMPbp7xAIBNlMW9xz/2tf9OMz89Pz///pvf/gv/x//DT/3DX/vH//v/WD2rzh8++ps//1/9O//7v/DoweKf/Cf/2OuffDl1aTF/8OGHl//6v/Vvf+ozb3z7zbf+mT/7Tz9+9OrV1c3rr762Wi7/7X/r33z51ddubm6IiBmlLzH3d+/+WdM0koyIRCSE8Du/8+0/+U//s5//wc89OL+3Wq3+6l/9L9rt9vr6crvdMiMRrdfrq6urjIWU1MzEehbOxJmpdMRsxh0zM3OWEgR2yCCIqAqDGFFVHUnLIIKRqmYBWPLL2EKWt1ks7whj6D2/qaoqf+6cy59kVkLE2WzWNE1Gca7jvUfEGKMqZAnZtu1yufTep9SllEQkN66qMUYRy882lPyMiISWTAGAAMEMEIHQJEsaBSa6WX1ufu+HX//U9eVFx+rZzefz0/Oztm1TSllEjMQ/Ch8zBDUAAEK0QiJp/8BEZjKOZ8dl0rNqKbVwmHgPW8D+vwCg3VE5OXm5E4OKAABkvTggbHyYOb8A9/rpvfNqVjE1909WKL/z4TuX3fLvzWs2QIP1ev38ydP19RI7AQCtwjWqgJ1yU3fmgbEJN6l143x2o79V9GehaZlAMyKJKJP72Eimj8m/psk5d3NzU9d1pl1EFBEFy7gfMUFEiAZAZoYAZpJp4v333//BL3y2aZqUOk0xS20RdY7NVERhh0tTVe99pnIzTSl1XfI+eu/JMEXxvj45OTGDNkVmIKdZsXl23aBBAcgMAW2iJku7ZAKx20A3QXNZ/2M/OdoCgN3yflomFkD5XL6cjKqUBXcYDXeP86i23uMfNEAzUECwXkYBEYnAZrORZCLi2BCJiBidJGFfAziQpCbBhyQEyGZkCEmsCnXSLTsPBiklQoeIznHXCTJm8QVmqIioCAB4CMjjACznVc5lYq8cgug2vEwo4WPHoKohBCByzhEgMxNR27ZZcGf6Z+9MMctuVb28fH59vbx//xQRU0pVVaWUkCF/O6AAcwshBJHe/hYR5zhzEHkCxJS2zOw8ISKYLW9uPnz7Q29+c3mzXi43m5uL5dXF1cV6vaZBU2U15IgREQnYkZmopcqHzPu5L9XU00lPMBBjzLOezSpmFhGH5L1XTYC9QlJNAKSqRMAciEg0AoCIbNstGA1CBnxweSLvvffOer2GQWSlFOu6Xq1WXdeFEFSBmVNKh3gcCxRiepAtBnicXwZKP+B3cqaD0VYo5gm67xARuXJeYJSLn1Fhr1artm3rup7NZsPiIePUZbMgV67rOstJKExwIvLee78zEXJlABhH248qazVEA8j9GpCqWo9WVQBksmEllusA6Dj3o7xTdnE3j/QDxp7cSjBOZFQpM2/jt3IwAwbHARjk1swAQPrZ6eX11ezMhXmz3XbhtH7l5ZfjR+9XCCYCTE3TnD+4P6/q7XKzWq027ZZrbwZd1zVcgWGMkZjcuPwtx30ImpEOEGEknRH3E0lUmjYZi2BWVdXlzfW9e/eY2QYDPNsviJhXaoMC1swkYD24VHW9bruuu76+ZkZH6OvKzMyiYf/5CNuMcmYex57NBRHznsCUyIHRYnHy6U+/cr3cDBMxM/Heu+iyCBhnNMxryocT4jiqhkswTrj0d6t69xv8ePHd/2rFV4OwuIPuD6fwux1eCaIJfBBRTQABEAx2smz4FoioqioAWK+2GYNmZpZim5DVBYfKMabZLMRuk1Kaze5fXl7OT7iNMlucbtbqXAU2GP6MbdsyGpKicY9PZEQxIwEj05F1J5i9bXYlmR0aKLcZNz00jrF/SQlmRjRtsySYLkZJhuyy+dt1HTMbQkrJCL33SZ1tLdsXiDifz/PCCJhiF5vmJGnMBnTfCyIAMHOMMYTsXMFRQ5sZqKBDBSPElFLbtqxudX3zvW9/59HpI4/k2UkIzhMwKEivgFERiIicJyQjopRSSgkBFycz5xwzpZSWy6VzLlQuhEBERgAA2+0WALz352entQ9qiYgcc7dtzQBUQEVFjBIAIBGiIUDlgyNmJEesqoRmqmDQNI1DWi6v33nnHSJSiQCQUnKON5vNdrsdBOCoZnc+oVEYjoJuQrGqirxzhpXIAgDVZJZ17c6biIjIYLr7Kvc1trmD/KDYSioa36eUMqbyJ/nzrGVTStvtdrPZbDabqqq893kFzMwpdXkBX9f1fD4Xka7bdl3Xtm2MMctqZkbk0jgYSXR8uQNUroNgZjrIP96tYXoTYWykVDeHjHBM6exxEBQrCtjjsh0HmRlOpLftPicoRfp0AEeGRAXPA4CaiESARLRN8Waz9p4JjVsMjhe+Poe0hCSgVLkFzjDUm2ptjOubK1FtfO0UYkqCDEAE5Eq75nAQpSzAnlf79W7GfUppdNGUCngsbRu3221VVc651Wq13W5ns9lhR/lzRBQxRMpOHlPKHRHjem2bzWY2WzgH3XajqtnDs95uCp00kv6eK1jVui62bcfsAMExKSgCv/zyy+/86q+//PKrmZhym8zctm2hP2x4PgKW/qFUcreX0li5o84h5A8qFLyxY8+9AZRYK4h1WvMFmaGsf6h1Pm60oLpTxocUVXSBiOi9B4PLy8teE6DFGAM7QU1d8q7y7LebjXeokjabVQgBICG59Saa1qImEr1jJmBGZlTrHDqRpAhgBAA
"text/plain": [
"<PIL.Image.Image image mode=RGB size=640x480>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import urllib\n",
"\n",
"from PIL import Image\n",
"\n",
"\n",
"def load_image_from_url(url: str) -> Image:\n",
" with urllib.request.urlopen(url) as f:\n",
" return Image.open(f).convert(\"RGB\")\n",
"\n",
"\n",
"EXAMPLE_IMAGE_URL = \"https://dl.fbaipublicfiles.com/dinov2/images/example.jpg\"\n",
"\n",
"\n",
"image = load_image_from_url(EXAMPLE_IMAGE_URL)\n",
"display(image)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Estimate depth on sample image"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAEAAElEQVR4nMz927LkOJItCC5VkGbbPTKrsupcep5mXuY/R+ZzR0r6dNfJiou7GaHzoIBCcSXN9o6sUgnxsE2CAEiCWFh6A/1//t//3z8OAHhGPAUCRJFDYCKCKADSKfsNgHIZJgyFiQhgAhEYIAJTVTSKpCZGbQGIqUsQ5D9L4XQEQEx/pR8xdTBVspCmPwwCYMcY4xvzF/UtWOvNqZhuQZAfoB6RXC4V0IOQprZB5133KP/WgwwionJH+U4Zerj6gfwG8yXlHu3NlvrdvVN9xB5mPzByW+7ayZjxcvb2Kon1nzwu9ZXiW7SuRv/tWEl7xdKe6q9yl0tzSge/iERBhByS/nvI8cDxwPMnHg96HHgc9HziceChn0vEASB/OqD8eBjB/vS/h2KXW20RhyAe8oh4RImCI3Utt0LEgXbGHmjfcN9x30X/3XfsG3hHCMQ7MYMCEVE1qu0ef8bjdzz+g379D/q/f5P/64/j//rj+e8/n38/4h8x/iFyQIcobYE/tvD9vv/zx/a3X/i//5P8j7/JP/0tfPx1D3/Z+Xuge8BHwD1gI7kzAkkgbCQ6Rw2nMnKvqx//+bjYWStfPg0tr2VyeXIXkqtBL+dcmIbHXUOpBnK1pWvLcd9JK9AcLPfVHWkkTxdCJIGFSEKIYYvbFnmPYZNwF9pAW249PwiywcX5T6b8m8ozZSpH7K2USurH6vvU/74y0VjZCZLJ+PsUREEUeQqeUX5G+RHjT8QfEh+IPykeJAfFI3db748FAO+y/yUiYgOwMz4C/jjwM+K3pxyCp5taPBACBpDpYOqwAN0cTQAgZVKmhMdeApEAERJ0CiMQpbmMQIb1BEgGEiEJ2g0iEQRCBAJIZyUtJEKKW/Yxz2CsgViPScA5EmtHm8pZSFu0yz2QMFGU/O0hDWStQe9CD1qd4XL/8+qhQmIPvXpHfpFh6Ouht5xN77EgNEbgusBdVM8w/anlqSt5KsOvABnMwujU9eovYr1f0vkW7Wj/0qMgEGkZvd/o1o4qk+dAURD0Q0qLMwRA9HXFcnsRHEUiwpYXqBGRwYIQM3CKWzDobwJHHDx6cuzmtmEBxeCIeOARaFekjziEouBIvUcI2Ik4YA/Yd9x3uQXsir43hADeKQQi/Y9BNhRT06JDNAbiTThI2GgPtDPtgW+Bb3oXgigSmTaibQsfe/jlHv7pg//5F/nbL/L9O90+QvgIfGO6B9wd+u4sgRBIOCNcP87zC2oxNRdrsdljrf1pKNv8tj973NVKeqBtSq5xt0dcX9LfzinoNkIkTEJkP8rl1CzhmAa4i+p5EZM9HfLP0UHvAHfdHOR7lrsx//rnlDH9qOeavrREARECgYUUCSPkEHoKMThAGDgAFgYkav9FfwAIuwCgDdtfd/njoEfE7890+4AcwOHW7IK0spWyxJWGnejMrt9PJrsgISZ5AhuT6KvqViSZubrfmYg7NlyxAXFziUhZnQx5sGecVsxw63BnKJcriJVP9qNzyG7Ln12jTRn/+ApfN+o8ob++57mkgV9C8XQLIJ0F9S7I3VEZ57qeIUDWhLhaRxsqH+lZ4cjrCStgGhQmP0GVCYioBfuZvESCFzIcdb3E8eF5+fr4gvKiGYTV8WbwjNuNTutziBg7fMhxIP7E8QM/HvR40I8Djwd+KCge8hBEwRFFYfiwCgmBiXUWAcAIRNyQY9f64f9UzDMGLDiiPEWOiCgSARAxg4kCgZn2QPtG943uO+477nd8HHLbsUWRTXgjPhSDhfRbsw9Z7/GHPH/g8ZP++IHfHvLbU35EeSKR7A2IiV4Qa9+iPA55POjHU46HHD+P8AcjELOu+AU3JoEcQoGg/9md2+B0YCyHpO/CxrZHakW1Q8CEQ/FS0vyvxY98XwzkqhIMH0JMcoitdKWuRBShiOQQYtKhpZcQE0inGvfBSqq5PEY7ZWfzQale6iui/U/0lyMLbTgiERHHEGkTipAnwPV8XYGrIIOxcEZiyXQzgW4pKR3cOlwXdypdXily3T0D87X8TAyZXOtZ+5oZ8DPiKfIUiYgHpNyyUO6e/jh+BPohIYJ32RR9D0EEGIgivi1k/JvdiJeoIwYSiUJ+GgRsTIHycsGtb3OnEEWCO6IkmECBEhJXDBiSeID2k4gnGoJ0+45xpiMv8KLc1eX7GqIv8idq5zy1JaJ2CWNtXdM5L/TP/kejiEaB3kKCh9DrFdHleKeIHrLeRrHfwPMrOqGrhX2Lvd2hPztUTkA/ga7yl9B3oW12R9zyq6u8Ad18sCxAs+VCsU70iCe4Mf8WGa8oCAGAoq8e8ehr0OvV0QHsVdCU/1REF4mGvqmY4EAkxEB7lAcTi0ShGLsH7IeoOwiRdI+qYP+Jx4N+PuT3R/z9eeh/f4g8FYnT/dJTb5lp2/j7g34c9ExvHEQZawNhY/tdVNBApYWm+ZcIQNzAjhlT8xRNsJGQqxUhovQ+2F2CwW9xlUBVgbo0iaJTqEQhLS+o/hyK+4pbQnxFhjVTXVVmw6AtchAigFXDjGSGtAs5PwJkFTSQ2BuSyg493x2g72hqMK1sA5YmPXG8Lk1tTIOvl8ngVgSIZPTXihyRwpEI0nZkuI2iylscgkPk6MyuRxRjlt7Omm9cSgfSWKedORAOwc4IVEhwumTyWJy6W8wGDIXe2gbsDcBlVsrEN1Xi7Kw912yfnoJQxutGnXsqw2obE6//IZD+YEWUR5TMcRSbuXRxTUfS2DsNvIgep4y7OmsrN2VQJNEfR2LDifUeevuSGksTR/7qYn5EE53N5PESHZIWU4ddK62a+qL4b2dt7G/ODhGuPT6htrM6G7htLDW+/rrFcbcboO1/R5EDIiJPxAPRrL8P+vHAjyd+HHh47istfz0IIUosrw+AJHV0xFMXYEN7sHHfiMchjyhP/VekMgAj0VOO9GTa0gJfQMQH7U85CM+AIJBDIlOIEBaKJPb1HSJPiT/l+IHn7/jjN/6P3/DvP+TvP4//eBy/Po7fYvwjxh8ot0aEIPwEEPj+jL8d/E9PHPaVJYNjxl1T1AGIIBaheuqW7qv3nKpXPh/O6AQYUoIpfz6iC1+Jk6HeLGcJ4jXSqGDIVvni/zQN8xVFdNX0iypoMwNjhGI0Qd+mTK7LDvXN5B8eWZmAtOjIZ0clm99DEVlh8PqsrbY2qAoaAl1h8g4dA8cTiNrbth45KDK2mBC3sM8ARKJIIqAoYEqqqx7vG0WrKkhVOZYQK0bVNrIUxb5ZAYOvh8r0xA6DS3frAx7VbBJT9LVeRYjhrvdsmsFwpbZVtpdMUPDHG+ns2oNOWv1NgWS9zmAZ6x8AJpbNqsIrVLjppLcKd10STu990XIrDccddHX0lMpMl2/3wACMTxepw8pxhsoLmemNh3X2hZvPxD4cZXV+ATocPCJZ2abOBJkAVb8poSgLCYjADObEXxVlmBHysgnDCZ+y15XSX2TEHRp9y/1qOYEgEiIlYzM3GH8qZmb2Lhemj0lTn4BBG3jDtsm208eDPgLfA/+M8Sn0JEr29HQ7tBFtREzETHvAviVLM++MjWkjbISdsRM2xkaiAMHOY+OiKPVMz0QqaHTwqTeod1egdGYSrq9N5c0JqzreWYIX0Gsk1X7nrr52y430mE0sRJ0N+LoIQOlpJnw1jcFa+hUK5rfX6OdWTyGdGiuSXK+IYcxSBIhApHhQ+rcUEwCSaXE8ZHtESDfL6JFDlUqAiDxFJFuebPWNjGEeyRjEIAJFIaG0nGUCRRAT0erjnlng0llJM7Wp43ruq8RXoVdx14Nu7WbctnaUCajg08DI6oRBmJGYpf4KjR7yrPBCesOwP5hAXRc3WdmjVlvPgw9nDI7J+40g2RyVqfBRVoSS16MrKuwO+s++fLjUlZeMValAmddeni0qTWAnax/msqxcIi7qb6dXMjf017e7pr84Y8D9xabdZQRBFIqmNW5Mv/lHZfrlbAxGh8FmAG6+GiYeqdCTDVg/zWyPOhd1w/Ssg0m9fCioPxc4YA+0B75FfgoigaNspqBJflj8beNvO3+
"text/plain": [
"<PIL.Image.Image image mode=RGB size=640x480>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib\n",
"from torchvision import transforms\n",
"\n",
"\n",
"def make_depth_transform() -> transforms.Compose:\n",
" return transforms.Compose([\n",
" transforms.ToTensor(),\n",
" lambda x: 255.0 * x[:3], # Discard alpha component and scale by 255\n",
" transforms.Normalize(\n",
" mean=(123.675, 116.28, 103.53),\n",
" std=(58.395, 57.12, 57.375),\n",
" ),\n",
" ])\n",
"\n",
"\n",
"def render_depth(values, colormap_name=\"magma_r\") -> Image:\n",
" min_value, max_value = values.min(), values.max()\n",
" normalized_values = (values - min_value) / (max_value - min_value)\n",
"\n",
" colormap = matplotlib.colormaps[colormap_name]\n",
" colors = colormap(normalized_values, bytes=True) # ((1)xhxwx4)\n",
" colors = colors[:, :, :3] # Discard alpha component\n",
" return Image.fromarray(colors)\n",
"\n",
"\n",
"transform = make_depth_transform()\n",
"\n",
"scale_factor = 1\n",
"rescaled_image = image.resize((scale_factor * image.width, scale_factor * image.height))\n",
"transformed_image = transform(rescaled_image)\n",
"batch = transformed_image.unsqueeze(0).cuda() # Make a batch of one image\n",
"\n",
"with torch.inference_mode():\n",
" result = model.whole_inference(batch, img_meta=None, rescale=True)\n",
"\n",
"depth_image = render_depth(result.squeeze().cpu())\n",
"display(depth_image)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
}
},
"nbformat": 4,
"nbformat_minor": 4
}