RE-OWOD/projects/PointRend/point_rend/coarse_mask_head.py

93 lines
3.6 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import fvcore.nn.weight_init as weight_init
import torch
from torch import nn
from torch.nn import functional as F
from detectron2.layers import Conv2d, ShapeSpec
from detectron2.modeling import ROI_MASK_HEAD_REGISTRY
@ROI_MASK_HEAD_REGISTRY.register()
class CoarseMaskHead(nn.Module):
"""
A mask head with fully connected layers. Given pooled features it first reduces channels and
spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously
to the standard box head.
"""
def __init__(self, cfg, input_shape: ShapeSpec):
"""
The following attributes are parsed from config:
conv_dim: the output dimension of the conv layers
fc_dim: the feature dimenstion of the FC layers
num_fc: the number of FC layers
output_side_resolution: side resolution of the output square mask prediction
"""
super(CoarseMaskHead, self).__init__()
# fmt: off
self.num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES
conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM
self.fc_dim = cfg.MODEL.ROI_MASK_HEAD.FC_DIM
num_fc = cfg.MODEL.ROI_MASK_HEAD.NUM_FC
self.output_side_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION
self.input_channels = input_shape.channels
self.input_h = input_shape.height
self.input_w = input_shape.width
# fmt: on
self.conv_layers = []
if self.input_channels > conv_dim:
self.reduce_channel_dim_conv = Conv2d(
self.input_channels,
conv_dim,
kernel_size=1,
stride=1,
padding=0,
bias=True,
activation=F.relu,
)
self.conv_layers.append(self.reduce_channel_dim_conv)
self.reduce_spatial_dim_conv = Conv2d(
conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu
)
self.conv_layers.append(self.reduce_spatial_dim_conv)
input_dim = conv_dim * self.input_h * self.input_w
input_dim //= 4
self.fcs = []
for k in range(num_fc):
fc = nn.Linear(input_dim, self.fc_dim)
self.add_module("coarse_mask_fc{}".format(k + 1), fc)
self.fcs.append(fc)
input_dim = self.fc_dim
output_dim = self.num_classes * self.output_side_resolution * self.output_side_resolution
self.prediction = nn.Linear(self.fc_dim, output_dim)
# use normal distribution initialization for mask prediction layer
nn.init.normal_(self.prediction.weight, std=0.001)
nn.init.constant_(self.prediction.bias, 0)
for layer in self.conv_layers:
weight_init.c2_msra_fill(layer)
for layer in self.fcs:
weight_init.c2_xavier_fill(layer)
def forward(self, x):
# unlike BaseMaskRCNNHead, this head only outputs intermediate
# features, because the features will be used later by PointHead.
N = x.shape[0]
x = x.view(N, self.input_channels, self.input_h, self.input_w)
for layer in self.conv_layers:
x = layer(x)
x = torch.flatten(x, start_dim=1)
for layer in self.fcs:
x = F.relu(layer(x))
return self.prediction(x).view(
N, self.num_classes, self.output_side_resolution, self.output_side_resolution
)