Add semantic segmentation (Mask2Former) code (#186)
Add semantic segmentation (Mask2Former based on ViT-Adapter) code + update demo notebook for segmentation with a dedicated section.pull/189/head
parent
d5b0405eff
commit
91d8cd81c2
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .core import * # noqa: F403
|
||||
from .models import * # noqa: F403
|
||||
from .ops import * # noqa: F403
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from mmseg.core.evaluation import * # noqa: F403
|
||||
from mmseg.core.seg import * # noqa: F403
|
||||
|
||||
from .anchor import * # noqa: F403
|
||||
from .box import * # noqa: F403
|
||||
from .utils import * # noqa: F403
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .point_generator import MlvlPointGenerator # noqa: F403
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
PRIOR_GENERATORS = Registry("Generator for anchors and points")
|
||||
|
||||
ANCHOR_GENERATORS = PRIOR_GENERATORS
|
||||
|
||||
|
||||
def build_prior_generator(cfg, default_args=None):
|
||||
return build_from_cfg(cfg, PRIOR_GENERATORS, default_args)
|
||||
|
||||
|
||||
def build_anchor_generator(cfg, default_args=None):
|
||||
warnings.warn("``build_anchor_generator`` would be deprecated soon, please use " "``build_prior_generator`` ")
|
||||
return build_prior_generator(cfg, default_args=default_args)
|
|
@ -0,0 +1,205 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.modules.utils import _pair
|
||||
|
||||
from .builder import PRIOR_GENERATORS
|
||||
|
||||
|
||||
@PRIOR_GENERATORS.register_module()
|
||||
class MlvlPointGenerator:
|
||||
"""Standard points generator for multi-level (Mlvl) feature maps in 2D
|
||||
points-based detectors.
|
||||
|
||||
Args:
|
||||
strides (list[int] | list[tuple[int, int]]): Strides of anchors
|
||||
in multiple feature levels in order (w, h).
|
||||
offset (float): The offset of points, the value is normalized with
|
||||
corresponding stride. Defaults to 0.5.
|
||||
"""
|
||||
|
||||
def __init__(self, strides, offset=0.5):
|
||||
self.strides = [_pair(stride) for stride in strides]
|
||||
self.offset = offset
|
||||
|
||||
@property
|
||||
def num_levels(self):
|
||||
"""int: number of feature levels that the generator will be applied"""
|
||||
return len(self.strides)
|
||||
|
||||
@property
|
||||
def num_base_priors(self):
|
||||
"""list[int]: The number of priors (points) at a point
|
||||
on the feature grid"""
|
||||
return [1 for _ in range(len(self.strides))]
|
||||
|
||||
def _meshgrid(self, x, y, row_major=True):
|
||||
yy, xx = torch.meshgrid(y, x)
|
||||
if row_major:
|
||||
# warning .flatten() would cause error in ONNX exporting
|
||||
# have to use reshape here
|
||||
return xx.reshape(-1), yy.reshape(-1)
|
||||
|
||||
else:
|
||||
return yy.reshape(-1), xx.reshape(-1)
|
||||
|
||||
def grid_priors(self, featmap_sizes, dtype=torch.float32, device="cuda", with_stride=False):
|
||||
"""Generate grid points of multiple feature levels.
|
||||
|
||||
Args:
|
||||
featmap_sizes (list[tuple]): List of feature map sizes in
|
||||
multiple feature levels, each size arrange as
|
||||
as (h, w).
|
||||
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
|
||||
device (str): The device where the anchors will be put on.
|
||||
with_stride (bool): Whether to concatenate the stride to
|
||||
the last dimension of points.
|
||||
|
||||
Return:
|
||||
list[torch.Tensor]: Points of multiple feature levels.
|
||||
The sizes of each tensor should be (N, 2) when with stride is
|
||||
``False``, where N = width * height, width and height
|
||||
are the sizes of the corresponding feature level,
|
||||
and the last dimension 2 represent (coord_x, coord_y),
|
||||
otherwise the shape should be (N, 4),
|
||||
and the last dimension 4 represent
|
||||
(coord_x, coord_y, stride_w, stride_h).
|
||||
"""
|
||||
|
||||
assert self.num_levels == len(featmap_sizes)
|
||||
multi_level_priors = []
|
||||
for i in range(self.num_levels):
|
||||
priors = self.single_level_grid_priors(
|
||||
featmap_sizes[i], level_idx=i, dtype=dtype, device=device, with_stride=with_stride
|
||||
)
|
||||
multi_level_priors.append(priors)
|
||||
return multi_level_priors
|
||||
|
||||
def single_level_grid_priors(self, featmap_size, level_idx, dtype=torch.float32, device="cuda", with_stride=False):
|
||||
"""Generate grid Points of a single level.
|
||||
|
||||
Note:
|
||||
This function is usually called by method ``self.grid_priors``.
|
||||
|
||||
Args:
|
||||
featmap_size (tuple[int]): Size of the feature maps, arrange as
|
||||
(h, w).
|
||||
level_idx (int): The index of corresponding feature map level.
|
||||
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
|
||||
device (str, optional): The device the tensor will be put on.
|
||||
Defaults to 'cuda'.
|
||||
with_stride (bool): Concatenate the stride to the last dimension
|
||||
of points.
|
||||
|
||||
Return:
|
||||
Tensor: Points of single feature levels.
|
||||
The shape of tensor should be (N, 2) when with stride is
|
||||
``False``, where N = width * height, width and height
|
||||
are the sizes of the corresponding feature level,
|
||||
and the last dimension 2 represent (coord_x, coord_y),
|
||||
otherwise the shape should be (N, 4),
|
||||
and the last dimension 4 represent
|
||||
(coord_x, coord_y, stride_w, stride_h).
|
||||
"""
|
||||
feat_h, feat_w = featmap_size
|
||||
stride_w, stride_h = self.strides[level_idx]
|
||||
shift_x = (torch.arange(0, feat_w, device=device) + self.offset) * stride_w
|
||||
# keep featmap_size as Tensor instead of int, so that we
|
||||
# can convert to ONNX correctly
|
||||
shift_x = shift_x.to(dtype)
|
||||
|
||||
shift_y = (torch.arange(0, feat_h, device=device) + self.offset) * stride_h
|
||||
# keep featmap_size as Tensor instead of int, so that we
|
||||
# can convert to ONNX correctly
|
||||
shift_y = shift_y.to(dtype)
|
||||
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
|
||||
if not with_stride:
|
||||
shifts = torch.stack([shift_xx, shift_yy], dim=-1)
|
||||
else:
|
||||
# use `shape[0]` instead of `len(shift_xx)` for ONNX export
|
||||
stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to(dtype)
|
||||
stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to(dtype)
|
||||
shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h], dim=-1)
|
||||
all_points = shifts.to(device)
|
||||
return all_points
|
||||
|
||||
def valid_flags(self, featmap_sizes, pad_shape, device="cuda"):
|
||||
"""Generate valid flags of points of multiple feature levels.
|
||||
|
||||
Args:
|
||||
featmap_sizes (list(tuple)): List of feature map sizes in
|
||||
multiple feature levels, each size arrange as
|
||||
as (h, w).
|
||||
pad_shape (tuple(int)): The padded shape of the image,
|
||||
arrange as (h, w).
|
||||
device (str): The device where the anchors will be put on.
|
||||
|
||||
Return:
|
||||
list(torch.Tensor): Valid flags of points of multiple levels.
|
||||
"""
|
||||
assert self.num_levels == len(featmap_sizes)
|
||||
multi_level_flags = []
|
||||
for i in range(self.num_levels):
|
||||
point_stride = self.strides[i]
|
||||
feat_h, feat_w = featmap_sizes[i]
|
||||
h, w = pad_shape[:2]
|
||||
valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h)
|
||||
valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w)
|
||||
flags = self.single_level_valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device)
|
||||
multi_level_flags.append(flags)
|
||||
return multi_level_flags
|
||||
|
||||
def single_level_valid_flags(self, featmap_size, valid_size, device="cuda"):
|
||||
"""Generate the valid flags of points of a single feature map.
|
||||
|
||||
Args:
|
||||
featmap_size (tuple[int]): The size of feature maps, arrange as
|
||||
as (h, w).
|
||||
valid_size (tuple[int]): The valid size of the feature maps.
|
||||
The size arrange as as (h, w).
|
||||
device (str, optional): The device where the flags will be put on.
|
||||
Defaults to 'cuda'.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The valid flags of each points in a single level \
|
||||
feature map.
|
||||
"""
|
||||
feat_h, feat_w = featmap_size
|
||||
valid_h, valid_w = valid_size
|
||||
assert valid_h <= feat_h and valid_w <= feat_w
|
||||
valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
|
||||
valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
|
||||
valid_x[:valid_w] = 1
|
||||
valid_y[:valid_h] = 1
|
||||
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
|
||||
valid = valid_xx & valid_yy
|
||||
return valid
|
||||
|
||||
def sparse_priors(self, prior_idxs, featmap_size, level_idx, dtype=torch.float32, device="cuda"):
|
||||
"""Generate sparse points according to the ``prior_idxs``.
|
||||
|
||||
Args:
|
||||
prior_idxs (Tensor): The index of corresponding anchors
|
||||
in the feature map.
|
||||
featmap_size (tuple[int]): feature map size arrange as (w, h).
|
||||
level_idx (int): The level index of corresponding feature
|
||||
map.
|
||||
dtype (obj:`torch.dtype`): Date type of points. Defaults to
|
||||
``torch.float32``.
|
||||
device (obj:`torch.device`): The device where the points is
|
||||
located.
|
||||
Returns:
|
||||
Tensor: Anchor with shape (N, 2), N should be equal to
|
||||
the length of ``prior_idxs``. And last dimension
|
||||
2 represent (coord_x, coord_y).
|
||||
"""
|
||||
height, width = featmap_size
|
||||
x = (prior_idxs % width + self.offset) * self.strides[level_idx][0]
|
||||
y = ((prior_idxs // width) % height + self.offset) * self.strides[level_idx][1]
|
||||
prioris = torch.stack([x, y], 1).to(dtype)
|
||||
prioris = prioris.to(device)
|
||||
return prioris
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .builder import * # noqa: F403
|
||||
from .samplers import MaskPseudoSampler # noqa: F403
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
BBOX_SAMPLERS = Registry("bbox_sampler")
|
||||
BBOX_CODERS = Registry("bbox_coder")
|
||||
|
||||
|
||||
def build_sampler(cfg, **default_args):
|
||||
"""Builder of box sampler."""
|
||||
return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)
|
||||
|
||||
|
||||
def build_bbox_coder(cfg, **default_args):
|
||||
"""Builder of box coder."""
|
||||
return build_from_cfg(cfg, BBOX_CODERS, default_args)
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .mask_pseudo_sampler import MaskPseudoSampler # noqa: F403
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from .sampling_result import SamplingResult
|
||||
|
||||
|
||||
class BaseSampler(metaclass=ABCMeta):
|
||||
"""Base class of samplers."""
|
||||
|
||||
def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs):
|
||||
self.num = num
|
||||
self.pos_fraction = pos_fraction
|
||||
self.neg_pos_ub = neg_pos_ub
|
||||
self.add_gt_as_proposals = add_gt_as_proposals
|
||||
self.pos_sampler = self
|
||||
self.neg_sampler = self
|
||||
|
||||
@abstractmethod
|
||||
def _sample_pos(self, assign_result, num_expected, **kwargs):
|
||||
"""Sample positive samples."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _sample_neg(self, assign_result, num_expected, **kwargs):
|
||||
"""Sample negative samples."""
|
||||
pass
|
||||
|
||||
def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs):
|
||||
"""Sample positive and negative bboxes.
|
||||
|
||||
This is a simple implementation of bbox sampling given candidates,
|
||||
assigning results and ground truth bboxes.
|
||||
|
||||
Args:
|
||||
assign_result (:obj:`AssignResult`): Bbox assigning results.
|
||||
bboxes (Tensor): Boxes to be sampled from.
|
||||
gt_bboxes (Tensor): Ground truth bboxes.
|
||||
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
|
||||
|
||||
Returns:
|
||||
:obj:`SamplingResult`: Sampling result.
|
||||
|
||||
Example:
|
||||
>>> from mmdet.core.bbox import RandomSampler
|
||||
>>> from mmdet.core.bbox import AssignResult
|
||||
>>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
|
||||
>>> rng = ensure_rng(None)
|
||||
>>> assign_result = AssignResult.random(rng=rng)
|
||||
>>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
|
||||
>>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
|
||||
>>> gt_labels = None
|
||||
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
|
||||
>>> add_gt_as_proposals=False)
|
||||
>>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
|
||||
"""
|
||||
if len(bboxes.shape) < 2:
|
||||
bboxes = bboxes[None, :]
|
||||
|
||||
bboxes = bboxes[:, :4]
|
||||
|
||||
gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8)
|
||||
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
|
||||
if gt_labels is None:
|
||||
raise ValueError("gt_labels must be given when add_gt_as_proposals is True")
|
||||
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
|
||||
assign_result.add_gt_(gt_labels)
|
||||
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
|
||||
gt_flags = torch.cat([gt_ones, gt_flags])
|
||||
|
||||
num_expected_pos = int(self.num * self.pos_fraction)
|
||||
pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
|
||||
# We found that sampled indices have duplicated items occasionally.
|
||||
# (may be a bug of PyTorch)
|
||||
pos_inds = pos_inds.unique()
|
||||
num_sampled_pos = pos_inds.numel()
|
||||
num_expected_neg = self.num - num_sampled_pos
|
||||
if self.neg_pos_ub >= 0:
|
||||
_pos = max(1, num_sampled_pos)
|
||||
neg_upper_bound = int(self.neg_pos_ub * _pos)
|
||||
if num_expected_neg > neg_upper_bound:
|
||||
num_expected_neg = neg_upper_bound
|
||||
neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
|
||||
neg_inds = neg_inds.unique()
|
||||
|
||||
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags)
|
||||
return sampling_result
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py
|
||||
|
||||
import torch
|
||||
|
||||
from ..builder import BBOX_SAMPLERS
|
||||
from .base_sampler import BaseSampler
|
||||
from .mask_sampling_result import MaskSamplingResult
|
||||
|
||||
|
||||
@BBOX_SAMPLERS.register_module()
|
||||
class MaskPseudoSampler(BaseSampler):
|
||||
"""A pseudo sampler that does not do sampling actually."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _sample_pos(self, **kwargs):
|
||||
"""Sample positive samples."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _sample_neg(self, **kwargs):
|
||||
"""Sample negative samples."""
|
||||
raise NotImplementedError
|
||||
|
||||
def sample(self, assign_result, masks, gt_masks, **kwargs):
|
||||
"""Directly returns the positive and negative indices of samples.
|
||||
|
||||
Args:
|
||||
assign_result (:obj:`AssignResult`): Assigned results
|
||||
masks (torch.Tensor): Bounding boxes
|
||||
gt_masks (torch.Tensor): Ground truth boxes
|
||||
Returns:
|
||||
:obj:`SamplingResult`: sampler results
|
||||
"""
|
||||
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
|
||||
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
|
||||
gt_flags = masks.new_zeros(masks.shape[0], dtype=torch.uint8)
|
||||
sampling_result = MaskSamplingResult(pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags)
|
||||
return sampling_result
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/ZwwWayne/K-Net/blob/main/knet/det/mask_pseudo_sampler.py
|
||||
|
||||
import torch
|
||||
|
||||
from .sampling_result import SamplingResult
|
||||
|
||||
|
||||
class MaskSamplingResult(SamplingResult):
|
||||
"""Mask sampling result."""
|
||||
|
||||
def __init__(self, pos_inds, neg_inds, masks, gt_masks, assign_result, gt_flags):
|
||||
self.pos_inds = pos_inds
|
||||
self.neg_inds = neg_inds
|
||||
self.pos_masks = masks[pos_inds]
|
||||
self.neg_masks = masks[neg_inds]
|
||||
self.pos_is_gt = gt_flags[pos_inds]
|
||||
|
||||
self.num_gts = gt_masks.shape[0]
|
||||
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
|
||||
|
||||
if gt_masks.numel() == 0:
|
||||
# hack for index error case
|
||||
assert self.pos_assigned_gt_inds.numel() == 0
|
||||
self.pos_gt_masks = torch.empty_like(gt_masks)
|
||||
else:
|
||||
self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :]
|
||||
|
||||
if assign_result.labels is not None:
|
||||
self.pos_gt_labels = assign_result.labels[pos_inds]
|
||||
else:
|
||||
self.pos_gt_labels = None
|
||||
|
||||
@property
|
||||
def masks(self):
|
||||
"""torch.Tensor: concatenated positive and negative boxes"""
|
||||
return torch.cat([self.pos_masks, self.neg_masks])
|
||||
|
||||
def __nice__(self):
|
||||
data = self.info.copy()
|
||||
data["pos_masks"] = data.pop("pos_masks").shape
|
||||
data["neg_masks"] = data.pop("neg_masks").shape
|
||||
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
|
||||
body = " " + ",\n ".join(parts)
|
||||
return "{\n" + body + "\n}"
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
"""Returns a dictionary of info about the object."""
|
||||
return {
|
||||
"pos_inds": self.pos_inds,
|
||||
"neg_inds": self.neg_inds,
|
||||
"pos_masks": self.pos_masks,
|
||||
"neg_masks": self.neg_masks,
|
||||
"pos_is_gt": self.pos_is_gt,
|
||||
"num_gts": self.num_gts,
|
||||
"pos_assigned_gt_inds": self.pos_assigned_gt_inds,
|
||||
}
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class SamplingResult:
|
||||
"""Bbox sampling result.
|
||||
|
||||
Example:
|
||||
>>> # xdoctest: +IGNORE_WANT
|
||||
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
|
||||
>>> self = SamplingResult.random(rng=10)
|
||||
>>> print(f'self = {self}')
|
||||
self = <SamplingResult({
|
||||
'neg_bboxes': torch.Size([12, 4]),
|
||||
'neg_inds': tensor([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
|
||||
'num_gts': 4,
|
||||
'pos_assigned_gt_inds': tensor([], dtype=torch.int64),
|
||||
'pos_bboxes': torch.Size([0, 4]),
|
||||
'pos_inds': tensor([], dtype=torch.int64),
|
||||
'pos_is_gt': tensor([], dtype=torch.uint8)
|
||||
})>
|
||||
"""
|
||||
|
||||
def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags):
|
||||
self.pos_inds = pos_inds
|
||||
self.neg_inds = neg_inds
|
||||
self.pos_bboxes = bboxes[pos_inds]
|
||||
self.neg_bboxes = bboxes[neg_inds]
|
||||
self.pos_is_gt = gt_flags[pos_inds]
|
||||
|
||||
self.num_gts = gt_bboxes.shape[0]
|
||||
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
|
||||
|
||||
if gt_bboxes.numel() == 0:
|
||||
# hack for index error case
|
||||
assert self.pos_assigned_gt_inds.numel() == 0
|
||||
self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
|
||||
else:
|
||||
if len(gt_bboxes.shape) < 2:
|
||||
gt_bboxes = gt_bboxes.view(-1, 4)
|
||||
|
||||
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long(), :]
|
||||
|
||||
if assign_result.labels is not None:
|
||||
self.pos_gt_labels = assign_result.labels[pos_inds]
|
||||
else:
|
||||
self.pos_gt_labels = None
|
||||
|
||||
@property
|
||||
def bboxes(self):
|
||||
"""torch.Tensor: concatenated positive and negative boxes"""
|
||||
return torch.cat([self.pos_bboxes, self.neg_bboxes])
|
||||
|
||||
def to(self, device):
|
||||
"""Change the device of the data inplace.
|
||||
|
||||
Example:
|
||||
>>> self = SamplingResult.random()
|
||||
>>> print(f'self = {self.to(None)}')
|
||||
>>> # xdoctest: +REQUIRES(--gpu)
|
||||
>>> print(f'self = {self.to(0)}')
|
||||
"""
|
||||
_dict = self.__dict__
|
||||
for key, value in _dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
_dict[key] = value.to(device)
|
||||
return self
|
||||
|
||||
def __nice__(self):
|
||||
data = self.info.copy()
|
||||
data["pos_bboxes"] = data.pop("pos_bboxes").shape
|
||||
data["neg_bboxes"] = data.pop("neg_bboxes").shape
|
||||
parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
|
||||
body = " " + ",\n ".join(parts)
|
||||
return "{\n" + body + "\n}"
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
"""Returns a dictionary of info about the object."""
|
||||
return {
|
||||
"pos_inds": self.pos_inds,
|
||||
"neg_inds": self.neg_inds,
|
||||
"pos_bboxes": self.pos_bboxes,
|
||||
"neg_bboxes": self.neg_bboxes,
|
||||
"pos_is_gt": self.pos_is_gt,
|
||||
"num_gts": self.num_gts,
|
||||
"pos_assigned_gt_inds": self.pos_assigned_gt_inds,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def random(cls, rng=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
rng (None | int | numpy.random.RandomState): seed or state.
|
||||
kwargs (keyword arguments):
|
||||
- num_preds: number of predicted boxes
|
||||
- num_gts: number of true boxes
|
||||
- p_ignore (float): probability of a predicted box assigned to \
|
||||
an ignored truth.
|
||||
- p_assigned (float): probability of a predicted box not being \
|
||||
assigned.
|
||||
- p_use_label (float | bool): with labels or not.
|
||||
|
||||
Returns:
|
||||
:obj:`SamplingResult`: Randomly generated sampling result.
|
||||
|
||||
Example:
|
||||
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
|
||||
>>> self = SamplingResult.random()
|
||||
>>> print(self.__dict__)
|
||||
"""
|
||||
from mmdet.core.bbox import demodata
|
||||
from mmdet.core.bbox.assigners.assign_result import AssignResult
|
||||
from mmdet.core.bbox.samplers.random_sampler import RandomSampler
|
||||
|
||||
rng = demodata.ensure_rng(rng)
|
||||
|
||||
# make probabalistic?
|
||||
num = 32
|
||||
pos_fraction = 0.5
|
||||
neg_pos_ub = -1
|
||||
|
||||
assign_result = AssignResult.random(rng=rng, **kwargs)
|
||||
|
||||
# Note we could just compute an assignment
|
||||
bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
|
||||
gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)
|
||||
|
||||
if rng.rand() > 0.2:
|
||||
# sometimes algorithms squeeze their data, be robust to that
|
||||
gt_bboxes = gt_bboxes.squeeze()
|
||||
bboxes = bboxes.squeeze()
|
||||
|
||||
if assign_result.labels is None:
|
||||
gt_labels = None
|
||||
else:
|
||||
gt_labels = None
|
||||
|
||||
if gt_labels is None:
|
||||
add_gt_as_proposals = False
|
||||
else:
|
||||
add_gt_as_proposals = True # make probabalistic?
|
||||
|
||||
sampler = RandomSampler(
|
||||
num, pos_fraction, neg_pos_ub=neg_pos_ub, add_gt_as_proposals=add_gt_as_proposals, rng=rng
|
||||
)
|
||||
self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
|
||||
return self
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .dist_utils import reduce_mean
|
||||
from .misc import add_prefix, multi_apply
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def reduce_mean(tensor):
|
||||
""" "Obtain the mean of tensor on different GPUs."""
|
||||
if not (dist.is_available() and dist.is_initialized()):
|
||||
return tensor
|
||||
tensor = tensor.clone()
|
||||
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
|
||||
return tensor
|
|
@ -0,0 +1,47 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functools import partial
|
||||
|
||||
|
||||
def multi_apply(func, *args, **kwargs):
|
||||
"""Apply function to a list of arguments.
|
||||
|
||||
Note:
|
||||
This function applies the ``func`` to multiple inputs and
|
||||
map the multiple outputs of the ``func`` into different
|
||||
list. Each list contains the same type of outputs corresponding
|
||||
to different inputs.
|
||||
|
||||
Args:
|
||||
func (Function): A function that will be applied to a list of
|
||||
arguments
|
||||
|
||||
Returns:
|
||||
tuple(list): A tuple containing multiple list, each list contains \
|
||||
a kind of returned results by the function
|
||||
"""
|
||||
pfunc = partial(func, **kwargs) if kwargs else func
|
||||
map_results = map(pfunc, *args)
|
||||
return tuple(map(list, zip(*map_results)))
|
||||
|
||||
|
||||
def add_prefix(inputs, prefix):
|
||||
"""Add prefix for dict.
|
||||
|
||||
Args:
|
||||
inputs (dict): The input dict with str keys.
|
||||
prefix (str): The prefix to add.
|
||||
|
||||
Returns:
|
||||
|
||||
dict: The dict with keys updated with ``prefix``.
|
||||
"""
|
||||
|
||||
outputs = dict()
|
||||
for name, value in inputs.items():
|
||||
outputs[f"{prefix}.{name}"] = value
|
||||
|
||||
return outputs
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .backbones import * # noqa: F403
|
||||
from .builder import MASK_ASSIGNERS, MATCH_COST, TRANSFORMER, build_assigner, build_match_cost
|
||||
from .decode_heads import * # noqa: F403
|
||||
from .losses import * # noqa: F403
|
||||
from .plugins import * # noqa: F403
|
||||
from .segmentors import * # noqa: F403
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .vit_adapter import ViTAdapter
|
|
@ -0,0 +1,442 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
|
||||
from ...ops.modules import MSDeformAttn
|
||||
from .drop_path import DropPath
|
||||
|
||||
|
||||
def get_reference_points(spatial_shapes, device):
|
||||
reference_points_list = []
|
||||
for lvl, (H_, W_) in enumerate(spatial_shapes):
|
||||
ref_y, ref_x = torch.meshgrid(
|
||||
torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
|
||||
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device),
|
||||
)
|
||||
ref_y = ref_y.reshape(-1)[None] / H_
|
||||
ref_x = ref_x.reshape(-1)[None] / W_
|
||||
ref = torch.stack((ref_x, ref_y), -1)
|
||||
reference_points_list.append(ref)
|
||||
reference_points = torch.cat(reference_points_list, 1)
|
||||
reference_points = reference_points[:, :, None]
|
||||
return reference_points
|
||||
|
||||
|
||||
def deform_inputs(x, patch_size):
|
||||
bs, c, h, w = x.shape
|
||||
spatial_shapes = torch.as_tensor(
|
||||
[(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], dtype=torch.long, device=x.device
|
||||
)
|
||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||
reference_points = get_reference_points([(h // patch_size, w // patch_size)], x.device)
|
||||
deform_inputs1 = [reference_points, spatial_shapes, level_start_index]
|
||||
|
||||
spatial_shapes = torch.as_tensor([(h // patch_size, w // patch_size)], dtype=torch.long, device=x.device)
|
||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||
reference_points = get_reference_points([(h // 8, w // 8), (h // 16, w // 16), (h // 32, w // 32)], x.device)
|
||||
deform_inputs2 = [reference_points, spatial_shapes, level_start_index]
|
||||
|
||||
return deform_inputs1, deform_inputs2
|
||||
|
||||
|
||||
class ConvFFN(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.dwconv = DWConv(hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
x = self.fc1(x)
|
||||
x = self.dwconv(x, H, W)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class DWConv(nn.Module):
|
||||
def __init__(self, dim=768):
|
||||
super().__init__()
|
||||
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
n = N // 21
|
||||
x1 = x[:, 0 : 16 * n, :].transpose(1, 2).view(B, C, H * 2, W * 2).contiguous()
|
||||
x2 = x[:, 16 * n : 20 * n, :].transpose(1, 2).view(B, C, H, W).contiguous()
|
||||
x3 = x[:, 20 * n :, :].transpose(1, 2).view(B, C, H // 2, W // 2).contiguous()
|
||||
x1 = self.dwconv(x1).flatten(2).transpose(1, 2)
|
||||
x2 = self.dwconv(x2).flatten(2).transpose(1, 2)
|
||||
x3 = self.dwconv(x3).flatten(2).transpose(1, 2)
|
||||
x = torch.cat([x1, x2, x3], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class Extractor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=6,
|
||||
n_points=4,
|
||||
n_levels=1,
|
||||
deform_ratio=1.0,
|
||||
with_cffn=True,
|
||||
cffn_ratio=0.25,
|
||||
drop=0.0,
|
||||
drop_path=0.0,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
with_cp=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.query_norm = norm_layer(dim)
|
||||
self.feat_norm = norm_layer(dim)
|
||||
self.attn = MSDeformAttn(
|
||||
d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
|
||||
)
|
||||
self.with_cffn = with_cffn
|
||||
self.with_cp = with_cp
|
||||
if with_cffn:
|
||||
self.ffn = ConvFFN(in_features=dim, hidden_features=int(dim * cffn_ratio), drop=drop)
|
||||
self.ffn_norm = norm_layer(dim)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
def forward(self, query, reference_points, feat, spatial_shapes, level_start_index, H, W):
|
||||
def _inner_forward(query, feat):
|
||||
|
||||
attn = self.attn(
|
||||
self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
|
||||
)
|
||||
query = query + attn
|
||||
|
||||
if self.with_cffn:
|
||||
query = query + self.drop_path(self.ffn(self.ffn_norm(query), H, W))
|
||||
return query
|
||||
|
||||
if self.with_cp and query.requires_grad:
|
||||
query = cp.checkpoint(_inner_forward, query, feat)
|
||||
else:
|
||||
query = _inner_forward(query, feat)
|
||||
|
||||
return query
|
||||
|
||||
|
||||
class Injector(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=6,
|
||||
n_points=4,
|
||||
n_levels=1,
|
||||
deform_ratio=1.0,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
init_values=0.0,
|
||||
with_cp=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.with_cp = with_cp
|
||||
self.query_norm = norm_layer(dim)
|
||||
self.feat_norm = norm_layer(dim)
|
||||
self.attn = MSDeformAttn(
|
||||
d_model=dim, n_levels=n_levels, n_heads=num_heads, n_points=n_points, ratio=deform_ratio
|
||||
)
|
||||
self.gamma = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
|
||||
|
||||
def forward(self, query, reference_points, feat, spatial_shapes, level_start_index):
|
||||
def _inner_forward(query, feat):
|
||||
|
||||
attn = self.attn(
|
||||
self.query_norm(query), reference_points, self.feat_norm(feat), spatial_shapes, level_start_index, None
|
||||
)
|
||||
return query + self.gamma * attn
|
||||
|
||||
if self.with_cp and query.requires_grad:
|
||||
query = cp.checkpoint(_inner_forward, query, feat)
|
||||
else:
|
||||
query = _inner_forward(query, feat)
|
||||
|
||||
return query
|
||||
|
||||
|
||||
class InteractionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=6,
|
||||
n_points=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
drop=0.0,
|
||||
drop_path=0.0,
|
||||
with_cffn=True,
|
||||
cffn_ratio=0.25,
|
||||
init_values=0.0,
|
||||
deform_ratio=1.0,
|
||||
extra_extractor=False,
|
||||
with_cp=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.injector = Injector(
|
||||
dim=dim,
|
||||
n_levels=3,
|
||||
num_heads=num_heads,
|
||||
init_values=init_values,
|
||||
n_points=n_points,
|
||||
norm_layer=norm_layer,
|
||||
deform_ratio=deform_ratio,
|
||||
with_cp=with_cp,
|
||||
)
|
||||
self.extractor = Extractor(
|
||||
dim=dim,
|
||||
n_levels=1,
|
||||
num_heads=num_heads,
|
||||
n_points=n_points,
|
||||
norm_layer=norm_layer,
|
||||
deform_ratio=deform_ratio,
|
||||
with_cffn=with_cffn,
|
||||
cffn_ratio=cffn_ratio,
|
||||
drop=drop,
|
||||
drop_path=drop_path,
|
||||
with_cp=with_cp,
|
||||
)
|
||||
if extra_extractor:
|
||||
self.extra_extractors = nn.Sequential(
|
||||
*[
|
||||
Extractor(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
n_points=n_points,
|
||||
norm_layer=norm_layer,
|
||||
with_cffn=with_cffn,
|
||||
cffn_ratio=cffn_ratio,
|
||||
deform_ratio=deform_ratio,
|
||||
drop=drop,
|
||||
drop_path=drop_path,
|
||||
with_cp=with_cp,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.extra_extractors = None
|
||||
|
||||
def forward(self, x, c, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
|
||||
x = self.injector(
|
||||
query=x,
|
||||
reference_points=deform_inputs1[0],
|
||||
feat=c,
|
||||
spatial_shapes=deform_inputs1[1],
|
||||
level_start_index=deform_inputs1[2],
|
||||
)
|
||||
for idx, blk in enumerate(blocks):
|
||||
x = blk(x, H_toks, W_toks)
|
||||
c = self.extractor(
|
||||
query=c,
|
||||
reference_points=deform_inputs2[0],
|
||||
feat=x,
|
||||
spatial_shapes=deform_inputs2[1],
|
||||
level_start_index=deform_inputs2[2],
|
||||
H=H_c,
|
||||
W=W_c,
|
||||
)
|
||||
if self.extra_extractors is not None:
|
||||
for extractor in self.extra_extractors:
|
||||
c = extractor(
|
||||
query=c,
|
||||
reference_points=deform_inputs2[0],
|
||||
feat=x,
|
||||
spatial_shapes=deform_inputs2[1],
|
||||
level_start_index=deform_inputs2[2],
|
||||
H=H_c,
|
||||
W=W_c,
|
||||
)
|
||||
return x, c
|
||||
|
||||
|
||||
class InteractionBlockWithCls(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=6,
|
||||
n_points=4,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
drop=0.0,
|
||||
drop_path=0.0,
|
||||
with_cffn=True,
|
||||
cffn_ratio=0.25,
|
||||
init_values=0.0,
|
||||
deform_ratio=1.0,
|
||||
extra_extractor=False,
|
||||
with_cp=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.injector = Injector(
|
||||
dim=dim,
|
||||
n_levels=3,
|
||||
num_heads=num_heads,
|
||||
init_values=init_values,
|
||||
n_points=n_points,
|
||||
norm_layer=norm_layer,
|
||||
deform_ratio=deform_ratio,
|
||||
with_cp=with_cp,
|
||||
)
|
||||
self.extractor = Extractor(
|
||||
dim=dim,
|
||||
n_levels=1,
|
||||
num_heads=num_heads,
|
||||
n_points=n_points,
|
||||
norm_layer=norm_layer,
|
||||
deform_ratio=deform_ratio,
|
||||
with_cffn=with_cffn,
|
||||
cffn_ratio=cffn_ratio,
|
||||
drop=drop,
|
||||
drop_path=drop_path,
|
||||
with_cp=with_cp,
|
||||
)
|
||||
if extra_extractor:
|
||||
self.extra_extractors = nn.Sequential(
|
||||
*[
|
||||
Extractor(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
n_points=n_points,
|
||||
norm_layer=norm_layer,
|
||||
with_cffn=with_cffn,
|
||||
cffn_ratio=cffn_ratio,
|
||||
deform_ratio=deform_ratio,
|
||||
drop=drop,
|
||||
drop_path=drop_path,
|
||||
with_cp=with_cp,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.extra_extractors = None
|
||||
|
||||
def forward(self, x, c, cls, blocks, deform_inputs1, deform_inputs2, H_c, W_c, H_toks, W_toks):
|
||||
x = self.injector(
|
||||
query=x,
|
||||
reference_points=deform_inputs1[0],
|
||||
feat=c,
|
||||
spatial_shapes=deform_inputs1[1],
|
||||
level_start_index=deform_inputs1[2],
|
||||
)
|
||||
x = torch.cat((cls, x), dim=1)
|
||||
for idx, blk in enumerate(blocks):
|
||||
x = blk(x, H_toks, W_toks)
|
||||
cls, x = (
|
||||
x[
|
||||
:,
|
||||
:1,
|
||||
],
|
||||
x[
|
||||
:,
|
||||
1:,
|
||||
],
|
||||
)
|
||||
c = self.extractor(
|
||||
query=c,
|
||||
reference_points=deform_inputs2[0],
|
||||
feat=x,
|
||||
spatial_shapes=deform_inputs2[1],
|
||||
level_start_index=deform_inputs2[2],
|
||||
H=H_c,
|
||||
W=W_c,
|
||||
)
|
||||
if self.extra_extractors is not None:
|
||||
for extractor in self.extra_extractors:
|
||||
c = extractor(
|
||||
query=c,
|
||||
reference_points=deform_inputs2[0],
|
||||
feat=x,
|
||||
spatial_shapes=deform_inputs2[1],
|
||||
level_start_index=deform_inputs2[2],
|
||||
H=H_c,
|
||||
W=W_c,
|
||||
)
|
||||
return x, c, cls
|
||||
|
||||
|
||||
class SpatialPriorModule(nn.Module):
|
||||
def __init__(self, inplanes=64, embed_dim=384, with_cp=False):
|
||||
super().__init__()
|
||||
self.with_cp = with_cp
|
||||
|
||||
self.stem = nn.Sequential(
|
||||
*[
|
||||
nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.SyncBatchNorm(inplanes),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.SyncBatchNorm(inplanes),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(inplanes, inplanes, kernel_size=3, stride=1, padding=1, bias=False),
|
||||
nn.SyncBatchNorm(inplanes),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
|
||||
]
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
*[
|
||||
nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.SyncBatchNorm(2 * inplanes),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
)
|
||||
self.conv3 = nn.Sequential(
|
||||
*[
|
||||
nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.SyncBatchNorm(4 * inplanes),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
)
|
||||
self.conv4 = nn.Sequential(
|
||||
*[
|
||||
nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
|
||||
nn.SyncBatchNorm(4 * inplanes),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
)
|
||||
self.fc1 = nn.Conv2d(inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
self.fc2 = nn.Conv2d(2 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
self.fc3 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
self.fc4 = nn.Conv2d(4 * inplanes, embed_dim, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
|
||||
def forward(self, x):
|
||||
def _inner_forward(x):
|
||||
c1 = self.stem(x)
|
||||
c2 = self.conv2(c1)
|
||||
c3 = self.conv3(c2)
|
||||
c4 = self.conv4(c3)
|
||||
c1 = self.fc1(c1)
|
||||
c2 = self.fc2(c2)
|
||||
c3 = self.fc3(c3)
|
||||
c4 = self.fc4(c4)
|
||||
|
||||
bs, dim, _, _ = c1.shape
|
||||
# c1 = c1.view(bs, dim, -1).transpose(1, 2) # 4s
|
||||
c2 = c2.view(bs, dim, -1).transpose(1, 2) # 8s
|
||||
c3 = c3.view(bs, dim, -1).transpose(1, 2) # 16s
|
||||
c4 = c4.view(bs, dim, -1).transpose(1, 2) # 32s
|
||||
|
||||
return c1, c2, c3, c4
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
outs = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
outs = _inner_forward(x)
|
||||
return outs
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob: float = 0.0):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
|
@ -0,0 +1,552 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""Vision Transformer (ViT) in PyTorch.
|
||||
|
||||
A PyTorch implement of Vision Transformers as described in:
|
||||
|
||||
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
|
||||
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
|
||||
- https://arxiv.org/abs/2106.10270
|
||||
|
||||
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
||||
|
||||
DeiT model defs and weights from https://github.com/facebookresearch/deit,
|
||||
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
||||
|
||||
Acknowledgments:
|
||||
* The paper authors for releasing code and weights, thanks!
|
||||
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
||||
for some einops/einsum fun
|
||||
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
||||
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from functools import partial
|
||||
from itertools import repeat
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.runner import BaseModule, load_checkpoint
|
||||
from mmseg.ops import resize
|
||||
from mmseg.utils import get_root_logger
|
||||
from torch import Tensor
|
||||
|
||||
from .drop_path import DropPath
|
||||
|
||||
|
||||
def to_2tuple(x):
|
||||
return tuple(repeat(x, 2))
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
swiglu_hidden_features = int(2 * hidden_features / 3)
|
||||
align_as = 8
|
||||
swiglu_hidden_features = (swiglu_hidden_features + align_as - 1) // align_as * align_as
|
||||
self.w1 = nn.Linear(in_features, swiglu_hidden_features)
|
||||
self.w2 = nn.Linear(in_features, swiglu_hidden_features)
|
||||
self.w3 = nn.Linear(swiglu_hidden_features, out_features)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1 = self.w1(x)
|
||||
x2 = self.w2(x)
|
||||
hidden = F.silu(x1) * x2
|
||||
return self.w3(hidden)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding."""
|
||||
|
||||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, bias=True
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
x = self.norm(x)
|
||||
return x, H, W
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class MemEffAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: Tensor, H, W) -> Tensor:
|
||||
from xformers.ops import memory_efficient_attention, unbind
|
||||
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
|
||||
q, k, v = unbind(qkv, 2)
|
||||
|
||||
x = memory_efficient_attention(q, k, v)
|
||||
x = x.reshape([B, N, C])
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowedAttention(nn.Module):
|
||||
def __init__(
|
||||
self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0, window_size=14, pad_mode="constant"
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.window_size = window_size
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
N_ = self.window_size * self.window_size
|
||||
H_ = math.ceil(H / self.window_size) * self.window_size
|
||||
W_ = math.ceil(W / self.window_size) * self.window_size
|
||||
|
||||
qkv = self.qkv(x) # [B, N, C]
|
||||
qkv = qkv.transpose(1, 2).reshape(B, C * 3, H, W) # [B, C, H, W]
|
||||
qkv = F.pad(qkv, [0, W_ - W, 0, H_ - H], mode=self.pad_mode)
|
||||
|
||||
qkv = F.unfold(
|
||||
qkv, kernel_size=(self.window_size, self.window_size), stride=(self.window_size, self.window_size)
|
||||
)
|
||||
B, C_kw_kw, L = qkv.shape # L - the num of windows
|
||||
qkv = qkv.reshape(B, C * 3, N_, L).permute(0, 3, 2, 1) # [B, L, N_, C]
|
||||
qkv = qkv.reshape(B, L, N_, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
# q,k,v [B, L, num_head, N_, C/num_head]
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_]
|
||||
# if self.mask:
|
||||
# attn = attn * mask
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn) # [B, L, num_head, N_, N_]
|
||||
# attn @ v = [B, L, num_head, N_, C/num_head]
|
||||
x = (attn @ v).permute(0, 2, 4, 3, 1).reshape(B, C_kw_kw // 3, L)
|
||||
|
||||
x = F.fold(
|
||||
x,
|
||||
output_size=(H_, W_),
|
||||
kernel_size=(self.window_size, self.window_size),
|
||||
stride=(self.window_size, self.window_size),
|
||||
) # [B, C, H_, W_]
|
||||
x = x[:, :, :H, :W].reshape(B, C, N).transpose(-1, -2)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
# class WindowedAttention(nn.Module):
|
||||
# def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., window_size=14, pad_mode="constant"):
|
||||
# super().__init__()
|
||||
# self.num_heads = num_heads
|
||||
# head_dim = dim // num_heads
|
||||
# self.scale = head_dim ** -0.5
|
||||
#
|
||||
# self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
# self.attn_drop = nn.Dropout(attn_drop)
|
||||
# self.proj = nn.Linear(dim, dim)
|
||||
# self.proj_drop = nn.Dropout(proj_drop)
|
||||
# self.window_size = window_size
|
||||
# self.pad_mode = pad_mode
|
||||
#
|
||||
# def forward(self, x, H, W):
|
||||
# B, N, C = x.shape
|
||||
#
|
||||
# N_ = self.window_size * self.window_size
|
||||
# H_ = math.ceil(H / self.window_size) * self.window_size
|
||||
# W_ = math.ceil(W / self.window_size) * self.window_size
|
||||
# x = x.view(B, H, W, C)
|
||||
# x = F.pad(x, [0, 0, 0, W_ - W, 0, H_- H], mode=self.pad_mode)
|
||||
#
|
||||
# x = window_partition(x, window_size=self.window_size)# nW*B, window_size, window_size, C
|
||||
# x = x.view(-1, N_, C)
|
||||
#
|
||||
# qkv = self.qkv(x).view(-1, N_, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
# q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
# attn = (q @ k.transpose(-2, -1)) * self.scale # [B, L, num_head, N_, N_]
|
||||
# attn = attn.softmax(dim=-1)
|
||||
# attn = self.attn_drop(attn) # [B, L, num_head, N_, N_]
|
||||
# x = (attn @ v).transpose(1, 2).reshape(-1, self.window_size, self.window_size, C)
|
||||
#
|
||||
# x = window_reverse(x, self.window_size, H_, W_)
|
||||
# x = x[:, :H, :W, :].reshape(B, N, C).contiguous()
|
||||
# x = self.proj(x)
|
||||
# x = self.proj_drop(x)
|
||||
# return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
windowed=False,
|
||||
window_size=14,
|
||||
pad_mode="constant",
|
||||
layer_scale=False,
|
||||
with_cp=False,
|
||||
ffn_layer=Mlp,
|
||||
memeff=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.with_cp = with_cp
|
||||
self.norm1 = norm_layer(dim)
|
||||
if windowed:
|
||||
self.attn = WindowedAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
window_size=window_size,
|
||||
pad_mode=pad_mode,
|
||||
)
|
||||
elif memeff:
|
||||
self.attn = MemEffAttention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop
|
||||
)
|
||||
else:
|
||||
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
self.layer_scale = layer_scale
|
||||
if layer_scale:
|
||||
self.gamma1 = nn.Parameter(torch.ones((dim)), requires_grad=True)
|
||||
self.gamma2 = nn.Parameter(torch.ones((dim)), requires_grad=True)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
def _inner_forward(x):
|
||||
if self.layer_scale:
|
||||
x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x), H, W))
|
||||
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
x = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
x = _inner_forward(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TIMMVisionTransformer(BaseModule):
|
||||
"""Vision Transformer.
|
||||
|
||||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
||||
- https://arxiv.org/abs/2010.11929
|
||||
|
||||
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
||||
- https://arxiv.org/abs/2012.12877
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
layer_scale=True,
|
||||
embed_layer=PatchEmbed,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
act_layer=nn.GELU,
|
||||
window_attn=False,
|
||||
window_size=14,
|
||||
pretrained=None,
|
||||
with_cp=False,
|
||||
pre_norm=False,
|
||||
ffn_type="mlp",
|
||||
memeff=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
num_classes (int): number of classes for classification head
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
drop_rate (float): dropout rate
|
||||
attn_drop_rate (float): attention dropout rate
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
norm_layer: (nn.Module): normalization layer
|
||||
pretrained: (str): pretrained path
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_tokens = 1
|
||||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
||||
act_layer = act_layer or nn.GELU
|
||||
self.norm_layer = norm_layer
|
||||
self.act_layer = act_layer
|
||||
self.pretrain_size = img_size
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.drop_rate = drop_rate
|
||||
self.patch_size = patch_size
|
||||
|
||||
window_attn = [window_attn] * depth if not isinstance(window_attn, list) else window_attn
|
||||
window_size = [window_size] * depth if not isinstance(window_size, list) else window_size
|
||||
logging.info("window attention:", window_attn)
|
||||
logging.info("window size:", window_size)
|
||||
logging.info("layer scale:", layer_scale)
|
||||
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, bias=not pre_norm
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
ffn_types = {"mlp": Mlp, "swiglu": SwiGLUFFN}
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.Sequential(
|
||||
*[
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
windowed=window_attn[i],
|
||||
window_size=window_size[i],
|
||||
layer_scale=layer_scale,
|
||||
with_cp=with_cp,
|
||||
ffn_layer=ffn_types[ffn_type],
|
||||
memeff=memeff,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# self.norm = norm_layer(embed_dim)
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
# For CLIP
|
||||
if pre_norm:
|
||||
norm_pre = norm_layer(embed_dim)
|
||||
self.norm_pre = norm_pre
|
||||
else:
|
||||
self.norm_pre = nn.Identity()
|
||||
self.init_weights(pretrained)
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
logger = get_root_logger()
|
||||
load_checkpoint(self, pretrained, map_location="cpu", strict=False, logger=logger)
|
||||
|
||||
def forward_features(self, x):
|
||||
x, H, W = self.patch_embed(x)
|
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
|
||||
# For CLIP
|
||||
x = self.norm_pre(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x, H, W)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def resize_pos_embed(pos_embed, input_shpae, pos_shape, mode):
|
||||
"""Resize pos_embed weights.
|
||||
|
||||
Resize pos_embed using bicubic interpolate method.
|
||||
Args:
|
||||
pos_embed (torch.Tensor): Position embedding weights.
|
||||
input_shpae (tuple): Tuple for (downsampled input image height,
|
||||
downsampled input image width).
|
||||
pos_shape (tuple): The resolution of downsampled origin training
|
||||
image.
|
||||
mode (str): Algorithm used for upsampling:
|
||||
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
||||
``'trilinear'``. Default: ``'nearest'``
|
||||
Return:
|
||||
torch.Tensor: The resized pos_embed of shape [B, L_new, C]
|
||||
"""
|
||||
assert pos_embed.ndim == 3, "shape of pos_embed must be [B, L, C]"
|
||||
pos_h, pos_w = pos_shape
|
||||
# keep dim for easy deployment
|
||||
cls_token_weight = pos_embed[:, 0:1]
|
||||
pos_embed_weight = pos_embed[:, (-1 * pos_h * pos_w) :]
|
||||
pos_embed_weight = pos_embed_weight.reshape(1, pos_h, pos_w, pos_embed.shape[2]).permute(0, 3, 1, 2)
|
||||
pos_embed_weight = resize(pos_embed_weight, size=input_shpae, align_corners=False, mode=mode)
|
||||
pos_embed_weight = torch.flatten(pos_embed_weight, 2).transpose(1, 2)
|
||||
pos_embed = torch.cat((cls_token_weight, pos_embed_weight), dim=1)
|
||||
return pos_embed
|
|
@ -0,0 +1,217 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmseg.models.builder import BACKBONES
|
||||
from torch.nn.init import normal_
|
||||
|
||||
from ...ops.modules import MSDeformAttn
|
||||
from .adapter_modules import InteractionBlock, InteractionBlockWithCls, SpatialPriorModule, deform_inputs
|
||||
from .vit import TIMMVisionTransformer
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ViTAdapter(TIMMVisionTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
pretrain_size=224,
|
||||
num_heads=12,
|
||||
conv_inplane=64,
|
||||
n_points=4,
|
||||
deform_num_heads=6,
|
||||
init_values=0.0,
|
||||
interaction_indexes=None,
|
||||
with_cffn=True,
|
||||
cffn_ratio=0.25,
|
||||
deform_ratio=1.0,
|
||||
add_vit_feature=True,
|
||||
pretrained=None,
|
||||
use_extra_extractor=True,
|
||||
freeze_vit=False,
|
||||
use_cls=True,
|
||||
with_cp=False,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super().__init__(num_heads=num_heads, pretrained=pretrained, with_cp=with_cp, *args, **kwargs)
|
||||
if freeze_vit:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# self.num_classes = 80
|
||||
self.use_cls = use_cls
|
||||
if not self.use_cls:
|
||||
self.cls_token = None
|
||||
self.num_block = len(self.blocks)
|
||||
self.pretrain_size = (pretrain_size, pretrain_size)
|
||||
self.interaction_indexes = interaction_indexes
|
||||
self.add_vit_feature = add_vit_feature
|
||||
embed_dim = self.embed_dim
|
||||
|
||||
block_fn = InteractionBlockWithCls if use_cls else InteractionBlock
|
||||
|
||||
self.level_embed = nn.Parameter(torch.zeros(3, embed_dim))
|
||||
self.spm = SpatialPriorModule(inplanes=conv_inplane, embed_dim=embed_dim, with_cp=False)
|
||||
self.interactions = nn.Sequential(
|
||||
*[
|
||||
block_fn(
|
||||
dim=embed_dim,
|
||||
num_heads=deform_num_heads,
|
||||
n_points=n_points,
|
||||
init_values=init_values,
|
||||
drop_path=self.drop_path_rate,
|
||||
norm_layer=self.norm_layer,
|
||||
with_cffn=with_cffn,
|
||||
cffn_ratio=cffn_ratio,
|
||||
deform_ratio=deform_ratio,
|
||||
extra_extractor=((True if i == len(interaction_indexes) - 1 else False) and use_extra_extractor),
|
||||
with_cp=with_cp,
|
||||
)
|
||||
for i in range(len(interaction_indexes))
|
||||
]
|
||||
)
|
||||
self.up = nn.ConvTranspose2d(embed_dim, embed_dim, 2, 2)
|
||||
self.norm1 = nn.SyncBatchNorm(embed_dim)
|
||||
self.norm2 = nn.SyncBatchNorm(embed_dim)
|
||||
self.norm3 = nn.SyncBatchNorm(embed_dim)
|
||||
self.norm4 = nn.SyncBatchNorm(embed_dim)
|
||||
|
||||
self.up.apply(self._init_weights)
|
||||
self.spm.apply(self._init_weights)
|
||||
self.interactions.apply(self._init_weights)
|
||||
self.apply(self._init_deform_weights)
|
||||
normal_(self.level_embed)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
torch.nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm2d):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _get_pos_embed(self, pos_embed, H, W):
|
||||
pos_embed = pos_embed.reshape(
|
||||
1, self.pretrain_size[0] // self.patch_size, self.pretrain_size[1] // self.patch_size, -1
|
||||
).permute(0, 3, 1, 2)
|
||||
pos_embed = (
|
||||
F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
|
||||
.reshape(1, -1, H * W)
|
||||
.permute(0, 2, 1)
|
||||
)
|
||||
return pos_embed
|
||||
|
||||
def _init_deform_weights(self, m):
|
||||
if isinstance(m, MSDeformAttn):
|
||||
m._reset_parameters()
|
||||
|
||||
def _add_level_embed(self, c2, c3, c4):
|
||||
c2 = c2 + self.level_embed[0]
|
||||
c3 = c3 + self.level_embed[1]
|
||||
c4 = c4 + self.level_embed[2]
|
||||
return c2, c3, c4
|
||||
|
||||
def forward(self, x):
|
||||
deform_inputs1, deform_inputs2 = deform_inputs(x, self.patch_size)
|
||||
|
||||
# SPM forward
|
||||
c1, c2, c3, c4 = self.spm(x)
|
||||
c2, c3, c4 = self._add_level_embed(c2, c3, c4)
|
||||
c = torch.cat([c2, c3, c4], dim=1)
|
||||
|
||||
# Patch Embedding forward
|
||||
H_c, W_c = x.shape[2] // 16, x.shape[3] // 16
|
||||
x, H_toks, W_toks = self.patch_embed(x)
|
||||
# print("H_toks, W_toks =", H_toks, W_toks)
|
||||
bs, n, dim = x.shape
|
||||
pos_embed = self._get_pos_embed(self.pos_embed[:, 1:], H_toks, W_toks)
|
||||
if self.use_cls:
|
||||
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_token, x), dim=1)
|
||||
pos_embed = torch.cat((self.pos_embed[:, :1], pos_embed), dim=1)
|
||||
x = self.pos_drop(x + pos_embed)
|
||||
# For CLIP
|
||||
x = self.norm_pre(x)
|
||||
|
||||
# Interaction
|
||||
if self.use_cls:
|
||||
cls, x = (
|
||||
x[
|
||||
:,
|
||||
:1,
|
||||
],
|
||||
x[
|
||||
:,
|
||||
1:,
|
||||
],
|
||||
)
|
||||
outs = list()
|
||||
for i, layer in enumerate(self.interactions):
|
||||
indexes = self.interaction_indexes[i]
|
||||
if self.use_cls:
|
||||
x, c, cls = layer(
|
||||
x,
|
||||
c,
|
||||
cls,
|
||||
self.blocks[indexes[0] : indexes[-1] + 1],
|
||||
deform_inputs1,
|
||||
deform_inputs2,
|
||||
H_c,
|
||||
W_c,
|
||||
H_toks,
|
||||
W_toks,
|
||||
)
|
||||
else:
|
||||
x, c = layer(
|
||||
x,
|
||||
c,
|
||||
self.blocks[indexes[0] : indexes[-1] + 1],
|
||||
deform_inputs1,
|
||||
deform_inputs2,
|
||||
H_c,
|
||||
W_c,
|
||||
H_toks,
|
||||
W_toks,
|
||||
)
|
||||
outs.append(x.transpose(1, 2).view(bs, dim, H_toks, W_toks).contiguous())
|
||||
|
||||
# Split & Reshape
|
||||
c2 = c[:, 0 : c2.size(1), :]
|
||||
c3 = c[:, c2.size(1) : c2.size(1) + c3.size(1), :]
|
||||
c4 = c[:, c2.size(1) + c3.size(1) :, :]
|
||||
|
||||
c2 = c2.transpose(1, 2).view(bs, dim, H_c * 2, W_c * 2).contiguous()
|
||||
c3 = c3.transpose(1, 2).view(bs, dim, H_c, W_c).contiguous()
|
||||
c4 = c4.transpose(1, 2).view(bs, dim, H_c // 2, W_c // 2).contiguous()
|
||||
c1 = self.up(c2) + c1
|
||||
|
||||
if self.add_vit_feature:
|
||||
x1, x2, x3, x4 = outs
|
||||
|
||||
x1 = F.interpolate(x1, size=(4 * H_c, 4 * W_c), mode="bilinear", align_corners=False)
|
||||
x2 = F.interpolate(x2, size=(2 * H_c, 2 * W_c), mode="bilinear", align_corners=False)
|
||||
x3 = F.interpolate(x3, size=(1 * H_c, 1 * W_c), mode="bilinear", align_corners=False)
|
||||
x4 = F.interpolate(x4, size=(H_c // 2, W_c // 2), mode="bilinear", align_corners=False)
|
||||
# print(c1.shape, c2.shape, c3.shape, c4.shape, x1.shape, x2.shape, x3.shape, x4.shape, H_c, H_toks)
|
||||
c1, c2, c3, c4 = c1 + x1, c2 + x2, c3 + x3, c4 + x4
|
||||
|
||||
# Final Norm
|
||||
f1 = self.norm1(c1)
|
||||
f2 = self.norm2(c2)
|
||||
f3 = self.norm3(c3)
|
||||
f4 = self.norm4(c4)
|
||||
return [f1, f2, f3, f4]
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from mmcv.utils import Registry
|
||||
|
||||
TRANSFORMER = Registry("Transformer")
|
||||
MASK_ASSIGNERS = Registry("mask_assigner")
|
||||
MATCH_COST = Registry("match_cost")
|
||||
|
||||
|
||||
def build_match_cost(cfg):
|
||||
"""Build Match Cost."""
|
||||
return MATCH_COST.build(cfg)
|
||||
|
||||
|
||||
def build_assigner(cfg):
|
||||
"""Build Assigner."""
|
||||
return MASK_ASSIGNERS.build(cfg)
|
||||
|
||||
|
||||
def build_transformer(cfg):
|
||||
"""Build Transformer."""
|
||||
return TRANSFORMER.build(cfg)
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .mask2former_head import Mask2FormerHead
|
|
@ -0,0 +1,544 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import Conv2d, build_plugin_layer, caffe2_xavier_init
|
||||
from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence
|
||||
from mmcv.ops import point_sample
|
||||
from mmcv.runner import ModuleList, force_fp32
|
||||
from mmseg.models.builder import HEADS, build_loss
|
||||
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||||
|
||||
from ...core import build_sampler, multi_apply, reduce_mean
|
||||
from ..builder import build_assigner
|
||||
from ..utils import get_uncertain_point_coords_with_randomness
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class Mask2FormerHead(BaseDecodeHead):
|
||||
"""Implements the Mask2Former head.
|
||||
|
||||
See `Masked-attention Mask Transformer for Universal Image
|
||||
Segmentation <https://arxiv.org/pdf/2112.01527>`_ for details.
|
||||
|
||||
Args:
|
||||
in_channels (list[int]): Number of channels in the input feature map.
|
||||
feat_channels (int): Number of channels for features.
|
||||
out_channels (int): Number of channels for output.
|
||||
num_things_classes (int): Number of things.
|
||||
num_stuff_classes (int): Number of stuff.
|
||||
num_queries (int): Number of query in Transformer decoder.
|
||||
pixel_decoder (:obj:`mmcv.ConfigDict` | dict): Config for pixel
|
||||
decoder. Defaults to None.
|
||||
enforce_decoder_input_project (bool, optional): Whether to add
|
||||
a layer to change the embed_dim of tranformer encoder in
|
||||
pixel decoder to the embed_dim of transformer decoder.
|
||||
Defaults to False.
|
||||
transformer_decoder (:obj:`mmcv.ConfigDict` | dict): Config for
|
||||
transformer decoder. Defaults to None.
|
||||
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
|
||||
transformer decoder position encoding. Defaults to None.
|
||||
loss_cls (:obj:`mmcv.ConfigDict` | dict): Config of the classification
|
||||
loss. Defaults to None.
|
||||
loss_mask (:obj:`mmcv.ConfigDict` | dict): Config of the mask loss.
|
||||
Defaults to None.
|
||||
loss_dice (:obj:`mmcv.ConfigDict` | dict): Config of the dice loss.
|
||||
Defaults to None.
|
||||
train_cfg (:obj:`mmcv.ConfigDict` | dict): Training config of
|
||||
Mask2Former head.
|
||||
test_cfg (:obj:`mmcv.ConfigDict` | dict): Testing config of
|
||||
Mask2Former head.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
feat_channels,
|
||||
out_channels,
|
||||
num_things_classes=80,
|
||||
num_stuff_classes=53,
|
||||
num_queries=100,
|
||||
num_transformer_feat_level=3,
|
||||
pixel_decoder=None,
|
||||
enforce_decoder_input_project=False,
|
||||
transformer_decoder=None,
|
||||
positional_encoding=None,
|
||||
loss_cls=None,
|
||||
loss_mask=None,
|
||||
loss_dice=None,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
init_cfg=None,
|
||||
**kwargs,
|
||||
):
|
||||
super(Mask2FormerHead, self).__init__(
|
||||
in_channels=in_channels,
|
||||
channels=feat_channels,
|
||||
num_classes=(num_things_classes + num_stuff_classes),
|
||||
init_cfg=init_cfg,
|
||||
input_transform="multiple_select",
|
||||
**kwargs,
|
||||
)
|
||||
self.num_things_classes = num_things_classes
|
||||
self.num_stuff_classes = num_stuff_classes
|
||||
self.num_classes = self.num_things_classes + self.num_stuff_classes
|
||||
self.num_queries = num_queries
|
||||
self.num_transformer_feat_level = num_transformer_feat_level
|
||||
self.num_heads = transformer_decoder.transformerlayers.attn_cfgs.num_heads
|
||||
self.num_transformer_decoder_layers = transformer_decoder.num_layers
|
||||
assert pixel_decoder.encoder.transformerlayers.attn_cfgs.num_levels == num_transformer_feat_level
|
||||
pixel_decoder_ = copy.deepcopy(pixel_decoder)
|
||||
pixel_decoder_.update(in_channels=in_channels, feat_channels=feat_channels, out_channels=out_channels)
|
||||
self.pixel_decoder = build_plugin_layer(pixel_decoder_)[1]
|
||||
self.transformer_decoder = build_transformer_layer_sequence(transformer_decoder)
|
||||
self.decoder_embed_dims = self.transformer_decoder.embed_dims
|
||||
|
||||
self.decoder_input_projs = ModuleList()
|
||||
# from low resolution to high resolution
|
||||
for _ in range(num_transformer_feat_level):
|
||||
if self.decoder_embed_dims != feat_channels or enforce_decoder_input_project:
|
||||
self.decoder_input_projs.append(Conv2d(feat_channels, self.decoder_embed_dims, kernel_size=1))
|
||||
else:
|
||||
self.decoder_input_projs.append(nn.Identity())
|
||||
self.decoder_positional_encoding = build_positional_encoding(positional_encoding)
|
||||
self.query_embed = nn.Embedding(self.num_queries, feat_channels)
|
||||
self.query_feat = nn.Embedding(self.num_queries, feat_channels)
|
||||
# from low resolution to high resolution
|
||||
self.level_embed = nn.Embedding(self.num_transformer_feat_level, feat_channels)
|
||||
|
||||
self.cls_embed = nn.Linear(feat_channels, self.num_classes + 1)
|
||||
self.mask_embed = nn.Sequential(
|
||||
nn.Linear(feat_channels, feat_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(feat_channels, feat_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(feat_channels, out_channels),
|
||||
)
|
||||
self.conv_seg = None # fix a bug here (conv_seg is not used)
|
||||
|
||||
self.test_cfg = test_cfg
|
||||
self.train_cfg = train_cfg
|
||||
if train_cfg:
|
||||
self.assigner = build_assigner(self.train_cfg.assigner)
|
||||
self.sampler = build_sampler(self.train_cfg.sampler, context=self)
|
||||
self.num_points = self.train_cfg.get("num_points", 12544)
|
||||
self.oversample_ratio = self.train_cfg.get("oversample_ratio", 3.0)
|
||||
self.importance_sample_ratio = self.train_cfg.get("importance_sample_ratio", 0.75)
|
||||
|
||||
self.class_weight = loss_cls.class_weight
|
||||
self.loss_cls = build_loss(loss_cls)
|
||||
self.loss_mask = build_loss(loss_mask)
|
||||
self.loss_dice = build_loss(loss_dice)
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.decoder_input_projs:
|
||||
if isinstance(m, Conv2d):
|
||||
caffe2_xavier_init(m, bias=0)
|
||||
|
||||
self.pixel_decoder.init_weights()
|
||||
|
||||
for p in self.transformer_decoder.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_normal_(p)
|
||||
|
||||
def get_targets(self, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas):
|
||||
"""Compute classification and mask targets for all images for a decoder
|
||||
layer.
|
||||
|
||||
Args:
|
||||
cls_scores_list (list[Tensor]): Mask score logits from a single
|
||||
decoder layer for all images. Each with shape [num_queries,
|
||||
cls_out_channels].
|
||||
mask_preds_list (list[Tensor]): Mask logits from a single decoder
|
||||
layer for all images. Each with shape [num_queries, h, w].
|
||||
gt_labels_list (list[Tensor]): Ground truth class indices for all
|
||||
images. Each with shape (n, ), n is the sum of number of stuff
|
||||
type and number of instance in a image.
|
||||
gt_masks_list (list[Tensor]): Ground truth mask for each image,
|
||||
each with shape (n, h, w).
|
||||
img_metas (list[dict]): List of image meta information.
|
||||
|
||||
Returns:
|
||||
tuple[list[Tensor]]: a tuple containing the following targets.
|
||||
|
||||
- labels_list (list[Tensor]): Labels of all images.
|
||||
Each with shape [num_queries, ].
|
||||
- label_weights_list (list[Tensor]): Label weights of all
|
||||
images.Each with shape [num_queries, ].
|
||||
- mask_targets_list (list[Tensor]): Mask targets of all images.
|
||||
Each with shape [num_queries, h, w].
|
||||
- mask_weights_list (list[Tensor]): Mask weights of all images.
|
||||
Each with shape [num_queries, ].
|
||||
- num_total_pos (int): Number of positive samples in all
|
||||
images.
|
||||
- num_total_neg (int): Number of negative samples in all
|
||||
images.
|
||||
"""
|
||||
(
|
||||
labels_list,
|
||||
label_weights_list,
|
||||
mask_targets_list,
|
||||
mask_weights_list,
|
||||
pos_inds_list,
|
||||
neg_inds_list,
|
||||
) = multi_apply(
|
||||
self._get_target_single, cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas
|
||||
)
|
||||
|
||||
num_total_pos = sum((inds.numel() for inds in pos_inds_list))
|
||||
num_total_neg = sum((inds.numel() for inds in neg_inds_list))
|
||||
return (labels_list, label_weights_list, mask_targets_list, mask_weights_list, num_total_pos, num_total_neg)
|
||||
|
||||
def _get_target_single(self, cls_score, mask_pred, gt_labels, gt_masks, img_metas):
|
||||
"""Compute classification and mask targets for one image.
|
||||
|
||||
Args:
|
||||
cls_score (Tensor): Mask score logits from a single decoder layer
|
||||
for one image. Shape (num_queries, cls_out_channels).
|
||||
mask_pred (Tensor): Mask logits for a single decoder layer for one
|
||||
image. Shape (num_queries, h, w).
|
||||
gt_labels (Tensor): Ground truth class indices for one image with
|
||||
shape (num_gts, ).
|
||||
gt_masks (Tensor): Ground truth mask for each image, each with
|
||||
shape (num_gts, h, w).
|
||||
img_metas (dict): Image informtation.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: A tuple containing the following for one image.
|
||||
|
||||
- labels (Tensor): Labels of each image. \
|
||||
shape (num_queries, ).
|
||||
- label_weights (Tensor): Label weights of each image. \
|
||||
shape (num_queries, ).
|
||||
- mask_targets (Tensor): Mask targets of each image. \
|
||||
shape (num_queries, h, w).
|
||||
- mask_weights (Tensor): Mask weights of each image. \
|
||||
shape (num_queries, ).
|
||||
- pos_inds (Tensor): Sampled positive indices for each \
|
||||
image.
|
||||
- neg_inds (Tensor): Sampled negative indices for each \
|
||||
image.
|
||||
"""
|
||||
# sample points
|
||||
num_queries = cls_score.shape[0]
|
||||
num_gts = gt_labels.shape[0]
|
||||
|
||||
point_coords = torch.rand((1, self.num_points, 2), device=cls_score.device)
|
||||
# shape (num_queries, num_points)
|
||||
mask_points_pred = point_sample(mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1, 1)).squeeze(1)
|
||||
# shape (num_gts, num_points)
|
||||
gt_points_masks = point_sample(gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1, 1)).squeeze(1)
|
||||
|
||||
# assign and sample
|
||||
assign_result = self.assigner.assign(cls_score, mask_points_pred, gt_labels, gt_points_masks, img_metas)
|
||||
sampling_result = self.sampler.sample(assign_result, mask_pred, gt_masks)
|
||||
pos_inds = sampling_result.pos_inds
|
||||
neg_inds = sampling_result.neg_inds
|
||||
|
||||
# label target
|
||||
labels = gt_labels.new_full((self.num_queries,), self.num_classes, dtype=torch.long)
|
||||
labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
|
||||
label_weights = gt_labels.new_ones((self.num_queries,))
|
||||
|
||||
# mask target
|
||||
mask_targets = gt_masks[sampling_result.pos_assigned_gt_inds]
|
||||
mask_weights = mask_pred.new_zeros((self.num_queries,))
|
||||
mask_weights[pos_inds] = 1.0
|
||||
|
||||
return (labels, label_weights, mask_targets, mask_weights, pos_inds, neg_inds)
|
||||
|
||||
def loss_single(self, cls_scores, mask_preds, gt_labels_list, gt_masks_list, img_metas):
|
||||
"""Loss function for outputs from a single decoder layer.
|
||||
|
||||
Args:
|
||||
cls_scores (Tensor): Mask score logits from a single decoder layer
|
||||
for all images. Shape (batch_size, num_queries,
|
||||
cls_out_channels). Note `cls_out_channels` should includes
|
||||
background.
|
||||
mask_preds (Tensor): Mask logits for a pixel decoder for all
|
||||
images. Shape (batch_size, num_queries, h, w).
|
||||
gt_labels_list (list[Tensor]): Ground truth class indices for each
|
||||
image, each with shape (num_gts, ).
|
||||
gt_masks_list (list[Tensor]): Ground truth mask for each image,
|
||||
each with shape (num_gts, h, w).
|
||||
img_metas (list[dict]): List of image meta information.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: Loss components for outputs from a single \
|
||||
decoder layer.
|
||||
"""
|
||||
num_imgs = cls_scores.size(0)
|
||||
cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
|
||||
mask_preds_list = [mask_preds[i] for i in range(num_imgs)]
|
||||
(
|
||||
labels_list,
|
||||
label_weights_list,
|
||||
mask_targets_list,
|
||||
mask_weights_list,
|
||||
num_total_pos,
|
||||
num_total_neg,
|
||||
) = self.get_targets(cls_scores_list, mask_preds_list, gt_labels_list, gt_masks_list, img_metas)
|
||||
# shape (batch_size, num_queries)
|
||||
labels = torch.stack(labels_list, dim=0)
|
||||
# shape (batch_size, num_queries)
|
||||
label_weights = torch.stack(label_weights_list, dim=0)
|
||||
# shape (num_total_gts, h, w)
|
||||
mask_targets = torch.cat(mask_targets_list, dim=0)
|
||||
# shape (batch_size, num_queries)
|
||||
mask_weights = torch.stack(mask_weights_list, dim=0)
|
||||
|
||||
# classfication loss
|
||||
# shape (batch_size * num_queries, )
|
||||
cls_scores = cls_scores.flatten(0, 1)
|
||||
labels = labels.flatten(0, 1)
|
||||
label_weights = label_weights.flatten(0, 1)
|
||||
|
||||
class_weight = cls_scores.new_tensor(self.class_weight)
|
||||
loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=class_weight[labels].sum())
|
||||
|
||||
num_total_masks = reduce_mean(cls_scores.new_tensor([num_total_pos]))
|
||||
num_total_masks = max(num_total_masks, 1)
|
||||
|
||||
# extract positive ones
|
||||
# shape (batch_size, num_queries, h, w) -> (num_total_gts, h, w)
|
||||
mask_preds = mask_preds[mask_weights > 0]
|
||||
|
||||
if mask_targets.shape[0] == 0:
|
||||
# zero match
|
||||
loss_dice = mask_preds.sum()
|
||||
loss_mask = mask_preds.sum()
|
||||
return loss_cls, loss_mask, loss_dice
|
||||
|
||||
with torch.no_grad():
|
||||
points_coords = get_uncertain_point_coords_with_randomness(
|
||||
mask_preds.unsqueeze(1), None, self.num_points, self.oversample_ratio, self.importance_sample_ratio
|
||||
)
|
||||
# shape (num_total_gts, h, w) -> (num_total_gts, num_points)
|
||||
mask_point_targets = point_sample(mask_targets.unsqueeze(1).float(), points_coords).squeeze(1)
|
||||
# shape (num_queries, h, w) -> (num_queries, num_points)
|
||||
mask_point_preds = point_sample(mask_preds.unsqueeze(1), points_coords).squeeze(1)
|
||||
|
||||
# dice loss
|
||||
loss_dice = self.loss_dice(mask_point_preds, mask_point_targets, avg_factor=num_total_masks)
|
||||
|
||||
# mask loss
|
||||
# shape (num_queries, num_points) -> (num_queries * num_points, )
|
||||
mask_point_preds = mask_point_preds.reshape(-1, 1)
|
||||
# shape (num_total_gts, num_points) -> (num_total_gts * num_points, )
|
||||
mask_point_targets = mask_point_targets.reshape(-1)
|
||||
loss_mask = self.loss_mask(mask_point_preds, mask_point_targets, avg_factor=num_total_masks * self.num_points)
|
||||
|
||||
return loss_cls, loss_mask, loss_dice
|
||||
|
||||
@force_fp32(apply_to=("all_cls_scores", "all_mask_preds"))
|
||||
def loss(self, all_cls_scores, all_mask_preds, gt_labels_list, gt_masks_list, img_metas):
|
||||
"""Loss function.
|
||||
|
||||
Args:
|
||||
all_cls_scores (Tensor): Classification scores for all decoder
|
||||
layers with shape [num_decoder, batch_size, num_queries,
|
||||
cls_out_channels].
|
||||
all_mask_preds (Tensor): Mask scores for all decoder layers with
|
||||
shape [num_decoder, batch_size, num_queries, h, w].
|
||||
gt_labels_list (list[Tensor]): Ground truth class indices for each
|
||||
image with shape (n, ). n is the sum of number of stuff type
|
||||
and number of instance in a image.
|
||||
gt_masks_list (list[Tensor]): Ground truth mask for each image with
|
||||
shape (n, h, w).
|
||||
img_metas (list[dict]): List of image meta information.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: A dictionary of loss components.
|
||||
"""
|
||||
num_dec_layers = len(all_cls_scores)
|
||||
all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
|
||||
all_gt_masks_list = [gt_masks_list for _ in range(num_dec_layers)]
|
||||
img_metas_list = [img_metas for _ in range(num_dec_layers)]
|
||||
losses_cls, losses_mask, losses_dice = multi_apply(
|
||||
self.loss_single, all_cls_scores, all_mask_preds, all_gt_labels_list, all_gt_masks_list, img_metas_list
|
||||
)
|
||||
|
||||
loss_dict = dict()
|
||||
# loss from the last decoder layer
|
||||
loss_dict["loss_cls"] = losses_cls[-1]
|
||||
loss_dict["loss_mask"] = losses_mask[-1]
|
||||
loss_dict["loss_dice"] = losses_dice[-1]
|
||||
# loss from other decoder layers
|
||||
num_dec_layer = 0
|
||||
for loss_cls_i, loss_mask_i, loss_dice_i in zip(losses_cls[:-1], losses_mask[:-1], losses_dice[:-1]):
|
||||
loss_dict[f"d{num_dec_layer}.loss_cls"] = loss_cls_i
|
||||
loss_dict[f"d{num_dec_layer}.loss_mask"] = loss_mask_i
|
||||
loss_dict[f"d{num_dec_layer}.loss_dice"] = loss_dice_i
|
||||
num_dec_layer += 1
|
||||
return loss_dict
|
||||
|
||||
def forward_head(self, decoder_out, mask_feature, attn_mask_target_size):
|
||||
"""Forward for head part which is called after every decoder layer.
|
||||
|
||||
Args:
|
||||
decoder_out (Tensor): in shape (num_queries, batch_size, c).
|
||||
mask_feature (Tensor): in shape (batch_size, c, h, w).
|
||||
attn_mask_target_size (tuple[int, int]): target attention
|
||||
mask size.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple contain three elements.
|
||||
|
||||
- cls_pred (Tensor): Classification scores in shape \
|
||||
(batch_size, num_queries, cls_out_channels). \
|
||||
Note `cls_out_channels` should includes background.
|
||||
- mask_pred (Tensor): Mask scores in shape \
|
||||
(batch_size, num_queries,h, w).
|
||||
- attn_mask (Tensor): Attention mask in shape \
|
||||
(batch_size * num_heads, num_queries, h, w).
|
||||
"""
|
||||
decoder_out = self.transformer_decoder.post_norm(decoder_out)
|
||||
decoder_out = decoder_out.transpose(0, 1)
|
||||
# shape (num_queries, batch_size, c)
|
||||
cls_pred = self.cls_embed(decoder_out)
|
||||
# shape (num_queries, batch_size, c)
|
||||
mask_embed = self.mask_embed(decoder_out)
|
||||
# shape (num_queries, batch_size, h, w)
|
||||
mask_pred = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_feature)
|
||||
attn_mask = F.interpolate(mask_pred, attn_mask_target_size, mode="bilinear", align_corners=False)
|
||||
# shape (num_queries, batch_size, h, w) ->
|
||||
# (batch_size * num_head, num_queries, h, w)
|
||||
attn_mask = attn_mask.flatten(2).unsqueeze(1).repeat((1, self.num_heads, 1, 1)).flatten(0, 1)
|
||||
attn_mask = attn_mask.sigmoid() < 0.5
|
||||
attn_mask = attn_mask.detach()
|
||||
|
||||
return cls_pred, mask_pred, attn_mask
|
||||
|
||||
def forward(self, feats, img_metas):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
feats (list[Tensor]): Multi scale Features from the
|
||||
upstream network, each is a 4D-tensor.
|
||||
img_metas (list[dict]): List of image information.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple contains two elements.
|
||||
|
||||
- cls_pred_list (list[Tensor)]: Classification logits \
|
||||
for each decoder layer. Each is a 3D-tensor with shape \
|
||||
(batch_size, num_queries, cls_out_channels). \
|
||||
Note `cls_out_channels` should includes background.
|
||||
- mask_pred_list (list[Tensor]): Mask logits for each \
|
||||
decoder layer. Each with shape (batch_size, num_queries, \
|
||||
h, w).
|
||||
"""
|
||||
batch_size = len(img_metas)
|
||||
mask_features, multi_scale_memorys = self.pixel_decoder(feats)
|
||||
# multi_scale_memorys (from low resolution to high resolution)
|
||||
decoder_inputs = []
|
||||
decoder_positional_encodings = []
|
||||
for i in range(self.num_transformer_feat_level):
|
||||
decoder_input = self.decoder_input_projs[i](multi_scale_memorys[i])
|
||||
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
|
||||
decoder_input = decoder_input.flatten(2).permute(2, 0, 1)
|
||||
level_embed = self.level_embed.weight[i].view(1, 1, -1)
|
||||
decoder_input = decoder_input + level_embed
|
||||
# shape (batch_size, c, h, w) -> (h*w, batch_size, c)
|
||||
mask = decoder_input.new_zeros((batch_size,) + multi_scale_memorys[i].shape[-2:], dtype=torch.bool)
|
||||
decoder_positional_encoding = self.decoder_positional_encoding(mask)
|
||||
decoder_positional_encoding = decoder_positional_encoding.flatten(2).permute(2, 0, 1)
|
||||
decoder_inputs.append(decoder_input)
|
||||
decoder_positional_encodings.append(decoder_positional_encoding)
|
||||
# shape (num_queries, c) -> (num_queries, batch_size, c)
|
||||
query_feat = self.query_feat.weight.unsqueeze(1).repeat((1, batch_size, 1))
|
||||
query_embed = self.query_embed.weight.unsqueeze(1).repeat((1, batch_size, 1))
|
||||
|
||||
cls_pred_list = []
|
||||
mask_pred_list = []
|
||||
cls_pred, mask_pred, attn_mask = self.forward_head(query_feat, mask_features, multi_scale_memorys[0].shape[-2:])
|
||||
cls_pred_list.append(cls_pred)
|
||||
mask_pred_list.append(mask_pred)
|
||||
|
||||
for i in range(self.num_transformer_decoder_layers):
|
||||
level_idx = i % self.num_transformer_feat_level
|
||||
# if a mask is all True(all background), then set it all False.
|
||||
attn_mask[torch.where(attn_mask.sum(-1) == attn_mask.shape[-1])] = False
|
||||
|
||||
# cross_attn + self_attn
|
||||
layer = self.transformer_decoder.layers[i]
|
||||
attn_masks = [attn_mask, None]
|
||||
query_feat = layer(
|
||||
query=query_feat,
|
||||
key=decoder_inputs[level_idx],
|
||||
value=decoder_inputs[level_idx],
|
||||
query_pos=query_embed,
|
||||
key_pos=decoder_positional_encodings[level_idx],
|
||||
attn_masks=attn_masks,
|
||||
query_key_padding_mask=None,
|
||||
# here we do not apply masking on padded region
|
||||
key_padding_mask=None,
|
||||
)
|
||||
cls_pred, mask_pred, attn_mask = self.forward_head(
|
||||
query_feat, mask_features, multi_scale_memorys[(i + 1) % self.num_transformer_feat_level].shape[-2:]
|
||||
)
|
||||
|
||||
cls_pred_list.append(cls_pred)
|
||||
mask_pred_list.append(mask_pred)
|
||||
|
||||
return cls_pred_list, mask_pred_list
|
||||
|
||||
def forward_train(self, x, img_metas, gt_semantic_seg, gt_labels, gt_masks):
|
||||
"""Forward function for training mode.
|
||||
|
||||
Args:
|
||||
x (list[Tensor]): Multi-level features from the upstream network,
|
||||
each is a 4D-tensor.
|
||||
img_metas (list[Dict]): List of image information.
|
||||
gt_semantic_seg (list[tensor]):Each element is the ground truth
|
||||
of semantic segmentation with the shape (N, H, W).
|
||||
train_cfg (dict): The training config, which not been used in
|
||||
maskformer.
|
||||
gt_labels (list[Tensor]): Each element is ground truth labels of
|
||||
each box, shape (num_gts,).
|
||||
gt_masks (list[BitmapMasks]): Each element is masks of instances
|
||||
of a image, shape (num_gts, h, w).
|
||||
|
||||
Returns:
|
||||
losses (dict[str, Tensor]): a dictionary of loss components
|
||||
"""
|
||||
|
||||
# forward
|
||||
all_cls_scores, all_mask_preds = self(x, img_metas)
|
||||
|
||||
# loss
|
||||
losses = self.loss(all_cls_scores, all_mask_preds, gt_labels, gt_masks, img_metas)
|
||||
|
||||
return losses
|
||||
|
||||
def forward_test(self, inputs, img_metas, test_cfg):
|
||||
"""Test segment without test-time aumengtation.
|
||||
|
||||
Only the output of last decoder layers was used.
|
||||
|
||||
Args:
|
||||
inputs (list[Tensor]): Multi-level features from the
|
||||
upstream network, each is a 4D-tensor.
|
||||
img_metas (list[dict]): List of image information.
|
||||
test_cfg (dict): Testing config.
|
||||
|
||||
Returns:
|
||||
seg_mask (Tensor): Predicted semantic segmentation logits.
|
||||
"""
|
||||
all_cls_scores, all_mask_preds = self(inputs, img_metas)
|
||||
cls_score, mask_pred = all_cls_scores[-1], all_mask_preds[-1]
|
||||
ori_h, ori_w, _ = img_metas[0]["ori_shape"]
|
||||
|
||||
# semantic inference
|
||||
cls_score = F.softmax(cls_score, dim=-1)[..., :-1]
|
||||
mask_pred = mask_pred.sigmoid()
|
||||
seg_mask = torch.einsum("bqc,bqhw->bchw", cls_score, mask_pred)
|
||||
return seg_mask
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .cross_entropy_loss import CrossEntropyLoss, binary_cross_entropy, cross_entropy, mask_cross_entropy
|
||||
from .dice_loss import DiceLoss
|
||||
from .match_costs import ClassificationCost, CrossEntropyLossCost, DiceCost
|
|
@ -0,0 +1,279 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmseg.models.builder import LOSSES
|
||||
from mmseg.models.losses.utils import get_class_weight, weight_reduce_loss
|
||||
|
||||
|
||||
def cross_entropy(
|
||||
pred,
|
||||
label,
|
||||
weight=None,
|
||||
class_weight=None,
|
||||
reduction="mean",
|
||||
avg_factor=None,
|
||||
ignore_index=-100,
|
||||
avg_non_ignore=False,
|
||||
):
|
||||
"""cross_entropy. The wrapper function for :func:`F.cross_entropy`
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, 1).
|
||||
label (torch.Tensor): The learning label of the prediction.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
Default: None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are 'none', 'mean' and 'sum'. Default: 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Default: None.
|
||||
ignore_index (int): Specifies a target value that is ignored and
|
||||
does not contribute to the input gradients. When
|
||||
``avg_non_ignore `` is ``True``, and the ``reduction`` is
|
||||
``''mean''``, the loss is averaged over non-ignored targets.
|
||||
Defaults: -100.
|
||||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||
only averaged over non-ignored targets. Default: False.
|
||||
`New in version 0.23.0.`
|
||||
"""
|
||||
|
||||
# class_weight is a manual rescaling weight given to each class.
|
||||
# If given, has to be a Tensor of size C element-wise losses
|
||||
loss = F.cross_entropy(pred, label, weight=class_weight, reduction="none", ignore_index=ignore_index)
|
||||
|
||||
# apply weights and do the reduction
|
||||
# average loss over non-ignored elements
|
||||
# pytorch's official cross_entropy average loss over non-ignored elements
|
||||
# refer to https://github.com/pytorch/pytorch/blob/56b43f4fec1f76953f15a627694d4bba34588969/torch/nn/functional.py#L2660 # noqa
|
||||
if (avg_factor is None) and avg_non_ignore and reduction == "mean":
|
||||
avg_factor = label.numel() - (label == ignore_index).sum().item()
|
||||
if weight is not None:
|
||||
weight = weight.float()
|
||||
loss = weight_reduce_loss(loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
|
||||
"""Expand onehot labels to match the size of prediction."""
|
||||
bin_labels = labels.new_zeros(target_shape)
|
||||
valid_mask = (labels >= 0) & (labels != ignore_index)
|
||||
inds = torch.nonzero(valid_mask, as_tuple=True)
|
||||
|
||||
if inds[0].numel() > 0:
|
||||
if labels.dim() == 3:
|
||||
bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
|
||||
else:
|
||||
bin_labels[inds[0], labels[valid_mask]] = 1
|
||||
|
||||
valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
|
||||
|
||||
if label_weights is None:
|
||||
bin_label_weights = valid_mask
|
||||
else:
|
||||
bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
|
||||
bin_label_weights = bin_label_weights * valid_mask
|
||||
|
||||
return bin_labels, bin_label_weights, valid_mask
|
||||
|
||||
|
||||
def binary_cross_entropy(
|
||||
pred,
|
||||
label,
|
||||
weight=None,
|
||||
reduction="mean",
|
||||
avg_factor=None,
|
||||
class_weight=None,
|
||||
ignore_index=-100,
|
||||
avg_non_ignore=False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Calculate the binary CrossEntropy loss.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, 1).
|
||||
label (torch.Tensor): The learning label of the prediction.
|
||||
Note: In bce loss, label < 0 is invalid.
|
||||
weight (torch.Tensor, optional): Sample-wise loss weight.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
ignore_index (int): The label index to be ignored. Default: -100.
|
||||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||
only averaged over non-ignored targets. Default: False.
|
||||
`New in version 0.23.0.`
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
if pred.size(1) == 1:
|
||||
# For binary class segmentation, the shape of pred is
|
||||
# [N, 1, H, W] and that of label is [N, H, W].
|
||||
assert label.max() <= 1, "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes"
|
||||
pred = pred.squeeze()
|
||||
if pred.dim() != label.dim():
|
||||
assert (pred.dim() == 2 and label.dim() == 1) or (pred.dim() == 4 and label.dim() == 3), (
|
||||
"Only pred shape [N, C], label shape [N] or pred shape [N, C, " "H, W], label shape [N, H, W] are supported"
|
||||
)
|
||||
# `weight` returned from `_expand_onehot_labels`
|
||||
# has been treated for valid (non-ignore) pixels
|
||||
label, weight, valid_mask = _expand_onehot_labels(label, weight, pred.shape, ignore_index)
|
||||
else:
|
||||
# should mask out the ignored elements
|
||||
valid_mask = ((label >= 0) & (label != ignore_index)).float()
|
||||
if weight is not None:
|
||||
weight = weight * valid_mask
|
||||
else:
|
||||
weight = valid_mask
|
||||
# average loss over non-ignored and valid elements
|
||||
if reduction == "mean" and avg_factor is None and avg_non_ignore:
|
||||
avg_factor = valid_mask.sum().item()
|
||||
|
||||
loss = F.binary_cross_entropy_with_logits(pred, label.float(), pos_weight=class_weight, reduction="none")
|
||||
# do the reduction for the weighted loss
|
||||
loss = weight_reduce_loss(loss, weight, reduction=reduction, avg_factor=avg_factor)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def mask_cross_entropy(
|
||||
pred, target, label, reduction="mean", avg_factor=None, class_weight=None, ignore_index=None, **kwargs
|
||||
):
|
||||
"""Calculate the CrossEntropy loss for masks.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction with shape (N, C), C is the number
|
||||
of classes.
|
||||
target (torch.Tensor): The learning label of the prediction.
|
||||
label (torch.Tensor): ``label`` indicates the class label of the mask'
|
||||
corresponding object. This will be used to select the mask in the
|
||||
of the class which the object belongs to when the mask prediction
|
||||
if not class-agnostic.
|
||||
reduction (str, optional): The method used to reduce the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
ignore_index (None): Placeholder, to be consistent with other loss.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
assert ignore_index is None, "BCE loss does not support ignore_index"
|
||||
assert reduction == "mean" and avg_factor is None
|
||||
num_rois = pred.size()[0]
|
||||
inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
|
||||
pred_slice = pred[inds, label].squeeze(1)
|
||||
return F.binary_cross_entropy_with_logits(pred_slice, target, weight=class_weight, reduction="mean")[None]
|
||||
|
||||
|
||||
@LOSSES.register_module(force=True)
|
||||
class CrossEntropyLoss(nn.Module):
|
||||
"""CrossEntropyLoss.
|
||||
|
||||
Args:
|
||||
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
||||
of softmax. Defaults to False.
|
||||
use_mask (bool, optional): Whether to use mask cross entropy loss.
|
||||
Defaults to False.
|
||||
reduction (str, optional): . Defaults to 'mean'.
|
||||
Options are "none", "mean" and "sum".
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||||
loss_name (str, optional): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_ce'.
|
||||
avg_non_ignore (bool): The flag decides to whether the loss is
|
||||
only averaged over non-ignored targets. Default: False.
|
||||
`New in version 0.23.0.`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_sigmoid=False,
|
||||
use_mask=False,
|
||||
reduction="mean",
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
loss_name="loss_ce",
|
||||
avg_non_ignore=False,
|
||||
):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
assert (use_sigmoid is False) or (use_mask is False)
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.use_mask = use_mask
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self.avg_non_ignore = avg_non_ignore
|
||||
if not self.avg_non_ignore and self.reduction == "mean":
|
||||
warnings.warn(
|
||||
"Default ``avg_non_ignore`` is False, if you would like to "
|
||||
"ignore the certain label and average loss over non-ignore "
|
||||
"labels, which is the same with PyTorch official "
|
||||
"cross_entropy, set ``avg_non_ignore=True``."
|
||||
)
|
||||
|
||||
if self.use_sigmoid:
|
||||
self.cls_criterion = binary_cross_entropy
|
||||
elif self.use_mask:
|
||||
self.cls_criterion = mask_cross_entropy
|
||||
else:
|
||||
self.cls_criterion = cross_entropy
|
||||
self._loss_name = loss_name
|
||||
|
||||
def extra_repr(self):
|
||||
"""Extra repr."""
|
||||
s = f"avg_non_ignore={self.avg_non_ignore}"
|
||||
return s
|
||||
|
||||
def forward(
|
||||
self, cls_score, label, weight=None, avg_factor=None, reduction_override=None, ignore_index=-100, **kwargs
|
||||
):
|
||||
"""Forward function."""
|
||||
assert reduction_override in (None, "none", "mean", "sum")
|
||||
reduction = reduction_override if reduction_override else self.reduction
|
||||
if self.class_weight is not None:
|
||||
class_weight = cls_score.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
# Note: for BCE loss, label < 0 is invalid.
|
||||
loss_cls = self.loss_weight * self.cls_criterion(
|
||||
cls_score,
|
||||
label,
|
||||
weight,
|
||||
class_weight=class_weight,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
avg_non_ignore=self.avg_non_ignore,
|
||||
ignore_index=ignore_index,
|
||||
**kwargs,
|
||||
)
|
||||
return loss_cls
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
"""Loss Name.
|
||||
|
||||
This function must be implemented and will return the name of this
|
||||
loss function. This name will be used to combine different loss items
|
||||
by simple sum operation. In addition, if you want this loss item to be
|
||||
included into the backward graph, `loss_` must be the prefix of the
|
||||
name.
|
||||
|
||||
Returns:
|
||||
str: The name of this loss item.
|
||||
"""
|
||||
return self._loss_name
|
|
@ -0,0 +1,153 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmseg.models.builder import LOSSES
|
||||
from mmseg.models.losses.utils import weight_reduce_loss
|
||||
|
||||
|
||||
def dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None):
|
||||
"""Calculate dice loss, which is proposed in
|
||||
`V-Net: Fully Convolutional Neural Networks for Volumetric
|
||||
Medical Image Segmentation <https://arxiv.org/abs/1606.04797>`_.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction, has a shape (n, *)
|
||||
target (torch.Tensor): The learning label of the prediction,
|
||||
shape (n, *), same shape of pred.
|
||||
weight (torch.Tensor, optional): The weight of loss for each
|
||||
prediction, has a shape (n,). Defaults to None.
|
||||
eps (float): Avoid dividing by zero. Default: 1e-3.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'.
|
||||
Options are "none", "mean" and "sum".
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
"""
|
||||
|
||||
input = pred.flatten(1)
|
||||
target = target.flatten(1).float()
|
||||
|
||||
a = torch.sum(input * target, 1)
|
||||
b = torch.sum(input * input, 1) + eps
|
||||
c = torch.sum(target * target, 1) + eps
|
||||
d = (2 * a) / (b + c)
|
||||
loss = 1 - d
|
||||
if weight is not None:
|
||||
assert weight.ndim == loss.ndim
|
||||
assert len(weight) == len(pred)
|
||||
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
def naive_dice_loss(pred, target, weight=None, eps=1e-3, reduction="mean", avg_factor=None):
|
||||
"""Calculate naive dice loss, the coefficient in the denominator is the
|
||||
first power instead of the second power.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction, has a shape (n, *)
|
||||
target (torch.Tensor): The learning label of the prediction,
|
||||
shape (n, *), same shape of pred.
|
||||
weight (torch.Tensor, optional): The weight of loss for each
|
||||
prediction, has a shape (n,). Defaults to None.
|
||||
eps (float): Avoid dividing by zero. Default: 1e-3.
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'.
|
||||
Options are "none", "mean" and "sum".
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
"""
|
||||
input = pred.flatten(1)
|
||||
target = target.flatten(1).float()
|
||||
|
||||
a = torch.sum(input * target, 1)
|
||||
b = torch.sum(input, 1)
|
||||
c = torch.sum(target, 1)
|
||||
d = (2 * a + eps) / (b + c + eps)
|
||||
loss = 1 - d
|
||||
if weight is not None:
|
||||
assert weight.ndim == loss.ndim
|
||||
assert len(weight) == len(pred)
|
||||
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
|
||||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module(force=True)
|
||||
class DiceLoss(nn.Module):
|
||||
def __init__(self, use_sigmoid=True, activate=True, reduction="mean", naive_dice=False, loss_weight=1.0, eps=1e-3):
|
||||
"""Dice Loss, there are two forms of dice loss is supported:
|
||||
|
||||
- the one proposed in `V-Net: Fully Convolutional Neural
|
||||
Networks for Volumetric Medical Image Segmentation
|
||||
<https://arxiv.org/abs/1606.04797>`_.
|
||||
- the dice loss in which the power of the number in the
|
||||
denominator is the first power instead of the second
|
||||
power.
|
||||
|
||||
Args:
|
||||
use_sigmoid (bool, optional): Whether to the prediction is
|
||||
used for sigmoid or softmax. Defaults to True.
|
||||
activate (bool): Whether to activate the predictions inside,
|
||||
this will disable the inside sigmoid operation.
|
||||
Defaults to True.
|
||||
reduction (str, optional): The method used
|
||||
to reduce the loss. Options are "none",
|
||||
"mean" and "sum". Defaults to 'mean'.
|
||||
naive_dice (bool, optional): If false, use the dice
|
||||
loss defined in the V-Net paper, otherwise, use the
|
||||
naive dice loss in which the power of the number in the
|
||||
denominator is the first power instead of the second
|
||||
power.Defaults to False.
|
||||
loss_weight (float, optional): Weight of loss. Defaults to 1.0.
|
||||
eps (float): Avoid dividing by zero. Defaults to 1e-3.
|
||||
"""
|
||||
|
||||
super(DiceLoss, self).__init__()
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.reduction = reduction
|
||||
self.naive_dice = naive_dice
|
||||
self.loss_weight = loss_weight
|
||||
self.eps = eps
|
||||
self.activate = activate
|
||||
|
||||
def forward(self, pred, target, weight=None, reduction_override=None, avg_factor=None):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The prediction, has a shape (n, *).
|
||||
target (torch.Tensor): The label of the prediction,
|
||||
shape (n, *), same shape of pred.
|
||||
weight (torch.Tensor, optional): The weight of loss for each
|
||||
prediction, has a shape (n,). Defaults to None.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. Defaults to None.
|
||||
reduction_override (str, optional): The reduction method used to
|
||||
override the original reduction method of the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss
|
||||
"""
|
||||
|
||||
assert reduction_override in (None, "none", "mean", "sum")
|
||||
reduction = reduction_override if reduction_override else self.reduction
|
||||
|
||||
if self.activate:
|
||||
if self.use_sigmoid:
|
||||
pred = pred.sigmoid()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if self.naive_dice:
|
||||
loss = self.loss_weight * naive_dice_loss(
|
||||
pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor
|
||||
)
|
||||
else:
|
||||
loss = self.loss_weight * dice_loss(
|
||||
pred, target, weight, eps=self.eps, reduction=reduction, avg_factor=avg_factor
|
||||
)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,153 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import MATCH_COST
|
||||
|
||||
|
||||
@MATCH_COST.register_module()
|
||||
class ClassificationCost:
|
||||
"""ClsSoftmaxCost.Borrow from
|
||||
mmdet.core.bbox.match_costs.match_cost.ClassificationCost.
|
||||
|
||||
Args:
|
||||
weight (int | float, optional): loss_weight
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> self = ClassificationCost()
|
||||
>>> cls_pred = torch.rand(4, 3)
|
||||
>>> gt_labels = torch.tensor([0, 1, 2])
|
||||
>>> factor = torch.tensor([10, 8, 10, 8])
|
||||
>>> self(cls_pred, gt_labels)
|
||||
tensor([[-0.3430, -0.3525, -0.3045],
|
||||
[-0.3077, -0.2931, -0.3992],
|
||||
[-0.3664, -0.3455, -0.2881],
|
||||
[-0.3343, -0.2701, -0.3956]])
|
||||
"""
|
||||
|
||||
def __init__(self, weight=1.0):
|
||||
self.weight = weight
|
||||
|
||||
def __call__(self, cls_pred, gt_labels):
|
||||
"""
|
||||
Args:
|
||||
cls_pred (Tensor): Predicted classification logits, shape
|
||||
[num_query, num_class].
|
||||
gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: cls_cost value with weight
|
||||
"""
|
||||
# Following the official DETR repo, contrary to the loss that
|
||||
# NLL is used, we approximate it in 1 - cls_score[gt_label].
|
||||
# The 1 is a constant that doesn't change the matching,
|
||||
# so it can be omitted.
|
||||
cls_score = cls_pred.softmax(-1)
|
||||
cls_cost = -cls_score[:, gt_labels]
|
||||
return cls_cost * self.weight
|
||||
|
||||
|
||||
@MATCH_COST.register_module()
|
||||
class DiceCost:
|
||||
"""Cost of mask assignments based on dice losses.
|
||||
|
||||
Args:
|
||||
weight (int | float, optional): loss_weight. Defaults to 1.
|
||||
pred_act (bool, optional): Whether to apply sigmoid to mask_pred.
|
||||
Defaults to False.
|
||||
eps (float, optional): default 1e-12.
|
||||
"""
|
||||
|
||||
def __init__(self, weight=1.0, pred_act=False, eps=1e-3):
|
||||
self.weight = weight
|
||||
self.pred_act = pred_act
|
||||
self.eps = eps
|
||||
|
||||
def binary_mask_dice_loss(self, mask_preds, gt_masks):
|
||||
"""
|
||||
Args:
|
||||
mask_preds (Tensor): Mask prediction in shape (N1, H, W).
|
||||
gt_masks (Tensor): Ground truth in shape (N2, H, W)
|
||||
store 0 or 1, 0 for negative class and 1 for
|
||||
positive class.
|
||||
|
||||
Returns:
|
||||
Tensor: Dice cost matrix in shape (N1, N2).
|
||||
"""
|
||||
mask_preds = mask_preds.reshape((mask_preds.shape[0], -1))
|
||||
gt_masks = gt_masks.reshape((gt_masks.shape[0], -1)).float()
|
||||
numerator = 2 * torch.einsum("nc,mc->nm", mask_preds, gt_masks)
|
||||
denominator = mask_preds.sum(-1)[:, None] + gt_masks.sum(-1)[None, :]
|
||||
loss = 1 - (numerator + self.eps) / (denominator + self.eps)
|
||||
return loss
|
||||
|
||||
def __call__(self, mask_preds, gt_masks):
|
||||
"""
|
||||
Args:
|
||||
mask_preds (Tensor): Mask prediction logits in shape (N1, H, W).
|
||||
gt_masks (Tensor): Ground truth in shape (N2, H, W).
|
||||
|
||||
Returns:
|
||||
Tensor: Dice cost matrix in shape (N1, N2).
|
||||
"""
|
||||
if self.pred_act:
|
||||
mask_preds = mask_preds.sigmoid()
|
||||
dice_cost = self.binary_mask_dice_loss(mask_preds, gt_masks)
|
||||
return dice_cost * self.weight
|
||||
|
||||
|
||||
@MATCH_COST.register_module()
|
||||
class CrossEntropyLossCost:
|
||||
"""CrossEntropyLossCost.
|
||||
|
||||
Args:
|
||||
weight (int | float, optional): loss weight. Defaults to 1.
|
||||
use_sigmoid (bool, optional): Whether the prediction uses sigmoid
|
||||
of softmax. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, weight=1.0, use_sigmoid=True):
|
||||
assert use_sigmoid, "use_sigmoid = False is not supported yet."
|
||||
self.weight = weight
|
||||
self.use_sigmoid = use_sigmoid
|
||||
|
||||
def _binary_cross_entropy(self, cls_pred, gt_labels):
|
||||
"""
|
||||
Args:
|
||||
cls_pred (Tensor): The prediction with shape (num_query, 1, *) or
|
||||
(num_query, *).
|
||||
gt_labels (Tensor): The learning label of prediction with
|
||||
shape (num_gt, *).
|
||||
Returns:
|
||||
Tensor: Cross entropy cost matrix in shape (num_query, num_gt).
|
||||
"""
|
||||
cls_pred = cls_pred.flatten(1).float()
|
||||
gt_labels = gt_labels.flatten(1).float()
|
||||
n = cls_pred.shape[1]
|
||||
pos = F.binary_cross_entropy_with_logits(cls_pred, torch.ones_like(cls_pred), reduction="none")
|
||||
neg = F.binary_cross_entropy_with_logits(cls_pred, torch.zeros_like(cls_pred), reduction="none")
|
||||
cls_cost = torch.einsum("nc,mc->nm", pos, gt_labels) + torch.einsum("nc,mc->nm", neg, 1 - gt_labels)
|
||||
cls_cost = cls_cost / n
|
||||
|
||||
return cls_cost
|
||||
|
||||
def __call__(self, cls_pred, gt_labels):
|
||||
"""
|
||||
Args:
|
||||
cls_pred (Tensor): Predicted classification logits.
|
||||
gt_labels (Tensor): Labels.
|
||||
Returns:
|
||||
Tensor: Cross entropy cost matrix with weight in
|
||||
shape (num_query, num_gt).
|
||||
"""
|
||||
if self.use_sigmoid:
|
||||
cls_cost = self._binary_cross_entropy(cls_pred, gt_labels)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return cls_cost * self.weight
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder
|
|
@ -0,0 +1,242 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init, normal_init, xavier_init
|
||||
from mmcv.cnn.bricks.transformer import build_positional_encoding, build_transformer_layer_sequence
|
||||
from mmcv.runner import BaseModule, ModuleList
|
||||
|
||||
from ...core.anchor import MlvlPointGenerator
|
||||
from ..utils.transformer import MultiScaleDeformableAttention
|
||||
|
||||
|
||||
@PLUGIN_LAYERS.register_module()
|
||||
class MSDeformAttnPixelDecoder(BaseModule):
|
||||
"""Pixel decoder with multi-scale deformable attention.
|
||||
|
||||
Args:
|
||||
in_channels (list[int] | tuple[int]): Number of channels in the
|
||||
input feature maps.
|
||||
strides (list[int] | tuple[int]): Output strides of feature from
|
||||
backbone.
|
||||
feat_channels (int): Number of channels for feature.
|
||||
out_channels (int): Number of channels for output.
|
||||
num_outs (int): Number of output scales.
|
||||
norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization.
|
||||
Defaults to dict(type='GN', num_groups=32).
|
||||
act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation.
|
||||
Defaults to dict(type='ReLU').
|
||||
encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer
|
||||
encoder. Defaults to `DetrTransformerEncoder`.
|
||||
positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for
|
||||
transformer encoder position encoding. Defaults to
|
||||
dict(type='SinePositionalEncoding', num_feats=128,
|
||||
normalize=True).
|
||||
init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=[256, 512, 1024, 2048],
|
||||
strides=[4, 8, 16, 32],
|
||||
feat_channels=256,
|
||||
out_channels=256,
|
||||
num_outs=3,
|
||||
norm_cfg=dict(type="GN", num_groups=32),
|
||||
act_cfg=dict(type="ReLU"),
|
||||
encoder=dict(
|
||||
type="DetrTransformerEncoder",
|
||||
num_layers=6,
|
||||
transformerlayers=dict(
|
||||
type="BaseTransformerLayer",
|
||||
attn_cfgs=dict(
|
||||
type="MultiScaleDeformableAttention",
|
||||
embed_dims=256,
|
||||
num_heads=8,
|
||||
num_levels=3,
|
||||
num_points=4,
|
||||
im2col_step=64,
|
||||
dropout=0.0,
|
||||
batch_first=False,
|
||||
norm_cfg=None,
|
||||
init_cfg=None,
|
||||
),
|
||||
feedforward_channels=1024,
|
||||
ffn_dropout=0.0,
|
||||
operation_order=("self_attn", "norm", "ffn", "norm"),
|
||||
),
|
||||
init_cfg=None,
|
||||
),
|
||||
positional_encoding=dict(type="SinePositionalEncoding", num_feats=128, normalize=True),
|
||||
init_cfg=None,
|
||||
):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.strides = strides
|
||||
self.num_input_levels = len(in_channels)
|
||||
self.num_encoder_levels = encoder.transformerlayers.attn_cfgs.num_levels
|
||||
assert self.num_encoder_levels >= 1, "num_levels in attn_cfgs must be at least one"
|
||||
input_conv_list = []
|
||||
# from top to down (low to high resolution)
|
||||
for i in range(self.num_input_levels - 1, self.num_input_levels - self.num_encoder_levels - 1, -1):
|
||||
input_conv = ConvModule(
|
||||
in_channels[i], feat_channels, kernel_size=1, norm_cfg=norm_cfg, act_cfg=None, bias=True
|
||||
)
|
||||
input_conv_list.append(input_conv)
|
||||
self.input_convs = ModuleList(input_conv_list)
|
||||
|
||||
self.encoder = build_transformer_layer_sequence(encoder)
|
||||
self.postional_encoding = build_positional_encoding(positional_encoding)
|
||||
# high resolution to low resolution
|
||||
self.level_encoding = nn.Embedding(self.num_encoder_levels, feat_channels)
|
||||
|
||||
# fpn-like structure
|
||||
self.lateral_convs = ModuleList()
|
||||
self.output_convs = ModuleList()
|
||||
self.use_bias = norm_cfg is None
|
||||
# from top to down (low to high resolution)
|
||||
# fpn for the rest features that didn't pass in encoder
|
||||
for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1):
|
||||
lateral_conv = ConvModule(
|
||||
in_channels[i], feat_channels, kernel_size=1, bias=self.use_bias, norm_cfg=norm_cfg, act_cfg=None
|
||||
)
|
||||
output_conv = ConvModule(
|
||||
feat_channels,
|
||||
feat_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=self.use_bias,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
)
|
||||
self.lateral_convs.append(lateral_conv)
|
||||
self.output_convs.append(output_conv)
|
||||
|
||||
self.mask_feature = Conv2d(feat_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.num_outs = num_outs
|
||||
self.point_generator = MlvlPointGenerator(strides)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize weights."""
|
||||
for i in range(0, self.num_encoder_levels):
|
||||
xavier_init(self.input_convs[i].conv, gain=1, bias=0, distribution="uniform")
|
||||
|
||||
for i in range(0, self.num_input_levels - self.num_encoder_levels):
|
||||
caffe2_xavier_init(self.lateral_convs[i].conv, bias=0)
|
||||
caffe2_xavier_init(self.output_convs[i].conv, bias=0)
|
||||
|
||||
caffe2_xavier_init(self.mask_feature, bias=0)
|
||||
|
||||
normal_init(self.level_encoding, mean=0, std=1)
|
||||
for p in self.encoder.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_normal_(p)
|
||||
|
||||
# init_weights defined in MultiScaleDeformableAttention
|
||||
for layer in self.encoder.layers:
|
||||
for attn in layer.attentions:
|
||||
if isinstance(attn, MultiScaleDeformableAttention):
|
||||
attn.init_weights()
|
||||
|
||||
def forward(self, feats):
|
||||
"""
|
||||
Args:
|
||||
feats (list[Tensor]): Feature maps of each level. Each has
|
||||
shape of (batch_size, c, h, w).
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the following:
|
||||
|
||||
- mask_feature (Tensor): shape (batch_size, c, h, w).
|
||||
- multi_scale_features (list[Tensor]): Multi scale \
|
||||
features, each in shape (batch_size, c, h, w).
|
||||
"""
|
||||
# generate padding mask for each level, for each image
|
||||
batch_size = feats[0].shape[0]
|
||||
encoder_input_list = []
|
||||
padding_mask_list = []
|
||||
level_positional_encoding_list = []
|
||||
spatial_shapes = []
|
||||
reference_points_list = []
|
||||
for i in range(self.num_encoder_levels):
|
||||
level_idx = self.num_input_levels - i - 1
|
||||
feat = feats[level_idx]
|
||||
feat_projected = self.input_convs[i](feat)
|
||||
h, w = feat.shape[-2:]
|
||||
|
||||
# no padding
|
||||
padding_mask_resized = feat.new_zeros((batch_size,) + feat.shape[-2:], dtype=torch.bool)
|
||||
pos_embed = self.postional_encoding(padding_mask_resized)
|
||||
level_embed = self.level_encoding.weight[i]
|
||||
level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed
|
||||
# (h_i * w_i, 2)
|
||||
reference_points = self.point_generator.single_level_grid_priors(
|
||||
feat.shape[-2:], level_idx, device=feat.device
|
||||
)
|
||||
# normalize
|
||||
factor = feat.new_tensor([[w, h]]) * self.strides[level_idx]
|
||||
reference_points = reference_points / factor
|
||||
|
||||
# shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c)
|
||||
feat_projected = feat_projected.flatten(2).permute(2, 0, 1)
|
||||
level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1)
|
||||
padding_mask_resized = padding_mask_resized.flatten(1)
|
||||
|
||||
encoder_input_list.append(feat_projected)
|
||||
padding_mask_list.append(padding_mask_resized)
|
||||
level_positional_encoding_list.append(level_pos_embed)
|
||||
spatial_shapes.append(feat.shape[-2:])
|
||||
reference_points_list.append(reference_points)
|
||||
# shape (batch_size, total_num_query),
|
||||
# total_num_query=sum([., h_i * w_i,.])
|
||||
padding_masks = torch.cat(padding_mask_list, dim=1)
|
||||
# shape (total_num_query, batch_size, c)
|
||||
encoder_inputs = torch.cat(encoder_input_list, dim=0)
|
||||
level_positional_encodings = torch.cat(level_positional_encoding_list, dim=0)
|
||||
device = encoder_inputs.device
|
||||
# shape (num_encoder_levels, 2), from low
|
||||
# resolution to high resolution
|
||||
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=device)
|
||||
# shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...)
|
||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||
reference_points = torch.cat(reference_points_list, dim=0)
|
||||
reference_points = reference_points[None, :, None].repeat(batch_size, 1, self.num_encoder_levels, 1)
|
||||
valid_radios = reference_points.new_ones((batch_size, self.num_encoder_levels, 2))
|
||||
# shape (num_total_query, batch_size, c)
|
||||
memory = self.encoder(
|
||||
query=encoder_inputs,
|
||||
key=None,
|
||||
value=None,
|
||||
query_pos=level_positional_encodings,
|
||||
key_pos=None,
|
||||
attn_masks=None,
|
||||
key_padding_mask=None,
|
||||
query_key_padding_mask=padding_masks,
|
||||
spatial_shapes=spatial_shapes,
|
||||
reference_points=reference_points,
|
||||
level_start_index=level_start_index,
|
||||
valid_radios=valid_radios,
|
||||
)
|
||||
# (num_total_query, batch_size, c) -> (batch_size, c, num_total_query)
|
||||
memory = memory.permute(1, 2, 0)
|
||||
|
||||
# from low resolution to high resolution
|
||||
num_query_per_level = [e[0] * e[1] for e in spatial_shapes]
|
||||
outs = torch.split(memory, num_query_per_level, dim=-1)
|
||||
outs = [x.reshape(batch_size, -1, spatial_shapes[i][0], spatial_shapes[i][1]) for i, x in enumerate(outs)]
|
||||
|
||||
for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, -1):
|
||||
x = feats[i]
|
||||
cur_feat = self.lateral_convs[i](x)
|
||||
y = cur_feat + F.interpolate(outs[-1], size=cur_feat.shape[-2:], mode="bilinear", align_corners=False)
|
||||
y = self.output_convs[i](y)
|
||||
outs.append(y)
|
||||
multi_scale_features = outs[: self.num_outs]
|
||||
|
||||
mask_feature = self.mask_feature(outs[-1])
|
||||
return mask_feature, multi_scale_features
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .encoder_decoder_mask2former import EncoderDecoderMask2Former
|
|
@ -0,0 +1,271 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmseg.core import add_prefix
|
||||
from mmseg.models import builder
|
||||
from mmseg.models.builder import SEGMENTORS
|
||||
from mmseg.models.segmentors.base import BaseSegmentor
|
||||
from mmseg.ops import resize
|
||||
|
||||
|
||||
@SEGMENTORS.register_module()
|
||||
class EncoderDecoderMask2Former(BaseSegmentor):
|
||||
"""Encoder Decoder segmentors.
|
||||
|
||||
EncoderDecoder typically consists of backbone, decode_head, auxiliary_head.
|
||||
Note that auxiliary_head is only used for deep supervision during training,
|
||||
which could be dumped during inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone,
|
||||
decode_head,
|
||||
neck=None,
|
||||
auxiliary_head=None,
|
||||
train_cfg=None,
|
||||
test_cfg=None,
|
||||
pretrained=None,
|
||||
init_cfg=None,
|
||||
):
|
||||
super(EncoderDecoderMask2Former, self).__init__(init_cfg)
|
||||
if pretrained is not None:
|
||||
assert backbone.get("pretrained") is None, "both backbone and segmentor set pretrained weight"
|
||||
backbone.pretrained = pretrained
|
||||
self.backbone = builder.build_backbone(backbone)
|
||||
if neck is not None:
|
||||
self.neck = builder.build_neck(neck)
|
||||
decode_head.update(train_cfg=train_cfg)
|
||||
decode_head.update(test_cfg=test_cfg)
|
||||
self._init_decode_head(decode_head)
|
||||
self._init_auxiliary_head(auxiliary_head)
|
||||
|
||||
self.train_cfg = train_cfg
|
||||
self.test_cfg = test_cfg
|
||||
|
||||
assert self.with_decode_head
|
||||
|
||||
def _init_decode_head(self, decode_head):
|
||||
"""Initialize ``decode_head``"""
|
||||
self.decode_head = builder.build_head(decode_head)
|
||||
self.align_corners = self.decode_head.align_corners
|
||||
self.num_classes = self.decode_head.num_classes
|
||||
|
||||
def _init_auxiliary_head(self, auxiliary_head):
|
||||
"""Initialize ``auxiliary_head``"""
|
||||
if auxiliary_head is not None:
|
||||
if isinstance(auxiliary_head, list):
|
||||
self.auxiliary_head = nn.ModuleList()
|
||||
for head_cfg in auxiliary_head:
|
||||
self.auxiliary_head.append(builder.build_head(head_cfg))
|
||||
else:
|
||||
self.auxiliary_head = builder.build_head(auxiliary_head)
|
||||
|
||||
def extract_feat(self, img):
|
||||
"""Extract features from images."""
|
||||
x = self.backbone(img)
|
||||
if self.with_neck:
|
||||
x = self.neck(x)
|
||||
return x
|
||||
|
||||
def encode_decode(self, img, img_metas):
|
||||
"""Encode images with backbone and decode into a semantic segmentation
|
||||
map of the same size as input."""
|
||||
x = self.extract_feat(img)
|
||||
out = self._decode_head_forward_test(x, img_metas)
|
||||
out = resize(input=out, size=img.shape[2:], mode="bilinear", align_corners=self.align_corners)
|
||||
return out
|
||||
|
||||
def _decode_head_forward_train(self, x, img_metas, gt_semantic_seg, **kwargs):
|
||||
"""Run forward function and calculate loss for decode head in
|
||||
training."""
|
||||
losses = dict()
|
||||
loss_decode = self.decode_head.forward_train(x, img_metas, gt_semantic_seg, **kwargs)
|
||||
|
||||
losses.update(add_prefix(loss_decode, "decode"))
|
||||
return losses
|
||||
|
||||
def _decode_head_forward_test(self, x, img_metas):
|
||||
"""Run forward function and calculate loss for decode head in
|
||||
inference."""
|
||||
seg_logits = self.decode_head.forward_test(x, img_metas, self.test_cfg)
|
||||
return seg_logits
|
||||
|
||||
def _auxiliary_head_forward_train(self, x, img_metas, gt_semantic_seg):
|
||||
"""Run forward function and calculate loss for auxiliary head in
|
||||
training."""
|
||||
losses = dict()
|
||||
if isinstance(self.auxiliary_head, nn.ModuleList):
|
||||
for idx, aux_head in enumerate(self.auxiliary_head):
|
||||
loss_aux = aux_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg)
|
||||
losses.update(add_prefix(loss_aux, f"aux_{idx}"))
|
||||
else:
|
||||
loss_aux = self.auxiliary_head.forward_train(x, img_metas, gt_semantic_seg, self.train_cfg)
|
||||
losses.update(add_prefix(loss_aux, "aux"))
|
||||
|
||||
return losses
|
||||
|
||||
def forward_dummy(self, img):
|
||||
"""Dummy forward function."""
|
||||
seg_logit = self.encode_decode(img, None)
|
||||
|
||||
return seg_logit
|
||||
|
||||
def forward_train(self, img, img_metas, gt_semantic_seg, **kwargs):
|
||||
"""Forward function for training.
|
||||
|
||||
Args:
|
||||
img (Tensor): Input images.
|
||||
img_metas (list[dict]): List of image info dict where each dict
|
||||
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
gt_semantic_seg (Tensor): Semantic segmentation masks
|
||||
used if the architecture supports semantic segmentation task.
|
||||
|
||||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
|
||||
x = self.extract_feat(img)
|
||||
|
||||
losses = dict()
|
||||
|
||||
loss_decode = self._decode_head_forward_train(x, img_metas, gt_semantic_seg, **kwargs)
|
||||
losses.update(loss_decode)
|
||||
|
||||
if self.with_auxiliary_head:
|
||||
loss_aux = self._auxiliary_head_forward_train(x, img_metas, gt_semantic_seg)
|
||||
losses.update(loss_aux)
|
||||
|
||||
return losses
|
||||
|
||||
def slide_inference(self, img, img_meta, rescale):
|
||||
"""Inference by sliding-window with overlap.
|
||||
|
||||
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
||||
decode without padding.
|
||||
"""
|
||||
|
||||
h_stride, w_stride = self.test_cfg.stride
|
||||
h_crop, w_crop = self.test_cfg.crop_size
|
||||
batch_size, _, h_img, w_img = img.size()
|
||||
num_classes = self.num_classes
|
||||
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
|
||||
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
|
||||
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
|
||||
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
||||
for h_idx in range(h_grids):
|
||||
for w_idx in range(w_grids):
|
||||
y1 = h_idx * h_stride
|
||||
x1 = w_idx * w_stride
|
||||
y2 = min(y1 + h_crop, h_img)
|
||||
x2 = min(x1 + w_crop, w_img)
|
||||
y1 = max(y2 - h_crop, 0)
|
||||
x1 = max(x2 - w_crop, 0)
|
||||
crop_img = img[:, :, y1:y2, x1:x2]
|
||||
crop_seg_logit = self.encode_decode(crop_img, img_meta)
|
||||
preds += F.pad(crop_seg_logit, (int(x1), int(preds.shape[3] - x2), int(y1), int(preds.shape[2] - y2)))
|
||||
|
||||
count_mat[:, :, y1:y2, x1:x2] += 1
|
||||
assert (count_mat == 0).sum() == 0
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
# cast count_mat to constant while exporting to ONNX
|
||||
count_mat = torch.from_numpy(count_mat.cpu().detach().numpy()).to(device=img.device)
|
||||
preds = preds / count_mat
|
||||
if rescale:
|
||||
preds = resize(
|
||||
preds,
|
||||
size=img_meta[0]["ori_shape"][:2],
|
||||
mode="bilinear",
|
||||
align_corners=self.align_corners,
|
||||
warning=False,
|
||||
)
|
||||
return preds
|
||||
|
||||
def whole_inference(self, img, img_meta, rescale):
|
||||
"""Inference with full image."""
|
||||
|
||||
seg_logit = self.encode_decode(img, img_meta)
|
||||
if rescale:
|
||||
# support dynamic shape for onnx
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
size = img.shape[2:]
|
||||
else:
|
||||
size = img_meta[0]["ori_shape"][:2]
|
||||
seg_logit = resize(seg_logit, size=size, mode="bilinear", align_corners=self.align_corners, warning=False)
|
||||
|
||||
return seg_logit
|
||||
|
||||
def inference(self, img, img_meta, rescale):
|
||||
"""Inference with slide/whole style.
|
||||
|
||||
Args:
|
||||
img (Tensor): The input image of shape (N, 3, H, W).
|
||||
img_meta (dict): Image info dict where each dict has: 'img_shape',
|
||||
'scale_factor', 'flip', and may also contain
|
||||
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
||||
For details on the values of these keys see
|
||||
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
||||
rescale (bool): Whether rescale back to original shape.
|
||||
|
||||
Returns:
|
||||
Tensor: The output segmentation map.
|
||||
"""
|
||||
|
||||
assert self.test_cfg.mode in ["slide", "whole"]
|
||||
ori_shape = img_meta[0]["ori_shape"]
|
||||
assert all(_["ori_shape"] == ori_shape for _ in img_meta)
|
||||
if self.test_cfg.mode == "slide":
|
||||
seg_logit = self.slide_inference(img, img_meta, rescale)
|
||||
else:
|
||||
seg_logit = self.whole_inference(img, img_meta, rescale)
|
||||
output = F.softmax(seg_logit, dim=1)
|
||||
flip = img_meta[0]["flip"]
|
||||
if flip:
|
||||
flip_direction = img_meta[0]["flip_direction"]
|
||||
assert flip_direction in ["horizontal", "vertical"]
|
||||
if flip_direction == "horizontal":
|
||||
output = output.flip(dims=(3,))
|
||||
elif flip_direction == "vertical":
|
||||
output = output.flip(dims=(2,))
|
||||
|
||||
return output
|
||||
|
||||
def simple_test(self, img, img_meta, rescale=True):
|
||||
"""Simple test with single image."""
|
||||
seg_logit = self.inference(img, img_meta, rescale)
|
||||
seg_pred = seg_logit.argmax(dim=1)
|
||||
if torch.onnx.is_in_onnx_export():
|
||||
# our inference backend only support 4D output
|
||||
seg_pred = seg_pred.unsqueeze(0)
|
||||
return seg_pred
|
||||
seg_pred = seg_pred.cpu().numpy()
|
||||
# unravel batch dim
|
||||
seg_pred = list(seg_pred)
|
||||
return seg_pred
|
||||
|
||||
def aug_test(self, imgs, img_metas, rescale=True):
|
||||
"""Test with augmentations.
|
||||
|
||||
Only rescale=True is supported.
|
||||
"""
|
||||
# aug_test rescale all imgs back to ori_shape for now
|
||||
assert rescale
|
||||
# to save memory, we get augmented seg logit inplace
|
||||
seg_logit = self.inference(imgs[0], img_metas[0], rescale)
|
||||
for i in range(1, len(imgs)):
|
||||
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
|
||||
seg_logit += cur_seg_logit
|
||||
seg_logit /= len(imgs)
|
||||
seg_pred = seg_logit.argmax(dim=1)
|
||||
seg_pred = seg_pred.cpu().numpy()
|
||||
# unravel batch dim
|
||||
seg_pred = list(seg_pred)
|
||||
return seg_pred
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .assigner import MaskHungarianAssigner
|
||||
from .point_sample import get_uncertain_point_coords_with_randomness
|
||||
from .positional_encoding import LearnedPositionalEncoding, SinePositionalEncoding
|
||||
from .transformer import DetrTransformerDecoder, DetrTransformerDecoderLayer, DynamicConv, Transformer
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from ..builder import MASK_ASSIGNERS, build_match_cost
|
||||
|
||||
try:
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
except ImportError:
|
||||
linear_sum_assignment = None
|
||||
|
||||
|
||||
class AssignResult(metaclass=ABCMeta):
|
||||
"""Collection of assign results."""
|
||||
|
||||
def __init__(self, num_gts, gt_inds, labels):
|
||||
self.num_gts = num_gts
|
||||
self.gt_inds = gt_inds
|
||||
self.labels = labels
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
info = {
|
||||
"num_gts": self.num_gts,
|
||||
"gt_inds": self.gt_inds,
|
||||
"labels": self.labels,
|
||||
}
|
||||
return info
|
||||
|
||||
|
||||
class BaseAssigner(metaclass=ABCMeta):
|
||||
"""Base assigner that assigns boxes to ground truth boxes."""
|
||||
|
||||
@abstractmethod
|
||||
def assign(self, masks, gt_masks, gt_masks_ignore=None, gt_labels=None):
|
||||
"""Assign boxes to either a ground truth boxes or a negative boxes."""
|
||||
pass
|
||||
|
||||
|
||||
@MASK_ASSIGNERS.register_module()
|
||||
class MaskHungarianAssigner(BaseAssigner):
|
||||
"""Computes one-to-one matching between predictions and ground truth for
|
||||
mask.
|
||||
|
||||
This class computes an assignment between the targets and the predictions
|
||||
based on the costs. The costs are weighted sum of three components:
|
||||
classification cost, regression L1 cost and regression iou cost. The
|
||||
targets don't include the no_object, so generally there are more
|
||||
predictions than targets. After the one-to-one matching, the un-matched
|
||||
are treated as backgrounds. Thus each query prediction will be assigned
|
||||
with `0` or a positive integer indicating the ground truth index:
|
||||
|
||||
- 0: negative sample, no assigned gt
|
||||
- positive integer: positive sample, index (1-based) of assigned gt
|
||||
|
||||
Args:
|
||||
cls_cost (obj:`mmcv.ConfigDict`|dict): Classification cost config.
|
||||
mask_cost (obj:`mmcv.ConfigDict`|dict): Mask cost config.
|
||||
dice_cost (obj:`mmcv.ConfigDict`|dict): Dice cost config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cls_cost=dict(type="ClassificationCost", weight=1.0),
|
||||
dice_cost=dict(type="DiceCost", weight=1.0),
|
||||
mask_cost=dict(type="MaskFocalCost", weight=1.0),
|
||||
):
|
||||
self.cls_cost = build_match_cost(cls_cost)
|
||||
self.dice_cost = build_match_cost(dice_cost)
|
||||
self.mask_cost = build_match_cost(mask_cost)
|
||||
|
||||
def assign(self, cls_pred, mask_pred, gt_labels, gt_masks, img_meta, gt_masks_ignore=None, eps=1e-7):
|
||||
"""Computes one-to-one matching based on the weighted costs.
|
||||
|
||||
This method assign each query prediction to a ground truth or
|
||||
background. The `assigned_gt_inds` with -1 means don't care,
|
||||
0 means negative sample, and positive number is the index (1-based)
|
||||
of assigned gt.
|
||||
The assignment is done in the following steps, the order matters.
|
||||
|
||||
1. assign every prediction to -1
|
||||
2. compute the weighted costs
|
||||
3. do Hungarian matching on CPU based on the costs
|
||||
4. assign all to 0 (background) first, then for each matched pair
|
||||
between predictions and gts, treat this prediction as foreground
|
||||
and assign the corresponding gt index (plus 1) to it.
|
||||
|
||||
Args:
|
||||
mask_pred (Tensor): Predicted mask, shape [num_query, h, w]
|
||||
cls_pred (Tensor): Predicted classification logits, shape
|
||||
[num_query, num_class].
|
||||
gt_masks (Tensor): Ground truth mask, shape [num_gt, h, w].
|
||||
gt_labels (Tensor): Label of `gt_masks`, shape (num_gt,).
|
||||
img_meta (dict): Meta information for current image.
|
||||
gt_masks_ignore (Tensor, optional): Ground truth masks that are
|
||||
labelled as `ignored`. Default None.
|
||||
eps (int | float, optional): A value added to the denominator for
|
||||
numerical stability. Default 1e-7.
|
||||
|
||||
Returns:
|
||||
:obj:`AssignResult`: The assigned result.
|
||||
"""
|
||||
assert gt_masks_ignore is None, "Only case when gt_masks_ignore is None is supported."
|
||||
num_gts, num_queries = gt_labels.shape[0], cls_pred.shape[0]
|
||||
|
||||
# 1. assign -1 by default
|
||||
assigned_gt_inds = cls_pred.new_full((num_queries,), -1, dtype=torch.long)
|
||||
assigned_labels = cls_pred.new_full((num_queries,), -1, dtype=torch.long)
|
||||
if num_gts == 0 or num_queries == 0:
|
||||
# No ground truth or boxes, return empty assignment
|
||||
if num_gts == 0:
|
||||
# No ground truth, assign all to background
|
||||
assigned_gt_inds[:] = 0
|
||||
return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels)
|
||||
|
||||
# 2. compute the weighted costs
|
||||
# classification and maskcost.
|
||||
if self.cls_cost.weight != 0 and cls_pred is not None:
|
||||
cls_cost = self.cls_cost(cls_pred, gt_labels)
|
||||
else:
|
||||
cls_cost = 0
|
||||
|
||||
if self.mask_cost.weight != 0:
|
||||
# mask_pred shape = [nq, h, w]
|
||||
# gt_mask shape = [ng, h, w]
|
||||
# mask_cost shape = [nq, ng]
|
||||
mask_cost = self.mask_cost(mask_pred, gt_masks)
|
||||
else:
|
||||
mask_cost = 0
|
||||
|
||||
if self.dice_cost.weight != 0:
|
||||
dice_cost = self.dice_cost(mask_pred, gt_masks)
|
||||
else:
|
||||
dice_cost = 0
|
||||
cost = cls_cost + mask_cost + dice_cost
|
||||
|
||||
# 3. do Hungarian matching on CPU using linear_sum_assignment
|
||||
cost = cost.detach().cpu()
|
||||
if linear_sum_assignment is None:
|
||||
raise ImportError('Please run "pip install scipy" ' "to install scipy first.")
|
||||
|
||||
matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
|
||||
matched_row_inds = torch.from_numpy(matched_row_inds).to(cls_pred.device)
|
||||
matched_col_inds = torch.from_numpy(matched_col_inds).to(cls_pred.device)
|
||||
|
||||
# 4. assign backgrounds and foregrounds
|
||||
# assign all indices to backgrounds first
|
||||
assigned_gt_inds[:] = 0
|
||||
# assign foregrounds based on matching results
|
||||
assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
|
||||
assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
|
||||
return AssignResult(num_gts, assigned_gt_inds, labels=assigned_labels)
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from mmcv.ops import point_sample
|
||||
|
||||
|
||||
def get_uncertainty(mask_pred, labels):
|
||||
"""Estimate uncertainty based on pred logits.
|
||||
|
||||
We estimate uncertainty as L1 distance between 0.0 and the logits
|
||||
prediction in 'mask_pred' for the foreground class in `classes`.
|
||||
|
||||
Args:
|
||||
mask_pred (Tensor): mask predication logits, shape (num_rois,
|
||||
num_classes, mask_height, mask_width).
|
||||
|
||||
labels (list[Tensor]): Either predicted or ground truth label for
|
||||
each predicted mask, of length num_rois.
|
||||
|
||||
Returns:
|
||||
scores (Tensor): Uncertainty scores with the most uncertain
|
||||
locations having the highest uncertainty score,
|
||||
shape (num_rois, 1, mask_height, mask_width)
|
||||
"""
|
||||
if mask_pred.shape[1] == 1:
|
||||
gt_class_logits = mask_pred.clone()
|
||||
else:
|
||||
inds = torch.arange(mask_pred.shape[0], device=mask_pred.device)
|
||||
gt_class_logits = mask_pred[inds, labels].unsqueeze(1)
|
||||
return -torch.abs(gt_class_logits)
|
||||
|
||||
|
||||
def get_uncertain_point_coords_with_randomness(
|
||||
mask_pred, labels, num_points, oversample_ratio, importance_sample_ratio
|
||||
):
|
||||
"""Get ``num_points`` most uncertain points with random points during
|
||||
train.
|
||||
|
||||
Sample points in [0, 1] x [0, 1] coordinate space based on their
|
||||
uncertainty. The uncertainties are calculated for each point using
|
||||
'get_uncertainty()' function that takes point's logit prediction as
|
||||
input.
|
||||
|
||||
Args:
|
||||
mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
|
||||
mask_height, mask_width) for class-specific or class-agnostic
|
||||
prediction.
|
||||
labels (list): The ground truth class for each instance.
|
||||
num_points (int): The number of points to sample.
|
||||
oversample_ratio (int): Oversampling parameter.
|
||||
importance_sample_ratio (float): Ratio of points that are sampled
|
||||
via importnace sampling.
|
||||
|
||||
Returns:
|
||||
point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
|
||||
that contains the coordinates sampled points.
|
||||
"""
|
||||
assert oversample_ratio >= 1
|
||||
assert 0 <= importance_sample_ratio <= 1
|
||||
batch_size = mask_pred.shape[0]
|
||||
num_sampled = int(num_points * oversample_ratio)
|
||||
point_coords = torch.rand(batch_size, num_sampled, 2, device=mask_pred.device)
|
||||
point_logits = point_sample(mask_pred, point_coords)
|
||||
# It is crucial to calculate uncertainty based on the sampled
|
||||
# prediction value for the points. Calculating uncertainties of the
|
||||
# coarse predictions first and sampling them for points leads to
|
||||
# incorrect results. To illustrate this: assume uncertainty func(
|
||||
# logits)=-abs(logits), a sampled point between two coarse
|
||||
# predictions with -1 and 1 logits has 0 logits, and therefore 0
|
||||
# uncertainty value. However, if we calculate uncertainties for the
|
||||
# coarse predictions first, both will have -1 uncertainty,
|
||||
# and sampled point will get -1 uncertainty.
|
||||
point_uncertainties = get_uncertainty(point_logits, labels)
|
||||
num_uncertain_points = int(importance_sample_ratio * num_points)
|
||||
num_random_points = num_points - num_uncertain_points
|
||||
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
|
||||
shift = num_sampled * torch.arange(batch_size, dtype=torch.long, device=mask_pred.device)
|
||||
idx += shift[:, None]
|
||||
point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(batch_size, num_uncertain_points, 2)
|
||||
if num_random_points > 0:
|
||||
rand_roi_coords = torch.rand(batch_size, num_random_points, 2, device=mask_pred.device)
|
||||
point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
|
||||
return point_coords
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
|
||||
@POSITIONAL_ENCODING.register_module()
|
||||
class SinePositionalEncoding(BaseModule):
|
||||
"""Position encoding with sine and cosine functions.
|
||||
|
||||
See `End-to-End Object Detection with Transformers
|
||||
<https://arxiv.org/pdf/2005.12872>`_ for details.
|
||||
|
||||
Args:
|
||||
num_feats (int): The feature dimension for each position
|
||||
along x-axis or y-axis. Note the final returned dimension
|
||||
for each position is 2 times of this value.
|
||||
temperature (int, optional): The temperature used for scaling
|
||||
the position embedding. Defaults to 10000.
|
||||
normalize (bool, optional): Whether to normalize the position
|
||||
embedding. Defaults to False.
|
||||
scale (float, optional): A scale factor that scales the position
|
||||
embedding. The scale will be used only when `normalize` is True.
|
||||
Defaults to 2*pi.
|
||||
eps (float, optional): A value added to the denominator for
|
||||
numerical stability. Defaults to 1e-6.
|
||||
offset (float): offset add to embed when do the normalization.
|
||||
Defaults to 0.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_feats, temperature=10000, normalize=False, scale=2 * math.pi, eps=1e-6, offset=0.0, init_cfg=None
|
||||
):
|
||||
super(SinePositionalEncoding, self).__init__(init_cfg)
|
||||
if normalize:
|
||||
assert isinstance(scale, (float, int)), (
|
||||
"when normalize is set," "scale should be provided and in float or int type, " f"found {type(scale)}"
|
||||
)
|
||||
self.num_feats = num_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
self.scale = scale
|
||||
self.eps = eps
|
||||
self.offset = offset
|
||||
|
||||
def forward(self, mask):
|
||||
"""Forward function for `SinePositionalEncoding`.
|
||||
|
||||
Args:
|
||||
mask (Tensor): ByteTensor mask. Non-zero values representing
|
||||
ignored positions, while zero values means valid positions
|
||||
for this image. Shape [bs, h, w].
|
||||
|
||||
Returns:
|
||||
pos (Tensor): Returned position embedding with shape
|
||||
[bs, num_feats*2, h, w].
|
||||
"""
|
||||
# For convenience of exporting to ONNX, it's required to convert
|
||||
# `masks` from bool to int.
|
||||
mask = mask.to(torch.int)
|
||||
not_mask = 1 - mask # logical_not
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
y_embed = (y_embed + self.offset) / (y_embed[:, -1:, :] + self.eps) * self.scale
|
||||
x_embed = (x_embed + self.offset) / (x_embed[:, :, -1:] + self.eps) * self.scale
|
||||
dim_t = torch.arange(self.num_feats, dtype=torch.float32, device=mask.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_feats)
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
# use `view` instead of `flatten` for dynamically exporting to ONNX
|
||||
B, H, W = mask.size()
|
||||
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1)
|
||||
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).view(B, H, W, -1)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
def __repr__(self):
|
||||
"""str: a string that describes the module"""
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f"(num_feats={self.num_feats}, "
|
||||
repr_str += f"temperature={self.temperature}, "
|
||||
repr_str += f"normalize={self.normalize}, "
|
||||
repr_str += f"scale={self.scale}, "
|
||||
repr_str += f"eps={self.eps})"
|
||||
return repr_str
|
||||
|
||||
|
||||
@POSITIONAL_ENCODING.register_module()
|
||||
class LearnedPositionalEncoding(BaseModule):
|
||||
"""Position embedding with learnable embedding weights.
|
||||
|
||||
Args:
|
||||
num_feats (int): The feature dimension for each position
|
||||
along x-axis or y-axis. The final returned dimension for
|
||||
each position is 2 times of this value.
|
||||
row_num_embed (int, optional): The dictionary size of row embeddings.
|
||||
Default 50.
|
||||
col_num_embed (int, optional): The dictionary size of col embeddings.
|
||||
Default 50.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self, num_feats, row_num_embed=50, col_num_embed=50, init_cfg=dict(type="Uniform", layer="Embedding")):
|
||||
super(LearnedPositionalEncoding, self).__init__(init_cfg)
|
||||
self.row_embed = nn.Embedding(row_num_embed, num_feats)
|
||||
self.col_embed = nn.Embedding(col_num_embed, num_feats)
|
||||
self.num_feats = num_feats
|
||||
self.row_num_embed = row_num_embed
|
||||
self.col_num_embed = col_num_embed
|
||||
|
||||
def forward(self, mask):
|
||||
"""Forward function for `LearnedPositionalEncoding`.
|
||||
|
||||
Args:
|
||||
mask (Tensor): ByteTensor mask. Non-zero values representing
|
||||
ignored positions, while zero values means valid positions
|
||||
for this image. Shape [bs, h, w].
|
||||
|
||||
Returns:
|
||||
pos (Tensor): Returned position embedding with shape
|
||||
[bs, num_feats*2, h, w].
|
||||
"""
|
||||
h, w = mask.shape[-2:]
|
||||
x = torch.arange(w, device=mask.device)
|
||||
y = torch.arange(h, device=mask.device)
|
||||
x_embed = self.col_embed(x)
|
||||
y_embed = self.row_embed(y)
|
||||
pos = (
|
||||
torch.cat((x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(1, w, 1)), dim=-1)
|
||||
.permute(2, 0, 1)
|
||||
.unsqueeze(0)
|
||||
.repeat(mask.shape[0], 1, 1, 1)
|
||||
)
|
||||
return pos
|
||||
|
||||
def __repr__(self):
|
||||
"""str: a string that describes the module"""
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f"(num_feats={self.num_feats}, "
|
||||
repr_str += f"row_num_embed={self.row_num_embed}, "
|
||||
repr_str += f"col_num_embed={self.col_num_embed})"
|
||||
return repr_str
|
|
@ -0,0 +1,989 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer, xavier_init
|
||||
from mmcv.cnn.bricks.drop import build_dropout
|
||||
from mmcv.cnn.bricks.registry import FEEDFORWARD_NETWORK, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE
|
||||
from mmcv.cnn.bricks.transformer import BaseTransformerLayer, TransformerLayerSequence, build_transformer_layer_sequence
|
||||
from mmcv.runner.base_module import BaseModule, Sequential
|
||||
from mmcv.utils import deprecated_api_warning, to_2tuple
|
||||
from torch.nn.init import normal_
|
||||
|
||||
from ..builder import TRANSFORMER
|
||||
|
||||
try:
|
||||
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
|
||||
|
||||
except ImportError:
|
||||
warnings.warn(
|
||||
"`MultiScaleDeformableAttention` in MMCV has been moved to "
|
||||
"`mmcv.ops.multi_scale_deform_attn`, please update your MMCV"
|
||||
)
|
||||
from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
|
||||
|
||||
|
||||
class AdaptivePadding(nn.Module):
|
||||
"""Applies padding to input (if needed) so that input can get fully covered
|
||||
by filter you specified. It support two modes "same" and "corner". The
|
||||
"same" mode is same with "SAME" padding mode in TensorFlow, pad zero around
|
||||
input. The "corner" mode would pad zero to bottom right.
|
||||
|
||||
Args:
|
||||
kernel_size (int | tuple): Size of the kernel:
|
||||
stride (int | tuple): Stride of the filter. Default: 1:
|
||||
dilation (int | tuple): Spacing between kernel elements.
|
||||
Default: 1
|
||||
padding (str): Support "same" and "corner", "corner" mode
|
||||
would pad zero to bottom right, and "same" mode would
|
||||
pad zero around input. Default: "corner".
|
||||
Example:
|
||||
>>> kernel_size = 16
|
||||
>>> stride = 16
|
||||
>>> dilation = 1
|
||||
>>> input = torch.rand(1, 1, 15, 17)
|
||||
>>> adap_pad = AdaptivePadding(
|
||||
>>> kernel_size=kernel_size,
|
||||
>>> stride=stride,
|
||||
>>> dilation=dilation,
|
||||
>>> padding="corner")
|
||||
>>> out = adap_pad(input)
|
||||
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
|
||||
>>> input = torch.rand(1, 1, 16, 17)
|
||||
>>> out = adap_pad(input)
|
||||
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size=1, stride=1, dilation=1, padding="corner"):
|
||||
|
||||
super(AdaptivePadding, self).__init__()
|
||||
|
||||
assert padding in ("same", "corner")
|
||||
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
stride = to_2tuple(stride)
|
||||
padding = to_2tuple(padding)
|
||||
dilation = to_2tuple(dilation)
|
||||
|
||||
self.padding = padding
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
def get_pad_shape(self, input_shape):
|
||||
input_h, input_w = input_shape
|
||||
kernel_h, kernel_w = self.kernel_size
|
||||
stride_h, stride_w = self.stride
|
||||
output_h = math.ceil(input_h / stride_h)
|
||||
output_w = math.ceil(input_w / stride_w)
|
||||
pad_h = max((output_h - 1) * stride_h + (kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
|
||||
pad_w = max((output_w - 1) * stride_w + (kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
|
||||
return pad_h, pad_w
|
||||
|
||||
def forward(self, x):
|
||||
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
if self.padding == "corner":
|
||||
x = F.pad(x, [0, pad_w, 0, pad_h])
|
||||
elif self.padding == "same":
|
||||
x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(BaseModule):
|
||||
"""Merge patch feature map.
|
||||
|
||||
This layer groups feature map by kernel_size, and applies norm and linear
|
||||
layers to the grouped feature map. Our implementation uses `nn.Unfold` to
|
||||
merge patch, which is about 25% faster than original implementation.
|
||||
Instead, we need to modify pretrained models for compatibility.
|
||||
|
||||
Args:
|
||||
in_channels (int): The num of input channels.
|
||||
to gets fully covered by filter and stride you specified..
|
||||
Default: True.
|
||||
out_channels (int): The num of output channels.
|
||||
kernel_size (int | tuple, optional): the kernel size in the unfold
|
||||
layer. Defaults to 2.
|
||||
stride (int | tuple, optional): the stride of the sliding blocks in the
|
||||
unfold layer. Default: None. (Would be set as `kernel_size`)
|
||||
padding (int | tuple | string ): The padding length of
|
||||
embedding conv. When it is a string, it means the mode
|
||||
of adaptive padding, support "same" and "corner" now.
|
||||
Default: "corner".
|
||||
dilation (int | tuple, optional): dilation parameter in the unfold
|
||||
layer. Default: 1.
|
||||
bias (bool, optional): Whether to add bias in linear layer or not.
|
||||
Defaults: False.
|
||||
norm_cfg (dict, optional): Config dict for normalization layer.
|
||||
Default: dict(type='LN').
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=2,
|
||||
stride=None,
|
||||
padding="corner",
|
||||
dilation=1,
|
||||
bias=False,
|
||||
norm_cfg=dict(type="LN"),
|
||||
init_cfg=None,
|
||||
):
|
||||
super().__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
if stride:
|
||||
stride = stride
|
||||
else:
|
||||
stride = kernel_size
|
||||
|
||||
kernel_size = to_2tuple(kernel_size)
|
||||
stride = to_2tuple(stride)
|
||||
dilation = to_2tuple(dilation)
|
||||
|
||||
if isinstance(padding, str):
|
||||
self.adap_padding = AdaptivePadding(
|
||||
kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding
|
||||
)
|
||||
# disable the padding of unfold
|
||||
padding = 0
|
||||
else:
|
||||
self.adap_padding = None
|
||||
|
||||
padding = to_2tuple(padding)
|
||||
self.sampler = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride)
|
||||
|
||||
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
|
||||
|
||||
if norm_cfg is not None:
|
||||
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
|
||||
|
||||
def forward(self, x, input_size):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): Has shape (B, H*W, C_in).
|
||||
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
tuple: Contains merged results and its spatial shape.
|
||||
|
||||
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
|
||||
- out_size (tuple[int]): Spatial shape of x, arrange as
|
||||
(Merged_H, Merged_W).
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
assert isinstance(input_size, Sequence), f"Expect " f"input_size is " f"`Sequence` " f"but get {input_size}"
|
||||
|
||||
H, W = input_size
|
||||
assert L == H * W, "input feature has wrong size"
|
||||
|
||||
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
|
||||
# Use nn.Unfold to merge patch. About 25% faster than original method,
|
||||
# but need to modify pretrained model for compatibility
|
||||
|
||||
if self.adap_padding:
|
||||
x = self.adap_padding(x)
|
||||
H, W = x.shape[-2:]
|
||||
|
||||
x = self.sampler(x)
|
||||
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
|
||||
|
||||
out_h = (
|
||||
H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] * (self.sampler.kernel_size[0] - 1) - 1
|
||||
) // self.sampler.stride[0] + 1
|
||||
out_w = (
|
||||
W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] * (self.sampler.kernel_size[1] - 1) - 1
|
||||
) // self.sampler.stride[1] + 1
|
||||
|
||||
output_size = (out_h, out_w)
|
||||
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
|
||||
x = self.norm(x) if self.norm else x
|
||||
x = self.reduction(x)
|
||||
return x, output_size
|
||||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-5):
|
||||
"""Inverse function of sigmoid.
|
||||
|
||||
Args:
|
||||
x (Tensor): The tensor to do the
|
||||
inverse.
|
||||
eps (float): EPS avoid numerical
|
||||
overflow. Defaults 1e-5.
|
||||
Returns:
|
||||
Tensor: The x has passed the inverse
|
||||
function of sigmoid, has same
|
||||
shape with input.
|
||||
"""
|
||||
x = x.clamp(min=0, max=1)
|
||||
x1 = x.clamp(min=eps)
|
||||
x2 = (1 - x).clamp(min=eps)
|
||||
return torch.log(x1 / x2)
|
||||
|
||||
|
||||
@FEEDFORWARD_NETWORK.register_module(force=True)
|
||||
class FFN(BaseModule):
|
||||
"""Implements feed-forward networks (FFNs) with identity connection.
|
||||
Args:
|
||||
embed_dims (int): The feature dimension. Same as
|
||||
`MultiheadAttention`. Defaults: 256.
|
||||
feedforward_channels (int): The hidden dimension of FFNs.
|
||||
Defaults: 1024.
|
||||
num_fcs (int, optional): The number of fully-connected layers in
|
||||
FFNs. Default: 2.
|
||||
act_cfg (dict, optional): The activation config for FFNs.
|
||||
Default: dict(type='ReLU')
|
||||
ffn_drop (float, optional): Probability of an element to be
|
||||
zeroed in FFN. Default 0.0.
|
||||
add_identity (bool, optional): Whether to add the
|
||||
identity connection. Default: `True`.
|
||||
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
||||
when adding the shortcut.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
@deprecated_api_warning({"dropout": "ffn_drop", "add_residual": "add_identity"}, cls_name="FFN")
|
||||
def __init__(
|
||||
self,
|
||||
embed_dims=256,
|
||||
feedforward_channels=1024,
|
||||
num_fcs=2,
|
||||
act_cfg=dict(type="ReLU", inplace=True),
|
||||
ffn_drop=0.0,
|
||||
dropout_layer=None,
|
||||
add_identity=True,
|
||||
init_cfg=None,
|
||||
with_cp=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(init_cfg)
|
||||
assert num_fcs >= 2, "num_fcs should be no less " f"than 2. got {num_fcs}."
|
||||
self.embed_dims = embed_dims
|
||||
self.feedforward_channels = feedforward_channels
|
||||
self.num_fcs = num_fcs
|
||||
self.act_cfg = act_cfg
|
||||
self.activate = build_activation_layer(act_cfg)
|
||||
self.with_cp = with_cp
|
||||
layers = []
|
||||
in_channels = embed_dims
|
||||
for _ in range(num_fcs - 1):
|
||||
layers.append(Sequential(Linear(in_channels, feedforward_channels), self.activate, nn.Dropout(ffn_drop)))
|
||||
in_channels = feedforward_channels
|
||||
layers.append(Linear(feedforward_channels, embed_dims))
|
||||
layers.append(nn.Dropout(ffn_drop))
|
||||
self.layers = Sequential(*layers)
|
||||
self.dropout_layer = build_dropout(dropout_layer) if dropout_layer else torch.nn.Identity()
|
||||
self.add_identity = add_identity
|
||||
|
||||
@deprecated_api_warning({"residual": "identity"}, cls_name="FFN")
|
||||
def forward(self, x, identity=None):
|
||||
"""Forward function for `FFN`.
|
||||
The function would add x to the output tensor if residue is None.
|
||||
"""
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(self.layers, x)
|
||||
else:
|
||||
out = self.layers(x)
|
||||
|
||||
if not self.add_identity:
|
||||
return self.dropout_layer(out)
|
||||
if identity is None:
|
||||
identity = x
|
||||
return identity + self.dropout_layer(out)
|
||||
|
||||
|
||||
@TRANSFORMER_LAYER.register_module()
|
||||
class DetrTransformerDecoderLayer(BaseTransformerLayer):
|
||||
"""Implements decoder layer in DETR transformer.
|
||||
|
||||
Args:
|
||||
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
|
||||
Configs for self_attention or cross_attention, the order
|
||||
should be consistent with it in `operation_order`. If it is
|
||||
a dict, it would be expand to the number of attention in
|
||||
`operation_order`.
|
||||
feedforward_channels (int): The hidden dimension for FFNs.
|
||||
ffn_dropout (float): Probability of an element to be zeroed
|
||||
in ffn. Default 0.0.
|
||||
operation_order (tuple[str]): The execution order of operation
|
||||
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
|
||||
Default:None
|
||||
act_cfg (dict): The activation config for FFNs. Default: `LN`
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: `LN`.
|
||||
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
|
||||
Default:2.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attn_cfgs,
|
||||
feedforward_channels,
|
||||
ffn_dropout=0.0,
|
||||
operation_order=None,
|
||||
act_cfg=dict(type="ReLU", inplace=True),
|
||||
norm_cfg=dict(type="LN"),
|
||||
ffn_num_fcs=2,
|
||||
**kwargs,
|
||||
):
|
||||
super(DetrTransformerDecoderLayer, self).__init__(
|
||||
attn_cfgs=attn_cfgs,
|
||||
feedforward_channels=feedforward_channels,
|
||||
ffn_dropout=ffn_dropout,
|
||||
operation_order=operation_order,
|
||||
act_cfg=act_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
ffn_num_fcs=ffn_num_fcs,
|
||||
**kwargs,
|
||||
)
|
||||
assert len(operation_order) == 6
|
||||
assert set(operation_order) == set(["self_attn", "norm", "cross_attn", "ffn"])
|
||||
|
||||
|
||||
@TRANSFORMER_LAYER_SEQUENCE.register_module()
|
||||
class DetrTransformerEncoder(TransformerLayerSequence):
|
||||
"""TransformerEncoder of DETR.
|
||||
|
||||
Args:
|
||||
post_norm_cfg (dict): Config of last normalization layer. Default:
|
||||
`LN`. Only used when `self.pre_norm` is `True`
|
||||
"""
|
||||
|
||||
def __init__(self, *args, post_norm_cfg=dict(type="LN"), **kwargs):
|
||||
super(DetrTransformerEncoder, self).__init__(*args, **kwargs)
|
||||
if post_norm_cfg is not None:
|
||||
self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None
|
||||
else:
|
||||
assert not self.pre_norm, f"Use prenorm in " f"{self.__class__.__name__}," f"Please specify post_norm_cfg"
|
||||
self.post_norm = None
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
"""Forward function for `TransformerCoder`.
|
||||
|
||||
Returns:
|
||||
Tensor: forwarded results with shape [num_query, bs, embed_dims].
|
||||
"""
|
||||
x = super(DetrTransformerEncoder, self).forward(*args, **kwargs)
|
||||
if self.post_norm is not None:
|
||||
x = self.post_norm(x)
|
||||
return x
|
||||
|
||||
|
||||
@TRANSFORMER_LAYER_SEQUENCE.register_module()
|
||||
class DetrTransformerDecoder(TransformerLayerSequence):
|
||||
"""Implements the decoder in DETR transformer.
|
||||
|
||||
Args:
|
||||
return_intermediate (bool): Whether to return intermediate outputs.
|
||||
post_norm_cfg (dict): Config of last normalization layer. Default:
|
||||
`LN`.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, post_norm_cfg=dict(type="LN"), return_intermediate=False, **kwargs):
|
||||
|
||||
super(DetrTransformerDecoder, self).__init__(*args, **kwargs)
|
||||
self.return_intermediate = return_intermediate
|
||||
if post_norm_cfg is not None:
|
||||
self.post_norm = build_norm_layer(post_norm_cfg, self.embed_dims)[1]
|
||||
else:
|
||||
self.post_norm = None
|
||||
|
||||
def forward(self, query, *args, **kwargs):
|
||||
"""Forward function for `TransformerDecoder`.
|
||||
|
||||
Args:
|
||||
query (Tensor): Input query with shape
|
||||
`(num_query, bs, embed_dims)`.
|
||||
|
||||
Returns:
|
||||
Tensor: Results with shape [1, num_query, bs, embed_dims] when
|
||||
return_intermediate is `False`, otherwise it has shape
|
||||
[num_layers, num_query, bs, embed_dims].
|
||||
"""
|
||||
if not self.return_intermediate:
|
||||
x = super().forward(query, *args, **kwargs)
|
||||
if self.post_norm:
|
||||
x = self.post_norm(x)[None]
|
||||
return x
|
||||
|
||||
intermediate = []
|
||||
for layer in self.layers:
|
||||
query = layer(query, *args, **kwargs)
|
||||
if self.return_intermediate:
|
||||
if self.post_norm is not None:
|
||||
intermediate.append(self.post_norm(query))
|
||||
else:
|
||||
intermediate.append(query)
|
||||
return torch.stack(intermediate)
|
||||
|
||||
|
||||
@TRANSFORMER.register_module()
|
||||
class Transformer(BaseModule):
|
||||
"""Implements the DETR transformer.
|
||||
|
||||
Following the official DETR implementation, this module copy-paste
|
||||
from torch.nn.Transformer with modifications:
|
||||
|
||||
* positional encodings are passed in MultiheadAttention
|
||||
* extra LN at the end of encoder is removed
|
||||
* decoder returns a stack of activations from all decoding layers
|
||||
|
||||
See `paper: End-to-End Object Detection with Transformers
|
||||
<https://arxiv.org/pdf/2005.12872>`_ for details.
|
||||
|
||||
Args:
|
||||
encoder (`mmcv.ConfigDict` | Dict): Config of
|
||||
TransformerEncoder. Defaults to None.
|
||||
decoder ((`mmcv.ConfigDict` | Dict)): Config of
|
||||
TransformerDecoder. Defaults to None
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, encoder=None, decoder=None, init_cfg=None):
|
||||
super(Transformer, self).__init__(init_cfg=init_cfg)
|
||||
self.encoder = build_transformer_layer_sequence(encoder)
|
||||
self.decoder = build_transformer_layer_sequence(decoder)
|
||||
self.embed_dims = self.encoder.embed_dims
|
||||
|
||||
def init_weights(self):
|
||||
# follow the official DETR to init parameters
|
||||
for m in self.modules():
|
||||
if hasattr(m, "weight") and m.weight.dim() > 1:
|
||||
xavier_init(m, distribution="uniform")
|
||||
self._is_init = True
|
||||
|
||||
def forward(self, x, mask, query_embed, pos_embed):
|
||||
"""Forward function for `Transformer`.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input query with shape [bs, c, h, w] where
|
||||
c = embed_dims.
|
||||
mask (Tensor): The key_padding_mask used for encoder and decoder,
|
||||
with shape [bs, h, w].
|
||||
query_embed (Tensor): The query embedding for decoder, with shape
|
||||
[num_query, c].
|
||||
pos_embed (Tensor): The positional encoding for encoder and
|
||||
decoder, with the same shape as `x`.
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: results of decoder containing the following tensor.
|
||||
|
||||
- out_dec: Output from decoder. If return_intermediate_dec \
|
||||
is True output has shape [num_dec_layers, bs,
|
||||
num_query, embed_dims], else has shape [1, bs, \
|
||||
num_query, embed_dims].
|
||||
- memory: Output results from encoder, with shape \
|
||||
[bs, embed_dims, h, w].
|
||||
"""
|
||||
bs, c, h, w = x.shape
|
||||
# use `view` instead of `flatten` for dynamically exporting to ONNX
|
||||
x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c]
|
||||
pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1)
|
||||
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # [num_query, dim] -> [num_query, bs, dim]
|
||||
mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w]
|
||||
memory = self.encoder(query=x, key=None, value=None, query_pos=pos_embed, query_key_padding_mask=mask)
|
||||
target = torch.zeros_like(query_embed)
|
||||
# out_dec: [num_layers, num_query, bs, dim]
|
||||
out_dec = self.decoder(
|
||||
query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask
|
||||
)
|
||||
out_dec = out_dec.transpose(1, 2)
|
||||
memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)
|
||||
return out_dec, memory
|
||||
|
||||
|
||||
@TRANSFORMER_LAYER_SEQUENCE.register_module()
|
||||
class DeformableDetrTransformerDecoder(TransformerLayerSequence):
|
||||
"""Implements the decoder in DETR transformer.
|
||||
|
||||
Args:
|
||||
return_intermediate (bool): Whether to return intermediate outputs.
|
||||
coder_norm_cfg (dict): Config of last normalization layer. Default:
|
||||
`LN`.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, return_intermediate=False, **kwargs):
|
||||
|
||||
super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs)
|
||||
self.return_intermediate = return_intermediate
|
||||
|
||||
def forward(self, query, *args, reference_points=None, valid_ratios=None, reg_branches=None, **kwargs):
|
||||
"""Forward function for `TransformerDecoder`.
|
||||
|
||||
Args:
|
||||
query (Tensor): Input query with shape
|
||||
`(num_query, bs, embed_dims)`.
|
||||
reference_points (Tensor): The reference
|
||||
points of offset. has shape
|
||||
(bs, num_query, 4) when as_two_stage,
|
||||
otherwise has shape ((bs, num_query, 2).
|
||||
valid_ratios (Tensor): The radios of valid
|
||||
points on the feature map, has shape
|
||||
(bs, num_levels, 2)
|
||||
reg_branch: (obj:`nn.ModuleList`): Used for
|
||||
refining the regression results. Only would
|
||||
be passed when with_box_refine is True,
|
||||
otherwise would be passed a `None`.
|
||||
|
||||
Returns:
|
||||
Tensor: Results with shape [1, num_query, bs, embed_dims] when
|
||||
return_intermediate is `False`, otherwise it has shape
|
||||
[num_layers, num_query, bs, embed_dims].
|
||||
"""
|
||||
output = query
|
||||
intermediate = []
|
||||
intermediate_reference_points = []
|
||||
for lid, layer in enumerate(self.layers):
|
||||
if reference_points.shape[-1] == 4:
|
||||
reference_points_input = (
|
||||
reference_points[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
|
||||
)
|
||||
else:
|
||||
assert reference_points.shape[-1] == 2
|
||||
reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]
|
||||
output = layer(output, *args, reference_points=reference_points_input, **kwargs)
|
||||
output = output.permute(1, 0, 2)
|
||||
|
||||
if reg_branches is not None:
|
||||
tmp = reg_branches[lid](output)
|
||||
if reference_points.shape[-1] == 4:
|
||||
new_reference_points = tmp + inverse_sigmoid(reference_points)
|
||||
new_reference_points = new_reference_points.sigmoid()
|
||||
else:
|
||||
assert reference_points.shape[-1] == 2
|
||||
new_reference_points = tmp
|
||||
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
|
||||
new_reference_points = new_reference_points.sigmoid()
|
||||
reference_points = new_reference_points.detach()
|
||||
|
||||
output = output.permute(1, 0, 2)
|
||||
if self.return_intermediate:
|
||||
intermediate.append(output)
|
||||
intermediate_reference_points.append(reference_points)
|
||||
|
||||
if self.return_intermediate:
|
||||
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
|
||||
|
||||
return output, reference_points
|
||||
|
||||
|
||||
@TRANSFORMER.register_module()
|
||||
class DeformableDetrTransformer(Transformer):
|
||||
"""Implements the DeformableDETR transformer.
|
||||
|
||||
Args:
|
||||
as_two_stage (bool): Generate query from encoder features.
|
||||
Default: False.
|
||||
num_feature_levels (int): Number of feature maps from FPN:
|
||||
Default: 4.
|
||||
two_stage_num_proposals (int): Number of proposals when set
|
||||
`as_two_stage` as True. Default: 300.
|
||||
"""
|
||||
|
||||
def __init__(self, as_two_stage=False, num_feature_levels=4, two_stage_num_proposals=300, **kwargs):
|
||||
super(DeformableDetrTransformer, self).__init__(**kwargs)
|
||||
self.as_two_stage = as_two_stage
|
||||
self.num_feature_levels = num_feature_levels
|
||||
self.two_stage_num_proposals = two_stage_num_proposals
|
||||
self.embed_dims = self.encoder.embed_dims
|
||||
self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
"""Initialize layers of the DeformableDetrTransformer."""
|
||||
self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dims))
|
||||
|
||||
if self.as_two_stage:
|
||||
self.enc_output = nn.Linear(self.embed_dims, self.embed_dims)
|
||||
self.enc_output_norm = nn.LayerNorm(self.embed_dims)
|
||||
self.pos_trans = nn.Linear(self.embed_dims * 2, self.embed_dims * 2)
|
||||
self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2)
|
||||
else:
|
||||
self.reference_points = nn.Linear(self.embed_dims, 2)
|
||||
|
||||
def init_weights(self):
|
||||
"""Initialize the transformer weights."""
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
for m in self.modules():
|
||||
if isinstance(m, MultiScaleDeformableAttention):
|
||||
m.init_weights()
|
||||
if not self.as_two_stage:
|
||||
xavier_init(self.reference_points, distribution="uniform", bias=0.0)
|
||||
normal_(self.level_embeds)
|
||||
|
||||
def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
|
||||
"""Generate proposals from encoded memory.
|
||||
|
||||
Args:
|
||||
memory (Tensor) : The output of encoder,
|
||||
has shape (bs, num_key, embed_dim). num_key is
|
||||
equal the number of points on feature map from
|
||||
all level.
|
||||
memory_padding_mask (Tensor): Padding mask for memory.
|
||||
has shape (bs, num_key).
|
||||
spatial_shapes (Tensor): The shape of all feature maps.
|
||||
has shape (num_level, 2).
|
||||
|
||||
Returns:
|
||||
tuple: A tuple of feature map and bbox prediction.
|
||||
|
||||
- output_memory (Tensor): The input of decoder, \
|
||||
has shape (bs, num_key, embed_dim). num_key is \
|
||||
equal the number of points on feature map from \
|
||||
all levels.
|
||||
- output_proposals (Tensor): The normalized proposal \
|
||||
after a inverse sigmoid, has shape \
|
||||
(bs, num_keys, 4).
|
||||
"""
|
||||
|
||||
N, S, C = memory.shape
|
||||
proposals = []
|
||||
_cur = 0
|
||||
for lvl, (H, W) in enumerate(spatial_shapes):
|
||||
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(N, H, W, 1)
|
||||
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
|
||||
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
|
||||
|
||||
grid_y, grid_x = torch.meshgrid(
|
||||
torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device),
|
||||
torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device),
|
||||
)
|
||||
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
|
||||
|
||||
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2)
|
||||
grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale
|
||||
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
|
||||
proposal = torch.cat((grid, wh), -1).view(N, -1, 4)
|
||||
proposals.append(proposal)
|
||||
_cur += H * W
|
||||
output_proposals = torch.cat(proposals, 1)
|
||||
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
|
||||
output_proposals = torch.log(output_proposals / (1 - output_proposals))
|
||||
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float("inf"))
|
||||
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
|
||||
|
||||
output_memory = memory
|
||||
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
|
||||
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
|
||||
output_memory = self.enc_output_norm(self.enc_output(output_memory))
|
||||
return output_memory, output_proposals
|
||||
|
||||
@staticmethod
|
||||
def get_reference_points(spatial_shapes, valid_ratios, device):
|
||||
"""Get the reference points used in decoder.
|
||||
|
||||
Args:
|
||||
spatial_shapes (Tensor): The shape of all
|
||||
feature maps, has shape (num_level, 2).
|
||||
valid_ratios (Tensor): The radios of valid
|
||||
points on the feature map, has shape
|
||||
(bs, num_levels, 2)
|
||||
device (obj:`device`): The device where
|
||||
reference_points should be.
|
||||
|
||||
Returns:
|
||||
Tensor: reference points used in decoder, has \
|
||||
shape (bs, num_keys, num_levels, 2).
|
||||
"""
|
||||
reference_points_list = []
|
||||
for lvl, (H, W) in enumerate(spatial_shapes):
|
||||
ref_y, ref_x = torch.meshgrid(
|
||||
torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),
|
||||
torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device),
|
||||
)
|
||||
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H)
|
||||
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W)
|
||||
ref = torch.stack((ref_x, ref_y), -1)
|
||||
reference_points_list.append(ref)
|
||||
reference_points = torch.cat(reference_points_list, 1)
|
||||
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
|
||||
return reference_points
|
||||
|
||||
def get_valid_ratio(self, mask):
|
||||
"""Get the valid radios of feature maps of all level."""
|
||||
_, H, W = mask.shape
|
||||
valid_H = torch.sum(~mask[:, :, 0], 1)
|
||||
valid_W = torch.sum(~mask[:, 0, :], 1)
|
||||
valid_ratio_h = valid_H.float() / H
|
||||
valid_ratio_w = valid_W.float() / W
|
||||
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
||||
return valid_ratio
|
||||
|
||||
def get_proposal_pos_embed(self, proposals, num_pos_feats=128, temperature=10000):
|
||||
"""Get the position embedding of proposal."""
|
||||
scale = 2 * math.pi
|
||||
dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
|
||||
dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
|
||||
# N, L, 4
|
||||
proposals = proposals.sigmoid() * scale
|
||||
# N, L, 4, 128
|
||||
pos = proposals[:, :, :, None] / dim_t
|
||||
# N, L, 4, 64, 2
|
||||
pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
|
||||
return pos
|
||||
|
||||
def forward(
|
||||
self, mlvl_feats, mlvl_masks, query_embed, mlvl_pos_embeds, reg_branches=None, cls_branches=None, **kwargs
|
||||
):
|
||||
"""Forward function for `Transformer`.
|
||||
|
||||
Args:
|
||||
mlvl_feats (list(Tensor)): Input queries from
|
||||
different level. Each element has shape
|
||||
[bs, embed_dims, h, w].
|
||||
mlvl_masks (list(Tensor)): The key_padding_mask from
|
||||
different level used for encoder and decoder,
|
||||
each element has shape [bs, h, w].
|
||||
query_embed (Tensor): The query embedding for decoder,
|
||||
with shape [num_query, c].
|
||||
mlvl_pos_embeds (list(Tensor)): The positional encoding
|
||||
of feats from different level, has the shape
|
||||
[bs, embed_dims, h, w].
|
||||
reg_branches (obj:`nn.ModuleList`): Regression heads for
|
||||
feature maps from each decoder layer. Only would
|
||||
be passed when
|
||||
`with_box_refine` is True. Default to None.
|
||||
cls_branches (obj:`nn.ModuleList`): Classification heads
|
||||
for feature maps from each decoder layer. Only would
|
||||
be passed when `as_two_stage`
|
||||
is True. Default to None.
|
||||
|
||||
|
||||
Returns:
|
||||
tuple[Tensor]: results of decoder containing the following tensor.
|
||||
|
||||
- inter_states: Outputs from decoder. If
|
||||
return_intermediate_dec is True output has shape \
|
||||
(num_dec_layers, bs, num_query, embed_dims), else has \
|
||||
shape (1, bs, num_query, embed_dims).
|
||||
- init_reference_out: The initial value of reference \
|
||||
points, has shape (bs, num_queries, 4).
|
||||
- inter_references_out: The internal value of reference \
|
||||
points in decoder, has shape \
|
||||
(num_dec_layers, bs,num_query, embed_dims)
|
||||
- enc_outputs_class: The classification score of \
|
||||
proposals generated from \
|
||||
encoder's feature maps, has shape \
|
||||
(batch, h*w, num_classes). \
|
||||
Only would be returned when `as_two_stage` is True, \
|
||||
otherwise None.
|
||||
- enc_outputs_coord_unact: The regression results \
|
||||
generated from encoder's feature maps., has shape \
|
||||
(batch, h*w, 4). Only would \
|
||||
be returned when `as_two_stage` is True, \
|
||||
otherwise None.
|
||||
"""
|
||||
assert self.as_two_stage or query_embed is not None
|
||||
|
||||
feat_flatten = []
|
||||
mask_flatten = []
|
||||
lvl_pos_embed_flatten = []
|
||||
spatial_shapes = []
|
||||
for lvl, (feat, mask, pos_embed) in enumerate(zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
|
||||
bs, c, h, w = feat.shape
|
||||
spatial_shape = (h, w)
|
||||
spatial_shapes.append(spatial_shape)
|
||||
feat = feat.flatten(2).transpose(1, 2)
|
||||
mask = mask.flatten(1)
|
||||
pos_embed = pos_embed.flatten(2).transpose(1, 2)
|
||||
lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
|
||||
lvl_pos_embed_flatten.append(lvl_pos_embed)
|
||||
feat_flatten.append(feat)
|
||||
mask_flatten.append(mask)
|
||||
feat_flatten = torch.cat(feat_flatten, 1)
|
||||
mask_flatten = torch.cat(mask_flatten, 1)
|
||||
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
|
||||
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=feat_flatten.device)
|
||||
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
|
||||
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in mlvl_masks], 1)
|
||||
|
||||
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=feat.device)
|
||||
|
||||
feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
|
||||
lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims)
|
||||
memory = self.encoder(
|
||||
query=feat_flatten,
|
||||
key=None,
|
||||
value=None,
|
||||
query_pos=lvl_pos_embed_flatten,
|
||||
query_key_padding_mask=mask_flatten,
|
||||
spatial_shapes=spatial_shapes,
|
||||
reference_points=reference_points,
|
||||
level_start_index=level_start_index,
|
||||
valid_ratios=valid_ratios,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
memory = memory.permute(1, 0, 2)
|
||||
bs, _, c = memory.shape
|
||||
if self.as_two_stage:
|
||||
output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)
|
||||
enc_outputs_class = cls_branches[self.decoder.num_layers](output_memory)
|
||||
enc_outputs_coord_unact = reg_branches[self.decoder.num_layers](output_memory) + output_proposals
|
||||
|
||||
topk = self.two_stage_num_proposals
|
||||
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
|
||||
topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
|
||||
topk_coords_unact = topk_coords_unact.detach()
|
||||
reference_points = topk_coords_unact.sigmoid()
|
||||
init_reference_out = reference_points
|
||||
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
|
||||
query_pos, query = torch.split(pos_trans_out, c, dim=2)
|
||||
else:
|
||||
query_pos, query = torch.split(query_embed, c, dim=1)
|
||||
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
|
||||
query = query.unsqueeze(0).expand(bs, -1, -1)
|
||||
reference_points = self.reference_points(query_pos).sigmoid()
|
||||
init_reference_out = reference_points
|
||||
|
||||
# decoder
|
||||
query = query.permute(1, 0, 2)
|
||||
memory = memory.permute(1, 0, 2)
|
||||
query_pos = query_pos.permute(1, 0, 2)
|
||||
inter_states, inter_references = self.decoder(
|
||||
query=query,
|
||||
key=None,
|
||||
value=memory,
|
||||
query_pos=query_pos,
|
||||
key_padding_mask=mask_flatten,
|
||||
reference_points=reference_points,
|
||||
spatial_shapes=spatial_shapes,
|
||||
level_start_index=level_start_index,
|
||||
valid_ratios=valid_ratios,
|
||||
reg_branches=reg_branches,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
inter_references_out = inter_references
|
||||
if self.as_two_stage:
|
||||
return inter_states, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact
|
||||
return inter_states, init_reference_out, inter_references_out, None, None
|
||||
|
||||
|
||||
@TRANSFORMER.register_module()
|
||||
class DynamicConv(BaseModule):
|
||||
"""Implements Dynamic Convolution.
|
||||
|
||||
This module generate parameters for each sample and
|
||||
use bmm to implement 1*1 convolution. Code is modified
|
||||
from the `official github repo <https://github.com/PeizeSun/
|
||||
SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py#L258>`_ .
|
||||
|
||||
Args:
|
||||
in_channels (int): The input feature channel.
|
||||
Defaults to 256.
|
||||
feat_channels (int): The inner feature channel.
|
||||
Defaults to 64.
|
||||
out_channels (int, optional): The output feature channel.
|
||||
When not specified, it will be set to `in_channels`
|
||||
by default
|
||||
input_feat_shape (int): The shape of input feature.
|
||||
Defaults to 7.
|
||||
with_proj (bool): Project two-dimentional feature to
|
||||
one-dimentional feature. Default to True.
|
||||
act_cfg (dict): The activation config for DynamicConv.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default
|
||||
layer normalization.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=256,
|
||||
feat_channels=64,
|
||||
out_channels=None,
|
||||
input_feat_shape=7,
|
||||
with_proj=True,
|
||||
act_cfg=dict(type="ReLU", inplace=True),
|
||||
norm_cfg=dict(type="LN"),
|
||||
init_cfg=None,
|
||||
):
|
||||
super(DynamicConv, self).__init__(init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.feat_channels = feat_channels
|
||||
self.out_channels_raw = out_channels
|
||||
self.input_feat_shape = input_feat_shape
|
||||
self.with_proj = with_proj
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.out_channels = out_channels if out_channels else in_channels
|
||||
|
||||
self.num_params_in = self.in_channels * self.feat_channels
|
||||
self.num_params_out = self.out_channels * self.feat_channels
|
||||
self.dynamic_layer = nn.Linear(self.in_channels, self.num_params_in + self.num_params_out)
|
||||
|
||||
self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
|
||||
self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
|
||||
|
||||
self.activation = build_activation_layer(act_cfg)
|
||||
|
||||
num_output = self.out_channels * input_feat_shape**2
|
||||
if self.with_proj:
|
||||
self.fc_layer = nn.Linear(num_output, self.out_channels)
|
||||
self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
|
||||
|
||||
def forward(self, param_feature, input_feature):
|
||||
"""Forward function for `DynamicConv`.
|
||||
|
||||
Args:
|
||||
param_feature (Tensor): The feature can be used
|
||||
to generate the parameter, has shape
|
||||
(num_all_proposals, in_channels).
|
||||
input_feature (Tensor): Feature that
|
||||
interact with parameters, has shape
|
||||
(num_all_proposals, in_channels, H, W).
|
||||
|
||||
Returns:
|
||||
Tensor: The output feature has shape
|
||||
(num_all_proposals, out_channels).
|
||||
"""
|
||||
input_feature = input_feature.flatten(2).permute(2, 0, 1)
|
||||
|
||||
input_feature = input_feature.permute(1, 0, 2)
|
||||
parameters = self.dynamic_layer(param_feature)
|
||||
|
||||
param_in = parameters[:, : self.num_params_in].view(-1, self.in_channels, self.feat_channels)
|
||||
param_out = parameters[:, -self.num_params_out :].view(-1, self.feat_channels, self.out_channels)
|
||||
|
||||
# input_feature has shape (num_all_proposals, H*W, in_channels)
|
||||
# param_in has shape (num_all_proposals, in_channels, feat_channels)
|
||||
# feature has shape (num_all_proposals, H*W, feat_channels)
|
||||
features = torch.bmm(input_feature, param_in)
|
||||
features = self.norm_in(features)
|
||||
features = self.activation(features)
|
||||
|
||||
# param_out has shape (batch_size, feat_channels, out_channels)
|
||||
features = torch.bmm(features, param_out)
|
||||
features = self.norm_out(features)
|
||||
features = self.activation(features)
|
||||
|
||||
if self.with_proj:
|
||||
features = features.flatten(1)
|
||||
features = self.fc_layer(features)
|
||||
features = self.fc_norm(features)
|
||||
features = self.activation(features)
|
||||
|
||||
return features
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/fundamentalvision/Deformable-DETR/tree/main/models/ops/modules
|
||||
# https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
|
||||
|
||||
from .ms_deform_attn import MSDeformAttn
|
|
@ -0,0 +1,185 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
from torch.cuda.amp import custom_fwd
|
||||
from torch.nn.init import constant_, xavier_uniform_
|
||||
|
||||
|
||||
class MSDeformAttnFunction(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(
|
||||
ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step
|
||||
):
|
||||
output = ms_deform_attn_core_pytorch(
|
||||
value,
|
||||
value_spatial_shapes,
|
||||
# value_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
|
||||
# for debug and test only,
|
||||
# need to use cuda version instead
|
||||
N_, S_, M_, D_ = value.shape
|
||||
_, Lq_, M_, L_, P_, _ = sampling_locations.shape
|
||||
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for lid_, (H_, W_) in enumerate(value_spatial_shapes):
|
||||
# N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
|
||||
value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_)
|
||||
# N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
|
||||
sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
|
||||
# N_*M_, D_, Lq_, P_
|
||||
sampling_value_l_ = F.grid_sample(
|
||||
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
||||
)
|
||||
sampling_value_list.append(sampling_value_l_)
|
||||
# (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
|
||||
attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_)
|
||||
output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_ * D_, Lq_)
|
||||
return output.transpose(1, 2).contiguous()
|
||||
|
||||
|
||||
def _is_power_of_2(n):
|
||||
if (not isinstance(n, int)) or (n < 0):
|
||||
raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
|
||||
return (n & (n - 1) == 0) and n != 0
|
||||
|
||||
|
||||
class MSDeformAttn(nn.Module):
|
||||
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4, ratio=1.0):
|
||||
"""Multi-Scale Deformable Attention Module.
|
||||
|
||||
:param d_model hidden dimension
|
||||
:param n_levels number of feature levels
|
||||
:param n_heads number of attention heads
|
||||
:param n_points number of sampling points per attention head per feature level
|
||||
"""
|
||||
super().__init__()
|
||||
if d_model % n_heads != 0:
|
||||
raise ValueError("d_model must be divisible by n_heads, " "but got {} and {}".format(d_model, n_heads))
|
||||
_d_per_head = d_model // n_heads
|
||||
# you'd better set _d_per_head to a power of 2
|
||||
# which is more efficient in our CUDA implementation
|
||||
if not _is_power_of_2(_d_per_head):
|
||||
warnings.warn(
|
||||
"You'd better set d_model in MSDeformAttn to make "
|
||||
"the dimension of each attention head a power of 2 "
|
||||
"which is more efficient in our CUDA implementation."
|
||||
)
|
||||
|
||||
self.im2col_step = 64
|
||||
|
||||
self.d_model = d_model
|
||||
self.n_levels = n_levels
|
||||
self.n_heads = n_heads
|
||||
self.n_points = n_points
|
||||
self.ratio = ratio
|
||||
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
|
||||
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
|
||||
self.value_proj = nn.Linear(d_model, int(d_model * ratio))
|
||||
self.output_proj = nn.Linear(int(d_model * ratio), d_model)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
constant_(self.sampling_offsets.weight.data, 0.0)
|
||||
thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
|
||||
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
||||
grid_init = (
|
||||
(grid_init / grid_init.abs().max(-1, keepdim=True)[0])
|
||||
.view(self.n_heads, 1, 1, 2)
|
||||
.repeat(1, self.n_levels, self.n_points, 1)
|
||||
)
|
||||
for i in range(self.n_points):
|
||||
grid_init[:, :, i, :] *= i + 1
|
||||
|
||||
with torch.no_grad():
|
||||
self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
|
||||
constant_(self.attention_weights.weight.data, 0.0)
|
||||
constant_(self.attention_weights.bias.data, 0.0)
|
||||
xavier_uniform_(self.value_proj.weight.data)
|
||||
constant_(self.value_proj.bias.data, 0.0)
|
||||
xavier_uniform_(self.output_proj.weight.data)
|
||||
constant_(self.output_proj.bias.data, 0.0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
reference_points,
|
||||
input_flatten,
|
||||
input_spatial_shapes,
|
||||
input_level_start_index,
|
||||
input_padding_mask=None,
|
||||
):
|
||||
"""
|
||||
:param query (N, Length_{query}, C)
|
||||
:param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
|
||||
or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
|
||||
:param input_flatten (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l, C)
|
||||
:param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
|
||||
:param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
|
||||
:param input_padding_mask (N, \\sum_{l=0}^{L-1} H_l \\cdot W_l), True for padding elements, False for non-padding elements
|
||||
|
||||
:return output (N, Length_{query}, C)
|
||||
"""
|
||||
# print(query.shape)
|
||||
# print(reference_points.shape)
|
||||
# print(input_flatten.shape)
|
||||
# print(input_spatial_shapes.shape)
|
||||
# print(input_level_start_index.shape)
|
||||
# print(input_spatial_shapes)
|
||||
# print(input_level_start_index)
|
||||
|
||||
N, Len_q, _ = query.shape
|
||||
N, Len_in, _ = input_flatten.shape
|
||||
assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
|
||||
|
||||
value = self.value_proj(input_flatten)
|
||||
if input_padding_mask is not None:
|
||||
value = value.masked_fill(input_padding_mask[..., None], float(0))
|
||||
|
||||
value = value.view(N, Len_in, self.n_heads, int(self.ratio * self.d_model) // self.n_heads)
|
||||
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
|
||||
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
|
||||
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
|
||||
|
||||
if reference_points.shape[-1] == 2:
|
||||
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
|
||||
sampling_locations = (
|
||||
reference_points[:, :, None, :, None, :]
|
||||
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
|
||||
)
|
||||
elif reference_points.shape[-1] == 4:
|
||||
sampling_locations = (
|
||||
reference_points[:, :, None, :, None, :2]
|
||||
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1])
|
||||
)
|
||||
output = MSDeformAttnFunction.apply(
|
||||
value,
|
||||
input_spatial_shapes,
|
||||
input_level_start_index,
|
||||
sampling_locations,
|
||||
attention_weights,
|
||||
self.im2col_step,
|
||||
)
|
||||
output = self.output_proj(output)
|
||||
return output
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue