dinov2/notebooks/semantic_segmentation.ipynb

400 lines
514 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "b470389d-a897-416e-9601-aeacb39cd694",
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) Meta Platforms, Inc. and affiliates."
]
},
{
"cell_type": "markdown",
"id": "eb5c8577-7dff-41b1-9b04-2dca12940e02",
"metadata": {},
"source": [
"# Semantic Segmentation <a target=\"_blank\" href=\"https://colab.research.google.com/github/facebookresearch/dinov2/blob/main/notebooks/segmentation.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "febdf412-5ad0-4bbc-9530-754f92dcc491",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
"INSTALL = False # Switch this to install dependencies\n",
"if INSTALL: # Try installing package with extras\n",
" REPO_URL = \"https://github.com/facebookresearch/dinov2\"\n",
" !{sys.executable} -m pip install -e {REPO_URL}'[extras]' --extra-index-url https://download.pytorch.org/whl/cu117 --extra-index-url https://pypi.nvidia.com\n",
"else:\n",
" REPO_PATH = \"<FIXME>\" # Specify a local path to the repository (or use installed package instead)\n",
" sys.path.append(REPO_PATH)"
]
},
{
"cell_type": "markdown",
"id": "efdf378b-0591-4879-9db6-6a4ab582d49f",
"metadata": {},
"source": [
"## Utilities"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "90223c04-e7da-4738-bb16-d4f7025aa3eb",
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import itertools\n",
"from functools import partial\n",
"\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from mmseg.apis import init_segmentor, inference_segmentor\n",
"\n",
"import dinov2.eval.segmentation.models\n",
"\n",
"\n",
"class CenterPadding(torch.nn.Module):\n",
" def __init__(self, multiple):\n",
" super().__init__()\n",
" self.multiple = multiple\n",
"\n",
" def _get_pad(self, size):\n",
" new_size = math.ceil(size / self.multiple) * self.multiple\n",
" pad_size = new_size - size\n",
" pad_size_left = pad_size // 2\n",
" pad_size_right = pad_size - pad_size_left\n",
" return pad_size_left, pad_size_right\n",
"\n",
" @torch.inference_mode()\n",
" def forward(self, x):\n",
" pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))\n",
" output = F.pad(x, pads)\n",
" return output\n",
"\n",
"\n",
"def create_segmenter(cfg, backbone_model):\n",
" model = init_segmentor(cfg)\n",
" model.backbone.forward = partial(\n",
" backbone_model.get_intermediate_layers,\n",
" n=cfg.model.backbone.out_indices,\n",
" reshape=True,\n",
" )\n",
" if hasattr(backbone_model, \"patch_size\"):\n",
" model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone_model.patch_size)(x[0]))\n",
" model.init_weights()\n",
" return model"
]
},
{
"cell_type": "markdown",
"id": "a5724efc-b2b8-46ed-94e1-7fee59a39ed9",
"metadata": {},
"source": [
"## Load pretrained backbone"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2d51b932-1157-45ce-997f-572ad417a12f",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using cache found in /private/home/plabatut/.cache/torch/hub/facebookresearch_dinov2_main\n",
"/private/home/plabatut/github/patricklabatut/dinov2/dinov2/layers/swiglu_ffn.py:43: UserWarning: xFormers is available (SwiGLU)\n",
" warnings.warn(\"xFormers is available (SwiGLU)\")\n",
"/private/home/plabatut/github/patricklabatut/dinov2/dinov2/layers/attention.py:27: UserWarning: xFormers is available (Attention)\n",
" warnings.warn(\"xFormers is available (Attention)\")\n",
"/private/home/plabatut/github/patricklabatut/dinov2/dinov2/layers/block.py:33: UserWarning: xFormers is available (Block)\n",
" warnings.warn(\"xFormers is available (Block)\")\n"
]
},
{
"data": {
"text/plain": [
"DinoVisionTransformer(\n",
" (patch_embed): PatchEmbed(\n",
" (proj): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))\n",
" (norm): Identity()\n",
" )\n",
" (blocks): ModuleList(\n",
" (0-11): 12 x NestedTensorBlock(\n",
" (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
" (attn): MemEffAttention(\n",
" (qkv): Linear(in_features=384, out_features=1152, bias=True)\n",
" (attn_drop): Dropout(p=0.0, inplace=False)\n",
" (proj): Linear(in_features=384, out_features=384, bias=True)\n",
" (proj_drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ls1): LayerScale()\n",
" (drop_path1): Identity()\n",
" (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
" (mlp): Mlp(\n",
" (fc1): Linear(in_features=384, out_features=1536, bias=True)\n",
" (act): GELU(approximate='none')\n",
" (fc2): Linear(in_features=1536, out_features=384, bias=True)\n",
" (drop): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ls2): LayerScale()\n",
" (drop_path2): Identity()\n",
" )\n",
" )\n",
" (norm): LayerNorm((384,), eps=1e-06, elementwise_affine=True)\n",
" (head): Identity()\n",
")"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"BACKBONE_SIZE = \"small\" # in (\"small\", \"base\", \"large\" or \"giant\")\n",
"\n",
"\n",
"BACKBONE_ARCHS = {\n",
" \"small\": \"vits14\",\n",
" \"base\": \"vitb14\",\n",
" \"large\": \"vitl14\",\n",
" \"giant\": \"vitg14\",\n",
"}\n",
"\n",
"\n",
"backbone_arch = BACKBONE_ARCHS[BACKBONE_SIZE]\n",
"backbone_name = f\"dinov2_{backbone_arch}\"\n",
"backbone_model = torch.hub.load(repo_or_dir=\"facebookresearch/dinov2\", model=backbone_name)\n",
"backbone_model.eval()\n",
"backbone_model.cuda()"
]
},
{
"cell_type": "markdown",
"id": "c1c90501-d6ef-436e-b1a1-72e63b0534e3",
"metadata": {},
"source": [
"## Load pretrained segmentation head"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d0bf0b7f-ad98-4cfb-8120-f076df8f8933",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/private/home/plabatut/.conda/envs/dinov2-extras-conda/lib/python3.9/site-packages/mmseg/models/losses/cross_entropy_loss.py:235: UserWarning: Default ``avg_non_ignore`` is False, if you would like to ignore the certain label and average loss over non-ignore labels, which is the same with PyTorch official cross_entropy, set ``avg_non_ignore=True``.\n",
" warnings.warn(\n",
"2023-08-31 06:05:23,461 - mmcv - INFO - initialize BNHead with init_cfg {'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}\n",
"2023-08-31 06:05:23,463 - mmcv - INFO - \n",
"decode_head.conv_seg.weight - torch.Size([21, 1536, 1, 1]): \n",
"NormalInit: mean=0, std=0.01, bias=0 \n",
" \n",
"2023-08-31 06:05:23,464 - mmcv - INFO - \n",
"decode_head.conv_seg.bias - torch.Size([21]): \n",
"NormalInit: mean=0, std=0.01, bias=0 \n",
" \n",
"2023-08-31 06:05:23,464 - mmcv - INFO - \n",
"decode_head.bn.weight - torch.Size([1536]): \n",
"The value is the same before and after calling `init_weights` of EncoderDecoder \n",
" \n",
"2023-08-31 06:05:23,465 - mmcv - INFO - \n",
"decode_head.bn.bias - torch.Size([1536]): \n",
"The value is the same before and after calling `init_weights` of EncoderDecoder \n",
" \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"scales: [1.0, 1.32, 1.73]\n",
"load checkpoint from http path: https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_voc2012_ms_head.pth\n"
]
},
{
"data": {
"text/plain": [
"EncoderDecoder(\n",
" (backbone): DinoVisionTransformer()\n",
" (decode_head): BNHead(\n",
" input_transform=resize_concat, ignore_index=255, align_corners=False\n",
" (loss_decode): CrossEntropyLoss(avg_non_ignore=False)\n",
" (conv_seg): Conv2d(1536, 21, kernel_size=(1, 1), stride=(1, 1))\n",
" (bn): SyncBatchNorm(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" init_cfg={'type': 'Normal', 'std': 0.01, 'override': {'name': 'conv_seg'}}\n",
")"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import urllib\n",
"\n",
"import mmcv\n",
"from mmcv.runner import load_checkpoint\n",
"\n",
"\n",
"def load_config_from_url(url: str) -> str:\n",
" with urllib.request.urlopen(url) as f:\n",
" return f.read().decode()\n",
"\n",
"\n",
"HEAD_SCALE_COUNT = 3 # more scales: slower but better results, in (1,2,3,4,5)\n",
"HEAD_DATASET = \"voc2012\" # in (\"ade20k\", \"voc2012\")\n",
"HEAD_TYPE = \"ms\" # in (\"ms, \"linear\")\n",
"\n",
"\n",
"DINOV2_BASE_URL = \"https://dl.fbaipublicfiles.com/dinov2\"\n",
"head_config_url = f\"{DINOV2_BASE_URL}/{backbone_name}/{backbone_name}_{HEAD_DATASET}_{HEAD_TYPE}_config.py\"\n",
"head_checkpoint_url = f\"{DINOV2_BASE_URL}/{backbone_name}/{backbone_name}_{HEAD_DATASET}_{HEAD_TYPE}_head.pth\"\n",
"\n",
"cfg_str = load_config_from_url(head_config_url)\n",
"cfg = mmcv.Config.fromstring(cfg_str, file_format=\".py\")\n",
"if HEAD_TYPE == \"ms\":\n",
" cfg.data.test.pipeline[1][\"img_ratios\"] = cfg.data.test.pipeline[1][\"img_ratios\"][:HEAD_SCALE_COUNT]\n",
" print(\"scales:\", cfg.data.test.pipeline[1][\"img_ratios\"])\n",
"\n",
"model = create_segmenter(cfg, backbone_model=backbone_model)\n",
"load_checkpoint(model, head_checkpoint_url, map_location=\"cpu\")\n",
"model.cuda()\n",
"model.eval()"
]
},
{
"cell_type": "markdown",
"id": "2dc1b106-d28c-41cc-9ddd-f558d66a4715",
"metadata": {},
"source": [
"## Load sample image"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "44511634-8243-4662-a512-4531014adb32",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAABamlDQ1BJQ0MgUHJvZmlsZQAAeJx1kL1Lw1AUxU+rUtA6iA4dHDKJQ9TSCnZxaCsURTBUBatTmn4JbXwkKVJxE1cp+B9YwVlwsIhUcHFwEEQHEd2cOim4aHjel1TaIt7H5f04nHO5XMAbUBkr9gIo6ZaRTMSktdS65HuDh55TqmayqKIsCv79u+vz0fXeT4hZTbt2ENlPXJfOLpd2ngJTf/1d1Z/Jmhr939RBjRkW4JGJlW2LCd4lHjFoKeKq4LzLx4LTLp87npVknPiWWNIKaoa4SSynO/R8B5eKZa21g9jen9VXl8Uc6lHMYRMmGIpQUYEEBeF//NOOP44tcldgUC6PAizKREkRE7LE89ChYRIycQhB6pC4c+t+D637yW1t7xWYbXDOL9raQgM4naGT1dvaeAQYGgBu6kw1VEfqofbmcsD7CTCYAobvKLNh5sIhd3t/DOh74fxjDPAdAnaV868jzu0ahZ+BK/0HFylqvLiAv9gAAQAASURBVHicrP1psHVLlhiErSEz997nnDt845vq1avqGrqGrpJoIfVUomXUjSSwDRFYAyAJQ8jwx+AfKMIORzjCxhgThiAc/mEby9gNEcYOhMMggxAt1JJopJ6kVnd1qau7a3j16s3vG+50pr0z11r+kXvvk2efc+/7GpxR9b1998mdw5pz5cqV+IUvflZVzQwAzExE8jMR5YfxfUrJzMAoxph/ijGqKhEhIiKamZmNz6qqqmeL6uzsjNAuLi6++PnPNU2z3qy6bSsiqppUzMwUFCF/nqRNUdu222wTGANQSl1Vh1deeeQDfeYHPmkmwXHqYkoJAJNKHk/uNA/SOdd1XR45IjRNc3Z2MpvNiIh9x+xjp++9+/Tn/tovv/rKYyMmDADgvd+0681m45wjItXEzITmvQeAPLMMECIa/1VNZOOzMnMGDg5lGAbml1CU8f3hT5M/i/cMAIg71ADu4WvE2tAHImIenoGOLY+Iy5+UYy5byB+OL8cZTcY5/GsTEA0TpLG+Y6+q5Ricc7nxZNp13bvvvvv9d578uT/3p1999VUzQ63UkDxGXfmaJBlTDagEgMaQfHCVwopdUlUwrw5VxHGFyKaYadlQAcBMzQRMCBQNyAAN0gFeiunrBBd5LiMQ8kMGo5qNcC5bI3JjgwPSdmAc+WVkIuiHatbjWjN6mDG1navqrk1J7T/4f/3Fv/W3f/HBo4d/+p/4Iw8e3COPH3548Vf+i59HqjtJ83lzOj996fEnTk7OPvWpT3zms59su9Vq0zbNiaFmBul7J2LmEEJKaT6fr1abDz/8sG1bYHLOzWazGr2BMtu22xDDYrGoqxkmapctdCrruF1vnj1/8vTph8+un3z07KMPbp7xAIBNlMW9xz/2tf9OMz89Pz///pvf/gv/x//DT/3DX/vH//v/WD2rzh8++ps//1/9O//7v/DoweKf/Cf/2OuffDl1aTF/8OGHl//6v/Vvf+ozb3z7zbf+mT/7Tz9+9OrV1c3rr762Wi7/7X/r33z51ddubm6IiBmlLzH3d+/+WdM0koyIRCSE8Du/8+0/+U//s5//wc89OL+3Wq3+6l/9L9rt9vr6crvdMiMRrdfrq6urjIWU1MzEehbOxJmpdMRsxh0zM3OWEgR2yCCIqAqDGFFVHUnLIIKRqmYBWPLL2EKWt1ks7whj6D2/qaoqf+6cy59kVkLE2WzWNE1Gca7jvUfEGKMqZAnZtu1yufTep9SllEQkN66qMUYRy882lPyMiISWTAGAAMEMEIHQJEsaBSa6WX1ufu+HX//U9eVFx+rZzefz0/Oztm1TSllEjMQ/Ch8zBDUAAEK0QiJp/8BEZjKOZ8dl0rNqKbVwmHgPW8D+vwCg3VE5OXm5E4OKAABkvTggbHyYOb8A9/rpvfNqVjE1909WKL/z4TuX3fLvzWs2QIP1ev38ydP19RI7AQCtwjWqgJ1yU3fmgbEJN6l143x2o79V9GehaZlAMyKJKJP72Eimj8m/psk5d3NzU9d1pl1EFBEFy7gfMUFEiAZAZoYAZpJp4v333//BL3y2aZqUOk0xS20RdY7NVERhh0tTVe99pnIzTSl1XfI+eu/JMEXxvj45OTGDNkVmIKdZsXl23aBBAcgMAW2iJku7ZAKx20A3QXNZ/2M/OdoCgN3yflomFkD5XL6cjKqUBXcYDXeP86i23uMfNEAzUECwXkYBEYnAZrORZCLi2BCJiBidJGFfAziQpCbBhyQEyGZkCEmsCnXSLTsPBiklQoeIznHXCTJm8QVmqIioCAB4CMjjACznVc5lYq8cgug2vEwo4WPHoKohBCByzhEgMxNR27ZZcGf6Z+9MMctuVb28fH59vbx//xQRU0pVVaWUkCF/O6AAcwshBJHe/hYR5zhzEHkCxJS2zOw8ISKYLW9uPnz7Q29+c3mzXi43m5uL5dXF1cV6vaZBU2U15IgREQnYkZmopcqHzPu5L9XU00lPMBBjzLOezSpmFhGH5L1XTYC9QlJNAKSqRMAciEg0AoCIbNstGA1CBnxweSLvvffOer2GQWSlFOu6Xq1WXdeFEFSBmVNKh3gcCxRiepAtBnicXwZKP+B3cqaD0VYo5gm67xARuXJeYJSLn1Fhr1artm3rup7NZsPiIePUZbMgV67rOstJKExwIvLee78zEXJlABhH248qazVEA8j9GpCqWo9WVQBksmEllusA6Dj3o7xTdnE3j/QDxp7cSjBOZFQpM2/jt3IwAwbHARjk1swAQPrZ6eX11ezMhXmz3XbhtH7l5ZfjR+9XCCYCTE3TnD+4P6/q7XKzWq027ZZrbwZd1zVcgWGMkZjcuPwtx30ImpEOEGEknRH3E0lUmjYZi2BWVdXlzfW9e/eY2QYDPNsviJhXaoMC1swkYD24VHW9bruuu76+ZkZH6OvKzMyiYf/5CNuMcmYex57NBRHznsCUyIHRYnHy6U+/cr3cDBMxM/Heu+iyCBhnNMxryocT4jiqhkswTrj0d6t69xv8ePHd/2rFV4OwuIPuD6fwux1eCaIJfBBRTQABEAx2smz4FoioqioAWK+2GYNmZpZim5DVBYfKMabZLMRuk1Kaze5fXl7OT7iNMlucbtbqXAU2GP6MbdsyGpKicY9PZEQxIwEj05F1J5i9bXYlmR0aKLcZNz00jrF/SQlmRjRtsySYLkZJhuyy+dt1HTMbQkrJCL33SZ1tLdsXiDifz/PCCJhiF5vmJGnMBnTfCyIAMHOMMYTsXMFRQ5sZqKBDBSPElFLbtqxudX3zvW9/59HpI4/k2UkIzhMwKEivgFERiIicJyQjopRSSgkBFycz5xwzpZSWy6VzLlQuhEBERgAA2+0WALz352entQ9qiYgcc7dtzQBUQEVFjBIAIBGiIUDlgyNmJEesqoRmqmDQNI1DWi6v33nnHSJSiQCQUnKON5vNdrsdBOCoZnc+oVEYjoJuQrGqirxzhpXIAgDVZJZ17c6biIjIYLr7Kvc1trmD/KDYSioa36eUMqbyJ/nzrGVTStvtdrPZbDabqqq893kFzMwpdXkBX9f1fD4Xka7bdl3Xtm2MMctqZkbk0jgYSXR8uQNUroNgZjrIP96tYXoTYWykVDeHjHBM6exxEBQrCtjjsh0HmRlOpLftPicoRfp0AEeGRAXPA4CaiESARLRN8Waz9p4JjVsMjhe+Poe0hCSgVLkFzjDUm2ptjOubK1FtfO0UYkqCDEAE5Eq75nAQpSzAnlf79W7GfUppdNGUCngsbRu3221VVc651Wq13W5ns9lhR/lzRBQxRMpOHlPKHRHjem2bzWY2WzgH3XajqtnDs95uCp00kv6eK1jVui62bcfsAMExKSgCv/zyy+/86q+//PKrmZhym8zctm2hP2x4PgKW/qFUcreX0li5o84h5A8qFLyxY8+9AZRYK4h1WvMFmaGsf6h1Pm60oLpTxocUVXSBiOi9B4PLy8teE6DFGAM7QU1d8q7y7LebjXeokjabVQgBICG59Saa1qImEr1jJmBGZlTrHDqRpAhgBAA
"text/plain": [
"<PIL.Image.Image image mode=RGB size=640x480>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import urllib\n",
"\n",
"from PIL import Image\n",
"\n",
"\n",
"def load_image_from_url(url: str) -> Image:\n",
" with urllib.request.urlopen(url) as f:\n",
" return Image.open(f).convert(\"RGB\")\n",
"\n",
"\n",
"EXAMPLE_IMAGE_URL = \"https://dl.fbaipublicfiles.com/dinov2/images/example.jpg\"\n",
"\n",
"\n",
"image = load_image_from_url(EXAMPLE_IMAGE_URL)\n",
"display(image)"
]
},
{
"cell_type": "markdown",
"id": "7e3240cb-54d0-438d-99e8-8c1af534f830",
"metadata": {},
"source": [
"## Semantic segmentation on sample image"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "49226d5b-83fc-4cfb-ba06-407bb2c0d96f",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoAAAAHgCAIAAAC6s0uzAAAXVklEQVR4nO3dXXbiuraAUdcZaViaRtPStPPAHi6KgPGPpLUkzfl0x7m1EweIPi/ZgT+3BWAKP7cP/+D70z848TU/2vim5774963AUdVw4uFN4vfjWeRn+V+BrwEAHCTAABBAgAH+E7JzW/yb5tx/5jcBBqCFkc4MivwsAgxw0khFaeMn6w1iG+od8J9qXxkgkf3L6J4bXI8uynvuTP79fbtr1SFd3BS9/RRc/BEEGJhCkQAf+iLrP376gvtLPHaAV2lLvOfxv3LwtqCB8R0q2ct/fHrv9PcCnbY3URKeZ7TZKjcBA4M7vZLeS3niP1//w+3WvvvKs03Aj7Z3C2qfu5Q6x9pJgIGRhTRsZ4CXTw2eMMA7FS/xxYf63PF8XfqeAInlD9i7m7PyH3msd9fXz32RKAIMjCl8ed3JsHvF6RKXfcDvX+3oMdiCBgYU27PTY5kMX7Tzka/3OB966k3AAAzi5ccWnRtPTx/A/m/kz5AAGNaa5IRvwiXAwGiyrbMEWufR7yOzaRsCDJBFtkL07v54/k5vksfZNWCARPa8a/T6Lx+Z+58kqewGEzBASdfX/XVue/f/fbmbmr83LUU9Goe+rwADQwkfBIscwEZ9T/xX5CTAwHTeDZFlBb4L5uR6eRC8EQcwlI3sFfmcwZ2u/PnpiU8OPvoVxrbnMc/wXhwCDIzm5doauCgfavC5O7CufKkhnfscqgbf+pEtaGA0p3cgK21d7nwLiITvFMEJ+59EEzAwoCufIxuyOXn6U4c/mjnqgR90sefZEWBgWPWuwp728mBOfzsNXiX8UKmPz44taGBYp2917uU2Wh6lqu8eJmCAt2rflnXx65uAk9t+grwVJcBblTY2WxZRfdMSYIAPEl5f3KO7A56NAAPssv9jEgLlP0JWbsIC2CvzzVn+jLg7AgxwQKoGr8WV3h65CxrgMMFjD3dBAxT29HdEZd9eg3reFTHkyTIBA9Siwak0/oSGj1crBBigLhkO1/4dS/Z8RzdhAdSV6r6tCbV//Hd+RwEGqE6Du1DkafJ5wAC5aHCIzA+7AAMwphP1vRjsQ/+5AAM0knkaG0/aS78rAQZoR4PbuPI4N/sMaQEGYCi9nOUIMADjaHwn85VvKsAATfUyn1GbAAPAeafPqAQYoDVDMIsAA4TQ4BoKPqoNniABBoihwWV193j6NCSASD4r6bp66f347Fz51iZggEjft9eLeHfzXIh3j14Rtc+Nvup+eQB2+L79t9zr7n7hj9XFAxBggBTCc9KXAR4uW9AAcNj1MwABBqAzA4y/iy1oAO4eq5b53uwM9S1yDCZggKRalib/p+de+a9OaHAKIsAAeYVMe2nH3wyz71LuMAQYgP+i+3NrV9+0mV9aHZt3wgLILnOrmkmy+VzwMNyEBQDL0vxExxY0AARsMwgwQHZJbj4aWMgmvwADQAABBoAAAgxAB8a7FVyAAehDpQZHpV2AAejGSHOwAAPQk5Zv1/Wk7O3oAgyQ3UhjXymlHpPAx9Y7YZHC93J7+b//vPnfAX5uBUbS77h5+kyAH9fKPevju7X1EAvxkD6+NtZ/4AXAtIy/G4o0OMrnAG8vke9iXCS6ew7Dutypo6+Q+7/3dMPw1qAOf+ax9WlIxSOagRU8UKlXlCeReQwfodXvQXb/z359CN75vcpO228n4CHru/z6uSzlVVV6FX0vN08c0LvXE/Co9d1mTb+u8SvHU8bY5hl/l8s7z22G4LoT8JzpvTt6cxl3ga8ZozAMI//ZRvG7vf5Yv7Z5fJ4kPEXzHDGq/E1K5WIgPz7axQPs74A/yPlnMNsVPHGoCbO6nzkYrnjsStfJ7+5PkkzAh8U+YkdL+fJou87tO17JjKdqDt+1qusG353L8PYPXiPtAnxe44duyGqW5cXMYEq18Gg83n3fK38p1F6pn/rEl9pJgKso/qiq735e0gzjet4KzoI9Tsz7f/z29V1cA67kRC/fZUN6j3JJGIo3o69rq3f5j1mAsxDagjSYadWozvbXDPwwgw2lHoeqFRdgxqTBTKhILZ5qmn+OvLtynFEnEP+L+bZQn00FuKiX+l4ROL4LMCPTYGggvNOnD6D9nx49EmAGp8GwX8KruR9Vqm8DrgEzPteDYY/wIFWV8KczATMFczDdaRyMhH3aY+f4m/OnE2BmocHwTs4+fdR1fRcBZioaDL+l7VOsBneWCTBz0WAoLiTh4bdeXyfAAAwr83z/v5/l9uMeUWZiCCa/Qx+HUEnmEbP2sbX52f9OwDLMPDSY7tyT0CyKRT5HqJLMZwaH2IJmUhpMp1LlJ/MG73L28Jo9ws8BNgQDxGpZtQx73dN6DrCxgHl4tdOjBnnOPNeOdHJgC5qpaTAzyxzal47W90StWwZegAE6UDyW21+wuzYX0Xi8/ifApgEm5GVPR1TzqMxb1iZggD6c6+vP7cV/2GOqx/jb30cCDIZghrWG9p7hQ939+I97qXjaIViAAcb0MpCHMtxLYq8LifTfABsCmJnXP7w0T4PbMwHDfzSY3j2OcQXDmeEqctpt5Cu+og8AgAJqJ2rgUTiq7v98GEPMIUAahmB69H17TsjAsRzJP1vQGgwaTF+G3Jt9MurP6BowQK9elsn4+2T7AQmsuwADJCKf8/BxhPDMLjQh1puNd85ko27MzsMEDJDLlSG45QDtDOCiP7dX/6sJgMnZCiKJjaD+7l9UfYt/3+Jpf3eEsecQJmB4wTko7R19r+ZAVbvVrL7hXgfY6T9AiNNv1NwyM2mTdkj4Frp3wgJIapjPI9ovPIot2YIG4LDx2t+eAAOkcB/+1hFwwsJNNf4uAgyQx2wFmpwAAzCdDOc6AgyQUYZCbEh+eF0QYACOqVHfCYv+9s+Qfpab9yIAYDVMI5P8IP4OGCCp7wTvjZWkVUOyBQ1QRXg7L/q+Vfxcpv1f/Lqnb5TnlGIrwN6QEuCKThtco46PX/P3F68dxfW756nvYgIGKOvpMxW6a3ClRGX4qIlU9V3efRzhI7diMSc7QFx07831Rb+Xj/i9eJzZ6tiACRiglr6icvFo+/phM/g8AS+GYGZlCCaJBkNwqXxeOdTZEi7A8JYAk0rxDNe73HvaVA3eFeBFg5mSAJNQF5daNXgP14ABetLyL2hDhN8s3czeABsFmJCNH9JKnuGLxzZJg03AAL1KnuErZmjwgfeC9vEMAP1qnOr12x1N6ainFL+ZgGGLk07ym6dYgxFgABKZYfP57liA3YoFkNDHIThqSp6npieYgOEDu9B0YSOx3e1RT5LtwwE2BAPk9Pum6IFvkx7AgbugAciveHEf59FmOf+Z4NRBgAH4z8e931Kfschy7hqwXWiAwfzcDlx5bXONdvgrwW7Cgs/chwVPNPi6kwE2BAMM41znmjV41AybgAE4qVkah8zw+QAbggFo2cWN73UvdF+dNgEDcEnjBj99u76i++hSgA3BACxvKlivi+u36zS9dyZgAMpYc9hmKu26vsv1ABuCAbpWNmP9bgi3ZwKGXfwpMHSho3fpKhBgQzAAHOW9oAHmtX+7+N1kacP5NAEGmNSedn7c0f120fesMteA7UIDdOfjpwXvvJ7a0WXXVEzAAFOTzyjuggbgKhU/oViA7UIDzCxDgzMcw34mYADK6Kt/4QQYgGICG9xd/ksG2C40AOxkAgagpPaT6Me/p8pJgGEXGzywX8sc9pjeu8IBtkgBsLTqYr/1XWpMwBoMwFK/jl3Xd7EFDUA9vTeyqj+3Ol/Xh6cyGFs7cFHZz2wYIO21JmCrFYNxTgkXdXqvcj22oAFoR4ZXFQNsCAbgJQ1eTMAAEEKAYS+XgaEgQ7AAA0CAugF2GRiA4saYnk3AAAQo+2fBPRJgAAggwHCA+7CgiCvj7xj7z0uDALsMDMAjm893JmAA2rlY32HG36VNgA3BAPDEBAzHuAwMFNEowIZgAK4b6fpxuwlYgwFg1XQLWoMZg11o4DrXgAFIZ6S7nd9pHWBDMMC0dl7Bvdd3+AabgAEgQECADcEAc/q+FZhrh7kR2gQMhzm
"text/plain": [
"<PIL.Image.Image image mode=RGB size=640x480>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"\n",
"import dinov2.eval.segmentation.utils.colormaps as colormaps\n",
"\n",
"\n",
"DATASET_COLORMAPS = {\n",
" \"ade20k\": colormaps.ADE20K_COLORMAP,\n",
" \"voc2012\": colormaps.VOC2012_COLORMAP,\n",
"}\n",
"\n",
"\n",
"colormap = DATASET_COLORMAPS[HEAD_DATASET]\n",
"\n",
"def render_segmentation(segmentation_logits):\n",
" colormap_array = np.array(colormap, dtype=np.uint8)\n",
" segmentation_values = colormap_array[segmentation_logits + 1]\n",
" return Image.fromarray(segmentation_values)\n",
"\n",
"\n",
"array = np.array(image)[:, :, ::-1] # BGR\n",
"segmentation_logits = inference_segmentor(model, array)[0]\n",
"segmented_image = render_segmentation(segmentation_logits)\n",
"display(segmented_image)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
}
},
"nbformat": 4,
"nbformat_minor": 5
}