mirror of https://github.com/YifanXu74/MQ-Det.git
125 lines
3.7 KiB
Python
125 lines
3.7 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
"""
|
|
Miscellaneous utility functions
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from maskrcnn_benchmark.config import cfg
|
|
from maskrcnn_benchmark.layers import Conv2d, DYReLU
|
|
from maskrcnn_benchmark.modeling.poolers import Pooler
|
|
|
|
|
|
def get_group_gn(dim, dim_per_gp, num_groups):
|
|
"""get number of groups used by GroupNorm, based on number of channels."""
|
|
assert dim_per_gp == -1 or num_groups == -1, \
|
|
"GroupNorm: can only specify G or C/G."
|
|
|
|
if dim_per_gp > 0:
|
|
assert dim % dim_per_gp == 0, \
|
|
"dim: {}, dim_per_gp: {}".format(dim, dim_per_gp)
|
|
group_gn = dim // dim_per_gp
|
|
else:
|
|
assert dim % num_groups == 0, \
|
|
"dim: {}, num_groups: {}".format(dim, num_groups)
|
|
group_gn = num_groups
|
|
|
|
return group_gn
|
|
|
|
|
|
def group_norm(out_channels, affine=True, divisor=1):
|
|
out_channels = out_channels // divisor
|
|
dim_per_gp = cfg.MODEL.GROUP_NORM.DIM_PER_GP // divisor
|
|
num_groups = cfg.MODEL.GROUP_NORM.NUM_GROUPS // divisor
|
|
eps = cfg.MODEL.GROUP_NORM.EPSILON # default: 1e-5
|
|
return torch.nn.GroupNorm(
|
|
get_group_gn(out_channels, dim_per_gp, num_groups),
|
|
out_channels,
|
|
eps,
|
|
affine
|
|
)
|
|
|
|
|
|
def make_conv3x3(
|
|
in_channels,
|
|
out_channels,
|
|
dilation=1,
|
|
stride=1,
|
|
use_gn=False,
|
|
use_relu=False,
|
|
kaiming_init=True
|
|
):
|
|
conv = Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
padding=dilation,
|
|
dilation=dilation,
|
|
bias=False if use_gn else True
|
|
)
|
|
if kaiming_init:
|
|
nn.init.kaiming_normal_(
|
|
conv.weight, mode="fan_out", nonlinearity="relu"
|
|
)
|
|
else:
|
|
torch.nn.init.normal_(conv.weight, std=0.01)
|
|
if not use_gn:
|
|
nn.init.constant_(conv.bias, 0)
|
|
module = [conv,]
|
|
if use_gn:
|
|
module.append(group_norm(out_channels))
|
|
if use_relu:
|
|
module.append(nn.ReLU(inplace=True))
|
|
if len(module) > 1:
|
|
return nn.Sequential(*module)
|
|
return conv
|
|
|
|
|
|
def make_fc(dim_in, hidden_dim, use_gn=False):
|
|
'''
|
|
Caffe2 implementation uses XavierFill, which in fact
|
|
corresponds to kaiming_uniform_ in PyTorch
|
|
'''
|
|
if use_gn:
|
|
fc = nn.Linear(dim_in, hidden_dim, bias=False)
|
|
nn.init.kaiming_uniform_(fc.weight, a=1)
|
|
return nn.Sequential(fc, group_norm(hidden_dim))
|
|
fc = nn.Linear(dim_in, hidden_dim)
|
|
nn.init.kaiming_uniform_(fc.weight, a=1)
|
|
nn.init.constant_(fc.bias, 0)
|
|
return fc
|
|
|
|
|
|
def conv_with_kaiming_uniform(use_gn=False, use_relu=False, use_dyrelu=False):
|
|
def make_conv(
|
|
in_channels, out_channels, kernel_size, stride=1, dilation=1
|
|
):
|
|
conv = Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=dilation * (kernel_size - 1) // 2,
|
|
dilation=dilation,
|
|
bias=False if use_gn else True
|
|
)
|
|
# Caffe2 implementation uses XavierFill, which in fact
|
|
# corresponds to kaiming_uniform_ in PyTorch
|
|
nn.init.kaiming_uniform_(conv.weight, a=1)
|
|
if not use_gn:
|
|
nn.init.constant_(conv.bias, 0)
|
|
module = [conv,]
|
|
if use_gn:
|
|
module.append(group_norm(out_channels))
|
|
if use_relu:
|
|
module.append(nn.ReLU(inplace=True))
|
|
if use_dyrelu:
|
|
module.append(DYReLU(out_channels, out_channels, use_spatial=True))
|
|
if len(module) > 1:
|
|
return nn.Sequential(*module)
|
|
return conv
|
|
|
|
return make_conv
|