mmcv/tests/test_ops/test_nms.py
Jiazhen Wang 362a90f8bf
[Feature] Add several MLU ops (#1563)
* [Feature] Add roiaware pool3d ops from mmdet3d (#1382)

* add ops (roiaware pool3d) in mmdet3d

* refactor code

* fix typo

Co-authored-by: zhouzaida <zhouzaida@163.com>

* [Feature] Add iou3d op from mmdet3d (#1356)

* add ops (iou3d) in mmdet3d

* add unit test

* refactor code

* refactor code

* refactor code

* refactor code

* refactor code

Co-authored-by: zhouzaida <zhouzaida@163.com>

* [Fix] Update test data for test_iou3d (#1427)

* Update test data for test_iou3d

* delete blank lines

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* [Feature] Add group points ops from mmdet3d (#1415)

* add op (group points) and its related ops (ball query and knn) in mmdet3d

* refactor code

* fix typo

* refactor code

* fix typo

* refactor code

* make input contiguous

Co-authored-by: zhouzaida <zhouzaida@163.com>

* add mmdet3d op (#1425)

Co-authored-by: zhouzaida <zhouzaida@163.com>

* [Feature] Loading objects from different backends and dumping objects to different backends (#1330)

* [Feature] Choose storage backend by the prefix of filepath

* refactor FileClient and add unittest

* support loading from different backends

* polish docstring

* fix unittet

* rename attribute str_like_obj to is_str_like_obj

* add infer_client method

* add check_exist method

* rename var client to file_client

* polish docstring

* add join_paths method

* remove join_paths and add _format_path

* enhance unittest

* refactor unittest

* singleton pattern

* fix test_clientio.py

* deprecate CephBackend

* enhance docstring

* refactor unittest for petrel

* refactor unittest for disk backend

* update io.md

* add concat_paths method

* improve docstring

* improve docstring

* add isdir and copyfile for file backend

* delete copyfile and add get_local_path

* remove isdir method of petrel

* fix typo

* add comment and polish docstring

* polish docstring

* rename _path_mapping to _map_path

* polish docstring and fix typo

* refactor get_local_path

* add list_dir_or_file for FileClient

* add list_dir_or_file for PetrelBackend

* fix windows ci

* Add return docstring

* polish docstring

* fix typo

* fix typo

* deprecate the conversion from Path to str

* add docs for loading checkpoints with FileClient

* refactor map_path

* add _ensure_methods to ensure methods have been implemented

* fix list_dir_or_file

* rename _ensure_method_implemented to has_method

* Add CI for pytorch 1.10 (#1431)

* [Feature] Upload checkpoints and logs to ceph (#1375)

* [Feature] Choose storage backend by the prefix of filepath

* refactor FileClient and add unittest

* support loading from different backends

* polish docstring

* fix unittet

* rename attribute str_like_obj to is_str_like_obj

* [Docs] Upload checkpoint to petrel oss

* add infer_client method

* Support uploading checkpoint to petrel oss

* add check_exist method

* refactor CheckpointHook

* support uploading logs to ceph

* rename var client to file_client

* polish docstring

* enhance load_from_ceph

* refactor load_from_ceph

* refactor TextLoggerHook

* change the meaning of out_dir argument

* fix test_checkpoint_hook.py

* add join_paths method

* remove join_paths and add _format_path

* enhance unittest

* refactor unittest

* add a unittest for EvalHook when file backend is petrel

* singleton pattern

* fix test_clientio.py

* deprecate CephBackend

* add warning in load_from_ceph

* fix type of out_suffix

* enhance docstring

* refactor unittest for petrel

* refactor unittest for disk backend

* update io.md

* add concat_paths method

* fix CI

* mock check_exist

* improve docstring

* improve docstring

* improve docstring

* improve docstring

* add isdir and copyfile for file backend

* delete copyfile and add get_local_path

* remove isdir method of petrel

* fix typo

* rename check_exists to exists

* refactor code and polish docstring

* fix windows ci

* add comment and polish docstring

* polish docstring

* polish docstring

* rename _path_mapping to _map_path

* polish docstring and fix typo

* refactor get_local_path

* add list_dir_or_file for FileClient

* add list_dir_or_file for PetrelBackend

* fix windows ci

* Add return docstring

* polish docstring

* fix typo

* fix typo

* fix typo

* fix error when mocking PetrelBackend

* deprecate the conversion from Path to str

* add docs for loading checkpoints with FileClient

* rename keep_log to keep_local

* refactor map_path

* add _ensure_methods to ensure methods have been implemented

* fix list_dir_or_file

* rename _ensure_method_implemented to has_method

* refactor

* polish information

* format information

* bump version to v1.3.16 (#1430)

* [Fix]: Update test data of test_tin_shift (#1426)

* Update test data of test_tin_shift

* Delete tmp.engine

* add pytest raises asserterror test

* raise valueerror, update test log

* add more comment

* Apply suggestions from code review

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix the wrong function reference bug in BaseTransformerLayer when batch_first is True (#1418)

* [Docs] Add mmcv itself in the docs list (#1441)

* Add mmcv itself in the docs list

* modify link of docs

* [Improve] improve checkpoint loading log (#1446)

* [Feature] Support SigmoidFocalLoss with Cambricon MLU backend (#1346)

* [Feature] Support SigmoidFocalLoss with Cambricon MLU backend

* refactor MMCV_WITH_MLU macro define

* refactor NFU_ALIGN_SIZE, PAD_DOWN and split_pipeline_num

* delete extra fool proofing in cpp

* [Feature] Support SigmoidFocalLossBackward with Cambricon MLU backend

* fix macro definition in SigmoidFocalLoss

* refactor mlu files into clang-format

* refactor sigmoid focal loss test

* refactor Sigmoid Focal Loss file structure.

* fix python lint error

* fix import torch_mlu error type

* fix lint

* refactor clang format style to google

Co-authored-by: zhouzaida <zhouzaida@163.com>

* [Feature] Support RoiAlign With Cambricon MLU Backend (#1429)

* [Feature] Support NMS with cambricon MLU backend (#1467)

* [Feature] Support BBoxOverlaps with cambricon MLU backend (#1507)

* [Refactor] Format C++ code

* [Refactor] include common_mlu_helper in pytorch_mlu_helper and refactor build condition

* [Improve] Improve the performance of roialign, nms and focalloss with MLU backend (#1572)

* [Improve] Improve the performance of roialign with MLU backend

* replace CHECK_MLU with CHECK_MLU_INPUT

* [Improve] Improve the perf of nms and focallosssigmoid with MLU backend

* [Improve] Improve the performance of roialign with MLU backend (#1741)

* [Feature] Support tin_shift with cambricon MLU backend (#1696)

* [Feature] Support tin_shift with cambricon MLU backend

* [fix] Add the assertion of batch_size in tin_shift.py

* [fix] fix the param check of tin_shift in cambricon code

* [fix] Fix lint failure.

* [fix] Fix source file lint failure.

* Update mmcv/ops/tin_shift.py

[Refactor] Modify the code in mmcv/ops/tin_shift.py.

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: budefei <budefei@cambricon.com>
Co-authored-by: budefei <budefei@cambricom.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* resolve conflicts and fix lint

* fix mmcv.utils.__init__

* fix mmcv.utils.__init__

* Fix lints and change FLAG

* fix setup and refine

* remove a redundant line

* remove an unnecessary 'f'

* fix compilation error

Co-authored-by: dingchang <hudingchang.vendor@sensetime.com>
Co-authored-by: zhouzaida <zhouzaida@163.com>
Co-authored-by: q.yao <yaoqian@sensetime.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: pc <luopeichao@sensetime.com>
Co-authored-by: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com>
Co-authored-by: q.yao <streetyao@live.com>
Co-authored-by: Tong Gao <gaotongxiao@gmail.com>
Co-authored-by: Yuxin Liu <liuyuxin@cambricon.com>
Co-authored-by: zihanchang11 <92860914+zihanchang11@users.noreply.github.com>
Co-authored-by: shlrao <shenglong.rao@gmail.com>
Co-authored-by: zhouchenyang <zcy19950525@gmail.com>
Co-authored-by: Mrxiaofei <36697723+Mrxiaofei@users.noreply.github.com>
Co-authored-by: budefei <budefei@cambricon.com>
Co-authored-by: budefei <budefei@cambricom.com>
2022-04-16 15:45:00 +08:00

208 lines
7.7 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import pytest
import torch
from mmcv.device.mlu import IS_MLU_AVAILABLE
from mmcv.utils import IS_CUDA_AVAILABLE
class Testnms(object):
@pytest.mark.parametrize('device', [
pytest.param(
'cuda',
marks=pytest.mark.skipif(
not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
pytest.param(
'mlu',
marks=pytest.mark.skipif(
not IS_MLU_AVAILABLE, reason='requires MLU support'))
])
def test_nms_allclose(self, device):
from mmcv.ops import nms
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
np_inds = np.array([1, 0, 3])
np_dets = np.array([[3.0, 6.0, 9.0, 11.0, 0.9],
[6.0, 3.0, 8.0, 7.0, 0.6],
[1.0, 4.0, 13.0, 7.0, 0.2]])
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
dets, inds = nms(boxes, scores, iou_threshold=0.3, offset=0)
assert np.allclose(dets, np_dets) # test cpu
assert np.allclose(inds, np_inds) # test cpu
dets, inds = nms(
boxes.to(device), scores.to(device), iou_threshold=0.3, offset=0)
assert np.allclose(dets.cpu().numpy(), np_dets) # test gpu
assert np.allclose(inds.cpu().numpy(), np_inds) # test gpu
def test_softnms_allclose(self):
if not torch.cuda.is_available():
return
from mmcv.ops import soft_nms
np_boxes = np.array([[6.0, 3.0, 8.0, 7.0], [3.0, 6.0, 9.0, 11.0],
[3.0, 7.0, 10.0, 12.0], [1.0, 4.0, 13.0, 7.0]],
dtype=np.float32)
np_scores = np.array([0.6, 0.9, 0.7, 0.2], dtype=np.float32)
np_output = {
'linear': {
'dets':
np.array(
[[3., 6., 9., 11., 0.9], [6., 3., 8., 7., 0.6],
[3., 7., 10., 12., 0.29024392], [1., 4., 13., 7., 0.2]],
dtype=np.float32),
'inds':
np.array([1, 0, 2, 3], dtype=np.int64)
},
'gaussian': {
'dets':
np.array([[3., 6., 9., 11., 0.9], [6., 3., 8., 7., 0.59630775],
[3., 7., 10., 12., 0.35275510],
[1., 4., 13., 7., 0.18650459]],
dtype=np.float32),
'inds':
np.array([1, 0, 2, 3], dtype=np.int64)
},
'naive': {
'dets':
np.array([[3., 6., 9., 11., 0.9], [6., 3., 8., 7., 0.6],
[1., 4., 13., 7., 0.2]],
dtype=np.float32),
'inds':
np.array([1, 0, 3], dtype=np.int64)
}
}
boxes = torch.from_numpy(np_boxes)
scores = torch.from_numpy(np_scores)
configs = [[0.3, 0.5, 0.01, 'linear'], [0.3, 0.5, 0.01, 'gaussian'],
[0.3, 0.5, 0.01, 'naive']]
for iou, sig, mscore, m in configs:
dets, inds = soft_nms(
boxes,
scores,
iou_threshold=iou,
sigma=sig,
min_score=mscore,
method=m)
assert np.allclose(dets.cpu().numpy(), np_output[m]['dets'])
assert np.allclose(inds.cpu().numpy(), np_output[m]['inds'])
if torch.__version__ != 'parrots':
boxes = boxes.cuda()
scores = scores.cuda()
for iou, sig, mscore, m in configs:
dets, inds = soft_nms(
boxes,
scores,
iou_threshold=iou,
sigma=sig,
min_score=mscore,
method=m)
assert np.allclose(dets.cpu().numpy(), np_output[m]['dets'])
assert np.allclose(inds.cpu().numpy(), np_output[m]['inds'])
def test_nms_match(self):
if not torch.cuda.is_available():
return
from mmcv.ops import nms, nms_match
iou_thr = 0.6
# empty input
empty_dets = np.array([])
assert len(nms_match(empty_dets, iou_thr)) == 0
# non empty ndarray input
np_dets = np.array(
[[49.1, 32.4, 51.0, 35.9, 0.9], [49.3, 32.9, 51.0, 35.3, 0.9],
[35.3, 11.5, 39.9, 14.5, 0.4], [35.2, 11.7, 39.7, 15.7, 0.3]],
dtype=np.float32)
np_groups = nms_match(np_dets, iou_thr)
assert isinstance(np_groups[0], np.ndarray)
assert len(np_groups) == 2
tensor_dets = torch.from_numpy(np_dets)
boxes = tensor_dets[:, :4]
scores = tensor_dets[:, 4]
nms_keep_inds = nms(boxes.contiguous(), scores.contiguous(),
iou_thr)[1]
assert set([g[0].item()
for g in np_groups]) == set(nms_keep_inds.tolist())
# non empty tensor input
tensor_dets = torch.from_numpy(np_dets)
tensor_groups = nms_match(tensor_dets, iou_thr)
assert isinstance(tensor_groups[0], torch.Tensor)
for i in range(len(tensor_groups)):
assert np.equal(tensor_groups[i].numpy(), np_groups[i]).all()
# input of wrong shape
wrong_dets = np.zeros((2, 3))
with pytest.raises(AssertionError):
nms_match(wrong_dets, iou_thr)
def test_batched_nms(self):
import mmcv
from mmcv.ops import batched_nms
results = mmcv.load('./tests/data/batched_nms_data.pkl')
nms_max_num = 100
nms_cfg = dict(
type='nms',
iou_threshold=0.7,
score_threshold=0.5,
max_num=nms_max_num)
boxes, keep = batched_nms(
torch.from_numpy(results['boxes']),
torch.from_numpy(results['scores']),
torch.from_numpy(results['idxs']),
nms_cfg,
class_agnostic=False)
nms_cfg.update(split_thr=100)
seq_boxes, seq_keep = batched_nms(
torch.from_numpy(results['boxes']),
torch.from_numpy(results['scores']),
torch.from_numpy(results['idxs']),
nms_cfg,
class_agnostic=False)
assert torch.equal(keep, seq_keep)
assert torch.equal(boxes, seq_boxes)
assert torch.equal(keep,
torch.from_numpy(results['keep'][:nms_max_num]))
nms_cfg = dict(type='soft_nms', iou_threshold=0.7)
boxes, keep = batched_nms(
torch.from_numpy(results['boxes']),
torch.from_numpy(results['scores']),
torch.from_numpy(results['idxs']),
nms_cfg,
class_agnostic=False)
nms_cfg.update(split_thr=100)
seq_boxes, seq_keep = batched_nms(
torch.from_numpy(results['boxes']),
torch.from_numpy(results['scores']),
torch.from_numpy(results['idxs']),
nms_cfg,
class_agnostic=False)
assert torch.equal(keep, seq_keep)
assert torch.equal(boxes, seq_boxes)
# test skip nms when `nms_cfg` is None
seq_boxes, seq_keep = batched_nms(
torch.from_numpy(results['boxes']),
torch.from_numpy(results['scores']),
torch.from_numpy(results['idxs']),
None,
class_agnostic=False)
assert len(seq_keep) == len(results['boxes'])
# assert score is descending order
assert ((seq_boxes[:, -1][1:] - seq_boxes[:, -1][:-1]) < 0).all()