mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
* add ext ops, support parrots * fix lint * fix lint * update op from mmdetection * support non-pytorch env * fix import bug * test not import mmcv.op * rename mmcv.op to mmcv.ops * fix compile warning * 1. fix syncbn warning in pytorch 1.5 2. support only cpu compile 3. add point_sample from mmdet * fix text bug * update docstrings * fix line endings * minor updates * remove non_local from ops * bug fix for nonlocal2d * rename ops_ext to _ext and _ext to _flow_warp_ext * update the doc * try clang-format github action * fix github action * add ops to api.rst * fix cpp format * fix clang format issues * remove .clang-format Co-authored-by: Kai Chen <chenkaidev@gmail.com>
15 lines
448 B
Python
15 lines
448 B
Python
import torch
|
|
|
|
|
|
class TestMaskedConv2d(object):
|
|
|
|
def test_masked_conv2d(self):
|
|
if not torch.cuda.is_available():
|
|
return
|
|
from mmcv.ops import MaskedConv2d
|
|
input = torch.randn(1, 3, 16, 16, requires_grad=True, device='cuda')
|
|
mask = torch.randn(1, 16, 16, requires_grad=True, device='cuda')
|
|
conv = MaskedConv2d(3, 3, 3).cuda()
|
|
output = conv(input, mask)
|
|
assert output is not None
|