import torch from torch.nn import functional as F def point_sample(input, point_coords, **kwargs): """ A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside [0, 1] x [0, 1] square. Args: input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains [0, 1] x [0, 1] normalized point coordinates. Returns: output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains features for points in `point_coords`. The features are obtained via bilinear interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. """ add_dim = False if point_coords.dim() == 3: add_dim = True point_coords = point_coords.unsqueeze(2) # fix type mismatch point_coords = point_coords.type_as(input) output = F.grid_sample(input, 2.0 * point_coords - 1.0, **kwargs) if add_dim: output = output.squeeze(3) return output def get_uncertain_point_coords_with_randomness(coarse_logits, uncertainty_func, num_points, oversample_ratio, importance_sample_ratio): """ Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties are calculated for each point using 'uncertainty_func' function that takes point's logit prediction as input. See PointRend paper for details. Args: coarse_logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for class-specific or class-agnostic prediction. uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that contains logit predictions for P points and returns their uncertainties as a Tensor of shape (N, 1, P). num_points (int): The number of points P to sample. oversample_ratio (int): Oversampling parameter. importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. Returns: point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P sampled points. """ assert oversample_ratio >= 1 assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 num_boxes = coarse_logits.shape[0] num_sampled = int(num_points * oversample_ratio) point_coords = torch.rand( num_boxes, num_sampled, 2, device=coarse_logits.device) point_logits = point_sample( coarse_logits, point_coords, align_corners=False) # It is crucial to calculate uncertainty based on the sampled prediction value for the points. # Calculating uncertainties of the coarse predictions first and sampling them for points leads # to incorrect results. # To illustrate this: assume uncertainty_func(logits)=-abs(logits), a sampled point between # two coarse predictions with -1 and 1 logits has 0 logits, and therefore 0 uncertainty value. # However, if we calculate uncertainties for the coarse predictions first, # both will have -1 uncertainty, and the sampled point will get -1 uncertainty. point_uncertainties = uncertainty_func(point_logits) num_uncertain_points = int(importance_sample_ratio * num_points) num_random_points = num_points - num_uncertain_points idx = torch.topk( point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] shift = num_sampled * torch.arange( num_boxes, dtype=torch.long, device=coarse_logits.device) idx += shift[:, None] point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( num_boxes, num_uncertain_points, 2) if num_random_points > 0: point_coords = torch.cat( [ point_coords, torch.rand( num_boxes, num_random_points, 2, device=coarse_logits.device), ], dim=1, ) return point_coords