226 lines
7.7 KiB
Python
Raw Normal View History

2022-04-02 20:01:06 +08:00
# Copyright (c) 2014-2021 Megvii Inc And Alibaba PAI-Teams. All rights reserved.
import logging
2022-04-02 20:01:06 +08:00
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from easycv.models.base import BaseModel
from easycv.models.builder import MODELS, build_head
2022-04-02 20:01:06 +08:00
from easycv.models.detection.utils import postprocess
from .yolo_pafpn import YOLOPAFPN
def init_yolo(M):
for m in M.modules():
if isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03
@MODELS.register_module
class YOLOX(BaseModel):
"""
YOLOX model module. The module list is defined by create_yolov3_modules function.
The network returns loss values from three YOLO layers during training
and detection results during test.
"""
param_map = {
'nano': [0.33, 0.25],
'tiny': [0.33, 0.375],
's': [0.33, 0.5],
'm': [0.67, 0.75],
'l': [1.0, 1.0],
'x': [1.33, 1.25]
}
def __init__(self,
model_type='s',
test_conf=0.01,
nms_thre=0.65,
backbone='CSPDarknet',
use_att=None,
asff_channel=2,
neck_type='yolo',
neck_mode='all',
num_classes=None,
head=None,
pretrained=True):
2022-04-02 20:01:06 +08:00
super(YOLOX, self).__init__()
2022-04-02 20:01:06 +08:00
assert model_type in self.param_map, f'invalid model_type for yolox {model_type}, valid ones are {list(self.param_map.keys())}'
self.pretrained = pretrained
2022-04-02 20:01:06 +08:00
in_channels = [256, 512, 1024]
depth = self.param_map[model_type][0]
width = self.param_map[model_type][1]
self.backbone = YOLOPAFPN(
depth,
width,
backbone=backbone,
neck_type=neck_type,
neck_mode=neck_mode,
in_channels=in_channels,
asff_channel=asff_channel,
use_att=use_att)
if num_classes is not None:
# adapt to previous export model (before easycv0.6.0)
logging.warning(
'Warning: You are now attend to use an old YOLOX model before easycv0.6.0 with key num_classes'
)
head = dict(
type='YOLOXHead',
model_type=model_type,
num_classes=num_classes,
)
if head is not None:
# head is None for YOLOX-edge to define a special head
self.head = build_head(head)
self.num_classes = self.head.num_classes
2022-04-02 20:01:06 +08:00
self.apply(init_yolo) # init_yolo(self)
self.test_conf = test_conf
self.nms_thre = nms_thre
self.use_trt_efficientnms = False # TRT NMS only will be convert during export
self.trt_efficientnms = None
self.export_type = 'raw' # export type will be convert during export
def get_nmsboxes_num(self, img_scale=(640, 640)):
""" Detection neck or head should provide nms box count information
"""
if getattr(self, 'neck', None) is not None:
return self.neck.get_nmsboxes_num(img_scale=(640, 640))
else:
return self.head.get_nmsboxes_num(img_scale=(640, 640))
2022-04-02 20:01:06 +08:00
def forward_train(self,
img: Tensor,
gt_bboxes: Tensor,
gt_labels: Tensor,
img_metas=None,
scale=None) -> Dict[str, Tensor]:
""" Abstract interface for model forward in training
Args:
img (Tensor): image tensor, NxCxHxW
target (List[Tensor]): list of target tensor, NTx5 [class,x_c,y_c,w,h]
"""
# gt_bboxes = gt_bboxes.to(torch.float16)
# gt_labels = gt_labels.to(torch.float16)
fpn_outs = self.backbone(img)
targets = torch.cat([gt_labels, gt_bboxes], dim=2)
loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
fpn_outs, targets, img)
outputs = {
'total_loss':
loss,
'iou_l':
iou_loss,
'conf_l':
conf_loss,
'cls_l':
cls_loss,
'img_h':
torch.tensor(img_metas[0]['img_shape'][0],
device=loss.device).float(),
'img_w':
torch.tensor(img_metas[0]['img_shape'][1],
device=loss.device).float()
}
2022-04-02 20:01:06 +08:00
return outputs
def forward_test(self, img: Tensor, img_metas=None) -> Tensor:
""" Abstract interface for model forward in training
Args:
img (Tensor): image tensor, NxCxHxW
target (List[Tensor]): list of target tensor, NTx5 [class,x_c,y_c,w,h]
"""
with torch.no_grad():
fpn_outs = self.backbone(img)
outputs = self.head(fpn_outs)
outputs = postprocess(outputs, self.num_classes, self.test_conf,
self.nms_thre)
detection_boxes = []
detection_scores = []
detection_classes = []
img_metas_list = []
for i in range(len(outputs)):
if img_metas:
img_metas_list.append(img_metas[i])
if outputs[i] is not None:
bboxes = outputs[i][:,
0:4] if outputs[i] is not None else None
if img_metas:
bboxes /= img_metas[i]['scale_factor'][0]
detection_boxes.append(bboxes.cpu().numpy())
detection_scores.append(
(outputs[i][:, 4] * outputs[i][:, 5]).cpu().numpy())
detection_classes.append(
outputs[i][:, 6].cpu().numpy().astype(np.int32))
else:
detection_boxes.append(None)
detection_scores.append(None)
detection_classes.append(None)
test_outputs = {
'detection_boxes': detection_boxes,
'detection_scores': detection_scores,
'detection_classes': detection_classes,
'img_metas': img_metas_list
}
return test_outputs
def forward(self, img, mode='compression', **kwargs):
if mode == 'train':
return self.forward_train(img, **kwargs)
elif mode == 'test':
return self.forward_test(img, **kwargs)
elif mode == 'compression':
return self.forward_compression(img, **kwargs)
def forward_compression(self, x):
# fpn output content features of [dark3, dark4, dark5]
fpn_outs = self.backbone(x)
outputs = self.head(fpn_outs)
return outputs
def forward_export(self, img):
with torch.no_grad():
fpn_outs = self.backbone(img)
outputs = self.head(fpn_outs)
if self.head.decode_in_inference:
if self.use_trt_efficientnms:
if self.trt_efficientnms is not None:
outputs = self.trt_efficientnms.forward(outputs)
else:
logging.error(
'PAI-YOLOX : using trt_efficientnms set to be True, but model has not attr(trt_efficientnms)'
)
else:
if self.export_type == 'jit':
outputs = postprocess(outputs, self.num_classes,
self.test_conf, self.nms_thre)
else:
logging.warning(
'PAI-YOLOX : export Blade model is not allowed to wrap the postprocess into jit script model'
)
return outputs