2022-08-31 15:18:11 +08:00
from distutils . version import LooseVersion
import numpy as np
import torch
import torch . nn as nn
import torch . nn . functional as F
import torchvision
from easycv . models . detection . utils import box_cxcywh_to_xyxy
class DetrPostProcess ( nn . Module ) :
""" This module converts the model ' s output into the format expected by the coco api """
2022-12-02 14:33:01 +08:00
def __init__ ( self ,
num_select = None ,
use_centerness = False ,
use_iouaware = False ) - > None :
2022-08-31 15:18:11 +08:00
super ( ) . __init__ ( )
self . num_select = num_select
2022-12-02 14:33:01 +08:00
self . use_centerness = use_centerness
self . use_iouaware = use_iouaware
2022-08-31 15:18:11 +08:00
@torch.no_grad ( )
def forward ( self , outputs , target_sizes , img_metas ) :
""" Perform the computation
Parameters :
outputs : raw outputs of the model
target_sizes : tensor of dimension [ batch_size x 2 ] containing the size of each images of the batch
For evaluation , this must be the original image size ( before any data augmentation )
For visualization , this should be the image size after data augment , but before padding
"""
out_logits , out_bbox = outputs [ ' pred_logits ' ] , outputs [ ' pred_boxes ' ]
assert len ( out_logits ) == len ( target_sizes )
assert target_sizes . shape [ 1 ] == 2
if self . num_select is None :
prob = F . softmax ( out_logits , - 1 )
scores , labels = prob [ . . . , : - 1 ] . max ( - 1 )
boxes = box_cxcywh_to_xyxy ( out_bbox )
2022-12-02 14:33:01 +08:00
# and from relative [0, 1] to absolute [0, height] coordinates
img_h , img_w = target_sizes . unbind ( 1 )
scale_fct = torch . stack ( [ img_w , img_h , img_w , img_h ] ,
dim = 1 ) . to ( boxes . device )
boxes = boxes * scale_fct [ : , None , : ]
2022-08-31 15:18:11 +08:00
else :
2022-12-02 14:33:01 +08:00
if self . use_centerness and self . use_iouaware :
prob = out_logits . sigmoid (
) * * 0.45 * outputs [ ' pred_centers ' ] . sigmoid (
) * * 0.05 * outputs [ ' pred_ious ' ] . sigmoid ( ) * * 0.5
elif self . use_centerness :
prob = out_logits . sigmoid ( ) * outputs [ ' pred_centers ' ] . sigmoid ( )
elif self . use_iouaware :
prob = out_logits . sigmoid ( ) * outputs [ ' pred_ious ' ] . sigmoid ( )
else :
prob = out_logits . sigmoid ( )
2022-08-31 15:18:11 +08:00
topk_values , topk_indexes = torch . topk (
prob . view ( out_logits . shape [ 0 ] , - 1 ) , self . num_select , dim = 1 )
scores = topk_values
topk_boxes = topk_indexes / / out_logits . shape [ 2 ]
labels = topk_indexes % out_logits . shape [ 2 ]
boxes = box_cxcywh_to_xyxy ( out_bbox )
boxes = torch . gather ( boxes , 1 ,
topk_boxes . unsqueeze ( - 1 ) . repeat ( 1 , 1 , 4 ) )
2022-12-02 14:33:01 +08:00
# and from relative [0, 1] to absolute [0, height] coordinates
img_h , img_w = target_sizes . unbind ( 1 )
scale_fct = torch . stack ( [ img_w , img_h , img_w , img_h ] ,
dim = 1 ) . to ( boxes . device )
boxes = boxes * scale_fct [ : , None , : ]
2022-08-31 15:18:11 +08:00
results = {
' detection_boxes ' : [ boxes [ 0 ] . cpu ( ) . numpy ( ) ] ,
' detection_scores ' : [ scores [ 0 ] . cpu ( ) . numpy ( ) ] ,
' detection_classes ' : [ labels [ 0 ] . cpu ( ) . numpy ( ) . astype ( np . int32 ) ] ,
' img_metas ' : img_metas
}
return results
def output_postprocess ( outputs , img_metas = None ) :
detection_boxes = [ ]
detection_scores = [ ]
detection_classes = [ ]
img_metas_list = [ ]
for i in range ( len ( outputs ) ) :
if img_metas :
img_metas_list . append ( img_metas [ i ] )
if outputs [ i ] is not None :
bboxes = outputs [ i ] [ : , 0 : 4 ] if outputs [ i ] is not None else None
if img_metas :
bboxes / = img_metas [ i ] [ ' scale_factor ' ] [ 0 ]
detection_boxes . append ( bboxes . cpu ( ) . numpy ( ) )
detection_scores . append (
( outputs [ i ] [ : , 4 ] * outputs [ i ] [ : , 5 ] ) . cpu ( ) . numpy ( ) )
detection_classes . append ( outputs [ i ] [ : , 6 ] . cpu ( ) . numpy ( ) . astype (
np . int32 ) )
else :
detection_boxes . append ( None )
detection_scores . append ( None )
detection_classes . append ( None )
test_outputs = {
' detection_boxes ' : detection_boxes ,
' detection_scores ' : detection_scores ,
' detection_classes ' : detection_classes ,
' img_metas ' : img_metas_list
}
return test_outputs
# refer to easycv/models/detection/detectors/yolox/postprocess.py and test.py to rebuild a torch-blade-trtplugin NMS, which is checked by zhoulou in test.py
# infer docker images is : registry.cn-shanghai.aliyuncs.com/pai-ai-test/eas-service:easycv_blade_181_export
def postprocess ( prediction , num_classes , conf_thre = 0.7 , nms_thre = 0.45 ) :
box_corner = prediction . new ( prediction . shape )
box_corner [ : , : , 0 ] = prediction [ : , : , 0 ] - prediction [ : , : , 2 ] / 2
box_corner [ : , : , 1 ] = prediction [ : , : , 1 ] - prediction [ : , : , 3 ] / 2
box_corner [ : , : , 2 ] = prediction [ : , : , 0 ] + prediction [ : , : , 2 ] / 2
box_corner [ : , : , 3 ] = prediction [ : , : , 1 ] + prediction [ : , : , 3 ] / 2
prediction [ : , : , : 4 ] = box_corner [ : , : , : 4 ]
output = [ None for _ in range ( len ( prediction ) ) ]
for i , image_pred in enumerate ( prediction ) :
# If none are remaining => process next image
if not image_pred . numel ( ) :
continue
# Get score and class with highest confidence
class_conf , class_pred = torch . max (
image_pred [ : , 5 : 5 + num_classes ] , 1 , keepdim = True )
conf_mask = ( image_pred [ : , 4 ] * class_conf . squeeze ( ) > =
conf_thre ) . squeeze ( )
# Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
detections = torch . cat (
( image_pred [ : , : 5 ] , class_conf , class_pred . float ( ) ) , 1 )
detections = detections [ conf_mask ]
if not detections . numel ( ) :
continue
if LooseVersion ( torchvision . __version__ ) > = LooseVersion ( ' 0.8.0 ' ) :
nms_out_index = torchvision . ops . batched_nms (
detections [ : , : 4 ] , detections [ : , 4 ] * detections [ : , 5 ] ,
detections [ : , 6 ] , nms_thre )
else :
nms_out_index = torchvision . ops . nms (
detections [ : , : 4 ] , detections [ : , 4 ] * detections [ : , 5 ] ,
nms_thre )
detections = detections [ nms_out_index ]
if output [ i ] is None :
output [ i ] = detections
else :
output [ i ] = torch . cat ( ( output [ i ] , detections ) )
return output