mmsegmentation/projects/sam_inference_demo/sam_image_demo.ipynb

123 lines
3.4 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import cv2\n",
"\n",
"import sam # noqa: F401\n",
"from sam.sam_inferencer import SAMInferencer\n",
"\n",
"\n",
"def show_mask(mask, ax, random_color=False):\n",
" if random_color:\n",
" color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
" else:\n",
" color = np.array([30/255, 144/255, 255/255, 0.6])\n",
" h, w = mask.shape[-2:]\n",
" mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
" ax.imshow(mask_image)\n",
" \n",
"def show_points(coords, labels, ax, marker_size=375):\n",
" pos_points = coords[labels==1]\n",
" neg_points = coords[labels==0]\n",
" ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
" ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) \n",
" \n",
"def show_box(box, ax):\n",
" x0, y0 = box[0], box[1]\n",
" w, h = box[2] - box[0], box[3] - box[1]\n",
" ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))\n",
"\n",
"image = cv2.imread('../../demo/demo.png')\n",
"image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(image)\n",
"plt.axis('on')\n",
"plt.show()\n",
"print(image.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"inferencer = SAMInferencer(arch='huge')\n",
"inferencer.set_image(image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"input_point = np.array([[280, 230], [500, 300]])\n",
"input_label = np.array([1, 1])\n",
"plt.figure(figsize=(10,10))\n",
"plt.imshow(image)\n",
"show_points(input_point, input_label, plt.gca())\n",
"plt.axis('on')\n",
"plt.show() "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"masks, scores, logits = inferencer.predict(\n",
" point_coords=input_point,\n",
" point_labels=input_label,\n",
" multimask_output=True,\n",
")\n",
"for i, (mask, score) in enumerate(zip(masks, scores)):\n",
" plt.figure(figsize=(10,10))\n",
" plt.imshow(image)\n",
" show_mask(mask, plt.gca(), random_color=True)\n",
" show_points(input_point, input_label, plt.gca())\n",
" plt.title(f\"Mask {i+1}, Score: {score:.3f}\", fontsize=18)\n",
" plt.axis('off')\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pt1.13",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}