2023-04-04 22:25:49 -07:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "f400486b",
"metadata": {},
"outputs": [],
"source": [
"# Copyright (c) Meta Platforms, Inc. and affiliates."
]
},
{
"cell_type": "markdown",
"id": "a1ae39ff",
"metadata": {},
"source": [
"# Object masks from prompts with SAM"
]
},
{
"cell_type": "markdown",
"id": "b4a4b25c",
"metadata": {},
"source": [
"The Segment Anything Model (SAM) predicts object masks given prompts that indicate the desired object. The model first converts the image into an image embedding that allows high quality masks to be efficiently produced from a prompt. \n",
"\n",
"The `SamPredictor` class provides an easy interface to the model for prompting the model. It allows the user to first set an image using the `set_image` method, which calculates the necessary image embeddings. Then, prompts can be provided via the `predict` method to efficiently predict masks from those prompts. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "18ab8c70",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from IPython.display import display, HTML\n",
"display(HTML(\n",
"\"\"\"\n",
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>\n",
"\"\"\"\n",
"))"
]
},
{
"cell_type": "markdown",
"id": "644532a8",
"metadata": {},
"source": [
"## Environment Set-up"
]
},
{
"cell_type": "markdown",
"id": "07fabfee",
"metadata": {},
"source": [
2023-04-10 10:50:17 -07:00
"If running locally using jupyter, first install `segment_anything` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything#installation) in the repository. If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'."
2023-04-04 22:25:49 -07:00
]
},
{
"cell_type": "code",
2023-04-05 06:13:09 -07:00
"execution_count": 1,
2023-04-04 22:25:49 -07:00
"id": "5ea65efc",
"metadata": {},
"outputs": [],
"source": [
2023-04-05 06:13:09 -07:00
"using_colab = False"
2023-04-04 22:25:49 -07:00
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "91dd9a89",
"metadata": {},
"outputs": [],
"source": [
2023-04-05 06:13:09 -07:00
"if using_colab:\n",
2023-04-04 22:25:49 -07:00
" import torch\n",
" import torchvision\n",
" print(\"PyTorch version:\", torch.__version__)\n",
" print(\"Torchvision version:\", torchvision.__version__)\n",
" print(\"CUDA is available:\", torch.cuda.is_available())\n",
" import sys\n",
" !{sys.executable} -m pip install opencv-python matplotlib\n",
" !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'\n",
" \n",
" !mkdir images\n",
2023-04-05 06:13:09 -07:00
" !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg\n",
" !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/groceries.jpg\n",
2023-04-04 22:25:49 -07:00
" \n",
" !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
]
},
{
"cell_type": "markdown",
"id": "0be845da",
"metadata": {},
"source": [
"## Set-up"
]
},
{
"cell_type": "markdown",
"id": "33681dd1",
"metadata": {},
"source": [
"Necessary imports and helper functions for displaying points, boxes, and masks."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "69b28288",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import matplotlib.pyplot as plt\n",
"import cv2"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "29bc90d5",
"metadata": {},
"outputs": [],
"source": [
"def show_mask(mask, ax, random_color=False):\n",
" if random_color:\n",
" color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
" else:\n",
" color = np.array([30/255, 144/255, 255/255, 0.6])\n",
" h, w = mask.shape[-2:]\n",
" mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
" ax.imshow(mask_image)\n",
" \n",
"def show_points(coords, labels, ax, marker_size=375):\n",
" pos_points = coords[labels==1]\n",
" neg_points = coords[labels==0]\n",
" ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
" ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) \n",
" \n",
"def show_box(box, ax):\n",
" x0, y0 = box[0], box[1]\n",
" w, h = box[2] - box[0], box[3] - box[1]\n",
" ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) \n"
]
},
{
"cell_type": "markdown",
"id": "23842fb2",
"metadata": {},
"source": [
"## Example image"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3c2e4f6b",
"metadata": {},
"outputs": [],
"source": [
"image = cv2.imread('images/truck.jpg')\n",
"image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e30125fd",
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0gAAAI1CAYAAADsLNpwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9WbNsSZbfh/2W+94Rcc659+Y8VGWN3V3VXdXVU1X1jAYaUwugjEbCZCIk0UjpkYYnGp5I4xP4gm9AmvGJZjJKBpOJb6LJCJlIEEB3o4eah66uOatyzrw37z1TROztvvSw3H379thx8hZICa3s42k3zzkRe/vsa63/mlxUVbktt+W23Jbbcltuy225LbflttyW24L7N92B23JbbsttuS235bbclttyW27LbfmLUm4B0m25LbflttyW23JbbsttuS235bakcguQbsttuS235bbclttyW27LbbkttyWVW4B0W27Lbbktt+W23Jbbcltuy225LancAqTbcltuy225LbflttyW23JbbsttSeUWIN2W23JbbsttuS235bbclttyW25LKrcA6bbclttyW27Lbbktt+W23JbbcltSuQVIt+W23Jbbcltuy225LbflttyW25LKLUC6LbflttyW23JbbsttuS235bbcllRuAdJtuS235bbclttyW27Lbbktt+W2pPIXHiD9F//Ff8HHP/5xNpsNn/vc5/jn//yf/5vu0m25LbflttyW23JbbsttuS235X1a/kIDpH/yT/4J//F//B/zn/1n/xlf/OIX+Z3f+R3+7t/9u7z88sv/prt2W27Lbbktt+W23Jbbcltuy215HxZRVf033Ylj5dd//df57Gc/y3/5X/6X5bNPfepT/Lv/7r/LP/7H//jfYM9uy225LbflttyW23JbbsttuS3vx9L9m+7AsbLf7/nTP/1T/pP/5D+Zff57v/d7/P7v//7B87vdjt1uV/6OMXL//n2eeeYZROT/6/29LbflttyW23JbbsttuS235bb8xSyqyvn5OR/84Adx7mYnur+wAOntt98mhMALL7ww+/yFF17g9ddfP3j+H//jf8w/+kf/6P9X3bstt+W23Jbbcltuy225Lbfltvz/WfnRj37Ehz70oRuf+QsLkHJprT+qumgR+k//0/+Uf/gP/2H5++HDh3zkIx/hP/w//R/5+C98GhB8EBRQ7xGNpR4RQUQYx5ExRLy4gixVFe89YFap3L6IoKrlOREp3zvnyJ6L3vvyTv48hDD7W1WJMRJjLPWMIbDb7yD10XvPer1mGEf6riPEwNXlFSEExjBy5+wOIQRW6w0ahTCG0tZ+2DOOI13XsV6tSx8vLi8IY8B7h3hh1Xfs9jvCOCIi7Pd7NpuNjVenMYZxxDlHv+q5uLrkubMzfudzn2fzxAm8/S4P/+Ar9PsBOsF7V8ZVl3EcyxzZT63Wdno2z1M93845xtSHGONsDfPf3vupTRHESZnrti/Dbk8II8MwsF5vSpvtvsv9ze/ndavXUURwIrQ7VFURYCWeq5XQj8q9J+/ynT/4Al2InN47w3cOFHoUpxCcMnawk8DP/ZVfYRsG1usT8B3ar3AnGwaAaOvz6Mdv8uoXvsHdQdL74GNE0VLXeNbzi7/9qwxekd4h4gkKSId0HRHH1nk++Klf4Mmnn+fBuxc8vLxkN2yJRLaXVzzz/PM8+4HnWKvy8Ac/4LVvfYOOCCESiMROIUa89lwFx8d+5Vd49oMvERCCCBHBRfAjvPzn30HPH/Ct//v/g/GHr3LlBmIMrEYhOIcSZ3Od1zloZCTSiePUr7izOrEzKhF1DvUeVbj2kY//6i/gP/Ak1zKy3mwQPNI5nO9AIQ6Rte9569XX2V9clvNZznSMaI+NeX2COIc4AecBQRF0GCEqWyJPPvcM3/x//z7xmz9mvR8ZRNNagAd2olythU/+9d+kf+KO7enO9nVE2e8H4jiy3e+Q64Hv/qsvsXm4o48QUcQL3nVst9eMYyBGZRwj4zgwDtFowjiWvXfv3hOcnZ2gxHKW2n3clvKd/YWis/OqqgwK+xB45ZVX6bqO1arHeW+/uw5B6Zzjzt1T8kyB4pxwTWT1wWd59lM/TffkHQLQiQffsXYdnQgqwKi8/KWv8+g7P+Y0gGokagQiew/rDz3Lxz/3GWS1gUS3a7ocBIZxgLfP+da//BPc9Z5OoQ/gVImijKLsPKyfe5Jf/N3fZtsB3tP7Lq2/jXccR0IYefjy67z+5W+x3o50MTJKRJ1w0q0Y9gPXu4FxtD3U9900lzKnJTWvqGl/jIr3jq7r7BnfsfWB7vm73PvIi5w8+xS+P6Hre1znWKmj8x3BgY4j12/c55u//yestop3jugE/+QZeM/w8Jz9fsuzH/8gv/Brn+Mi7qHzqBO8CiKOqIoTQJVhP/A//Y//kkfvXvDD195gP47EELm8vOKlF5+jX8Hduyf8nb/zv2K9WeO9x4mf7amgCiq89dZb/LP/4Z8x7AZW3QoQZLvjydM7/MLPfZpf/Y3f4IMf+zBjJ+iqox8hjoHXfvwKf/blr/Da91/m4VvvIEOYaGo6p857BAgxQsWzgRk/XqLh9b+at7dnoablS/yg/n46O/PPqhN2wEeW+pX7ojqXfdrxIBCb7+t+5b3Wjk1VQGX2bGmTtGdrvikC9VjS3zfSkPRdjDrJEDHMzqkqKHO6tESfZv3E+KyqFjlp1DidM5b7lZ+/qf52ffPnrbyR+6HYfMVCMeWg/VJHCEVWVLV1a/deTRvyZ20fahml7s9S/8tzTohM+yGEMJ+Xhd+tsflc5TqNV3rQw3m0R6IdhOhAQCXy8Z/5GB/9+IeJQ+SV11/j/OKS3fUVF++8S7zc4oJyud1xHUb2GL3vxHN3veZ0fYKXubxV+mwTPxvD0l7ax8D/9eWvcvfuXd6r/IUFSM8++yze+wNr0ZtvvnlgVQJYr9es1+uDz09OTjg5O0Uj+CA45wluyk6RNxvAerOxv3USkGvhPsZYAA8wY241Ea5BU11HEe4qgJSBUS6ZEbtxJIoJ+vnfer1mv9/T9z3jOBKjtdd1xoS994QQ0OgQsT70gO+6UreIsFqtTCjznv1+j4gwjHvWmzXDOLA6PbXn1mu6rgNV4hjYbDYMw8A4jnjviTGmZ3runN5hfXaCXA9su45T8QSv9J0/GJ+qQmL+9ZzkeannrgaoNWjt03jz/DnnynMHhC0LuzoB2EyoRYRtCARg0/VlLvMzNUOpiWNLDDKh8YlRU/WhHmcXQXrovfLUvXuceM/K92y6DkRwAmsRHDC6yNgJp6endCj3zk5Zrdao6wldh2zWnPgOjSakn3y44+EPfsz6fKAbIqoR5xQcjB6u3cgHXvogp5s1oRdwIM4TxaHiEd8Tcaxcx/7RQzbPv8hTTz3B5u4dzs7OCGHknfvvsD7d8PST9+h2Ox5urznte8QpLioqEDogKr309NERrq+4e3ZGlI7ghSgOCcpKPQ+ffpoh7jlZrRicRztQ9ayAwaS0Ms/1/EeNBBnxvqfza9xqTeeAMBgTkBWjF4IbefqJJwjdCrdeQ+9ZS490Htf3hAj+xCMRnn/+eR7qm4VxaFJUxBgJErhzesf2h7PzqM4DDnXgQiAOAQ17ri4vcE7wfccmCr1TgoCLJgJoD098+AWefe4ZYm91Oe+LQLLv94T9SL9eEdcjn/z5T/HGl79Fvx3RGHFe8L5DnafzQkDpfCAo7MNIUPBuYrInXcem6xHRGR2qy1EBJ/3LdK4WDr2CE8em6/Des/IdXZf+icdh5//E9zgB0QgoQSL37tzlznPP8MyzTzOubU77rkNdR6dC5xwhCWe//Ju/yo/XZ7z1re/ix4gD9gQ46/jYJ36K09MT+tM7gMxoQQbSwUF3dpe7/W/xZ7//J7jzLSsRiJHoFNcJewl8/BM/Td95+tMV0nVFcMzjNZobufczH+ee63n1i9+g3430okQHm1WPBCVIwDub+953hWeoLM9zLSBrtW6ZBgWE05Oep194nrPnnkHunCLdGtf1dJ1jg8eJY3DKMA489fGPcILjW//yC6zFM4gab/MC1yuQyEd
"text/plain": [
"<Figure size 1000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10,10))\n",
"plt.imshow(image)\n",
"plt.axis('on')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "98b228b8",
"metadata": {},
"source": [
"## Selecting objects with SAM"
]
},
{
"cell_type": "markdown",
"id": "0bb1927b",
"metadata": {},
"source": [
"First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7e28150b",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append(\"..\")\n",
"from segment_anything import sam_model_registry, SamPredictor\n",
"\n",
2023-04-06 21:15:43 -07:00
"sam_checkpoint = \"sam_vit_h_4b8939.pth\"\n",
"model_type = \"vit_h\"\n",
"\n",
"device = \"cuda\"\n",
"\n",
2023-04-04 22:25:49 -07:00
"sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
"sam.to(device=device)\n",
"\n",
"predictor = SamPredictor(sam)"
]
},
{
"cell_type": "markdown",
"id": "c925e829",
"metadata": {},
"source": [
"Process the image to produce an image embedding by calling `SamPredictor.set_image`. `SamPredictor` remembers this embedding and will use it for subsequent mask prediction."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "d95d48dd",
"metadata": {},
"outputs": [],
"source": [
"predictor.set_image(image)"
]
},
{
"cell_type": "markdown",
"id": "d8fc7a46",
"metadata": {},
"source": [
"To select the truck, choose a point on it. Points are input to the model in (x,y) format and come with labels 1 (foreground point) or 0 (background point). Multiple points can be input; here we use only one. The chosen point will be shown as a star on the image."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "5c69570c",
"metadata": {},
"outputs": [],
"source": [
"input_point = np.array([[500, 375]])\n",
"input_label = np.array([1])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "a91ba973",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0gAAAI1CAYAAADsLNpwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9WbNsSZbfh/2W+94Rcc659+Y8VGWN3V3VXdXVU1X1jAYaUwugjEbCZCIk0UjpkYYnGp5I4xP4gm9AmvGJZjJKBpOJb6LJCJlIEEB3o4eah66uOatyzrw37z1TROztvvSw3H379thx8hZICa3s42k3zzkRe/vsa63/mlxUVbktt+W23Jbbcltuy225LbflttyW24L7N92B23JbbsttuS235bbclttyW27LbfmLUm4B0m25LbflttyW23JbbsttuS235bakcguQbsttuS235bbclttyW27LbbkttyWVW4B0W27Lbbktt+W23Jbbcltuy225LancAqTbcltuy225LbflttyW23JbbsttSeUWIN2W23JbbsttuS235bbclttyW25LKrcA6bbclttyW27Lbbktt+W23JbbcltSuQVIt+W23Jbbcltuy225LbflttyW25LKLUC6LbflttyW23JbbsttuS235bbcllRuAdJtuS235bbclttyW27Lbbktt+W2pPIXHiD9F//Ff8HHP/5xNpsNn/vc5/jn//yf/5vu0m25LbflttyW23JbbsttuS235X1a/kIDpH/yT/4J//F//B/zn/1n/xlf/OIX+Z3f+R3+7t/9u7z88sv/prt2W27Lbbktt+W23Jbbcltuy215HxZRVf033Ylj5dd//df57Gc/y3/5X/6X5bNPfepT/Lv/7r/LP/7H//jfYM9uy225LbflttyW23JbbsttuS3vx9L9m+7AsbLf7/nTP/1T/pP/5D+Zff57v/d7/P7v//7B87vdjt1uV/6OMXL//n2eeeYZROT/6/29LbflttyW23JbbsttuS235bb8xSyqyvn5OR/84Adx7mYnur+wAOntt98mhMALL7ww+/yFF17g9ddfP3j+H//jf8w/+kf/6P9X3bstt+W23Jbbcltuy225Lbfltvz/WfnRj37Ehz70oRuf+QsLkHJprT+qumgR+k//0/+Uf/gP/2H5++HDh3zkIx/hP/w//R/5+C98GhB8EBRQ7xGNpR4RQUQYx5ExRLy4gixVFe89YFap3L6IoKrlOREp3zvnyJ6L3vvyTv48hDD7W1WJMRJjLPWMIbDb7yD10XvPer1mGEf6riPEwNXlFSEExjBy5+wOIQRW6w0ahTCG0tZ+2DOOI13XsV6tSx8vLi8IY8B7h3hh1Xfs9jvCOCIi7Pd7NpuNjVenMYZxxDlHv+q5uLrkubMzfudzn2fzxAm8/S4P/+Ar9PsBOsF7V8ZVl3EcyxzZT63Wdno2z1M93845xtSHGONsDfPf3vupTRHESZnrti/Dbk8II8MwsF5vSpvtvsv9ze/ndavXUURwIrQ7VFURYCWeq5XQj8q9J+/ynT/4Al2InN47w3cOFHoUpxCcMnawk8DP/ZVfYRsG1usT8B3ar3AnGwaAaOvz6Mdv8uoXvsHdQdL74GNE0VLXeNbzi7/9qwxekd4h4gkKSId0HRHH1nk++Klf4Mmnn+fBuxc8vLxkN2yJRLaXVzzz/PM8+4HnWKvy8Ac/4LVvfYOOCCESiMROIUa89lwFx8d+5Vd49oMvERCCCBHBRfAjvPzn30HPH/Ct//v/g/GHr3LlBmIMrEYhOIcSZ3Od1zloZCTSiePUr7izOrEzKhF1DvUeVbj2kY//6i/gP/Ak1zKy3mwQPNI5nO9AIQ6Rte9569XX2V9clvNZznSMaI+NeX2COIc4AecBQRF0GCEqWyJPPvcM3/x//z7xmz9mvR8ZRNNagAd2olythU/+9d+kf+KO7enO9nVE2e8H4jiy3e+Q64Hv/qsvsXm4o48QUcQL3nVst9eMYyBGZRwj4zgwDtFowjiWvXfv3hOcnZ2gxHKW2n3clvKd/YWis/OqqgwK+xB45ZVX6bqO1arHeW+/uw5B6Zzjzt1T8kyB4pxwTWT1wWd59lM/TffkHQLQiQffsXYdnQgqwKi8/KWv8+g7P+Y0gGokagQiew/rDz3Lxz/3GWS1gUS3a7ocBIZxgLfP+da//BPc9Z5OoQ/gVImijKLsPKyfe5Jf/N3fZtsB3tP7Lq2/jXccR0IYefjy67z+5W+x3o50MTJKRJ1w0q0Y9gPXu4FxtD3U9900lzKnJTWvqGl/jIr3jq7r7BnfsfWB7vm73PvIi5w8+xS+P6Hre1znWKmj8x3BgY4j12/c55u//yestop3jugE/+QZeM/w8Jz9fsuzH/8gv/Brn+Mi7qHzqBO8CiKOqIoTQJVhP/A//Y//kkfvXvDD195gP47EELm8vOKlF5+jX8Hduyf8nb/zv2K9WeO9x4mf7amgCiq89dZb/LP/4Z8x7AZW3QoQZLvjydM7/MLPfZpf/Y3f4IMf+zBjJ+iqox8hjoHXfvwKf/blr/Da91/m4VvvIEOYaGo6p857BAgxQsWzgRk/XqLh9b+at7dnoablS/yg/n46O/PPqhN2wEeW+pX7ojqXfdrxIBCb7+t+5b3Wjk1VQGX2bGmTtGdrvikC9VjS3zfSkPRdjDrJEDHMzqkqKHO6tESfZv3E+KyqFjlp1DidM5b7lZ+/qf52ffPnrbyR+6HYfMVCMeWg/VJHCEVWVLV1a/deTRvyZ20fahml7s9S/8tzTohM+yGEMJ+Xhd+tsflc5TqNV3rQw3m0R6IdhOhAQCXy8Z/5GB/9+IeJQ+SV11/j/OKS3fUVF++8S7zc4oJyud1xHUb2GL3vxHN3veZ0fYKXubxV+mwTPxvD0l7ax8D/9eWvcvfuXd6r/IUFSM8++yze+wNr0ZtvvnlgVQJYr9es1+uDz09OTjg5O0Uj+CA45wluyk6RNxvAerOxv3USkGvhPsZYAA8wY241Ea5BU11HEe4qgJSBUS6ZEbtxJIoJ+vnfer1mv9/T9z3jOBKjtdd1xoS994QQ0OgQsT70gO+6UreIsFqtTCjznv1+j4gwjHvWmzXDOLA6PbXn1mu6rgNV4hjYbDYMw8A4jnjviTGmZ3runN5hfXaCXA9su45T8QSv9J0/GJ+qQmL+9ZzkeannrgaoNWjt03jz/DnnynMHhC0LuzoB2EyoRYRtCARg0/VlLvMzNUOpiWNLDDKh8YlRU/WhHmcXQXrovfLUvXuceM/K92y6DkRwAmsRHDC6yNgJp6endCj3zk5Zrdao6wldh2zWnPgOjSakn3y44+EPfsz6fKAbIqoR5xQcjB6u3cgHXvogp5s1oRdwIM4TxaHiEd8Tcaxcx/7RQzbPv8hTTz3B5u4dzs7OCGHknfvvsD7d8PST9+h2Ox5urznte8QpLioqEDogKr309NERrq+4e3ZGlI7ghSgOCcpKPQ+ffpoh7jlZrRicRztQ9ayAwaS0Ms/1/EeNBBnxvqfza9xqTeeAMBgTkBWjF4IbefqJJwjdCrdeQ+9ZS490Htf3hAj+xCMRnn/+eR7qm4VxaFJUxBgJErhzesf2h7PzqM4DDnXgQiAOAQ17ri4vcE7wfccmCr1TgoCLJgJoD098+AWefe4ZYm91Oe+LQLLv94T9SL9eEdcjn/z5T/HGl79Fvx3RGHFe8L5DnafzQkDpfCAo7MNIUPBuYrInXcem6xHRGR2qy1EBJ/3LdK4WDr2CE8em6/Des/IdXZf+icdh5//E9zgB0QgoQSL37tzlznPP8MyzTzOubU77rkNdR6dC5xwhCWe//Ju/yo/XZ7z1re/ix4gD9gQ46/jYJ36K09MT+tM7gMxoQQbSwUF3dpe7/W/xZ7//J7jzLSsRiJHoFNcJewl8/BM/Td95+tMV0nVFcMzjNZobufczH+ee63n1i9+g3430okQHm1WPBCVIwDub+953hWeoLM9zLSBrtW6ZBgWE05Oep194nrPnnkHunCLdGtf1dJ1jg8eJY3DKMA489fGPcILjW//yC6zFM4gab/MC1yuQyEd
"text/plain": [
"<Figure size 1000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10,10))\n",
"plt.imshow(image)\n",
"show_points(input_point, input_label, plt.gca())\n",
"plt.axis('on')\n",
"plt.show() "
]
},
{
"cell_type": "markdown",
"id": "c765e952",
"metadata": {},
"source": [
"Predict with `SamPredictor.predict`. The model returns masks, quality predictions for those masks, and low resolution mask logits that can be passed to the next iteration of prediction."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "5373fd68",
"metadata": {},
"outputs": [],
"source": [
"masks, scores, logits = predictor.predict(\n",
" point_coords=input_point,\n",
" point_labels=input_label,\n",
" multimask_output=True,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "c7f0e938",
"metadata": {},
"source": [
"With `multimask_output=True` (the default setting), SAM outputs 3 masks, where `scores` gives the model's own estimation of the quality of these masks. This setting is intended for ambiguous input prompts, and helps the model disambiguate different objects consistent with the prompt. When `False`, it will return a single mask. For ambiguous prompts such as a single point, it is recommended to use `multimask_output=True` even if only a single mask is desired; the best single mask can be chosen by picking the one with the highest score returned in `scores`. This will often result in a better mask."
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "47821187",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(3, 1200, 1800)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"masks.shape # (number_of_masks) x H x W"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e9c227a6",
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAIzCAYAAACHlG8YAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9ebhtR1ngj3+qaq29z3Sn3Js5IWFImIIYCBBASQgQGbSVRxRFWtqRp6UfxYlWm6+iNLRoMzzi0+2AOLTMirat/hAixAmZIYQEEgIJmW+SO55hD2tVvb8/ali11l773HNDxGm/eU7uOXvXqlX1VtU7v28pEREWsIAFLGABC1jAAhawgAUs4AEE/c89gAUsYAELWMACFrCABSxgAf/2YKFoLGABC1jAAhawgAUsYAELeMBhoWgsYAELWMACFrCABSxgAQt4wGGhaCxgAQtYwAIWsIAFLGABC3jAYaFoLGABC1jAAhawgAUsYAELeMBhoWgsYAELWMACFrCABSxgAQt4wGGhaCxgAQtYwAIWsIAFLGABC3jAYaFoLGABC1jAAhawgAUsYAELeMBhoWgsYAELWMACFrCABSxgAQt4wGGhaCxgAQv4dweXX345Sile9apX/XMPZQELWMACFrCAf7OwUDQWsIAF7Bhe9apXoZRKP+985ztP+Mzznve81jO33HLLP/1A/xnhC1/4Am9961t52ctexpOf/GRWVlbS3L9WcP311/OjP/qjPPaxj2XPnj0MBgPOOussLr74Yr7ne76H3/iN3+DGG2/8mo3n3xKsr6/zqle9isc85jGsra2xZ88envCEJ/D617+e6XT6Vff/F3/xF3zrt34rZ5xxBoPBgDPOOIPnPe95/N//+39P+Kxzjre97W1ceeWVnHrqqQyHQ84++2y+8zu/k7/927896bFUVcXXfd3Xpf37n/7Tf7ofM1rAAhbw7xpkAQtYwAJ2CL/wC78gQPq58sort21/xx13iDGm9czNN9/8tRnsNnDZZZcJIL/wC7/wT9Z338/XAn7lV35FiqJovXfv3r2yvLzc+uyyyy77mozn3xLccsstcv755yccrqysyHA4TH9ffPHFcvjw4fvVd13X8pKXvCT1pZSSffv2tdby+7//+8U51/v8xsaGXHnllamtMUb27dsnWuvU38nu9+55f8lLXnK/5raABSzg3y8sPBoLWMACThoOHDjA6uoqV111Fbfddtvcdn/wB3+AtZbzzz//aze4f2YoioJHPvKRvPjFL+YNb3gDP/ETP/E1e/d73/teXvGKV1DXNU972tN4//vfz2g04siRI2xtbXH77bfzjne8gxe84AUMBoOv2bj+LYC1lm/5lm/hlltu4cwzz+QDH/gAm5ubbG1t8c53vpNdu3bx6U9/mu/5nu+5X/3/wi/8Ar//+78PwI/92I9xzz33cPjwYY4ePcqb3vQmyrLkrW99K//jf/yP3udf+tKX8v73vx+tNa997Ws5cuQIhw8f5tChQ/y3//bfEBF+8Rd/kbe//e07Gs/nPvc5Xvva1/KQhzyE008//X7NaQELWMACFh6NBSxgATuGaOE877zzkvX11a9+9dz2F154oQDyqle96t+NR6Ou69bfv/u7v/s182g85SlPEUAuuugiqapq27ZbW1v/5OP5twRvectb0jp++MMfnvn+7W9/e/r+qquuOqm+77vvPllaWhJAvu3bvq23TTx7KysrcvDgwdZ31157bXr3y1/+8t7n43k966yzZDKZbDueuq7lkksuEUDe//73y3nnnbfwaCxgAQu4X7DwaCxgAQu4X/B93/d9APze7/0eIjLz/d///d9z44038pCHPISnPe1p2/Z1ww038Ku/+qs885nP5KEPfSjLy8vs3r2biy++mFe+8pXcd999c5+t65rf+q3f4vLLL+fAgQOUZcn+/ft5+MMfzgtf+ELe+ta3nvTcfv/3f5+yLFFK8XM/93Mn9awx5qTf90DBZz7zGQCe+9znUhTFtm2Xl5fnfre5uckb3vAGLrvsMg4cOMBwOOScc87hsssu4/Wvfz0HDx7sfe7qq6/mO77jOzj77LMZDoccOHCAZzzjGfzu7/4u1treZ2Lez+WXXw7AH//xH3PllVdy2mmnobWeSdg/duwYr3nNa3jSk57Evn37GA6HnHvuuXz3d383H/nIR7ad81cD0dvw9Kc/nSc/+ckz33/Xd30XD37wgwHvyTsZuOqqqxiPxwD89E//dG+bn/qpn0JrzdbWFu9+97tb3/3FX/xF+n3e8694xSsAuPPOO3n/+9+/7Xhe//rX84lPfILv/d7v5VnPetaO57GABSxgATPwz63pLGABC/jXA7lHwzknD33oQwWQv/mbv5lp+/3f//0CyC/90i/Jhz70oW09GtFiSogl37t3ryil0mdnn322fOELX5h5rq5redazntWKI9+zZ08rbr6PzG3n0fjlX/5lAURrLW9+85vvF55y+Fp6NFZWVgSQF73oRfe7j09+8pNy7rnnpjFrrWXfvn2t9XjjG98489yP//iPz6xhnp9zxRVXyPHjx2eei3vqsssuk5/4iZ9o5ScYY1pr9JGPfEROP/30Vh7Crl27Wu997Wtf2zuvPN/gZL1qm5ubKdfhV37lV+a2+8//+T8LIGecccZJ9f+6170uje3IkSNz2z3oQQ8SQJ73vOf1vnfPnj1zn51Opynf42Uve9ncdjfccIMsLS3JgQMH5L777hMRWXg0FrCABdxvWHg0FrCABdwvyKvQdL0Gm5ubvPvd70ZrvaNKNZdeeilvfvObuemmmxiPxxw5coTxeMxVV13FE5/4RO644w5e9KIXzTz3jne8gw984AMsLS3xlre8hfX1dY4ePcpoNOLgwYO8973v5du//dt3NB8R4cd+7Mf4mZ/5GYbDIe985zv5L//lv+zo2X8p8MQnPhGAd7/73bz97W/HOXdSz99222180zd9E7fddhvnnnsu73znO1lfX+fw4cOMRiOuvfZaXvWqV3Hqqae2nvv1X/913vjGNwLwwz/8w9x5550cOXKEY8eO8cY3vpGiKPjgBz/ID/3QD8199yc/+Une8IY38IpXvIKDBw9y+PBhNjc3k+fslltu4dnPfjYHDx7kBS94AZ/85CcZj8ccP36cgwcP8v/9f/8fxhh+7ud+jj/90z89qXmfCD7/+c8nXF500UVz28Xv7r77bg4fPny/3jXP85N/d+211570s865NId5z4sIP/ADP8B4POZNb3oT+/fv3+mwF7CABSygH/65NZ0FLGAB/3og92iIiNx6662itZbV1VVZX19P7d761rcKIM961rNERE7o0dgO1tfXkxX77/7u71rfRUvuD//wD59Un12PxmQykRe+8IXJKvyhD33opPrbDr6WHo2rr766VaXojDPOkO/8zu+UX/mVX5EPfvCDsrGxse3zL37xiwWQ/fv3y6233rqjd25tbckpp5wigHz3d393b5tf+7VfS2P6+Mc/3vou9zT8xE/8xNz3vOAFLxBA/uN//I9z27zhDW8QQB772MfOfPfVeDT+7M/+LD17zTXXzG33p3/6p6ndtddeu+P+3/Wud6Xnrr766t42hw8fTl6loiha3+UekVtuuaX3+U9/+tOpzcMe9rDeNnGdvumbvqn1+cKjsYAFLOD+wsKjsYAFLOB+w7nnnsszn/nM5MGI8Lu/+7sAfP/3f/9X/Y61tTUuu+wywOd95LB3717AW5DvLxw/fpxnP/vZvOtd7+LMM8/kb/7mb1K+wL82uOyyy3jf+97Hwx/+cMDj5d3vfjeveMUruOKKK9i3bx/Pe97zeu9U2Nzc5F3vehcAP/MzP8O55567o3d+4AMfSNb7eRcg/siP/Ahnnnkm4L1QfaC15r/+1//a+93hw4d573vfm8Y2D773e78XgGuuuWYmj+RVr3oVIoKInHQVtPX19fT7ysrK3Hb5d/kzJ4JnPOMZLC0tAfCa17ymt81rX/valAtV1zWj0Sh999znPjf9/t//+3/vfT7v9/jx4zPf33LLLfzsz/4sKysr/MZv/MaOx76ABSxgAdvBQtFYwAIW8FV
"text/plain": [
"<Figure size 1000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAIzCAYAAACHlG8YAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9d7xtyVXYiX9X1d7n3Pvui52jWrFboYVQRCjQrVZASOAxssgYIewBM58ZjzEDY/zRgADbH8t8LPljnPCQ7J89EnhgGI8BGYRamCSUU7darayOr7tfvOmcs3fV+v1RYdfeZ5/7XkNj0ln9uf3uPWfvCquqVl6rRFWVNaxhDWtYwxrWsIY1rGENa3gMwfxpD2ANa1jDGtawhjWsYQ1rWMNfPFgrGmtYwxrWsIY1rGENa1jDGh5zWCsaa1jDGtawhjWsYQ1rWMMaHnNYKxprWMMa1rCGNaxhDWtYwxoec1grGmtYwxrWsIY1rGENa1jDGh5zWCsaa1jDGtawhjWsYQ1rWMMaHnNYKxprWMMa1rCGNaxhDWtYwxoec1grGmtYwxrWsIY1rGENa1jDGh5zWCsaa1jDGtawhjWsYQ1rWMMaHnNYKxprWMMa/tLBrbfeiojw5je/+U97KGtYwxrWsIY1/IWFtaKxhjWs4aLhzW9+MyKSf97xjndc8J3Xvva1vXe+8IUv/MkP9E8JVJU/+IM/4E1vehO33norV155JXVdc+zYMZ773OfyQz/0Q9x3331/4uO48847+dt/+2/zrGc9i2PHjjGZTLjmmmt49rOfzbd927fxb/7Nv+Huu+/+Ex/HXyTY29vj13/91/kH/+Af8LrXvY4bbrgh7+nHUmE9efIk3//9389NN93E5uYml1xyCS996Uv56Z/+aVT1gu9/9rOf5Xu+53t4whOewMbGBldccQVf/dVfzS/90i8d+N7HPvYx/uE//Id89Vd/Nddeey2TyYQjR45w880387f/9t9e75c1rGENfzTQNaxhDWu4SPiRH/kRBfLPq171qgOfv++++9Ra23vn85///H+fwR4At9xyiwL6Iz/yI49pu//gH/yD3lxFRI8fP64ikj87evSo/r//7//7mPZbwj/5J/9Eq6rqjeP48eO6ubnZ++yWW275ExvDX0S4/fbbe/grfx6rffSBD3xAL7300tzu4cOHe2v5qle9Smez2cr3f/VXf1UPHTrU22vGmPz3G9/4RvXeL733H/7Df1ia07Fjx3pndzKZ6L/5N//mMZnnGtawhr88sPZorGENa3jUcNlll7G1tcW73vUu7rnnnpXP/ft//+9xzvH4xz/+v9/g/hShaRqOHj3K3/pbf4t3v/vd7O7ucubMGXZ3d/mlX/olHve4x3H+/Hm+4Ru+gU9+8pOPef+//Mu/zA/+4A/Sti1f9VVfxW/8xm+wv7/PmTNn2Nvb49577+Xtb387r3/965lMJo95/3/R4cSJE7z85S/nB37gB3j729/OVVdd9Zi1fe7cOb72a7+WU6dO8dSnPpX3v//9bG9vs7u7y7/4F/+Cuq75jd/4Db7v+75v9P3Pf/7zfOM3fiN7e3u8+MUv5lOf+hTnzp3j3Llz/PAP/zAAP/dzP8dP/MRPLL3bNA3T6ZRv//Zv51d/9Vc5d+4cZ8+eZW9vj3e9613cfPPNLBYLvvd7v5d3vetdj9mc17CGNfwlgD9tTWcNa1jDnx9IHo0bbrhB3/CGNyigP/7jP77y+RtvvFEBffOb3/yXwqPx4Q9/WE+fPr3y+8997nPZs/A3/sbfeEz7VlV90YtepIDefPPN2jTNgc/u7e095v3/RYa2bZc+u+GGGx6zffSmN71JAd3c3NTPfe5zS9//o3/0jxRQa61+6lOfWvr+27/92xXQq666Ss+cObP0/Xd/93dnL8dwj95111163333rRzbmTNn9KqrrlJAX/7ylz/6ya1hDWv4Swtrj8Ya1rCGPxK88Y1vBODnf/7nR2PHf/d3f5e7776bJz7xiXzVV33VgW196lOf4id+4id4xStewZOe9CQ2Nzc5evQoz372s3nTm97EI488svLdtm35t//233Lrrbdy2WWXUdc1l156KTfddBPf9E3fxM/+7M8+6rn9u3/376jrGhHh7//9v3/R7335l385J06cWPn9E57wBF72spcB8P73v/9Rj+tC8JGPfASA17zmNVRVdeCzm5ubK7/b3d3lrW99K7fccguXXXYZ0+mU6667jltuuYV/+k//KSdPnhx97z3veQ/f8A3fwLXXXst0OuWyyy7j5S9/OT/3cz+Hc270nZT3c+uttwLwS7/0S7zqVa/iiiuuwBizlP9w7tw5/uE//Id8xVd8BSdOnGA6nXL99dfzLd/yLbz3ve89cM5/HLDW/om1DcH7B/DN3/zNPOEJT1j6/n/5X/4XDh8+jHOO//gf/2Pvu+QxA/je7/1ejh8/vvT+D/3QDwFw/vx5fuVXfqX33U033cQ111yzcmzHjx/nda97HfAns2/XsIY1/AWGP21NZw1rWMOfHyg9Gt57fdKTnqSA/vZv//bSs9/1Xd+lgP7Yj/1YL759zKORLMOsyGu49tpr9a677lp6r21bfeUrX7kUWz6dTnufDeEgj8Y//sf/WAE1xuhP/uRP/pHwdBC87nWvU0Cf8YxnPOZtp/j8b/3Wb/0jt/HBD35Qr7/++ow7Y4yeOHGitx5ve9vblt77vu/7vqU1LGP8b7vtNj1//vzSe2lP3XLLLfp3/+7fze+fOHFCrbW9NXrve9+rV155ZW7TWqtHjhzp9fuP/tE/Gp1XmV/0WHnVHiuPxl133ZXH9ou/+Isrn/uar/kaBfSFL3xh7/N3vvOd+f33ve99K99/2tOepoB+8zd/86MeY1qbra2tR/3uGtawhr+8sPZorGENa/gjgYjwnd/5nQBLXoPd3V1+8Rd/EWNMfuYgeOELX8hP/uRP8pnPfIbZbMaZM2eYzWa8613v4gUveAH33Xcf3/qt37r03tvf/nZ+8zd/k42NDX76p3+a7e1tzp49y/7+PidPnuSXf/mX+Wt/7a9d1HxUlf/1f/1f+Xt/7+8xnU55xzvewf/8P//PF/XuxULTNPze7/0eAM985jMf07YBXvCCFwDwi7/4i/xf/9f/hff+Ub1/zz338NVf/dXcc889XH/99bzjHe9ge3ub06dPs7+/z8c//nHe/OY3c/nll/fe+xf/4l/wtre9DYDv/u7v5v777+fMmTOcO3eOt73tbVRVxbvf/W7+x//xf1zZ9wc/+EHe+ta38oM/+IOcPHmS06dPs7u7mz1nX/jCF3j1q1/NyZMnef3rX88HP/hBZrMZ58+f5+TJk/wf/8f/gbWWv//3//6Sxf7POnziE5/Iv998880rn0vf3XnnnSvff8YznnHB9++4445HPcb3vOc9wJ/Mvl3DGtbwFxj+tDWdNaxhDX9+oPRoqKp+6UtfUmOMbm1t6fb2dn7uZ3/2ZxXQV77ylaqqF/RoHATb29vZiv07v/M7ve++93u/VwH97u/+7kfV5tCjMZ/P9Zu+6ZuyR+T2229/VO1dLCRvCaC/9Vu/9Zi3/573vKdXpeiqq67Sb/zGb9R/8k/+ib773e/WnZ2dA99Pcf6XXnqpfulLX7qoPvf29vSSSy5RQL/lW75l9Jl//s//eR7T+9///t53pafh7/7dv7uyn9e//vUK6F//63995TNvfetbFdBnPetZS9/9WfZolPg5d+7cyuf+2T/7Z/m58rwlb8OJEycO7Ofv/J2/k9f30cA73vGO3O/P/MzPPKp317CGNfzlhrVHYw1rWMMfGa6//npe8YpXZA9Ggp/7uZ8D4Lu+67v+2H0cPnyYW265BQh5HyWkWPQHH3zwj9z++fPnefWrX80v/MIvcPXVV/Pbv/3bOV/gsYTf/d3fzdV/vuVbvoXbbrvtMe/jlltu4Z3vfCc33XQTEPDyi7/4i/zgD/4gt912GydOnOC1r30t/+2//beld3d3d/mFX/gFAP7e3/t7XH/99RfV52/+5m9y+vRpgJX3SfxP/9P/xNVXXw0EL9QYGGP43//3/330u9OnT/P
"text/plain": [
"<Figure size 1000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAIzCAYAAACHlG8YAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9ebxtR1ngD3+raq29z3Sn3MwDYUqYghgSIAKaEAYZ2qlFUURRu1ts2m4cabVpQGkcUIGP2N2igEO3MjU4+6pEiK0yQwiBhCSEhMw3ufMZ9rBW1fP+UcOqtfba554Lofmp+8nn5J6z91o1PFX1zM9TSkSEBSxgAQtYwAIWsIAFLGABC3gAQX+1B7CABSxgAQtYwAIWsIAFLOCfHywUjQUsYAELWMACFrCABSxgAQ84LBSNBSxgAQtYwAIWsIAFLGABDzgsFI0FLGABC1jAAhawgAUsYAEPOCwUjQUsYAELWMACFrCABSxgAQ84LBSNBSxgAQtYwAIWsIAFLGABDzgsFI0FLGABC1jAAhawgAUsYAEPOCwUjQUsYAELWMACFrCABSxgAQ84LBSNBSxgAQtYwAIWsIAFLGABDzgsFI0FLGAB/+LgiiuuQCnFq1/96q/2UBawgAUsYAEL+GcLC0VjAQtYwI7h1a9+NUqp9POOd7zjhO8873nPa71z2223feUH+lWCqqr47d/+bV7ykpfwpCc9ifPOO4/l5WVWVlZ42MMexgtf+EKuuuqqr/g4rr/+ev7Tf/pPPO5xj2PPnj0MBgPOPvtsLr74Yr7ne76H3/zN3+Smm276io/jnyOsr6/z6le/msc+9rGsra2xZ88envCEJ/Brv/ZrTKfTL7v9v/iLv+BbvuVbOPPMMxkMBpx55pk873nP40/+5E/mvnP11Ve3ztiJfn7u536ut53RaMQb3vAGvv7rv579+/dTliV79uzhkksu4b/8l//Cvffe+2XPbwELWMC/MJAFLGABC9ghvOpVrxIg/TzrWc/a9vm77rpLjDGtd2699db/N4PdBi6//HIB5FWvetUD2u7999/fmqtSSvbt2zeDgxe/+MVSVdUD2neE173udVIURau/vXv3yvLycuuzyy+//CvS/z9nuO222+TBD35wwuHKyooMh8P098UXXyyHDx/+ktqu61pe/OIXz+ydfC1/8Ad/UJxzM+/+4z/+o5xxxhnb/qytraV2/uIv/qJ3bhdccEFrj+zZs0e01q2//+7v/u5Lmt8CFrCAf5mw8GgsYAELOGk49dRTWV1d5aqrruKOO+6Y+9zv//7vY63lwQ9+8P+7wX0VYTgc8h//43/kne98J7fddhuTyYTDhw8znU657rrr+K7v+i4Afu/3fo9f/dVffcD7f+9738vLX/5y6rrmG77hG/ibv/kbRqMRR44cYWtrizvvvJO3v/3tPP/5z2cwGDzg/f9zBmst3/RN38Rtt93GWWedxfve9z42NzfZ2triHe94B7t27eKaa67he77ne76k9l/1qlfxe7/3ewC87GUv47777uPw4cMcPXqUN77xjZRlydve9jZ+8Rd/cebdJz/5ydx7773b/lxxxRUAnHPOOXzjN37jTBvf933fx80338xgMOA3fuM3WF9f5+jRo4xGI/7kT/6Ec845h2PHjvGd3/mdjEajL2mOC1jAAv4Fwldb01nAAhbwTweiR+P8889P1tfXvOY1c5+/8MILBZBXv/rV/yI8GicC55w8+clPFkAe/vCHP+Dtx7YvuuiiE3pMtra2HvD+/znDW97ylrSHP/jBD858/4d/+Ifp+6uuuuqk2j548KAsLS0JIN/6rd/a+0w8eysrK3LgwIGTaj/3LL7iFa+Y+f62225LY593Jq666qr0zF/91V+dVP8LWMAC/uXCwqOxgAUs4EuCH/iBHwDgd3/3dxGRme//4R/+gZtuuomHPvShfMM3fMO2bd144438yq/8Cs94xjN42MMexvLyMrt37+biiy/mFa94BQcPHpz7bl3X/NZv/RZXXHEFp556KmVZsn//fh7xiEfwghe8gLe97W0nPbff+73foyxLlFL87M/+7Em/Pw+UUjzpSU8C4M4773zA2o3wqU99CoDnPve5FEWx7bPLy8tzv9vc3OT1r389l19+OaeeeirD4ZBzzz2Xyy+/nF/7tV/jwIEDve9dffXVfMd3fAfnnHMOw+GQU089lac//en8zu/8Dtba3ndi3k+0uL/nPe/hWc96Fqeffjpa65mE/WPHjvHa176WJz3pSezbt4/hcMh5553Hd3/3d/PhD3942zl/ORC9DU972tP4uq/7upnvv+u7vouHPOQhgPfknQxcddVVjMdjAH7qp36q95mf/MmfRGvN1tYW73rXu06q/d/93d/FWotSih/8wR+c+f6ee+5Jv1966aW9bTzxiU9Mv29sbJxU/wtYwAL+BcNXW9NZwAIW8E8Hco+Gc04e9rCHCdAbt/2DP/iDAsjP//zPywc+8IFtPRrnn39+KzZ97969opRKn51zzjnyuc99bua9uq7lmc985kxceR4330fmtvNo/NIv/ZIAorWWN73pTV8SnuaBtVae9KQnCSCPecxjHtC2RURWVlYEkBe+8IVfchuf+MQn5Lzzzku401rLvn37Wuvxhje8Yea9H/uxH5tZwzw35corr5Tjx4/PvBf31OWXXy4//uM/PpPbkq/Rhz/8YTnjjDNSm8YY2bVrV6vfX/iFX+idV55fdLJetc3NzZSr8LrXvW7uc//+3/97AeTMM888qfZ/+Zd/OY3tyJEjc5970IMeJIA873nP23Hb+Tl9xjOe0fvMgQMHduzR0FrLLbfcsuP+F7CABfzLhoVHYwELWMCXBEopvv/7vx9gxmuwubnJu971LrTW6Znt4LLLLuNNb3oTn//85xmPxxw5coTxeMxVV13FE5/4RO666y5e+MIXzrz39re/nfe9730sLS3xlre8pRVXfuDAAd773vfy7d/+7Tuaj4jwspe9jJ/+6Z9mOBzyjne8gx/5kR/Z0bsngsOHD/MP//APfNu3fRsf+chHAPiJn/iJB6TtHKLV+V3vehd/+Id/iHPupN6/4447+MZv/EbuuOMOzjvvPN7xjnewvr7O4cOHGY1GXHfddbz61a/mtNNOa733G7/xG7zhDW8A4Id+6Ie4++67OXLkCMeOHeMNb3gDRVHw/ve/n3/37/7d3L4/8YlP8PrXv56Xv/zlHDhwgMOHD7O5uZk8Z7fddhvPfvazOXDgAM9//vP5xCc+wXg85vjx4xw4cID/+l//K8YYfvZnf5Y//uM/Pql5nwhuuOGGhMuLLrpo7nPxu3vvvZfDhw9/SX3N8/zk31133XU7bu/qq6/mlltuAeDf/tt/2/vM6aefzr/+1/8agF/8xV/kv//3/568FlVV8ad/+qe8+MUvBvy+fehDH7rj/hewgAX8C4evtqazgAUs4J8O5B4NEZHbb79dtNayuroq6+vr6bm3ve1tAsgzn/lMEZETejS2g/X19WTF/vu///vWd9GC/EM/9EMn1WbXozGZTOQFL3hB8oh84AMfOKn2+uAXf/EXW16V+LO2ttbrEXgg4Oqrr25VKTrzzDPlO7/zO+V1r3udvP/975eNjY1t33/Ri14kgOzfv19uv/32HfW5tbUlp5xyigDy3d/93b3P/Pqv/3oa08c+9rHWd7mn4cd//Mfn9vP85z9fAPne7/3euc+8/vWvF0Ae97jHzXz35Xg0/vRP/zS9e+2118597o//+I/Tc9ddd92O23/nO9+Z3rv66qt7nzl8+HDyKhVFseO2v+d7viet6Xg8nvvc4cOH5Ru/8RvnVp167GMfK7/927+9434XsIAFLEBk4dFYwAIW8GXAeeedxzOe8YzkwYjwO7/zOwC98eAnC2tra1x++eWAz/vIYe/evQBfVn3/48eP8+xnP5t3vvOdnHXWWfzd3/1dyhf4cmBtbY0zzjiD0047DaUUACsrK7zmNa/h3/ybf/Nlt98Hl19+OX/1V3/FIx7xCMDj5V3vehcvf/nLufLKK9m3bx/Pe97z+L//9//OvLu5uck73/lOAH76p3+a884
"text/plain": [
"<Figure size 1000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"for i, (mask, score) in enumerate(zip(masks, scores)):\n",
" plt.figure(figsize=(10,10))\n",
" plt.imshow(image)\n",
" show_mask(mask, plt.gca())\n",
" show_points(input_point, input_label, plt.gca())\n",
" plt.title(f\"Mask {i+1}, Score: {score:.3f}\", fontsize=18)\n",
" plt.axis('off')\n",
" plt.show() \n",
" "
]
},
{
"cell_type": "markdown",
"id": "3fa31f7c",
"metadata": {},
"source": [
"## Specifying a specific object with additional points"
]
},
{
"cell_type": "markdown",
"id": "88d6d29a",
"metadata": {},
"source": [
"The single input point is ambiguous, and the model has returned multiple objects consistent with it. To obtain a single object, multiple points can be provided. If available, a mask from a previous iteration can also be supplied to the model to aid in prediction. When specifying a single object with multiple prompts, a single mask can be requested by setting `multimask_output=False`."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f6923b94",
"metadata": {},
"outputs": [],
"source": [
"input_point = np.array([[500, 375], [1125, 625]])\n",
"input_label = np.array([1, 1])\n",
"\n",
"mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "d98f96a1",
"metadata": {},
"outputs": [],
"source": [
"masks, _, _ = predictor.predict(\n",
" point_coords=input_point,\n",
" point_labels=input_label,\n",
" mask_input=mask_input[None, :, :],\n",
" multimask_output=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "0ce8b82f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1, 1200, 1800)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"masks.shape"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e06d5c8d",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAIYCAYAAADq/5rtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9Wa8lS3agiX3LzH3vM0TEjTtn5s2RZCaZySRZTM5ksYpFVrO7utEPhW61BAgQ9Av0LkAQBEiAfoMeBDQg6UmAHlutQqO7WqqBYzLngTnnzTsPMZxpb3c3W3pYZubmtn2fiGxRECCEXcQ95+ztbsOyZWtey0RVlWftWXvWnrVn7Vl71p61Z+1Ze9aetX/A5v5/PYFn7Vl71p61Z+1Ze9aetWftWXvW/v+vPVM0nrVn7Vl71p61Z+1Ze9aetWftWfsHb88UjWftWXvWnrVn7Vl71p61Z+1Ze9b+wdszReNZe9aetWftWXvWnrVn7Vl71p61f/D2TNF41p61Z+1Ze9aetWftWXvWnrVn7R+8PVM0nrVn7Vl71p61Z+1Ze9aetWftWfsHb88UjWftWXvWnrVn7Vl71p61Z+1Ze9b+wdszReNZe9aetWftWXvWnrVn7Vl71p61f/DWPe2Df/zrv8U//k//Q57/5GtohB4PCMGBE8E701lijIQQUFU636EKqkqMkRgjAM45RIRxHAHYbreEEAB7Nn+f/xaR0kf+3TmH9x6AcRzL33n8GGPpR0QYw8QUAiEE+r4H4ObmBhFhs9kwDAMAfd/jvWeaJhs/Wp/DMDCOIyLCNE3cuXOHzWbDOI5M08SjR484Pz+n6zpCnBCJeO8ZhoHr62tEBO893jmmYSww8t6jqmw2G/bTnhe2p/zp7/0+2+dOie++z+N/+1VOxoD2sOn7AodpmnDOFbi2MOq6jhhjgV+GZb6fse97XNqzPBfn3OJnvRfOOXtehKix9O2cK/COMeKQAtvT09MFDuXx6nnUc8t7Vq/DpX91y2vqVbjZCBsVXnnto3z1v/pv2Sic3D2j7zt7BhCBPRNDp9z/+Kv82h/+Ju8/+ADf9Ui3xZ+eEjc9QQSiw0/wzje/z/tf+3vujkJA2XvFCYgqKsq0Ee79wmt88Y9+mwdXjxBRxHVEBKRDup6I49pv+MQ/+m1OT+4SgjKpsh93fPDoQ7ZOeO7Oc8j5KRIDlz/6Ke///bc4lQkEJhdRHxGNaPBcxA2v/PIX+fgvfIYoavPF4YPSB8ePvvVdpofv8Zf/x/8zdx9c8VhvCDFyGhyh8+n86cF5xAlB7fycbDdsNz2qIK7DO0cMA75zjPdO+Cf/8/+MHz14mzMVTtUTthv8dgPiiFOEMfLWT37GxcUFExHN5x0BAe8cfiP8wuc+g4gQI+A6EMMtVdApgio7p5zdu8uP/vVfcf2X3+E0RG5cYBLYug6GCbzj2iu/8Z//OeHOhq7rcJ0nAiEGphAYh4GwG7h68z1++O++zP3JQ4hMoriuQ3DsdjumKRCCEkJgmibGIZTfY7TzfP/+fU5OtogTMlpmGpNxs235u6hKrM5j/fyEsJsmXn/9dbbbLZvNppy5TdfjgN4Ld+/ewSOggRgDXmDsPS/88qd57hc/gbt7Ttdv8M7juo6N75EYCSGye/8hf/df/bfc34ETZecVnLANSiAyvXTOH/zn/wk3CFrR4kLXNSLDxLt/+Q3e/cb3UYHJCz6CD5GI4dfo4N4nP8of/sv/mHfiNRFhqx5X0agQArubG77+f//XnL53TR+UoIGhs/N6t98wjIHr/Vj4RNd1MyxZ0qeaBoYQFvOueQXOs/MTn/v938C/cE5/5w6bzRm+2yC9IFOk73vUOWSY+Jv/+r9j9+b7uADee8ZOkLunqBPCwyvisOPzf/glvvTHf8B7+0uuo/EIr4LHpxMQiVPg+uqa//q/+lc8enDFT99+h5tpZL8bmMaJj33kJU5PO37/93+LP/uzf8bV1SUhBLz3FZ1XxikSI/z0xz/hX/0//hvCOEGEje+5Nygvb895+ewu027PTiLx3il/8B/9KV/8g99DOm/4p9CrsHt4wY+++R2+/Zdf5mevv852s2UYB0h7pKo4ZMGP81xq3M3fZ7qe97im+cda7mON7+c5qCocOWP2uxQ+mt9ZO5c1P8zzzd/VfDJqXJzT9kznd+v+7TkBdYvPFvi6st6Mt/U8WtjU/dRrqXn+wTO4RR817637XMhZGK81OhVtzhVM1t7N74dKHqjnuTZubu28630JVPvRPFfvVZY9VA2+IKtwq3Eiv5/7zXJjPY81WOf9Kv07QSr608Km5rM1jukK35hx0+NkKX/ZeEqUiETFBQcIsYv80q9+js9+4bOM4rja3/Dhwwc8fO8D3vr+jxk/eMzzJ+dc7fY82l1zHUfGGNg6z3ObU877LZ757C7gZaBcxZn2vPwfvv/XB3vbtqdWNMAE+qgRjRAVFCHiEFFCQ2BUlf2wp/P9AqAt8tSbvHbI6k3Lm51/AoUYt60mBOM4Mkwju/2+COF937PZbNjtdqgq42gMTVW5f/8+AM55gsI0TUUpyExyHMcyblYisjLiO4f3JOFlous69vt9UnBs/K7rinKVBX1VY6YYXU1zcCgTzvkFIfLel/XVykaGS1aKZkSdYVkz7JZY1kpGfcjKnogUBMyfZYFIVZmGcYGEbf9P024jTgslJR2SGCLTfm8w3Q8zEczPhwAScU744O13ePToETc3N5ydebou7a9zRO9ABUEZdSIITBrBGQGWGBEURYkh8t6bb3F9cUnUgPcOjZEpanre2ZMu0jlhd3ON4hlC4GZ/xc31FT98/XW+8IUvcv/eGV4cJ2dbU2I04oAQIkqwcdWhIXB+dorDcEWEQmDHcWScTEHZ9BuiXhiEJAtcRqzqNu8rKAGi4iZH7zsQ0BhMgfIeojJe7/jpd76PvHDG4NK5jAE/TeA8MVgf/ekJ2zDhxpGYcFvEhteoTGPg4cOHPHfvPlNUJNr7iBCiwhQRhZGJ8/NzbvZ7xDs0RJJcYeceCCiTRB48fszp9nkUwSlJIYYYIE6KIkjfE50nIvafGONwsmRmrZDTKut2APLPn7/VTDLvX1wRfAr+KyCaBBnAYUKBc6A27w8++IDuoy9yen5CmAac80gMRB/x4ogxEDoh9I4YBR2nsv5JI+qFi5trHl1cEPoeSXhTn90hTHiFx+OOwZnSTVQkgERFHKZqiPD+B+/z6NFD9KxLSp8SKxiGZPTxJ1smf02nAlFAIyrGX0KiazUNKk1ZKBP1HpVHGvplgFbw8JOf/IQX/cc5ccImCq6b6ILHI6gIURQJkeiFYIiGqBKi4mLEdT04IQr8+Cc/5hd+4ws82l8Re1cUDSeZL0U0Rm5ubsygJba2EIJhUoxkAenrX/86n//8r7DdbtPaQHWa6bEKMSqPH1+QsBrnHeKESSJTB/sOONvwy1/8PL/zZ/+E05efBwQJ4MURhpGf/vDHfP0v/5Yffffv0esBorK7uZn5qzhU40KgbIXrmicfU7Bbvr72TP6+3een4Rs1D2z7q/FgoUhUQnH97FKgPaSV9XyyYnPb2mq5RoUilIYQDhSdVqZp198KeLcpcO2Mnpr/ylLuYkXAPBhrRajPRplW+WxlwDUlYJ7K4dqfpq3tdf3dGizqfW/HWsP3uckCB1rFr4VPMeSKzAJes8Zj+J8VXyem3MSoeN/hEDREMzhd3rB7dEncj/hotHgaRqZxYtzvmeJE0Mgkkeg2uI0DZWGYKTi5Atd2Xk/C/7r9XIrGzc2NCcFO0ABZh6wJRCZUZgVUAvMiskCe38nImCddP1czj7xxXdctiF29wXkT20OYNd4s8APldxHh3r17BdBd1zFNU2GC261pl13X4b0nhMAwDOX3rIBki9k0TUUZyR6FVlA3HJuRN8+l73t2496eT0Jc522+gqDRLOm5v/qw1gTzNqtI3p98GGpLVK2t19p7vae53fZ7JqTtAfx5Wmt5qltZd/qfJvhMo3mJukxg1IRVj1nluiQ4TePEsNsBhhthv7f1e88YJojO3js9YZDI5ASnikZwMSkZTstBHG5uCF5
"text/plain": [
"<Figure size 1000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10,10))\n",
"plt.imshow(image)\n",
"show_mask(masks, plt.gca())\n",
"show_points(input_point, input_label, plt.gca())\n",
"plt.axis('off')\n",
"plt.show() "
]
},
{
"cell_type": "markdown",
"id": "c93e2087",
"metadata": {},
"source": [
"To exclude the car and specify just the window, a background point (with label 0, here shown in red) can be supplied."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "9a196f68",
"metadata": {},
"outputs": [],
"source": [
"input_point = np.array([[500, 375], [1125, 625]])\n",
"input_label = np.array([1, 0])\n",
"\n",
"mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "81a52282",
"metadata": {},
"outputs": [],
"source": [
"masks, _, _ = predictor.predict(\n",
" point_coords=input_point,\n",
" point_labels=input_label,\n",
" mask_input=mask_input[None, :, :],\n",
" multimask_output=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "bfca709f",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAIYCAYAAADq/5rtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9Wa8lS3agiX3LzH3vM0TEjTtn5s2RZCaZySRZTM5ksYpFVrO7utEPhW61BAgQ9Av0LkAQBEiAfoMeBDQg6UmAHlutQqO7WqqBYzLngTnnzTsPMZxpb3c3W3pYZubmtn2fiGxRECCEXcQ95+ztbsOyZWtey0RVlWftWXvWnrVn7Vl71p61Z+1Ze9aetX/A5v5/PYFn7Vl71p61Z+1Ze9aetWftWXvW/v+vPVM0nrVn7Vl71p61Z+1Ze9aetWftWfsHb88UjWftWXvWnrVn7Vl71p61Z+1Ze9b+wdszReNZe9aetWftWXvWnrVn7Vl71p61f/D2TNF41p61Z+1Ze9aetWftWXvWnrVn7R+8PVM0nrVn7Vl71p61Z+1Ze9aetWftWfsHb88UjWftWXvWnrVn7Vl71p61Z+1Ze9b+wdszReNZe9aetWftWXvWnrVn7Vl71p61f/DWPe2Df/zrv8U//k//Q57/5GtohB4PCMGBE8E701lijIQQUFU636EKqkqMkRgjAM45RIRxHAHYbreEEAB7Nn+f/xaR0kf+3TmH9x6AcRzL33n8GGPpR0QYw8QUAiEE+r4H4ObmBhFhs9kwDAMAfd/jvWeaJhs/Wp/DMDCOIyLCNE3cuXOHzWbDOI5M08SjR484Pz+n6zpCnBCJeO8ZhoHr62tEBO893jmmYSww8t6jqmw2G/bTnhe2p/zp7/0+2+dOie++z+N/+1VOxoD2sOn7AodpmnDOFbi2MOq6jhhjgV+GZb6fse97XNqzPBfn3OJnvRfOOXtehKix9O2cK/COMeKQAtvT09MFDuXx6nnUc8t7Vq/DpX91y2vqVbjZCBsVXnnto3z1v/pv2Sic3D2j7zt7BhCBPRNDp9z/+Kv82h/+Ju8/+ADf9Ui3xZ+eEjc9QQSiw0/wzje/z/tf+3vujkJA2XvFCYgqKsq0Ee79wmt88Y9+mwdXjxBRxHVEBKRDup6I49pv+MQ/+m1OT+4SgjKpsh93fPDoQ7ZOeO7Oc8j5KRIDlz/6Ke///bc4lQkEJhdRHxGNaPBcxA2v/PIX+fgvfIYoavPF4YPSB8ePvvVdpofv8Zf/x/8zdx9c8VhvCDFyGhyh8+n86cF5xAlB7fycbDdsNz2qIK7DO0cMA75zjPdO+Cf/8/+MHz14mzMVTtUTthv8dgPiiFOEMfLWT37GxcUFExHN5x0BAe8cfiP8wuc+g4gQI+A6EMMtVdApgio7p5zdu8uP/vVfcf2X3+E0RG5cYBLYug6GCbzj2iu/8Z//OeHOhq7rcJ0nAiEGphAYh4GwG7h68z1++O++zP3JQ4hMoriuQ3DsdjumKRCCEkJgmibGIZTfY7TzfP/+fU5OtogTMlpmGpNxs235u6hKrM5j/fyEsJsmXn/9dbbbLZvNppy5TdfjgN4Ld+/ewSOggRgDXmDsPS/88qd57hc/gbt7Ttdv8M7juo6N75EYCSGye/8hf/df/bfc34ETZecVnLANSiAyvXTOH/zn/wk3CFrR4kLXNSLDxLt/+Q3e/cb3UYHJCz6CD5GI4dfo4N4nP8of/sv/mHfiNRFhqx5X0agQArubG77+f//XnL53TR+UoIGhs/N6t98wjIHr/Vj4RNd1MyxZ0qeaBoYQFvOueQXOs/MTn/v938C/cE5/5w6bzRm+2yC9IFOk73vUOWSY+Jv/+r9j9+b7uADee8ZOkLunqBPCwyvisOPzf/glvvTHf8B7+0uuo/EIr4LHpxMQiVPg+uqa//q/+lc8enDFT99+h5tpZL8bmMaJj33kJU5PO37/93+LP/uzf8bV1SUhBLz3FZ1XxikSI/z0xz/hX/0//hvCOEGEje+5Nygvb895+ewu027PTiLx3il/8B/9KV/8g99DOm/4p9CrsHt4wY+++R2+/Zdf5mevv852s2UYB0h7pKo4ZMGP81xq3M3fZ7qe97im+cda7mON7+c5qCocOWP2uxQ+mt9ZO5c1P8zzzd/VfDJqXJzT9kznd+v+7TkBdYvPFvi6st6Mt/U8WtjU/dRrqXn+wTO4RR817637XMhZGK81OhVtzhVM1t7N74dKHqjnuTZubu28630JVPvRPFfvVZY9VA2+IKtwq3Eiv5/7zXJjPY81WOf9Kv07QSr608Km5rM1jukK35hx0+NkKX/ZeEqUiETFBQcIsYv80q9+js9+4bOM4rja3/Dhwwc8fO8D3vr+jxk/eMzzJ+dc7fY82l1zHUfGGNg6z3ObU877LZ757C7gZaBcxZn2vPwfvv/XB3vbtqdWNMAE+qgRjRAVFCHiEFFCQ2BUlf2wp/P9AqAt8tSbvHbI6k3Lm51/AoUYt60mBOM4Mkwju/2+COF937PZbNjtdqgq42gMTVW5f/8+AM55gsI0TUUpyExyHMcyblYisjLiO4f3JOFlous69vt9UnBs/K7rinKVBX1VY6YYXU1zcCgTzvkFIfLel/XVykaGS1aKZkSdYVkz7JZY1kpGfcjKnogUBMyfZYFIVZmGcYGEbf9P024jTgslJR2SGCLTfm8w3Q8zEczPhwAScU744O13ePToETc3N5ydebou7a9zRO9ABUEZdSIITBrBGQGWGBEURYkh8t6bb3F9cUnUgPcOjZEpanre2ZMu0jlhd3ON4hlC4GZ/xc31FT98/XW+8IUvcv/eGV4cJ2dbU2I04oAQIkqwcdWhIXB+dorDcEWEQmDHcWScTEHZ9BuiXhiEJAtcRqzqNu8rKAGi4iZH7zsQ0BhMgfIeojJe7/jpd76PvHDG4NK5jAE/TeA8MVgf/ekJ2zDhxpGYcFvEhteoTGPg4cOHPHfvPlNUJNr7iBCiwhQRhZGJ8/NzbvZ7xDs0RJJcYeceCCiTRB48fszp9nkUwSlJIYYYIE6KIkjfE50nIvafGONwsmRmrZDTKut2APLPn7/VTDLvX1wRfAr+KyCaBBnAYUKBc6A27w8++IDuoy9yen5CmAac80gMRB/x4ogxEDoh9I4YBR2nsv5JI+qFi5trHl1cEPoeSXhTn90hTHiFx+OOwZnSTVQkgERFHKZqiPD+B+/z6NFD9KxLSp8SKxiGZPTxJ1smf02nAlFAIyrGX0KiazUNKk1ZKBP1HpVHGvplgFbw8JOf/IQX/cc5ccImCq6b6ILHI6gIURQJkeiFYIiGqBKi4mLEdT04IQr8+Cc/5hd+4ws82l8Re1cUDSeZL0U0Rm5ubsygJba2EIJhUoxkAenrX/86n//8r7DdbtPaQHWa6bEKMSqPH1+QsBrnHeKESSJTB/sOONvwy1/8PL/zZ/+E05efBwQJ4MURhpGf/vDHfP0v/5Yffffv0esBorK7uZn5qzhU40KgbIXrmicfU7Bbvr72TP6+3een4Rs1D2z7q/FgoUhUQnH97FKgPaSV9XyyYnPb2mq5RoUilIYQDhSdVqZp198KeLcpcO2Mnpr/ylLuYkXAPBhrRajPRplW+WxlwDUlYJ7K4dqfpq3tdf3dGizqfW/HWsP3uckCB1rFr4VPMeSKzAJes8Zj+J8VXyem3MSoeN/hEDREMzhd3rB7dEncj/hotHgaRqZxYtzvmeJE0Mgkkeg2uI0DZWGYKTi5Atd2Xk/C/7r9XIrGzc2NCcFO0ABZh6wJRCZUZgVUAvMiskCe38nImCddP1czj7xxXdctiF29wXkT20OYNd4s8APldxHh3r17BdBd1zFNU2GC261pl13X4b0nhMAwDOX3rIBki9k0TUUZyR6FVlA3HJuRN8+l73t2496eT0Jc522+gqDRLOm5v/qw1gTzNqtI3p98GGpLVK2t19p7vae53fZ7JqTtAfx5Wmt5qltZd/qfJvhMo3mJukxg1IRVj1nluiQ4TePEsNsBhhthv7f1e88YJojO3js9YZDI5ASnikZwMSkZTstBHG5uCF5
"text/plain": [
"<Figure size 1000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(image)\n",
"show_mask(masks, plt.gca())\n",
"show_points(input_point, input_label, plt.gca())\n",
"plt.axis('off')\n",
"plt.show() "
]
},
{
"cell_type": "markdown",
"id": "41e2d5a9",
"metadata": {},
"source": [
"## Specifying a specific object with a box"
]
},
{
"cell_type": "markdown",
"id": "d61ca7ac",
"metadata": {},
"source": [
"The model can also take a box as input, provided in xyxy format."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "8ea92a7b",
"metadata": {},
"outputs": [],
"source": [
"input_box = np.array([425, 600, 700, 875])"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "b35a8814",
"metadata": {},
"outputs": [],
"source": [
"masks, _, _ = predictor.predict(\n",
" point_coords=None,\n",
" point_labels=None,\n",
" box=input_box[None, :],\n",
" multimask_output=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "984b79c1",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAIYCAYAAADq/5rtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9Wa8lS3agiX3LzH3vM0TEjTtn5s2RZCaZySRZTM5ksYpFVrO7utEPhW61BAgQ9Av0LkAQBEiAfoMeBDQg6UmAHlutQqO7WqqBYzLngTnnzTsPMZxpb3c3W3pYZubmtn2fiGxRECCEXcQ95+ztbsOyZWtey0RVlWftWXvWnrVn7Vl71p61Z+1Ze9aetX/A5v5/PYFn7Vl71p61Z+1Ze9aetWftWXvW/v+vPVM0nrVn7Vl71p61Z+1Ze9aetWftWfsHb88UjWftWXvWnrVn7Vl71p61Z+1Ze9b+wdszReNZe9aetWftWXvWnrVn7Vl71p61f/D2TNF41p61Z+1Ze9aetWftWXvWnrVn7R+8PVM0nrVn7Vl71p61Z+1Ze9aetWftWfsHb88UjWftWXvWnrVn7Vl71p61Z+1Ze9b+wdszReNZe9aetWftWXvWnrVn7Vl71p61f/DWPe2Df/zrv8U//k//Q57/5GtohB4PCMGBE8E701lijIQQUFU636EKqkqMkRgjAM45RIRxHAHYbreEEAB7Nn+f/xaR0kf+3TmH9x6AcRzL33n8GGPpR0QYw8QUAiEE+r4H4ObmBhFhs9kwDAMAfd/jvWeaJhs/Wp/DMDCOIyLCNE3cuXOHzWbDOI5M08SjR484Pz+n6zpCnBCJeO8ZhoHr62tEBO893jmmYSww8t6jqmw2G/bTnhe2p/zp7/0+2+dOie++z+N/+1VOxoD2sOn7AodpmnDOFbi2MOq6jhhjgV+GZb6fse97XNqzPBfn3OJnvRfOOXtehKix9O2cK/COMeKQAtvT09MFDuXx6nnUc8t7Vq/DpX91y2vqVbjZCBsVXnnto3z1v/pv2Sic3D2j7zt7BhCBPRNDp9z/+Kv82h/+Ju8/+ADf9Ui3xZ+eEjc9QQSiw0/wzje/z/tf+3vujkJA2XvFCYgqKsq0Ee79wmt88Y9+mwdXjxBRxHVEBKRDup6I49pv+MQ/+m1OT+4SgjKpsh93fPDoQ7ZOeO7Oc8j5KRIDlz/6Ke///bc4lQkEJhdRHxGNaPBcxA2v/PIX+fgvfIYoavPF4YPSB8ePvvVdpofv8Zf/x/8zdx9c8VhvCDFyGhyh8+n86cF5xAlB7fycbDdsNz2qIK7DO0cMA75zjPdO+Cf/8/+MHz14mzMVTtUTthv8dgPiiFOEMfLWT37GxcUFExHN5x0BAe8cfiP8wuc+g4gQI+A6EMMtVdApgio7p5zdu8uP/vVfcf2X3+E0RG5cYBLYug6GCbzj2iu/8Z//OeHOhq7rcJ0nAiEGphAYh4GwG7h68z1++O++zP3JQ4hMoriuQ3DsdjumKRCCEkJgmibGIZTfY7TzfP/+fU5OtogTMlpmGpNxs235u6hKrM5j/fyEsJsmXn/9dbbbLZvNppy5TdfjgN4Ld+/ewSOggRgDXmDsPS/88qd57hc/gbt7Ttdv8M7juo6N75EYCSGye/8hf/df/bfc34ETZecVnLANSiAyvXTOH/zn/wk3CFrR4kLXNSLDxLt/+Q3e/cb3UYHJCz6CD5GI4dfo4N4nP8of/sv/mHfiNRFhqx5X0agQArubG77+f//XnL53TR+UoIGhs/N6t98wjIHr/Vj4RNd1MyxZ0qeaBoYQFvOueQXOs/MTn/v938C/cE5/5w6bzRm+2yC9IFOk73vUOWSY+Jv/+r9j9+b7uADee8ZOkLunqBPCwyvisOPzf/glvvTHf8B7+0uuo/EIr4LHpxMQiVPg+uqa//q/+lc8enDFT99+h5tpZL8bmMaJj33kJU5PO37/93+LP/uzf8bV1SUhBLz3FZ1XxikSI/z0xz/hX/0//hvCOEGEje+5Nygvb895+ewu027PTiLx3il/8B/9KV/8g99DOm/4p9CrsHt4wY+++R2+/Zdf5mevv852s2UYB0h7pKo4ZMGP81xq3M3fZ7qe97im+cda7mON7+c5qCocOWP2uxQ+mt9ZO5c1P8zzzd/VfDJqXJzT9kznd+v+7TkBdYvPFvi6st6Mt/U8WtjU/dRrqXn+wTO4RR817637XMhZGK81OhVtzhVM1t7N74dKHqjnuTZubu28630JVPvRPFfvVZY9VA2+IKtwq3Eiv5/7zXJjPY81WOf9Kv07QSr608Km5rM1jukK35hx0+NkKX/ZeEqUiETFBQcIsYv80q9+js9+4bOM4rja3/Dhwwc8fO8D3vr+jxk/eMzzJ+dc7fY82l1zHUfGGNg6z3ObU877LZ757C7gZaBcxZn2vPwfvv/XB3vbtqdWNMAE+qgRjRAVFCHiEFFCQ2BUlf2wp/P9AqAt8tSbvHbI6k3Lm51/AoUYt60mBOM4Mkwju/2+COF937PZbNjtdqgq42gMTVW5f/8+AM55gsI0TUUpyExyHMcyblYisjLiO4f3JOFlous69vt9UnBs/K7rinKVBX1VY6YYXU1zcCgTzvkFIfLel/XVykaGS1aKZkSdYVkz7JZY1kpGfcjKnogUBMyfZYFIVZmGcYGEbf9P024jTgslJR2SGCLTfm8w3Q8zEczPhwAScU744O13ePToETc3N5ydebou7a9zRO9ABUEZdSIITBrBGQGWGBEURYkh8t6bb3F9cUnUgPcOjZEpanre2ZMu0jlhd3ON4hlC4GZ/xc31FT98/XW+8IUvcv/eGV4cJ2dbU2I04oAQIkqwcdWhIXB+dorDcEWEQmDHcWScTEHZ9BuiXhiEJAtcRqzqNu8rKAGi4iZH7zsQ0BhMgfIeojJe7/jpd76PvHDG4NK5jAE/TeA8MVgf/ekJ2zDhxpGYcFvEhteoTGPg4cOHPHfvPlNUJNr7iBCiwhQRhZGJ8/NzbvZ7xDs0RJJcYeceCCiTRB48fszp9nkUwSlJIYYYIE6KIkjfE50nIvafGONwsmRmrZDTKut2APLPn7/VTDLvX1wRfAr+KyCaBBnAYUKBc6A27w8++IDuoy9yen5CmAac80gMRB/x4ogxEDoh9I4YBR2nsv5JI+qFi5trHl1cEPoeSXhTn90hTHiFx+OOwZnSTVQkgERFHKZqiPD+B+/z6NFD9KxLSp8SKxiGZPTxJ1smf02nAlFAIyrGX0KiazUNKk1ZKBP1HpVHGvplgFbw8JOf/IQX/cc5ccImCq6b6ILHI6gIURQJkeiFYIiGqBKi4mLEdT04IQr8+Cc/5hd+4ws82l8Re1cUDSeZL0U0Rm5ubsygJba2EIJhUoxkAenrX/86n//8r7DdbtPaQHWa6bEKMSqPH1+QsBrnHeKESSJTB/sOONvwy1/8PL/zZ/+E05efBwQJ4MURhpGf/vDHfP0v/5Yffffv0esBorK7uZn5qzhU40KgbIXrmicfU7Bbvr72TP6+3een4Rs1D2z7q/FgoUhUQnH97FKgPaSV9XyyYnPb2mq5RoUilIYQDhSdVqZp198KeLcpcO2Mnpr/ylLuYkXAPBhrRajPRplW+WxlwDUlYJ7K4dqfpq3tdf3dGizqfW/HWsP3uckCB1rFr4VPMeSKzAJes8Zj+J8VXyem3MSoeN/hEDREMzhd3rB7dEncj/hotHgaRqZxYtzvmeJE0Mgkkeg2uI0DZWGYKTi5Atd2Xk/C/7r9XIrGzc2NCcFO0ABZh6wJRCZUZgVUAvMiskCe38nImCddP1czj7xxXdctiF29wXkT20OYNd4s8APldxHh3r17BdBd1zFNU2GC261pl13X4b0nhMAwDOX3rIBki9k0TUUZyR6FVlA3HJuRN8+l73t2496eT0Jc522+gqDRLOm5v/qw1gTzNqtI3p98GGpLVK2t19p7vae53fZ7JqTtAfx5Wmt5qltZd/qfJvhMo3mJukxg1IRVj1nluiQ4TePEsNsBhhthv7f1e88YJojO3js9YZDI5ASnikZwMSkZTstBHG5uCF5
"text/plain": [
"<Figure size 1000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(image)\n",
"show_mask(masks[0], plt.gca())\n",
"show_box(input_box, plt.gca())\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "c1ed9f0a",
"metadata": {},
"source": [
"## Combining points and boxes"
]
},
{
"cell_type": "markdown",
"id": "8455d1c5",
"metadata": {},
"source": [
"Points and boxes may be combined, just by including both types of prompts to the predictor. Here this can be used to select just the trucks's tire, instead of the entire wheel."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "90e2e547",
"metadata": {},
"outputs": [],
"source": [
"input_box = np.array([425, 600, 700, 875])\n",
"input_point = np.array([[575, 750]])\n",
"input_label = np.array([0])"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "6956d8c4",
"metadata": {},
"outputs": [],
"source": [
"masks, _, _ = predictor.predict(\n",
" point_coords=input_point,\n",
" point_labels=input_label,\n",
" box=input_box,\n",
" multimask_output=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "8e13088a",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAIYCAYAAADq/5rtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9Wa8lS3agiX3LzH3vM0TEjTtn5s2RZCaZySRZTM5ksYpFVrO7utEPhW61BAgQ9Av0LkAQBEiAfoMeBDQg6UmAHlutQqO7WqqBYzLngTnnzTsPMZxpb3c3W3pYZubmtn2fiGxRECCEXcQ95+ztbsOyZWtey0RVlWftWXvWnrVn7Vl71p61Z+1Ze9aetX/A5v5/PYFn7Vl71p61Z+1Ze9aetWftWXvW/v+vPVM0nrVn7Vl71p61Z+1Ze9aetWftWfsHb88UjWftWXvWnrVn7Vl71p61Z+1Ze9b+wdszReNZe9aetWftWXvWnrVn7Vl71p61f/D2TNF41p61Z+1Ze9aetWftWXvWnrVn7R+8PVM0nrVn7Vl71p61Z+1Ze9aetWftWfsHb88UjWftWXvWnrVn7Vl71p61Z+1Ze9b+wdszReNZe9aetWftWXvWnrVn7Vl71p61f/DWPe2Df/zrv8U//k//Q57/5GtohB4PCMGBE8E701lijIQQUFU636EKqkqMkRgjAM45RIRxHAHYbreEEAB7Nn+f/xaR0kf+3TmH9x6AcRzL33n8GGPpR0QYw8QUAiEE+r4H4ObmBhFhs9kwDAMAfd/jvWeaJhs/Wp/DMDCOIyLCNE3cuXOHzWbDOI5M08SjR484Pz+n6zpCnBCJeO8ZhoHr62tEBO893jmmYSww8t6jqmw2G/bTnhe2p/zp7/0+2+dOie++z+N/+1VOxoD2sOn7AodpmnDOFbi2MOq6jhhjgV+GZb6fse97XNqzPBfn3OJnvRfOOXtehKix9O2cK/COMeKQAtvT09MFDuXx6nnUc8t7Vq/DpX91y2vqVbjZCBsVXnnto3z1v/pv2Sic3D2j7zt7BhCBPRNDp9z/+Kv82h/+Ju8/+ADf9Ui3xZ+eEjc9QQSiw0/wzje/z/tf+3vujkJA2XvFCYgqKsq0Ee79wmt88Y9+mwdXjxBRxHVEBKRDup6I49pv+MQ/+m1OT+4SgjKpsh93fPDoQ7ZOeO7Oc8j5KRIDlz/6Ke///bc4lQkEJhdRHxGNaPBcxA2v/PIX+fgvfIYoavPF4YPSB8ePvvVdpofv8Zf/x/8zdx9c8VhvCDFyGhyh8+n86cF5xAlB7fycbDdsNz2qIK7DO0cMA75zjPdO+Cf/8/+MHz14mzMVTtUTthv8dgPiiFOEMfLWT37GxcUFExHN5x0BAe8cfiP8wuc+g4gQI+A6EMMtVdApgio7p5zdu8uP/vVfcf2X3+E0RG5cYBLYug6GCbzj2iu/8Z//OeHOhq7rcJ0nAiEGphAYh4GwG7h68z1++O++zP3JQ4hMoriuQ3DsdjumKRCCEkJgmibGIZTfY7TzfP/+fU5OtogTMlpmGpNxs235u6hKrM5j/fyEsJsmXn/9dbbbLZvNppy5TdfjgN4Ld+/ewSOggRgDXmDsPS/88qd57hc/gbt7Ttdv8M7juo6N75EYCSGye/8hf/df/bfc34ETZecVnLANSiAyvXTOH/zn/wk3CFrR4kLXNSLDxLt/+Q3e/cb3UYHJCz6CD5GI4dfo4N4nP8of/sv/mHfiNRFhqx5X0agQArubG77+f//XnL53TR+UoIGhs/N6t98wjIHr/Vj4RNd1MyxZ0qeaBoYQFvOueQXOs/MTn/v938C/cE5/5w6bzRm+2yC9IFOk73vUOWSY+Jv/+r9j9+b7uADee8ZOkLunqBPCwyvisOPzf/glvvTHf8B7+0uuo/EIr4LHpxMQiVPg+uqa//q/+lc8enDFT99+h5tpZL8bmMaJj33kJU5PO37/93+LP/uzf8bV1SUhBLz3FZ1XxikSI/z0xz/hX/0//hvCOEGEje+5Nygvb895+ewu027PTiLx3il/8B/9KV/8g99DOm/4p9CrsHt4wY+++R2+/Zdf5mevv852s2UYB0h7pKo4ZMGP81xq3M3fZ7qe97im+cda7mON7+c5qCocOWP2uxQ+mt9ZO5c1P8zzzd/VfDJqXJzT9kznd+v+7TkBdYvPFvi6st6Mt/U8WtjU/dRrqXn+wTO4RR817637XMhZGK81OhVtzhVM1t7N74dKHqjnuTZubu28630JVPvRPFfvVZY9VA2+IKtwq3Eiv5/7zXJjPY81WOf9Kv07QSr608Km5rM1jukK35hx0+NkKX/ZeEqUiETFBQcIsYv80q9+js9+4bOM4rja3/Dhwwc8fO8D3vr+jxk/eMzzJ+dc7fY82l1zHUfGGNg6z3ObU877LZ757C7gZaBcxZn2vPwfvv/XB3vbtqdWNMAE+qgRjRAVFCHiEFFCQ2BUlf2wp/P9AqAt8tSbvHbI6k3Lm51/AoUYt60mBOM4Mkwju/2+COF937PZbNjtdqgq42gMTVW5f/8+AM55gsI0TUUpyExyHMcyblYisjLiO4f3JOFlous69vt9UnBs/K7rinKVBX1VY6YYXU1zcCgTzvkFIfLel/XVykaGS1aKZkSdYVkz7JZY1kpGfcjKnogUBMyfZYFIVZmGcYGEbf9P024jTgslJR2SGCLTfm8w3Q8zEczPhwAScU744O13ePToETc3N5ydebou7a9zRO9ABUEZdSIITBrBGQGWGBEURYkh8t6bb3F9cUnUgPcOjZEpanre2ZMu0jlhd3ON4hlC4GZ/xc31FT98/XW+8IUvcv/eGV4cJ2dbU2I04oAQIkqwcdWhIXB+dorDcEWEQmDHcWScTEHZ9BuiXhiEJAtcRqzqNu8rKAGi4iZH7zsQ0BhMgfIeojJe7/jpd76PvHDG4NK5jAE/TeA8MVgf/ekJ2zDhxpGYcFvEhteoTGPg4cOHPHfvPlNUJNr7iBCiwhQRhZGJ8/NzbvZ7xDs0RJJcYeceCCiTRB48fszp9nkUwSlJIYYYIE6KIkjfE50nIvafGONwsmRmrZDTKut2APLPn7/VTDLvX1wRfAr+KyCaBBnAYUKBc6A27w8++IDuoy9yen5CmAac80gMRB/x4ogxEDoh9I4YBR2nsv5JI+qFi5trHl1cEPoeSXhTn90hTHiFx+OOwZnSTVQkgERFHKZqiPD+B+/z6NFD9KxLSp8SKxiGZPTxJ1smf02nAlFAIyrGX0KiazUNKk1ZKBP1HpVHGvplgFbw8JOf/IQX/cc5ccImCq6b6ILHI6gIURQJkeiFYIiGqBKi4mLEdT04IQr8+Cc/5hd+4ws82l8Re1cUDSeZL0U0Rm5ubsygJba2EIJhUoxkAenrX/86n//8r7DdbtPaQHWa6bEKMSqPH1+QsBrnHeKESSJTB/sOONvwy1/8PL/zZ/+E05efBwQJ4MURhpGf/vDHfP0v/5Yffffv0esBorK7uZn5qzhU40KgbIXrmicfU7Bbvr72TP6+3een4Rs1D2z7q/FgoUhUQnH97FKgPaSV9XyyYnPb2mq5RoUilIYQDhSdVqZp198KeLcpcO2Mnpr/ylLuYkXAPBhrRajPRplW+WxlwDUlYJ7K4dqfpq3tdf3dGizqfW/HWsP3uckCB1rFr4VPMeSKzAJes8Zj+J8VXyem3MSoeN/hEDREMzhd3rB7dEncj/hotHgaRqZxYtzvmeJE0Mgkkeg2uI0DZWGYKTi5Atd2Xk/C/7r9XIrGzc2NCcFO0ABZh6wJRCZUZgVUAvMiskCe38nImCddP1czj7xxXdctiF29wXkT20OYNd4s8APldxHh3r17BdBd1zFNU2GC261pl13X4b0nhMAwDOX3rIBki9k0TUUZyR6FVlA3HJuRN8+l73t2496eT0Jc522+gqDRLOm5v/qw1gTzNqtI3p98GGpLVK2t19p7vae53fZ7JqTtAfx5Wmt5qltZd/qfJvhMo3mJukxg1IRVj1nluiQ4TePEsNsBhhthv7f1e88YJojO3js9YZDI5ASnikZwMSkZTstBHG5uCF5
"text/plain": [
"<Figure size 1000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(image)\n",
"show_mask(masks[0], plt.gca())\n",
"show_box(input_box, plt.gca())\n",
"show_points(input_point, input_label, plt.gca())\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "45ddbca3",
"metadata": {},
"source": [
"## Batched prompt inputs"
]
},
{
"cell_type": "markdown",
"id": "df6f18a0",
"metadata": {},
"source": [
"SamPredictor can take multiple input prompts for the same image, using `predict_torch` method. This method assumes input points are already torch tensors and have already been transformed to the input frame. For example, imagine we have several box outputs from an object detector."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "0a06681b",
"metadata": {},
"outputs": [],
"source": [
"input_boxes = torch.tensor([\n",
" [75, 275, 1725, 850],\n",
" [425, 600, 700, 875],\n",
" [1375, 550, 1650, 800],\n",
" [1240, 675, 1400, 750],\n",
"], device=predictor.device)"
]
},
{
"cell_type": "markdown",
"id": "bf957d16",
"metadata": {},
"source": [
"Transform the boxes to the input frame, then predict masks. `SamPredictor` stores the necessary transform as the `transform` field for easy access, though it can also be instantiated directly for use in e.g. a dataloader (see `segment_anything.utils.transforms`)."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "117521a3",
"metadata": {},
"outputs": [],
"source": [
"transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])\n",
"masks, _, _ = predictor.predict_torch(\n",
" point_coords=None,\n",
" point_labels=None,\n",
" boxes=transformed_boxes,\n",
" multimask_output=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "6a8f5d49",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([4, 1, 1200, 1800])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"masks.shape # (batch_size) x (num_predicted_masks_per_input) x H x W"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "c00c3681",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAIYCAYAAADq/5rtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9Wa8lS3agiX3LzH3vM0TEjTtn5s2RZCaZySRZTM5ksYpFVrO7utEPhW61BAgQ9Av0LkAQBEiAfoMeBDQg6UmAHlutQqO7WqqBYzLngTnnzTsPMZxpb3c3W3pYZubmtn2fiGxRECCEXcQ95+ztbsOyZWtey0RVlWftWXvWnrVn7Vl71p61Z+1Ze9aetX/A5v5/PYFn7Vl71p61Z+1Ze9aetWftWXvW/v+vPVM0nrVn7Vl71p61Z+1Ze9aetWftWfsHb88UjWftWXvWnrVn7Vl71p61Z+1Ze9b+wdszReNZe9aetWftWXvWnrVn7Vl71p61f/D2TNF41p61Z+1Ze9aetWftWXvWnrVn7R+8PVM0nrVn7Vl71p61Z+1Ze9aetWftWfsHb88UjWftWXvWnrVn7Vl71p61Z+1Ze9b+wdszReNZe9aetWftWXvWnrVn7Vl71p61f/DWPe2Df/zrv8U//k//Q57/5GtohB4PCMGBE8E701lijIQQUFU636EKqkqMkRgjAM45RIRxHAHYbreEEAB7Nn+f/xaR0kf+3TmH9x6AcRzL33n8GGPpR0QYw8QUAiEE+r4H4ObmBhFhs9kwDAMAfd/jvWeaJhs/Wp/DMDCOIyLCNE3cuXOHzWbDOI5M08SjR484Pz+n6zpCnBCJeO8ZhoHr62tEBO893jmmYSww8t6jqmw2G/bTnhe2p/zp7/0+2+dOie++z+N/+1VOxoD2sOn7AodpmnDOFbi2MOq6jhhjgV+GZb6fse97XNqzPBfn3OJnvRfOOXtehKix9O2cK/COMeKQAtvT09MFDuXx6nnUc8t7Vq/DpX91y2vqVbjZCBsVXnnto3z1v/pv2Sic3D2j7zt7BhCBPRNDp9z/+Kv82h/+Ju8/+ADf9Ui3xZ+eEjc9QQSiw0/wzje/z/tf+3vujkJA2XvFCYgqKsq0Ee79wmt88Y9+mwdXjxBRxHVEBKRDup6I49pv+MQ/+m1OT+4SgjKpsh93fPDoQ7ZOeO7Oc8j5KRIDlz/6Ke///bc4lQkEJhdRHxGNaPBcxA2v/PIX+fgvfIYoavPF4YPSB8ePvvVdpofv8Zf/x/8zdx9c8VhvCDFyGhyh8+n86cF5xAlB7fycbDdsNz2qIK7DO0cMA75zjPdO+Cf/8/+MHz14mzMVTtUTthv8dgPiiFOEMfLWT37GxcUFExHN5x0BAe8cfiP8wuc+g4gQI+A6EMMtVdApgio7p5zdu8uP/vVfcf2X3+E0RG5cYBLYug6GCbzj2iu/8Z//OeHOhq7rcJ0nAiEGphAYh4GwG7h68z1++O++zP3JQ4hMoriuQ3DsdjumKRCCEkJgmibGIZTfY7TzfP/+fU5OtogTMlpmGpNxs235u6hKrM5j/fyEsJsmXn/9dbbbLZvNppy5TdfjgN4Ld+/ewSOggRgDXmDsPS/88qd57hc/gbt7Ttdv8M7juo6N75EYCSGye/8hf/df/bfc34ETZecVnLANSiAyvXTOH/zn/wk3CFrR4kLXNSLDxLt/+Q3e/cb3UYHJCz6CD5GI4dfo4N4nP8of/sv/mHfiNRFhqx5X0agQArubG77+f//XnL53TR+UoIGhs/N6t98wjIHr/Vj4RNd1MyxZ0qeaBoYQFvOueQXOs/MTn/v938C/cE5/5w6bzRm+2yC9IFOk73vUOWSY+Jv/+r9j9+b7uADee8ZOkLunqBPCwyvisOPzf/glvvTHf8B7+0uuo/EIr4LHpxMQiVPg+uqa//q/+lc8enDFT99+h5tpZL8bmMaJj33kJU5PO37/93+LP/uzf8bV1SUhBLz3FZ1XxikSI/z0xz/hX/0//hvCOEGEje+5Nygvb895+ewu027PTiLx3il/8B/9KV/8g99DOm/4p9CrsHt4wY+++R2+/Zdf5mevv852s2UYB0h7pKo4ZMGP81xq3M3fZ7qe97im+cda7mON7+c5qCocOWP2uxQ+mt9ZO5c1P8zzzd/VfDJqXJzT9kznd+v+7TkBdYvPFvi6st6Mt/U8WtjU/dRrqXn+wTO4RR817637XMhZGK81OhVtzhVM1t7N74dKHqjnuTZubu28630JVPvRPFfvVZY9VA2+IKtwq3Eiv5/7zXJjPY81WOf9Kv07QSr608Km5rM1jukK35hx0+NkKX/ZeEqUiETFBQcIsYv80q9+js9+4bOM4rja3/Dhwwc8fO8D3vr+jxk/eMzzJ+dc7fY82l1zHUfGGNg6z3ObU877LZ757C7gZaBcxZn2vPwfvv/XB3vbtqdWNMAE+qgRjRAVFCHiEFFCQ2BUlf2wp/P9AqAt8tSbvHbI6k3Lm51/AoUYt60mBOM4Mkwju/2+COF937PZbNjtdqgq42gMTVW5f/8+AM55gsI0TUUpyExyHMcyblYisjLiO4f3JOFlous69vt9UnBs/K7rinKVBX1VY6YYXU1zcCgTzvkFIfLel/XVykaGS1aKZkSdYVkz7JZY1kpGfcjKnogUBMyfZYFIVZmGcYGEbf9P024jTgslJR2SGCLTfm8w3Q8zEczPhwAScU744O13ePToETc3N5ydebou7a9zRO9ABUEZdSIITBrBGQGWGBEURYkh8t6bb3F9cUnUgPcOjZEpanre2ZMu0jlhd3ON4hlC4GZ/xc31FT98/XW+8IUvcv/eGV4cJ2dbU2I04oAQIkqwcdWhIXB+dorDcEWEQmDHcWScTEHZ9BuiXhiEJAtcRqzqNu8rKAGi4iZH7zsQ0BhMgfIeojJe7/jpd76PvHDG4NK5jAE/TeA8MVgf/ekJ2zDhxpGYcFvEhteoTGPg4cOHPHfvPlNUJNr7iBCiwhQRhZGJ8/NzbvZ7xDs0RJJcYeceCCiTRB48fszp9nkUwSlJIYYYIE6KIkjfE50nIvafGONwsmRmrZDTKut2APLPn7/VTDLvX1wRfAr+KyCaBBnAYUKBc6A27w8++IDuoy9yen5CmAac80gMRB/x4ogxEDoh9I4YBR2nsv5JI+qFi5trHl1cEPoeSXhTn90hTHiFx+OOwZnSTVQkgERFHKZqiPD+B+/z6NFD9KxLSp8SKxiGZPTxJ1smf02nAlFAIyrGX0KiazUNKk1ZKBP1HpVHGvplgFbw8JOf/IQX/cc5ccImCq6b6ILHI6gIURQJkeiFYIiGqBKi4mLEdT04IQr8+Cc/5hd+4ws82l8Re1cUDSeZL0U0Rm5ubsygJba2EIJhUoxkAenrX/86n//8r7DdbtPaQHWa6bEKMSqPH1+QsBrnHeKESSJTB/sOONvwy1/8PL/zZ/+E05efBwQJ4MURhpGf/vDHfP0v/5Yffffv0esBorK7uZn5qzhU40KgbIXrmicfU7Bbvr72TP6+3een4Rs1D2z7q/FgoUhUQnH97FKgPaSV9XyyYnPb2mq5RoUilIYQDhSdVqZp198KeLcpcO2Mnpr/ylLuYkXAPBhrRajPRplW+WxlwDUlYJ7K4dqfpq3tdf3dGizqfW/HWsP3uckCB1rFr4VPMeSKzAJes8Zj+J8VXyem3MSoeN/hEDREMzhd3rB7dEncj/hotHgaRqZxYtzvmeJE0Mgkkeg2uI0DZWGYKTi5Atd2Xk/C/7r9XIrGzc2NCcFO0ABZh6wJRCZUZgVUAvMiskCe38nImCddP1czj7xxXdctiF29wXkT20OYNd4s8APldxHh3r17BdBd1zFNU2GC261pl13X4b0nhMAwDOX3rIBki9k0TUUZyR6FVlA3HJuRN8+l73t2496eT0Jc522+gqDRLOm5v/qw1gTzNqtI3p98GGpLVK2t19p7vae53fZ7JqTtAfx5Wmt5qltZd/qfJvhMo3mJukxg1IRVj1nluiQ4TePEsNsBhhthv7f1e88YJojO3js9YZDI5ASnikZwMSkZTstBHG5uCF5
"text/plain": [
"<Figure size 1000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(10, 10))\n",
"plt.imshow(image)\n",
"for mask in masks:\n",
" show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)\n",
"for box in input_boxes:\n",
" show_box(box.cpu().numpy(), plt.gca())\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "8bea70c0",
"metadata": {},
"source": [
"## End-to-end batched inference"
]
},
{
"cell_type": "markdown",
"id": "89c3ba52",
"metadata": {},
"source": [
"If all prompts are available in advance, it is possible to run SAM directly in an end-to-end fashion. This also allows batching over images."
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "45c01ae4",
"metadata": {},
"outputs": [],
"source": [
"image1 = image # truck.jpg from above\n",
"image1_boxes = torch.tensor([\n",
" [75, 275, 1725, 850],\n",
" [425, 600, 700, 875],\n",
" [1375, 550, 1650, 800],\n",
" [1240, 675, 1400, 750],\n",
"], device=sam.device)\n",
"\n",
"image2 = cv2.imread('images/groceries.jpg')\n",
"image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)\n",
"image2_boxes = torch.tensor([\n",
" [450, 170, 520, 350],\n",
" [350, 190, 450, 350],\n",
" [500, 170, 580, 350],\n",
" [580, 170, 640, 350],\n",
"], device=sam.device)"
]
},
{
"cell_type": "markdown",
"id": "ce56c57d",
"metadata": {},
"source": [
"Both images and prompts are input as PyTorch tensors that are already transformed to the correct frame. Inputs are packaged as a list over images, which each element is a dict that takes the following keys:\n",
"* `image`: The input image as a PyTorch tensor in CHW format.\n",
"* `original_size`: The size of the image before transforming for input to SAM, in (H, W) format.\n",
"* `point_coords`: Batched coordinates of point prompts.\n",
"* `point_labels`: Batched labels of point prompts.\n",
"* `boxes`: Batched input boxes.\n",
"* `mask_inputs`: Batched input masks.\n",
"\n",
"If a prompt is not present, the key can be excluded."
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "79f908ca",
"metadata": {},
"outputs": [],
"source": [
"from segment_anything.utils.transforms import ResizeLongestSide\n",
"resize_transform = ResizeLongestSide(sam.image_encoder.img_size)\n",
"\n",
"def prepare_image(image, transform, device):\n",
" image = transform.apply_image(image)\n",
" image = torch.as_tensor(image, device=device.device) \n",
" return image.permute(2, 0, 1).contiguous()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "23f63723",
"metadata": {},
"outputs": [],
"source": [
"batched_input = [\n",
" {\n",
" 'image': prepare_image(image1, resize_transform, sam),\n",
" 'boxes': resize_transform.apply_boxes_torch(image1_boxes, image1.shape[:2]),\n",
" 'original_size': image1.shape[:2]\n",
" },\n",
" {\n",
" 'image': prepare_image(image2, resize_transform, sam),\n",
" 'boxes': resize_transform.apply_boxes_torch(image2_boxes, image2.shape[:2]),\n",
" 'original_size': image2.shape[:2]\n",
" }\n",
"]"
]
},
{
"cell_type": "markdown",
"id": "6fbeb831",
"metadata": {},
"source": [
"Run the model."
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "f3b311b1",
"metadata": {},
"outputs": [],
"source": [
"batched_output = sam(batched_input, multimask_output=False)"
]
},
{
"cell_type": "markdown",
"id": "27bb50fd",
"metadata": {},
"source": [
"The output is a list over results for each input image, where list elements are dictionaries with the following keys:\n",
"* `masks`: A batched torch tensor of predicted binary masks, the size of the original image.\n",
"* `iou_predictions`: The model's prediction of the quality for each mask.\n",
"* `low_res_logits`: Low res logits for each mask, which can be passed back to the model as mask input on a later iteration."
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "eb3dba0f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['masks', 'iou_predictions', 'low_res_logits'])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batched_output[0].keys()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "e1108f48",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB8YAAAKgCAYAAADpkhewAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9aZN0SXbfif2Ou997IyIzn7X2tbuBRi9AswECJEESI4JDcTZSJjPJxuaFjGb6AvoqeimTvsKMFo6NTJQww21IAUOgAbCxsZfqru6uqq6qp541l4i497r70YvjfuNGZD5V1QA5MoPFscrKfCLu4uvxc/5nE1VVjnSkIx3pSEc60pGOdKQjHelIRzrSkY50pCMd6UhHOtKRjnSkIx3pSEf6S0ru/98NONKRjnSkIx3pSEc60pGOdKQjHelIRzrSkY50pCMd6UhHOtKRjnSkIx3pPyQdDeNHOtKRjnSkIx3pSEc60pGOdKQjHelIRzrSkY50pCMd6UhHOtKRjnSkv9R0NIwf6UhHOtKRjnSkIx3pSEc60pGOdKQjHelIRzrSkY50pCMd6UhHOtKR/lLT0TB+pCMd6UhHOtKRjnSkIx3pSEc60pGOdKQjHelIRzrSkY50pCMd6UhH+ktNR8P4kY50pCMd6UhHOtKRjnSkIx3pSEc60pGOdKQjHelIRzrSkY50pCMd6S81HQ3jRzrSkY50pCMd6UhHOtKRjnSkIx3pSEc60pGOdKQjHelIRzrSkY50pL/UdDSMH+lIRzrSkY50pCMd6UhHOtKRjnSkIx3pSEc60pGOdKQjHelIRzrSkf5S09EwfqQjHelIRzrSkY50pCMd6UhHOtKRjnSkIx3pSEc60pGOdKQjHelIR/pLTUfD+JGOdKQjHelIRzrSkY50pCMd6UhHOtKRjnSkIx3pSEc60pGOdKQjHekvNYXPe+Hf+Su/ymtf+Tn++t//TQaApDR4RIQEqBOcKt57nHPEGHHOMY4jOWeapkGzAoKqklKye8tv5xwi9t04jogIXdehqsQY8d6Tc8Y5s+WLCACqOv07pQSAc256rvce7/30znEcWSwWAOScyTlP185/APo4oqrknBERQgjEGNlut1P7APq+xzlHCAHvPQDjOFq7CKjau4ZhYBgGQghTH09OTmjblmEYEBE2mw2Xl5ecnp4SQkBEGMYt3gtt2zKOI33fs91u8d5PY+cQxmGY+pJzpm1bYox0XUfKifXmil968wt886tfw582nLUtH/7OH8CPPqJTyAGaJkxjnXMmxgiA954YI6qKc24aF1WdxrnOQb2mtkNVp3kKIRDCbtnV+Q4hTNfVMcw5T/M5XyP1d8oJnV1X31d/NCsCNE3DZrMhxshisZjmt1JdU7U98zVVn1nbM/9sfp2frZubKKigTtg2sFLPaz//BX7/v/8X+AfP0JOO1aKzcXGCB4KAAIMkejKDz3z1V7/Jm195i4ePH5KzEpoF6jx+uYS2JQIqgiYhZNDzDb//T/4Zd6KjGyGJMjjITnGAqAJKboS1RH7x7/4GL7/1Mk/WF+Rkew7nUQUVj/iAuEBSWIvn7pe/xksvv4qqJyfIKEkT51cXXG4uuL1Y0vmW5tYZIxnXD5z/+D0e/+D7LGWEHMF7omSyTzgyTh0xClfa4u6+yDf+xq/jRMkCWYSM4HHImGnV8/4P3uXJgw+5+uE7fO8f/xa3E1xoT6+RLgkNjugdWTOU+apzvDeHzhFVcM4TgmO5aHHeoxmc84jzOFHiONCEQN8ob/+Nb/L2r3+TH33yPks8p1FIzqGLltA0IA7NiqYMMfPww0948NHH0HiS5mlf1XUzrW2nvPrWy7zwwgtcXl7iQws4VAQRhwqggsYMKZMc9E5pFx3xwyf84X/z/+KV3qNkepcZsf3YiSf3Ay4E1kTu/PybfPk/+Ztc9VvjNd4Z79RMKnwxpkQaRnI/Mj4550/+xe9we3Qs1BFTIorigid4403DME48NWdIKdlzxvpZnngfwMnJCaenp/jgcQduWnKwp+pcHV4z/z6X3zftRVUlizAqPHn6lMePH7NYLGjbduIB3nsaH/AioJmubTk9WSEo5FTWTsaLktsWvb3k537tl0knLW65wPlAKGegCy2N8ziFMY5IVs4/eMAf/Q//irvRE2ImBZs7nNBlxWdYS8S/eo/f+N/+F6xVGWLCOUfOeY9P5pzJKKRM83jNt/7b36IdMlEgeuMfTQJyImO8MKKMHt74xV/gb/znf48PNs8YAvgkNOqmcat7o55bP/7T7/Lwd/6Y277Dx0zWxOAhO/OuO2s70phYj7Zm6nlW5YF6tmd0mp85L61/1zUyPz9yztM5XmYdnGPrI299/ed59Stf4Nm4ZXl2RggdzjX4NpA147LShABFBmDd89v/z99Cn16R+0QbGlQgBsGdrRjJBIT45AJNkbPX7/Nf/u//EZep59H2kl7tzPEIXgUnvvTJ+EuKkWEY+a1/8lusL3s+efiEDx4+RIJnu96Ss6I58fZbr+Fc5s03X+Uf/aP/HTlHNpvNNDZ1nnfzIIzRzuMP3vuA3/qt/x5NmdjHac2uton7zYJ7qzNWEojDyCjKFRF354T/xX/xn/KLf/3XUC9QxtYjBLX5P//kIR/+4Ef8yb/5fX7yk5+Y3FDmw87TjKqdG4cy23y+Dvfq/LpD/ju/5vPS/IzerQmmdXutDTM+chMPsTF3e3LP/Fq54f657HPIg+b9rN+rKlkzefZZXffz+6psVen68wF1gOw9e/63Air7bTy8Zt7em8blpjbMr53LXIey4D7vlcJ5Pn08n99fRRC8c6ScUcq7BBBB7RV7Z/nhs6Z3zeZtvibn94UQro3b4bjc1I95m2POiJO9sZm3aerX7O8qN09jXMdu9q7Dc+1539X3Vb3n8L7Ds7H2dy537z1bTD467M9h/+frYL7HRQRxZa4O5mVv7eAQ2dcbah9Kj8mSISsuC5JNSFYP3dmSv/Wbf4tm0ZK8p48jF5eXPH32jPXTZ3z87ntsPnnCWdPRhY6LzZqroWejkSGZXrPwDadNxzK0tOL35qeumWnsp8HcXxvzMb5prfyfv/+7/Hnp23/0fyQnkycdmRCUGEe+8+4nvPvhwIcPMw+fJC6fDYiCiIOs5JTt3FGlH1NpbKbOyJgTOZme1C0WrE5Op3M7ZWWMkb7fstls2GyvGLcbhqFn7LcM256h34IkMg5wdi66E/7+f/Yb/MO/81Vevd2xaCGcOZJbcNLepwt3WbS36dpTmqYh6kCMAzFHNuOGR1eP6NfnnDTKSdfQOJPF+2FgPW6IOjKkzBAjMQ5s11cgmeVyyXJ5wmpxyunpPW7duo1zwc6uoi8qicv+gqeXj3m2fsZFf0Efey43Gy4vr4gxk7OgKihKEzyC3R/jyJCEzXZgvRkQ51l0C85OVpy2DS/dPuHFe3dYNh4dB9bbnmHMxAzP1msen5/z8Nk5fYz4NvDyyYK7pytOTpasFh3eB9R7tmPkctNzvt4wqnC2OMM7k0tdFkQ93nlCExBn+s122/Pe44+4PH9GcI5lF1h2LV6sz1ltT0UVNmp87t7du7xw+x73T29z9+QWnW+JQySmjLiRrnOchDP+6f/lv+Gj3/oW90bHqQ9A5mqzoR8jH4fET8g8zJmTDH+vvUV68yX+ycNP+PZPH/Dk4TO6uwHCipPkyALSBk5v3+J8eMavvPeU/8PyTX55ecojueT/9tGH/Gs551EHzjWoCg+HDSnDrRTYirKRTAROBF7xHR7P45x5oJHHkkAzbcqTzlj34OE+reQkkzJo0QFPTm8T2hZyhnKCqZpcO8neEy4Qr8k2zjekGGmahqZp8AXLmZ8p4j3OBzxq+I4PNp/iUSBlk4ty8KhY21AH4vG+JYSW4Fu8D3gfcAEEwTZ/xdoSOZsOllMi5UQeNsQ4Mo6ROI7T56arRdNXcwYKLsZOftFsv2UaN2fvLH2PKU16NkBWJeWESMZJATV0/xx2IlRGqnL9fJrmTkHryeiun7vXzn/b7Hvf1/7I7ExLeXde3SzDlGd7P/XV2lJ19mDyBNY/EeOd3rvyDmfnn9h7EYc
"text/plain": [
"<Figure size 2000x2000 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"fig, ax = plt.subplots(1, 2, figsize=(20, 20))\n",
"\n",
"ax[0].imshow(image1)\n",
"for mask in batched_output[0]['masks']:\n",
" show_mask(mask.cpu().numpy(), ax[0], random_color=True)\n",
"for box in image1_boxes:\n",
" show_box(box.cpu().numpy(), ax[0])\n",
"ax[0].axis('off')\n",
"\n",
"ax[1].imshow(image2)\n",
"for mask in batched_output[1]['masks']:\n",
" show_mask(mask.cpu().numpy(), ax[1], random_color=True)\n",
"for box in image2_boxes:\n",
" show_box(box.cpu().numpy(), ax[1])\n",
"ax[1].axis('off')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2023-04-06 21:15:43 -07:00
"version": "3.8.0"
2023-04-04 22:25:49 -07:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}