EasyCV/easycv/models/detection/detectors/yolox/tood_head.py

246 lines
9.0 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, normal_init
from easycv.models.builder import HEADS
from .yolo_head_template import YOLOXHead_Template
class TaskDecomposition(nn.Module):
"""Task decomposition module in task-aligned predictor of TOOD.
Args:
feat_channels (int): Number of feature channels in TOOD head.
stacked_convs (int): Number of conv layers in TOOD head.
la_down_rate (int): Downsample rate of layer attention.
conv_cfg (dict): Config dict for convolution layer.
norm_cfg (dict): Config dict for normalization layer.
"""
def __init__(self,
feat_channels,
stacked_convs=6,
la_down_rate=8,
conv_cfg=None,
norm_cfg=None):
super(TaskDecomposition, self).__init__()
self.feat_channels = feat_channels
self.stacked_convs = stacked_convs
self.in_channels = self.feat_channels * self.stacked_convs
self.norm_cfg = norm_cfg
self.layer_attention = nn.Sequential(
nn.Conv2d(self.in_channels, self.in_channels // la_down_rate, 1),
nn.ReLU(inplace=True),
nn.Conv2d(
self.in_channels // la_down_rate,
self.stacked_convs,
1,
padding=0), nn.Sigmoid())
self.reduction_conv = ConvModule(
self.in_channels,
self.feat_channels,
1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=norm_cfg is None)
def init_weights(self):
for m in self.layer_attention.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001)
normal_init(self.reduction_conv.conv, std=0.01)
def forward(self, feat, avg_feat=None):
b, c, h, w = feat.shape
if avg_feat is None:
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
weight = self.layer_attention(avg_feat)
# here we first compute the product between layer attention weight and
# conv weight, and then compute the convolution between new conv weight
# and feature map, in order to save memory and FLOPs.
conv_weight = weight.reshape(
b, 1, self.stacked_convs,
1) * self.reduction_conv.conv.weight.reshape(
1, self.feat_channels, self.stacked_convs, self.feat_channels)
conv_weight = conv_weight.reshape(b, self.feat_channels,
self.in_channels)
feat = feat.reshape(b, self.in_channels, h * w)
feat = torch.bmm(conv_weight, feat).reshape(b, self.feat_channels, h,
w)
if self.norm_cfg is not None:
feat = self.reduction_conv.norm(feat)
feat = self.reduction_conv.activate(feat)
return feat
@HEADS.register_module
class TOODHead(YOLOXHead_Template):
def __init__(
self,
num_classes,
model_type='s',
strides=[8, 16, 32],
in_channels=[256, 512, 1024],
act='silu',
conv_type='conv',
stage='CLOUD',
obj_loss_type='BCE',
reg_loss_type='giou',
stacked_convs=3,
la_down_rate=32,
decode_in_inference=True,
conv_cfg=None,
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
):
"""
Args:
num_classes (int): detection class numbers.
width (float): model width. Default value: 1.0.
strides (list): expanded strides. Default value: [8, 16, 32].
in_channels (list): model conv channels set. Default value: [256, 512, 1024].
act (str): activation type of conv. Defalut value: "silu".
depthwise (bool): whether apply depthwise conv in conv branch. Default value: False.
stage (str): model stage, distinguish edge head to cloud head. Default value: CLOUD.
obj_loss_type (str): the loss function of the obj conf. Default value: l1.
reg_loss_type (str): the loss function of the box prediction. Default value: l1.
"""
super(TOODHead, self).__init__(
num_classes=num_classes,
model_type=model_type,
strides=strides,
in_channels=in_channels,
act=act,
conv_type=conv_type,
stage=stage,
obj_loss_type=obj_loss_type,
reg_loss_type=reg_loss_type,
decode_in_inference=decode_in_inference)
self.stacked_convs = stacked_convs
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.feat_channels = int(256 * self.width)
self.cls_decomps = nn.ModuleList()
self.reg_decomps = nn.ModuleList()
self.inter_convs = nn.ModuleList()
for i in range(len(in_channels)):
self.cls_decomps.append(
TaskDecomposition(self.feat_channels, self.stacked_convs,
self.stacked_convs * la_down_rate,
self.conv_cfg, self.norm_cfg))
self.reg_decomps.append(
TaskDecomposition(self.feat_channels, self.stacked_convs,
self.stacked_convs * la_down_rate,
self.conv_cfg, self.norm_cfg))
for i in range(self.stacked_convs):
conv_cfg = self.conv_cfg
chn = self.feat_channels
self.inter_convs.append(
ConvModule(
chn,
self.feat_channels,
3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=self.norm_cfg))
def forward(self, xin, labels=None, imgs=None):
outputs = []
origin_preds = []
x_shifts = []
y_shifts = []
expanded_strides = []
for k, (cls_decomp, reg_decomp, cls_conv, reg_conv, stride_this_level,
x) in enumerate(
zip(self.cls_decomps, self.reg_decomps, self.cls_convs,
self.reg_convs, self.strides, xin)):
x = self.stems[k](x)
inter_feats = []
for inter_conv in self.inter_convs:
x = inter_conv(x)
inter_feats.append(x)
feat = torch.cat(inter_feats, 1)
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
cls_x = cls_decomp(feat, avg_feat)
reg_x = reg_decomp(feat, avg_feat)
cls_feat = cls_conv(cls_x)
cls_output = self.cls_preds[k](cls_feat)
reg_feat = reg_conv(reg_x)
reg_output = self.reg_preds[k](reg_feat)
obj_output = self.obj_preds[k](reg_feat)
if self.training:
output = torch.cat([reg_output, obj_output, cls_output], 1)
output, grid = self.get_output_and_grid(
output, k, stride_this_level, xin[0].type())
x_shifts.append(grid[:, :, 0])
y_shifts.append(grid[:, :, 1])
expanded_strides.append(
torch.zeros(
1, grid.shape[1]).fill_(stride_this_level).type_as(
xin[0]))
if self.use_l1:
batch_size = reg_output.shape[0]
hsize, wsize = reg_output.shape[-2:]
reg_output = reg_output.view(batch_size, self.n_anchors, 4,
hsize, wsize)
reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
batch_size, -1, 4)
origin_preds.append(reg_output.clone())
else:
if self.stage == 'EDGE':
m = nn.Hardsigmoid()
output = torch.cat(
[reg_output, m(obj_output),
m(cls_output)], 1)
else:
output = torch.cat([
reg_output,
obj_output.sigmoid(),
cls_output.sigmoid()
], 1)
outputs.append(output)
if self.training:
return self.get_losses(
imgs,
x_shifts,
y_shifts,
expanded_strides,
labels,
torch.cat(outputs, 1),
origin_preds,
dtype=xin[0].dtype,
)
else:
self.hw = [x.shape[-2:] for x in outputs]
# [batch, n_anchors_all, 85]
outputs = torch.cat([x.flatten(start_dim=2) for x in outputs],
dim=2).permute(0, 2, 1)
if self.decode_in_inference:
return self.decode_outputs(outputs, dtype=xin[0].type())
else:
return outputs