mirror of
https://github.com/alibaba/EasyCV.git
synced 2025-06-03 14:49:00 +08:00
* add attention layer and more loss function * add attention layer and various loss functions * add siou loss * add tah,various attention layers, and different loss functions * add asff sim, gsconv * blade utils fit faster * blade optimize for yolox static & fp16 * decode output for yolox control by cfg * add reparameterize_models for export * e2e trt_nms plugin export support and numeric test * split preprocess from end2end+blade, speedup from 17ms->7.2ms Co-authored-by: zouxinyi0625 <zouxinyi.zxy@alibaba-inc.com>
128 lines
4.5 KiB
Python
128 lines
4.5 KiB
Python
# Copyright (c) 2014-2021 Megvii Inc And Alibaba PAI-Teams. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from easycv.models.builder import HEADS
|
|
from .yolo_head_template import YOLOXHead_Template
|
|
|
|
|
|
@HEADS.register_module
|
|
class YOLOXHead(YOLOXHead_Template):
|
|
|
|
def __init__(
|
|
self,
|
|
num_classes=80,
|
|
model_type='s',
|
|
strides=[8, 16, 32],
|
|
in_channels=[256, 512, 1024],
|
|
act='silu',
|
|
conv_type='conv',
|
|
stage='CLOUD',
|
|
obj_loss_type='BCE',
|
|
reg_loss_type='giou',
|
|
decode_in_inference=True,
|
|
width=None,
|
|
):
|
|
"""
|
|
Args:
|
|
num_classes (int): detection class numbers.
|
|
width (float): model width. Default value: 1.0.
|
|
strides (list): expanded strides. Default value: [8, 16, 32].
|
|
in_channels (list): model conv channels set. Default value: [256, 512, 1024].
|
|
act (str): activation type of conv. Defalut value: "silu".
|
|
depthwise (bool): whether apply depthwise conv in conv branch. Default value: False.
|
|
stage (str): model stage, distinguish edge head to cloud head. Default value: CLOUD.
|
|
obj_loss_type (str): the loss function of the obj conf. Default value: l1.
|
|
reg_loss_type (str): the loss function of the box prediction. Default value: l1.
|
|
"""
|
|
super(YOLOXHead, self).__init__(
|
|
width=width,
|
|
num_classes=num_classes,
|
|
model_type=model_type,
|
|
strides=strides,
|
|
in_channels=in_channels,
|
|
act=act,
|
|
conv_type=conv_type,
|
|
stage=stage,
|
|
obj_loss_type=obj_loss_type,
|
|
reg_loss_type=reg_loss_type,
|
|
decode_in_inference=decode_in_inference,
|
|
)
|
|
|
|
def forward(self, xin, labels=None, imgs=None):
|
|
outputs = []
|
|
origin_preds = []
|
|
x_shifts = []
|
|
y_shifts = []
|
|
expanded_strides = []
|
|
|
|
for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
|
|
zip(self.cls_convs, self.reg_convs, self.strides, xin)):
|
|
x = self.stems[k](x)
|
|
cls_x = x
|
|
reg_x = x
|
|
|
|
cls_feat = cls_conv(cls_x)
|
|
cls_output = self.cls_preds[k](cls_feat)
|
|
|
|
reg_feat = reg_conv(reg_x)
|
|
reg_output = self.reg_preds[k](reg_feat)
|
|
obj_output = self.obj_preds[k](reg_feat)
|
|
|
|
if self.training:
|
|
output = torch.cat([reg_output, obj_output, cls_output], 1)
|
|
output, grid = self.get_output_and_grid(
|
|
output, k, stride_this_level, xin[0].type())
|
|
x_shifts.append(grid[:, :, 0])
|
|
y_shifts.append(grid[:, :, 1])
|
|
expanded_strides.append(
|
|
torch.zeros(
|
|
1, grid.shape[1]).fill_(stride_this_level).type_as(
|
|
xin[0]))
|
|
if self.use_l1:
|
|
batch_size = reg_output.shape[0]
|
|
hsize, wsize = reg_output.shape[-2:]
|
|
reg_output = reg_output.view(batch_size, self.n_anchors, 4,
|
|
hsize, wsize)
|
|
reg_output = reg_output.permute(0, 1, 3, 4, 2).reshape(
|
|
batch_size, -1, 4)
|
|
origin_preds.append(reg_output.clone())
|
|
|
|
else:
|
|
if self.stage == 'EDGE':
|
|
m = nn.Hardsigmoid()
|
|
output = torch.cat(
|
|
[reg_output, m(obj_output),
|
|
m(cls_output)], 1)
|
|
else:
|
|
output = torch.cat([
|
|
reg_output,
|
|
obj_output.sigmoid(),
|
|
cls_output.sigmoid()
|
|
], 1)
|
|
|
|
outputs.append(output)
|
|
|
|
if self.training:
|
|
|
|
return self.get_losses(
|
|
imgs,
|
|
x_shifts,
|
|
y_shifts,
|
|
expanded_strides,
|
|
labels,
|
|
torch.cat(outputs, 1),
|
|
origin_preds,
|
|
dtype=xin[0].dtype,
|
|
)
|
|
|
|
else:
|
|
self.hw = [x.shape[-2:] for x in outputs]
|
|
# [batch, n_anchors_all, 85]
|
|
outputs = torch.cat([x.flatten(start_dim=2) for x in outputs],
|
|
dim=2).permute(0, 2, 1)
|
|
if self.decode_in_inference:
|
|
return self.decode_outputs(outputs, dtype=xin[0].type())
|
|
else:
|
|
return outputs
|