mirror of https://github.com/UX-Decoder/DINOv.git
add demo link
parent
f5723f6f4a
commit
8ea2645d23
|
@ -1,5 +1,5 @@
|
|||
# IntelliJ project files
|
||||
repo.diff
|
||||
#repo.diff
|
||||
.idea
|
||||
.vscode
|
||||
.amltignore
|
||||
|
|
46
README.md
46
README.md
|
@ -1,10 +1,26 @@
|
|||
# Visual In-Context Prompting
|
||||
:grapes: \[[Read our arXiv Paper](https://arxiv.org/pdf/2311.13601.pdf)\] :apple: \[[Try our Demo](http://semantic-sam.xyzou.net:6099/)\]
|
||||
|
||||
In this work, we introduce [DINOv](https://arxiv.org/pdf/2311.13601.pdf), a Visual In-Context Prompting framework for referring and generic segmentation tasks.
|
||||
|
||||
For visualization and demos, we recommend using [T-Rex demo link](https://deepdataspace.com/playground/ivp), which is another visual prompting tool in our team with similar properties as DINOv.
|
||||
For visualization and demos, we also recommend trying [T-Rex demo link](https://deepdataspace.com/playground/ivp), which is another visual prompting tool in our team with similar properties as DINOv.
|
||||
|
||||

|
||||
|
||||
### :hammer_and_wrench: Installation
|
||||
```shell
|
||||
pip3 install torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
python -m pip install 'git+https://github.com/MaureenZOU/detectron2-xyz.git'
|
||||
pip install git+https://github.com/cocodataset/panopticapi.git
|
||||
git clone https://github.com/UX-Decoder/DINOv
|
||||
cd DINOv
|
||||
python -m pip install -r requirements.txt
|
||||
```
|
||||
#### :point_right: Launch a demo for visual in-context prompting
|
||||
```shell
|
||||
python demo_openset.py --ckpt /path/to/swinL/ckpt
|
||||
```
|
||||
|
||||
# Openset segmentation
|
||||

|
||||
|
||||
|
@ -19,19 +35,13 @@ For visualization and demos, we recommend using [T-Rex demo link](https://deepda
|
|||
|
||||
## :unicorn: Getting Started
|
||||
|
||||
### :hammer_and_wrench: Installation
|
||||
```shell
|
||||
pip3 install torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
python -m pip install 'git+https://github.com/MaureenZOU/detectron2-xyz.git'
|
||||
pip install git+https://github.com/cocodataset/panopticapi.git
|
||||
git clone https://github.com/UX-Decoder/DINOv
|
||||
cd DINOv
|
||||
python -m pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### :mosque: Data preparation
|
||||
We jointly train on COCO and SA-1B data. Please refer to [prepare SA-1B data](https://github.com/UX-Decoder/Semantic-SAM/blob/main/DATASET.md) and [prepare coco data](https://github.com/IDEA-Research/MaskDINO/blob/main/README.md).
|
||||
|
||||
For evaluation, you need to prepare
|
||||
* [ADE20K](https://github.com/IDEA-Research/MaskDINO/blob/main/datasets/README.md) for open-set segmentation evaluation.
|
||||
* [DAVIS2017](https://davischallenge.org/davis2017/code.html) for refering segmentation (video object segmentation).
|
||||
|
||||
### :volcano: Model Zoo
|
||||
The currently released checkpoints are trained with SA-1B and COCO data.
|
||||
<table><tbody>
|
||||
|
@ -65,6 +75,7 @@ We do detection evaluation on COCO val2017.
|
|||
`$n` is the number of gpus you use
|
||||
|
||||
Process visual prompt embeddings for inference. We calculate the all the instance prompt embeddings of the validate set (you can also use the training set, but the processing time is much longer) and store them. Then we infrence by randomly selecting some visual prompts as in-context examples.
|
||||
#### Evaluate Open-set detection and segmentation
|
||||
* Infenrence script to get and store visual prompts
|
||||
```shell
|
||||
python train_net.py --eval_only --resume --eval_get_content_features --num-gpus 8 --config-file /path/to/configs COCO.TEST.BATCH_SIZE_TOTAL=8 MODEL.WEIGHTS=/path/to/weights OUTPUT_DIR=/path/to/outputs
|
||||
|
@ -74,8 +85,21 @@ python train_net.py --eval_only --resume --eval_get_content_features --num-gpus
|
|||
python train_net.py --eval_only --resume --eval_visual_openset --num-gpus 8 --config-file /path/to/configs COCO.TEST.BATCH_SIZE_TOTAL=8 MODEL.WEIGHTS=/path/to/weights MODEL.DECODER.INFERENCE_EXAMPLE=16 OUTPUT_DIR=/path/to/outputs
|
||||
```
|
||||
* **configs** to use are `configs/dinov_sam_coco_train.yaml` for swinT and `configs/dinov_sam_coco_swinl_train.yaml` for swinL.
|
||||
* For ADE20K data, use `configs/dinov_sam_ade_eval.yaml` and adjust the batchsize of ADE evaluation to the correct number.
|
||||
* `OUTPUT_DIR` is the dir to store the visual prompt embeddings
|
||||
* `INFERENCE_EXAMPLE` number of in-context examples to represent a category. Default set to 16.
|
||||
#### Evaluate Refering segmentation on VOS
|
||||
We evaluate under the `DAVIS 2017 Semi-supervised` setting, please refer to [davis2017-evaluation](https://github.com/davisvideochallenge/davis2017-evaluation) for more details.
|
||||
|
||||
The first step is to compute and store the results of DAVIS2017. We implement a navie memory-aware approach with our in-context visual prompting.
|
||||
```shell
|
||||
python train_net.py --eval_track_prev --eval_only --resume --num-gpus 8 --config-file configs/dinov_sam_coco_train.yaml DAVIS.TEST.BATCH_SIZE_TOTAL=8 OUTPUT_DIR=$outdir MODEL.WEIGHTS=/path/to/weights MODEL.DECODER.NMS_THRESHOLD=0.9 MODEL.DECODER.MAX_MEMORY_SIZE=9 OUTPUT_DIR=/path/to/outputs
|
||||
```
|
||||
The second step is to evaluate the semi-supervised results.
|
||||
```shell
|
||||
python evaluation_method.py --task semi-supervised --results_path /path/to/results --davis_path /path/to/davis/data
|
||||
```
|
||||
* We use MAX_MEMORY_SIZE = 9 by default (1 current frame token and 8 previous memory tokens)
|
||||
### :star: Training
|
||||
We currently release the code of training on SA-1B and COCO. It can also support Objects365 and other datasets with minimal modifications.
|
||||
`$n` is the number of gpus you use
|
||||
|
|
|
@ -176,7 +176,7 @@ MODEL:
|
|||
NMS_THRESHOLD: 0.7
|
||||
REFER_IMAGE: False
|
||||
point_per_side: 20
|
||||
max_memory_size: 8
|
||||
MAX_MEMORY_SIZE: 9
|
||||
softmax: False
|
||||
MANY2MANY: True
|
||||
VIS: False
|
||||
|
|
|
@ -166,7 +166,6 @@ MODEL:
|
|||
NUM_CONTENT_TOKENS: 1
|
||||
MAX_NUM_INSTANCE: 50
|
||||
MAX_NUM_INSTANCE_CONTENT: 8
|
||||
# CONTENT_INDEPENDENT: True
|
||||
USE_SHAPE_SAMPLER: True
|
||||
OPENSET_USE_SHAPE_SAMPLER: False
|
||||
INFERENCE_EXAMPLE: 4
|
||||
|
@ -176,7 +175,7 @@ MODEL:
|
|||
NMS_THRESHOLD: 0.7
|
||||
REFER_IMAGE: False
|
||||
point_per_side: 20
|
||||
max_memory_size: 8
|
||||
MAX_MEMORY_SIZE: 9
|
||||
softmax: False
|
||||
MANY2MANY: True
|
||||
VIS: False
|
||||
|
|
|
@ -176,7 +176,7 @@ MODEL:
|
|||
NMS_THRESHOLD: 0.7
|
||||
REFER_IMAGE: False
|
||||
point_per_side: 20
|
||||
max_memory_size: 8
|
||||
MAX_MEMORY_SIZE: 9
|
||||
softmax: True
|
||||
MANY2MANY: True
|
||||
VIS: False
|
||||
|
@ -379,7 +379,7 @@ ADE20K:
|
|||
|
||||
DAVIS:
|
||||
INPUT:
|
||||
MIN_SIZE_TEST: 800
|
||||
MIN_SIZE_TEST: 640
|
||||
MAX_SIZE_TEST: 1333
|
||||
DATALOADER:
|
||||
FILTER_EMPTY_ANNOTATIONS: False
|
||||
|
@ -392,7 +392,7 @@ DAVIS:
|
|||
|
||||
VOS:
|
||||
INPUT:
|
||||
MIN_SIZE_TEST: 800
|
||||
MIN_SIZE_TEST: 640
|
||||
MAX_SIZE_TEST: 1333
|
||||
DATALOADER:
|
||||
FILTER_EMPTY_ANNOTATIONS: False
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .openset_task import task_openset
|
Binary file not shown.
After Width: | Height: | Size: 140 KiB |
Binary file not shown.
After Width: | Height: | Size: 35 KiB |
Binary file not shown.
After Width: | Height: | Size: 8.1 MiB |
Binary file not shown.
After Width: | Height: | Size: 207 KiB |
|
@ -0,0 +1,142 @@
|
|||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
# Copyright (c) 2024 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Feng Li (fliay@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from torchvision import transforms
|
||||
from utils.visualizer import Visualizer
|
||||
from typing import Tuple
|
||||
from PIL import Image
|
||||
from detectron2.data import MetadataCatalog
|
||||
import os
|
||||
import cv2
|
||||
|
||||
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
|
||||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-5):
|
||||
x = x.clamp(min=0, max=1)
|
||||
x1 = x.clamp(min=eps)
|
||||
x2 = (1 - x).clamp(min=eps)
|
||||
return torch.log(x1/x2)
|
||||
|
||||
def task_openset(model,generic_vp1, generic_vp2, generic_vp3, generic_vp4,
|
||||
generic_vp5, generic_vp6, generic_vp7, generic_vp8, image_tgt=None, text_size=640,hole_scale=100,island_scale=100):
|
||||
in_context_examples = [generic_vp1, generic_vp2, generic_vp3, generic_vp4,
|
||||
generic_vp5, generic_vp6, generic_vp7, generic_vp8]
|
||||
in_context_examples = [x for x in in_context_examples if x is not None]
|
||||
t = []
|
||||
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
|
||||
def prepare_image(image_ori):
|
||||
width = image_ori.size[0]
|
||||
height = image_ori.size[1]
|
||||
image_ori = np.asarray(image_ori)
|
||||
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
|
||||
return images, height, width
|
||||
transform1 = transforms.Compose(t)
|
||||
image_ori_tgt = transform1(image_tgt)
|
||||
images_tgt, height_tgt, width_tgt = prepare_image(image_ori_tgt)
|
||||
data_tgt = {"image": images_tgt, "height": height_tgt, "width": width_tgt}
|
||||
batched_inputs = []
|
||||
batched_inputs_tgt = [data_tgt]
|
||||
multi_scale_features2, mask_features2, _, _ = model.model.get_encoder_feature(batched_inputs_tgt)
|
||||
input_query_label_content_all = []
|
||||
point_coords = torch.ones(1, 4).cuda().float()
|
||||
point_coords[:, :2] = 0.
|
||||
input_query_bbox_content_init = inverse_sigmoid(point_coords[None])
|
||||
for image in in_context_examples:
|
||||
image_ori = transform1(image['image'])
|
||||
mask_ori = transform1(image['mask'])
|
||||
images, height, width = prepare_image(image_ori)
|
||||
|
||||
data = {"image": images, "height": height, "width": width}
|
||||
data['seg_image'] = data_tgt
|
||||
|
||||
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
|
||||
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)
|
||||
|
||||
data['targets'] = [dict()]
|
||||
data['targets'][0]['rand_shape']=mask_ori
|
||||
data['targets'][0]['pb']=torch.tensor([1.]) # FIXME 0 or 1
|
||||
|
||||
frame = data
|
||||
rand_shape = mask_ori
|
||||
frame['targets'][0]['rand_shape'] = rand_shape
|
||||
|
||||
batched_inputs.append(frame)
|
||||
|
||||
multi_scale_features, _, padded_h, padded_w = model.model.get_encoder_feature([frame])
|
||||
input_query_label_content, input_query_bbox_content, attn_mask_content = model.model. \
|
||||
get_visual_prompt_content_feature(multi_scale_features, frame['targets'][0]['rand_shape'], padded_h, padded_w)
|
||||
input_query_label_content_all.append(input_query_label_content)
|
||||
|
||||
# prompt to tgt image
|
||||
input_query_label_content_current = torch.stack(input_query_label_content_all).mean(0)
|
||||
masks, ious, ori_masks, scores_per_image_openset = model.model.evaluate_demo_content_openset_multi_with_content_features(
|
||||
batched_inputs_tgt, mask_features2, multi_scale_features2, input_query_label_content_current,
|
||||
input_query_bbox_content_init, attn_mask_content, padded_h, padded_w)
|
||||
if len(ious.shape)>1:
|
||||
ious=ious[0]
|
||||
ids=torch.argsort(scores_per_image_openset,descending=True)
|
||||
areas=[]
|
||||
image_ori = image_ori_tgt
|
||||
new_pred_mask = []
|
||||
new_pred_class_score = []
|
||||
for i in ids:
|
||||
new_pred_class_score.append(scores_per_image_openset[i])
|
||||
new_pred_mask.append(masks[i])
|
||||
pred_masks_poses = new_pred_mask
|
||||
ious = new_pred_class_score
|
||||
visual = Visualizer(image_ori, metadata=metadata)
|
||||
for i,(pred_masks_pos,iou, _, _) in enumerate(zip(pred_masks_poses,ious, pred_masks_poses, pred_masks_poses)):
|
||||
iou=round(float(iou),2)
|
||||
texts=f'{iou}'
|
||||
mask=(pred_masks_pos>0.0).cpu().numpy()
|
||||
area=mask.sum()
|
||||
areas.append(area)
|
||||
# uncomment for additional postprocessing
|
||||
# mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
|
||||
# mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
|
||||
mask=(mask).astype(np.float)
|
||||
color=[0.,0.,1.0]
|
||||
color=[0.502, 0.0, 0.502]
|
||||
demo = visual.draw_binary_mask(mask, text='', alpha=0.7, edge_color=color)
|
||||
res = demo.get_image()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return res
|
||||
|
||||
def remove_small_regions(
|
||||
mask: np.ndarray, area_thresh: float, mode: str
|
||||
) -> Tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Removes small disconnected regions and holes in a mask. Returns the
|
||||
mask and an indicator of if the mask has been modified.
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in ["holes", "islands"]
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if len(small_regions) == 0:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
||||
# If every region is below threshold, keep largest
|
||||
if len(fill_labels) == 0:
|
||||
fill_labels = [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
|
@ -0,0 +1,140 @@
|
|||
# --------------------------------------------------------
|
||||
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
|
||||
# Copyright (c) 2023 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
# Copyright (c) 2024 Microsoft
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# Written by Feng Li (fliay@connect.ust.hk)
|
||||
# --------------------------------------------------------
|
||||
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
from dinov.BaseModel import BaseModel
|
||||
from dinov import build_model
|
||||
from utils.arguments import load_opt_from_config_file
|
||||
|
||||
from demo import task_openset
|
||||
|
||||
def parse_option():
|
||||
parser = argparse.ArgumentParser('DINOv Demo', add_help=False)
|
||||
parser.add_argument('--conf_files', default="configs/dinov_sam_coco_swinl_train.yaml", metavar="FILE", help='path to config file', )
|
||||
parser.add_argument('--ckpt', default="", metavar="FILE", help='path to ckpt', required=True)
|
||||
parser.add_argument('--port', default=6099, type=int, help='path to ckpt', )
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class ImageMask(gr.components.Image):
|
||||
"""
|
||||
Sets: source="canvas", tool="sketch"
|
||||
"""
|
||||
|
||||
is_template = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
|
||||
|
||||
def preprocess(self, x):
|
||||
return super().preprocess(x)
|
||||
|
||||
|
||||
'''
|
||||
build args
|
||||
'''
|
||||
args = parse_option()
|
||||
|
||||
'''
|
||||
build model
|
||||
'''
|
||||
|
||||
sam_cfg=args.conf_files
|
||||
|
||||
opt = load_opt_from_config_file(sam_cfg)
|
||||
|
||||
model_sam = BaseModel(opt, build_model(opt)).from_pretrained(args.ckpt).eval().cuda()
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(generic_vp1, generic_vp2, generic_vp3, generic_vp4,
|
||||
generic_vp5, generic_vp6, generic_vp7, generic_vp8, image2,*args, **kwargs):
|
||||
with torch.autocast(device_type='cuda', dtype=torch.float16):
|
||||
model=model_sam
|
||||
a= task_openset(model, generic_vp1, generic_vp2, generic_vp3, generic_vp4,
|
||||
generic_vp5, generic_vp6, generic_vp7, generic_vp8, image2, *args, **kwargs)
|
||||
return a
|
||||
|
||||
|
||||
'''
|
||||
launch app
|
||||
'''
|
||||
title = "DINOv: Visual In-Context Prompting"
|
||||
|
||||
article = "The Demo is Run on DINOv."
|
||||
|
||||
demo = gr.Blocks()
|
||||
image_tgt=gr.components.Image(label="Target Image ",type="pil",brush_radius=15.0)
|
||||
gallery_output=gr.components.Image(label="Results Image ",type="pil",brush_radius=15.0)
|
||||
|
||||
generic_vp1 = ImageMask(label="scribble on refer Image 1",type="pil",brush_radius=15.0)
|
||||
generic_vp2 = ImageMask(label="scribble on refer Image 2",type="pil",brush_radius=15.0)
|
||||
generic_vp3 = ImageMask(label="scribble on refer Image 3",type="pil",brush_radius=15.0)
|
||||
generic_vp4 = ImageMask(label="scribble on refer Image 5",type="pil",brush_radius=15.0)
|
||||
generic_vp5 = ImageMask(label="scribble on refer Image 6",type="pil",brush_radius=15.0)
|
||||
generic_vp6 = ImageMask(label="scribble on refer Image 7",type="pil",brush_radius=15.0)
|
||||
generic_vp7 = ImageMask(label="scribble on refer Image 8",type="pil",brush_radius=15.0)
|
||||
generic_vp8 = ImageMask(label="scribble on refer Image 9",type="pil",brush_radius=15.0)
|
||||
generic = gr.TabbedInterface([
|
||||
generic_vp1, generic_vp2, generic_vp3, generic_vp4,
|
||||
generic_vp5, generic_vp6, generic_vp7, generic_vp8
|
||||
], ["1", "2", "3", "4", "5", "6", "7", "8"])
|
||||
|
||||
title='''
|
||||
# DINOv: Visual In-Context Prompting
|
||||
|
||||
# [[Read our arXiv Paper](https://arxiv.org/pdf/2311.13601.pdf)\] \[[Github page](https://github.com/UX-Decoder/DINOv)\]
|
||||
'''
|
||||
|
||||
with demo:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3.0):
|
||||
generation_tittle = gr.Markdown(title)
|
||||
image_tgt.render()
|
||||
generic.render()
|
||||
with gr.Row(scale=2.0):
|
||||
clearBtn = gr.ClearButton(
|
||||
components=[image_tgt])
|
||||
runBtn = gr.Button("Run")
|
||||
with gr.Column(scale=5.0):
|
||||
|
||||
gallery_tittle = gr.Markdown("# Open-set results.")
|
||||
with gr.Row(scale=9.0):
|
||||
gallery_output.render()
|
||||
|
||||
example = gr.Examples(
|
||||
examples=[
|
||||
["demo/examples/bags.jpg"],
|
||||
["demo/examples/img.png"],
|
||||
["demo/examples/corgi2.jpg"],
|
||||
["demo/examples/ref_cat.jpeg"],
|
||||
],
|
||||
inputs=image_tgt,
|
||||
cache_examples=False,
|
||||
)
|
||||
|
||||
title = title,
|
||||
article = article,
|
||||
allow_flagging = 'never',
|
||||
|
||||
runBtn.click(inference, inputs=[generic_vp1, generic_vp2, generic_vp3, generic_vp4,
|
||||
generic_vp5, generic_vp6, generic_vp7, generic_vp8, image_tgt],
|
||||
outputs = [gallery_output])
|
||||
|
||||
|
||||
|
||||
demo.queue().launch(share=True,server_port=args.port)
|
||||
|
|
@ -191,7 +191,7 @@ class DINOv(nn.Module):
|
|||
|
||||
# freeze some parameters
|
||||
to_freeze_dict = ['label_enc', 'pb_embedding']
|
||||
if freeze_all:
|
||||
if freeze_all:
|
||||
for (name, param) in self.named_parameters():
|
||||
param.requires_grad = False
|
||||
print("!!!!!!!!freeze_all!!!!!!!!, except ", to_freeze_dict)
|
||||
|
@ -1139,9 +1139,52 @@ class DINOv(nn.Module):
|
|||
|
||||
return src_boxes, mask_pred_results, src_ious, pred_ious
|
||||
|
||||
def filter_data_openset(self, src_boxes, mask_pred_results, src_ious, pred_ious, pred_score_openset):
|
||||
def keep_data(box, mask, iou, match_score, pred_score_openset, keep):
|
||||
return box[keep], mask[keep], iou[keep], match_score[:, keep], pred_score_openset[:, keep]
|
||||
# advanced filtering
|
||||
# uncomment below for better filltering
|
||||
|
||||
# print('filter iou score')
|
||||
# keep = src_ious > 0.5
|
||||
# src_boxes, mask_pred_results, src_ious, pred_ious, pred_score_openset = keep_data(src_boxes, mask_pred_results, src_ious, pred_ious,pred_score_openset,
|
||||
# keep)
|
||||
#
|
||||
# stability_score = calculate_stability_score(
|
||||
# mask_pred_results, 0.0, self.stability_score_offset
|
||||
# )
|
||||
# keep = stability_score >= self.stability_score_thresh
|
||||
#
|
||||
# src_boxes, mask_pred_results, src_ious, pred_ious, pred_score_openset = keep_data(src_boxes, mask_pred_results, src_ious,pred_score_openset,
|
||||
# pred_ious, keep)
|
||||
|
||||
# print('using nms')
|
||||
# item_indice = nms(box_ops.box_cxcywh_to_xyxy(src_boxes), src_ious, self.nms_thersh) # FIXME iou threshold
|
||||
# mask_pred_results = mask_pred_results[item_indice]
|
||||
# src_boxes = src_boxes[item_indice]
|
||||
# src_ious = src_ious[item_indice]
|
||||
# pred_ious = torch.index_select(pred_ious, -1, item_indice)
|
||||
# pred_score_openset = torch.index_select(pred_score_openset, -1, item_indice)
|
||||
|
||||
# print("remove small objects")
|
||||
# keep = (mask_pred_results > 0).flatten(-2, -1).sum(-1) > 50
|
||||
# # keep = (mask_pred_results > 0).flatten(-2, -1).sum(-1) > 100
|
||||
# src_boxes, mask_pred_results, src_ious, pred_ious, pred_score_openset = keep_data(src_boxes, mask_pred_results, src_ious,
|
||||
# pred_ious, pred_score_openset, keep)
|
||||
|
||||
scores_per_image_openset, label_openset = pred_score_openset.sigmoid().max(-1)
|
||||
thresh = 0.12
|
||||
keep = scores_per_image_openset>thresh
|
||||
while sum(keep)<1:
|
||||
thresh = thresh-0.04
|
||||
keep = scores_per_image_openset > thresh
|
||||
scores_per_image_openset, label_openset = scores_per_image_openset[keep], label_openset[keep]
|
||||
mask_pred_results = mask_pred_results[keep]
|
||||
|
||||
return src_boxes, mask_pred_results, src_ious, pred_ious, scores_per_image_openset
|
||||
|
||||
def get_encoder_feature(self, batched_inputs):
|
||||
# get the image encoder features (multi-scale)
|
||||
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
||||
images = self.prepare_image(batched_inputs)
|
||||
padded_h = images.tensor.shape[-2] # divisable to 32
|
||||
padded_w = images.tensor.shape[-1]
|
||||
|
@ -1238,6 +1281,70 @@ class DINOv(nn.Module):
|
|||
)
|
||||
return pred_masks, pred_ious, ori_masks
|
||||
|
||||
def evaluate_demo_content_openset_multi_with_content_features(self, batched_inputs, mask_features, multi_scale_features,
|
||||
input_query_label_content,
|
||||
input_query_bbox_content, attn_mask_content,
|
||||
padded_h, padded_w,
|
||||
level=[0,1,2,3,4,5], return_src_ious=False):
|
||||
assert len(batched_inputs) == 1, "only support batch size equal to 1"
|
||||
prediction_switch = {'part': False, 'whole': False, 'seg': True, 'det': True}
|
||||
|
||||
def prepare_image(batched_inputs, key='image'):
|
||||
images = [x['image'].to(self.device) for x in batched_inputs]
|
||||
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
|
||||
images = ImageList.from_tensors(images, self.size_divisibility)
|
||||
return images
|
||||
|
||||
images = prepare_image(batched_inputs)
|
||||
|
||||
# image_size_xyxy = torch.tensor([padded_w, padded_h, padded_w, padded_h]).cuda()
|
||||
# auto_points = torch.tensor(self.build_point_grid(self.point_per_side)).cuda() * image_size_xyxy[:2]
|
||||
|
||||
# boxes_dn = box_ops.box_xyxy_to_cxcywh(
|
||||
# torch.cat([auto_points - 3, auto_points + 3], 1)) / image_size_xyxy
|
||||
# boxes_dn = torch.tensor(boxes_dn, dtype=torch.float32)
|
||||
# pb = torch.ones(boxes_dn.shape[0]) # FIXME: use 1 for pb
|
||||
# targets_p = [{}]
|
||||
# targets_p[0]['boxes_dn'] = boxes_dn
|
||||
# targets_p[0]['pb'] = pb
|
||||
|
||||
# targets = targets_p
|
||||
targets = None
|
||||
outputs, mask_dict = self.sem_seg_head.predictor.forward_openset_image_with_extracted_content(multi_scale_features,
|
||||
mask_features, None, input_query_label_content, input_query_bbox_content, attn_mask_content, targets,
|
||||
extra=prediction_switch)
|
||||
|
||||
src_boxes = outputs["pred_boxes"][0]
|
||||
mask_pred_results = outputs["pred_masks"][0]
|
||||
pred_score_openset = outputs["pred_logits"][0]
|
||||
# level = torch.tensor(level).cuda()
|
||||
src_ious = pred_score_openset.flatten(0, 1)
|
||||
pred_ious = pred_score_openset
|
||||
src_boxes, mask_pred_results, src_ious, pred_ious, scores_per_image_openset = self.filter_data_openset(src_boxes, mask_pred_results, src_ious,
|
||||
pred_ious, pred_score_openset)
|
||||
|
||||
# upsample masks
|
||||
mask_pred_results = F.interpolate(
|
||||
mask_pred_results[None],
|
||||
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
pred_masks = mask_pred_results
|
||||
image_size = images.image_sizes[0]
|
||||
|
||||
height = batched_inputs[0].get('height', image_size[0])
|
||||
width = batched_inputs[0].get('width', image_size[1])
|
||||
ori_masks = pred_masks[:, : image_size[0], : image_size[1]].expand(1, -1, -1, -1)[0]
|
||||
# import ipdb; ipdb.set_trace()
|
||||
if self.sem_seg_postprocess_before_inference:
|
||||
pred_masks = retry_if_cuda_oom(sem_seg_postprocess)(
|
||||
pred_masks, image_size, height, width
|
||||
)
|
||||
return pred_masks, pred_ious, ori_masks, scores_per_image_openset
|
||||
|
||||
|
||||
def semantic_inference(self, mask_cls, mask_pred):
|
||||
# if use cross-entropy loss in training, evaluate with softmax
|
||||
if self.semantic_ce_loss:
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
diff --git openseed/architectures/joint_oi_model.py openseed/architectures/joint_oi_model.py
|
||||
index 8086690..a0679fe 100644
|
||||
--- openseed/architectures/joint_oi_model.py
|
||||
+++ openseed/architectures/joint_oi_model.py
|
||||
@@ -286,6 +286,7 @@ class GeneralizedMaskDINO(nn.Module):
|
||||
"coco_on": dec_cfg.get('COCO', True),
|
||||
"coco_mask_on": dec_cfg.get('COCO_MASK', True),
|
||||
"o365_on": dec_cfg.get('O365', True),
|
||||
+ "regenerate_point": dec_cfg.get('RE_POINT', False),
|
||||
}
|
||||
|
||||
@property
|
||||
@@ -531,7 +532,7 @@ class GeneralizedMaskDINO(nn.Module):
|
||||
|
||||
# if not self.training:
|
||||
# box_start = int(num_mask/4*3)
|
||||
- box_start = random.randint(0, self.max_num_instance - 1) # box based interactive after this number; about 1/4
|
||||
+ box_start = random.randint(1, self.max_num_instance - 1) # box based interactive after this number; about 1/4
|
||||
point_coords = targets_per_image.point_coords[index[:box_start]]
|
||||
# FIXME randomly sample one point as the user input
|
||||
if self.regenerate_point:
|
24
train_net.py
24
train_net.py
|
@ -309,7 +309,7 @@ class Trainer(DefaultTrainer):
|
|||
from PIL import Image
|
||||
from queue import Queue, LifoQueue, PriorityQueue
|
||||
cfg['DATASETS']['TEST'] = ['davis17_val']
|
||||
maxsize = cfg['MODEL']['DECODER']['max_memory_size']
|
||||
maxsize = cfg['MODEL']['DECODER']['MAX_MEMORY_SIZE']
|
||||
dataloaders = cls.build_test_loader(cfg, dataset_name=None)
|
||||
model = model.eval().cuda()
|
||||
|
||||
|
@ -612,10 +612,7 @@ def main(args=None):
|
|||
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
|
||||
cfg.MODEL.WEIGHTS, resume=args.resume
|
||||
)
|
||||
if args.original_load:
|
||||
print("using original loading")
|
||||
model = model.from_pretrained(cfg.MODEL.WEIGHTS)
|
||||
elif args.eval_visual_openset:
|
||||
if args.eval_visual_openset:
|
||||
res = Trainer.test_visual_openset(cfg, model, args.eval_visual_openset_combine)
|
||||
elif args.eval_track_prev:
|
||||
res = Trainer.test_tracking_prev(cfg, model)
|
||||
|
@ -630,13 +627,6 @@ def main(args=None):
|
|||
init_wandb(cfg, cfg['OUTPUT_DIR'], entity=args.wandb_usr_name, job_name=cfg['OUTPUT_DIR'])
|
||||
|
||||
trainer = Trainer(cfg)
|
||||
if len(args.lang_weight)>0:
|
||||
import copy
|
||||
weight = copy.deepcopy(trainer.cfg.MODEL.WEIGHTS)
|
||||
trainer.cfg.MODEL.WEIGHTS = args.lang_weight
|
||||
print("load original language language weight!!!!!!")
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
trainer.cfg.MODEL.WEIGHTS = weight
|
||||
|
||||
print("load pretrained model weight!!!!!!")
|
||||
trainer.resume_or_load(resume=args.resume)
|
||||
|
@ -647,19 +637,9 @@ def main(args=None):
|
|||
if __name__ == "__main__":
|
||||
parser = default_argument_parser()
|
||||
parser.add_argument('--eval_only', action='store_true')
|
||||
parser.add_argument('--eval_track', action='store_true')
|
||||
parser.add_argument('--eval_visual_openset', action='store_true')
|
||||
parser.add_argument('--eval_visual_openset_combine', action='store_true')
|
||||
parser.add_argument('--eval_track_prev', action='store_true')
|
||||
parser.add_argument('--eval_track_prev_davis16', action='store_true')
|
||||
parser.add_argument('--eval_track_gt_frame', action='store_true')
|
||||
parser.add_argument('--eval_track_prev_ytvos', action='store_true')
|
||||
parser.add_argument('--eval_track_prev_ytvos_v2', action='store_true')
|
||||
parser.add_argument('--eval_track_gt', action='store_true')
|
||||
parser.add_argument('--eval_get_content_features', action='store_true')
|
||||
parser.add_argument('--original_load', action='store_true')
|
||||
parser.add_argument('--EVAL_FLAG', type=int, default=1)
|
||||
parser.add_argument('--lang_weight', type=str, default='')
|
||||
parser.add_argument('--WANDB', action='store_true')
|
||||
parser.add_argument('--wandb_usr_name', type=str, default='')
|
||||
parser.add_argument('--wandb_key', type=str, default='')
|
||||
|
|
Loading…
Reference in New Issue