DINOv/datasets/shapes/simpleclick_sampler.py

80 lines
3.0 KiB
Python

import sys
import random
import cv2
import numpy as np
from scipy import ndimage
import torch
import torch.nn as nn
import torch.nn.functional as F
from kornia.contrib import distance_transform
from .scribble import Scribble
from dinov.utils import configurable
class SimpleClickSampler(nn.Module):
@configurable
def __init__(self, mask_mode='point', sample_negtive=False, is_train=True, dilation=None, dilation_kernel=None):
super().__init__()
self.mask_mode = mask_mode
self.sample_negtive = sample_negtive
self.is_train = is_train
self.dilation = dilation
self.register_buffer("dilation_kernel", dilation_kernel)
@classmethod
def from_config(cls, cfg, is_train=True, mode=None):
mask_mode = mode
sample_negtive = cfg['STROKE_SAMPLER']['EVAL']['NEGATIVE']
dilation = cfg['STROKE_SAMPLER']['DILATION']
dilation_kernel = torch.ones((1, 1, dilation, dilation), device=torch.cuda.current_device())
# Build augmentation
return {
"mask_mode": mask_mode,
"sample_negtive": sample_negtive,
"is_train": is_train,
"dilation": dilation,
"dilation_kernel": dilation_kernel,
}
def forward_scribble(self, instances, pred_masks=None, prev_masks=None):
gt_masks_batch = instances.gt_masks
_,h,w = gt_masks_batch.shape
rand_shapes = []
for i in range(len(gt_masks_batch)):
gt_masks = gt_masks_batch[i:i+1]
assert len(gt_masks) == 1 # it only supports a single image, with a single candidate mask.
# pred_masks is after padding
# We only consider positive points
pred_masks = torch.zeros(gt_masks.shape).bool() if pred_masks is None else pred_masks[:,:h,:w]
prev_masks = torch.zeros(gt_masks.shape).bool() if prev_masks is None else prev_masks
fp = gt_masks & (~(gt_masks & pred_masks)) & (~prev_masks)
next_mask = torch.zeros(gt_masks.shape).bool()
mask_dt = torch.from_numpy(cv2.distanceTransform(fp[0].numpy().astype(np.uint8), cv2.DIST_L2, 0)[None,:])
max_value = mask_dt.max()
next_mask[(mask_dt==max_value).nonzero()[0:1].t().tolist()] = True
points = next_mask[0].nonzero().flip(dims=[-1])
next_mask = Scribble.draw_by_points(points, gt_masks, h, w)
rand_shapes += [(prev_masks | next_mask)]
types = ['scribble' for i in range(len(gt_masks_batch))]
return {'gt_masks': instances.gt_masks, 'rand_shape': rand_shapes, 'types': types, 'sampler': self}
def forward(self, instances, *args, **kwargs):
if self.mask_mode == 'Point':
return self.forward_point(instances, *args, **kwargs)
elif self.mask_mode == 'Circle':
assert False, "Circle not support best path."
elif self.mask_mode == 'Scribble':
assert False, "Scribble not support best path."
elif self.mask_mode == 'Polygon':
assert False, "Polygon not support best path."