mirror of https://github.com/UX-Decoder/DINOv.git
142 lines
6.0 KiB
Python
142 lines
6.0 KiB
Python
# --------------------------------------------------------
|
|
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
|
# Copyright (c) 2023 Microsoft
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
|
# --------------------------------------------------------
|
|
# Copyright (c) 2024 Microsoft
|
|
# Licensed under The MIT License [see LICENSE for details]
|
|
# Written by Feng Li (fliay@connect.ust.hk)
|
|
# --------------------------------------------------------
|
|
|
|
import torch
|
|
import numpy as np
|
|
from torchvision import transforms
|
|
from utils.visualizer import Visualizer
|
|
from typing import Tuple
|
|
from PIL import Image
|
|
from detectron2.data import MetadataCatalog
|
|
import os
|
|
import cv2
|
|
|
|
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
|
|
|
|
|
def inverse_sigmoid(x, eps=1e-5):
|
|
x = x.clamp(min=0, max=1)
|
|
x1 = x.clamp(min=eps)
|
|
x2 = (1 - x).clamp(min=eps)
|
|
return torch.log(x1/x2)
|
|
|
|
def task_openset(model,generic_vp1, generic_vp2, generic_vp3, generic_vp4,
|
|
generic_vp5, generic_vp6, generic_vp7, generic_vp8, image_tgt=None, text_size=640,hole_scale=100,island_scale=100):
|
|
in_context_examples = [generic_vp1, generic_vp2, generic_vp3, generic_vp4,
|
|
generic_vp5, generic_vp6, generic_vp7, generic_vp8]
|
|
in_context_examples = [x for x in in_context_examples if x is not None]
|
|
t = []
|
|
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
|
def prepare_image(image_ori):
|
|
width = image_ori.size[0]
|
|
height = image_ori.size[1]
|
|
image_ori = np.asarray(image_ori)
|
|
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
|
return images, height, width
|
|
transform1 = transforms.Compose(t)
|
|
image_ori_tgt = transform1(image_tgt)
|
|
images_tgt, height_tgt, width_tgt = prepare_image(image_ori_tgt)
|
|
data_tgt = {"image": images_tgt, "height": height_tgt, "width": width_tgt}
|
|
batched_inputs = []
|
|
batched_inputs_tgt = [data_tgt]
|
|
multi_scale_features2, mask_features2, _, _ = model.model.get_encoder_feature(batched_inputs_tgt)
|
|
input_query_label_content_all = []
|
|
point_coords = torch.ones(1, 4).cuda().float()
|
|
point_coords[:, :2] = 0.
|
|
input_query_bbox_content_init = inverse_sigmoid(point_coords[None])
|
|
for image in in_context_examples:
|
|
image_ori = transform1(image['image'])
|
|
mask_ori = transform1(image['mask'])
|
|
images, height, width = prepare_image(image_ori)
|
|
|
|
data = {"image": images, "height": height, "width": width}
|
|
data['seg_image'] = data_tgt
|
|
|
|
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
|
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)
|
|
|
|
data['targets'] = [dict()]
|
|
data['targets'][0]['rand_shape']=mask_ori
|
|
data['targets'][0]['pb']=torch.tensor([1.]) # FIXME 0 or 1
|
|
|
|
frame = data
|
|
rand_shape = mask_ori
|
|
frame['targets'][0]['rand_shape'] = rand_shape
|
|
|
|
batched_inputs.append(frame)
|
|
|
|
multi_scale_features, _, padded_h, padded_w = model.model.get_encoder_feature([frame])
|
|
input_query_label_content, input_query_bbox_content, attn_mask_content = model.model. \
|
|
get_visual_prompt_content_feature(multi_scale_features, frame['targets'][0]['rand_shape'], padded_h, padded_w)
|
|
input_query_label_content_all.append(input_query_label_content)
|
|
|
|
# prompt to tgt image
|
|
input_query_label_content_current = torch.stack(input_query_label_content_all).mean(0)
|
|
masks, ious, ori_masks, scores_per_image_openset = model.model.evaluate_demo_content_openset_multi_with_content_features(
|
|
batched_inputs_tgt, mask_features2, multi_scale_features2, input_query_label_content_current,
|
|
input_query_bbox_content_init, attn_mask_content, padded_h, padded_w)
|
|
if len(ious.shape)>1:
|
|
ious=ious[0]
|
|
ids=torch.argsort(scores_per_image_openset,descending=True)
|
|
areas=[]
|
|
image_ori = image_ori_tgt
|
|
new_pred_mask = []
|
|
new_pred_class_score = []
|
|
for i in ids:
|
|
new_pred_class_score.append(scores_per_image_openset[i])
|
|
new_pred_mask.append(masks[i])
|
|
pred_masks_poses = new_pred_mask
|
|
ious = new_pred_class_score
|
|
visual = Visualizer(image_ori, metadata=metadata)
|
|
for i,(pred_masks_pos,iou, _, _) in enumerate(zip(pred_masks_poses,ious, pred_masks_poses, pred_masks_poses)):
|
|
iou=round(float(iou),2)
|
|
texts=f'{iou}'
|
|
mask=(pred_masks_pos>0.0).cpu().numpy()
|
|
area=mask.sum()
|
|
areas.append(area)
|
|
# uncomment for additional postprocessing
|
|
# mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
|
# mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
|
mask=(mask).astype(np.float)
|
|
color=[0.,0.,1.0]
|
|
color=[0.502, 0.0, 0.502]
|
|
demo = visual.draw_binary_mask(mask, text='', alpha=0.7, edge_color=color)
|
|
res = demo.get_image()
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
return res
|
|
|
|
def remove_small_regions(
|
|
mask: np.ndarray, area_thresh: float, mode: str
|
|
) -> Tuple[np.ndarray, bool]:
|
|
"""
|
|
Removes small disconnected regions and holes in a mask. Returns the
|
|
mask and an indicator of if the mask has been modified.
|
|
"""
|
|
import cv2 # type: ignore
|
|
|
|
assert mode in ["holes", "islands"]
|
|
correct_holes = mode == "holes"
|
|
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
|
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
|
sizes = stats[:, -1][1:] # Row 0 is background label
|
|
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
|
if len(small_regions) == 0:
|
|
return mask, False
|
|
fill_labels = [0] + small_regions
|
|
if not correct_holes:
|
|
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
|
# If every region is below threshold, keep largest
|
|
if len(fill_labels) == 0:
|
|
fill_labels = [int(np.argmax(sizes)) + 1]
|
|
mask = np.isin(regions, fill_labels)
|
|
return mask, True |