mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
709 lines
25 KiB
Python
709 lines
25 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class NonLocalModule(nn.Module):
|
|
"""
|
|
Builds Non-local Neural Networks as a generic family of building
|
|
blocks for capturing long-range dependencies. Non-local Network
|
|
computes the response at a position as a weighted sum of the
|
|
features at all positions. This building block can be plugged into
|
|
many computer vision architectures.
|
|
More details in the paper: https://arxiv.org/pdf/1711.07971.pdf
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
dim_inner,
|
|
pool_size=None,
|
|
instantiation='softmax',
|
|
zero_init_final_conv=False,
|
|
zero_init_final_norm=True,
|
|
norm_eps=1e-5,
|
|
norm_momentum=0.1,
|
|
norm_module=nn.BatchNorm3d,
|
|
):
|
|
"""
|
|
Args:
|
|
dim (int): number of dimension for the input.
|
|
dim_inner (int): number of dimension inside of the Non-local block.
|
|
pool_size (list): the kernel size of spatial temporal pooling,
|
|
temporal pool kernel size, spatial pool kernel size, spatial
|
|
pool kernel size in order. By default pool_size is None,
|
|
then there would be no pooling used.
|
|
instantiation (string): supports two different instantiation method:
|
|
"dot_product": normalizing correlation matrix with L2.
|
|
"softmax": normalizing correlation matrix with Softmax.
|
|
zero_init_final_conv (bool): If true, zero initializing the final
|
|
convolution of the Non-local block.
|
|
zero_init_final_norm (bool):
|
|
If true, zero initializing the final batch norm of the Non-local
|
|
block.
|
|
norm_module (nn.Module): nn.Module for the normalization layer. The
|
|
default is nn.BatchNorm3d.
|
|
"""
|
|
super(NonLocalModule, self).__init__()
|
|
self.dim = dim
|
|
self.dim_inner = dim_inner
|
|
self.pool_size = pool_size
|
|
self.instantiation = instantiation
|
|
self.use_pool = (False if pool_size is None else any(
|
|
(size > 1 for size in pool_size)))
|
|
self.norm_eps = norm_eps
|
|
self.norm_momentum = norm_momentum
|
|
self._construct_nonlocal(zero_init_final_conv, zero_init_final_norm,
|
|
norm_module)
|
|
|
|
def _construct_nonlocal(self, zero_init_final_conv, zero_init_final_norm,
|
|
norm_module):
|
|
# Three convolution heads: theta, phi, and g.
|
|
self.conv_theta = nn.Conv3d(
|
|
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
|
|
self.conv_phi = nn.Conv3d(
|
|
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
|
|
self.conv_g = nn.Conv3d(
|
|
self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
|
|
|
|
# Final convolution output.
|
|
self.conv_out = nn.Conv3d(
|
|
self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0)
|
|
# Zero initializing the final convolution output.
|
|
self.conv_out.zero_init = zero_init_final_conv
|
|
|
|
# TODO: change the name to `norm`
|
|
self.bn = norm_module(
|
|
num_features=self.dim,
|
|
eps=self.norm_eps,
|
|
momentum=self.norm_momentum,
|
|
)
|
|
# Zero initializing the final bn.
|
|
self.bn.transform_final_bn = zero_init_final_norm
|
|
|
|
# Optional to add the spatial-temporal pooling.
|
|
if self.use_pool:
|
|
self.pool = nn.MaxPool3d(
|
|
kernel_size=self.pool_size,
|
|
stride=self.pool_size,
|
|
padding=[0, 0, 0],
|
|
)
|
|
|
|
def forward(self, x):
|
|
x_identity = x
|
|
N, C, T, H, W = x.size()
|
|
|
|
theta = self.conv_theta(x)
|
|
|
|
# Perform temporal-spatial pooling to reduce the computation.
|
|
if self.use_pool:
|
|
x = self.pool(x)
|
|
|
|
phi = self.conv_phi(x)
|
|
g = self.conv_g(x)
|
|
|
|
theta = theta.view(N, self.dim_inner, -1)
|
|
phi = phi.view(N, self.dim_inner, -1)
|
|
g = g.view(N, self.dim_inner, -1)
|
|
|
|
# (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW).
|
|
theta_phi = torch.einsum('nct,ncp->ntp', (theta, phi))
|
|
# For original Non-local paper, there are two main ways to normalize
|
|
# the affinity tensor:
|
|
# 1) Softmax normalization (norm on exp).
|
|
# 2) dot_product normalization.
|
|
if self.instantiation == 'softmax':
|
|
# Normalizing the affinity tensor theta_phi before softmax.
|
|
theta_phi = theta_phi * (self.dim_inner**-0.5)
|
|
theta_phi = nn.functional.softmax(theta_phi, dim=2)
|
|
elif self.instantiation == 'dot_product':
|
|
spatial_temporal_dim = theta_phi.shape[2]
|
|
theta_phi = theta_phi / spatial_temporal_dim
|
|
else:
|
|
raise NotImplementedError('Unknown norm type {}'.format(
|
|
self.instantiation))
|
|
|
|
# (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW).
|
|
theta_phi_g = torch.einsum('ntg,ncg->nct', (theta_phi, g))
|
|
|
|
# (N, C, TxHxW) => (N, C, T, H, W).
|
|
theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W)
|
|
|
|
p = self.conv_out(theta_phi_g)
|
|
p = self.bn(p)
|
|
return x_identity + p
|
|
|
|
|
|
class Swish(nn.Module):
|
|
"""Swish activation function: x * sigmoid(x)."""
|
|
|
|
def __init__(self):
|
|
super(Swish, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return SwishEfficient.apply(x)
|
|
|
|
|
|
class SwishEfficient(torch.autograd.Function):
|
|
"""Swish activation function: x * sigmoid(x)."""
|
|
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
result = x * torch.sigmoid(x)
|
|
ctx.save_for_backward(x)
|
|
return result
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
x = ctx.saved_variables[0]
|
|
sigmoid_x = torch.sigmoid(x)
|
|
return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
|
|
|
|
|
|
class SE(nn.Module):
|
|
"""Squeeze-and-Excitation (SE) block w/ Swish: AvgPool, FC, Swish, FC, Sigmoid."""
|
|
|
|
def _round_width(self, width, multiplier, min_width=8, divisor=8):
|
|
"""
|
|
Round width of filters based on width multiplier
|
|
Args:
|
|
width (int): the channel dimensions of the input.
|
|
multiplier (float): the multiplication factor.
|
|
min_width (int): the minimum width after multiplication.
|
|
divisor (int): the new width should be dividable by divisor.
|
|
"""
|
|
if not multiplier:
|
|
return width
|
|
|
|
width *= multiplier
|
|
min_width = min_width or divisor
|
|
width_out = max(min_width,
|
|
int(width + divisor / 2) // divisor * divisor)
|
|
if width_out < 0.9 * width:
|
|
width_out += divisor
|
|
return int(width_out)
|
|
|
|
def __init__(self, dim_in, ratio, relu_act=True):
|
|
"""
|
|
Args:
|
|
dim_in (int): the channel dimensions of the input.
|
|
ratio (float): the channel reduction ratio for squeeze.
|
|
relu_act (bool): whether to use ReLU activation instead
|
|
of Swish (default).
|
|
divisor (int): the new width should be dividable by divisor.
|
|
"""
|
|
super(SE, self).__init__()
|
|
self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
|
|
dim_fc = self._round_width(dim_in, ratio)
|
|
self.fc1 = nn.Conv3d(dim_in, dim_fc, 1, bias=True)
|
|
self.fc1_act = nn.ReLU() if relu_act else Swish()
|
|
self.fc2 = nn.Conv3d(dim_fc, dim_in, 1, bias=True)
|
|
|
|
self.fc2_sig = nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
x_in = x
|
|
for module in self.children():
|
|
x = module(x)
|
|
return x_in * x
|
|
|
|
|
|
class X3DTransform(nn.Module):
|
|
"""
|
|
X3D transformation: 1x1x1, Tx3x3 (channelwise, num_groups=dim_in), 1x1x1,
|
|
augmented with (optional) SE (squeeze-excitation) on the 3x3x3 output.
|
|
T is the temporal kernel size (defaulting to 3)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim_in,
|
|
dim_out,
|
|
temp_kernel_size,
|
|
stride,
|
|
dim_inner,
|
|
num_groups,
|
|
stride_1x1=False,
|
|
inplace_relu=True,
|
|
eps=1e-5,
|
|
bn_mmt=0.1,
|
|
dilation=1,
|
|
norm_module=nn.BatchNorm3d,
|
|
se_ratio=0.0625,
|
|
swish_inner=True,
|
|
block_idx=0,
|
|
):
|
|
"""
|
|
Args:
|
|
dim_in (int): the channel dimensions of the input.
|
|
dim_out (int): the channel dimension of the output.
|
|
temp_kernel_size (int): the temporal kernel sizes of the middle
|
|
convolution in the bottleneck.
|
|
stride (int): the stride of the bottleneck.
|
|
dim_inner (int): the inner dimension of the block.
|
|
num_groups (int): number of groups for the convolution. num_groups=1
|
|
is for standard ResNet like networks, and num_groups>1 is for
|
|
ResNeXt like networks.
|
|
stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise
|
|
apply stride to the 3x3 conv.
|
|
inplace_relu (bool): if True, calculate the relu on the original
|
|
input without allocating new memory.
|
|
eps (float): epsilon for batch norm.
|
|
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
|
|
PyTorch = 1 - BN momentum in Caffe2.
|
|
dilation (int): size of dilation.
|
|
norm_module (nn.Module): nn.Module for the normalization layer. The
|
|
default is nn.BatchNorm3d.
|
|
se_ratio (float): if > 0, apply SE to the Tx3x3 conv, with the SE
|
|
channel dimensionality being se_ratio times the Tx3x3 conv dim.
|
|
swish_inner (bool): if True, apply swish to the Tx3x3 conv, otherwise
|
|
apply ReLU to the Tx3x3 conv.
|
|
"""
|
|
super(X3DTransform, self).__init__()
|
|
self.temp_kernel_size = temp_kernel_size
|
|
self._inplace_relu = inplace_relu
|
|
self._eps = eps
|
|
self._bn_mmt = bn_mmt
|
|
self._se_ratio = se_ratio
|
|
self._swish_inner = swish_inner
|
|
self._stride_1x1 = stride_1x1
|
|
self._block_idx = block_idx
|
|
self._construct(
|
|
dim_in,
|
|
dim_out,
|
|
stride,
|
|
dim_inner,
|
|
num_groups,
|
|
dilation,
|
|
norm_module,
|
|
)
|
|
|
|
def _construct(
|
|
self,
|
|
dim_in,
|
|
dim_out,
|
|
stride,
|
|
dim_inner,
|
|
num_groups,
|
|
dilation,
|
|
norm_module,
|
|
):
|
|
(str1x1, str3x3) = (stride, 1) if self._stride_1x1 else (1, stride)
|
|
|
|
# 1x1x1, BN, ReLU.
|
|
self.a = nn.Conv3d(
|
|
dim_in,
|
|
dim_inner,
|
|
kernel_size=[1, 1, 1],
|
|
stride=[1, str1x1, str1x1],
|
|
padding=[0, 0, 0],
|
|
bias=False,
|
|
)
|
|
self.a_bn = norm_module(
|
|
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt)
|
|
self.a_relu = nn.ReLU(inplace=self._inplace_relu)
|
|
|
|
# Tx3x3, BN, ReLU.
|
|
self.b = nn.Conv3d(
|
|
dim_inner,
|
|
dim_inner,
|
|
[self.temp_kernel_size, 3, 3],
|
|
stride=[1, str3x3, str3x3],
|
|
padding=[int(self.temp_kernel_size // 2), dilation, dilation],
|
|
groups=num_groups,
|
|
bias=False,
|
|
dilation=[1, dilation, dilation],
|
|
)
|
|
|
|
# from easycv.thirdparty.depthwise_conv3d.depthwise_conv3d import DepthwiseConv3d
|
|
# self.b = DepthwiseConv3d(
|
|
# dim_inner,
|
|
# dim_inner,
|
|
# [self.temp_kernel_size, 3, 3],
|
|
# stride=[1, str3x3, str3x3],
|
|
# padding=[int(self.temp_kernel_size // 2), dilation, dilation],
|
|
# dilation=[1, dilation, dilation],
|
|
# groups=num_groups,
|
|
# bias=False,
|
|
# )
|
|
self.b_bn = norm_module(
|
|
num_features=dim_inner, eps=self._eps, momentum=self._bn_mmt)
|
|
|
|
# Apply SE attention or not
|
|
use_se = True if (self._block_idx + 1) % 2 else False
|
|
if self._se_ratio > 0.0 and use_se:
|
|
self.se = SE(dim_inner, self._se_ratio)
|
|
|
|
if self._swish_inner:
|
|
self.b_relu = Swish()
|
|
else:
|
|
self.b_relu = nn.ReLU(inplace=self._inplace_relu)
|
|
|
|
# 1x1x1, BN.
|
|
self.c = nn.Conv3d(
|
|
dim_inner,
|
|
dim_out,
|
|
kernel_size=[1, 1, 1],
|
|
stride=[1, 1, 1],
|
|
padding=[0, 0, 0],
|
|
bias=False,
|
|
)
|
|
self.c_bn = norm_module(
|
|
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt)
|
|
self.c_bn.transform_final_bn = True
|
|
|
|
def forward(self, x):
|
|
for block in self.children():
|
|
x = block(x)
|
|
return x
|
|
|
|
|
|
def get_trans_func(name):
|
|
"""
|
|
Retrieves the transformation module by name.
|
|
"""
|
|
trans_funcs = {
|
|
'x3d_transform': X3DTransform,
|
|
}
|
|
assert (name in trans_funcs.keys()
|
|
), "Transformation function '{}' not supported".format(name)
|
|
return trans_funcs[name]
|
|
|
|
|
|
class ResBlock(nn.Module):
|
|
"""
|
|
Residual block.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim_in,
|
|
dim_out,
|
|
temp_kernel_size,
|
|
stride,
|
|
trans_func,
|
|
dim_inner,
|
|
num_groups=1,
|
|
stride_1x1=False,
|
|
inplace_relu=True,
|
|
eps=1e-5,
|
|
bn_mmt=0.1,
|
|
dilation=1,
|
|
norm_module=nn.BatchNorm3d,
|
|
block_idx=0,
|
|
drop_connect_rate=0.0,
|
|
):
|
|
"""
|
|
ResBlock class constructs redisual blocks. More details can be found in:
|
|
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun.
|
|
"Deep residual learning for image recognition."
|
|
https://arxiv.org/abs/1512.03385
|
|
Args:
|
|
dim_in (int): the channel dimensions of the input.
|
|
dim_out (int): the channel dimension of the output.
|
|
temp_kernel_size (int): the temporal kernel sizes of the middle
|
|
convolution in the bottleneck.
|
|
stride (int): the stride of the bottleneck.
|
|
trans_func (string): transform function to be used to construct the
|
|
bottleneck.
|
|
dim_inner (int): the inner dimension of the block.
|
|
num_groups (int): number of groups for the convolution. num_groups=1
|
|
is for standard ResNet like networks, and num_groups>1 is for
|
|
ResNeXt like networks.
|
|
stride_1x1 (bool): if True, apply stride to 1x1 conv, otherwise
|
|
apply stride to the 3x3 conv.
|
|
inplace_relu (bool): calculate the relu on the original input
|
|
without allocating new memory.
|
|
eps (float): epsilon for batch norm.
|
|
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
|
|
PyTorch = 1 - BN momentum in Caffe2.
|
|
dilation (int): size of dilation.
|
|
norm_module (nn.Module): nn.Module for the normalization layer. The
|
|
default is nn.BatchNorm3d.
|
|
drop_connect_rate (float): basic rate at which blocks are dropped,
|
|
linearly increases from input to output blocks.
|
|
"""
|
|
super(ResBlock, self).__init__()
|
|
self._inplace_relu = inplace_relu
|
|
self._eps = eps
|
|
self._bn_mmt = bn_mmt
|
|
self._drop_connect_rate = drop_connect_rate
|
|
|
|
self._construct(
|
|
dim_in,
|
|
dim_out,
|
|
temp_kernel_size,
|
|
stride,
|
|
trans_func,
|
|
dim_inner,
|
|
num_groups,
|
|
stride_1x1,
|
|
inplace_relu,
|
|
dilation,
|
|
norm_module,
|
|
block_idx,
|
|
)
|
|
|
|
def _construct(
|
|
self,
|
|
dim_in,
|
|
dim_out,
|
|
temp_kernel_size,
|
|
stride,
|
|
trans_func,
|
|
dim_inner,
|
|
num_groups,
|
|
stride_1x1,
|
|
inplace_relu,
|
|
dilation,
|
|
norm_module,
|
|
block_idx,
|
|
):
|
|
# Use skip connection with projection if dim or res change.
|
|
if (dim_in != dim_out) or (stride != 1):
|
|
self.branch1 = nn.Conv3d(
|
|
dim_in,
|
|
dim_out,
|
|
kernel_size=1,
|
|
stride=[1, stride, stride],
|
|
padding=0,
|
|
bias=False,
|
|
dilation=1,
|
|
)
|
|
self.branch1_bn = norm_module(
|
|
num_features=dim_out, eps=self._eps, momentum=self._bn_mmt)
|
|
self.branch2 = trans_func(
|
|
dim_in,
|
|
dim_out,
|
|
temp_kernel_size,
|
|
stride,
|
|
dim_inner,
|
|
num_groups,
|
|
stride_1x1=stride_1x1,
|
|
inplace_relu=inplace_relu,
|
|
dilation=dilation,
|
|
norm_module=norm_module,
|
|
block_idx=block_idx,
|
|
)
|
|
self.relu = nn.ReLU(self._inplace_relu)
|
|
|
|
def _drop_connect(self, x, drop_ratio):
|
|
"""Apply dropconnect to x"""
|
|
keep_ratio = 1.0 - drop_ratio
|
|
mask = torch.empty([x.shape[0], 1, 1, 1, 1],
|
|
dtype=x.dtype,
|
|
device=x.device)
|
|
mask.bernoulli_(keep_ratio)
|
|
x.div_(keep_ratio)
|
|
x.mul_(mask)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
f_x = self.branch2(x)
|
|
if self.training and self._drop_connect_rate > 0.0:
|
|
f_x = self._drop_connect(f_x, self._drop_connect_rate)
|
|
if hasattr(self, 'branch1'):
|
|
x = self.branch1_bn(self.branch1(x)) + f_x
|
|
else:
|
|
x = x + f_x
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
|
|
class ResStage(nn.Module):
|
|
"""
|
|
Stage of 3D ResNet. It expects to have one or more tensors as input for
|
|
single pathway (C2D, I3D, Slow), and multi-pathway (SlowFast) cases.
|
|
More details can be found here:
|
|
|
|
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
|
|
"SlowFast networks for video recognition."
|
|
https://arxiv.org/pdf/1812.03982.pdf
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim_in,
|
|
dim_out,
|
|
stride,
|
|
temp_kernel_sizes,
|
|
num_blocks,
|
|
dim_inner,
|
|
num_groups,
|
|
num_block_temp_kernel,
|
|
nonlocal_inds,
|
|
nonlocal_group,
|
|
nonlocal_pool,
|
|
dilation,
|
|
instantiation='softmax',
|
|
trans_func_name='bottleneck_transform',
|
|
stride_1x1=False,
|
|
inplace_relu=True,
|
|
norm_module=nn.BatchNorm3d,
|
|
drop_connect_rate=0.0,
|
|
):
|
|
"""
|
|
The `__init__` method of any subclass should also contain these arguments.
|
|
ResStage builds p streams, where p can be greater or equal to one.
|
|
Args:
|
|
dim_in (list): list of p the channel dimensions of the input.
|
|
Different channel dimensions control the input dimension of
|
|
different pathways.
|
|
dim_out (list): list of p the channel dimensions of the output.
|
|
Different channel dimensions control the input dimension of
|
|
different pathways.
|
|
temp_kernel_sizes (list): list of the p temporal kernel sizes of the
|
|
convolution in the bottleneck. Different temp_kernel_sizes
|
|
control different pathway.
|
|
stride (list): list of the p strides of the bottleneck. Different
|
|
stride control different pathway.
|
|
num_blocks (list): list of p numbers of blocks for each of the
|
|
pathway.
|
|
dim_inner (list): list of the p inner channel dimensions of the
|
|
input. Different channel dimensions control the input dimension
|
|
of different pathways.
|
|
num_groups (list): list of number of p groups for the convolution.
|
|
num_groups=1 is for standard ResNet like networks, and
|
|
num_groups>1 is for ResNeXt like networks.
|
|
num_block_temp_kernel (list): extent the temp_kernel_sizes to
|
|
num_block_temp_kernel blocks, then fill temporal kernel size
|
|
of 1 for the rest of the layers.
|
|
nonlocal_inds (list): If the tuple is empty, no nonlocal layer will
|
|
be added. If the tuple is not empty, add nonlocal layers after
|
|
the index-th block.
|
|
dilation (list): size of dilation for each pathway.
|
|
nonlocal_group (list): list of number of p nonlocal groups. Each
|
|
number controls how to fold temporal dimension to batch
|
|
dimension before applying nonlocal transformation.
|
|
https://github.com/facebookresearch/video-nonlocal-net.
|
|
instantiation (string): different instantiation for nonlocal layer.
|
|
Supports two different instantiation method:
|
|
"dot_product": normalizing correlation matrix with L2.
|
|
"softmax": normalizing correlation matrix with Softmax.
|
|
trans_func_name (string): name of the the transformation function apply
|
|
on the network.
|
|
norm_module (nn.Module): nn.Module for the normalization layer. The
|
|
default is nn.BatchNorm3d.
|
|
drop_connect_rate (float): basic rate at which blocks are dropped,
|
|
linearly increases from input to output blocks.
|
|
"""
|
|
super(ResStage, self).__init__()
|
|
assert all((num_block_temp_kernel[i] <= num_blocks[i]
|
|
for i in range(len(temp_kernel_sizes))))
|
|
self.num_blocks = num_blocks
|
|
self.nonlocal_group = nonlocal_group
|
|
self._drop_connect_rate = drop_connect_rate
|
|
self.temp_kernel_sizes = [
|
|
(temp_kernel_sizes[i] * num_blocks[i])[:num_block_temp_kernel[i]] +
|
|
[1] * (num_blocks[i] - num_block_temp_kernel[i])
|
|
for i in range(len(temp_kernel_sizes))
|
|
]
|
|
assert (len({
|
|
len(dim_in),
|
|
len(dim_out),
|
|
len(temp_kernel_sizes),
|
|
len(stride),
|
|
len(num_blocks),
|
|
len(dim_inner),
|
|
len(num_groups),
|
|
len(num_block_temp_kernel),
|
|
len(nonlocal_inds),
|
|
len(nonlocal_group),
|
|
}) == 1)
|
|
self.num_pathways = len(self.num_blocks)
|
|
self._construct(
|
|
dim_in,
|
|
dim_out,
|
|
stride,
|
|
dim_inner,
|
|
num_groups,
|
|
trans_func_name,
|
|
stride_1x1,
|
|
inplace_relu,
|
|
nonlocal_inds,
|
|
nonlocal_pool,
|
|
instantiation,
|
|
dilation,
|
|
norm_module,
|
|
)
|
|
|
|
def _construct(
|
|
self,
|
|
dim_in,
|
|
dim_out,
|
|
stride,
|
|
dim_inner,
|
|
num_groups,
|
|
trans_func_name,
|
|
stride_1x1,
|
|
inplace_relu,
|
|
nonlocal_inds,
|
|
nonlocal_pool,
|
|
instantiation,
|
|
dilation,
|
|
norm_module,
|
|
):
|
|
for pathway in range(self.num_pathways):
|
|
for i in range(self.num_blocks[pathway]):
|
|
# Retrieve the transformation function.
|
|
trans_func = get_trans_func(trans_func_name)
|
|
# Construct the block.
|
|
res_block = ResBlock(
|
|
dim_in[pathway] if i == 0 else dim_out[pathway],
|
|
dim_out[pathway],
|
|
self.temp_kernel_sizes[pathway][i],
|
|
stride[pathway] if i == 0 else 1,
|
|
trans_func,
|
|
dim_inner[pathway],
|
|
num_groups[pathway],
|
|
stride_1x1=stride_1x1,
|
|
inplace_relu=inplace_relu,
|
|
dilation=dilation[pathway],
|
|
norm_module=norm_module,
|
|
block_idx=i,
|
|
drop_connect_rate=self._drop_connect_rate,
|
|
)
|
|
self.add_module('pathway{}_res{}'.format(pathway, i),
|
|
res_block)
|
|
if i in nonlocal_inds[pathway]:
|
|
nln = NonLocalModule(
|
|
dim_out[pathway],
|
|
dim_out[pathway] // 2,
|
|
nonlocal_pool[pathway],
|
|
instantiation=instantiation,
|
|
norm_module=norm_module,
|
|
)
|
|
self.add_module('pathway{}_nonlocal{}'.format(pathway, i),
|
|
nln)
|
|
|
|
def forward(self, inputs):
|
|
output = []
|
|
for pathway in range(self.num_pathways):
|
|
x = inputs
|
|
for i in range(self.num_blocks[pathway]):
|
|
m = getattr(self, 'pathway{}_res{}'.format(pathway, i))
|
|
x = m(x)
|
|
if hasattr(self, 'pathway{}_nonlocal{}'.format(pathway, i)):
|
|
nln = getattr(self,
|
|
'pathway{}_nonlocal{}'.format(pathway, i))
|
|
b, c, t, h, w = x.shape
|
|
if self.nonlocal_group[pathway] > 1:
|
|
# Fold temporal dimension into batch dimension.
|
|
x = x.permute(0, 2, 1, 3, 4)
|
|
x = x.reshape(
|
|
b * self.nonlocal_group[pathway],
|
|
t // self.nonlocal_group[pathway],
|
|
c,
|
|
h,
|
|
w,
|
|
)
|
|
x = x.permute(0, 2, 1, 3, 4)
|
|
x = nln(x)
|
|
if self.nonlocal_group[pathway] > 1:
|
|
# Fold back to temporal dimension.
|
|
x = x.permute(0, 2, 1, 3, 4)
|
|
x = x.reshape(b, t, c, h, w)
|
|
x = x.permute(0, 2, 1, 3, 4)
|
|
output.append(x)
|
|
|
|
return output[0]
|