mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
* Fix export to onnx for PointRend * Fix codestyle * Fix codestyle * Fix type in docstring * Minor fix * Fix export with custom ops * Fix codestyle * Add tests for bilinear_grid_sample function * Remove redundant operation and rename variables * Fix bug in bilinear_grid_sample and update test * Fix getting batch size * skip torch==1.3.1 * remove unused import * fix lint * support export with batch * fix dynamic clip * skip test for torch<1.5.0 * Add docstrings and comments * Minor fix * Recover clipping code * Fix clamping in pytorch 1.7.0 * Fix bilinear_grid_sampler * Minor fix Co-authored-by: maningsheng <maningsheng@sensetime.com>
41 lines
1.8 KiB
Python
41 lines
1.8 KiB
Python
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class TestBilinearGridSample(object):
|
|
|
|
def _test_bilinear_grid_sample(self,
|
|
dtype=torch.float,
|
|
align_corners=False,
|
|
multiplier=1,
|
|
precision=1e-3):
|
|
from mmcv.ops.point_sample import bilinear_grid_sample
|
|
|
|
input = torch.rand(1, 1, 20, 20, dtype=dtype)
|
|
grid = torch.Tensor([[[1, 0, 0], [0, 1, 0]]])
|
|
grid = nn.functional.affine_grid(grid, (1, 1, 15, 15)).type_as(input)
|
|
grid *= multiplier
|
|
|
|
out = bilinear_grid_sample(input, grid, align_corners=align_corners)
|
|
ref_out = F.grid_sample(input, grid, align_corners=align_corners)
|
|
|
|
assert np.allclose(out.data.detach().cpu().numpy(),
|
|
ref_out.data.detach().cpu().numpy(), precision)
|
|
|
|
def test_bilinear_grid_sample(self):
|
|
self._test_bilinear_grid_sample(torch.double, False)
|
|
self._test_bilinear_grid_sample(torch.double, True)
|
|
self._test_bilinear_grid_sample(torch.float, False)
|
|
self._test_bilinear_grid_sample(torch.float, True)
|
|
self._test_bilinear_grid_sample(torch.float, False)
|
|
self._test_bilinear_grid_sample(torch.float, True, 5)
|
|
self._test_bilinear_grid_sample(torch.float, False, 10)
|
|
self._test_bilinear_grid_sample(torch.float, True, -6)
|
|
self._test_bilinear_grid_sample(torch.float, False, -10)
|
|
self._test_bilinear_grid_sample(torch.double, True, 5)
|
|
self._test_bilinear_grid_sample(torch.double, False, 10)
|
|
self._test_bilinear_grid_sample(torch.double, True, -6)
|
|
self._test_bilinear_grid_sample(torch.double, False, -10)
|