mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
363 lines
14 KiB
Python
363 lines
14 KiB
Python
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
from typing import List, Optional
|
||
|
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
import torch.nn.functional as F
|
||
|
import torchvision
|
||
|
from mmcv.runner import get_dist_info
|
||
|
from torch import Tensor, nn
|
||
|
|
||
|
from .point_rend import (get_uncertain_point_coords_with_randomness,
|
||
|
point_sample)
|
||
|
|
||
|
|
||
|
def _max_by_axis(the_list):
|
||
|
# type: (List[List[int]]) -> List[int]
|
||
|
maxes = the_list[0]
|
||
|
for sublist in the_list[1:]:
|
||
|
for index, item in enumerate(sublist):
|
||
|
maxes[index] = max(maxes[index], item)
|
||
|
return maxes
|
||
|
|
||
|
|
||
|
class NestedTensor(object):
|
||
|
|
||
|
def __init__(self, tensors, mask: Optional[Tensor]):
|
||
|
self.tensors = tensors
|
||
|
self.mask = mask
|
||
|
|
||
|
def to(self, device):
|
||
|
# type: (Device) -> NestedTensor # noqa
|
||
|
cast_tensor = self.tensors.to(device)
|
||
|
mask = self.mask
|
||
|
if mask is not None:
|
||
|
assert mask is not None
|
||
|
cast_mask = mask.to(device)
|
||
|
else:
|
||
|
cast_mask = None
|
||
|
return NestedTensor(cast_tensor, cast_mask)
|
||
|
|
||
|
def decompose(self):
|
||
|
return self.tensors, self.mask
|
||
|
|
||
|
def __repr__(self):
|
||
|
return str(self.tensors)
|
||
|
|
||
|
|
||
|
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||
|
# TODO make this more general
|
||
|
if tensor_list[0].ndim == 3:
|
||
|
if torchvision._is_tracing():
|
||
|
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||
|
# call _onnx_nested_tensor_from_tensor_list() instead
|
||
|
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||
|
|
||
|
# TODO make it support different-sized images
|
||
|
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||
|
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||
|
batch_shape = [len(tensor_list)] + max_size
|
||
|
b, c, h, w = batch_shape
|
||
|
dtype = tensor_list[0].dtype
|
||
|
device = tensor_list[0].device
|
||
|
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||
|
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||
|
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||
|
pad_img[:img.shape[0], :img.shape[1], :img.shape[2]].copy_(img)
|
||
|
m[:img.shape[1], :img.shape[2]] = False
|
||
|
else:
|
||
|
raise ValueError('not supported')
|
||
|
return NestedTensor(tensor, mask)
|
||
|
|
||
|
|
||
|
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||
|
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||
|
@torch.jit.unused
|
||
|
def _onnx_nested_tensor_from_tensor_list(
|
||
|
tensor_list: List[Tensor]) -> NestedTensor:
|
||
|
max_size = []
|
||
|
for i in range(tensor_list[0].dim()):
|
||
|
max_size_i = torch.max(
|
||
|
torch.stack([img.shape[i] for img in tensor_list
|
||
|
]).to(torch.float32)).to(torch.int64)
|
||
|
max_size.append(max_size_i)
|
||
|
max_size = tuple(max_size)
|
||
|
|
||
|
# work around for
|
||
|
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||
|
# m[: img.shape[1], :img.shape[2]] = False
|
||
|
# which is not yet supported in onnx
|
||
|
padded_imgs = []
|
||
|
padded_masks = []
|
||
|
for img in tensor_list:
|
||
|
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
||
|
padded_img = torch.nn.functional.pad(
|
||
|
img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||
|
padded_imgs.append(padded_img)
|
||
|
|
||
|
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||
|
padded_mask = torch.nn.functional.pad(m,
|
||
|
(0, padding[2], 0, padding[1]),
|
||
|
'constant', 1)
|
||
|
padded_masks.append(padded_mask.to(torch.bool))
|
||
|
|
||
|
tensor = torch.stack(padded_imgs)
|
||
|
mask = torch.stack(padded_masks)
|
||
|
|
||
|
return NestedTensor(tensor, mask=mask)
|
||
|
|
||
|
|
||
|
def dice_loss(
|
||
|
inputs: torch.Tensor,
|
||
|
targets: torch.Tensor,
|
||
|
num_masks: float,
|
||
|
):
|
||
|
"""
|
||
|
Compute the DICE loss, similar to generalized IOU for masks
|
||
|
Args:
|
||
|
inputs: A float tensor of arbitrary shape.
|
||
|
The predictions for each example.
|
||
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
||
|
classification label for each element in inputs
|
||
|
(0 for the negative class and 1 for the positive class).
|
||
|
"""
|
||
|
inputs = inputs.sigmoid()
|
||
|
inputs = inputs.flatten(1)
|
||
|
numerator = 2 * (inputs * targets).sum(-1)
|
||
|
denominator = inputs.sum(-1) + targets.sum(-1)
|
||
|
loss = 1 - (numerator + 1) / (denominator + 1)
|
||
|
return loss.sum() / num_masks
|
||
|
|
||
|
|
||
|
dice_loss_jit = torch.jit.script(dice_loss) # type: torch.jit.ScriptModule
|
||
|
|
||
|
|
||
|
def sigmoid_ce_loss(
|
||
|
inputs: torch.Tensor,
|
||
|
targets: torch.Tensor,
|
||
|
num_masks: float,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
inputs: A float tensor of arbitrary shape.
|
||
|
The predictions for each example.
|
||
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
||
|
classification label for each element in inputs
|
||
|
(0 for the negative class and 1 for the positive class).
|
||
|
Returns:
|
||
|
Loss tensor
|
||
|
"""
|
||
|
loss = F.binary_cross_entropy_with_logits(
|
||
|
inputs, targets, reduction='none')
|
||
|
|
||
|
return loss.mean(1).sum() / num_masks
|
||
|
|
||
|
|
||
|
sigmoid_ce_loss_jit = torch.jit.script(
|
||
|
sigmoid_ce_loss) # type: torch.jit.ScriptModule
|
||
|
|
||
|
|
||
|
def calculate_uncertainty(logits):
|
||
|
"""
|
||
|
We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the
|
||
|
foreground class in `classes`.
|
||
|
Args:
|
||
|
logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or
|
||
|
class-agnostic, where R is the total number of predicted masks in all images and C is
|
||
|
the number of foreground classes. The values are logits.
|
||
|
Returns:
|
||
|
scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with
|
||
|
the most uncertain locations having the highest uncertainty score.
|
||
|
"""
|
||
|
assert logits.shape[1] == 1
|
||
|
gt_class_logits = logits.clone()
|
||
|
return -(torch.abs(gt_class_logits))
|
||
|
|
||
|
|
||
|
# Modified from https://github.com/facebookresearch/detr/blob/master/models/detr.py
|
||
|
class SetCriterion(nn.Module):
|
||
|
"""This class computes the loss for Mask2former.
|
||
|
The process happens in two steps:
|
||
|
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
|
||
|
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
|
||
|
"""
|
||
|
|
||
|
def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses,
|
||
|
num_points, oversample_ratio, importance_sample_ratio):
|
||
|
"""Create the criterion.
|
||
|
Parameters:
|
||
|
num_classes: number of object categories, omitting the special no-object category
|
||
|
matcher: module able to compute a matching between targets and proposals
|
||
|
weight_dict: dict containing as key the names of the losses and as values their relative weight.
|
||
|
eos_coef: relative classification weight applied to the no-object category
|
||
|
losses: list of all the losses to be applied. See get_loss for list of available losses.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self.num_classes = num_classes
|
||
|
self.matcher = matcher
|
||
|
self.weight_dict = weight_dict
|
||
|
self.eos_coef = eos_coef
|
||
|
self.losses = losses
|
||
|
empty_weight = torch.ones(self.num_classes + 1)
|
||
|
empty_weight[-1] = self.eos_coef
|
||
|
self.register_buffer('empty_weight', empty_weight)
|
||
|
|
||
|
# pointwise mask loss parameters
|
||
|
self.num_points = num_points
|
||
|
self.oversample_ratio = oversample_ratio
|
||
|
self.importance_sample_ratio = importance_sample_ratio
|
||
|
|
||
|
def loss_labels(self, outputs, targets, indices, num_masks):
|
||
|
"""Classification loss (NLL)
|
||
|
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
|
||
|
"""
|
||
|
assert 'pred_logits' in outputs
|
||
|
src_logits = outputs['pred_logits'].float()
|
||
|
|
||
|
idx = self._get_src_permutation_idx(indices)
|
||
|
target_classes_o = torch.cat(
|
||
|
[t['labels'][J] for t, (_, J) in zip(targets, indices)])
|
||
|
target_classes = torch.full(
|
||
|
src_logits.shape[:2],
|
||
|
self.num_classes,
|
||
|
dtype=torch.int64,
|
||
|
device=src_logits.device)
|
||
|
target_classes[idx] = target_classes_o
|
||
|
loss_ce = F.cross_entropy(
|
||
|
src_logits.transpose(1, 2), target_classes, self.empty_weight)
|
||
|
losses = {'loss_ce': loss_ce}
|
||
|
return losses
|
||
|
|
||
|
def loss_masks(self, outputs, targets, indices, num_masks):
|
||
|
"""Compute the losses related to the masks: the focal loss and the dice loss.
|
||
|
targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
|
||
|
"""
|
||
|
assert 'pred_masks' in outputs
|
||
|
|
||
|
src_idx = self._get_src_permutation_idx(indices)
|
||
|
tgt_idx = self._get_tgt_permutation_idx(indices)
|
||
|
src_masks = outputs['pred_masks']
|
||
|
src_masks = src_masks[src_idx]
|
||
|
masks = [t['masks'] for t in targets]
|
||
|
# TODO use valid to mask invalid areas due to padding in loss
|
||
|
target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
|
||
|
target_masks = target_masks.to(src_masks)
|
||
|
target_masks = target_masks[tgt_idx]
|
||
|
|
||
|
# No need to upsample predictions as we are using normalized coordinates :)
|
||
|
# N x 1 x H x W
|
||
|
src_masks = src_masks[:, None]
|
||
|
target_masks = target_masks[:, None]
|
||
|
|
||
|
with torch.no_grad():
|
||
|
# sample point_coords
|
||
|
point_coords = get_uncertain_point_coords_with_randomness(
|
||
|
src_masks,
|
||
|
lambda logits: calculate_uncertainty(logits),
|
||
|
self.num_points,
|
||
|
self.oversample_ratio,
|
||
|
self.importance_sample_ratio,
|
||
|
)
|
||
|
# get gt labels
|
||
|
point_labels = point_sample(
|
||
|
target_masks,
|
||
|
point_coords,
|
||
|
align_corners=False,
|
||
|
).squeeze(1)
|
||
|
|
||
|
point_logits = point_sample(
|
||
|
src_masks,
|
||
|
point_coords,
|
||
|
align_corners=False,
|
||
|
).squeeze(1)
|
||
|
losses = {
|
||
|
'loss_mask': sigmoid_ce_loss_jit(point_logits, point_labels,
|
||
|
num_masks),
|
||
|
'loss_dice': dice_loss_jit(point_logits, point_labels, num_masks),
|
||
|
}
|
||
|
|
||
|
del src_masks
|
||
|
del target_masks
|
||
|
return losses
|
||
|
|
||
|
def _get_src_permutation_idx(self, indices):
|
||
|
# permute predictions following indices
|
||
|
batch_idx = torch.cat(
|
||
|
[torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
|
||
|
src_idx = torch.cat([src for (src, _) in indices])
|
||
|
return batch_idx, src_idx
|
||
|
|
||
|
def _get_tgt_permutation_idx(self, indices):
|
||
|
# permute targets following indices
|
||
|
batch_idx = torch.cat(
|
||
|
[torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
|
||
|
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
|
||
|
return batch_idx, tgt_idx
|
||
|
|
||
|
def get_loss(self, loss, outputs, targets, indices, num_masks):
|
||
|
loss_map = {
|
||
|
'labels': self.loss_labels,
|
||
|
'masks': self.loss_masks,
|
||
|
}
|
||
|
assert loss in loss_map, f'do you really want to compute {loss} loss?'
|
||
|
return loss_map[loss](outputs, targets, indices, num_masks)
|
||
|
|
||
|
def forward(self, outputs, targets):
|
||
|
"""This performs the loss computation.
|
||
|
Parameters:
|
||
|
outputs: dict of tensors, see the output specification of the model for the format
|
||
|
targets: list of dicts, such that len(targets) == batch_size.
|
||
|
The expected keys in each dict depends on the losses applied, see each loss' doc
|
||
|
"""
|
||
|
outputs_without_aux = {
|
||
|
k: v
|
||
|
for k, v in outputs.items() if k != 'aux_outputs'
|
||
|
}
|
||
|
|
||
|
# Retrieve the matching between the outputs of the last layer and the targets
|
||
|
indices = self.matcher(outputs_without_aux, targets)
|
||
|
|
||
|
# Compute the average number of target boxes accross all nodes, for normalization purposes
|
||
|
num_masks = sum(len(t['labels']) for t in targets)
|
||
|
num_masks = torch.as_tensor([num_masks],
|
||
|
dtype=torch.float,
|
||
|
device=next(iter(outputs.values())).device)
|
||
|
if dist.is_available() and dist.is_initialized():
|
||
|
torch.distributed.all_reduce(num_masks)
|
||
|
rank, world_size = get_dist_info()
|
||
|
num_masks = torch.clamp(num_masks / world_size, min=1).item()
|
||
|
|
||
|
# Compute all the requested losses
|
||
|
losses = {}
|
||
|
for loss in self.losses:
|
||
|
losses.update(
|
||
|
self.get_loss(loss, outputs, targets, indices, num_masks))
|
||
|
|
||
|
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
|
||
|
if 'aux_outputs' in outputs:
|
||
|
for i, aux_outputs in enumerate(outputs['aux_outputs']):
|
||
|
indices = self.matcher(aux_outputs, targets)
|
||
|
for loss in self.losses:
|
||
|
l_dict = self.get_loss(loss, aux_outputs, targets, indices,
|
||
|
num_masks)
|
||
|
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
|
||
|
losses.update(l_dict)
|
||
|
|
||
|
return losses
|
||
|
|
||
|
def __repr__(self):
|
||
|
head = 'Criterion ' + self.__class__.__name__
|
||
|
body = [
|
||
|
'matcher: {}'.format(self.matcher.__repr__(_repr_indent=8)),
|
||
|
'losses: {}'.format(self.losses),
|
||
|
'weight_dict: {}'.format(self.weight_dict),
|
||
|
'num_classes: {}'.format(self.num_classes),
|
||
|
'eos_coef: {}'.format(self.eos_coef),
|
||
|
'num_points: {}'.format(self.num_points),
|
||
|
'oversample_ratio: {}'.format(self.oversample_ratio),
|
||
|
'importance_sample_ratio: {}'.format(self.importance_sample_ratio),
|
||
|
]
|
||
|
_repr_indent = 4
|
||
|
lines = [head] + [' ' * _repr_indent + line for line in body]
|
||
|
return '\n'.join(lines)
|