77 lines
2.9 KiB
Python
77 lines
2.9 KiB
Python
import sys
|
|
import random
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from .point import Point
|
|
from .polygon import Polygon
|
|
from .scribble import Scribble
|
|
from .circle import Circle
|
|
|
|
from modeling.utils import configurable
|
|
|
|
|
|
class ShapeSampler(nn.Module):
|
|
@configurable
|
|
def __init__(self, max_candidate=1, shape_prob=[], shape_candidate=[], is_train=True):
|
|
super().__init__()
|
|
self.max_candidate = max_candidate
|
|
self.shape_prob = shape_prob
|
|
self.shape_candidate = shape_candidate
|
|
self.is_train = is_train
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg, is_train=True, mode=None):
|
|
max_candidate = cfg['STROKE_SAMPLER']['MAX_CANDIDATE']
|
|
candidate_probs = cfg['STROKE_SAMPLER']['CANDIDATE_PROBS']
|
|
candidate_names = cfg['STROKE_SAMPLER']['CANDIDATE_NAMES']
|
|
|
|
if mode == 'hack_train':
|
|
candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, True) for class_name in candidate_names]
|
|
else:
|
|
# overwrite condidate_prob
|
|
if not is_train:
|
|
candidate_probs = [0.0 for x in range(len(candidate_names))]
|
|
candidate_probs[candidate_names.index(mode)] = 1.0
|
|
candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, is_train) for class_name in candidate_names]
|
|
|
|
# Build augmentation
|
|
return {
|
|
"max_candidate": max_candidate,
|
|
"shape_prob": candidate_probs,
|
|
"shape_candidate": candidate_classes,
|
|
"is_train": is_train,
|
|
}
|
|
|
|
def forward(self, instances):
|
|
masks = instances.gt_masks.tensor
|
|
boxes = instances.gt_boxes.tensor
|
|
|
|
if len(masks) == 0:
|
|
gt_masks = torch.zeros(masks.shape[-2:]).bool()
|
|
rand_masks = torch.zeros(masks.shape[-2:]).bool()
|
|
return {'gt_masks': gt_masks[None,:], 'rand_shape': torch.stack([rand_masks]), 'types': ['none']}
|
|
indices = [x for x in range(len(masks))]
|
|
|
|
if self.is_train:
|
|
random.shuffle(indices)
|
|
candidate_mask = masks[indices[:self.max_candidate]]
|
|
candidate_box = boxes[indices[:self.max_candidate]]
|
|
else:
|
|
candidate_mask = masks
|
|
candidate_box = boxes
|
|
|
|
draw_funcs = random.choices(self.shape_candidate, weights=self.shape_prob, k=len(candidate_mask))
|
|
rand_shapes = [d.draw(x,y) for d,x,y in zip(draw_funcs, candidate_mask, candidate_box)]
|
|
types = [repr(x) for x in draw_funcs]
|
|
for i in range(0, len(rand_shapes)):
|
|
if rand_shapes[i].sum() == 0:
|
|
candidate_mask[i] = candidate_mask[i] * 0
|
|
types[i] = 'none'
|
|
|
|
# candidate_mask: (c,h,w), bool. rand_shape: (c, iter, h, w), bool. types: list(c)
|
|
return {'gt_masks': candidate_mask, 'rand_shape': torch.stack(rand_shapes).bool(), 'types': types, 'sampler': self}
|
|
|
|
def build_shape_sampler(cfg, **kwargs):
|
|
return ShapeSampler(cfg, **kwargs) |