mirror of https://github.com/RE-OWOD/RE-OWOD
93 lines
3.6 KiB
Python
93 lines
3.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
import fvcore.nn.weight_init as weight_init
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from detectron2.layers import Conv2d, ShapeSpec
|
|
from detectron2.modeling import ROI_MASK_HEAD_REGISTRY
|
|
|
|
|
|
@ROI_MASK_HEAD_REGISTRY.register()
|
|
class CoarseMaskHead(nn.Module):
|
|
"""
|
|
A mask head with fully connected layers. Given pooled features it first reduces channels and
|
|
spatial dimensions with conv layers and then uses FC layers to predict coarse masks analogously
|
|
to the standard box head.
|
|
"""
|
|
|
|
def __init__(self, cfg, input_shape: ShapeSpec):
|
|
"""
|
|
The following attributes are parsed from config:
|
|
conv_dim: the output dimension of the conv layers
|
|
fc_dim: the feature dimenstion of the FC layers
|
|
num_fc: the number of FC layers
|
|
output_side_resolution: side resolution of the output square mask prediction
|
|
"""
|
|
super(CoarseMaskHead, self).__init__()
|
|
|
|
# fmt: off
|
|
self.num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES
|
|
conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM
|
|
self.fc_dim = cfg.MODEL.ROI_MASK_HEAD.FC_DIM
|
|
num_fc = cfg.MODEL.ROI_MASK_HEAD.NUM_FC
|
|
self.output_side_resolution = cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION
|
|
self.input_channels = input_shape.channels
|
|
self.input_h = input_shape.height
|
|
self.input_w = input_shape.width
|
|
# fmt: on
|
|
|
|
self.conv_layers = []
|
|
if self.input_channels > conv_dim:
|
|
self.reduce_channel_dim_conv = Conv2d(
|
|
self.input_channels,
|
|
conv_dim,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=True,
|
|
activation=F.relu,
|
|
)
|
|
self.conv_layers.append(self.reduce_channel_dim_conv)
|
|
|
|
self.reduce_spatial_dim_conv = Conv2d(
|
|
conv_dim, conv_dim, kernel_size=2, stride=2, padding=0, bias=True, activation=F.relu
|
|
)
|
|
self.conv_layers.append(self.reduce_spatial_dim_conv)
|
|
|
|
input_dim = conv_dim * self.input_h * self.input_w
|
|
input_dim //= 4
|
|
|
|
self.fcs = []
|
|
for k in range(num_fc):
|
|
fc = nn.Linear(input_dim, self.fc_dim)
|
|
self.add_module("coarse_mask_fc{}".format(k + 1), fc)
|
|
self.fcs.append(fc)
|
|
input_dim = self.fc_dim
|
|
|
|
output_dim = self.num_classes * self.output_side_resolution * self.output_side_resolution
|
|
|
|
self.prediction = nn.Linear(self.fc_dim, output_dim)
|
|
# use normal distribution initialization for mask prediction layer
|
|
nn.init.normal_(self.prediction.weight, std=0.001)
|
|
nn.init.constant_(self.prediction.bias, 0)
|
|
|
|
for layer in self.conv_layers:
|
|
weight_init.c2_msra_fill(layer)
|
|
for layer in self.fcs:
|
|
weight_init.c2_xavier_fill(layer)
|
|
|
|
def forward(self, x):
|
|
# unlike BaseMaskRCNNHead, this head only outputs intermediate
|
|
# features, because the features will be used later by PointHead.
|
|
N = x.shape[0]
|
|
x = x.view(N, self.input_channels, self.input_h, self.input_w)
|
|
for layer in self.conv_layers:
|
|
x = layer(x)
|
|
x = torch.flatten(x, start_dim=1)
|
|
for layer in self.fcs:
|
|
x = F.relu(layer(x))
|
|
return self.prediction(x).view(
|
|
N, self.num_classes, self.output_side_resolution, self.output_side_resolution
|
|
)
|