123 lines
3.4 KiB
Plaintext
123 lines
3.4 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import cv2\n",
|
|
"\n",
|
|
"import sam # noqa: F401\n",
|
|
"from sam.sam_inferencer import SAMInferencer\n",
|
|
"\n",
|
|
"\n",
|
|
"def show_mask(mask, ax, random_color=False):\n",
|
|
" if random_color:\n",
|
|
" color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
|
|
" else:\n",
|
|
" color = np.array([30/255, 144/255, 255/255, 0.6])\n",
|
|
" h, w = mask.shape[-2:]\n",
|
|
" mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
|
|
" ax.imshow(mask_image)\n",
|
|
" \n",
|
|
"def show_points(coords, labels, ax, marker_size=375):\n",
|
|
" pos_points = coords[labels==1]\n",
|
|
" neg_points = coords[labels==0]\n",
|
|
" ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
|
|
" ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) \n",
|
|
" \n",
|
|
"def show_box(box, ax):\n",
|
|
" x0, y0 = box[0], box[1]\n",
|
|
" w, h = box[2] - box[0], box[3] - box[1]\n",
|
|
" ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))\n",
|
|
"\n",
|
|
"image = cv2.imread('../../demo/demo.png')\n",
|
|
"image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
|
|
"plt.figure(figsize=(10,10))\n",
|
|
"plt.imshow(image)\n",
|
|
"plt.axis('on')\n",
|
|
"plt.show()\n",
|
|
"print(image.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"inferencer = SAMInferencer(arch='huge')\n",
|
|
"inferencer.set_image(image)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"input_point = np.array([[280, 230], [500, 300]])\n",
|
|
"input_label = np.array([1, 1])\n",
|
|
"plt.figure(figsize=(10,10))\n",
|
|
"plt.imshow(image)\n",
|
|
"show_points(input_point, input_label, plt.gca())\n",
|
|
"plt.axis('on')\n",
|
|
"plt.show() "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"masks, scores, logits = inferencer.predict(\n",
|
|
" point_coords=input_point,\n",
|
|
" point_labels=input_label,\n",
|
|
" multimask_output=True,\n",
|
|
")\n",
|
|
"for i, (mask, score) in enumerate(zip(masks, scores)):\n",
|
|
" plt.figure(figsize=(10,10))\n",
|
|
" plt.imshow(image)\n",
|
|
" show_mask(mask, plt.gca(), random_color=True)\n",
|
|
" show_points(input_point, input_label, plt.gca())\n",
|
|
" plt.title(f\"Mask {i+1}, Score: {score:.3f}\", fontsize=18)\n",
|
|
" plt.axis('off')\n",
|
|
" plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "pt1.13",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.9"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|