mirror of https://github.com/RE-OWOD/RE-OWOD
108 lines
3.8 KiB
Python
108 lines
3.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from torch.nn.modules.utils import _pair
|
|
|
|
from detectron2.layers.wrappers import _NewEmptyTensorOp
|
|
|
|
|
|
class TridentConv(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
paddings=0,
|
|
dilations=1,
|
|
groups=1,
|
|
num_branch=1,
|
|
test_branch_idx=-1,
|
|
bias=False,
|
|
norm=None,
|
|
activation=None,
|
|
):
|
|
super(TridentConv, self).__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.kernel_size = _pair(kernel_size)
|
|
self.num_branch = num_branch
|
|
self.stride = _pair(stride)
|
|
self.groups = groups
|
|
self.with_bias = bias
|
|
if isinstance(paddings, int):
|
|
paddings = [paddings] * self.num_branch
|
|
if isinstance(dilations, int):
|
|
dilations = [dilations] * self.num_branch
|
|
self.paddings = [_pair(padding) for padding in paddings]
|
|
self.dilations = [_pair(dilation) for dilation in dilations]
|
|
self.test_branch_idx = test_branch_idx
|
|
self.norm = norm
|
|
self.activation = activation
|
|
|
|
assert len({self.num_branch, len(self.paddings), len(self.dilations)}) == 1
|
|
|
|
self.weight = nn.Parameter(
|
|
torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
|
|
)
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
|
else:
|
|
self.bias = None
|
|
|
|
nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
|
|
if self.bias is not None:
|
|
nn.init.constant_(self.bias, 0)
|
|
|
|
def forward(self, inputs):
|
|
num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1
|
|
assert len(inputs) == num_branch
|
|
|
|
if inputs[0].numel() == 0:
|
|
output_shape = [
|
|
(i + 2 * p - (di * (k - 1) + 1)) // s + 1
|
|
for i, p, di, k, s in zip(
|
|
inputs[0].shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
|
|
)
|
|
]
|
|
output_shape = [input[0].shape[0], self.weight.shape[0]] + output_shape
|
|
return [_NewEmptyTensorOp.apply(input, output_shape) for input in inputs]
|
|
|
|
if self.training or self.test_branch_idx == -1:
|
|
outputs = [
|
|
F.conv2d(input, self.weight, self.bias, self.stride, padding, dilation, self.groups)
|
|
for input, dilation, padding in zip(inputs, self.dilations, self.paddings)
|
|
]
|
|
else:
|
|
outputs = [
|
|
F.conv2d(
|
|
inputs[0],
|
|
self.weight,
|
|
self.bias,
|
|
self.stride,
|
|
self.paddings[self.test_branch_idx],
|
|
self.dilations[self.test_branch_idx],
|
|
self.groups,
|
|
)
|
|
]
|
|
|
|
if self.norm is not None:
|
|
outputs = [self.norm(x) for x in outputs]
|
|
if self.activation is not None:
|
|
outputs = [self.activation(x) for x in outputs]
|
|
return outputs
|
|
|
|
def extra_repr(self):
|
|
tmpstr = "in_channels=" + str(self.in_channels)
|
|
tmpstr += ", out_channels=" + str(self.out_channels)
|
|
tmpstr += ", kernel_size=" + str(self.kernel_size)
|
|
tmpstr += ", num_branch=" + str(self.num_branch)
|
|
tmpstr += ", test_branch_idx=" + str(self.test_branch_idx)
|
|
tmpstr += ", stride=" + str(self.stride)
|
|
tmpstr += ", paddings=" + str(self.paddings)
|
|
tmpstr += ", dilations=" + str(self.dilations)
|
|
tmpstr += ", groups=" + str(self.groups)
|
|
tmpstr += ", bias=" + str(self.with_bias)
|
|
return tmpstr
|