# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import fvcore.nn.weight_init as weight_init import torch from torch import nn from torch.nn import functional as F from detectron2.layers import Conv2d, ShapeSpec from detectron2.modeling import ROI_MASK_HEAD_REGISTRY @ROI_MASK_HEAD_REGISTRY.register() class CoarseMaskHead(nn.Module): """ A mask head with fully connected layers. Given pooled features it first reduces channels and spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously to the standard box head. """ def __init__(self, cfg, input_shape: ShapeSpec): """ The following attributes are parsed from config: conv_dim: the output dimension of the conv layers fc_dim: the feature dimenstion of the FC layers num_fc: the number of FC layers output_side_resolution: side resolution of the output square mask prediction """ super(CoarseMaskHead, self).__init__() # fmt: off self.num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM self.fc_dim = cfg.MODEL.ROI_MASK_HEAD.FC_DIM num_fc = cfg.MODEL.ROI_MASK_HEAD.NUM_FC self.output_side_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION self.input_channels = input_shape.channels self.input_h = input_shape.height self.input_w = input_shape.width # fmt: on self.conv_layers = [] if self.input_channels > conv_dim: self.reduce_channel_dim_conv = Conv2d( self.input_channels, conv_dim, kernel_size=1, stride=1, padding=0, bias=True, activation=F.relu, ) self.conv_layers.append(self.reduce_channel_dim_conv) self.reduce_spatial_dim_conv = Conv2d( conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu ) self.conv_layers.append(self.reduce_spatial_dim_conv) input_dim = conv_dim * self.input_h * self.input_w input_dim //= 4 self.fcs = [] for k in range(num_fc): fc = nn.Linear(input_dim, self.fc_dim) self.add_module("coarse_mask_fc{}".format(k + 1), fc) self.fcs.append(fc) input_dim = self.fc_dim output_dim = self.num_classes * self.output_side_resolution * self.output_side_resolution self.prediction = nn.Linear(self.fc_dim, output_dim) # use normal distribution initialization for mask prediction layer nn.init.normal_(self.prediction.weight, std=0.001) nn.init.constant_(self.prediction.bias, 0) for layer in self.conv_layers: weight_init.c2_msra_fill(layer) for layer in self.fcs: weight_init.c2_xavier_fill(layer) def forward(self, x): # unlike BaseMaskRCNNHead, this head only outputs intermediate # features, because the features will be used later by PointHead. N = x.shape[0] x = x.view(N, self.input_channels, self.input_h, self.input_w) for layer in self.conv_layers: x = layer(x) x = torch.flatten(x, start_dim=1) for layer in self.fcs: x = F.relu(layer(x)) return self.prediction(x).view( N, self.num_classes, self.output_side_resolution, self.output_side_resolution )