DINOv/demo/openset_task.py

142 lines
6.0 KiB
Python

# --------------------------------------------------------
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
# --------------------------------------------------------
# Copyright (c) 2024 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Feng Li (fliay@connect.ust.hk)
# --------------------------------------------------------
import torch
import numpy as np
from torchvision import transforms
from utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
import os
import cv2
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
def inverse_sigmoid(x, eps=1e-5):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1/x2)
def task_openset(model,generic_vp1, generic_vp2, generic_vp3, generic_vp4,
generic_vp5, generic_vp6, generic_vp7, generic_vp8, image_tgt=None, text_size=640,hole_scale=100,island_scale=100):
in_context_examples = [generic_vp1, generic_vp2, generic_vp3, generic_vp4,
generic_vp5, generic_vp6, generic_vp7, generic_vp8]
in_context_examples = [x for x in in_context_examples if x is not None]
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
def prepare_image(image_ori):
width = image_ori.size[0]
height = image_ori.size[1]
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
return images, height, width
transform1 = transforms.Compose(t)
image_ori_tgt = transform1(image_tgt)
images_tgt, height_tgt, width_tgt = prepare_image(image_ori_tgt)
data_tgt = {"image": images_tgt, "height": height_tgt, "width": width_tgt}
batched_inputs = []
batched_inputs_tgt = [data_tgt]
multi_scale_features2, mask_features2, _, _ = model.model.get_encoder_feature(batched_inputs_tgt)
input_query_label_content_all = []
point_coords = torch.ones(1, 4).cuda().float()
point_coords[:, :2] = 0.
input_query_bbox_content_init = inverse_sigmoid(point_coords[None])
for image in in_context_examples:
image_ori = transform1(image['image'])
mask_ori = transform1(image['mask'])
images, height, width = prepare_image(image_ori)
data = {"image": images, "height": height, "width": width}
data['seg_image'] = data_tgt
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)
data['targets'] = [dict()]
data['targets'][0]['rand_shape']=mask_ori
data['targets'][0]['pb']=torch.tensor([1.]) # FIXME 0 or 1
frame = data
rand_shape = mask_ori
frame['targets'][0]['rand_shape'] = rand_shape
batched_inputs.append(frame)
multi_scale_features, _, padded_h, padded_w = model.model.get_encoder_feature([frame])
input_query_label_content, input_query_bbox_content, attn_mask_content = model.model. \
get_visual_prompt_content_feature(multi_scale_features, frame['targets'][0]['rand_shape'], padded_h, padded_w)
input_query_label_content_all.append(input_query_label_content)
# prompt to tgt image
input_query_label_content_current = torch.stack(input_query_label_content_all).mean(0)
masks, ious, ori_masks, scores_per_image_openset = model.model.evaluate_demo_content_openset_multi_with_content_features(
batched_inputs_tgt, mask_features2, multi_scale_features2, input_query_label_content_current,
input_query_bbox_content_init, attn_mask_content, padded_h, padded_w)
if len(ious.shape)>1:
ious=ious[0]
ids=torch.argsort(scores_per_image_openset,descending=True)
areas=[]
image_ori = image_ori_tgt
new_pred_mask = []
new_pred_class_score = []
for i in ids:
new_pred_class_score.append(scores_per_image_openset[i])
new_pred_mask.append(masks[i])
pred_masks_poses = new_pred_mask
ious = new_pred_class_score
visual = Visualizer(image_ori, metadata=metadata)
for i,(pred_masks_pos,iou, _, _) in enumerate(zip(pred_masks_poses,ious, pred_masks_poses, pred_masks_poses)):
iou=round(float(iou),2)
texts=f'{iou}'
mask=(pred_masks_pos>0.0).cpu().numpy()
area=mask.sum()
areas.append(area)
# uncomment for additional postprocessing
# mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
# mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
mask=(mask).astype(np.float)
color=[0.,0.,1.0]
color=[0.502, 0.0, 0.502]
demo = visual.draw_binary_mask(mask, text='', alpha=0.7, edge_color=color)
res = demo.get_image()
torch.cuda.empty_cache()
return res
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True