add demo link

pull/8/head
Feng Li 2024-03-21 10:14:43 +08:00
parent f5723f6f4a
commit 8ea2645d23
15 changed files with 456 additions and 42 deletions

2
.gitignore vendored
View File

@ -1,5 +1,5 @@
# IntelliJ project files
repo.diff
#repo.diff
.idea
.vscode
.amltignore

View File

@ -1,10 +1,26 @@
# Visual In-Context Prompting
:grapes: \[[Read our arXiv Paper](https://arxiv.org/pdf/2311.13601.pdf)\]   :apple: \[[Try our Demo](http://semantic-sam.xyzou.net:6099/)\]
In this work, we introduce [DINOv](https://arxiv.org/pdf/2311.13601.pdf), a Visual In-Context Prompting framework for referring and generic segmentation tasks.
For visualization and demos, we recommend using [T-Rex demo link](https://deepdataspace.com/playground/ivp), which is another visual prompting tool in our team with similar properties as DINOv.
For visualization and demos, we also recommend trying [T-Rex demo link](https://deepdataspace.com/playground/ivp), which is another visual prompting tool in our team with similar properties as DINOv.
![teaser](https://github.com/UX-Decoder/DINOv/assets/34880758/f686dd20-a5aa-40fa-ad57-c4c69575853b)
### :hammer_and_wrench: Installation
```shell
pip3 install torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu113
python -m pip install 'git+https://github.com/MaureenZOU/detectron2-xyz.git'
pip install git+https://github.com/cocodataset/panopticapi.git
git clone https://github.com/UX-Decoder/DINOv
cd DINOv
python -m pip install -r requirements.txt
```
#### :point_right: Launch a demo for visual in-context prompting
```shell
python demo_openset.py --ckpt /path/to/swinL/ckpt
```
# Openset segmentation
![generic_seg_vis](https://github.com/UX-Decoder/DINOv/assets/34880758/bfbe4d90-5be9-4fa5-a4e7-83f5c25f7f23)
@ -19,19 +35,13 @@ For visualization and demos, we recommend using [T-Rex demo link](https://deepda
## :unicorn: Getting Started
### :hammer_and_wrench: Installation
```shell
pip3 install torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu113
python -m pip install 'git+https://github.com/MaureenZOU/detectron2-xyz.git'
pip install git+https://github.com/cocodataset/panopticapi.git
git clone https://github.com/UX-Decoder/DINOv
cd DINOv
python -m pip install -r requirements.txt
```
### :mosque: Data preparation
We jointly train on COCO and SA-1B data. Please refer to [prepare SA-1B data](https://github.com/UX-Decoder/Semantic-SAM/blob/main/DATASET.md) and [prepare coco data](https://github.com/IDEA-Research/MaskDINO/blob/main/README.md).
For evaluation, you need to prepare
* [ADE20K](https://github.com/IDEA-Research/MaskDINO/blob/main/datasets/README.md) for open-set segmentation evaluation.
* [DAVIS2017](https://davischallenge.org/davis2017/code.html) for refering segmentation (video object segmentation).
### :volcano: Model Zoo
The currently released checkpoints are trained with SA-1B and COCO data.
<table><tbody>
@ -65,6 +75,7 @@ We do detection evaluation on COCO val2017.
`$n` is the number of gpus you use
Process visual prompt embeddings for inference. We calculate the all the instance prompt embeddings of the validate set (you can also use the training set, but the processing time is much longer) and store them. Then we infrence by randomly selecting some visual prompts as in-context examples.
#### Evaluate Open-set detection and segmentation
* Infenrence script to get and store visual prompts
```shell
python train_net.py --eval_only --resume --eval_get_content_features --num-gpus 8 --config-file /path/to/configs COCO.TEST.BATCH_SIZE_TOTAL=8 MODEL.WEIGHTS=/path/to/weights OUTPUT_DIR=/path/to/outputs
@ -74,8 +85,21 @@ python train_net.py --eval_only --resume --eval_get_content_features --num-gpus
python train_net.py --eval_only --resume --eval_visual_openset --num-gpus 8 --config-file /path/to/configs COCO.TEST.BATCH_SIZE_TOTAL=8 MODEL.WEIGHTS=/path/to/weights MODEL.DECODER.INFERENCE_EXAMPLE=16 OUTPUT_DIR=/path/to/outputs
```
* **configs** to use are `configs/dinov_sam_coco_train.yaml` for swinT and `configs/dinov_sam_coco_swinl_train.yaml` for swinL.
* For ADE20K data, use `configs/dinov_sam_ade_eval.yaml` and adjust the batchsize of ADE evaluation to the correct number.
* `OUTPUT_DIR` is the dir to store the visual prompt embeddings
* `INFERENCE_EXAMPLE` number of in-context examples to represent a category. Default set to 16.
#### Evaluate Refering segmentation on VOS
We evaluate under the `DAVIS 2017 Semi-supervised` setting, please refer to [davis2017-evaluation](https://github.com/davisvideochallenge/davis2017-evaluation) for more details.
The first step is to compute and store the results of DAVIS2017. We implement a navie memory-aware approach with our in-context visual prompting.
```shell
python train_net.py --eval_track_prev --eval_only --resume --num-gpus 8 --config-file configs/dinov_sam_coco_train.yaml DAVIS.TEST.BATCH_SIZE_TOTAL=8 OUTPUT_DIR=$outdir MODEL.WEIGHTS=/path/to/weights MODEL.DECODER.NMS_THRESHOLD=0.9 MODEL.DECODER.MAX_MEMORY_SIZE=9 OUTPUT_DIR=/path/to/outputs
```
The second step is to evaluate the semi-supervised results.
```shell
python evaluation_method.py --task semi-supervised --results_path /path/to/results --davis_path /path/to/davis/data
```
* We use MAX_MEMORY_SIZE = 9 by default (1 current frame token and 8 previous memory tokens)
### :star: Training
We currently release the code of training on SA-1B and COCO. It can also support Objects365 and other datasets with minimal modifications.
`$n` is the number of gpus you use

View File

@ -176,7 +176,7 @@ MODEL:
NMS_THRESHOLD: 0.7
REFER_IMAGE: False
point_per_side: 20
max_memory_size: 8
MAX_MEMORY_SIZE: 9
softmax: False
MANY2MANY: True
VIS: False

View File

@ -166,7 +166,6 @@ MODEL:
NUM_CONTENT_TOKENS: 1
MAX_NUM_INSTANCE: 50
MAX_NUM_INSTANCE_CONTENT: 8
# CONTENT_INDEPENDENT: True
USE_SHAPE_SAMPLER: True
OPENSET_USE_SHAPE_SAMPLER: False
INFERENCE_EXAMPLE: 4
@ -176,7 +175,7 @@ MODEL:
NMS_THRESHOLD: 0.7
REFER_IMAGE: False
point_per_side: 20
max_memory_size: 8
MAX_MEMORY_SIZE: 9
softmax: False
MANY2MANY: True
VIS: False

View File

@ -176,7 +176,7 @@ MODEL:
NMS_THRESHOLD: 0.7
REFER_IMAGE: False
point_per_side: 20
max_memory_size: 8
MAX_MEMORY_SIZE: 9
softmax: True
MANY2MANY: True
VIS: False
@ -379,7 +379,7 @@ ADE20K:
DAVIS:
INPUT:
MIN_SIZE_TEST: 800
MIN_SIZE_TEST: 640
MAX_SIZE_TEST: 1333
DATALOADER:
FILTER_EMPTY_ANNOTATIONS: False
@ -392,7 +392,7 @@ DAVIS:
VOS:
INPUT:
MIN_SIZE_TEST: 800
MIN_SIZE_TEST: 640
MAX_SIZE_TEST: 1333
DATALOADER:
FILTER_EMPTY_ANNOTATIONS: False

1
demo/__init__.py 100644
View File

@ -0,0 +1 @@
from .openset_task import task_openset

Binary file not shown.

After

Width:  |  Height:  |  Size: 140 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 207 KiB

View File

@ -0,0 +1,142 @@
# --------------------------------------------------------
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
# --------------------------------------------------------
# Copyright (c) 2024 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Feng Li (fliay@connect.ust.hk)
# --------------------------------------------------------
import torch
import numpy as np
from torchvision import transforms
from utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
import os
import cv2
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
def inverse_sigmoid(x, eps=1e-5):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1/x2)
def task_openset(model,generic_vp1, generic_vp2, generic_vp3, generic_vp4,
generic_vp5, generic_vp6, generic_vp7, generic_vp8, image_tgt=None, text_size=640,hole_scale=100,island_scale=100):
in_context_examples = [generic_vp1, generic_vp2, generic_vp3, generic_vp4,
generic_vp5, generic_vp6, generic_vp7, generic_vp8]
in_context_examples = [x for x in in_context_examples if x is not None]
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
def prepare_image(image_ori):
width = image_ori.size[0]
height = image_ori.size[1]
image_ori = np.asarray(image_ori)
images = torch.from_numpy(image_ori.copy()).permute(2,0,1).cuda()
return images, height, width
transform1 = transforms.Compose(t)
image_ori_tgt = transform1(image_tgt)
images_tgt, height_tgt, width_tgt = prepare_image(image_ori_tgt)
data_tgt = {"image": images_tgt, "height": height_tgt, "width": width_tgt}
batched_inputs = []
batched_inputs_tgt = [data_tgt]
multi_scale_features2, mask_features2, _, _ = model.model.get_encoder_feature(batched_inputs_tgt)
input_query_label_content_all = []
point_coords = torch.ones(1, 4).cuda().float()
point_coords[:, :2] = 0.
input_query_bbox_content_init = inverse_sigmoid(point_coords[None])
for image in in_context_examples:
image_ori = transform1(image['image'])
mask_ori = transform1(image['mask'])
images, height, width = prepare_image(image_ori)
data = {"image": images, "height": height, "width": width}
data['seg_image'] = data_tgt
mask_ori = np.asarray(mask_ori)[:,:,0:1].copy()
mask_ori = torch.from_numpy(mask_ori).permute(2,0,1)
data['targets'] = [dict()]
data['targets'][0]['rand_shape']=mask_ori
data['targets'][0]['pb']=torch.tensor([1.]) # FIXME 0 or 1
frame = data
rand_shape = mask_ori
frame['targets'][0]['rand_shape'] = rand_shape
batched_inputs.append(frame)
multi_scale_features, _, padded_h, padded_w = model.model.get_encoder_feature([frame])
input_query_label_content, input_query_bbox_content, attn_mask_content = model.model. \
get_visual_prompt_content_feature(multi_scale_features, frame['targets'][0]['rand_shape'], padded_h, padded_w)
input_query_label_content_all.append(input_query_label_content)
# prompt to tgt image
input_query_label_content_current = torch.stack(input_query_label_content_all).mean(0)
masks, ious, ori_masks, scores_per_image_openset = model.model.evaluate_demo_content_openset_multi_with_content_features(
batched_inputs_tgt, mask_features2, multi_scale_features2, input_query_label_content_current,
input_query_bbox_content_init, attn_mask_content, padded_h, padded_w)
if len(ious.shape)>1:
ious=ious[0]
ids=torch.argsort(scores_per_image_openset,descending=True)
areas=[]
image_ori = image_ori_tgt
new_pred_mask = []
new_pred_class_score = []
for i in ids:
new_pred_class_score.append(scores_per_image_openset[i])
new_pred_mask.append(masks[i])
pred_masks_poses = new_pred_mask
ious = new_pred_class_score
visual = Visualizer(image_ori, metadata=metadata)
for i,(pred_masks_pos,iou, _, _) in enumerate(zip(pred_masks_poses,ious, pred_masks_poses, pred_masks_poses)):
iou=round(float(iou),2)
texts=f'{iou}'
mask=(pred_masks_pos>0.0).cpu().numpy()
area=mask.sum()
areas.append(area)
# uncomment for additional postprocessing
# mask,_=remove_small_regions(mask,int(hole_scale),mode="holes")
# mask,_=remove_small_regions(mask,int(island_scale),mode="islands")
mask=(mask).astype(np.float)
color=[0.,0.,1.0]
color=[0.502, 0.0, 0.502]
demo = visual.draw_binary_mask(mask, text='', alpha=0.7, edge_color=color)
res = demo.get_image()
torch.cuda.empty_cache()
return res
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True

140
demo_openset.py 100644
View File

@ -0,0 +1,140 @@
# --------------------------------------------------------
# Semantic-SAM: Segment and Recognize Anything at Any Granularity
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Hao Zhang (hzhangcx@connect.ust.hk)
# --------------------------------------------------------
# Copyright (c) 2024 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Feng Li (fliay@connect.ust.hk)
# --------------------------------------------------------
import gradio as gr
import torch
import argparse
from dinov.BaseModel import BaseModel
from dinov import build_model
from utils.arguments import load_opt_from_config_file
from demo import task_openset
def parse_option():
parser = argparse.ArgumentParser('DINOv Demo', add_help=False)
parser.add_argument('--conf_files', default="configs/dinov_sam_coco_swinl_train.yaml", metavar="FILE", help='path to config file', )
parser.add_argument('--ckpt', default="", metavar="FILE", help='path to ckpt', required=True)
parser.add_argument('--port', default=6099, type=int, help='path to ckpt', )
args = parser.parse_args()
return args
class ImageMask(gr.components.Image):
"""
Sets: source="canvas", tool="sketch"
"""
is_template = True
def __init__(self, **kwargs):
super().__init__(source="upload", tool="sketch", interactive=True, **kwargs)
def preprocess(self, x):
return super().preprocess(x)
'''
build args
'''
args = parse_option()
'''
build model
'''
sam_cfg=args.conf_files
opt = load_opt_from_config_file(sam_cfg)
model_sam = BaseModel(opt, build_model(opt)).from_pretrained(args.ckpt).eval().cuda()
@torch.no_grad()
def inference(generic_vp1, generic_vp2, generic_vp3, generic_vp4,
generic_vp5, generic_vp6, generic_vp7, generic_vp8, image2,*args, **kwargs):
with torch.autocast(device_type='cuda', dtype=torch.float16):
model=model_sam
a= task_openset(model, generic_vp1, generic_vp2, generic_vp3, generic_vp4,
generic_vp5, generic_vp6, generic_vp7, generic_vp8, image2, *args, **kwargs)
return a
'''
launch app
'''
title = "DINOv: Visual In-Context Prompting"
article = "The Demo is Run on DINOv."
demo = gr.Blocks()
image_tgt=gr.components.Image(label="Target Image ",type="pil",brush_radius=15.0)
gallery_output=gr.components.Image(label="Results Image ",type="pil",brush_radius=15.0)
generic_vp1 = ImageMask(label="scribble on refer Image 1",type="pil",brush_radius=15.0)
generic_vp2 = ImageMask(label="scribble on refer Image 2",type="pil",brush_radius=15.0)
generic_vp3 = ImageMask(label="scribble on refer Image 3",type="pil",brush_radius=15.0)
generic_vp4 = ImageMask(label="scribble on refer Image 5",type="pil",brush_radius=15.0)
generic_vp5 = ImageMask(label="scribble on refer Image 6",type="pil",brush_radius=15.0)
generic_vp6 = ImageMask(label="scribble on refer Image 7",type="pil",brush_radius=15.0)
generic_vp7 = ImageMask(label="scribble on refer Image 8",type="pil",brush_radius=15.0)
generic_vp8 = ImageMask(label="scribble on refer Image 9",type="pil",brush_radius=15.0)
generic = gr.TabbedInterface([
generic_vp1, generic_vp2, generic_vp3, generic_vp4,
generic_vp5, generic_vp6, generic_vp7, generic_vp8
], ["1", "2", "3", "4", "5", "6", "7", "8"])
title='''
# DINOv: Visual In-Context Prompting
# [[Read our arXiv Paper](https://arxiv.org/pdf/2311.13601.pdf)\] &nbsp; \[[Github page](https://github.com/UX-Decoder/DINOv)\]
'''
with demo:
with gr.Row():
with gr.Column(scale=3.0):
generation_tittle = gr.Markdown(title)
image_tgt.render()
generic.render()
with gr.Row(scale=2.0):
clearBtn = gr.ClearButton(
components=[image_tgt])
runBtn = gr.Button("Run")
with gr.Column(scale=5.0):
gallery_tittle = gr.Markdown("# Open-set results.")
with gr.Row(scale=9.0):
gallery_output.render()
example = gr.Examples(
examples=[
["demo/examples/bags.jpg"],
["demo/examples/img.png"],
["demo/examples/corgi2.jpg"],
["demo/examples/ref_cat.jpeg"],
],
inputs=image_tgt,
cache_examples=False,
)
title = title,
article = article,
allow_flagging = 'never',
runBtn.click(inference, inputs=[generic_vp1, generic_vp2, generic_vp3, generic_vp4,
generic_vp5, generic_vp6, generic_vp7, generic_vp8, image_tgt],
outputs = [gallery_output])
demo.queue().launch(share=True,server_port=args.port)

View File

@ -191,7 +191,7 @@ class DINOv(nn.Module):
# freeze some parameters
to_freeze_dict = ['label_enc', 'pb_embedding']
if freeze_all:
if freeze_all:
for (name, param) in self.named_parameters():
param.requires_grad = False
print("!!!!!!!!freeze_all!!!!!!!!, except ", to_freeze_dict)
@ -1139,9 +1139,52 @@ class DINOv(nn.Module):
return src_boxes, mask_pred_results, src_ious, pred_ious
def filter_data_openset(self, src_boxes, mask_pred_results, src_ious, pred_ious, pred_score_openset):
def keep_data(box, mask, iou, match_score, pred_score_openset, keep):
return box[keep], mask[keep], iou[keep], match_score[:, keep], pred_score_openset[:, keep]
# advanced filtering
# uncomment below for better filltering
# print('filter iou score')
# keep = src_ious > 0.5
# src_boxes, mask_pred_results, src_ious, pred_ious, pred_score_openset = keep_data(src_boxes, mask_pred_results, src_ious, pred_ious,pred_score_openset,
# keep)
#
# stability_score = calculate_stability_score(
# mask_pred_results, 0.0, self.stability_score_offset
# )
# keep = stability_score >= self.stability_score_thresh
#
# src_boxes, mask_pred_results, src_ious, pred_ious, pred_score_openset = keep_data(src_boxes, mask_pred_results, src_ious,pred_score_openset,
# pred_ious, keep)
# print('using nms')
# item_indice = nms(box_ops.box_cxcywh_to_xyxy(src_boxes), src_ious, self.nms_thersh) # FIXME iou threshold
# mask_pred_results = mask_pred_results[item_indice]
# src_boxes = src_boxes[item_indice]
# src_ious = src_ious[item_indice]
# pred_ious = torch.index_select(pred_ious, -1, item_indice)
# pred_score_openset = torch.index_select(pred_score_openset, -1, item_indice)
# print("remove small objects")
# keep = (mask_pred_results > 0).flatten(-2, -1).sum(-1) > 50
# # keep = (mask_pred_results > 0).flatten(-2, -1).sum(-1) > 100
# src_boxes, mask_pred_results, src_ious, pred_ious, pred_score_openset = keep_data(src_boxes, mask_pred_results, src_ious,
# pred_ious, pred_score_openset, keep)
scores_per_image_openset, label_openset = pred_score_openset.sigmoid().max(-1)
thresh = 0.12
keep = scores_per_image_openset>thresh
while sum(keep)<1:
thresh = thresh-0.04
keep = scores_per_image_openset > thresh
scores_per_image_openset, label_openset = scores_per_image_openset[keep], label_openset[keep]
mask_pred_results = mask_pred_results[keep]
return src_boxes, mask_pred_results, src_ious, pred_ious, scores_per_image_openset
def get_encoder_feature(self, batched_inputs):
# get the image encoder features (multi-scale)
assert len(batched_inputs) == 1, "only support batch size equal to 1"
images = self.prepare_image(batched_inputs)
padded_h = images.tensor.shape[-2] # divisable to 32
padded_w = images.tensor.shape[-1]
@ -1238,6 +1281,70 @@ class DINOv(nn.Module):
)
return pred_masks, pred_ious, ori_masks
def evaluate_demo_content_openset_multi_with_content_features(self, batched_inputs, mask_features, multi_scale_features,
input_query_label_content,
input_query_bbox_content, attn_mask_content,
padded_h, padded_w,
level=[0,1,2,3,4,5], return_src_ious=False):
assert len(batched_inputs) == 1, "only support batch size equal to 1"
prediction_switch = {'part': False, 'whole': False, 'seg': True, 'det': True}
def prepare_image(batched_inputs, key='image'):
images = [x['image'].to(self.device) for x in batched_inputs]
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
images = ImageList.from_tensors(images, self.size_divisibility)
return images
images = prepare_image(batched_inputs)
# image_size_xyxy = torch.tensor([padded_w, padded_h, padded_w, padded_h]).cuda()
# auto_points = torch.tensor(self.build_point_grid(self.point_per_side)).cuda() * image_size_xyxy[:2]
# boxes_dn = box_ops.box_xyxy_to_cxcywh(
# torch.cat([auto_points - 3, auto_points + 3], 1)) / image_size_xyxy
# boxes_dn = torch.tensor(boxes_dn, dtype=torch.float32)
# pb = torch.ones(boxes_dn.shape[0]) # FIXME: use 1 for pb
# targets_p = [{}]
# targets_p[0]['boxes_dn'] = boxes_dn
# targets_p[0]['pb'] = pb
# targets = targets_p
targets = None
outputs, mask_dict = self.sem_seg_head.predictor.forward_openset_image_with_extracted_content(multi_scale_features,
mask_features, None, input_query_label_content, input_query_bbox_content, attn_mask_content, targets,
extra=prediction_switch)
src_boxes = outputs["pred_boxes"][0]
mask_pred_results = outputs["pred_masks"][0]
pred_score_openset = outputs["pred_logits"][0]
# level = torch.tensor(level).cuda()
src_ious = pred_score_openset.flatten(0, 1)
pred_ious = pred_score_openset
src_boxes, mask_pred_results, src_ious, pred_ious, scores_per_image_openset = self.filter_data_openset(src_boxes, mask_pred_results, src_ious,
pred_ious, pred_score_openset)
# upsample masks
mask_pred_results = F.interpolate(
mask_pred_results[None],
size=(images.tensor.shape[-2], images.tensor.shape[-1]),
mode="bilinear",
align_corners=False,
)
pred_masks = mask_pred_results
image_size = images.image_sizes[0]
height = batched_inputs[0].get('height', image_size[0])
width = batched_inputs[0].get('width', image_size[1])
ori_masks = pred_masks[:, : image_size[0], : image_size[1]].expand(1, -1, -1, -1)[0]
# import ipdb; ipdb.set_trace()
if self.sem_seg_postprocess_before_inference:
pred_masks = retry_if_cuda_oom(sem_seg_postprocess)(
pred_masks, image_size, height, width
)
return pred_masks, pred_ious, ori_masks, scores_per_image_openset
def semantic_inference(self, mask_cls, mask_pred):
# if use cross-entropy loss in training, evaluate with softmax
if self.semantic_ce_loss:

21
repo.diff 100644
View File

@ -0,0 +1,21 @@
diff --git openseed/architectures/joint_oi_model.py openseed/architectures/joint_oi_model.py
index 8086690..a0679fe 100644
--- openseed/architectures/joint_oi_model.py
+++ openseed/architectures/joint_oi_model.py
@@ -286,6 +286,7 @@ class GeneralizedMaskDINO(nn.Module):
"coco_on": dec_cfg.get('COCO', True),
"coco_mask_on": dec_cfg.get('COCO_MASK', True),
"o365_on": dec_cfg.get('O365', True),
+ "regenerate_point": dec_cfg.get('RE_POINT', False),
}
@property
@@ -531,7 +532,7 @@ class GeneralizedMaskDINO(nn.Module):
# if not self.training:
# box_start = int(num_mask/4*3)
- box_start = random.randint(0, self.max_num_instance - 1) # box based interactive after this number; about 1/4
+ box_start = random.randint(1, self.max_num_instance - 1) # box based interactive after this number; about 1/4
point_coords = targets_per_image.point_coords[index[:box_start]]
# FIXME randomly sample one point as the user input
if self.regenerate_point:

View File

@ -309,7 +309,7 @@ class Trainer(DefaultTrainer):
from PIL import Image
from queue import Queue, LifoQueue, PriorityQueue
cfg['DATASETS']['TEST'] = ['davis17_val']
maxsize = cfg['MODEL']['DECODER']['max_memory_size']
maxsize = cfg['MODEL']['DECODER']['MAX_MEMORY_SIZE']
dataloaders = cls.build_test_loader(cfg, dataset_name=None)
model = model.eval().cuda()
@ -612,10 +612,7 @@ def main(args=None):
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
if args.original_load:
print("using original loading")
model = model.from_pretrained(cfg.MODEL.WEIGHTS)
elif args.eval_visual_openset:
if args.eval_visual_openset:
res = Trainer.test_visual_openset(cfg, model, args.eval_visual_openset_combine)
elif args.eval_track_prev:
res = Trainer.test_tracking_prev(cfg, model)
@ -630,13 +627,6 @@ def main(args=None):
init_wandb(cfg, cfg['OUTPUT_DIR'], entity=args.wandb_usr_name, job_name=cfg['OUTPUT_DIR'])
trainer = Trainer(cfg)
if len(args.lang_weight)>0:
import copy
weight = copy.deepcopy(trainer.cfg.MODEL.WEIGHTS)
trainer.cfg.MODEL.WEIGHTS = args.lang_weight
print("load original language language weight!!!!!!")
trainer.resume_or_load(resume=args.resume)
trainer.cfg.MODEL.WEIGHTS = weight
print("load pretrained model weight!!!!!!")
trainer.resume_or_load(resume=args.resume)
@ -647,19 +637,9 @@ def main(args=None):
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument('--eval_only', action='store_true')
parser.add_argument('--eval_track', action='store_true')
parser.add_argument('--eval_visual_openset', action='store_true')
parser.add_argument('--eval_visual_openset_combine', action='store_true')
parser.add_argument('--eval_track_prev', action='store_true')
parser.add_argument('--eval_track_prev_davis16', action='store_true')
parser.add_argument('--eval_track_gt_frame', action='store_true')
parser.add_argument('--eval_track_prev_ytvos', action='store_true')
parser.add_argument('--eval_track_prev_ytvos_v2', action='store_true')
parser.add_argument('--eval_track_gt', action='store_true')
parser.add_argument('--eval_get_content_features', action='store_true')
parser.add_argument('--original_load', action='store_true')
parser.add_argument('--EVAL_FLAG', type=int, default=1)
parser.add_argument('--lang_weight', type=str, default='')
parser.add_argument('--WANDB', action='store_true')
parser.add_argument('--wandb_usr_name', type=str, default='')
parser.add_argument('--wandb_key', type=str, default='')