mirror of https://github.com/open-mmlab/mmcv.git
[Feature] Add nondeterministic voxelization op from mmdet3d (#1783)
* add nondeterministic voxelization op * fix lint * fix lint * resolve comments * fix lintpull/1806/head^2
parent
33e14deaea
commit
b5d550f090
|
@ -166,4 +166,51 @@ __global__ void determin_voxel_num(
|
|||
}
|
||||
}
|
||||
|
||||
__global__ void nondeterministic_get_assign_pos(
|
||||
const int nthreads, const int32_t* coors_map, int32_t* pts_id,
|
||||
int32_t* coors_count, int32_t* reduce_count, int32_t* coors_order) {
|
||||
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
|
||||
int coors_idx = coors_map[thread_idx];
|
||||
if (coors_idx > -1) {
|
||||
int32_t coors_pts_pos = atomicAdd(&reduce_count[coors_idx], 1);
|
||||
pts_id[thread_idx] = coors_pts_pos;
|
||||
if (coors_pts_pos == 0) {
|
||||
coors_order[coors_idx] = atomicAdd(coors_count, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void nondeterministic_assign_point_voxel(
|
||||
const int nthreads, const T* points, const int32_t* coors_map,
|
||||
const int32_t* pts_id, const int32_t* coors_in, const int32_t* reduce_count,
|
||||
const int32_t* coors_order, T* voxels, int32_t* coors, int32_t* pts_count,
|
||||
const int max_voxels, const int max_points, const int num_features,
|
||||
const int NDim) {
|
||||
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
|
||||
int coors_idx = coors_map[thread_idx];
|
||||
int coors_pts_pos = pts_id[thread_idx];
|
||||
if (coors_idx > -1 && coors_pts_pos < max_points) {
|
||||
int coors_pos = coors_order[coors_idx];
|
||||
if (coors_pos < max_voxels) {
|
||||
auto voxels_offset =
|
||||
voxels + (coors_pos * max_points + coors_pts_pos) * num_features;
|
||||
auto points_offset = points + thread_idx * num_features;
|
||||
for (int k = 0; k < num_features; k++) {
|
||||
voxels_offset[k] = points_offset[k];
|
||||
}
|
||||
if (coors_pts_pos == 0) {
|
||||
pts_count[coors_pos] = min(reduce_count[coors_idx], max_points);
|
||||
auto coors_offset = coors + coors_pos * NDim;
|
||||
auto coors_in_offset = coors_in + coors_idx * NDim;
|
||||
for (int k = 0; k < NDim; k++) {
|
||||
coors_offset[k] = coors_in_offset[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // VOXELIZATION_CUDA_KERNEL_CUH
|
||||
|
|
|
@ -1396,6 +1396,12 @@ int HardVoxelizeForwardCUDAKernelLauncher(
|
|||
const std::vector<float> coors_range, const int max_points,
|
||||
const int max_voxels, const int NDim = 3);
|
||||
|
||||
int NondeterministicHardVoxelizeForwardCUDAKernelLauncher(
|
||||
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
|
||||
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range, const int max_points,
|
||||
const int max_voxels, const int NDim = 3);
|
||||
|
||||
void DynamicVoxelizeForwardCUDAKernelLauncher(
|
||||
const at::Tensor& points, at::Tensor& coors,
|
||||
const std::vector<float> voxel_size, const std::vector<float> coors_range,
|
||||
|
@ -1413,6 +1419,16 @@ int hard_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& voxels,
|
|||
max_points, max_voxels, NDim);
|
||||
};
|
||||
|
||||
int nondeterministic_hard_voxelize_forward_cuda(
|
||||
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
|
||||
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range, const int max_points,
|
||||
const int max_voxels, const int NDim) {
|
||||
return NondeterministicHardVoxelizeForwardCUDAKernelLauncher(
|
||||
points, voxels, coors, num_points_per_voxel, voxel_size, coors_range,
|
||||
max_points, max_voxels, NDim);
|
||||
};
|
||||
|
||||
void dynamic_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& coors,
|
||||
const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range,
|
||||
|
@ -1429,6 +1445,12 @@ int hard_voxelize_forward_impl(const at::Tensor& points, at::Tensor& voxels,
|
|||
const int max_points, const int max_voxels,
|
||||
const int NDim);
|
||||
|
||||
int nondeterministic_hard_voxelize_forward_impl(
|
||||
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
|
||||
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range, const int max_points,
|
||||
const int max_voxels, const int NDim);
|
||||
|
||||
void dynamic_voxelize_forward_impl(const at::Tensor& points, at::Tensor& coors,
|
||||
const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range,
|
||||
|
@ -1436,6 +1458,8 @@ void dynamic_voxelize_forward_impl(const at::Tensor& points, at::Tensor& coors,
|
|||
|
||||
REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, CUDA,
|
||||
hard_voxelize_forward_cuda);
|
||||
REGISTER_DEVICE_IMPL(nondeterministic_hard_voxelize_forward_impl, CUDA,
|
||||
nondeterministic_hard_voxelize_forward_cuda);
|
||||
REGISTER_DEVICE_IMPL(dynamic_voxelize_forward_impl, CUDA,
|
||||
dynamic_voxelize_forward_cuda);
|
||||
|
||||
|
|
|
@ -145,6 +145,104 @@ int HardVoxelizeForwardCUDAKernelLauncher(
|
|||
return voxel_num_int;
|
||||
}
|
||||
|
||||
int NondeterministicHardVoxelizeForwardCUDAKernelLauncher(
|
||||
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
|
||||
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range, const int max_points,
|
||||
const int max_voxels, const int NDim = 3) {
|
||||
at::cuda::CUDAGuard device_guard(points.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
const int num_points = points.size(0);
|
||||
const int num_features = points.size(1);
|
||||
|
||||
if (num_points == 0) return 0;
|
||||
|
||||
dim3 blocks(
|
||||
std::min(at::cuda::ATenCeilDiv(num_points, THREADS_PER_BLOCK), 4096));
|
||||
dim3 threads(THREADS_PER_BLOCK);
|
||||
|
||||
const float voxel_x = voxel_size[0];
|
||||
const float voxel_y = voxel_size[1];
|
||||
const float voxel_z = voxel_size[2];
|
||||
const float coors_x_min = coors_range[0];
|
||||
const float coors_y_min = coors_range[1];
|
||||
const float coors_z_min = coors_range[2];
|
||||
const float coors_x_max = coors_range[3];
|
||||
const float coors_y_max = coors_range[4];
|
||||
const float coors_z_max = coors_range[5];
|
||||
|
||||
const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
|
||||
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
|
||||
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);
|
||||
|
||||
// map points to voxel coors
|
||||
at::Tensor temp_coors =
|
||||
at::zeros({num_points, NDim}, points.options().dtype(at::kInt));
|
||||
|
||||
// 1. link point to corresponding voxel coors
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
points.scalar_type(), "hard_voxelize_kernel", ([&] {
|
||||
dynamic_voxelize_kernel<scalar_t, int><<<blocks, threads, 0, stream>>>(
|
||||
points.contiguous().data_ptr<scalar_t>(),
|
||||
temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
|
||||
coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
|
||||
coors_z_max, grid_x, grid_y, grid_z, num_points, num_features,
|
||||
NDim);
|
||||
}));
|
||||
|
||||
at::Tensor coors_map;
|
||||
at::Tensor reduce_count;
|
||||
|
||||
auto coors_clean = temp_coors.masked_fill(temp_coors.lt(0).any(-1, true), -1);
|
||||
|
||||
std::tie(temp_coors, coors_map, reduce_count) =
|
||||
at::unique_dim(coors_clean, 0, true, true, false);
|
||||
|
||||
if (temp_coors[0][0].lt(0).item<bool>()) {
|
||||
// the first element of temp_coors is (-1,-1,-1) and should be removed
|
||||
temp_coors = temp_coors.slice(0, 1);
|
||||
coors_map = coors_map - 1;
|
||||
}
|
||||
|
||||
int num_coors = temp_coors.size(0);
|
||||
temp_coors = temp_coors.to(at::kInt);
|
||||
coors_map = coors_map.to(at::kInt);
|
||||
|
||||
at::Tensor coors_count = at::zeros({1}, coors_map.options());
|
||||
at::Tensor coors_order = at::empty({num_coors}, coors_map.options());
|
||||
at::Tensor pts_id = at::zeros({num_points}, coors_map.options());
|
||||
reduce_count = at::zeros({num_coors}, coors_map.options());
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
points.scalar_type(), "get_assign_pos", ([&] {
|
||||
nondeterministic_get_assign_pos<<<blocks, threads, 0, stream>>>(
|
||||
num_points, coors_map.contiguous().data_ptr<int32_t>(),
|
||||
pts_id.contiguous().data_ptr<int32_t>(),
|
||||
coors_count.contiguous().data_ptr<int32_t>(),
|
||||
reduce_count.contiguous().data_ptr<int32_t>(),
|
||||
coors_order.contiguous().data_ptr<int32_t>());
|
||||
}));
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(
|
||||
points.scalar_type(), "assign_point_to_voxel", ([&] {
|
||||
nondeterministic_assign_point_voxel<scalar_t>
|
||||
<<<blocks, threads, 0, stream>>>(
|
||||
num_points, points.contiguous().data_ptr<scalar_t>(),
|
||||
coors_map.contiguous().data_ptr<int32_t>(),
|
||||
pts_id.contiguous().data_ptr<int32_t>(),
|
||||
temp_coors.contiguous().data_ptr<int32_t>(),
|
||||
reduce_count.contiguous().data_ptr<int32_t>(),
|
||||
coors_order.contiguous().data_ptr<int32_t>(),
|
||||
voxels.contiguous().data_ptr<scalar_t>(),
|
||||
coors.contiguous().data_ptr<int32_t>(),
|
||||
num_points_per_voxel.contiguous().data_ptr<int32_t>(),
|
||||
max_voxels, max_points, num_features, NDim);
|
||||
}));
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return max_voxels < num_coors ? max_voxels : num_coors;
|
||||
}
|
||||
|
||||
void DynamicVoxelizeForwardCUDAKernelLauncher(
|
||||
const at::Tensor &points, at::Tensor &coors,
|
||||
const std::vector<float> voxel_size, const std::vector<float> coors_range,
|
||||
|
|
|
@ -338,7 +338,8 @@ void hard_voxelize_forward(const at::Tensor &points,
|
|||
const at::Tensor &coors_range, at::Tensor &voxels,
|
||||
at::Tensor &coors, at::Tensor &num_points_per_voxel,
|
||||
at::Tensor &voxel_num, const int max_points,
|
||||
const int max_voxels, const int NDim);
|
||||
const int max_voxels, const int NDim,
|
||||
const bool deterministic);
|
||||
|
||||
void dynamic_voxelize_forward(const at::Tensor &points,
|
||||
const at::Tensor &voxel_size,
|
||||
|
@ -756,7 +757,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||
"hard_voxelize_forward", py::arg("points"), py::arg("voxel_size"),
|
||||
py::arg("coors_range"), py::arg("voxels"), py::arg("coors"),
|
||||
py::arg("num_points_per_voxel"), py::arg("voxel_num"),
|
||||
py::arg("max_points"), py::arg("max_voxels"), py::arg("NDim"));
|
||||
py::arg("max_points"), py::arg("max_voxels"), py::arg("NDim"),
|
||||
py::arg("deterministic"));
|
||||
m.def("dynamic_voxelize_forward", &dynamic_voxelize_forward,
|
||||
"dynamic_voxelize_forward", py::arg("points"), py::arg("voxel_size"),
|
||||
py::arg("coors_range"), py::arg("coors"), py::arg("NDim"));
|
||||
|
|
|
@ -14,6 +14,17 @@ int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels,
|
|||
max_points, max_voxels, NDim);
|
||||
}
|
||||
|
||||
int nondeterministic_hard_voxelize_forward_impl(
|
||||
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
|
||||
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range, const int max_points,
|
||||
const int max_voxels, const int NDim = 3) {
|
||||
return DISPATCH_DEVICE_IMPL(nondeterministic_hard_voxelize_forward_impl,
|
||||
points, voxels, coors, num_points_per_voxel,
|
||||
voxel_size, coors_range, max_points, max_voxels,
|
||||
NDim);
|
||||
}
|
||||
|
||||
void dynamic_voxelize_forward_impl(const at::Tensor &points, at::Tensor &coors,
|
||||
const std::vector<float> voxel_size,
|
||||
const std::vector<float> coors_range,
|
||||
|
@ -27,7 +38,8 @@ void hard_voxelize_forward(const at::Tensor &points,
|
|||
const at::Tensor &coors_range, at::Tensor &voxels,
|
||||
at::Tensor &coors, at::Tensor &num_points_per_voxel,
|
||||
at::Tensor &voxel_num, const int max_points,
|
||||
const int max_voxels, const int NDim = 3) {
|
||||
const int max_voxels, const int NDim = 3,
|
||||
const bool deterministic = true) {
|
||||
int64_t *voxel_num_data = voxel_num.data_ptr<int64_t>();
|
||||
std::vector<float> voxel_size_v(
|
||||
voxel_size.data_ptr<float>(),
|
||||
|
@ -36,9 +48,15 @@ void hard_voxelize_forward(const at::Tensor &points,
|
|||
coors_range.data_ptr<float>(),
|
||||
coors_range.data_ptr<float>() + coors_range.numel());
|
||||
|
||||
*voxel_num_data = hard_voxelize_forward_impl(
|
||||
points, voxels, coors, num_points_per_voxel, voxel_size_v, coors_range_v,
|
||||
max_points, max_voxels, NDim);
|
||||
if (deterministic) {
|
||||
*voxel_num_data = hard_voxelize_forward_impl(
|
||||
points, voxels, coors, num_points_per_voxel, voxel_size_v,
|
||||
coors_range_v, max_points, max_voxels, NDim);
|
||||
} else {
|
||||
*voxel_num_data = nondeterministic_hard_voxelize_forward_impl(
|
||||
points, voxels, coors, num_points_per_voxel, voxel_size_v,
|
||||
coors_range_v, max_points, max_voxels, NDim);
|
||||
}
|
||||
}
|
||||
|
||||
void dynamic_voxelize_forward(const at::Tensor &points,
|
||||
|
|
|
@ -18,7 +18,8 @@ class _Voxelization(Function):
|
|||
voxel_size,
|
||||
coors_range,
|
||||
max_points=35,
|
||||
max_voxels=20000):
|
||||
max_voxels=20000,
|
||||
deterministic=True):
|
||||
"""Convert kitti points(N, >=3) to voxels.
|
||||
|
||||
Args:
|
||||
|
@ -34,6 +35,16 @@ class _Voxelization(Function):
|
|||
for second, 20000 is a good choice. Users should shuffle points
|
||||
before call this function because max_voxels may drop points.
|
||||
Default: 20000.
|
||||
deterministic: bool. whether to invoke the non-deterministic
|
||||
version of hard-voxelization implementations. non-deterministic
|
||||
version is considerablly fast but is not deterministic. only
|
||||
affects hard voxelization. default True. for more information
|
||||
of this argument and the implementation insights, please refer
|
||||
to the following links:
|
||||
https://github.com/open-mmlab/mmdetection3d/issues/894
|
||||
https://github.com/open-mmlab/mmdetection3d/pull/904
|
||||
it is an experimental feature and we will appreciate it if
|
||||
you could share with us the failing cases.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor]: tuple[torch.Tensor]: A tuple contains three
|
||||
|
@ -69,7 +80,8 @@ class _Voxelization(Function):
|
|||
voxel_num,
|
||||
max_points=max_points,
|
||||
max_voxels=max_voxels,
|
||||
NDim=3)
|
||||
NDim=3,
|
||||
deterministic=deterministic)
|
||||
# select the valid voxels
|
||||
voxels_out = voxels[:voxel_num]
|
||||
coors_out = coors[:voxel_num]
|
||||
|
@ -102,7 +114,27 @@ class Voxelization(nn.Module):
|
|||
voxel_size,
|
||||
point_cloud_range,
|
||||
max_num_points,
|
||||
max_voxels=20000):
|
||||
max_voxels=20000,
|
||||
deterministic=True):
|
||||
"""
|
||||
Args:
|
||||
voxel_size (list): list [x, y, z] size of three dimension
|
||||
point_cloud_range (list):
|
||||
[x_min, y_min, z_min, x_max, y_max, z_max]
|
||||
max_num_points (int): max number of points per voxel
|
||||
max_voxels (tuple or int): max number of voxels in
|
||||
(training, testing) time
|
||||
deterministic: bool. whether to invoke the non-deterministic
|
||||
version of hard-voxelization implementations. non-deterministic
|
||||
version is considerablly fast but is not deterministic. only
|
||||
affects hard voxelization. default True. for more information
|
||||
of this argument and the implementation insights, please refer
|
||||
to the following links:
|
||||
https://github.com/open-mmlab/mmdetection3d/issues/894
|
||||
https://github.com/open-mmlab/mmdetection3d/pull/904
|
||||
it is an experimental feature and we will appreciate it if
|
||||
you could share with us the failing cases.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.voxel_size = voxel_size
|
||||
|
@ -112,6 +144,7 @@ class Voxelization(nn.Module):
|
|||
self.max_voxels = max_voxels
|
||||
else:
|
||||
self.max_voxels = _pair(max_voxels)
|
||||
self.deterministic = deterministic
|
||||
|
||||
point_cloud_range = torch.tensor(
|
||||
point_cloud_range, dtype=torch.float32)
|
||||
|
@ -132,7 +165,8 @@ class Voxelization(nn.Module):
|
|||
max_voxels = self.max_voxels[1]
|
||||
|
||||
return voxelization(input, self.voxel_size, self.point_cloud_range,
|
||||
self.max_num_points, max_voxels)
|
||||
self.max_num_points, max_voxels,
|
||||
self.deterministic)
|
||||
|
||||
def __repr__(self):
|
||||
s = self.__class__.__name__ + '('
|
||||
|
@ -140,5 +174,6 @@ class Voxelization(nn.Module):
|
|||
s += ', point_cloud_range=' + str(self.point_cloud_range)
|
||||
s += ', max_num_points=' + str(self.max_num_points)
|
||||
s += ', max_voxels=' + str(self.max_voxels)
|
||||
s += ', deterministic=' + str(self.deterministic)
|
||||
s += ')'
|
||||
return s
|
||||
|
|
|
@ -60,3 +60,80 @@ def test_voxelization(device_type):
|
|||
assert np.all(
|
||||
points[indices] == expected_coors[i][:num_points_current_voxel])
|
||||
assert num_points_current_voxel == expected_num_points_per_voxel[i]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not torch.cuda.is_available(), reason='requires CUDA support')
|
||||
def test_voxelization_nondeterministic():
|
||||
voxel_size = [0.5, 0.5, 0.5]
|
||||
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
|
||||
|
||||
voxel_dict = np.load(
|
||||
'tests/data/for_3d_ops/test_voxel.npy', allow_pickle=True).item()
|
||||
points = voxel_dict['points']
|
||||
|
||||
points = torch.tensor(points)
|
||||
max_num_points = -1
|
||||
dynamic_voxelization = Voxelization(voxel_size, point_cloud_range,
|
||||
max_num_points)
|
||||
|
||||
max_num_points = 10
|
||||
max_voxels = 50
|
||||
hard_voxelization = Voxelization(
|
||||
voxel_size,
|
||||
point_cloud_range,
|
||||
max_num_points,
|
||||
max_voxels,
|
||||
deterministic=False)
|
||||
|
||||
# test hard_voxelization (non-deterministic version) on gpu
|
||||
points = torch.tensor(points).contiguous().to(device='cuda:0')
|
||||
voxels, coors, num_points_per_voxel = hard_voxelization.forward(points)
|
||||
coors = coors.cpu().detach().numpy().tolist()
|
||||
voxels = voxels.cpu().detach().numpy().tolist()
|
||||
num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy().tolist()
|
||||
|
||||
coors_all = dynamic_voxelization.forward(points)
|
||||
coors_all = coors_all.cpu().detach().numpy().tolist()
|
||||
|
||||
coors_set = set([tuple(c) for c in coors])
|
||||
coors_all_set = set([tuple(c) for c in coors_all])
|
||||
|
||||
assert len(coors_set) == len(coors)
|
||||
assert len(coors_set - coors_all_set) == 0
|
||||
|
||||
points = points.cpu().detach().numpy().tolist()
|
||||
|
||||
coors_points_dict = {}
|
||||
for c, ps in zip(coors_all, points):
|
||||
if tuple(c) not in coors_points_dict:
|
||||
coors_points_dict[tuple(c)] = set()
|
||||
coors_points_dict[tuple(c)].add(tuple(ps))
|
||||
|
||||
for c, ps, n in zip(coors, voxels, num_points_per_voxel):
|
||||
ideal_voxel_points_set = coors_points_dict[tuple(c)]
|
||||
voxel_points_set = set([tuple(p) for p in ps[:n]])
|
||||
assert len(voxel_points_set) == n
|
||||
if n < max_num_points:
|
||||
assert voxel_points_set == ideal_voxel_points_set
|
||||
for p in ps[n:]:
|
||||
assert max(p) == min(p) == 0
|
||||
else:
|
||||
assert len(voxel_points_set - ideal_voxel_points_set) == 0
|
||||
|
||||
# test hard_voxelization (non-deterministic version) on gpu
|
||||
# with all input point in range
|
||||
points = torch.tensor(points).contiguous().to(device='cuda:0')[:max_voxels]
|
||||
coors_all = dynamic_voxelization.forward(points)
|
||||
valid_mask = coors_all.ge(0).all(-1)
|
||||
points = points[valid_mask]
|
||||
coors_all = coors_all[valid_mask]
|
||||
coors_all = coors_all.cpu().detach().numpy().tolist()
|
||||
|
||||
voxels, coors, num_points_per_voxel = hard_voxelization.forward(points)
|
||||
coors = coors.cpu().detach().numpy().tolist()
|
||||
|
||||
coors_set = set([tuple(c) for c in coors])
|
||||
coors_all_set = set([tuple(c) for c in coors_all])
|
||||
|
||||
assert len(coors_set) == len(coors) == len(coors_all_set)
|
||||
|
|
Loading…
Reference in New Issue