mirror of https://github.com/UX-Decoder/DINOv.git
80 lines
3.0 KiB
Python
80 lines
3.0 KiB
Python
import sys
|
|
import random
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from scipy import ndimage
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from kornia.contrib import distance_transform
|
|
|
|
from .scribble import Scribble
|
|
from dinov.utils import configurable
|
|
|
|
|
|
class SimpleClickSampler(nn.Module):
|
|
@configurable
|
|
def __init__(self, mask_mode='point', sample_negtive=False, is_train=True, dilation=None, dilation_kernel=None):
|
|
super().__init__()
|
|
self.mask_mode = mask_mode
|
|
self.sample_negtive = sample_negtive
|
|
self.is_train = is_train
|
|
self.dilation = dilation
|
|
self.register_buffer("dilation_kernel", dilation_kernel)
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg, is_train=True, mode=None):
|
|
mask_mode = mode
|
|
sample_negtive = cfg['STROKE_SAMPLER']['EVAL']['NEGATIVE']
|
|
|
|
dilation = cfg['STROKE_SAMPLER']['DILATION']
|
|
dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
|
|
|
|
# Build augmentation
|
|
return {
|
|
"mask_mode": mask_mode,
|
|
"sample_negtive": sample_negtive,
|
|
"is_train": is_train,
|
|
"dilation": dilation,
|
|
"dilation_kernel": dilation_kernel,
|
|
}
|
|
|
|
def forward_scribble(self, instances, pred_masks=None, prev_masks=None):
|
|
gt_masks_batch = instances.gt_masks
|
|
_,h,w = gt_masks_batch.shape
|
|
|
|
rand_shapes = []
|
|
for i in range(len(gt_masks_batch)):
|
|
gt_masks = gt_masks_batch[i:i+1]
|
|
assert len(gt_masks) == 1 # it only supports a single image, with a single candidate mask.
|
|
# pred_masks is after padding
|
|
|
|
# We only consider positive points
|
|
pred_masks = torch.zeros(gt_masks.shape).bool() if pred_masks is None else pred_masks[:,:h,:w]
|
|
prev_masks = torch.zeros(gt_masks.shape).bool() if prev_masks is None else prev_masks
|
|
|
|
fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
|
|
next_mask = torch.zeros(gt_masks.shape).bool()
|
|
|
|
mask_dt = torch.from_numpy(cv2.distanceTransform(fp[0].numpy().astype(np.uint8), cv2.DIST_L2, 0)[None,:])
|
|
max_value = mask_dt.max()
|
|
next_mask[(mask_dt==max_value).nonzero()[0:1].t().tolist()] = True
|
|
|
|
points = next_mask[0].nonzero().flip(dims=[-1])
|
|
next_mask = Scribble.draw_by_points(points, gt_masks, h, w)
|
|
rand_shapes += [(prev_masks | next_mask)]
|
|
|
|
types = ['scribble' for i in range(len(gt_masks_batch))]
|
|
return {'gt_masks': instances.gt_masks, 'rand_shape': rand_shapes, 'types': types, 'sampler': self}
|
|
|
|
def forward(self, instances, *args, **kwargs):
|
|
if self.mask_mode == 'Point':
|
|
return self.forward_point(instances, *args, **kwargs)
|
|
elif self.mask_mode == 'Circle':
|
|
assert False, "Circle not support best path."
|
|
elif self.mask_mode == 'Scribble':
|
|
assert False, "Scribble not support best path."
|
|
elif self.mask_mode == 'Polygon':
|
|
assert False, "Polygon not support best path."
|