mirror of https://github.com/alibaba/EasyCV.git
68 lines
2.8 KiB
Python
68 lines
2.8 KiB
Python
from torch import nn
|
|
|
|
from easycv.models.builder import HEADS
|
|
from .pixel_decoder import MSDeformAttnPixelDecoder
|
|
from .transformer_decoder import MultiScaleMaskedTransformerDecoder
|
|
|
|
|
|
@HEADS.register_module()
|
|
class Mask2FormerHead(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
pixel_decoder,
|
|
transformer_decoder,
|
|
num_things_classes: int,
|
|
num_stuff_classes: int,
|
|
loss_weight: float = 1.0,
|
|
ignore_value: int = -1,
|
|
# extra parameters
|
|
transformer_in_feature: str = 'multi_scale_pixel_decoder',
|
|
):
|
|
"""
|
|
|
|
Args:
|
|
pixel_decoder (cfg): config to build pixel decoder
|
|
transformer_decoder (cfg): config to build transformer decoder
|
|
num_things_classes (int): number of things classes
|
|
num_stuff_classes (int): number of stuff classes
|
|
loss_weight (float, optional): loss weight. Defaults to 1.0.
|
|
ignore_value (int, optional): category id to be ignored during training. Defaults to -1.
|
|
transformer_in_feature (str, optional): nput feature name to the transformer_predictor, only support multi_scale_pixel_decoder now. Defaults to 'multi_scale_pixel_decoder'.
|
|
"""
|
|
super().__init__()
|
|
|
|
self.ignore_value = ignore_value
|
|
self.common_stride = 4
|
|
self.loss_weight = loss_weight
|
|
self.pixel_decoder = MSDeformAttnPixelDecoder(**pixel_decoder)
|
|
self.predictor = MultiScaleMaskedTransformerDecoder(
|
|
**transformer_decoder)
|
|
self.transformer_in_feature = transformer_in_feature
|
|
|
|
self.num_classes = num_things_classes + num_stuff_classes
|
|
|
|
def forward(self, features, mask=None):
|
|
|
|
return self.layers(features, mask)
|
|
|
|
def layers(self, features, mask=None):
|
|
mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(
|
|
features)
|
|
if self.transformer_in_feature == 'multi_scale_pixel_decoder':
|
|
predictions = self.predictor(multi_scale_features, mask_features,
|
|
mask)
|
|
else:
|
|
if self.transformer_in_feature == 'transformer_encoder':
|
|
assert (transformer_encoder_features is not None
|
|
), 'Please use the TransformerEncoderPixelDecoder.'
|
|
predictions = self.predictor(transformer_encoder_features,
|
|
mask_features, mask)
|
|
elif self.transformer_in_feature == 'pixel_embedding':
|
|
predictions = self.predictor(mask_features, mask_features,
|
|
mask)
|
|
else:
|
|
predictions = self.predictor(
|
|
features[self.transformer_in_feature], mask_features, mask)
|
|
return predictions
|