mirror of https://github.com/alibaba/EasyCV.git
168 lines
5.5 KiB
Python
168 lines
5.5 KiB
Python
# Copyright (c) 2014-2021 Megvii Inc And Alibaba PAI-Teams. All rights reserved.
|
|
from typing import Dict
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
|
|
from easycv.models.base import BaseModel
|
|
from easycv.models.builder import MODELS
|
|
from easycv.models.detection.utils import postprocess
|
|
from .yolo_head import YOLOXHead
|
|
from .yolo_pafpn import YOLOPAFPN
|
|
|
|
|
|
def init_yolo(M):
|
|
for m in M.modules():
|
|
if isinstance(m, nn.BatchNorm2d):
|
|
m.eps = 1e-3
|
|
m.momentum = 0.03
|
|
|
|
|
|
@MODELS.register_module
|
|
class YOLOX(BaseModel):
|
|
"""
|
|
YOLOX model module. The module list is defined by create_yolov3_modules function.
|
|
The network returns loss values from three YOLO layers during training
|
|
and detection results during test.
|
|
"""
|
|
param_map = {
|
|
'nano': [0.33, 0.25],
|
|
'tiny': [0.33, 0.375],
|
|
's': [0.33, 0.5],
|
|
'm': [0.67, 0.75],
|
|
'l': [1.0, 1.0],
|
|
'x': [1.33, 1.25]
|
|
}
|
|
|
|
# TODO configs support more params
|
|
# backbone(Darknet)、neck(YOLOXPAFPN)、head(YOLOXHead)
|
|
def __init__(self,
|
|
model_type: str = 's',
|
|
num_classes: int = 80,
|
|
test_size: tuple = (640, 640),
|
|
test_conf: float = 0.01,
|
|
nms_thre: float = 0.65,
|
|
pretrained: str = None):
|
|
super(YOLOX, self).__init__()
|
|
assert model_type in self.param_map, f'invalid model_type for yolox {model_type}, valid ones are {list(self.param_map.keys())}'
|
|
|
|
in_channels = [256, 512, 1024]
|
|
depth = self.param_map[model_type][0]
|
|
width = self.param_map[model_type][1]
|
|
|
|
self.backbone = YOLOPAFPN(depth, width, in_channels=in_channels)
|
|
self.head = YOLOXHead(num_classes, width, in_channels=in_channels)
|
|
|
|
self.apply(init_yolo) # init_yolo(self)
|
|
self.head.initialize_biases(1e-2)
|
|
|
|
self.num_classes = num_classes
|
|
self.test_conf = test_conf
|
|
self.nms_thre = nms_thre
|
|
self.test_size = test_size
|
|
|
|
def forward_train(self,
|
|
img: Tensor,
|
|
gt_bboxes: Tensor,
|
|
gt_labels: Tensor,
|
|
img_metas=None,
|
|
scale=None) -> Dict[str, Tensor]:
|
|
""" Abstract interface for model forward in training
|
|
|
|
Args:
|
|
img (Tensor): image tensor, NxCxHxW
|
|
target (List[Tensor]): list of target tensor, NTx5 [class,x_c,y_c,w,h]
|
|
"""
|
|
|
|
# gt_bboxes = gt_bboxes.to(torch.float16)
|
|
# gt_labels = gt_labels.to(torch.float16)
|
|
|
|
fpn_outs = self.backbone(img)
|
|
|
|
targets = torch.cat([gt_labels, gt_bboxes], dim=2)
|
|
|
|
loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
|
|
fpn_outs, targets, img)
|
|
|
|
outputs = {
|
|
'total_loss':
|
|
loss,
|
|
'iou_l':
|
|
iou_loss,
|
|
'conf_l':
|
|
conf_loss,
|
|
'cls_l':
|
|
cls_loss,
|
|
'img_h':
|
|
torch.tensor(img_metas[0]['img_shape'][0],
|
|
device=loss.device).float(),
|
|
'img_w':
|
|
torch.tensor(img_metas[0]['img_shape'][1],
|
|
device=loss.device).float()
|
|
}
|
|
return outputs
|
|
|
|
def forward_test(self, img: Tensor, img_metas=None) -> Tensor:
|
|
""" Abstract interface for model forward in training
|
|
|
|
Args:
|
|
img (Tensor): image tensor, NxCxHxW
|
|
target (List[Tensor]): list of target tensor, NTx5 [class,x_c,y_c,w,h]
|
|
"""
|
|
with torch.no_grad():
|
|
|
|
fpn_outs = self.backbone(img)
|
|
outputs = self.head(fpn_outs)
|
|
|
|
outputs = postprocess(outputs, self.num_classes, self.test_conf,
|
|
self.nms_thre)
|
|
|
|
detection_boxes = []
|
|
detection_scores = []
|
|
detection_classes = []
|
|
img_metas_list = []
|
|
|
|
for i in range(len(outputs)):
|
|
if img_metas:
|
|
img_metas_list.append(img_metas[i])
|
|
if outputs[i] is not None:
|
|
bboxes = outputs[i][:,
|
|
0:4] if outputs[i] is not None else None
|
|
if img_metas:
|
|
bboxes /= img_metas[i]['scale_factor'][0]
|
|
detection_boxes.append(bboxes.cpu().numpy())
|
|
detection_scores.append(
|
|
(outputs[i][:, 4] * outputs[i][:, 5]).cpu().numpy())
|
|
detection_classes.append(
|
|
outputs[i][:, 6].cpu().numpy().astype(np.int32))
|
|
else:
|
|
detection_boxes.append(None)
|
|
detection_scores.append(None)
|
|
detection_classes.append(None)
|
|
|
|
test_outputs = {
|
|
'detection_boxes': detection_boxes,
|
|
'detection_scores': detection_scores,
|
|
'detection_classes': detection_classes,
|
|
'img_metas': img_metas_list
|
|
}
|
|
|
|
return test_outputs
|
|
|
|
def forward(self, img, mode='compression', **kwargs):
|
|
if mode == 'train':
|
|
return self.forward_train(img, **kwargs)
|
|
elif mode == 'test':
|
|
return self.forward_test(img, **kwargs)
|
|
elif mode == 'compression':
|
|
return self.forward_compression(img, **kwargs)
|
|
|
|
def forward_compression(self, x):
|
|
# fpn output content features of [dark3, dark4, dark5]
|
|
fpn_outs = self.backbone(x)
|
|
outputs = self.head(fpn_outs)
|
|
|
|
return outputs
|