mirror of https://github.com/alibaba/EasyCV.git
371 lines
15 KiB
Python
371 lines
15 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import mmcv
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from easycv.models import builder
|
|
from easycv.models.base import BaseModel
|
|
from easycv.models.builder import MODELS
|
|
from easycv.models.segmentation.utils.criterion import SetCriterion
|
|
from easycv.models.segmentation.utils.matcher import MaskHungarianMatcher
|
|
from easycv.models.segmentation.utils.panoptic_gt_processing import (
|
|
multi_apply, preprocess_panoptic_gt)
|
|
from easycv.utils.checkpoint import load_checkpoint
|
|
from easycv.utils.logger import get_root_logger, print_log
|
|
|
|
INSTANCE_OFFSET = 1000
|
|
|
|
|
|
@MODELS.register_module()
|
|
class Mask2Former(BaseModel):
|
|
|
|
def __init__(
|
|
self,
|
|
backbone,
|
|
head,
|
|
train_cfg,
|
|
test_cfg,
|
|
pretrained=None,
|
|
):
|
|
"""Mask2Former Model
|
|
|
|
Args:
|
|
backbone (dict): config to build backbone
|
|
head (dict): config to builg mask2former head
|
|
train_cfg (dict): config of training strategy.
|
|
test_cfg (dict): config of test strategy.
|
|
pretrained (str, optional): path of model weights. Defaults to None.
|
|
"""
|
|
super(Mask2Former, self).__init__()
|
|
self.train_cfg = train_cfg
|
|
self.test_cfg = test_cfg
|
|
self.instance_on = test_cfg.get('instance_on', False)
|
|
self.panoptic_on = test_cfg.get('panoptic_on', False)
|
|
self.pretrained = pretrained
|
|
self.backbone = builder.build_backbone(backbone)
|
|
self.head = builder.build_head(head)
|
|
# building criterion
|
|
self.num_classes = head.num_things_classes + head.num_stuff_classes
|
|
self.num_things_classes = head.num_things_classes
|
|
self.num_stuff_classes = head.num_stuff_classes
|
|
|
|
matcher = MaskHungarianMatcher(
|
|
cost_class=train_cfg.class_weight,
|
|
cost_mask=train_cfg.mask_weight,
|
|
cost_dice=train_cfg.dice_weight,
|
|
num_points=train_cfg.num_points,
|
|
)
|
|
weight_dict = {
|
|
'loss_ce': train_cfg.class_weight,
|
|
'loss_mask': train_cfg.mask_weight,
|
|
'loss_dice': train_cfg.dice_weight
|
|
}
|
|
|
|
if train_cfg.deep_supervision:
|
|
dec_layers = train_cfg.dec_layers
|
|
aux_weight_dict = {}
|
|
for i in range(dec_layers - 1):
|
|
aux_weight_dict.update(
|
|
{k + f'_{i}': v
|
|
for k, v in weight_dict.items()})
|
|
weight_dict.update(aux_weight_dict)
|
|
|
|
losses = ['labels', 'masks']
|
|
self.criterion = SetCriterion(
|
|
self.head.num_classes,
|
|
matcher=matcher,
|
|
weight_dict=weight_dict,
|
|
eos_coef=train_cfg.no_object_weight,
|
|
losses=losses,
|
|
num_points=train_cfg.num_points,
|
|
oversample_ratio=train_cfg.oversample_ratio,
|
|
importance_sample_ratio=train_cfg.importance_sample_ratio,
|
|
)
|
|
|
|
self.init_weights()
|
|
|
|
def init_weights(self):
|
|
logger = get_root_logger()
|
|
if isinstance(self.pretrained, str):
|
|
load_checkpoint(
|
|
self.backbone, self.pretrained, strict=False, logger=logger)
|
|
elif self.pretrained:
|
|
if self.backbone.__class__.__name__ == 'PytorchImageModelWrapper':
|
|
self.backbone.init_weights(pretrained=self.pretrained)
|
|
elif hasattr(self.backbone, 'default_pretrained_model_path'
|
|
) and self.backbone.default_pretrained_model_path:
|
|
print_log(
|
|
'load model from default path: {}'.format(
|
|
self.backbone.default_pretrained_model_path), logger)
|
|
load_checkpoint(
|
|
self.backbone,
|
|
self.backbone.default_pretrained_model_path,
|
|
strict=False,
|
|
logger=logger)
|
|
else:
|
|
print_log('load model from init weights')
|
|
self.backbone.init_weights()
|
|
else:
|
|
print_log('load model from init weights')
|
|
self.backbone.init_weights()
|
|
|
|
def forward_train(self, img, gt_labels, gt_masks, gt_semantic_seg,
|
|
img_metas):
|
|
features = self.backbone(img)
|
|
outputs = self.head(features)
|
|
targets = self.preprocess_gt(gt_labels, gt_masks, gt_semantic_seg,
|
|
img_metas)
|
|
losses = self.criterion(outputs, targets)
|
|
for k in list(losses.keys()):
|
|
if k in self.criterion.weight_dict:
|
|
losses[k] *= self.criterion.weight_dict[k]
|
|
else:
|
|
# remove this loss if not specified in `weight_dict`
|
|
losses.pop(k)
|
|
return losses
|
|
|
|
def forward_test(self, img, img_metas, rescale=True, encode=True):
|
|
features = self.backbone(img[0])
|
|
outputs = self.head(features)
|
|
mask_cls_results = outputs['pred_logits']
|
|
mask_pred_results = outputs['pred_masks']
|
|
detection_boxes = []
|
|
detection_scores = []
|
|
detection_classes = []
|
|
detection_masks = []
|
|
pan_masks = []
|
|
for mask_cls_result, mask_pred_result, meta in zip(
|
|
mask_cls_results, mask_pred_results, img_metas[0]):
|
|
pad_height, pad_width = meta['pad_shape'][:2]
|
|
mask_pred_result = F.interpolate(
|
|
mask_pred_result[:, None],
|
|
size=(pad_height, pad_width),
|
|
mode='bilinear',
|
|
align_corners=False)[:, 0]
|
|
# remove padding
|
|
img_height, img_width = meta['img_shape'][:2]
|
|
mask_pred_result = mask_pred_result[:, :img_height, :img_width]
|
|
ori_height, ori_width = meta['ori_shape'][:2]
|
|
mask_pred_result = F.interpolate(
|
|
mask_pred_result[:, None],
|
|
size=(ori_height, ori_width),
|
|
mode='bilinear',
|
|
align_corners=False)[:, 0]
|
|
|
|
# instance_on
|
|
if self.instance_on:
|
|
from easycv.utils.mmlab_utils import encode_mask_results
|
|
labels_per_image, bboxes, mask_pred_binary = self.instance_postprocess(
|
|
mask_cls_result, mask_pred_result)
|
|
segms = []
|
|
if mask_pred_binary is not None and labels_per_image.shape[
|
|
0] > 0:
|
|
mask_pred_binary = [mask_pred_binary]
|
|
if encode:
|
|
mask_pred_binary = encode_mask_results(
|
|
mask_pred_binary)
|
|
segms = mmcv.concat_list(mask_pred_binary)
|
|
segms = np.stack(segms, axis=0)
|
|
scores = bboxes[:, 4] if bboxes.shape[1] == 5 else None
|
|
bboxes = bboxes[:, 0:4] if bboxes.shape[1] == 5 else bboxes
|
|
detection_boxes.append(bboxes)
|
|
detection_scores.append(scores)
|
|
detection_classes.append(labels_per_image)
|
|
detection_masks.append(segms)
|
|
# panoptic on
|
|
if self.panoptic_on:
|
|
pan_results = self.panoptic_postprocess(
|
|
mask_cls_result, mask_pred_result)
|
|
pan_masks.append(pan_results.cpu().numpy())
|
|
assert len(img_metas) == 1
|
|
outputs = {
|
|
'detection_boxes': detection_boxes,
|
|
'detection_scores': detection_scores,
|
|
'detection_classes': detection_classes,
|
|
'detection_masks': detection_masks,
|
|
'img_metas': img_metas[0]
|
|
}
|
|
outputs['pan_results'] = pan_masks
|
|
return outputs
|
|
|
|
def forward(self,
|
|
img,
|
|
mode='train',
|
|
gt_labels=None,
|
|
gt_masks=None,
|
|
gt_semantic_seg=None,
|
|
img_metas=None,
|
|
**kwargs):
|
|
|
|
if mode == 'train':
|
|
return self.forward_train(img, gt_labels, gt_masks,
|
|
gt_semantic_seg, img_metas)
|
|
elif mode == 'test':
|
|
return self.forward_test(img, img_metas)
|
|
else:
|
|
raise Exception('No such mode: {}'.format(mode))
|
|
|
|
def instance_postprocess(self, mask_cls, mask_pred):
|
|
"""Instance segmengation postprocess.
|
|
|
|
Args:
|
|
mask_cls (Tensor): Classfication outputs of shape
|
|
(num_queries, cls_out_channels) for a image.
|
|
Note `cls_out_channels` should includes
|
|
background.
|
|
mask_pred (Tensor): Mask outputs of shape
|
|
(num_queries, h, w) for a image.
|
|
|
|
Returns:
|
|
tuple[Tensor]: Instance segmentation results.
|
|
|
|
- labels_per_image (Tensor): Predicted labels,\
|
|
shape (n, ).
|
|
- bboxes (Tensor): Bboxes and scores with shape (n, 5) of \
|
|
positive region in binary mask, the last column is scores.
|
|
- mask_pred_binary (Tensor): Instance masks of \
|
|
shape (n, h, w).
|
|
"""
|
|
from easycv.utils.mmlab_utils import mask2bbox
|
|
max_per_image = self.test_cfg.get('max_per_image', 100)
|
|
num_queries = mask_cls.shape[0]
|
|
# shape (num_queries, num_class)
|
|
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
|
|
# shape (num_queries * num_class, )
|
|
labels = torch.arange(self.num_classes, device=mask_cls.device).\
|
|
unsqueeze(0).repeat(num_queries, 1).flatten(0, 1)
|
|
scores_per_image, top_indices = scores.flatten(0, 1).topk(
|
|
max_per_image, sorted=False)
|
|
labels_per_image = labels[top_indices]
|
|
query_indices = top_indices // self.num_classes
|
|
mask_pred = mask_pred[query_indices]
|
|
|
|
# extract things
|
|
is_thing = labels_per_image < self.num_things_classes
|
|
scores_per_image = scores_per_image[is_thing]
|
|
labels_per_image = labels_per_image[is_thing]
|
|
mask_pred = mask_pred[is_thing]
|
|
|
|
mask_pred_binary = (mask_pred > 0).float()
|
|
mask_scores_per_image = (mask_pred.sigmoid() *
|
|
mask_pred_binary).flatten(1).sum(1) / (
|
|
mask_pred_binary.flatten(1).sum(1) + 1e-6)
|
|
det_scores = scores_per_image * mask_scores_per_image
|
|
mask_pred_binary = mask_pred_binary.bool()
|
|
bboxes = mask2bbox(mask_pred_binary)
|
|
bboxes = torch.cat([bboxes, det_scores[:, None]], dim=-1)
|
|
|
|
labels_per_image = labels_per_image.detach().cpu().numpy()
|
|
bboxes = bboxes.detach().cpu().numpy()
|
|
mask_pred_binary = mask_pred_binary.detach().cpu().numpy()
|
|
return labels_per_image, bboxes, mask_pred_binary
|
|
|
|
def panoptic_postprocess(self, mask_cls, mask_pred):
|
|
"""Panoptic segmengation inference.
|
|
|
|
Args:
|
|
mask_cls (Tensor): Classfication outputs of shape
|
|
(num_queries, cls_out_channels) for a image.
|
|
Note `cls_out_channels` should includes
|
|
background.
|
|
mask_pred (Tensor): Mask outputs of shape
|
|
(num_queries, h, w) for a image.
|
|
|
|
Returns:
|
|
Tensor: Panoptic segment result of shape \
|
|
(h, w), each element in Tensor means: \
|
|
``segment_id = _cls + instance_id * INSTANCE_OFFSET``.
|
|
"""
|
|
object_mask_thr = self.test_cfg.get('object_mask_thr', 0.8)
|
|
iou_thr = self.test_cfg.get('iou_thr', 0.8)
|
|
filter_low_score = self.test_cfg.get('filter_low_score', False)
|
|
|
|
scores, labels = F.softmax(mask_cls, dim=-1).max(-1)
|
|
mask_pred = mask_pred.sigmoid()
|
|
|
|
keep = labels.ne(self.num_classes) & (scores > object_mask_thr)
|
|
cur_scores = scores[keep]
|
|
cur_classes = labels[keep]
|
|
cur_masks = mask_pred[keep]
|
|
cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
|
|
|
|
h, w = cur_masks.shape[-2:]
|
|
panoptic_seg = torch.full((h, w),
|
|
self.num_classes,
|
|
dtype=torch.int32,
|
|
device=cur_masks.device)
|
|
if cur_masks.shape[0] == 0:
|
|
# We didn't detect any mask :(
|
|
pass
|
|
else:
|
|
cur_mask_ids = cur_prob_masks.argmax(0)
|
|
instance_id = 1
|
|
for k in range(cur_classes.shape[0]):
|
|
pred_class = int(cur_classes[k].item())
|
|
isthing = pred_class < self.num_things_classes
|
|
mask = cur_mask_ids == k
|
|
mask_area = mask.sum().item()
|
|
original_area = (cur_masks[k] >= 0.5).sum().item()
|
|
|
|
if filter_low_score:
|
|
mask = mask & (cur_masks[k] >= 0.5)
|
|
|
|
if mask_area > 0 and original_area > 0:
|
|
if mask_area / original_area < iou_thr:
|
|
continue
|
|
|
|
if not isthing:
|
|
# different stuff regions of same class will be
|
|
# merged here, and stuff share the instance_id 0.
|
|
panoptic_seg[mask] = pred_class
|
|
else:
|
|
panoptic_seg[mask] = (
|
|
pred_class + instance_id * INSTANCE_OFFSET)
|
|
instance_id += 1
|
|
|
|
return panoptic_seg
|
|
|
|
def preprocess_gt(self, gt_labels_list, gt_masks_list, gt_semantic_segs,
|
|
img_metas):
|
|
"""Preprocess the ground truth for all images.
|
|
|
|
Args:
|
|
gt_labels_list (list[Tensor]): Each is ground truth
|
|
labels of each bbox, with shape (num_gts, ).
|
|
gt_masks_list (list[BitmapMasks]): Each is ground truth
|
|
masks of each instances of a image, shape
|
|
(num_gts, h, w).
|
|
gt_semantic_seg (Tensor): Ground truth of semantic
|
|
segmentation with the shape (batch_size, n, h, w).
|
|
[0, num_thing_class - 1] means things,
|
|
[num_thing_class, num_class-1] means stuff,
|
|
255 means VOID.
|
|
target_shape (tuple[int]): Shape of output mask_preds.
|
|
Resize the masks to shape of mask_preds.
|
|
|
|
Returns:
|
|
tuple: a tuple containing the following targets.
|
|
- labels (list[Tensor]): Ground truth class indices\
|
|
for all images. Each with shape (n, ), n is the sum of\
|
|
number of stuff type and number of instance in a image.
|
|
- masks (list[Tensor]): Ground truth mask for each\
|
|
image, each with shape (n, h, w).
|
|
"""
|
|
num_things_list = [self.num_things_classes] * len(gt_labels_list)
|
|
num_stuff_list = [self.num_stuff_classes] * len(gt_labels_list)
|
|
if gt_semantic_segs is None:
|
|
gt_semantic_segs = [None] * len(gt_labels_list)
|
|
|
|
targets = multi_apply(preprocess_panoptic_gt, gt_labels_list,
|
|
gt_masks_list, gt_semantic_segs, num_things_list,
|
|
num_stuff_list, img_metas)
|
|
labels, masks = targets
|
|
new_targets = []
|
|
for label, mask in zip(labels, masks):
|
|
new_targets.append({
|
|
'labels': label,
|
|
'masks': mask,
|
|
})
|
|
return new_targets
|