mirror of https://github.com/open-mmlab/mmcv.git
fix test unit of nms and batched_nms for tensorrt (#872)
parent
ce10425d2f
commit
3dd98daaa2
|
@ -123,8 +123,8 @@ def test_nms():
|
|||
fp16_mode = False
|
||||
max_workspace_size = 1 << 30
|
||||
data = mmcv.load('./tests/data/batched_nms_data.pkl')
|
||||
boxes = data['boxes'].cuda()
|
||||
scores = data['scores'].cuda()
|
||||
boxes = torch.from_numpy(data['boxes']).cuda()
|
||||
scores = torch.from_numpy(data['scores']).cuda()
|
||||
nms = partial(nms, iou_threshold=0.7, offset=0)
|
||||
wrapped_model = WrapFunction(nms)
|
||||
wrapped_model.cpu().eval()
|
||||
|
@ -195,9 +195,9 @@ def test_batched_nms():
|
|||
max_workspace_size = 1 << 30
|
||||
data = mmcv.load('./tests/data/batched_nms_data.pkl')
|
||||
nms_cfg = dict(type='nms', iou_threshold=0.7)
|
||||
boxes = data['boxes'].cuda()
|
||||
scores = data['scores'].cuda()
|
||||
idxs = data['idxs'].cuda()
|
||||
boxes = torch.from_numpy(data['boxes']).cuda()
|
||||
scores = torch.from_numpy(data['scores']).cuda()
|
||||
idxs = torch.from_numpy(data['idxs']).cuda()
|
||||
class_agnostic = False
|
||||
|
||||
nms = partial(batched_nms, nms_cfg=nms_cfg, class_agnostic=class_agnostic)
|
||||
|
|
Loading…
Reference in New Issue