2022-04-02 20:01:06 +08:00
# Copyright (c) Alibaba, Inc. and its affiliates.
2022-07-11 14:52:49 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
2022-04-02 20:01:06 +08:00
from distutils . version import LooseVersion
2022-07-11 14:52:49 +08:00
import numpy as np
2022-04-02 20:01:06 +08:00
import torch
import torchvision
2022-07-12 18:07:02 +08:00
from torchvision . ops . boxes import box_area , nms
2022-04-02 20:01:06 +08:00
2022-07-11 14:52:49 +08:00
from easycv . models . detection . utils . misc import fp16_clamp
2022-04-02 20:01:06 +08:00
def bboxes_iou ( bboxes_a , bboxes_b , xyxy = True ) :
if bboxes_a . shape [ 1 ] != 4 or bboxes_b . shape [ 1 ] != 4 :
raise IndexError
if xyxy :
tl = torch . max ( bboxes_a [ : , None , : 2 ] , bboxes_b [ : , : 2 ] )
br = torch . min ( bboxes_a [ : , None , 2 : ] , bboxes_b [ : , 2 : ] )
area_a = torch . prod ( bboxes_a [ : , 2 : ] - bboxes_a [ : , : 2 ] , 1 )
area_b = torch . prod ( bboxes_b [ : , 2 : ] - bboxes_b [ : , : 2 ] , 1 )
else :
tl = torch . max (
( bboxes_a [ : , None , : 2 ] - bboxes_a [ : , None , 2 : ] / 2 ) ,
( bboxes_b [ : , : 2 ] - bboxes_b [ : , 2 : ] / 2 ) ,
)
br = torch . min (
( bboxes_a [ : , None , : 2 ] + bboxes_a [ : , None , 2 : ] / 2 ) ,
( bboxes_b [ : , : 2 ] + bboxes_b [ : , 2 : ] / 2 ) ,
)
area_a = torch . prod ( bboxes_a [ : , 2 : ] , 1 )
area_b = torch . prod ( bboxes_b [ : , 2 : ] , 1 )
en = ( tl < br ) . type ( tl . type ( ) ) . prod ( dim = 2 )
area_i = torch . prod ( br - tl , 2 ) * en # * ((tl < br).all())
return area_i / ( area_a [ : , None ] + area_b - area_i )
2022-08-24 18:11:15 +08:00
# refer to easycv/models/detection/detectors/yolox/postprocess.py and test.py to rebuild a torch-blade-trtplugin NMS, which is checked by zhoulou in test.py
# infer docker images is : registry.cn-shanghai.aliyuncs.com/pai-ai-test/eas-service:easycv_blade_181_export
2022-04-02 20:01:06 +08:00
def postprocess ( prediction , num_classes , conf_thre = 0.7 , nms_thre = 0.45 ) :
box_corner = prediction . new ( prediction . shape )
box_corner [ : , : , 0 ] = prediction [ : , : , 0 ] - prediction [ : , : , 2 ] / 2
box_corner [ : , : , 1 ] = prediction [ : , : , 1 ] - prediction [ : , : , 3 ] / 2
box_corner [ : , : , 2 ] = prediction [ : , : , 0 ] + prediction [ : , : , 2 ] / 2
box_corner [ : , : , 3 ] = prediction [ : , : , 1 ] + prediction [ : , : , 3 ] / 2
prediction [ : , : , : 4 ] = box_corner [ : , : , : 4 ]
output = [ None for _ in range ( len ( prediction ) ) ]
for i , image_pred in enumerate ( prediction ) :
# If none are remaining => process next image
2022-06-08 20:41:15 +08:00
if not image_pred . numel ( ) :
2022-04-02 20:01:06 +08:00
continue
# Get score and class with highest confidence
class_conf , class_pred = torch . max (
image_pred [ : , 5 : 5 + num_classes ] , 1 , keepdim = True )
conf_mask = ( image_pred [ : , 4 ] * class_conf . squeeze ( ) > =
conf_thre ) . squeeze ( )
# Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
detections = torch . cat (
( image_pred [ : , : 5 ] , class_conf , class_pred . float ( ) ) , 1 )
detections = detections [ conf_mask ]
2022-06-08 20:41:15 +08:00
if not detections . numel ( ) :
2022-04-02 20:01:06 +08:00
continue
if LooseVersion ( torchvision . __version__ ) > = LooseVersion ( ' 0.8.0 ' ) :
nms_out_index = torchvision . ops . batched_nms (
detections [ : , : 4 ] , detections [ : , 4 ] * detections [ : , 5 ] ,
detections [ : , 6 ] , nms_thre )
else :
nms_out_index = torchvision . ops . nms (
detections [ : , : 4 ] , detections [ : , 4 ] * detections [ : , 5 ] ,
nms_thre )
detections = detections [ nms_out_index ]
if output [ i ] is None :
output [ i ] = detections
else :
output [ i ] = torch . cat ( ( output [ i ] , detections ) )
return output
2022-07-11 14:52:49 +08:00
def bbox2result ( bboxes , labels , num_classes ) :
""" Convert detection results to a list of numpy arrays.
Args :
bboxes ( torch . Tensor | np . ndarray ) : shape ( n , 5 )
labels ( torch . Tensor | np . ndarray ) : shape ( n , )
num_classes ( int ) : class number , including background class
Returns :
list ( ndarray ) : bbox results of each class
"""
if bboxes . shape [ 0 ] == 0 :
return [ np . zeros ( ( 0 , 5 ) , dtype = np . float32 ) for i in range ( num_classes ) ]
else :
if isinstance ( bboxes , torch . Tensor ) :
bboxes = bboxes . detach ( ) . cpu ( ) . numpy ( )
labels = labels . detach ( ) . cpu ( ) . numpy ( )
return [ bboxes [ labels == i , : ] for i in range ( num_classes ) ]
def box_cxcywh_to_xyxy ( x ) :
x_c , y_c , w , h = x . unbind ( - 1 )
b = [ ( x_c - 0.5 * w ) , ( y_c - 0.5 * h ) , ( x_c + 0.5 * w ) , ( y_c + 0.5 * h ) ]
return torch . stack ( b , dim = - 1 )
def box_xyxy_to_cxcywh ( x ) :
x0 , y0 , x1 , y1 = x . unbind ( - 1 )
b = [ ( x0 + x1 ) / 2 , ( y0 + y1 ) / 2 , ( x1 - x0 ) , ( y1 - y0 ) ]
return torch . stack ( b , dim = - 1 )
# modified from torchvision to also return the union
def box_iou ( boxes1 , boxes2 ) :
area1 = box_area ( boxes1 )
area2 = box_area ( boxes2 )
lt = torch . max ( boxes1 [ : , None , : 2 ] , boxes2 [ : , : 2 ] ) # [N,M,2]
rb = torch . min ( boxes1 [ : , None , 2 : ] , boxes2 [ : , 2 : ] ) # [N,M,2]
wh = ( rb - lt ) . clamp ( min = 0 ) # [N,M,2]
inter = wh [ : , : , 0 ] * wh [ : , : , 1 ] # [N,M]
union = area1 [ : , None ] + area2 - inter
iou = inter / union
return iou , union
def generalized_box_iou ( boxes1 , boxes2 ) :
"""
Generalized IoU from https : / / giou . stanford . edu /
The boxes should be in [ x0 , y0 , x1 , y1 ] format
Returns a [ N , M ] pairwise matrix , where N = len ( boxes1 )
and M = len ( boxes2 )
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert ( boxes1 [ : , 2 : ] > = boxes1 [ : , : 2 ] ) . all ( )
assert ( boxes2 [ : , 2 : ] > = boxes2 [ : , : 2 ] ) . all ( )
iou , union = box_iou ( boxes1 , boxes2 )
lt = torch . min ( boxes1 [ : , None , : 2 ] , boxes2 [ : , : 2 ] )
rb = torch . max ( boxes1 [ : , None , 2 : ] , boxes2 [ : , 2 : ] )
wh = ( rb - lt ) . clamp ( min = 0 ) # [N,M,2]
area = wh [ : , : , 0 ] * wh [ : , : , 1 ]
return iou - ( area - union ) / area
def bbox_overlaps ( bboxes1 , bboxes2 , mode = ' iou ' , is_aligned = False , eps = 1e-6 ) :
""" Calculate overlap between two set of bboxes.
FP16 Contributed by https : / / github . com / open - mmlab / mmdetection / pull / 4889
Note :
Assume bboxes1 is M x 4 , bboxes2 is N x 4 , when mode is ' iou ' ,
there are some new generated variable when calculating IOU
using bbox_overlaps function :
1 ) is_aligned is False
area1 : M x 1
area2 : N x 1
lt : M x N x 2
rb : M x N x 2
wh : M x N x 2
overlap : M x N x 1
union : M x N x 1
ious : M x N x 1
Total memory :
S = ( 9 x N x M + N + M ) * 4 Byte ,
When using FP16 , we can reduce :
R = ( 9 x N x M + N + M ) * 4 / 2 Byte
R large than ( N + M ) * 4 * 2 is always true when N and M > = 1.
Obviously , N + M < = N * M < 3 * N * M , when N > = 2 and M > = 2 ,
N + 1 < 3 * N , when N or M is 1.
Given M = 40 ( ground truth ) , N = 400000 ( three anchor boxes
in per grid , FPN , R - CNNs ) ,
R = 275 MB ( one times )
A special case ( dense detection ) , M = 512 ( ground truth ) ,
R = 3516 MB = 3.43 GB
When the batch size is B , reduce :
B x R
Therefore , CUDA memory runs out frequently .
Experiments on GeForce RTX 2080 Ti ( 11019 MiB ) :
| dtype | M | N | Use | Real | Ideal |
| : - - - - : | : - - - - : | : - - - - : | : - - - - : | : - - - - : | : - - - - : |
| FP32 | 512 | 400000 | 8020 MiB | - - | - - |
| FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB |
| FP32 | 40 | 400000 | 1540 MiB | - - | - - |
| FP16 | 40 | 400000 | 1264 MiB | 276 MiB | 275 MiB |
2 ) is_aligned is True
area1 : N x 1
area2 : N x 1
lt : N x 2
rb : N x 2
wh : N x 2
overlap : N x 1
union : N x 1
ious : N x 1
Total memory :
S = 11 x N * 4 Byte
When using FP16 , we can reduce :
R = 11 x N * 4 / 2 Byte
So do the ' giou ' ( large than ' iou ' ) .
Time - wise , FP16 is generally faster than FP32 .
When gpu_assign_thr is not - 1 , it takes more time on cpu
but not reduce memory .
There , we can reduce half the memory and keep the speed .
If ` ` is_aligned ` ` is ` ` False ` ` , then calculate the overlaps between each
bbox of bboxes1 and bboxes2 , otherwise the overlaps between each aligned
pair of bboxes1 and bboxes2 .
Args :
bboxes1 ( Tensor ) : shape ( B , m , 4 ) in < x1 , y1 , x2 , y2 > format or empty .
bboxes2 ( Tensor ) : shape ( B , n , 4 ) in < x1 , y1 , x2 , y2 > format or empty .
B indicates the batch dim , in shape ( B1 , B2 , . . . , Bn ) .
If ` ` is_aligned ` ` is ` ` True ` ` , then m and n must be equal .
mode ( str ) : " iou " ( intersection over union ) , " iof " ( intersection over
foreground ) or " giou " ( generalized intersection over union ) .
Default " iou " .
is_aligned ( bool , optional ) : If True , then m and n must be equal .
Default False .
eps ( float , optional ) : A value added to the denominator for numerical
stability . Default 1e-6 .
Returns :
Tensor : shape ( m , n ) if ` ` is_aligned ` ` is False else shape ( m , )
Example :
>> > bboxes1 = torch . FloatTensor ( [
>> > [ 0 , 0 , 10 , 10 ] ,
>> > [ 10 , 10 , 20 , 20 ] ,
>> > [ 32 , 32 , 38 , 42 ] ,
>> > ] )
>> > bboxes2 = torch . FloatTensor ( [
>> > [ 0 , 0 , 10 , 20 ] ,
>> > [ 0 , 10 , 10 , 19 ] ,
>> > [ 10 , 10 , 20 , 20 ] ,
>> > ] )
>> > overlaps = bbox_overlaps ( bboxes1 , bboxes2 )
>> > assert overlaps . shape == ( 3 , 3 )
>> > overlaps = bbox_overlaps ( bboxes1 , bboxes2 , is_aligned = True )
>> > assert overlaps . shape == ( 3 , )
Example :
>> > empty = torch . empty ( 0 , 4 )
>> > nonempty = torch . FloatTensor ( [ [ 0 , 0 , 10 , 9 ] ] )
>> > assert tuple ( bbox_overlaps ( empty , nonempty ) . shape ) == ( 0 , 1 )
>> > assert tuple ( bbox_overlaps ( nonempty , empty ) . shape ) == ( 1 , 0 )
>> > assert tuple ( bbox_overlaps ( empty , empty ) . shape ) == ( 0 , 0 )
"""
assert mode in [ ' iou ' , ' iof ' , ' giou ' ] , f ' Unsupported mode { mode } '
# Either the boxes are empty or the length of boxes' last dimension is 4
assert ( bboxes1 . size ( - 1 ) == 4 or bboxes1 . size ( 0 ) == 0 )
assert ( bboxes2 . size ( - 1 ) == 4 or bboxes2 . size ( 0 ) == 0 )
# Batch dim must be the same
# Batch dim: (B1, B2, ... Bn)
assert bboxes1 . shape [ : - 2 ] == bboxes2 . shape [ : - 2 ]
batch_shape = bboxes1 . shape [ : - 2 ]
rows = bboxes1 . size ( - 2 )
cols = bboxes2 . size ( - 2 )
if is_aligned :
assert rows == cols
if rows * cols == 0 :
if is_aligned :
return bboxes1 . new ( batch_shape + ( rows , ) )
else :
return bboxes1 . new ( batch_shape + ( rows , cols ) )
area1 = ( bboxes1 [ . . . , 2 ] - bboxes1 [ . . . , 0 ] ) * (
bboxes1 [ . . . , 3 ] - bboxes1 [ . . . , 1 ] )
area2 = ( bboxes2 [ . . . , 2 ] - bboxes2 [ . . . , 0 ] ) * (
bboxes2 [ . . . , 3 ] - bboxes2 [ . . . , 1 ] )
if is_aligned :
lt = torch . max ( bboxes1 [ . . . , : 2 ] , bboxes2 [ . . . , : 2 ] ) # [B, rows, 2]
rb = torch . min ( bboxes1 [ . . . , 2 : ] , bboxes2 [ . . . , 2 : ] ) # [B, rows, 2]
wh = fp16_clamp ( rb - lt , min = 0 )
overlap = wh [ . . . , 0 ] * wh [ . . . , 1 ]
if mode in [ ' iou ' , ' giou ' ] :
union = area1 + area2 - overlap
else :
union = area1
if mode == ' giou ' :
enclosed_lt = torch . min ( bboxes1 [ . . . , : 2 ] , bboxes2 [ . . . , : 2 ] )
enclosed_rb = torch . max ( bboxes1 [ . . . , 2 : ] , bboxes2 [ . . . , 2 : ] )
else :
lt = torch . max ( bboxes1 [ . . . , : , None , : 2 ] ,
bboxes2 [ . . . , None , : , : 2 ] ) # [B, rows, cols, 2]
rb = torch . min ( bboxes1 [ . . . , : , None , 2 : ] ,
bboxes2 [ . . . , None , : , 2 : ] ) # [B, rows, cols, 2]
wh = fp16_clamp ( rb - lt , min = 0 )
overlap = wh [ . . . , 0 ] * wh [ . . . , 1 ]
if mode in [ ' iou ' , ' giou ' ] :
union = area1 [ . . . , None ] + area2 [ . . . , None , : ] - overlap
else :
union = area1 [ . . . , None ]
if mode == ' giou ' :
enclosed_lt = torch . min ( bboxes1 [ . . . , : , None , : 2 ] ,
bboxes2 [ . . . , None , : , : 2 ] )
enclosed_rb = torch . max ( bboxes1 [ . . . , : , None , 2 : ] ,
bboxes2 [ . . . , None , : , 2 : ] )
eps = union . new_tensor ( [ eps ] )
union = torch . max ( union , eps )
ious = overlap / union
if mode in [ ' iou ' , ' iof ' ] :
return ious
# calculate gious
enclose_wh = fp16_clamp ( enclosed_rb - enclosed_lt , min = 0 )
enclose_area = enclose_wh [ . . . , 0 ] * enclose_wh [ . . . , 1 ]
enclose_area = torch . max ( enclose_area , eps )
gious = ious - ( enclose_area - union ) / enclose_area
return gious
def bbox2distance ( points , bbox , max_dis = None , eps = 0.1 ) :
""" Decode bounding box based on distances.
Args :
points ( Tensor ) : Shape ( n , 2 ) , [ x , y ] .
bbox ( Tensor ) : Shape ( n , 4 ) , " xyxy " format
max_dis ( float ) : Upper bound of the distance .
eps ( float ) : a small value to ensure target < max_dis , instead < =
Returns :
Tensor : Decoded distances .
"""
left = points [ : , 0 ] - bbox [ : , 0 ]
top = points [ : , 1 ] - bbox [ : , 1 ]
right = bbox [ : , 2 ] - points [ : , 0 ]
bottom = bbox [ : , 3 ] - points [ : , 1 ]
if max_dis is not None :
left = left . clamp ( min = 0 , max = max_dis - eps )
top = top . clamp ( min = 0 , max = max_dis - eps )
right = right . clamp ( min = 0 , max = max_dis - eps )
bottom = bottom . clamp ( min = 0 , max = max_dis - eps )
return torch . stack ( [ left , top , right , bottom ] , - 1 )
def distance2bbox ( points , distance , max_shape = None ) :
""" Decode distance prediction to bounding box.
Args :
points ( Tensor ) : Shape ( B , N , 2 ) or ( N , 2 ) .
distance ( Tensor ) : Distance from the given point to 4
boundaries ( left , top , right , bottom ) . Shape ( B , N , 4 ) or ( N , 4 )
max_shape ( Sequence [ int ] or torch . Tensor or Sequence [
Sequence [ int ] ] , optional ) : Maximum bounds for boxes , specifies
( H , W , C ) or ( H , W ) . If priors shape is ( B , N , 4 ) , then
the max_shape should be a Sequence [ Sequence [ int ] ]
and the length of max_shape should also be B .
Returns :
Tensor : Boxes with shape ( N , 4 ) or ( B , N , 4 )
"""
x1 = points [ . . . , 0 ] - distance [ . . . , 0 ]
y1 = points [ . . . , 1 ] - distance [ . . . , 1 ]
x2 = points [ . . . , 0 ] + distance [ . . . , 2 ]
y2 = points [ . . . , 1 ] + distance [ . . . , 3 ]
bboxes = torch . stack ( [ x1 , y1 , x2 , y2 ] , - 1 )
if max_shape is not None :
if bboxes . dim ( ) == 2 and not torch . onnx . is_in_onnx_export ( ) :
# speed up
bboxes [ : , 0 : : 2 ] . clamp_ ( min = 0 , max = max_shape [ 1 ] )
bboxes [ : , 1 : : 2 ] . clamp_ ( min = 0 , max = max_shape [ 0 ] )
return bboxes
if not isinstance ( max_shape , torch . Tensor ) :
max_shape = x1 . new_tensor ( max_shape )
max_shape = max_shape [ . . . , : 2 ] . type_as ( x1 )
if max_shape . ndim == 2 :
assert bboxes . ndim == 3
assert max_shape . size ( 0 ) == bboxes . size ( 0 )
min_xy = x1 . new_tensor ( 0 )
max_xy = torch . cat ( [ max_shape , max_shape ] ,
dim = - 1 ) . flip ( - 1 ) . unsqueeze ( - 2 )
bboxes = torch . where ( bboxes < min_xy , min_xy , bboxes )
bboxes = torch . where ( bboxes > max_xy , max_xy , bboxes )
return bboxes
2022-07-12 18:07:02 +08:00
def batched_nms ( boxes , scores , idxs , nms_cfg , class_agnostic = False ) :
r """ Performs non-maximum suppression in a batched fashion.
Modified from ` torchvision / ops / boxes . py #L39
< https : / / github . com / pytorch / vision / blob /
505 cd6957711af790211896d32b40291bea1bc21 / torchvision / ops / boxes . py #L39>`_.
In order to perform NMS independently per class , we add an offset to all
the boxes . The offset is dependent only on the class idx , and is large
enough so that boxes from different classes do not overlap .
Note :
In v1 .4 .1 and later , ` ` batched_nms ` ` supports skipping the NMS and
returns sorted raw results when ` nms_cfg ` is None .
Args :
boxes ( torch . Tensor ) : boxes in shape ( N , 4 ) .
scores ( torch . Tensor ) : scores in shape ( N , ) .
idxs ( torch . Tensor ) : each index value correspond to a bbox cluster ,
and NMS will not be applied between elements of different idxs ,
shape ( N , ) .
nms_cfg ( dict | None ) : Supports skipping the nms when ` nms_cfg `
is None , otherwise it should specify nms type and other
parameters like ` iou_thr ` . Possible keys includes the following .
- iou_thr ( float ) : IoU threshold used for NMS .
- split_thr ( float ) : threshold number of boxes . In some cases the
number of boxes is large ( e . g . , 200 k ) . To avoid OOM during
training , the users could set ` split_thr ` to a small value .
If the number of boxes is greater than the threshold , it will
perform NMS on each group of boxes separately and sequentially .
Defaults to 10000.
class_agnostic ( bool ) : if true , nms is class agnostic ,
i . e . IoU thresholding happens over all boxes ,
regardless of the predicted class .
Returns :
tuple : kept dets and indice .
- boxes ( Tensor ) : Bboxes with score after nms , has shape
( num_bboxes , 5 ) . last dimension 5 arrange as
( x1 , y1 , x2 , y2 , score )
- keep ( Tensor ) : The indices of remaining boxes in input
boxes .
"""
# skip nms when nms_cfg is None
if nms_cfg is None :
scores , inds = scores . sort ( descending = True )
boxes = boxes [ inds ]
return torch . cat ( [ boxes , scores [ : , None ] ] , - 1 ) , inds
nms_cfg_ = nms_cfg . copy ( )
class_agnostic = nms_cfg_ . pop ( ' class_agnostic ' , class_agnostic )
if class_agnostic :
boxes_for_nms = boxes
else :
max_coordinate = boxes . max ( )
offsets = idxs . to ( boxes ) * ( max_coordinate + torch . tensor ( 1 ) . to ( boxes ) )
boxes_for_nms = boxes + offsets [ : , None ]
nms_type = nms_cfg_ . pop ( ' type ' , ' nms ' )
nms_op = eval ( nms_type )
split_thr = nms_cfg_ . pop ( ' split_thr ' , 10000 )
# Won't split to multiple nms nodes when exporting to onnx
if boxes_for_nms . shape [ 0 ] < split_thr or torch . onnx . is_in_onnx_export ( ) :
keep = nms ( boxes_for_nms , scores , * * nms_cfg_ )
boxes = boxes [ keep ]
# This assumes `dets` has arbitrary dimensions where
# the last dimension is score.
# Currently it supports bounding boxes [x1, y1, x2, y2, score] or
# rotated boxes [cx, cy, w, h, angle_radian, score].
scores = scores [ keep ]
else :
max_num = nms_cfg_ . pop ( ' max_num ' , - 1 )
total_mask = scores . new_zeros ( scores . size ( ) , dtype = torch . bool )
# Some type of nms would reweight the score, such as SoftNMS
scores_after_nms = scores . new_zeros ( scores . size ( ) )
for id in torch . unique ( idxs ) :
mask = ( idxs == id ) . nonzero ( as_tuple = False ) . view ( - 1 )
keep = nms ( boxes_for_nms [ mask ] , scores [ mask ] , * * nms_cfg_ )
total_mask [ mask [ keep ] ] = True
scores_after_nms [ mask [ keep ] ] = scores [ keep ]
keep = total_mask . nonzero ( as_tuple = False ) . view ( - 1 )
scores , inds = scores_after_nms [ keep ] . sort ( descending = True )
keep = keep [ inds ]
boxes = boxes [ keep ]
if max_num > 0 :
keep = keep [ : max_num ]
boxes = boxes [ : max_num ]
scores = scores [ : max_num ]
boxes = torch . cat ( [ boxes , scores [ : , None ] ] , - 1 )
return boxes , keep