wuziheng 9aaa600f79
Yolox improve with REPConv/ASFF/TOOD (#154)
* add attention layer and more loss function

* add attention layer and various loss functions

* add siou loss

* add tah,various attention layers, and different loss functions

* add asff sim, gsconv

* blade utils fit faster

* blade optimize for yolox static & fp16

* decode output for yolox control by cfg

* add reparameterize_models for export

* e2e trt_nms plugin export support and numeric test

* split preprocess from end2end+blade, speedup from 17ms->7.2ms

Co-authored-by: zouxinyi0625 <zouxinyi.zxy@alibaba-inc.com>
2022-08-24 18:11:15 +08:00

128 lines
4.5 KiB
Python

# Copyright (c) 2014-2021 Megvii Inc And Alibaba PAI-Teams. All rights reserved.
import torch
import torch.nn as nn
from easycv.models.builder import HEADS
from .yolo_head_template import YOLOXHead_Template
@HEADS.register_module
class YOLOXHead(YOLOXHead_Template):
def __init__(
self,
num_classes=80,
model_type='s',
strides=[8, 16, 32],
in_channels=[256, 512, 1024],
act='silu',
conv_type='conv',
stage='CLOUD',
obj_loss_type='BCE',
reg_loss_type='giou',
decode_in_inference=True,
width=None,
):
"""
Args:
num_classes (int): detection class numbers.
width (float): model width. Default value: 1.0.
strides (list): expanded strides. Default value: [8, 16, 32].
in_channels (list): model conv channels set. Default value: [256, 512, 1024].
act (str): activation type of conv. Defalut value: "silu".
depthwise (bool): whether apply depthwise conv in conv branch. Default value: False.
stage (str): model stage, distinguish edge head to cloud head. Default value: CLOUD.
obj_loss_type (str): the loss function of the obj conf. Default value: l1.
reg_loss_type (str): the loss function of the box prediction. Default value: l1.
"""
super(YOLOXHead, self).__init__(
width=width,
num_classes=num_classes,
model_type=model_type,
strides=strides,
in_channels=in_channels,
act=act,
conv_type=conv_type,
stage=stage,
obj_loss_type=obj_loss_type,
reg_loss_type=reg_loss_type,
decode_in_inference=decode_in_inference,
)
def forward(self, xin, labels=None, imgs=None):
outputs = []
origin_preds = []
x_shifts = []
y_shifts = []
expanded_strides = []
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
zip(self.cls_convs, self.reg_convs, self.strides, xin)):
x = self.stems[k](x)
cls_x = x
reg_x = x
cls_feat = cls_conv(cls_x)
cls_output = self.cls_preds[k](cls_feat)
reg_feat = reg_conv(reg_x)
reg_output = self.reg_preds[k](reg_feat)
obj_output = self.obj_preds[k](reg_feat)
if self.training:
output = torch.cat([reg_output, obj_output, cls_output], 1)
output, grid = self.get_output_and_grid(
output, k, stride_this_level, xin[0].type())
x_shifts.append(grid[:, :, 0])
y_shifts.append(grid[:, :, 1])
expanded_strides.append(
torch.zeros(
1, grid.shape[1]).fill_(stride_this_level).type_as(
xin[0]))
if self.use_l1:
batch_size = reg_output.shape[0]
hsize, wsize = reg_output.shape[-2:]
reg_output = reg_output.view(batch_size, self.n_anchors, 4,
hsize, wsize)
reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
batch_size, -1, 4)
origin_preds.append(reg_output.clone())
else:
if self.stage == 'EDGE':
m = nn.Hardsigmoid()
output = torch.cat(
[reg_output, m(obj_output),
m(cls_output)], 1)
else:
output = torch.cat([
reg_output,
obj_output.sigmoid(),
cls_output.sigmoid()
], 1)
outputs.append(output)
if self.training:
return self.get_losses(
imgs,
x_shifts,
y_shifts,
expanded_strides,
labels,
torch.cat(outputs, 1),
origin_preds,
dtype=xin[0].dtype,
)
else:
self.hw = [x.shape[-2:] for x in outputs]
# [batch, n_anchors_all, 85]
outputs = torch.cat([x.flatten(start_dim=2) for x in outputs],
dim=2).permute(0, 2, 1)
if self.decode_in_inference:
return self.decode_outputs(outputs, dtype=xin[0].type())
else:
return outputs