Add semantic segmentation (Mask2Former) code (#186)

Add semantic segmentation (Mask2Former based on ViT-Adapter) code + update demo notebook for segmentation with a dedicated section.
pull/189/head
Patrick Labatut 2023-08-31 15:36:47 +02:00 committed by GitHub
parent d5b0405eff
commit 91d8cd81c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 6335 additions and 13 deletions

View File

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .core import * # noqa: F403
from .models import * # noqa: F403
from .ops import * # noqa: F403

View File

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from mmseg.core.evaluation import * # noqa: F403
from mmseg.core.seg import * # noqa: F403
from .anchor import * # noqa: F403
from .box import * # noqa: F403
from .utils import * # noqa: F403

View File

@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .point_generator import MlvlPointGenerator # noqa: F403

View File

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import warnings
from mmcv.utils import Registry, build_from_cfg
PRIOR_GENERATORS = Registry("Generator for anchors and points")
ANCHOR_GENERATORS = PRIOR_GENERATORS
def build_prior_generator(cfg, default_args=None):
return build_from_cfg(cfg, PRIOR_GENERATORS, default_args)
def build_anchor_generator(cfg, default_args=None):
warnings.warn("``build_anchor_generator`` would be deprecated soon, please use " "``build_prior_generator`` ")
return build_prior_generator(cfg, default_args=default_args)

View File

@ -0,0 +1,205 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from torch.nn.modules.utils import _pair
from .builder import PRIOR_GENERATORS
@PRIOR_GENERATORS.register_module()
class MlvlPointGenerator:
"""Standard points generator for multi-level (Mlvl) feature maps in 2D
points-based detectors.
Args:
strides (list[int] | list[tuple[int, int]]): Strides of anchors
in multiple feature levels in order (w, h).
offset (float): The offset of points, the value is normalized with
corresponding stride. Defaults to 0.5.
"""
def __init__(self, strides, offset=0.5):
self.strides = [_pair(stride) for stride in strides]
self.offset = offset
@property
def num_levels(self):
"""int: number of feature levels that the generator will be applied"""
return len(self.strides)
@property
def num_base_priors(self):
"""list[int]: The number of priors (points) at a point
on the feature grid"""
return [1 for _ in range(len(self.strides))]
def _meshgrid(self, x, y, row_major=True):
yy, xx = torch.meshgrid(y, x)
if row_major:
# warning .flatten() would cause error in ONNX exporting
# have to use reshape here
return xx.reshape(-1), yy.reshape(-1)
else:
return yy.reshape(-1), xx.reshape(-1)
def grid_priors(self, featmap_sizes, dtype=torch.float32, device="cuda", with_stride=False):
"""Generate grid points of multiple feature levels.
Args:
featmap_sizes (list[tuple]): List of feature map sizes in
multiple feature levels, each size arrange as
as (h, w).
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
device (str): The device where the anchors will be put on.
with_stride (bool): Whether to concatenate the stride to
the last dimension of points.
Return:
list[torch.Tensor]: Points of multiple feature levels.
The sizes of each tensor should be (N, 2) when with stride is
``False``, where N = width * height, width and height
are the sizes of the corresponding feature level,
and the last dimension 2 represent (coord_x, coord_y),
otherwise the shape should be (N, 4),
and the last dimension 4 represent
(coord_x, coord_y, stride_w, stride_h).
"""
assert self.num_levels == len(featmap_sizes)
multi_level_priors = []
for i in range(self.num_levels):
priors = self.single_level_grid_priors(
featmap_sizes[i], level_idx=i, dtype=dtype, device=device, with_stride=with_stride
)
multi_level_priors.append(priors)
return multi_level_priors
def single_level_grid_priors(self, featmap_size, level_idx, dtype=torch.float32, device="cuda", with_stride=False):
"""Generate grid Points of a single level.
Note:
This function is usually called by method ``self.grid_priors``.
Args:
featmap_size (tuple[int]): Size of the feature maps, arrange as
(h, w).
level_idx (int): The index of corresponding feature map level.
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
device (str, optional): The device the tensor will be put on.
Defaults to 'cuda'.
with_stride (bool): Concatenate the stride to the last dimension
of points.
Return:
Tensor: Points of single feature levels.
The shape of tensor should be (N, 2) when with stride is
``False``, where N = width * height, width and height
are the sizes of the corresponding feature level,
and the last dimension 2 represent (coord_x, coord_y),
otherwise the shape should be (N, 4),
and the last dimension 4 represent
(coord_x, coord_y, stride_w, stride_h).
"""
feat_h, feat_w = featmap_size
stride_w, stride_h = self.strides[level_idx]
shift_x = (torch.arange(0, feat_w, device=device) + self.offset) * stride_w
# keep featmap_size as Tensor instead of int, so that we
# can convert to ONNX correctly
shift_x = shift_x.to(dtype)
shift_y = (torch.arange(0, feat_h, device=device) + self.offset) * stride_h
# keep featmap_size as Tensor instead of int, so that we
# can convert to ONNX correctly
shift_y = shift_y.to(dtype)
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
if not with_stride:
shifts = torch.stack([shift_xx, shift_yy], dim=-1)
else:
# use `shape[0]` instead of `len(shift_xx)` for ONNX export
stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to(dtype)
stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to(dtype)
shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h], dim=-1)
all_points = shifts.to(device)
return all_points
def valid_flags(self, featmap_sizes, pad_shape, device="cuda"):
"""Generate valid flags of points of multiple feature levels.
Args:
featmap_sizes (list(tuple)): List of feature map sizes in
multiple feature levels, each size arrange as
as (h, w).
pad_shape (tuple(int)): The padded shape of the image,
arrange as (h, w).
device (str): The device where the anchors will be put on.
Return:
list(torch.Tensor): Valid flags of points of multiple levels.
"""
assert self.num_levels == len(featmap_sizes)
multi_level_flags = []
for i in range(self.num_levels):
point_stride = self.strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = pad_shape[:2]
valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h)
valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w)
flags = self.single_level_valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device)
multi_level_flags.append(flags)
return multi_level_flags
def single_level_valid_flags(self, featmap_size, valid_size, device="cuda"):
"""Generate the valid flags of points of a single feature map.
Args:
featmap_size (tuple[int]): The size of feature maps, arrange as
as (h, w).
valid_size (tuple[int]): The valid size of the feature maps.
The size arrange as as (h, w).
device (str, optional): The device where the flags will be put on.
Defaults to 'cuda'.
Returns:
torch.Tensor: The valid flags of each points in a single level \
feature map.
"""
feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w
valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
valid_x[:valid_w] = 1
valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy
return valid
def sparse_priors(self, prior_idxs, featmap_size, level_idx, dtype=torch.float32, device="cuda"):
"""Generate sparse points according to the ``prior_idxs``.
Args:
prior_idxs (Tensor): The index of corresponding anchors
in the feature map.
featmap_size (tuple[int]): feature map size arrange as (w, h).
level_idx (int): The level index of corresponding feature
map.
dtype (obj:`torch.dtype`): Date type of points. Defaults to
``torch.float32``.
device (obj:`torch.device`): The device where the points is
located.
Returns:
Tensor: Anchor with shape (N, 2), N should be equal to
the length of ``prior_idxs``. And last dimension
2 represent (coord_x, coord_y).
"""
height, width = featmap_size
x = (prior_idxs % width + self.offset) * self.strides[level_idx][0]
y = ((prior_idxs // width) % height + self.offset) * self.strides[level_idx][1]
prioris = torch.stack([x, y], 1).to(dtype)
prioris = prioris.to(device)
return prioris

View File

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .builder import * # noqa: F403
from .samplers import MaskPseudoSampler # noqa: F403

View File

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from mmcv.utils import Registry, build_from_cfg
BBOX_SAMPLERS = Registry("bbox_sampler")
BBOX_CODERS = Registry("bbox_coder")
def build_sampler(cfg, **default_args):
"""Builder of box sampler."""
return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)
def build_bbox_coder(cfg, **default_args):
"""Builder of box coder."""
return build_from_cfg(cfg, BBOX_CODERS, default_args)

View File

@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .mask_pseudo_sampler import MaskPseudoSampler # noqa: F403

View File

@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
import torch
from .sampling_result import SamplingResult
class BaseSampler(metaclass=ABCMeta):
"""Base class of samplers."""
def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs):
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.pos_sampler = self
self.neg_sampler = self
@abstractmethod
def _sample_pos(self, assign_result, num_expected, **kwargs):
"""Sample positive samples."""
pass
@abstractmethod
def _sample_neg(self, assign_result, num_expected, **kwargs):
"""Sample negative samples."""
pass
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
assigning results and ground truth bboxes.
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes.
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
Returns:
:obj:`SamplingResult`: Sampling result.
Example:
>>> from mmdet.core.bbox import RandomSampler
>>> from mmdet.core.bbox import AssignResult
>>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
>>> rng = ensure_rng(None)
>>> assign_result = AssignResult.random(rng=rng)
>>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
>>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
>>> gt_labels = None
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
>>> add_gt_as_proposals=False)
>>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
"""
if len(bboxes.shape) < 2:
bboxes = bboxes[None, :]
bboxes = bboxes[:, :4]
gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8)
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
if gt_labels is None:
raise ValueError("gt_labels must be given when add_gt_as_proposals is True")
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
assign_result.add_gt_(gt_labels)
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
gt_flags = torch.cat([gt_ones, gt_flags])
num_expected_pos = int(self.num * self.pos_fraction)
pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds = pos_inds.unique()
num_sampled_pos = pos_inds.numel()
num_expected_neg = self.num - num_sampled_pos
if self.neg_pos_ub >= 0:
_pos = max(1, num_sampled_pos)
neg_upper_bound = int(self.neg_pos_ub * _pos)
if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound
neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique()
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags)
return sampling_result

View File

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py
import torch
from ..builder import BBOX_SAMPLERS
from .base_sampler import BaseSampler
from .mask_sampling_result import MaskSamplingResult
@BBOX_SAMPLERS.register_module()
class MaskPseudoSampler(BaseSampler):
"""A pseudo sampler that does not do sampling actually."""
def __init__(self, **kwargs):
pass
def _sample_pos(self, **kwargs):
"""Sample positive samples."""
raise NotImplementedError
def _sample_neg(self, **kwargs):
"""Sample negative samples."""
raise NotImplementedError
def sample(self, assign_result, masks, gt_masks, **kwargs):
"""Directly returns the positive and negative indices of samples.
Args:
assign_result (:obj:`AssignResult`): Assigned results
masks (torch.Tensor): Bounding boxes
gt_masks (torch.Tensor): Ground truth boxes
Returns:
:obj:`SamplingResult`: sampler results
"""
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8)
sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags)
return sampling_result

View File

@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py
import torch
from .sampling_result import SamplingResult
class MaskSamplingResult(SamplingResult):
"""Mask sampling result."""
def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags):
self.pos_inds = pos_inds
self.neg_inds = neg_inds
self.pos_masks = masks[pos_inds]
self.neg_masks = masks[neg_inds]
self.pos_is_gt = gt_flags[pos_inds]
self.num_gts = gt_masks.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
if gt_masks.numel() == 0:
# hack for index error case
assert self.pos_assigned_gt_inds.numel() == 0
self.pos_gt_masks = torch.empty_like(gt_masks)
else:
self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :]
if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds]
else:
self.pos_gt_labels = None
@property
def masks(self):
"""torch.Tensor: concatenated positive and negative boxes"""
return torch.cat([self.pos_masks, self.neg_masks])
def __nice__(self):
data = self.info.copy()
data["pos_masks"] = data.pop("pos_masks").shape
data["neg_masks"] = data.pop("neg_masks").shape
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
body = " " + ",\n ".join(parts)
return "{\n" + body + "\n}"
@property
def info(self):
"""Returns a dictionary of info about the object."""
return {
"pos_inds": self.pos_inds,
"neg_inds": self.neg_inds,
"pos_masks": self.pos_masks,
"neg_masks": self.neg_masks,
"pos_is_gt": self.pos_is_gt,
"num_gts": self.num_gts,
"pos_assigned_gt_inds": self.pos_assigned_gt_inds,
}

View File

@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
class SamplingResult:
"""Bbox sampling result.
Example:
>>> # xdoctest: +IGNORE_WANT
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random(rng=10)
>>> print(f'self = {self}')
self = <SamplingResult({
'neg_bboxes': torch.Size([12, 4]),
'neg_inds': tensor([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
'num_gts': 4,
'pos_assigned_gt_inds': tensor([], dtype=torch.int64),
'pos_bboxes': torch.Size([0, 4]),
'pos_inds': tensor([], dtype=torch.int64),
'pos_is_gt': tensor([], dtype=torch.uint8)
})>
"""
def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags):
self.pos_inds = pos_inds
self.neg_inds = neg_inds
self.pos_bboxes = bboxes[pos_inds]
self.neg_bboxes = bboxes[neg_inds]
self.pos_is_gt = gt_flags[pos_inds]
self.num_gts = gt_bboxes.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
if gt_bboxes.numel() == 0:
# hack for index error case
assert self.pos_assigned_gt_inds.numel() == 0
self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.view(-1, 4)
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :]
if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds]
else:
self.pos_gt_labels = None
@property
def bboxes(self):
"""torch.Tensor: concatenated positive and negative boxes"""
return torch.cat([self.pos_bboxes, self.neg_bboxes])
def to(self, device):
"""Change the device of the data inplace.
Example:
>>> self = SamplingResult.random()
>>> print(f'self = {self.to(None)}')
>>> # xdoctest: +REQUIRES(--gpu)
>>> print(f'self = {self.to(0)}')
"""
_dict = self.__dict__
for key, value in _dict.items():
if isinstance(value, torch.Tensor):
_dict[key] = value.to(device)
return self
def __nice__(self):
data = self.info.copy()
data["pos_bboxes"] = data.pop("pos_bboxes").shape
data["neg_bboxes"] = data.pop("neg_bboxes").shape
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
body = " " + ",\n ".join(parts)
return "{\n" + body + "\n}"
@property
def info(self):
"""Returns a dictionary of info about the object."""
return {
"pos_inds": self.pos_inds,
"neg_inds": self.neg_inds,
"pos_bboxes": self.pos_bboxes,
"neg_bboxes": self.neg_bboxes,
"pos_is_gt": self.pos_is_gt,
"num_gts": self.num_gts,
"pos_assigned_gt_inds": self.pos_assigned_gt_inds,
}
@classmethod
def random(cls, rng=None, **kwargs):
"""
Args:
rng (None | int | numpy.random.RandomState): seed or state.
kwargs (keyword arguments):
- num_preds: number of predicted boxes
- num_gts: number of true boxes
- p_ignore (float): probability of a predicted box assigned to \
an ignored truth.
- p_assigned (float): probability of a predicted box not being \
assigned.
- p_use_label (float | bool): with labels or not.
Returns:
:obj:`SamplingResult`: Randomly generated sampling result.
Example:
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random()
>>> print(self.__dict__)
"""
from mmdet.core.bbox import demodata
from mmdet.core.bbox.assigners.assign_result import AssignResult
from mmdet.core.bbox.samplers.random_sampler import RandomSampler
rng = demodata.ensure_rng(rng)
# make probabalistic?
num = 32
pos_fraction = 0.5
neg_pos_ub = -1
assign_result = AssignResult.random(rng=rng, **kwargs)
# Note we could just compute an assignment
bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)
if rng.rand() > 0.2:
# sometimes algorithms squeeze their data, be robust to that
gt_bboxes = gt_bboxes.squeeze()
bboxes = bboxes.squeeze()
if assign_result.labels is None:
gt_labels = None
else:
gt_labels = None
if gt_labels is None:
add_gt_as_proposals = False
else:
add_gt_as_proposals = True # make probabalistic?
sampler = RandomSampler(
num, pos_fraction, neg_pos_ub=neg_pos_ub, add_gt_as_proposals=add_gt_as_proposals, rng=rng
)
self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
return self

View File

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .dist_utils import reduce_mean
from .misc import add_prefix, multi_apply

View File

@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch.distributed as dist
def reduce_mean(tensor):
""" "Obtain the mean of tensor on different GPUs."""
if not (dist.is_available() and dist.is_initialized()):
return tensor
tensor = tensor.clone()
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
return tensor

View File

@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from functools import partial
def multi_apply(func, *args, **kwargs):
"""Apply function to a list of arguments.
Note:
This function applies the ``func`` to multiple inputs and
map the multiple outputs of the ``func`` into different
list. Each list contains the same type of outputs corresponding
to different inputs.
Args:
func (Function): A function that will be applied to a list of
arguments
Returns:
tuple(list): A tuple containing multiple list, each list contains \
a kind of returned results by the function
"""
pfunc = partial(func, **kwargs) if kwargs else func
map_results = map(pfunc, *args)
return tuple(map(list, zip(*map_results)))
def add_prefix(inputs, prefix):
"""Add prefix for dict.
Args:
inputs (dict): The input dict with str keys.
prefix (str): The prefix to add.
Returns:
dict: The dict with keys updated with ``prefix``.
"""
outputs = dict()
for name, value in inputs.items():
outputs[f"{prefix}.{name}"] = value
return outputs

View File

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .backbones import * # noqa: F403
from .builder import MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, build_match_cost
from .decode_heads import * # noqa: F403
from .losses import * # noqa: F403
from .plugins import * # noqa: F403
from .segmentors import * # noqa: F403

View File

@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .vit_adapter import ViTAdapter

View File

@ -0,0 +1,442 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from functools import partial
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from ...ops.modules import MSDeformAttn
from .drop_path import DropPath
def get_reference_points(spatial_shapes, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / H_
ref_x = ref_x.reshape(-1)[None] / W_
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None]
return reference_points
def deform_inputs(x, patch_size):
bs, c, h, w = x.shape
spatial_shapes = torch.as_tensor(
[(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device
)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device)
deform_inputs1 = [reference_points, spatial_shapes, level_start_index]
spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device)
deform_inputs2 = [reference_points, spatial_shapes, level_start_index]
return deform_inputs1, deform_inputs2
class ConvFFN(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
n = N // 21
x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous()
x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous()
x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous()
x1 = self.dwconv(x1).flatten(2).transpose(1, 2)
x2 = self.dwconv(x2).flatten(2).transpose(1, 2)
x3 = self.dwconv(x3).flatten(2).transpose(1, 2)
x = torch.cat([x1, x2, x3], dim=1)
return x
class Extractor(nn.Module):
def __init__(
self,
dim,
num_heads=6,
n_points=4,
n_levels=1,
deform_ratio=1.0,
with_cffn=True,
cffn_ratio=0.25,
drop=0.0,
drop_path=0.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
with_cp=False,
):
super().__init__()
self.query_norm = norm_layer(dim)
self.feat_norm = norm_layer(dim)
self.attn = MSDeformAttn(
d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
)
self.with_cffn = with_cffn
self.with_cp = with_cp
if with_cffn:
self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop)
self.ffn_norm = norm_layer(dim)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W):
def _inner_forward(query, feat):
attn = self.attn(
self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
)
query = query + attn
if self.with_cffn:
query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W))
return query
if self.with_cp and query.requires_grad:
query = cp.checkpoint(_inner_forward, query, feat)
else:
query = _inner_forward(query, feat)
return query
class Injector(nn.Module):
def __init__(
self,
dim,
num_heads=6,
n_points=4,
n_levels=1,
deform_ratio=1.0,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
init_values=0.0,
with_cp=False,
):
super().__init__()
self.with_cp = with_cp
self.query_norm = norm_layer(dim)
self.feat_norm = norm_layer(dim)
self.attn = MSDeformAttn(
d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
)
self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
def forward(self, query, reference_points, feat, spatial_shapes, level_start_index):
def _inner_forward(query, feat):
attn = self.attn(
self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
)
return query + self.gamma * attn
if self.with_cp and query.requires_grad:
query = cp.checkpoint(_inner_forward, query, feat)
else:
query = _inner_forward(query, feat)
return query
class InteractionBlock(nn.Module):
def __init__(
self,
dim,
num_heads=6,
n_points=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
drop=0.0,
drop_path=0.0,
with_cffn=True,
cffn_ratio=0.25,
init_values=0.0,
deform_ratio=1.0,
extra_extractor=False,
with_cp=False,
):
super().__init__()
self.injector = Injector(
dim=dim,
n_levels=3,
num_heads=num_heads,
init_values=init_values,
n_points=n_points,
norm_layer=norm_layer,
deform_ratio=deform_ratio,
with_cp=with_cp,
)
self.extractor = Extractor(
dim=dim,
n_levels=1,
num_heads=num_heads,
n_points=n_points,
norm_layer=norm_layer,
deform_ratio=deform_ratio,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
drop=drop,
drop_path=drop_path,
with_cp=with_cp,
)
if extra_extractor:
self.extra_extractors = nn.Sequential(
*[
Extractor(
dim=dim,
num_heads=num_heads,
n_points=n_points,
norm_layer=norm_layer,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
deform_ratio=deform_ratio,
drop=drop,
drop_path=drop_path,
with_cp=with_cp,
)
for _ in range(2)
]
)
else:
self.extra_extractors = None
def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
x = self.injector(
query=x,
reference_points=deform_inputs1[0],
feat=c,
spatial_shapes=deform_inputs1[1],
level_start_index=deform_inputs1[2],
)
for idx, blk in enumerate(blocks):
x = blk(x, H_toks, W_toks)
c = self.extractor(
query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H_c,
W=W_c,
)
if self.extra_extractors is not None:
for extractor in self.extra_extractors:
c = extractor(
query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H_c,
W=W_c,
)
return x, c
class InteractionBlockWithCls(nn.Module):
def __init__(
self,
dim,
num_heads=6,
n_points=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
drop=0.0,
drop_path=0.0,
with_cffn=True,
cffn_ratio=0.25,
init_values=0.0,
deform_ratio=1.0,
extra_extractor=False,
with_cp=False,
):
super().__init__()
self.injector = Injector(
dim=dim,
n_levels=3,
num_heads=num_heads,
init_values=init_values,
n_points=n_points,
norm_layer=norm_layer,
deform_ratio=deform_ratio,
with_cp=with_cp,
)
self.extractor = Extractor(
dim=dim,
n_levels=1,
num_heads=num_heads,
n_points=n_points,
norm_layer=norm_layer,
deform_ratio=deform_ratio,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
drop=drop,
drop_path=drop_path,
with_cp=with_cp,
)
if extra_extractor:
self.extra_extractors = nn.Sequential(
*[
Extractor(
dim=dim,
num_heads=num_heads,
n_points=n_points,
norm_layer=norm_layer,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
deform_ratio=deform_ratio,
drop=drop,
drop_path=drop_path,
with_cp=with_cp,
)
for _ in range(2)
]
)
else:
self.extra_extractors = None
def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
x = self.injector(
query=x,
reference_points=deform_inputs1[0],
feat=c,
spatial_shapes=deform_inputs1[1],
level_start_index=deform_inputs1[2],
)
x = torch.cat((cls, x), dim=1)
for idx, blk in enumerate(blocks):
x = blk(x, H_toks, W_toks)
cls, x = (
x[
:,
:1,
],
x[
:,
1:,
],
)
c = self.extractor(
query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H_c,
W=W_c,
)
if self.extra_extractors is not None:
for extractor in self.extra_extractors:
c = extractor(
query=c,
reference_points=deform_inputs2[0],
feat=x,
spatial_shapes=deform_inputs2[1],
level_start_index=deform_inputs2[2],
H=H_c,
W=W_c,
)
return x, c, cls
class SpatialPriorModule(nn.Module):
def __init__(self, inplanes=64, embed_dim=384, with_cp=False):
super().__init__()
self.with_cp = with_cp
self.stem = nn.Sequential(
*[
nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(inplanes),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
nn.SyncBatchNorm(inplanes),
nn.ReLU(inplace=True),
nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
nn.SyncBatchNorm(inplanes),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
]
)
self.conv2 = nn.Sequential(
*[
nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(2 * inplanes),
nn.ReLU(inplace=True),
]
)
self.conv3 = nn.Sequential(
*[
nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(4 * inplanes),
nn.ReLU(inplace=True),
]
)
self.conv4 = nn.Sequential(
*[
nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
nn.SyncBatchNorm(4 * inplanes),
nn.ReLU(inplace=True),
]
)
self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
def forward(self, x):
def _inner_forward(x):
c1 = self.stem(x)
c2 = self.conv2(c1)
c3 = self.conv3(c2)
c4 = self.conv4(c3)
c1 = self.fc1(c1)
c2 = self.fc2(c2)
c3 = self.fc3(c3)
c4 = self.fc4(c4)
bs, dim, _, _ = c1.shape
# c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s
c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s
c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s
c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s
return c1, c2, c3, c4
if self.with_cp and x.requires_grad:
outs = cp.checkpoint(_inner_forward, x)
else:
outs = _inner_forward(x)
return outs

View File

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
from torch import nn
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob: float = 0.0):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

View File

@ -0,0 +1,552 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
"""Vision Transformer (ViT) in PyTorch.
A PyTorch implement of Vision Transformers as described in:
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
- https://arxiv.org/abs/2010.11929
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
- https://arxiv.org/abs/2106.10270
The official jax code is released and available at https://github.com/google-research/vision_transformer
DeiT model defs and weights from https://github.com/facebookresearch/deit,
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
Hacked together by / Copyright 2021 Ross Wightman
"""
import logging
import math
from functools import partial
from itertools import repeat
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.runner import BaseModule, load_checkpoint
from mmseg.ops import resize
from mmseg.utils import get_root_logger
from torch import Tensor
from .drop_path import DropPath
def to_2tuple(x):
return tuple(repeat(x, 2))
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
swiglu_hidden_features = int(2 * hidden_features / 3)
align_as = 8
swiglu_hidden_features = (swiglu_hidden_features + align_as - 1) // align_as * align_as
self.w1 = nn.Linear(in_features, swiglu_hidden_features)
self.w2 = nn.Linear(in_features, swiglu_hidden_features)
self.w3 = nn.Linear(swiglu_hidden_features, out_features)
def forward(self, x: Tensor) -> Tensor:
x1 = self.w1(x)
x2 = self.w2(x)
hidden = F.silu(x1) * x2
return self.w3(hidden)
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding."""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x, H, W
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, H, W):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: Tensor, H, W) -> Tensor:
from xformers.ops import memory_efficient_attention, unbind
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
x = memory_efficient_attention(q, k, v)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowedAttention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, window_size=14, pad_mode="constant"
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.window_size = window_size
self.pad_mode = pad_mode
def forward(self, x, H, W):
B, N, C = x.shape
N_ = self.window_size * self.window_size
H_ = math.ceil(H / self.window_size) * self.window_size
W_ = math.ceil(W / self.window_size) * self.window_size
qkv = self.qkv(x) # [B, N, C]
qkv = qkv.transpose(1, 2).reshape(B, C * 3, H, W) # [B, C, H, W]
qkv = F.pad(qkv, [0, W_ - W, 0, H_ - H], mode=self.pad_mode)
qkv = F.unfold(
qkv, kernel_size=(self.window_size, self.window_size), stride=(self.window_size, self.window_size)
)
B, C_kw_kw, L = qkv.shape # L - the num of windows
qkv = qkv.reshape(B, C * 3, N_, L).permute(0, 3, 2, 1) # [B, L, N_, C]
qkv = qkv.reshape(B, L, N_, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# q,k,v [B, L, num_head, N_, C/num_head]
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_]
# if self.mask:
# attn = attn * mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) # [B, L, num_head, N_, N_]
# attn @ v = [B, L, num_head, N_, C/num_head]
x = (attn @ v).permute(0, 2, 4, 3, 1).reshape(B, C_kw_kw // 3, L)
x = F.fold(
x,
output_size=(H_, W_),
kernel_size=(self.window_size, self.window_size),
stride=(self.window_size, self.window_size),
) # [B, C, H_, W_]
x = x[:, :, :H, :W].reshape(B, C, N).transpose(-1, -2)
x = self.proj(x)
x = self.proj_drop(x)
return x
# class WindowedAttention(nn.Module):
# def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., window_size=14, pad_mode="constant"):
# super().__init__()
# self.num_heads = num_heads
# head_dim = dim // num_heads
# self.scale = head_dim ** -0.5
#
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
# self.attn_drop = nn.Dropout(attn_drop)
# self.proj = nn.Linear(dim, dim)
# self.proj_drop = nn.Dropout(proj_drop)
# self.window_size = window_size
# self.pad_mode = pad_mode
#
# def forward(self, x, H, W):
# B, N, C = x.shape
#
# N_ = self.window_size * self.window_size
# H_ = math.ceil(H / self.window_size) * self.window_size
# W_ = math.ceil(W / self.window_size) * self.window_size
# x = x.view(B, H, W, C)
# x = F.pad(x, [0, 0, 0, W_ - W, 0, H_- H], mode=self.pad_mode)
#
# x = window_partition(x, window_size=self.window_size)# nW*B, window_size, window_size, C
# x = x.view(-1, N_, C)
#
# qkv = self.qkv(x).view(-1, N_, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_]
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn) # [B, L, num_head, N_, N_]
# x = (attn @ v).transpose(1, 2).reshape(-1, self.window_size, self.window_size, C)
#
# x = window_reverse(x, self.window_size, H_, W_)
# x = x[:, :H, :W, :].reshape(B, N, C).contiguous()
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
windowed=False,
window_size=14,
pad_mode="constant",
layer_scale=False,
with_cp=False,
ffn_layer=Mlp,
memeff=False,
):
super().__init__()
self.with_cp = with_cp
self.norm1 = norm_layer(dim)
if windowed:
self.attn = WindowedAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
window_size=window_size,
pad_mode=pad_mode,
)
elif memeff:
self.attn = MemEffAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop
)
else:
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.layer_scale = layer_scale
if layer_scale:
self.gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True)
self.gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True)
def forward(self, x, H, W):
def _inner_forward(x):
if self.layer_scale:
x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
if self.with_cp and x.requires_grad:
x = cp.checkpoint(_inner_forward, x)
else:
x = _inner_forward(x)
return x
class TIMMVisionTransformer(BaseModule):
"""Vision Transformer.
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
- https://arxiv.org/abs/2012.12877
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
layer_scale=True,
embed_layer=PatchEmbed,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
window_attn=False,
window_size=14,
pretrained=None,
with_cp=False,
pre_norm=False,
ffn_type="mlp",
memeff=False,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
pretrained: (str): pretrained path
"""
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.norm_layer = norm_layer
self.act_layer = act_layer
self.pretrain_size = img_size
self.drop_path_rate = drop_path_rate
self.drop_rate = drop_rate
self.patch_size = patch_size
window_attn = [window_attn] * depth if not isinstance(window_attn, list) else window_attn
window_size = [window_size] * depth if not isinstance(window_size, list) else window_size
logging.info("window attention:", window_attn)
logging.info("window size:", window_size)
logging.info("layer scale:", layer_scale)
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm
)
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
ffn_types = {"mlp": Mlp, "swiglu": SwiGLUFFN}
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(
*[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
windowed=window_attn[i],
window_size=window_size[i],
layer_scale=layer_scale,
with_cp=with_cp,
ffn_layer=ffn_types[ffn_type],
memeff=memeff,
)
for i in range(depth)
]
)
# self.norm = norm_layer(embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# For CLIP
if pre_norm:
norm_pre = norm_layer(embed_dim)
self.norm_pre = norm_pre
else:
self.norm_pre = nn.Identity()
self.init_weights(pretrained)
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, map_location="cpu", strict=False, logger=logger)
def forward_features(self, x):
x, H, W = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
# For CLIP
x = self.norm_pre(x)
for blk in self.blocks:
x = blk(x, H, W)
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
return x
@staticmethod
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
"""Resize pos_embed weights.
Resize pos_embed using bicubic interpolate method.
Args:
pos_embed (torch.Tensor): Position embedding weights.
input_shpae (tuple): Tuple for (downsampled input image height,
downsampled input image width).
pos_shape (tuple): The resolution of downsampled origin training
image.
mode (str): Algorithm used for upsampling:
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
``'trilinear'``. Default: ``'nearest'``
Return:
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
"""
assert pos_embed.ndim == 3, "shape of pos_embed must be [B, L, C]"
pos_h, pos_w = pos_shape
# keep dim for easy deployment
cls_token_weight = pos_embed[:, 0:1]
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :]
pos_embed_weight = pos_embed_weight.reshape(1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
pos_embed_weight = resize(pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
return pos_embed

View File

@ -0,0 +1,217 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.models.builder import BACKBONES
from torch.nn.init import normal_
from ...ops.modules import MSDeformAttn
from .adapter_modules import InteractionBlock, InteractionBlockWithCls, SpatialPriorModule, deform_inputs
from .vit import TIMMVisionTransformer
@BACKBONES.register_module()
class ViTAdapter(TIMMVisionTransformer):
def __init__(
self,
pretrain_size=224,
num_heads=12,
conv_inplane=64,
n_points=4,
deform_num_heads=6,
init_values=0.0,
interaction_indexes=None,
with_cffn=True,
cffn_ratio=0.25,
deform_ratio=1.0,
add_vit_feature=True,
pretrained=None,
use_extra_extractor=True,
freeze_vit=False,
use_cls=True,
with_cp=False,
*args,
**kwargs
):
super().__init__(num_heads=num_heads, pretrained=pretrained, with_cp=with_cp, *args, **kwargs)
if freeze_vit:
for param in self.parameters():
param.requires_grad = False
# self.num_classes = 80
self.use_cls = use_cls
if not self.use_cls:
self.cls_token = None
self.num_block = len(self.blocks)
self.pretrain_size = (pretrain_size, pretrain_size)
self.interaction_indexes = interaction_indexes
self.add_vit_feature = add_vit_feature
embed_dim = self.embed_dim
block_fn = InteractionBlockWithCls if use_cls else InteractionBlock
self.level_embed = nn.Parameter(torch.zeros(3, embed_dim))
self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False)
self.interactions = nn.Sequential(
*[
block_fn(
dim=embed_dim,
num_heads=deform_num_heads,
n_points=n_points,
init_values=init_values,
drop_path=self.drop_path_rate,
norm_layer=self.norm_layer,
with_cffn=with_cffn,
cffn_ratio=cffn_ratio,
deform_ratio=deform_ratio,
extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor),
with_cp=with_cp,
)
for i in range(len(interaction_indexes))
]
)
self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)
self.norm1 = nn.SyncBatchNorm(embed_dim)
self.norm2 = nn.SyncBatchNorm(embed_dim)
self.norm3 = nn.SyncBatchNorm(embed_dim)
self.norm4 = nn.SyncBatchNorm(embed_dim)
self.up.apply(self._init_weights)
self.spm.apply(self._init_weights)
self.interactions.apply(self._init_weights)
self.apply(self._init_deform_weights)
normal_(self.level_embed)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def _get_pos_embed(self, pos_embed, H, W):
pos_embed = pos_embed.reshape(
1, self.pretrain_size[0] // self.patch_size, self.pretrain_size[1] // self.patch_size, -1
).permute(0, 3, 1, 2)
pos_embed = (
F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
.reshape(1, -1, H * W)
.permute(0, 2, 1)
)
return pos_embed
def _init_deform_weights(self, m):
if isinstance(m, MSDeformAttn):
m._reset_parameters()
def _add_level_embed(self, c2, c3, c4):
c2 = c2 + self.level_embed[0]
c3 = c3 + self.level_embed[1]
c4 = c4 + self.level_embed[2]
return c2, c3, c4
def forward(self, x):
deform_inputs1, deform_inputs2 = deform_inputs(x, self.patch_size)
# SPM forward
c1, c2, c3, c4 = self.spm(x)
c2, c3, c4 = self._add_level_embed(c2, c3, c4)
c = torch.cat([c2, c3, c4], dim=1)
# Patch Embedding forward
H_c, W_c = x.shape[2] // 16, x.shape[3] // 16
x, H_toks, W_toks = self.patch_embed(x)
# print("H_toks, W_toks =", H_toks, W_toks)
bs, n, dim = x.shape
pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H_toks, W_toks)
if self.use_cls:
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_token, x), dim=1)
pos_embed = torch.cat((self.pos_embed[:, :1], pos_embed), dim=1)
x = self.pos_drop(x + pos_embed)
# For CLIP
x = self.norm_pre(x)
# Interaction
if self.use_cls:
cls, x = (
x[
:,
:1,
],
x[
:,
1:,
],
)
outs = list()
for i, layer in enumerate(self.interactions):
indexes = self.interaction_indexes[i]
if self.use_cls:
x, c, cls = layer(
x,
c,
cls,
self.blocks[indexes[0] : indexes[-1] + 1],
deform_inputs1,
deform_inputs2,
H_c,
W_c,
H_toks,
W_toks,
)
else:
x, c = layer(
x,
c,
self.blocks[indexes[0] : indexes[-1] + 1],
deform_inputs1,
deform_inputs2,
H_c,
W_c,
H_toks,
W_toks,
)
outs.append(x.transpose(1, 2).view(bs, dim, H_toks, W_toks).contiguous())
# Split & Reshape
c2 = c[:, 0 : c2.size(1), :]
c3 = c[:, c2.size(1) : c2.size(1) + c3.size(1), :]
c4 = c[:, c2.size(1) + c3.size(1) :, :]
c2 = c2.transpose(1, 2).view(bs, dim, H_c * 2, W_c * 2).contiguous()
c3 = c3.transpose(1, 2).view(bs, dim, H_c, W_c).contiguous()
c4 = c4.transpose(1, 2).view(bs, dim, H_c // 2, W_c // 2).contiguous()
c1 = self.up(c2) + c1
if self.add_vit_feature:
x1, x2, x3, x4 = outs
x1 = F.interpolate(x1, size=(4 * H_c, 4 * W_c), mode="bilinear", align_corners=False)
x2 = F.interpolate(x2, size=(2 * H_c, 2 * W_c), mode="bilinear", align_corners=False)
x3 = F.interpolate(x3, size=(1 * H_c, 1 * W_c), mode="bilinear", align_corners=False)
x4 = F.interpolate(x4, size=(H_c // 2, W_c // 2), mode="bilinear", align_corners=False)
# print(c1.shape, c2.shape, c3.shape, c4.shape, x1.shape, x2.shape, x3.shape, x4.shape, H_c, H_toks)
c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4
# Final Norm
f1 = self.norm1(c1)
f2 = self.norm2(c2)
f3 = self.norm3(c3)
f4 = self.norm4(c4)
return [f1, f2, f3, f4]

View File

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from mmcv.utils import Registry
TRANSFORMER = Registry("Transformer")
MASK_ASSIGNERS = Registry("mask_assigner")
MATCH_COST = Registry("match_cost")
def build_match_cost(cfg):
"""Build Match Cost."""
return MATCH_COST.build(cfg)
def build_assigner(cfg):
"""Build Assigner."""
return MASK_ASSIGNERS.build(cfg)
def build_transformer(cfg):
"""Build Transformer."""
return TRANSFORMER.build(cfg)

View File

@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .mask2former_head import Mask2FormerHead

View File

@ -0,0 +1,544 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init
from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence
from mmcv.ops import point_sample
from mmcv.runner import ModuleList, force_fp32
from mmseg.models.builder import HEADS, build_loss
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from ...core import build_sampler, multi_apply, reduce_mean
from ..builder import build_assigner
from ..utils import get_uncertain_point_coords_with_randomness
@HEADS.register_module()
class Mask2FormerHead(BaseDecodeHead):
"""Implements the Mask2Former head.
See `Masked-attention Mask Transformer for Universal Image
Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
Args:
in_channels (list[int]): Number of channels in the input feature map.
feat_channels (int): Number of channels for features.
out_channels (int): Number of channels for output.
num_things_classes (int): Number of things.
num_stuff_classes (int): Number of stuff.
num_queries (int): Number of query in Transformer decoder.
pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel
decoder. Defaults to None.
enforce_decoder_input_project (bool, optional): Whether to add
a layer to change the embed_dim of tranformer encoder in
pixel decoder to the embed_dim of transformer decoder.
Defaults to False.
transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for
transformer decoder. Defaults to None.
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
transformer decoder position encoding. Defaults to None.
loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification
loss. Defaults to None.
loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss.
Defaults to None.
loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss.
Defaults to None.
train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of
Mask2Former head.
test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of
Mask2Former head.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
in_channels,
feat_channels,
out_channels,
num_things_classes=80,
num_stuff_classes=53,
num_queries=100,
num_transformer_feat_level=3,
pixel_decoder=None,
enforce_decoder_input_project=False,
transformer_decoder=None,
positional_encoding=None,
loss_cls=None,
loss_mask=None,
loss_dice=None,
train_cfg=None,
test_cfg=None,
init_cfg=None,
**kwargs,
):
super(Mask2FormerHead, self).__init__(
in_channels=in_channels,
channels=feat_channels,
num_classes=(num_things_classes + num_stuff_classes),
init_cfg=init_cfg,
input_transform="multiple_select",
**kwargs,
)
self.num_things_classes = num_things_classes
self.num_stuff_classes = num_stuff_classes
self.num_classes = self.num_things_classes + self.num_stuff_classes
self.num_queries = num_queries
self.num_transformer_feat_level = num_transformer_feat_level
self.num_heads = transformer_decoder.transformerlayers.attn_cfgs.num_heads
self.num_transformer_decoder_layers = transformer_decoder.num_layers
assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level
pixel_decoder_ = copy.deepcopy(pixel_decoder)
pixel_decoder_.update(in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels)
self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1]
self.transformer_decoder = build_transformer_layer_sequence(transformer_decoder)
self.decoder_embed_dims = self.transformer_decoder.embed_dims
self.decoder_input_projs = ModuleList()
# from low resolution to high resolution
for _ in range(num_transformer_feat_level):
if self.decoder_embed_dims != feat_channels or enforce_decoder_input_project:
self.decoder_input_projs.append(Conv2d(feat_channels, self.decoder_embed_dims, kernel_size=1))
else:
self.decoder_input_projs.append(nn.Identity())
self.decoder_positional_encoding = build_positional_encoding(positional_encoding)
self.query_embed = nn.Embedding(self.num_queries, feat_channels)
self.query_feat = nn.Embedding(self.num_queries, feat_channels)
# from low resolution to high resolution
self.level_embed = nn.Embedding(self.num_transformer_feat_level, feat_channels)
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
self.mask_embed = nn.Sequential(
nn.Linear(feat_channels, feat_channels),
nn.ReLU(inplace=True),
nn.Linear(feat_channels, feat_channels),
nn.ReLU(inplace=True),
nn.Linear(feat_channels, out_channels),
)
self.conv_seg = None # fix a bug here (conv_seg is not used)
self.test_cfg = test_cfg
self.train_cfg = train_cfg
if train_cfg:
self.assigner = build_assigner(self.train_cfg.assigner)
self.sampler = build_sampler(self.train_cfg.sampler, context=self)
self.num_points = self.train_cfg.get("num_points", 12544)
self.oversample_ratio = self.train_cfg.get("oversample_ratio", 3.0)
self.importance_sample_ratio = self.train_cfg.get("importance_sample_ratio", 0.75)
self.class_weight = loss_cls.class_weight
self.loss_cls = build_loss(loss_cls)
self.loss_mask = build_loss(loss_mask)
self.loss_dice = build_loss(loss_dice)
def init_weights(self):
for m in self.decoder_input_projs:
if isinstance(m, Conv2d):
caffe2_xavier_init(m, bias=0)
self.pixel_decoder.init_weights()
for p in self.transformer_decoder.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas):
"""Compute classification and mask targets for all images for a decoder
layer.
Args:
cls_scores_list (list[Tensor]): Mask score logits from a single
decoder layer for all images. Each with shape [num_queries,
cls_out_channels].
mask_preds_list (list[Tensor]): Mask logits from a single decoder
layer for all images. Each with shape [num_queries, h, w].
gt_labels_list (list[Tensor]): Ground truth class indices for all
images. Each with shape (n, ), n is the sum of number of stuff
type and number of instance in a image.
gt_masks_list (list[Tensor]): Ground truth mask for each image,
each with shape (n, h, w).
img_metas (list[dict]): List of image meta information.
Returns:
tuple[list[Tensor]]: a tuple containing the following targets.
- labels_list (list[Tensor]): Labels of all images.
Each with shape [num_queries, ].
- label_weights_list (list[Tensor]): Label weights of all
images.Each with shape [num_queries, ].
- mask_targets_list (list[Tensor]): Mask targets of all images.
Each with shape [num_queries, h, w].
- mask_weights_list (list[Tensor]): Mask weights of all images.
Each with shape [num_queries, ].
- num_total_pos (int): Number of positive samples in all
images.
- num_total_neg (int): Number of negative samples in all
images.
"""
(
labels_list,
label_weights_list,
mask_targets_list,
mask_weights_list,
pos_inds_list,
neg_inds_list,
) = multi_apply(
self._get_target_single, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas
)
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
return (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg)
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas):
"""Compute classification and mask targets for one image.
Args:
cls_score (Tensor): Mask score logits from a single decoder layer
for one image. Shape (num_queries, cls_out_channels).
mask_pred (Tensor): Mask logits for a single decoder layer for one
image. Shape (num_queries, h, w).
gt_labels (Tensor): Ground truth class indices for one image with
shape (num_gts, ).
gt_masks (Tensor): Ground truth mask for each image, each with
shape (num_gts, h, w).
img_metas (dict): Image informtation.
Returns:
tuple[Tensor]: A tuple containing the following for one image.
- labels (Tensor): Labels of each image. \
shape (num_queries, ).
- label_weights (Tensor): Label weights of each image. \
shape (num_queries, ).
- mask_targets (Tensor): Mask targets of each image. \
shape (num_queries, h, w).
- mask_weights (Tensor): Mask weights of each image. \
shape (num_queries, ).
- pos_inds (Tensor): Sampled positive indices for each \
image.
- neg_inds (Tensor): Sampled negative indices for each \
image.
"""
# sample points
num_queries = cls_score.shape[0]
num_gts = gt_labels.shape[0]
point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device)
# shape (num_queries, num_points)
mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1)
# shape (num_gts, num_points)
gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1)
# assign and sample
assign_result = self.assigner.assign(cls_score, mask_points_pred, gt_labels, gt_points_masks, img_metas)
sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks)
pos_inds = sampling_result.pos_inds
neg_inds = sampling_result.neg_inds
# label target
labels = gt_labels.new_full((self.num_queries,), self.num_classes, dtype=torch.long)
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
label_weights = gt_labels.new_ones((self.num_queries,))
# mask target
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
mask_weights = mask_pred.new_zeros((self.num_queries,))
mask_weights[pos_inds] = 1.0
return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds)
def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas):
"""Loss function for outputs from a single decoder layer.
Args:
cls_scores (Tensor): Mask score logits from a single decoder layer
for all images. Shape (batch_size, num_queries,
cls_out_channels). Note `cls_out_channels` should includes
background.
mask_preds (Tensor): Mask logits for a pixel decoder for all
images. Shape (batch_size, num_queries, h, w).
gt_labels_list (list[Tensor]): Ground truth class indices for each
image, each with shape (num_gts, ).
gt_masks_list (list[Tensor]): Ground truth mask for each image,
each with shape (num_gts, h, w).
img_metas (list[dict]): List of image meta information.
Returns:
tuple[Tensor]: Loss components for outputs from a single \
decoder layer.
"""
num_imgs = cls_scores.size(0)
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
(
labels_list,
label_weights_list,
mask_targets_list,
mask_weights_list,
num_total_pos,
num_total_neg,
) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas)
# shape (batch_size, num_queries)
labels = torch.stack(labels_list, dim=0)
# shape (batch_size, num_queries)
label_weights = torch.stack(label_weights_list, dim=0)
# shape (num_total_gts, h, w)
mask_targets = torch.cat(mask_targets_list, dim=0)
# shape (batch_size, num_queries)
mask_weights = torch.stack(mask_weights_list, dim=0)
# classfication loss
# shape (batch_size * num_queries, )
cls_scores = cls_scores.flatten(0, 1)
labels = labels.flatten(0, 1)
label_weights = label_weights.flatten(0, 1)
class_weight = cls_scores.new_tensor(self.class_weight)
loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum())
num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos]))
num_total_masks = max(num_total_masks, 1)
# extract positive ones
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
mask_preds = mask_preds[mask_weights > 0]
if mask_targets.shape[0] == 0:
# zero match
loss_dice = mask_preds.sum()
loss_mask = mask_preds.sum()
return loss_cls, loss_mask, loss_dice
with torch.no_grad():
points_coords = get_uncertain_point_coords_with_randomness(
mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio
)
# shape (num_total_gts, h, w) -> (num_total_gts, num_points)
mask_point_targets = point_sample(mask_targets.unsqueeze(1).float(), points_coords).squeeze(1)
# shape (num_queries, h, w) -> (num_queries, num_points)
mask_point_preds = point_sample(mask_preds.unsqueeze(1), points_coords).squeeze(1)
# dice loss
loss_dice = self.loss_dice(mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
# mask loss
# shape (num_queries, num_points) -> (num_queries * num_points, )
mask_point_preds = mask_point_preds.reshape(-1, 1)
# shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
mask_point_targets = mask_point_targets.reshape(-1)
loss_mask = self.loss_mask(mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points)
return loss_cls, loss_mask, loss_dice
@force_fp32(apply_to=("all_cls_scores", "all_mask_preds"))
def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, gt_masks_list, img_metas):
"""Loss function.
Args:
all_cls_scores (Tensor): Classification scores for all decoder
layers with shape [num_decoder, batch_size, num_queries,
cls_out_channels].
all_mask_preds (Tensor): Mask scores for all decoder layers with
shape [num_decoder, batch_size, num_queries, h, w].
gt_labels_list (list[Tensor]): Ground truth class indices for each
image with shape (n, ). n is the sum of number of stuff type
and number of instance in a image.
gt_masks_list (list[Tensor]): Ground truth mask for each image with
shape (n, h, w).
img_metas (list[dict]): List of image meta information.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
num_dec_layers = len(all_cls_scores)
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)]
img_metas_list = [img_metas for _ in range(num_dec_layers)]
losses_cls, losses_mask, losses_dice = multi_apply(
self.loss_single, all_cls_scores, all_mask_preds, all_gt_labels_list, all_gt_masks_list, img_metas_list
)
loss_dict = dict()
# loss from the last decoder layer
loss_dict["loss_cls"] = losses_cls[-1]
loss_dict["loss_mask"] = losses_mask[-1]
loss_dict["loss_dice"] = losses_dice[-1]
# loss from other decoder layers
num_dec_layer = 0
for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i
loss_dict[f"d{num_dec_layer}.loss_mask"] = loss_mask_i
loss_dict[f"d{num_dec_layer}.loss_dice"] = loss_dice_i
num_dec_layer += 1
return loss_dict
def forward_head(self, decoder_out, mask_feature, attn_mask_target_size):
"""Forward for head part which is called after every decoder layer.
Args:
decoder_out (Tensor): in shape (num_queries, batch_size, c).
mask_feature (Tensor): in shape (batch_size, c, h, w).
attn_mask_target_size (tuple[int, int]): target attention
mask size.
Returns:
tuple: A tuple contain three elements.
- cls_pred (Tensor): Classification scores in shape \
(batch_size, num_queries, cls_out_channels). \
Note `cls_out_channels` should includes background.
- mask_pred (Tensor): Mask scores in shape \
(batch_size, num_queries,h, w).
- attn_mask (Tensor): Attention mask in shape \
(batch_size * num_heads, num_queries, h, w).
"""
decoder_out = self.transformer_decoder.post_norm(decoder_out)
decoder_out = decoder_out.transpose(0, 1)
# shape (num_queries, batch_size, c)
cls_pred = self.cls_embed(decoder_out)
# shape (num_queries, batch_size, c)
mask_embed = self.mask_embed(decoder_out)
# shape (num_queries, batch_size, h, w)
mask_pred = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_feature)
attn_mask = F.interpolate(mask_pred, attn_mask_target_size, mode="bilinear", align_corners=False)
# shape (num_queries, batch_size, h, w) ->
# (batch_size * num_head, num_queries, h, w)
attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat((1, self.num_heads, 1, 1)).flatten(0, 1)
attn_mask = attn_mask.sigmoid() < 0.5
attn_mask = attn_mask.detach()
return cls_pred, mask_pred, attn_mask
def forward(self, feats, img_metas):
"""Forward function.
Args:
feats (list[Tensor]): Multi scale Features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
Returns:
tuple: A tuple contains two elements.
- cls_pred_list (list[Tensor)]: Classification logits \
for each decoder layer. Each is a 3D-tensor with shape \
(batch_size, num_queries, cls_out_channels). \
Note `cls_out_channels` should includes background.
- mask_pred_list (list[Tensor]): Mask logits for each \
decoder layer. Each with shape (batch_size, num_queries, \
h, w).
"""
batch_size = len(img_metas)
mask_features, multi_scale_memorys = self.pixel_decoder(feats)
# multi_scale_memorys (from low resolution to high resolution)
decoder_inputs = []
decoder_positional_encodings = []
for i in range(self.num_transformer_feat_level):
decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
decoder_input = decoder_input.flatten(2).permute(2, 0, 1)
level_embed = self.level_embed.weight[i].view(1, 1, -1)
decoder_input = decoder_input + level_embed
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
mask = decoder_input.new_zeros((batch_size,) + multi_scale_memorys[i].shape[-2:], dtype=torch.bool)
decoder_positional_encoding = self.decoder_positional_encoding(mask)
decoder_positional_encoding = decoder_positional_encoding.flatten(2).permute(2, 0, 1)
decoder_inputs.append(decoder_input)
decoder_positional_encodings.append(decoder_positional_encoding)
# shape (num_queries, c) -> (num_queries, batch_size, c)
query_feat = self.query_feat.weight.unsqueeze(1).repeat((1, batch_size, 1))
query_embed = self.query_embed.weight.unsqueeze(1).repeat((1, batch_size, 1))
cls_pred_list = []
mask_pred_list = []
cls_pred, mask_pred, attn_mask = self.forward_head(query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
cls_pred_list.append(cls_pred)
mask_pred_list.append(mask_pred)
for i in range(self.num_transformer_decoder_layers):
level_idx = i % self.num_transformer_feat_level
# if a mask is all True(all background), then set it all False.
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
# cross_attn + self_attn
layer = self.transformer_decoder.layers[i]
attn_masks = [attn_mask, None]
query_feat = layer(
query=query_feat,
key=decoder_inputs[level_idx],
value=decoder_inputs[level_idx],
query_pos=query_embed,
key_pos=decoder_positional_encodings[level_idx],
attn_masks=attn_masks,
query_key_padding_mask=None,
# here we do not apply masking on padded region
key_padding_mask=None,
)
cls_pred, mask_pred, attn_mask = self.forward_head(
query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:]
)
cls_pred_list.append(cls_pred)
mask_pred_list.append(mask_pred)
return cls_pred_list, mask_pred_list
def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, gt_masks):
"""Forward function for training mode.
Args:
x (list[Tensor]): Multi-level features from the upstream network,
each is a 4D-tensor.
img_metas (list[Dict]): List of image information.
gt_semantic_seg (list[tensor]):Each element is the ground truth
of semantic segmentation with the shape (N, H, W).
train_cfg (dict): The training config, which not been used in
maskformer.
gt_labels (list[Tensor]): Each element is ground truth labels of
each box, shape (num_gts,).
gt_masks (list[BitmapMasks]): Each element is masks of instances
of a image, shape (num_gts, h, w).
Returns:
losses (dict[str, Tensor]): a dictionary of loss components
"""
# forward
all_cls_scores, all_mask_preds = self(x, img_metas)
# loss
losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, img_metas)
return losses
def forward_test(self, inputs, img_metas, test_cfg):
"""Test segment without test-time aumengtation.
Only the output of last decoder layers was used.
Args:
inputs (list[Tensor]): Multi-level features from the
upstream network, each is a 4D-tensor.
img_metas (list[dict]): List of image information.
test_cfg (dict): Testing config.
Returns:
seg_mask (Tensor): Predicted semantic segmentation logits.
"""
all_cls_scores, all_mask_preds = self(inputs, img_metas)
cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1]
ori_h, ori_w, _ = img_metas[0]["ori_shape"]
# semantic inference
cls_score = F.softmax(cls_score, dim=-1)[..., :-1]
mask_pred = mask_pred.sigmoid()
seg_mask = torch.einsum("bqc,bqhw->bchw", cls_score, mask_pred)
return seg_mask

View File

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .cross_entropy_loss import CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy
from .dice_loss import DiceLoss
from .match_costs import ClassificationCost, CrossEntropyLossCost, DiceCost

View File

@ -0,0 +1,279 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.models.builder import LOSSES
from mmseg.models.losses.utils import get_class_weight, weight_reduce_loss
def cross_entropy(
pred,
label,
weight=None,
class_weight=None,
reduction="mean",
avg_factor=None,
ignore_index=-100,
avg_non_ignore=False,
):
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
weight (torch.Tensor, optional): Sample-wise loss weight.
Default: None.
class_weight (list[float], optional): The weight for each class.
Default: None.
reduction (str, optional): The method used to reduce the loss.
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
avg_factor (int, optional): Average factor that is used to average
the loss. Default: None.
ignore_index (int): Specifies a target value that is ignored and
does not contribute to the input gradients. When
``avg_non_ignore `` is ``True``, and the ``reduction`` is
``''mean''``, the loss is averaged over non-ignored targets.
Defaults: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
"""
# class_weight is a manual rescaling weight given to each class.
# If given, has to be a Tensor of size C element-wise losses
loss = F.cross_entropy(pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index)
# apply weights and do the reduction
# average loss over non-ignored elements
# pytorch's official cross_entropy average loss over non-ignored elements
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
if (avg_factor is None) and avg_non_ignore and reduction == "mean":
avg_factor = label.numel() - (label == ignore_index).sum().item()
if weight is not None:
weight = weight.float()
loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
return loss
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
"""Expand onehot labels to match the size of prediction."""
bin_labels = labels.new_zeros(target_shape)
valid_mask = (labels >= 0) & (labels != ignore_index)
inds = torch.nonzero(valid_mask, as_tuple=True)
if inds[0].numel() > 0:
if labels.dim() == 3:
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
else:
bin_labels[inds[0], labels[valid_mask]] = 1
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
if label_weights is None:
bin_label_weights = valid_mask
else:
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
bin_label_weights = bin_label_weights * valid_mask
return bin_labels, bin_label_weights, valid_mask
def binary_cross_entropy(
pred,
label,
weight=None,
reduction="mean",
avg_factor=None,
class_weight=None,
ignore_index=-100,
avg_non_ignore=False,
**kwargs,
):
"""Calculate the binary CrossEntropy loss.
Args:
pred (torch.Tensor): The prediction with shape (N, 1).
label (torch.Tensor): The learning label of the prediction.
Note: In bce loss, label < 0 is invalid.
weight (torch.Tensor, optional): Sample-wise loss weight.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (int): The label index to be ignored. Default: -100.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
Returns:
torch.Tensor: The calculated loss
"""
if pred.size(1) == 1:
# For binary class segmentation, the shape of pred is
# [N, 1, H, W] and that of label is [N, H, W].
assert label.max() <= 1, "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes"
pred = pred.squeeze()
if pred.dim() != label.dim():
assert (pred.dim() == 2 and label.dim() == 1) or (pred.dim() == 4 and label.dim() == 3), (
"Only pred shape [N, C], label shape [N] or pred shape [N, C, " "H, W], label shape [N, H, W] are supported"
)
# `weight` returned from `_expand_onehot_labels`
# has been treated for valid (non-ignore) pixels
label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.shape, ignore_index)
else:
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
weight = weight * valid_mask
else:
weight = valid_mask
# average loss over non-ignored and valid elements
if reduction == "mean" and avg_factor is None and avg_non_ignore:
avg_factor = valid_mask.sum().item()
loss = F.binary_cross_entropy_with_logits(pred, label.float(), pos_weight=class_weight, reduction="none")
# do the reduction for the weighted loss
loss = weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor)
return loss
def mask_cross_entropy(
pred, target, label, reduction="mean", avg_factor=None, class_weight=None, ignore_index=None, **kwargs
):
"""Calculate the CrossEntropy loss for masks.
Args:
pred (torch.Tensor): The prediction with shape (N, C), C is the number
of classes.
target (torch.Tensor): The learning label of the prediction.
label (torch.Tensor): ``label`` indicates the class label of the mask'
corresponding object. This will be used to select the mask in the
of the class which the object belongs to when the mask prediction
if not class-agnostic.
reduction (str, optional): The method used to reduce the loss.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
class_weight (list[float], optional): The weight for each class.
ignore_index (None): Placeholder, to be consistent with other loss.
Default: None.
Returns:
torch.Tensor: The calculated loss
"""
assert ignore_index is None, "BCE loss does not support ignore_index"
assert reduction == "mean" and avg_factor is None
num_rois = pred.size()[0]
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
pred_slice = pred[inds, label].squeeze(1)
return F.binary_cross_entropy_with_logits(pred_slice, target, weight=class_weight, reduction="mean")[None]
@LOSSES.register_module(force=True)
class CrossEntropyLoss(nn.Module):
"""CrossEntropyLoss.
Args:
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to False.
use_mask (bool, optional): Whether to use mask cross entropy loss.
Defaults to False.
reduction (str, optional): . Defaults to 'mean'.
Options are "none", "mean" and "sum".
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
loss_name (str, optional): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_ce'.
avg_non_ignore (bool): The flag decides to whether the loss is
only averaged over non-ignored targets. Default: False.
`New in version 0.23.0.`
"""
def __init__(
self,
use_sigmoid=False,
use_mask=False,
reduction="mean",
class_weight=None,
loss_weight=1.0,
loss_name="loss_ce",
avg_non_ignore=False,
):
super(CrossEntropyLoss, self).__init__()
assert (use_sigmoid is False) or (use_mask is False)
self.use_sigmoid = use_sigmoid
self.use_mask = use_mask
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = get_class_weight(class_weight)
self.avg_non_ignore = avg_non_ignore
if not self.avg_non_ignore and self.reduction == "mean":
warnings.warn(
"Default ``avg_non_ignore`` is False, if you would like to "
"ignore the certain label and average loss over non-ignore "
"labels, which is the same with PyTorch official "
"cross_entropy, set ``avg_non_ignore=True``."
)
if self.use_sigmoid:
self.cls_criterion = binary_cross_entropy
elif self.use_mask:
self.cls_criterion = mask_cross_entropy
else:
self.cls_criterion = cross_entropy
self._loss_name = loss_name
def extra_repr(self):
"""Extra repr."""
s = f"avg_non_ignore={self.avg_non_ignore}"
return s
def forward(
self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, ignore_index=-100, **kwargs
):
"""Forward function."""
assert reduction_override in (None, "none", "mean", "sum")
reduction = reduction_override if reduction_override else self.reduction
if self.class_weight is not None:
class_weight = cls_score.new_tensor(self.class_weight)
else:
class_weight = None
# Note: for BCE loss, label < 0 is invalid.
loss_cls = self.loss_weight * self.cls_criterion(
cls_score,
label,
weight,
class_weight=class_weight,
reduction=reduction,
avg_factor=avg_factor,
avg_non_ignore=self.avg_non_ignore,
ignore_index=ignore_index,
**kwargs,
)
return loss_cls
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from mmseg.models.builder import LOSSES
from mmseg.models.losses.utils import weight_reduce_loss
def dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None):
"""Calculate dice loss, which is proposed in
`V-Net: Fully Convolutional Neural Networks for Volumetric
Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *)
target (torch.Tensor): The learning label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
eps (float): Avoid dividing by zero. Default: 1e-3.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
input = pred.flatten(1)
target = target.flatten(1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + eps
c = torch.sum(target * target, 1) + eps
d = (2 * a) / (b + c)
loss = 1 - d
if weight is not None:
assert weight.ndim == loss.ndim
assert len(weight) == len(pred)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
def naive_dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None):
"""Calculate naive dice loss, the coefficient in the denominator is the
first power instead of the second power.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *)
target (torch.Tensor): The learning label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
eps (float): Avoid dividing by zero. Default: 1e-3.
reduction (str, optional): The method used to reduce the loss into
a scalar. Defaults to 'mean'.
Options are "none", "mean" and "sum".
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
"""
input = pred.flatten(1)
target = target.flatten(1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input, 1)
c = torch.sum(target, 1)
d = (2 * a + eps) / (b + c + eps)
loss = 1 - d
if weight is not None:
assert weight.ndim == loss.ndim
assert len(weight) == len(pred)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
@LOSSES.register_module(force=True)
class DiceLoss(nn.Module):
def __init__(self, use_sigmoid=True, activate=True, reduction="mean", naive_dice=False, loss_weight=1.0, eps=1e-3):
"""Dice Loss, there are two forms of dice loss is supported:
- the one proposed in `V-Net: Fully Convolutional Neural
Networks for Volumetric Medical Image Segmentation
<https://arxiv.org/abs/1606.04797>`_.
- the dice loss in which the power of the number in the
denominator is the first power instead of the second
power.
Args:
use_sigmoid (bool, optional): Whether to the prediction is
used for sigmoid or softmax. Defaults to True.
activate (bool): Whether to activate the predictions inside,
this will disable the inside sigmoid operation.
Defaults to True.
reduction (str, optional): The method used
to reduce the loss. Options are "none",
"mean" and "sum". Defaults to 'mean'.
naive_dice (bool, optional): If false, use the dice
loss defined in the V-Net paper, otherwise, use the
naive dice loss in which the power of the number in the
denominator is the first power instead of the second
power.Defaults to False.
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
eps (float): Avoid dividing by zero. Defaults to 1e-3.
"""
super(DiceLoss, self).__init__()
self.use_sigmoid = use_sigmoid
self.reduction = reduction
self.naive_dice = naive_dice
self.loss_weight = loss_weight
self.eps = eps
self.activate = activate
def forward(self, pred, target, weight=None, reduction_override=None, avg_factor=None):
"""Forward function.
Args:
pred (torch.Tensor): The prediction, has a shape (n, *).
target (torch.Tensor): The label of the prediction,
shape (n, *), same shape of pred.
weight (torch.Tensor, optional): The weight of loss for each
prediction, has a shape (n,). Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Options are "none", "mean" and "sum".
Returns:
torch.Tensor: The calculated loss
"""
assert reduction_override in (None, "none", "mean", "sum")
reduction = reduction_override if reduction_override else self.reduction
if self.activate:
if self.use_sigmoid:
pred = pred.sigmoid()
else:
raise NotImplementedError
if self.naive_dice:
loss = self.loss_weight * naive_dice_loss(
pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor
)
else:
loss = self.loss_weight * dice_loss(
pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor
)
return loss

View File

@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from ..builder import MATCH_COST
@MATCH_COST.register_module()
class ClassificationCost:
"""ClsSoftmaxCost.Borrow from
mmdet.core.bbox.match_costs.match_cost.ClassificationCost.
Args:
weight (int | float, optional): loss_weight
Examples:
>>> import torch
>>> self = ClassificationCost()
>>> cls_pred = torch.rand(4, 3)
>>> gt_labels = torch.tensor([0, 1, 2])
>>> factor = torch.tensor([10, 8, 10, 8])
>>> self(cls_pred, gt_labels)
tensor([[-0.3430, -0.3525, -0.3045],
[-0.3077, -0.2931, -0.3992],
[-0.3664, -0.3455, -0.2881],
[-0.3343, -0.2701, -0.3956]])
"""
def __init__(self, weight=1.0):
self.weight = weight
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
Returns:
torch.Tensor: cls_cost value with weight
"""
# Following the official DETR repo, contrary to the loss that
# NLL is used, we approximate it in 1 - cls_score[gt_label].
# The 1 is a constant that doesn't change the matching,
# so it can be omitted.
cls_score = cls_pred.softmax(-1)
cls_cost = -cls_score[:, gt_labels]
return cls_cost * self.weight
@MATCH_COST.register_module()
class DiceCost:
"""Cost of mask assignments based on dice losses.
Args:
weight (int | float, optional): loss_weight. Defaults to 1.
pred_act (bool, optional): Whether to apply sigmoid to mask_pred.
Defaults to False.
eps (float, optional): default 1e-12.
"""
def __init__(self, weight=1.0, pred_act=False, eps=1e-3):
self.weight = weight
self.pred_act = pred_act
self.eps = eps
def binary_mask_dice_loss(self, mask_preds, gt_masks):
"""
Args:
mask_preds (Tensor): Mask prediction in shape (N1, H, W).
gt_masks (Tensor): Ground truth in shape (N2, H, W)
store 0 or 1, 0 for negative class and 1 for
positive class.
Returns:
Tensor: Dice cost matrix in shape (N1, N2).
"""
mask_preds = mask_preds.reshape((mask_preds.shape[0], -1))
gt_masks = gt_masks.reshape((gt_masks.shape[0], -1)).float()
numerator = 2 * torch.einsum("nc,mc->nm", mask_preds, gt_masks)
denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :]
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
return loss
def __call__(self, mask_preds, gt_masks):
"""
Args:
mask_preds (Tensor): Mask prediction logits in shape (N1, H, W).
gt_masks (Tensor): Ground truth in shape (N2, H, W).
Returns:
Tensor: Dice cost matrix in shape (N1, N2).
"""
if self.pred_act:
mask_preds = mask_preds.sigmoid()
dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks)
return dice_cost * self.weight
@MATCH_COST.register_module()
class CrossEntropyLossCost:
"""CrossEntropyLossCost.
Args:
weight (int | float, optional): loss weight. Defaults to 1.
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
of softmax. Defaults to True.
"""
def __init__(self, weight=1.0, use_sigmoid=True):
assert use_sigmoid, "use_sigmoid = False is not supported yet."
self.weight = weight
self.use_sigmoid = use_sigmoid
def _binary_cross_entropy(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): The prediction with shape (num_query, 1, *) or
(num_query, *).
gt_labels (Tensor): The learning label of prediction with
shape (num_gt, *).
Returns:
Tensor: Cross entropy cost matrix in shape (num_query, num_gt).
"""
cls_pred = cls_pred.flatten(1).float()
gt_labels = gt_labels.flatten(1).float()
n = cls_pred.shape[1]
pos = F.binary_cross_entropy_with_logits(cls_pred, torch.ones_like(cls_pred), reduction="none")
neg = F.binary_cross_entropy_with_logits(cls_pred, torch.zeros_like(cls_pred), reduction="none")
cls_cost = torch.einsum("nc,mc->nm", pos, gt_labels) + torch.einsum("nc,mc->nm", neg, 1 - gt_labels)
cls_cost = cls_cost / n
return cls_cost
def __call__(self, cls_pred, gt_labels):
"""
Args:
cls_pred (Tensor): Predicted classification logits.
gt_labels (Tensor): Labels.
Returns:
Tensor: Cross entropy cost matrix with weight in
shape (num_query, num_gt).
"""
if self.use_sigmoid:
cls_cost = self._binary_cross_entropy(cls_pred, gt_labels)
else:
raise NotImplementedError
return cls_cost * self.weight

View File

@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder

View File

@ -0,0 +1,242 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init, normal_init, xavier_init
from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence
from mmcv.runner import BaseModule, ModuleList
from ...core.anchor import MlvlPointGenerator
from ..utils.transformer import MultiScaleDeformableAttention
@PLUGIN_LAYERS.register_module()
class MSDeformAttnPixelDecoder(BaseModule):
"""Pixel decoder with multi-scale deformable attention.
Args:
in_channels (list[int] | tuple[int]): Number of channels in the
input feature maps.
strides (list[int] | tuple[int]): Output strides of feature from
backbone.
feat_channels (int): Number of channels for feature.
out_channels (int): Number of channels for output.
num_outs (int): Number of output scales.
norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization.
Defaults to dict(type='GN', num_groups=32).
act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation.
Defaults to dict(type='ReLU').
encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer
encoder. Defaults to `DetrTransformerEncoder`.
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
transformer encoder position encoding. Defaults to
dict(type='SinePositionalEncoding', num_feats=128,
normalize=True).
init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict.
"""
def __init__(
self,
in_channels=[256, 512, 1024, 2048],
strides=[4, 8, 16, 32],
feat_channels=256,
out_channels=256,
num_outs=3,
norm_cfg=dict(type="GN", num_groups=32),
act_cfg=dict(type="ReLU"),
encoder=dict(
type="DetrTransformerEncoder",
num_layers=6,
transformerlayers=dict(
type="BaseTransformerLayer",
attn_cfgs=dict(
type="MultiScaleDeformableAttention",
embed_dims=256,
num_heads=8,
num_levels=3,
num_points=4,
im2col_step=64,
dropout=0.0,
batch_first=False,
norm_cfg=None,
init_cfg=None,
),
feedforward_channels=1024,
ffn_dropout=0.0,
operation_order=("self_attn", "norm", "ffn", "norm"),
),
init_cfg=None,
),
positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True),
init_cfg=None,
):
super().__init__(init_cfg=init_cfg)
self.strides = strides
self.num_input_levels = len(in_channels)
self.num_encoder_levels = encoder.transformerlayers.attn_cfgs.num_levels
assert self.num_encoder_levels >= 1, "num_levels in attn_cfgs must be at least one"
input_conv_list = []
# from top to down (low to high resolution)
for i in range(self.num_input_levels - 1, self.num_input_levels - self.num_encoder_levels - 1, -1):
input_conv = ConvModule(
in_channels[i], feat_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=None, bias=True
)
input_conv_list.append(input_conv)
self.input_convs = ModuleList(input_conv_list)
self.encoder = build_transformer_layer_sequence(encoder)
self.postional_encoding = build_positional_encoding(positional_encoding)
# high resolution to low resolution
self.level_encoding = nn.Embedding(self.num_encoder_levels, feat_channels)
# fpn-like structure
self.lateral_convs = ModuleList()
self.output_convs = ModuleList()
self.use_bias = norm_cfg is None
# from top to down (low to high resolution)
# fpn for the rest features that didn't pass in encoder
for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1):
lateral_conv = ConvModule(
in_channels[i], feat_channels, kernel_size=1, bias=self.use_bias, norm_cfg=norm_cfg, act_cfg=None
)
output_conv = ConvModule(
feat_channels,
feat_channels,
kernel_size=3,
stride=1,
padding=1,
bias=self.use_bias,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
self.lateral_convs.append(lateral_conv)
self.output_convs.append(output_conv)
self.mask_feature = Conv2d(feat_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.num_outs = num_outs
self.point_generator = MlvlPointGenerator(strides)
def init_weights(self):
"""Initialize weights."""
for i in range(0, self.num_encoder_levels):
xavier_init(self.input_convs[i].conv, gain=1, bias=0, distribution="uniform")
for i in range(0, self.num_input_levels - self.num_encoder_levels):
caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
caffe2_xavier_init(self.output_convs[i].conv, bias=0)
caffe2_xavier_init(self.mask_feature, bias=0)
normal_init(self.level_encoding, mean=0, std=1)
for p in self.encoder.parameters():
if p.dim() > 1:
nn.init.xavier_normal_(p)
# init_weights defined in MultiScaleDeformableAttention
for layer in self.encoder.layers:
for attn in layer.attentions:
if isinstance(attn, MultiScaleDeformableAttention):
attn.init_weights()
def forward(self, feats):
"""
Args:
feats (list[Tensor]): Feature maps of each level. Each has
shape of (batch_size, c, h, w).
Returns:
tuple: A tuple containing the following:
- mask_feature (Tensor): shape (batch_size, c, h, w).
- multi_scale_features (list[Tensor]): Multi scale \
features, each in shape (batch_size, c, h, w).
"""
# generate padding mask for each level, for each image
batch_size = feats[0].shape[0]
encoder_input_list = []
padding_mask_list = []
level_positional_encoding_list = []
spatial_shapes = []
reference_points_list = []
for i in range(self.num_encoder_levels):
level_idx = self.num_input_levels - i - 1
feat = feats[level_idx]
feat_projected = self.input_convs[i](feat)
h, w = feat.shape[-2:]
# no padding
padding_mask_resized = feat.new_zeros((batch_size,) + feat.shape[-2:], dtype=torch.bool)
pos_embed = self.postional_encoding(padding_mask_resized)
level_embed = self.level_encoding.weight[i]
level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed
# (h_i * w_i, 2)
reference_points = self.point_generator.single_level_grid_priors(
feat.shape[-2:], level_idx, device=feat.device
)
# normalize
factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]
reference_points = reference_points / factor
# shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)
feat_projected = feat_projected.flatten(2).permute(2, 0, 1)
level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1)
padding_mask_resized = padding_mask_resized.flatten(1)
encoder_input_list.append(feat_projected)
padding_mask_list.append(padding_mask_resized)
level_positional_encoding_list.append(level_pos_embed)
spatial_shapes.append(feat.shape[-2:])
reference_points_list.append(reference_points)
# shape (batch_size, total_num_query),
# total_num_query=sum([., h_i * w_i,.])
padding_masks = torch.cat(padding_mask_list, dim=1)
# shape (total_num_query, batch_size, c)
encoder_inputs = torch.cat(encoder_input_list, dim=0)
level_positional_encodings = torch.cat(level_positional_encoding_list, dim=0)
device = encoder_inputs.device
# shape (num_encoder_levels, 2), from low
# resolution to high resolution
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=device)
# shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
reference_points = torch.cat(reference_points_list, dim=0)
reference_points = reference_points[None, :, None].repeat(batch_size, 1, self.num_encoder_levels, 1)
valid_radios = reference_points.new_ones((batch_size, self.num_encoder_levels, 2))
# shape (num_total_query, batch_size, c)
memory = self.encoder(
query=encoder_inputs,
key=None,
value=None,
query_pos=level_positional_encodings,
key_pos=None,
attn_masks=None,
key_padding_mask=None,
query_key_padding_mask=padding_masks,
spatial_shapes=spatial_shapes,
reference_points=reference_points,
level_start_index=level_start_index,
valid_radios=valid_radios,
)
# (num_total_query, batch_size, c) -> (batch_size, c, num_total_query)
memory = memory.permute(1, 2, 0)
# from low resolution to high resolution
num_query_per_level = [e[0] * e[1] for e in spatial_shapes]
outs = torch.split(memory, num_query_per_level, dim=-1)
outs = [x.reshape(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) for i, x in enumerate(outs)]
for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1):
x = feats[i]
cur_feat = self.lateral_convs[i](x)
y = cur_feat + F.interpolate(outs[-1], size=cur_feat.shape[-2:], mode="bilinear", align_corners=False)
y = self.output_convs[i](y)
outs.append(y)
multi_scale_features = outs[: self.num_outs]
mask_feature = self.mask_feature(outs[-1])
return mask_feature, multi_scale_features

View File

@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .encoder_decoder_mask2former import EncoderDecoderMask2Former

View File

@ -0,0 +1,271 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.core import add_prefix
from mmseg.models import builder
from mmseg.models.builder import SEGMENTORS
from mmseg.models.segmentors.base import BaseSegmentor
from mmseg.ops import resize
@SEGMENTORS.register_module()
class EncoderDecoderMask2Former(BaseSegmentor):
"""Encoder Decoder segmentors.
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
Note that auxiliary_head is only used for deep supervision during training,
which could be dumped during inference.
"""
def __init__(
self,
backbone,
decode_head,
neck=None,
auxiliary_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
init_cfg=None,
):
super(EncoderDecoderMask2Former, self).__init__(init_cfg)
if pretrained is not None:
assert backbone.get("pretrained") is None, "both backbone and segmentor set pretrained weight"
backbone.pretrained = pretrained
self.backbone = builder.build_backbone(backbone)
if neck is not None:
self.neck = builder.build_neck(neck)
decode_head.update(train_cfg=train_cfg)
decode_head.update(test_cfg=test_cfg)
self._init_decode_head(decode_head)
self._init_auxiliary_head(auxiliary_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
assert self.with_decode_head
def _init_decode_head(self, decode_head):
"""Initialize ``decode_head``"""
self.decode_head = builder.build_head(decode_head)
self.align_corners = self.decode_head.align_corners
self.num_classes = self.decode_head.num_classes
def _init_auxiliary_head(self, auxiliary_head):
"""Initialize ``auxiliary_head``"""
if auxiliary_head is not None:
if isinstance(auxiliary_head, list):
self.auxiliary_head = nn.ModuleList()
for head_cfg in auxiliary_head:
self.auxiliary_head.append(builder.build_head(head_cfg))
else:
self.auxiliary_head = builder.build_head(auxiliary_head)
def extract_feat(self, img):
"""Extract features from images."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def encode_decode(self, img, img_metas):
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
x = self.extract_feat(img)
out = self._decode_head_forward_test(x, img_metas)
out = resize(input=out, size=img.shape[2:], mode="bilinear", align_corners=self.align_corners)
return out
def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, **kwargs):
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head.forward_train(x, img_metas, gt_semantic_seg, **kwargs)
losses.update(add_prefix(loss_decode, "decode"))
return losses
def _decode_head_forward_test(self, x, img_metas):
"""Run forward function and calculate loss for decode head in
inference."""
seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg)
return seg_logits
def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg):
"""Run forward function and calculate loss for auxiliary head in
training."""
losses = dict()
if isinstance(self.auxiliary_head, nn.ModuleList):
for idx, aux_head in enumerate(self.auxiliary_head):
loss_aux = aux_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg)
losses.update(add_prefix(loss_aux, f"aux_{idx}"))
else:
loss_aux = self.auxiliary_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg)
losses.update(add_prefix(loss_aux, "aux"))
return losses
def forward_dummy(self, img):
"""Dummy forward function."""
seg_logit = self.encode_decode(img, None)
return seg_logit
def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs):
"""Forward function for training.
Args:
img (Tensor): Input images.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
gt_semantic_seg (Tensor): Semantic segmentation masks
used if the architecture supports semantic segmentation task.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
x = self.extract_feat(img)
losses = dict()
loss_decode = self._decode_head_forward_train(x, img_metas, gt_semantic_seg, **kwargs)
losses.update(loss_decode)
if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train(x, img_metas, gt_semantic_seg)
losses.update(loss_aux)
return losses
def slide_inference(self, img, img_meta, rescale):
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
"""
h_stride, w_stride = self.test_cfg.stride
h_crop, w_crop = self.test_cfg.crop_size
batch_size, _, h_img, w_img = img.size()
num_classes = self.num_classes
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
crop_seg_logit = self.encode_decode(crop_img, img_meta)
preds += F.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
if torch.onnx.is_in_onnx_export():
# cast count_mat to constant while exporting to ONNX
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
preds = preds / count_mat
if rescale:
preds = resize(
preds,
size=img_meta[0]["ori_shape"][:2],
mode="bilinear",
align_corners=self.align_corners,
warning=False,
)
return preds
def whole_inference(self, img, img_meta, rescale):
"""Inference with full image."""
seg_logit = self.encode_decode(img, img_meta)
if rescale:
# support dynamic shape for onnx
if torch.onnx.is_in_onnx_export():
size = img.shape[2:]
else:
size = img_meta[0]["ori_shape"][:2]
seg_logit = resize(seg_logit, size=size, mode="bilinear", align_corners=self.align_corners, warning=False)
return seg_logit
def inference(self, img, img_meta, rescale):
"""Inference with slide/whole style.
Args:
img (Tensor): The input image of shape (N, 3, H, W).
img_meta (dict): Image info dict where each dict has: 'img_shape',
'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
`mmseg/datasets/pipelines/formatting.py:Collect`.
rescale (bool): Whether rescale back to original shape.
Returns:
Tensor: The output segmentation map.
"""
assert self.test_cfg.mode in ["slide", "whole"]
ori_shape = img_meta[0]["ori_shape"]
assert all(_["ori_shape"] == ori_shape for _ in img_meta)
if self.test_cfg.mode == "slide":
seg_logit = self.slide_inference(img, img_meta, rescale)
else:
seg_logit = self.whole_inference(img, img_meta, rescale)
output = F.softmax(seg_logit, dim=1)
flip = img_meta[0]["flip"]
if flip:
flip_direction = img_meta[0]["flip_direction"]
assert flip_direction in ["horizontal", "vertical"]
if flip_direction == "horizontal":
output = output.flip(dims=(3,))
elif flip_direction == "vertical":
output = output.flip(dims=(2,))
return output
def simple_test(self, img, img_meta, rescale=True):
"""Simple test with single image."""
seg_logit = self.inference(img, img_meta, rescale)
seg_pred = seg_logit.argmax(dim=1)
if torch.onnx.is_in_onnx_export():
# our inference backend only support 4D output
seg_pred = seg_pred.unsqueeze(0)
return seg_pred
seg_pred = seg_pred.cpu().numpy()
# unravel batch dim
seg_pred = list(seg_pred)
return seg_pred
def aug_test(self, imgs, img_metas, rescale=True):
"""Test with augmentations.
Only rescale=True is supported.
"""
# aug_test rescale all imgs back to ori_shape for now
assert rescale
# to save memory, we get augmented seg logit inplace
seg_logit = self.inference(imgs[0], img_metas[0], rescale)
for i in range(1, len(imgs)):
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
seg_logit += cur_seg_logit
seg_logit /= len(imgs)
seg_pred = seg_logit.argmax(dim=1)
seg_pred = seg_pred.cpu().numpy()
# unravel batch dim
seg_pred = list(seg_pred)
return seg_pred

View File

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from .assigner import MaskHungarianAssigner
from .point_sample import get_uncertain_point_coords_with_randomness
from .positional_encoding import LearnedPositionalEncoding, SinePositionalEncoding
from .transformer import DetrTransformerDecoder, DetrTransformerDecoderLayer, DynamicConv, Transformer

View File

@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from abc import ABCMeta, abstractmethod
import torch
from ..builder import MASK_ASSIGNERS, build_match_cost
try:
from scipy.optimize import linear_sum_assignment
except ImportError:
linear_sum_assignment = None
class AssignResult(metaclass=ABCMeta):
"""Collection of assign results."""
def __init__(self, num_gts, gt_inds, labels):
self.num_gts = num_gts
self.gt_inds = gt_inds
self.labels = labels
@property
def info(self):
info = {
"num_gts": self.num_gts,
"gt_inds": self.gt_inds,
"labels": self.labels,
}
return info
class BaseAssigner(metaclass=ABCMeta):
"""Base assigner that assigns boxes to ground truth boxes."""
@abstractmethod
def assign(self, masks, gt_masks, gt_masks_ignore=None, gt_labels=None):
"""Assign boxes to either a ground truth boxes or a negative boxes."""
pass
@MASK_ASSIGNERS.register_module()
class MaskHungarianAssigner(BaseAssigner):
"""Computes one-to-one matching between predictions and ground truth for
mask.
This class computes an assignment between the targets and the predictions
based on the costs. The costs are weighted sum of three components:
classification cost, regression L1 cost and regression iou cost. The
targets don't include the no_object, so generally there are more
predictions than targets. After the one-to-one matching, the un-matched
are treated as backgrounds. Thus each query prediction will be assigned
with `0` or a positive integer indicating the ground truth index:
- 0: negative sample, no assigned gt
- positive integer: positive sample, index (1-based) of assigned gt
Args:
cls_cost (obj:`mmcv.ConfigDict`|dict): Classification cost config.
mask_cost (obj:`mmcv.ConfigDict`|dict): Mask cost config.
dice_cost (obj:`mmcv.ConfigDict`|dict): Dice cost config.
"""
def __init__(
self,
cls_cost=dict(type="ClassificationCost", weight=1.0),
dice_cost=dict(type="DiceCost", weight=1.0),
mask_cost=dict(type="MaskFocalCost", weight=1.0),
):
self.cls_cost = build_match_cost(cls_cost)
self.dice_cost = build_match_cost(dice_cost)
self.mask_cost = build_match_cost(mask_cost)
def assign(self, cls_pred, mask_pred, gt_labels, gt_masks, img_meta, gt_masks_ignore=None, eps=1e-7):
"""Computes one-to-one matching based on the weighted costs.
This method assign each query prediction to a ground truth or
background. The `assigned_gt_inds` with -1 means don't care,
0 means negative sample, and positive number is the index (1-based)
of assigned gt.
The assignment is done in the following steps, the order matters.
1. assign every prediction to -1
2. compute the weighted costs
3. do Hungarian matching on CPU based on the costs
4. assign all to 0 (background) first, then for each matched pair
between predictions and gts, treat this prediction as foreground
and assign the corresponding gt index (plus 1) to it.
Args:
mask_pred (Tensor): Predicted mask, shape [num_query, h, w]
cls_pred (Tensor): Predicted classification logits, shape
[num_query, num_class].
gt_masks (Tensor): Ground truth mask, shape [num_gt, h, w].
gt_labels (Tensor): Label of `gt_masks`, shape (num_gt,).
img_meta (dict): Meta information for current image.
gt_masks_ignore (Tensor, optional): Ground truth masks that are
labelled as `ignored`. Default None.
eps (int | float, optional): A value added to the denominator for
numerical stability. Default 1e-7.
Returns:
:obj:`AssignResult`: The assigned result.
"""
assert gt_masks_ignore is None, "Only case when gt_masks_ignore is None is supported."
num_gts, num_queries = gt_labels.shape[0], cls_pred.shape[0]
# 1. assign -1 by default
assigned_gt_inds = cls_pred.new_full((num_queries,), -1, dtype=torch.long)
assigned_labels = cls_pred.new_full((num_queries,), -1, dtype=torch.long)
if num_gts == 0 or num_queries == 0:
# No ground truth or boxes, return empty assignment
if num_gts == 0:
# No ground truth, assign all to background
assigned_gt_inds[:] = 0
return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels)
# 2. compute the weighted costs
# classification and maskcost.
if self.cls_cost.weight != 0 and cls_pred is not None:
cls_cost = self.cls_cost(cls_pred, gt_labels)
else:
cls_cost = 0
if self.mask_cost.weight != 0:
# mask_pred shape = [nq, h, w]
# gt_mask shape = [ng, h, w]
# mask_cost shape = [nq, ng]
mask_cost = self.mask_cost(mask_pred, gt_masks)
else:
mask_cost = 0
if self.dice_cost.weight != 0:
dice_cost = self.dice_cost(mask_pred, gt_masks)
else:
dice_cost = 0
cost = cls_cost + mask_cost + dice_cost
# 3. do Hungarian matching on CPU using linear_sum_assignment
cost = cost.detach().cpu()
if linear_sum_assignment is None:
raise ImportError('Please run "pip install scipy" ' "to install scipy first.")
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
matched_row_inds = torch.from_numpy(matched_row_inds).to(cls_pred.device)
matched_col_inds = torch.from_numpy(matched_col_inds).to(cls_pred.device)
# 4. assign backgrounds and foregrounds
# assign all indices to backgrounds first
assigned_gt_inds[:] = 0
# assign foregrounds based on matching results
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels)

View File

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import torch
from mmcv.ops import point_sample
def get_uncertainty(mask_pred, labels):
"""Estimate uncertainty based on pred logits.
We estimate uncertainty as L1 distance between 0.0 and the logits
prediction in 'mask_pred' for the foreground class in `classes`.
Args:
mask_pred (Tensor): mask predication logits, shape (num_rois,
num_classes, mask_height, mask_width).
labels (list[Tensor]): Either predicted or ground truth label for
each predicted mask, of length num_rois.
Returns:
scores (Tensor): Uncertainty scores with the most uncertain
locations having the highest uncertainty score,
shape (num_rois, 1, mask_height, mask_width)
"""
if mask_pred.shape[1] == 1:
gt_class_logits = mask_pred.clone()
else:
inds = torch.arange(mask_pred.shape[0], device=mask_pred.device)
gt_class_logits = mask_pred[inds, labels].unsqueeze(1)
return -torch.abs(gt_class_logits)
def get_uncertain_point_coords_with_randomness(
mask_pred, labels, num_points, oversample_ratio, importance_sample_ratio
):
"""Get ``num_points`` most uncertain points with random points during
train.
Sample points in [0, 1] x [0, 1] coordinate space based on their
uncertainty. The uncertainties are calculated for each point using
'get_uncertainty()' function that takes point's logit prediction as
input.
Args:
mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
mask_height, mask_width) for class-specific or class-agnostic
prediction.
labels (list): The ground truth class for each instance.
num_points (int): The number of points to sample.
oversample_ratio (int): Oversampling parameter.
importance_sample_ratio (float): Ratio of points that are sampled
via importnace sampling.
Returns:
point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
that contains the coordinates sampled points.
"""
assert oversample_ratio >= 1
assert 0 <= importance_sample_ratio <= 1
batch_size = mask_pred.shape[0]
num_sampled = int(num_points * oversample_ratio)
point_coords = torch.rand(batch_size, num_sampled, 2, device=mask_pred.device)
point_logits = point_sample(mask_pred, point_coords)
# It is crucial to calculate uncertainty based on the sampled
# prediction value for the points. Calculating uncertainties of the
# coarse predictions first and sampling them for points leads to
# incorrect results. To illustrate this: assume uncertainty func(
# logits)=-abs(logits), a sampled point between two coarse
# predictions with -1 and 1 logits has 0 logits, and therefore 0
# uncertainty value. However, if we calculate uncertainties for the
# coarse predictions first, both will have -1 uncertainty,
# and sampled point will get -1 uncertainty.
point_uncertainties = get_uncertainty(point_logits, labels)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_sampled * torch.arange(batch_size, dtype=torch.long, device=mask_pred.device)
idx += shift[:, None]
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(batch_size, num_uncertain_points, 2)
if num_random_points > 0:
rand_roi_coords = torch.rand(batch_size, num_random_points, 2, device=mask_pred.device)
point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
return point_coords

View File

@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import math
import torch
import torch.nn as nn
from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING
from mmcv.runner import BaseModule
@POSITIONAL_ENCODING.register_module()
class SinePositionalEncoding(BaseModule):
"""Position encoding with sine and cosine functions.
See `End-to-End Object Detection with Transformers
<https://arxiv.org/pdf/2005.12872>`_ for details.
Args:
num_feats (int): The feature dimension for each position
along x-axis or y-axis. Note the final returned dimension
for each position is 2 times of this value.
temperature (int, optional): The temperature used for scaling
the position embedding. Defaults to 10000.
normalize (bool, optional): Whether to normalize the position
embedding. Defaults to False.
scale (float, optional): A scale factor that scales the position
embedding. The scale will be used only when `normalize` is True.
Defaults to 2*pi.
eps (float, optional): A value added to the denominator for
numerical stability. Defaults to 1e-6.
offset (float): offset add to embed when do the normalization.
Defaults to 0.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
def __init__(
self, num_feats, temperature=10000, normalize=False, scale=2 * math.pi, eps=1e-6, offset=0.0, init_cfg=None
):
super(SinePositionalEncoding, self).__init__(init_cfg)
if normalize:
assert isinstance(scale, (float, int)), (
"when normalize is set," "scale should be provided and in float or int type, " f"found {type(scale)}"
)
self.num_feats = num_feats
self.temperature = temperature
self.normalize = normalize
self.scale = scale
self.eps = eps
self.offset = offset
def forward(self, mask):
"""Forward function for `SinePositionalEncoding`.
Args:
mask (Tensor): ByteTensor mask. Non-zero values representing
ignored positions, while zero values means valid positions
for this image. Shape [bs, h, w].
Returns:
pos (Tensor): Returned position embedding with shape
[bs, num_feats*2, h, w].
"""
# For convenience of exporting to ONNX, it's required to convert
# `masks` from bool to int.
mask = mask.to(torch.int)
not_mask = 1 - mask # logical_not
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
y_embed = (y_embed + self.offset) / (y_embed[:, -1:, :] + self.eps) * self.scale
x_embed = (x_embed + self.offset) / (x_embed[:, :, -1:] + self.eps) * self.scale
dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=mask.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
# use `view` instead of `flatten` for dynamically exporting to ONNX
B, H, W = mask.size()
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f"(num_feats={self.num_feats}, "
repr_str += f"temperature={self.temperature}, "
repr_str += f"normalize={self.normalize}, "
repr_str += f"scale={self.scale}, "
repr_str += f"eps={self.eps})"
return repr_str
@POSITIONAL_ENCODING.register_module()
class LearnedPositionalEncoding(BaseModule):
"""Position embedding with learnable embedding weights.
Args:
num_feats (int): The feature dimension for each position
along x-axis or y-axis. The final returned dimension for
each position is 2 times of this value.
row_num_embed (int, optional): The dictionary size of row embeddings.
Default 50.
col_num_embed (int, optional): The dictionary size of col embeddings.
Default 50.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(self, num_feats, row_num_embed=50, col_num_embed=50, init_cfg=dict(type="Uniform", layer="Embedding")):
super(LearnedPositionalEncoding, self).__init__(init_cfg)
self.row_embed = nn.Embedding(row_num_embed, num_feats)
self.col_embed = nn.Embedding(col_num_embed, num_feats)
self.num_feats = num_feats
self.row_num_embed = row_num_embed
self.col_num_embed = col_num_embed
def forward(self, mask):
"""Forward function for `LearnedPositionalEncoding`.
Args:
mask (Tensor): ByteTensor mask. Non-zero values representing
ignored positions, while zero values means valid positions
for this image. Shape [bs, h, w].
Returns:
pos (Tensor): Returned position embedding with shape
[bs, num_feats*2, h, w].
"""
h, w = mask.shape[-2:]
x = torch.arange(w, device=mask.device)
y = torch.arange(h, device=mask.device)
x_embed = self.col_embed(x)
y_embed = self.row_embed(y)
pos = (
torch.cat((x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(1, w, 1)), dim=-1)
.permute(2, 0, 1)
.unsqueeze(0)
.repeat(mask.shape[0], 1, 1, 1)
)
return pos
def __repr__(self):
"""str: a string that describes the module"""
repr_str = self.__class__.__name__
repr_str += f"(num_feats={self.num_feats}, "
repr_str += f"row_num_embed={self.row_num_embed}, "
repr_str += f"col_num_embed={self.col_num_embed})"
return repr_str

View File

@ -0,0 +1,989 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import math
import warnings
from typing import Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer, xavier_init
from mmcv.cnn.bricks.drop import build_dropout
from mmcv.cnn.bricks.registry import FEEDFORWARD_NETWORK, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE
from mmcv.cnn.bricks.transformer import BaseTransformerLayer, TransformerLayerSequence, build_transformer_layer_sequence
from mmcv.runner.base_module import BaseModule, Sequential
from mmcv.utils import deprecated_api_warning, to_2tuple
from torch.nn.init import normal_
from ..builder import TRANSFORMER
try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
except ImportError:
warnings.warn(
"`MultiScaleDeformableAttention` in MMCV has been moved to "
"`mmcv.ops.multi_scale_deform_attn`, please update your MMCV"
)
from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
class AdaptivePadding(nn.Module):
"""Applies padding to input (if needed) so that input can get fully covered
by filter you specified. It support two modes "same" and "corner". The
"same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
input. The "corner" mode would pad zero to bottom right.
Args:
kernel_size (int | tuple): Size of the kernel:
stride (int | tuple): Stride of the filter. Default: 1:
dilation (int | tuple): Spacing between kernel elements.
Default: 1
padding (str): Support "same" and "corner", "corner" mode
would pad zero to bottom right, and "same" mode would
pad zero around input. Default: "corner".
Example:
>>> kernel_size = 16
>>> stride = 16
>>> dilation = 1
>>> input = torch.rand(1, 1, 15, 17)
>>> adap_pad = AdaptivePadding(
>>> kernel_size=kernel_size,
>>> stride=stride,
>>> dilation=dilation,
>>> padding="corner")
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
>>> input = torch.rand(1, 1, 16, 17)
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
"""
def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"):
super(AdaptivePadding, self).__init__()
assert padding in ("same", "corner")
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
padding = to_2tuple(padding)
dilation = to_2tuple(dilation)
self.padding = padding
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
def get_pad_shape(self, input_shape):
input_h, input_w = input_shape
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.stride
output_h = math.ceil(input_h / stride_h)
output_w = math.ceil(input_w / stride_w)
pad_h = max((output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
pad_w = max((output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
return pad_h, pad_w
def forward(self, x):
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
if pad_h > 0 or pad_w > 0:
if self.padding == "corner":
x = F.pad(x, [0, pad_w, 0, pad_h])
elif self.padding == "same":
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
return x
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer groups feature map by kernel_size, and applies norm and linear
layers to the grouped feature map. Our implementation uses `nn.Unfold` to
merge patch, which is about 25% faster than original implementation.
Instead, we need to modify pretrained models for compatibility.
Args:
in_channels (int): The num of input channels.
to gets fully covered by filter and stride you specified..
Default: True.
out_channels (int): The num of output channels.
kernel_size (int | tuple, optional): the kernel size in the unfold
layer. Defaults to 2.
stride (int | tuple, optional): the stride of the sliding blocks in the
unfold layer. Default: None. (Would be set as `kernel_size`)
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int | tuple, optional): dilation parameter in the unfold
layer. Default: 1.
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size=2,
stride=None,
padding="corner",
dilation=1,
bias=False,
norm_cfg=dict(type="LN"),
init_cfg=None,
):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
if stride:
stride = stride
else:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adap_padding = AdaptivePadding(
kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding
)
# disable the padding of unfold
padding = 0
else:
self.adap_padding = None
padding = to_2tuple(padding)
self.sampler = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride)
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, input_size):
"""
Args:
x (Tensor): Has shape (B, H*W, C_in).
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
Default: None.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- out_size (tuple[int]): Spatial shape of x, arrange as
(Merged_H, Merged_W).
"""
B, L, C = x.shape
assert isinstance(input_size, Sequence), f"Expect " f"input_size is " f"`Sequence` " f"but get {input_size}"
H, W = input_size
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
if self.adap_padding:
x = self.adap_padding(x)
H, W = x.shape[-2:]
x = self.sampler(x)
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
out_h = (
H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) - 1
) // self.sampler.stride[0] + 1
out_w = (
W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) - 1
) // self.sampler.stride[1] + 1
output_size = (out_h, out_w)
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
return x, output_size
def inverse_sigmoid(x, eps=1e-5):
"""Inverse function of sigmoid.
Args:
x (Tensor): The tensor to do the
inverse.
eps (float): EPS avoid numerical
overflow. Defaults 1e-5.
Returns:
Tensor: The x has passed the inverse
function of sigmoid, has same
shape with input.
"""
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1 / x2)
@FEEDFORWARD_NETWORK.register_module(force=True)
class FFN(BaseModule):
"""Implements feed-forward networks (FFNs) with identity connection.
Args:
embed_dims (int): The feature dimension. Same as
`MultiheadAttention`. Defaults: 256.
feedforward_channels (int): The hidden dimension of FFNs.
Defaults: 1024.
num_fcs (int, optional): The number of fully-connected layers in
FFNs. Default: 2.
act_cfg (dict, optional): The activation config for FFNs.
Default: dict(type='ReLU')
ffn_drop (float, optional): Probability of an element to be
zeroed in FFN. Default 0.0.
add_identity (bool, optional): Whether to add the
identity connection. Default: `True`.
dropout_layer (obj:`ConfigDict`): The dropout_layer used
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
@deprecated_api_warning({"dropout": "ffn_drop", "add_residual": "add_identity"}, cls_name="FFN")
def __init__(
self,
embed_dims=256,
feedforward_channels=1024,
num_fcs=2,
act_cfg=dict(type="ReLU", inplace=True),
ffn_drop=0.0,
dropout_layer=None,
add_identity=True,
init_cfg=None,
with_cp=False,
**kwargs,
):
super().__init__(init_cfg)
assert num_fcs >= 2, "num_fcs should be no less " f"than 2. got {num_fcs}."
self.embed_dims = embed_dims
self.feedforward_channels = feedforward_channels
self.num_fcs = num_fcs
self.act_cfg = act_cfg
self.activate = build_activation_layer(act_cfg)
self.with_cp = with_cp
layers = []
in_channels = embed_dims
for _ in range(num_fcs - 1):
layers.append(Sequential(Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(ffn_drop)))
in_channels = feedforward_channels
layers.append(Linear(feedforward_channels, embed_dims))
layers.append(nn.Dropout(ffn_drop))
self.layers = Sequential(*layers)
self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else torch.nn.Identity()
self.add_identity = add_identity
@deprecated_api_warning({"residual": "identity"}, cls_name="FFN")
def forward(self, x, identity=None):
"""Forward function for `FFN`.
The function would add x to the output tensor if residue is None.
"""
if self.with_cp and x.requires_grad:
out = cp.checkpoint(self.layers, x)
else:
out = self.layers(x)
if not self.add_identity:
return self.dropout_layer(out)
if identity is None:
identity = x
return identity + self.dropout_layer(out)
@TRANSFORMER_LAYER.register_module()
class DetrTransformerDecoderLayer(BaseTransformerLayer):
"""Implements decoder layer in DETR transformer.
Args:
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
Configs for self_attention or cross_attention, the order
should be consistent with it in `operation_order`. If it is
a dict, it would be expand to the number of attention in
`operation_order`.
feedforward_channels (int): The hidden dimension for FFNs.
ffn_dropout (float): Probability of an element to be zeroed
in ffn. Default 0.0.
operation_order (tuple[str]): The execution order of operation
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
DefaultNone
act_cfg (dict): The activation config for FFNs. Default: `LN`
norm_cfg (dict): Config dict for normalization layer.
Default: `LN`.
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
Default2.
"""
def __init__(
self,
attn_cfgs,
feedforward_channels,
ffn_dropout=0.0,
operation_order=None,
act_cfg=dict(type="ReLU", inplace=True),
norm_cfg=dict(type="LN"),
ffn_num_fcs=2,
**kwargs,
):
super(DetrTransformerDecoderLayer, self).__init__(
attn_cfgs=attn_cfgs,
feedforward_channels=feedforward_channels,
ffn_dropout=ffn_dropout,
operation_order=operation_order,
act_cfg=act_cfg,
norm_cfg=norm_cfg,
ffn_num_fcs=ffn_num_fcs,
**kwargs,
)
assert len(operation_order) == 6
assert set(operation_order) == set(["self_attn", "norm", "cross_attn", "ffn"])
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class DetrTransformerEncoder(TransformerLayerSequence):
"""TransformerEncoder of DETR.
Args:
post_norm_cfg (dict): Config of last normalization layer. Default
`LN`. Only used when `self.pre_norm` is `True`
"""
def __init__(self, *args, post_norm_cfg=dict(type="LN"), **kwargs):
super(DetrTransformerEncoder, self).__init__(*args, **kwargs)
if post_norm_cfg is not None:
self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None
else:
assert not self.pre_norm, f"Use prenorm in " f"{self.__class__.__name__}," f"Please specify post_norm_cfg"
self.post_norm = None
def forward(self, *args, **kwargs):
"""Forward function for `TransformerCoder`.
Returns:
Tensor: forwarded results with shape [num_query, bs, embed_dims].
"""
x = super(DetrTransformerEncoder, self).forward(*args, **kwargs)
if self.post_norm is not None:
x = self.post_norm(x)
return x
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class DetrTransformerDecoder(TransformerLayerSequence):
"""Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
post_norm_cfg (dict): Config of last normalization layer. Default
`LN`.
"""
def __init__(self, *args, post_norm_cfg=dict(type="LN"), return_intermediate=False, **kwargs):
super(DetrTransformerDecoder, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
if post_norm_cfg is not None:
self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1]
else:
self.post_norm = None
def forward(self, query, *args, **kwargs):
"""Forward function for `TransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
if not self.return_intermediate:
x = super().forward(query, *args, **kwargs)
if self.post_norm:
x = self.post_norm(x)[None]
return x
intermediate = []
for layer in self.layers:
query = layer(query, *args, **kwargs)
if self.return_intermediate:
if self.post_norm is not None:
intermediate.append(self.post_norm(query))
else:
intermediate.append(query)
return torch.stack(intermediate)
@TRANSFORMER.register_module()
class Transformer(BaseModule):
"""Implements the DETR transformer.
Following the official DETR implementation, this module copy-paste
from torch.nn.Transformer with modifications:
* positional encodings are passed in MultiheadAttention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
See `paper: End-to-End Object Detection with Transformers
<https://arxiv.org/pdf/2005.12872>`_ for details.
Args:
encoder (`mmcv.ConfigDict` | Dict): Config of
TransformerEncoder. Defaults to None.
decoder ((`mmcv.ConfigDict` | Dict)): Config of
TransformerDecoder. Defaults to None
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Defaults to None.
"""
def __init__(self, encoder=None, decoder=None, init_cfg=None):
super(Transformer, self).__init__(init_cfg=init_cfg)
self.encoder = build_transformer_layer_sequence(encoder)
self.decoder = build_transformer_layer_sequence(decoder)
self.embed_dims = self.encoder.embed_dims
def init_weights(self):
# follow the official DETR to init parameters
for m in self.modules():
if hasattr(m, "weight") and m.weight.dim() > 1:
xavier_init(m, distribution="uniform")
self._is_init = True
def forward(self, x, mask, query_embed, pos_embed):
"""Forward function for `Transformer`.
Args:
x (Tensor): Input query with shape [bs, c, h, w] where
c = embed_dims.
mask (Tensor): The key_padding_mask used for encoder and decoder,
with shape [bs, h, w].
query_embed (Tensor): The query embedding for decoder, with shape
[num_query, c].
pos_embed (Tensor): The positional encoding for encoder and
decoder, with the same shape as `x`.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- out_dec: Output from decoder. If return_intermediate_dec \
is True output has shape [num_dec_layers, bs,
num_query, embed_dims], else has shape [1, bs, \
num_query, embed_dims].
- memory: Output results from encoder, with shape \
[bs, embed_dims, h, w].
"""
bs, c, h, w = x.shape
# use `view` instead of `flatten` for dynamically exporting to ONNX
x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c]
pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # [num_query, dim] -> [num_query, bs, dim]
mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w]
memory = self.encoder(query=x, key=None, value=None, query_pos=pos_embed, query_key_padding_mask=mask)
target = torch.zeros_like(query_embed)
# out_dec: [num_layers, num_query, bs, dim]
out_dec = self.decoder(
query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask
)
out_dec = out_dec.transpose(1, 2)
memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)
return out_dec, memory
@TRANSFORMER_LAYER_SEQUENCE.register_module()
class DeformableDetrTransformerDecoder(TransformerLayerSequence):
"""Implements the decoder in DETR transformer.
Args:
return_intermediate (bool): Whether to return intermediate outputs.
coder_norm_cfg (dict): Config of last normalization layer. Default
`LN`.
"""
def __init__(self, *args, return_intermediate=False, **kwargs):
super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs)
self.return_intermediate = return_intermediate
def forward(self, query, *args, reference_points=None, valid_ratios=None, reg_branches=None, **kwargs):
"""Forward function for `TransformerDecoder`.
Args:
query (Tensor): Input query with shape
`(num_query, bs, embed_dims)`.
reference_points (Tensor): The reference
points of offset. has shape
(bs, num_query, 4) when as_two_stage,
otherwise has shape ((bs, num_query, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
reg_branch: (obj:`nn.ModuleList`): Used for
refining the regression results. Only would
be passed when with_box_refine is True,
otherwise would be passed a `None`.
Returns:
Tensor: Results with shape [1, num_query, bs, embed_dims] when
return_intermediate is `False`, otherwise it has shape
[num_layers, num_query, bs, embed_dims].
"""
output = query
intermediate = []
intermediate_reference_points = []
for lid, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = (
reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
)
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]
output = layer(output, *args, reference_points=reference_points_input, **kwargs)
output = output.permute(1, 0, 2)
if reg_branches is not None:
tmp = reg_branches[lid](output)
if reference_points.shape[-1] == 4:
new_reference_points = tmp + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
else:
assert reference_points.shape[-1] == 2
new_reference_points = tmp
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
output = output.permute(1, 0, 2)
if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
return output, reference_points
@TRANSFORMER.register_module()
class DeformableDetrTransformer(Transformer):
"""Implements the DeformableDETR transformer.
Args:
as_two_stage (bool): Generate query from encoder features.
Default: False.
num_feature_levels (int): Number of feature maps from FPN:
Default: 4.
two_stage_num_proposals (int): Number of proposals when set
`as_two_stage` as True. Default: 300.
"""
def __init__(self, as_two_stage=False, num_feature_levels=4, two_stage_num_proposals=300, **kwargs):
super(DeformableDetrTransformer, self).__init__(**kwargs)
self.as_two_stage = as_two_stage
self.num_feature_levels = num_feature_levels
self.two_stage_num_proposals = two_stage_num_proposals
self.embed_dims = self.encoder.embed_dims
self.init_layers()
def init_layers(self):
"""Initialize layers of the DeformableDetrTransformer."""
self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dims))
if self.as_two_stage:
self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
self.enc_output_norm = nn.LayerNorm(self.embed_dims)
self.pos_trans = nn.Linear(self.embed_dims * 2, self.embed_dims * 2)
self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
else:
self.reference_points = nn.Linear(self.embed_dims, 2)
def init_weights(self):
"""Initialize the transformer weights."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention):
m.init_weights()
if not self.as_two_stage:
xavier_init(self.reference_points, distribution="uniform", bias=0.0)
normal_(self.level_embeds)
def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
"""Generate proposals from encoded memory.
Args:
memory (Tensor) : The output of encoder,
has shape (bs, num_key, embed_dim). num_key is
equal the number of points on feature map from
all level.
memory_padding_mask (Tensor): Padding mask for memory.
has shape (bs, num_key).
spatial_shapes (Tensor): The shape of all feature maps.
has shape (num_level, 2).
Returns:
tuple: A tuple of feature map and bbox prediction.
- output_memory (Tensor): The input of decoder, \
has shape (bs, num_key, embed_dim). num_key is \
equal the number of points on feature map from \
all levels.
- output_proposals (Tensor): The normalized proposal \
after a inverse sigmoid, has shape \
(bs, num_keys, 4).
"""
N, S, C = memory.shape
proposals = []
_cur = 0
for lvl, (H, W) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(N, H, W, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device),
torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device),
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
proposal = torch.cat((grid, wh), -1).view(N, -1, 4)
proposals.append(proposal)
_cur += H * W
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals / (1 - output_proposals))
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
"""Get the reference points used in decoder.
Args:
spatial_shapes (Tensor): The shape of all
feature maps, has shape (num_level, 2).
valid_ratios (Tensor): The radios of valid
points on the feature map, has shape
(bs, num_levels, 2)
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
reference_points_list = []
for lvl, (H, W) in enumerate(spatial_shapes):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),
torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def get_valid_ratio(self, mask):
"""Get the valid radios of feature maps of all level."""
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000):
"""Get the position embedding of proposal."""
scale = 2 * math.pi
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
# N, L, 4
proposals = proposals.sigmoid() * scale
# N, L, 4, 128
pos = proposals[:, :, :, None] / dim_t
# N, L, 4, 64, 2
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
return pos
def forward(
self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, reg_branches=None, cls_branches=None, **kwargs
):
"""Forward function for `Transformer`.
Args:
mlvl_feats (list(Tensor)): Input queries from
different level. Each element has shape
[bs, embed_dims, h, w].
mlvl_masks (list(Tensor)): The key_padding_mask from
different level used for encoder and decoder,
each element has shape [bs, h, w].
query_embed (Tensor): The query embedding for decoder,
with shape [num_query, c].
mlvl_pos_embeds (list(Tensor)): The positional encoding
of feats from different level, has the shape
[bs, embed_dims, h, w].
reg_branches (obj:`nn.ModuleList`): Regression heads for
feature maps from each decoder layer. Only would
be passed when
`with_box_refine` is True. Default to None.
cls_branches (obj:`nn.ModuleList`): Classification heads
for feature maps from each decoder layer. Only would
be passed when `as_two_stage`
is True. Default to None.
Returns:
tuple[Tensor]: results of decoder containing the following tensor.
- inter_states: Outputs from decoder. If
return_intermediate_dec is True output has shape \
(num_dec_layers, bs, num_query, embed_dims), else has \
shape (1, bs, num_query, embed_dims).
- init_reference_out: The initial value of reference \
points, has shape (bs, num_queries, 4).
- inter_references_out: The internal value of reference \
points in decoder, has shape \
(num_dec_layers, bs,num_query, embed_dims)
- enc_outputs_class: The classification score of \
proposals generated from \
encoder's feature maps, has shape \
(batch, h*w, num_classes). \
Only would be returned when `as_two_stage` is True, \
otherwise None.
- enc_outputs_coord_unact: The regression results \
generated from encoder's feature maps., has shape \
(batch, h*w, 4). Only would \
be returned when `as_two_stage` is True, \
otherwise None.
"""
assert self.as_two_stage or query_embed is not None
feat_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (feat, mask, pos_embed) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
bs, c, h, w = feat.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
feat = feat.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
feat_flatten.append(feat)
mask_flatten.append(mask)
feat_flatten = torch.cat(feat_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feat_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1)
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=feat.device)
feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
memory = self.encoder(
query=feat_flatten,
key=None,
value=None,
query_pos=lvl_pos_embed_flatten,
query_key_padding_mask=mask_flatten,
spatial_shapes=spatial_shapes,
reference_points=reference_points,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
**kwargs,
)
memory = memory.permute(1, 0, 2)
bs, _, c = memory.shape
if self.as_two_stage:
output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory)
enc_outputs_coord_unact = reg_branches[self.decoder.num_layers](output_memory) + output_proposals
topk = self.two_stage_num_proposals
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
topk_coords_unact = topk_coords_unact.detach()
reference_points = topk_coords_unact.sigmoid()
init_reference_out = reference_points
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
query_pos, query = torch.split(pos_trans_out, c, dim=2)
else:
query_pos, query = torch.split(query_embed, c, dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
query = query.unsqueeze(0).expand(bs, -1, -1)
reference_points = self.reference_points(query_pos).sigmoid()
init_reference_out = reference_points
# decoder
query = query.permute(1, 0, 2)
memory = memory.permute(1, 0, 2)
query_pos = query_pos.permute(1, 0, 2)
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=memory,
query_pos=query_pos,
key_padding_mask=mask_flatten,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reg_branches=reg_branches,
**kwargs,
)
inter_references_out = inter_references
if self.as_two_stage:
return inter_states, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact
return inter_states, init_reference_out, inter_references_out, None, None
@TRANSFORMER.register_module()
class DynamicConv(BaseModule):
"""Implements Dynamic Convolution.
This module generate parameters for each sample and
use bmm to implement 1*1 convolution. Code is modified
from the `official github repo <https://github.com/PeizeSun/
SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py#L258>`_ .
Args:
in_channels (int): The input feature channel.
Defaults to 256.
feat_channels (int): The inner feature channel.
Defaults to 64.
out_channels (int, optional): The output feature channel.
When not specified, it will be set to `in_channels`
by default
input_feat_shape (int): The shape of input feature.
Defaults to 7.
with_proj (bool): Project two-dimentional feature to
one-dimentional feature. Default to True.
act_cfg (dict): The activation config for DynamicConv.
norm_cfg (dict): Config dict for normalization layer. Default
layer normalization.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
"""
def __init__(
self,
in_channels=256,
feat_channels=64,
out_channels=None,
input_feat_shape=7,
with_proj=True,
act_cfg=dict(type="ReLU", inplace=True),
norm_cfg=dict(type="LN"),
init_cfg=None,
):
super(DynamicConv, self).__init__(init_cfg)
self.in_channels = in_channels
self.feat_channels = feat_channels
self.out_channels_raw = out_channels
self.input_feat_shape = input_feat_shape
self.with_proj = with_proj
self.act_cfg = act_cfg
self.norm_cfg = norm_cfg
self.out_channels = out_channels if out_channels else in_channels
self.num_params_in = self.in_channels * self.feat_channels
self.num_params_out = self.out_channels * self.feat_channels
self.dynamic_layer = nn.Linear(self.in_channels, self.num_params_in + self.num_params_out)
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
self.activation = build_activation_layer(act_cfg)
num_output = self.out_channels * input_feat_shape**2
if self.with_proj:
self.fc_layer = nn.Linear(num_output, self.out_channels)
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
def forward(self, param_feature, input_feature):
"""Forward function for `DynamicConv`.
Args:
param_feature (Tensor): The feature can be used
to generate the parameter, has shape
(num_all_proposals, in_channels).
input_feature (Tensor): Feature that
interact with parameters, has shape
(num_all_proposals, in_channels, H, W).
Returns:
Tensor: The output feature has shape
(num_all_proposals, out_channels).
"""
input_feature = input_feature.flatten(2).permute(2, 0, 1)
input_feature = input_feature.permute(1, 0, 2)
parameters = self.dynamic_layer(param_feature)
param_in = parameters[:, : self.num_params_in].view(-1, self.in_channels, self.feat_channels)
param_out = parameters[:, -self.num_params_out :].view(-1, self.feat_channels, self.out_channels)
# input_feature has shape (num_all_proposals, H*W, in_channels)
# param_in has shape (num_all_proposals, in_channels, feat_channels)
# feature has shape (num_all_proposals, H*W, feat_channels)
features = torch.bmm(input_feature, param_in)
features = self.norm_in(features)
features = self.activation(features)
# param_out has shape (batch_size, feat_channels, out_channels)
features = torch.bmm(features, param_out)
features = self.norm_out(features)
features = self.activation(features)
if self.with_proj:
features = features.flatten(1)
features = self.fc_layer(features)
features = self.fc_norm(features)
features = self.activation(features)
return features

View File

@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/fundamentalvision/Deformable-DETR/tree/main/models/ops/modules
# https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
from .ms_deform_attn import MSDeformAttn

View File

@ -0,0 +1,185 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import math
import warnings
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Function
from torch.cuda.amp import custom_fwd
from torch.nn.init import constant_, xavier_uniform_
class MSDeformAttnFunction(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(
ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step
):
output = ms_deform_attn_core_pytorch(
value,
value_spatial_shapes,
# value_level_start_index,
sampling_locations,
attention_weights,
)
return output
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
# for debug and test only,
# need to use cuda version instead
N_, S_, M_, D_ = value.shape
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_)
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
# N_*M_, D_, Lq_, P_
sampling_value_l_ = F.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_)
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_)
return output.transpose(1, 2).contiguous()
def _is_power_of_2(n):
if (not isinstance(n, int)) or (n < 0):
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
return (n & (n - 1) == 0) and n != 0
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, ratio=1.0):
"""Multi-Scale Deformable Attention Module.
:param d_model hidden dimension
:param n_levels number of feature levels
:param n_heads number of attention heads
:param n_points number of sampling points per attention head per feature level
"""
super().__init__()
if d_model % n_heads != 0:
raise ValueError("d_model must be divisible by n_heads, " "but got {} and {}".format(d_model, n_heads))
_d_per_head = d_model // n_heads
# you'd better set _d_per_head to a power of 2
# which is more efficient in our CUDA implementation
if not _is_power_of_2(_d_per_head):
warnings.warn(
"You'd better set d_model in MSDeformAttn to make "
"the dimension of each attention head a power of 2 "
"which is more efficient in our CUDA implementation."
)
self.im2col_step = 64
self.d_model = d_model
self.n_levels = n_levels
self.n_heads = n_heads
self.n_points = n_points
self.ratio = ratio
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
self.value_proj = nn.Linear(d_model, int(d_model * ratio))
self.output_proj = nn.Linear(int(d_model * ratio), d_model)
self._reset_parameters()
def _reset_parameters(self):
constant_(self.sampling_offsets.weight.data, 0.0)
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
.view(self.n_heads, 1, 1, 2)
.repeat(1, self.n_levels, self.n_points, 1)
)
for i in range(self.n_points):
grid_init[:, :, i, :] *= i + 1
with torch.no_grad():
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
constant_(self.attention_weights.weight.data, 0.0)
constant_(self.attention_weights.bias.data, 0.0)
xavier_uniform_(self.value_proj.weight.data)
constant_(self.value_proj.bias.data, 0.0)
xavier_uniform_(self.output_proj.weight.data)
constant_(self.output_proj.bias.data, 0.0)
def forward(
self,
query,
reference_points,
input_flatten,
input_spatial_shapes,
input_level_start_index,
input_padding_mask=None,
):
"""
:param query (N, Length_{query}, C)
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
:param input_flatten (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C)
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
:param input_padding_mask (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l), True for padding elements, False for non-padding elements
:return output (N, Length_{query}, C)
"""
# print(query.shape)
# print(reference_points.shape)
# print(input_flatten.shape)
# print(input_spatial_shapes.shape)
# print(input_level_start_index.shape)
# print(input_spatial_shapes)
# print(input_level_start_index)
N, Len_q, _ = query.shape
N, Len_in, _ = input_flatten.shape
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
value = self.value_proj(input_flatten)
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, int(self.ratio * self.d_model) // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
sampling_locations = (
reference_points[:, :, None, :, None, :]
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
)
elif reference_points.shape[-1] == 4:
sampling_locations = (
reference_points[:, :, None, :, None, :2]
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
)
else:
raise ValueError(
"Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1])
)
output = MSDeformAttnFunction.apply(
value,
input_spatial_shapes,
input_level_start_index,
sampling_locations,
attention_weights,
self.im2col_step,
)
output = self.output_proj(output)
return output

File diff suppressed because one or more lines are too long