mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
* add ext ops, support parrots * fix lint * fix lint * update op from mmdetection * support non-pytorch env * fix import bug * test not import mmcv.op * rename mmcv.op to mmcv.ops * fix compile warning * 1. fix syncbn warning in pytorch 1.5 2. support only cpu compile 3. add point_sample from mmdet * fix text bug * update docstrings * fix line endings * minor updates * remove non_local from ops * bug fix for nonlocal2d * rename ops_ext to _ext and _ext to _flow_warp_ext * update the doc * try clang-format github action * fix github action * add ops to api.rst * fix cpp format * fix clang format issues * remove .clang-format Co-authored-by: Kai Chen <chenkaidev@gmail.com>
28 lines
938 B
Python
28 lines
938 B
Python
import torch
|
|
from torch.autograd import gradcheck
|
|
|
|
|
|
class TestCarafe(object):
|
|
|
|
def test_carafe_naive_gradcheck(self):
|
|
if not torch.cuda.is_available():
|
|
return
|
|
from mmcv.ops import CARAFENaive
|
|
feat = torch.randn(
|
|
2, 64, 3, 3, requires_grad=True, device='cuda').double()
|
|
mask = torch.randn(
|
|
2, 100, 6, 6, requires_grad=True,
|
|
device='cuda').sigmoid().double()
|
|
gradcheck(CARAFENaive(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4)
|
|
|
|
def test_carafe_gradcheck(self):
|
|
if not torch.cuda.is_available():
|
|
return
|
|
from mmcv.ops import CARAFE
|
|
feat = torch.randn(
|
|
2, 64, 3, 3, requires_grad=True, device='cuda').double()
|
|
mask = torch.randn(
|
|
2, 100, 6, 6, requires_grad=True,
|
|
device='cuda').sigmoid().double()
|
|
gradcheck(CARAFE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4)
|