EasyCV/easycv/models/detection/yolox/yolox.py

168 lines
5.5 KiB
Python

# Copyright (c) 2014-2021 Megvii Inc And Alibaba PAI-Teams. All rights reserved.
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from easycv.models.base import BaseModel
from easycv.models.builder import MODELS
from easycv.models.detection.utils import postprocess
from .yolo_head import YOLOXHead
from .yolo_pafpn import YOLOPAFPN
def init_yolo(M):
for m in M.modules():
if isinstance(m, nn.BatchNorm2d):
m.eps = 1e-3
m.momentum = 0.03
@MODELS.register_module
class YOLOX(BaseModel):
"""
YOLOX model module. The module list is defined by create_yolov3_modules function.
The network returns loss values from three YOLO layers during training
and detection results during test.
"""
param_map = {
'nano': [0.33, 0.25],
'tiny': [0.33, 0.375],
's': [0.33, 0.5],
'm': [0.67, 0.75],
'l': [1.0, 1.0],
'x': [1.33, 1.25]
}
# TODO configs support more params
# backbone(Darknet)、neck(YOLOXPAFPN)、head(YOLOXHead)
def __init__(self,
model_type: str = 's',
num_classes: int = 80,
test_size: tuple = (640, 640),
test_conf: float = 0.01,
nms_thre: float = 0.65,
pretrained: str = None):
super(YOLOX, self).__init__()
assert model_type in self.param_map, f'invalid model_type for yolox {model_type}, valid ones are {list(self.param_map.keys())}'
in_channels = [256, 512, 1024]
depth = self.param_map[model_type][0]
width = self.param_map[model_type][1]
self.backbone = YOLOPAFPN(depth, width, in_channels=in_channels)
self.head = YOLOXHead(num_classes, width, in_channels=in_channels)
self.apply(init_yolo) # init_yolo(self)
self.head.initialize_biases(1e-2)
self.num_classes = num_classes
self.test_conf = test_conf
self.nms_thre = nms_thre
self.test_size = test_size
def forward_train(self,
img: Tensor,
gt_bboxes: Tensor,
gt_labels: Tensor,
img_metas=None,
scale=None) -> Dict[str, Tensor]:
""" Abstract interface for model forward in training
Args:
img (Tensor): image tensor, NxCxHxW
target (List[Tensor]): list of target tensor, NTx5 [class,x_c,y_c,w,h]
"""
# gt_bboxes = gt_bboxes.to(torch.float16)
# gt_labels = gt_labels.to(torch.float16)
fpn_outs = self.backbone(img)
targets = torch.cat([gt_labels, gt_bboxes], dim=2)
loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
fpn_outs, targets, img)
outputs = {
'total_loss':
loss,
'iou_l':
iou_loss,
'conf_l':
conf_loss,
'cls_l':
cls_loss,
'img_h':
torch.tensor(img_metas[0]['img_shape'][0],
device=loss.device).float(),
'img_w':
torch.tensor(img_metas[0]['img_shape'][1],
device=loss.device).float()
}
return outputs
def forward_test(self, img: Tensor, img_metas=None) -> Tensor:
""" Abstract interface for model forward in training
Args:
img (Tensor): image tensor, NxCxHxW
target (List[Tensor]): list of target tensor, NTx5 [class,x_c,y_c,w,h]
"""
with torch.no_grad():
fpn_outs = self.backbone(img)
outputs = self.head(fpn_outs)
outputs = postprocess(outputs, self.num_classes, self.test_conf,
self.nms_thre)
detection_boxes = []
detection_scores = []
detection_classes = []
img_metas_list = []
for i in range(len(outputs)):
if img_metas:
img_metas_list.append(img_metas[i])
if outputs[i] is not None:
bboxes = outputs[i][:,
0:4] if outputs[i] is not None else None
if img_metas:
bboxes /= img_metas[i]['scale_factor'][0]
detection_boxes.append(bboxes.cpu().numpy())
detection_scores.append(
(outputs[i][:, 4] * outputs[i][:, 5]).cpu().numpy())
detection_classes.append(
outputs[i][:, 6].cpu().numpy().astype(np.int32))
else:
detection_boxes.append(None)
detection_scores.append(None)
detection_classes.append(None)
test_outputs = {
'detection_boxes': detection_boxes,
'detection_scores': detection_scores,
'detection_classes': detection_classes,
'img_metas': img_metas_list
}
return test_outputs
def forward(self, img, mode='compression', **kwargs):
if mode == 'train':
return self.forward_train(img, **kwargs)
elif mode == 'test':
return self.forward_test(img, **kwargs)
elif mode == 'compression':
return self.forward_compression(img, **kwargs)
def forward_compression(self, x):
# fpn output content features of [dark3, dark4, dark5]
fpn_outs = self.backbone(x)
outputs = self.head(fpn_outs)
return outputs