mmcv/tests/test_ops/test_bilinear_grid_sample.py
Dmitry Sidnev a88d1d28c1
[Feature] enable exporting to onnx for PointRend (#953)
* Fix export to onnx for PointRend

* Fix codestyle

* Fix codestyle

* Fix type in docstring

* Minor fix

* Fix export with custom ops

* Fix codestyle

* Add tests for bilinear_grid_sample function

* Remove redundant operation and rename variables

* Fix bug in bilinear_grid_sample and update test

* Fix getting batch size

* skip torch==1.3.1

* remove unused import

* fix lint

* support export with batch

* fix dynamic clip

* skip test for torch<1.5.0

* Add docstrings and comments

* Minor fix

* Recover clipping code

* Fix clamping in pytorch 1.7.0

* Fix bilinear_grid_sampler

* Minor fix

Co-authored-by: maningsheng <maningsheng@sensetime.com>
2021-06-11 13:49:19 +08:00

41 lines
1.8 KiB
Python

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class TestBilinearGridSample(object):
def _test_bilinear_grid_sample(self,
dtype=torch.float,
align_corners=False,
multiplier=1,
precision=1e-3):
from mmcv.ops.point_sample import bilinear_grid_sample
input = torch.rand(1, 1, 20, 20, dtype=dtype)
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
grid *= multiplier
out = bilinear_grid_sample(input, grid, align_corners=align_corners)
ref_out = F.grid_sample(input, grid, align_corners=align_corners)
assert np.allclose(out.data.detach().cpu().numpy(),
ref_out.data.detach().cpu().numpy(), precision)
def test_bilinear_grid_sample(self):
self._test_bilinear_grid_sample(torch.double, False)
self._test_bilinear_grid_sample(torch.double, True)
self._test_bilinear_grid_sample(torch.float, False)
self._test_bilinear_grid_sample(torch.float, True)
self._test_bilinear_grid_sample(torch.float, False)
self._test_bilinear_grid_sample(torch.float, True, 5)
self._test_bilinear_grid_sample(torch.float, False, 10)
self._test_bilinear_grid_sample(torch.float, True, -6)
self._test_bilinear_grid_sample(torch.float, False, -10)
self._test_bilinear_grid_sample(torch.double, True, 5)
self._test_bilinear_grid_sample(torch.double, False, 10)
self._test_bilinear_grid_sample(torch.double, True, -6)
self._test_bilinear_grid_sample(torch.double, False, -10)