mirror of https://github.com/alibaba/EasyCV.git
246 lines
9.0 KiB
Python
246 lines
9.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import ConvModule, normal_init
|
|
|
|
from easycv.models.builder import HEADS
|
|
from .yolo_head_template import YOLOXHead_Template
|
|
|
|
|
|
class TaskDecomposition(nn.Module):
|
|
"""Task decomposition module in task-aligned predictor of TOOD.
|
|
|
|
Args:
|
|
feat_channels (int): Number of feature channels in TOOD head.
|
|
stacked_convs (int): Number of conv layers in TOOD head.
|
|
la_down_rate (int): Downsample rate of layer attention.
|
|
conv_cfg (dict): Config dict for convolution layer.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
"""
|
|
|
|
def __init__(self,
|
|
feat_channels,
|
|
stacked_convs=6,
|
|
la_down_rate=8,
|
|
conv_cfg=None,
|
|
norm_cfg=None):
|
|
super(TaskDecomposition, self).__init__()
|
|
self.feat_channels = feat_channels
|
|
self.stacked_convs = stacked_convs
|
|
self.in_channels = self.feat_channels * self.stacked_convs
|
|
self.norm_cfg = norm_cfg
|
|
self.layer_attention = nn.Sequential(
|
|
nn.Conv2d(self.in_channels, self.in_channels // la_down_rate, 1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(
|
|
self.in_channels // la_down_rate,
|
|
self.stacked_convs,
|
|
1,
|
|
padding=0), nn.Sigmoid())
|
|
|
|
self.reduction_conv = ConvModule(
|
|
self.in_channels,
|
|
self.feat_channels,
|
|
1,
|
|
stride=1,
|
|
padding=0,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
bias=norm_cfg is None)
|
|
|
|
def init_weights(self):
|
|
for m in self.layer_attention.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
normal_init(m, std=0.001)
|
|
normal_init(self.reduction_conv.conv, std=0.01)
|
|
|
|
def forward(self, feat, avg_feat=None):
|
|
b, c, h, w = feat.shape
|
|
if avg_feat is None:
|
|
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
|
|
weight = self.layer_attention(avg_feat)
|
|
|
|
# here we first compute the product between layer attention weight and
|
|
# conv weight, and then compute the convolution between new conv weight
|
|
# and feature map, in order to save memory and FLOPs.
|
|
conv_weight = weight.reshape(
|
|
b, 1, self.stacked_convs,
|
|
1) * self.reduction_conv.conv.weight.reshape(
|
|
1, self.feat_channels, self.stacked_convs, self.feat_channels)
|
|
conv_weight = conv_weight.reshape(b, self.feat_channels,
|
|
self.in_channels)
|
|
feat = feat.reshape(b, self.in_channels, h * w)
|
|
feat = torch.bmm(conv_weight, feat).reshape(b, self.feat_channels, h,
|
|
w)
|
|
if self.norm_cfg is not None:
|
|
feat = self.reduction_conv.norm(feat)
|
|
feat = self.reduction_conv.activate(feat)
|
|
|
|
return feat
|
|
|
|
|
|
@HEADS.register_module
|
|
class TOODHead(YOLOXHead_Template):
|
|
|
|
def __init__(
|
|
self,
|
|
num_classes,
|
|
model_type='s',
|
|
strides=[8, 16, 32],
|
|
in_channels=[256, 512, 1024],
|
|
act='silu',
|
|
conv_type='conv',
|
|
stage='CLOUD',
|
|
obj_loss_type='BCE',
|
|
reg_loss_type='giou',
|
|
stacked_convs=3,
|
|
la_down_rate=32,
|
|
decode_in_inference=True,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
|
|
):
|
|
"""
|
|
Args:
|
|
num_classes (int): detection class numbers.
|
|
width (float): model width. Default value: 1.0.
|
|
strides (list): expanded strides. Default value: [8, 16, 32].
|
|
in_channels (list): model conv channels set. Default value: [256, 512, 1024].
|
|
act (str): activation type of conv. Defalut value: "silu".
|
|
depthwise (bool): whether apply depthwise conv in conv branch. Default value: False.
|
|
stage (str): model stage, distinguish edge head to cloud head. Default value: CLOUD.
|
|
obj_loss_type (str): the loss function of the obj conf. Default value: l1.
|
|
reg_loss_type (str): the loss function of the box prediction. Default value: l1.
|
|
"""
|
|
super(TOODHead, self).__init__(
|
|
num_classes=num_classes,
|
|
model_type=model_type,
|
|
strides=strides,
|
|
in_channels=in_channels,
|
|
act=act,
|
|
conv_type=conv_type,
|
|
stage=stage,
|
|
obj_loss_type=obj_loss_type,
|
|
reg_loss_type=reg_loss_type,
|
|
decode_in_inference=decode_in_inference)
|
|
|
|
self.stacked_convs = stacked_convs
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
self.feat_channels = int(256 * self.width)
|
|
|
|
self.cls_decomps = nn.ModuleList()
|
|
self.reg_decomps = nn.ModuleList()
|
|
|
|
self.inter_convs = nn.ModuleList()
|
|
|
|
for i in range(len(in_channels)):
|
|
self.cls_decomps.append(
|
|
TaskDecomposition(self.feat_channels, self.stacked_convs,
|
|
self.stacked_convs * la_down_rate,
|
|
self.conv_cfg, self.norm_cfg))
|
|
self.reg_decomps.append(
|
|
TaskDecomposition(self.feat_channels, self.stacked_convs,
|
|
self.stacked_convs * la_down_rate,
|
|
self.conv_cfg, self.norm_cfg))
|
|
|
|
for i in range(self.stacked_convs):
|
|
conv_cfg = self.conv_cfg
|
|
chn = self.feat_channels
|
|
|
|
self.inter_convs.append(
|
|
ConvModule(
|
|
chn,
|
|
self.feat_channels,
|
|
3,
|
|
stride=1,
|
|
padding=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=self.norm_cfg))
|
|
|
|
def forward(self, xin, labels=None, imgs=None):
|
|
outputs = []
|
|
origin_preds = []
|
|
x_shifts = []
|
|
y_shifts = []
|
|
expanded_strides = []
|
|
|
|
for k, (cls_decomp, reg_decomp, cls_conv, reg_conv, stride_this_level,
|
|
x) in enumerate(
|
|
zip(self.cls_decomps, self.reg_decomps, self.cls_convs,
|
|
self.reg_convs, self.strides, xin)):
|
|
x = self.stems[k](x)
|
|
|
|
inter_feats = []
|
|
for inter_conv in self.inter_convs:
|
|
x = inter_conv(x)
|
|
inter_feats.append(x)
|
|
feat = torch.cat(inter_feats, 1)
|
|
|
|
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
|
|
cls_x = cls_decomp(feat, avg_feat)
|
|
reg_x = reg_decomp(feat, avg_feat)
|
|
|
|
cls_feat = cls_conv(cls_x)
|
|
cls_output = self.cls_preds[k](cls_feat)
|
|
|
|
reg_feat = reg_conv(reg_x)
|
|
reg_output = self.reg_preds[k](reg_feat)
|
|
obj_output = self.obj_preds[k](reg_feat)
|
|
|
|
if self.training:
|
|
output = torch.cat([reg_output, obj_output, cls_output], 1)
|
|
output, grid = self.get_output_and_grid(
|
|
output, k, stride_this_level, xin[0].type())
|
|
x_shifts.append(grid[:, :, 0])
|
|
y_shifts.append(grid[:, :, 1])
|
|
expanded_strides.append(
|
|
torch.zeros(
|
|
1, grid.shape[1]).fill_(stride_this_level).type_as(
|
|
xin[0]))
|
|
if self.use_l1:
|
|
batch_size = reg_output.shape[0]
|
|
hsize, wsize = reg_output.shape[-2:]
|
|
reg_output = reg_output.view(batch_size, self.n_anchors, 4,
|
|
hsize, wsize)
|
|
reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
|
|
batch_size, -1, 4)
|
|
origin_preds.append(reg_output.clone())
|
|
|
|
else:
|
|
if self.stage == 'EDGE':
|
|
m = nn.Hardsigmoid()
|
|
output = torch.cat(
|
|
[reg_output, m(obj_output),
|
|
m(cls_output)], 1)
|
|
else:
|
|
output = torch.cat([
|
|
reg_output,
|
|
obj_output.sigmoid(),
|
|
cls_output.sigmoid()
|
|
], 1)
|
|
|
|
outputs.append(output)
|
|
|
|
if self.training:
|
|
|
|
return self.get_losses(
|
|
imgs,
|
|
x_shifts,
|
|
y_shifts,
|
|
expanded_strides,
|
|
labels,
|
|
torch.cat(outputs, 1),
|
|
origin_preds,
|
|
dtype=xin[0].dtype,
|
|
)
|
|
|
|
else:
|
|
self.hw = [x.shape[-2:] for x in outputs]
|
|
# [batch, n_anchors_all, 85]
|
|
outputs = torch.cat([x.flatten(start_dim=2) for x in outputs],
|
|
dim=2).permute(0, 2, 1)
|
|
if self.decode_in_inference:
|
|
return self.decode_outputs(outputs, dtype=xin[0].type())
|
|
else:
|
|
return outputs
|