Add pixel group and contour expand ops (#993)

* add pixel group ops

* reformatting

* formatting&rm auto

* Add citation

* Add contour expand

* c++ linting

* Add unit tests with Tensor

* rm model.pth

* rename

* c++ linting

* c++ linting

* Rename variables
pull/1020/head
jeffreykuang 2021-05-12 10:41:22 +08:00 committed by GitHub
parent c77e95a65f
commit 2623fbf21c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 482 additions and 1 deletions

View File

@ -2,6 +2,7 @@ from .bbox import bbox_overlaps
from .box_iou_rotated import box_iou_rotated
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
from .contour_expand import contour_expand
from .corner_pool import CornerPool
from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
@ -20,6 +21,7 @@ from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
modulated_deform_conv2d)
from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
from .pixel_group import pixel_group
from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point)
from .psa_mask import PSAMask
@ -46,5 +48,5 @@ __all__ = [
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated'
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand'
]

View File

@ -0,0 +1,37 @@
import numpy as np
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['contour_expand'])
def contour_expand(kernel_mask, internal_kernel_label, min_kernel_area,
kernel_num):
"""Expand kernel contours so that foreground pixels are assigned into
instances.
Arguments:
kernel_mask (np.array or Tensor): The instance kernel mask with
size hxw.
internal_kernel_label (np.array or Tensor): The instance internal
kernel label with size hxw.
min_kernel_area (int): The minimum kernel area.
kernel_num (int): The instance kernel number.
Returns:
label (np.array or Tensor): The instance index map with size hxw.
"""
assert isinstance(kernel_mask, (torch.Tensor, np.ndarray))
assert isinstance(internal_kernel_label, (torch.Tensor, np.ndarray))
assert isinstance(min_kernel_area, int)
assert isinstance(kernel_num, int)
if isinstance(kernel_mask, np.ndarray):
kernel_mask = torch.from_numpy(kernel_mask)
if isinstance(internal_kernel_label, np.ndarray):
internal_kernel_label = torch.from_numpy(internal_kernel_label)
label = ext_module.contour_expand(kernel_mask, internal_kernel_label,
min_kernel_area, kernel_num)
return label

View File

@ -0,0 +1,111 @@
// It is modified from https://github.com/whai362/PSENet
#include <iostream>
#include <queue>
#include "pytorch_cpp_helper.hpp"
using namespace std;
class Point2d {
public:
int x;
int y;
Point2d() : x(0), y(0) {}
Point2d(int _x, int _y) : x(_x), y(_y) {}
};
void kernel_dilate(const uint8_t *data, IntArrayRef data_shape,
const int *label_map, int &label_num, int &min_area,
vector<vector<int>> &text_line) {
std::vector<int> area(label_num + 1);
int kernel_num = data_shape[0];
int height = data_shape[1];
int width = data_shape[2];
for (int x = 0; x < height; ++x) {
for (int y = 0; y < width; ++y) {
int label = label_map[x * width + y];
if (label == 0) continue;
area[label] += 1;
}
}
queue<Point2d> queue, next_queue;
for (int x = 0; x < height; ++x) {
vector<int> row(width);
for (int y = 0; y < width; ++y) {
int label = label_map[x * width + y];
if (label == 0) continue;
if (area[label] < min_area) continue;
Point2d point(x, y);
queue.push(point);
row[y] = label;
}
text_line.emplace_back(row);
}
int dx[] = {-1, 1, 0, 0};
int dy[] = {0, 0, -1, 1};
vector<int> kernel_step(kernel_num);
std::for_each(kernel_step.begin(), kernel_step.end(),
[=](int &k) { return k * height * width; });
for (int kernel_id = kernel_num - 2; kernel_id >= 0; --kernel_id) {
while (!queue.empty()) {
Point2d point = queue.front();
queue.pop();
int x = point.x;
int y = point.y;
int label = text_line[x][y];
bool is_edge = true;
for (int d = 0; d < 4; ++d) {
int tmp_x = x + dx[d];
int tmp_y = y + dy[d];
if (tmp_x < 0 || tmp_x >= height) continue;
if (tmp_y < 0 || tmp_y >= width) continue;
int kernel_value = data[kernel_step[kernel_id] + tmp_x * width + tmp_y];
if (kernel_value == 0) continue;
if (text_line[tmp_x][tmp_y] > 0) continue;
Point2d point(tmp_x, tmp_y);
queue.push(point);
text_line[tmp_x][tmp_y] = label;
is_edge = false;
}
if (is_edge) {
next_queue.push(point);
}
}
swap(queue, next_queue);
}
}
std::vector<std::vector<int>> contour_expand(Tensor kernel_mask,
Tensor internal_kernel_label,
int min_kernel_area,
int kernel_num) {
kernel_mask = kernel_mask.contiguous();
internal_kernel_label = internal_kernel_label.contiguous();
assert(kernel_mask.dim() == 3);
assert(internal_kernel_label.dim() == 2);
assert(kernel_mask.size(1) == internal_kernel_label.size(0));
assert(kernel_mask.size(2) == internal_kernel_label.size(1));
CHECK_CPU_INPUT(kernel_mask);
CHECK_CPU_INPUT(internal_kernel_label);
auto ptr_data = kernel_mask.data_ptr<uint8_t>();
IntArrayRef data_shape = kernel_mask.sizes();
auto data_label_map = internal_kernel_label.data_ptr<int32_t>();
IntArrayRef label_map_shape = internal_kernel_label.sizes();
vector<vector<int>> text_line;
kernel_dilate(ptr_data, data_shape, data_label_map, kernel_num,
min_kernel_area, text_line);
return text_line;
}

View File

@ -0,0 +1,136 @@
// It is modified from https://github.com/WenmuZhou/PAN.pytorch
#include "pytorch_cpp_helper.hpp"
std::vector<std::vector<float>> estimate_confidence(int32_t* label,
float* score, int label_num,
int height, int width) {
std::vector<std::vector<float>> point_vector;
for (int i = 0; i < label_num; i++) {
std::vector<float> point;
point.push_back(0);
point.push_back(0);
point_vector.push_back(point);
}
for (int y = 0; y < height; y++) {
auto label_tmp = label + y * width;
auto score_tmp = score + y * width;
for (int x = 0; x < width; x++) {
auto l = label_tmp[x];
if (l > 0) {
float confidence = score_tmp[x];
point_vector[l].push_back(x);
point_vector[l].push_back(y);
point_vector[l][0] += confidence;
point_vector[l][1] += 1;
}
}
}
for (int l = 0; l < point_vector.size(); l++)
if (point_vector[l][1] > 0) {
point_vector[l][0] /= point_vector[l][1];
}
return point_vector;
}
std::vector<std::vector<float>> pixel_group_cpu(
Tensor score, Tensor mask, Tensor embedding, Tensor kernel_label,
Tensor kernel_contour, int kernel_region_num, float dis_threshold) {
assert(score.dim() == 2);
assert(mask.dim() == 2);
assert(embedding_dim.dim() == 3);
int height = score.size(0);
int width = score.size(1);
assert(height == mask.size(0) == embedding.size(1) == kernel_label.size(1));
assert(width == mask.size(1) == embedding.size(2) == kernel_label.size(2));
auto threshold_square = dis_threshold * dis_threshold;
auto ptr_score = score.data_ptr<float>();
auto ptr_mask = mask.data_ptr<bool>();
auto ptr_kernel_contour = kernel_contour.data_ptr<uint8_t>();
auto ptr_embedding = embedding.data_ptr<float>();
auto ptr_kernel_label = kernel_label.data_ptr<int32_t>();
std::queue<std::tuple<int, int, int32_t>> contour_pixels;
auto embedding_dim = embedding.size(2);
std::vector<std::vector<float>> kernel_vector(
kernel_region_num, std::vector<float>(embedding_dim + 1, 0));
Tensor text_label;
text_label = kernel_label.clone();
auto ptr_text_label = text_label.data_ptr<int32_t>();
for (int i = 0; i < height; i++) {
auto ptr_embedding_tmp = ptr_embedding + i * width * embedding_dim;
auto ptr_kernel_label_tmp = ptr_kernel_label + i * width;
auto ptr_kernel_contour_tmp = ptr_kernel_contour + i * width;
for (int j = 0, k = 0; j < width && k < width * embedding_dim;
j++, k += embedding_dim) {
int32_t label = ptr_kernel_label_tmp[j];
if (label > 0) {
for (int d = 0; d < embedding_dim; d++)
kernel_vector[label][d] += ptr_embedding_tmp[k + d];
kernel_vector[label][embedding_dim] += 1;
// kernel pixel number
if (ptr_kernel_contour_tmp[j]) {
contour_pixels.push(std::make_tuple(i, j, label));
}
}
}
}
for (int i = 0; i < kernel_region_num; i++) {
for (int j = 0; j < embedding_dim; j++) {
kernel_vector[i][j] /= kernel_vector[i][embedding_dim];
}
}
int dx[4] = {-1, 1, 0, 0};
int dy[4] = {0, 0, -1, 1};
while (!contour_pixels.empty()) {
auto query_pixel = contour_pixels.front();
contour_pixels.pop();
int y = std::get<0>(query_pixel);
int x = std::get<1>(query_pixel);
int32_t l = std::get<2>(query_pixel);
auto kernel_cv = kernel_vector[l];
for (int idx = 0; idx < 4; idx++) {
int tmpy = y + dy[idx];
int tmpx = x + dx[idx];
auto ptr_text_label_tmp = ptr_text_label + tmpy * width;
if (tmpy < 0 || tmpy >= height || tmpx < 0 || tmpx >= width) continue;
if (!ptr_mask[tmpy * width + tmpx] || ptr_text_label_tmp[tmpx] > 0)
continue;
float dis = 0;
auto ptr_embedding_tmp = ptr_embedding + tmpy * width * embedding_dim;
for (size_t i = 0; i < embedding_dim; i++) {
dis +=
pow(kernel_cv[i] - ptr_embedding_tmp[tmpx * embedding_dim + i], 2);
// ignore further computing if dis is big enough
if (dis >= threshold_square) break;
}
if (dis >= threshold_square) continue;
contour_pixels.push(std::make_tuple(tmpy, tmpx, l));
ptr_text_label_tmp[tmpx] = l;
}
}
return estimate_confidence(ptr_text_label, ptr_score, kernel_region_num,
height, width);
}
std::vector<std::vector<float>> pixel_group(
Tensor score, Tensor mask, Tensor embedding, Tensor kernel_label,
Tensor kernel_contour, int kernel_region_num, float distance_threshold) {
score = score.contiguous();
mask = mask.contiguous();
embedding = embedding.contiguous();
kernel_label = kernel_label.contiguous();
kernel_contour = kernel_contour.contiguous();
CHECK_CPU_INPUT(score);
CHECK_CPU_INPUT(mask);
CHECK_CPU_INPUT(embedding);
CHECK_CPU_INPUT(kernel_label);
CHECK_CPU_INPUT(kernel_contour);
return pixel_group_cpu(score, mask, embedding, kernel_label, kernel_contour,
kernel_region_num, distance_threshold);
}

View File

@ -112,6 +112,15 @@ Tensor softnms(Tensor boxes, Tensor scores, Tensor dets, float iou_threshold,
std::vector<std::vector<int> > nms_match(Tensor dets, float iou_threshold);
std::vector<std::vector<float> > pixel_group(
Tensor score, Tensor mask, Tensor embedding, Tensor kernel_label,
Tensor kernel_contour, int kernel_region_num, float distance_threshold);
std::vector<std::vector<int> > contour_expand(Tensor kernel_mask,
Tensor internal_kernel_label,
int min_kernel_area,
int kernel_num);
void roi_align_forward(Tensor input, Tensor rois, Tensor output,
Tensor argmax_y, Tensor argmax_x, int aligned_height,
int aligned_width, float spatial_scale,
@ -325,6 +334,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("offset"));
m.def("nms_match", &nms_match, "nms_match (CPU) ", py::arg("dets"),
py::arg("iou_threshold"));
m.def("pixel_group", &pixel_group, "pixel group (CPU) ", py::arg("score"),
py::arg("mask"), py::arg("embedding"), py::arg("kernel_label"),
py::arg("kernel_contour"), py::arg("kernel_region_label"),
py::arg("distance_threshold"));
m.def("contour_expand", &contour_expand, "contour exapnd (CPU) ",
py::arg("kernel_mask"), py::arg("internal_kernel_label"),
py::arg("min_kernel_area"), py::arg("kernel_num"));
m.def("roi_align_forward", &roi_align_forward, "roi_align forward",
py::arg("input"), py::arg("rois"), py::arg("output"),
py::arg("argmax_y"), py::arg("argmax_x"), py::arg("aligned_height"),

View File

@ -0,0 +1,54 @@
import numpy as np
import torch
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext', ['pixel_group'])
def pixel_group(score, mask, embedding, kernel_label, kernel_contour,
kernel_region_num, distance_threshold):
"""Group pixels into text instances, which is widely used text detection
methods.
Arguments:
score (np.array or Tensor): The foreground score with size hxw.
mask (np.array or Tensor): The foreground mask with size hxw.
embedding (np.array or Tensor): The emdedding with size hxwxc to
distinguish instances.
kernel_label (np.array or Tensor): The instance kernel index with
size hxw.
kernel_contour (np.array or Tensor): The kernel contour with size hxw.
kernel_region_num (int): The instance kernel region number.
distance_threshold (float): The embedding distance threshold between
kernel and pixel in one instance.
Returns:
pixel_assignment (List[List[float]]): The instance coordinate list.
Each element consists of averaged confidence, pixel number, and
coordinates (x_i, y_i for all pixels) in order.
"""
assert isinstance(score, (torch.Tensor, np.ndarray))
assert isinstance(mask, (torch.Tensor, np.ndarray))
assert isinstance(embedding, (torch.Tensor, np.ndarray))
assert isinstance(kernel_label, (torch.Tensor, np.ndarray))
assert isinstance(kernel_contour, (torch.Tensor, np.ndarray))
assert isinstance(kernel_region_num, int)
assert isinstance(distance_threshold, float)
if isinstance(score, np.ndarray):
score = torch.from_numpy(score)
if isinstance(mask, np.ndarray):
mask = torch.from_numpy(mask)
if isinstance(embedding, np.ndarray):
embedding = torch.from_numpy(embedding)
if isinstance(kernel_label, np.ndarray):
kernel_label = torch.from_numpy(kernel_label)
if isinstance(kernel_contour, np.ndarray):
kernel_contour = torch.from_numpy(kernel_contour)
pixel_assignment = ext_module.pixel_group(score, mask, embedding,
kernel_label, kernel_contour,
kernel_region_num,
distance_threshold)
return pixel_assignment

View File

@ -0,0 +1,48 @@
import numpy as np
import torch
def test_contour_expand():
from mmcv.ops import contour_expand
np_internal_kernel_label = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 0, 0, 0, 0, 2, 0],
[0, 0, 1, 1, 0, 0, 0, 0, 2, 0],
[0, 0, 1, 1, 0, 0, 0, 0, 2, 0],
[0, 0, 1, 1, 0, 0, 0, 0, 2, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0,
0]]).astype(np.int32)
np_kernel_mask1 = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 1, 1, 1, 1, 1, 0],
[0, 0, 1, 1, 1, 1, 1, 1, 1, 0],
[0, 0, 1, 1, 1, 1, 1, 1, 1, 0],
[0, 0, 1, 1, 1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0,
0]]).astype(np.uint8)
np_kernel_mask2 = (np_internal_kernel_label > 0).astype(np.uint8)
np_kernel_mask = np.stack([np_kernel_mask1, np_kernel_mask2])
min_area = 1
kernel_region_num = 3
result = contour_expand(np_kernel_mask, np_internal_kernel_label, min_area,
kernel_region_num)
gt = [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 1, 1, 1, 2, 2, 2, 0],
[0, 0, 1, 1, 1, 1, 2, 2, 2, 0], [0, 0, 1, 1, 1, 1, 2, 2, 2, 0],
[0, 0, 1, 1, 1, 1, 2, 2, 2, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
assert np.allclose(result, gt)
np_kernel_mask_t = torch.from_numpy(np_kernel_mask)
np_internal_kernel_label_t = torch.from_numpy(np_internal_kernel_label)
result = contour_expand(np_kernel_mask_t, np_internal_kernel_label_t,
min_area, kernel_region_num)
assert np.allclose(result, gt)

View File

@ -0,0 +1,77 @@
import numpy as np
import torch
def test_pixel_group():
from mmcv.ops import pixel_group
np_score = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0],
[0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0],
[0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0],
[0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0.9, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]).astype(np.float32)
np_mask = (np_score > 0.5)
np_embedding = np.zeros((10, 10, 8)).astype(np.float32)
np_embedding[:, :7] = 0.9
np_embedding[:, 7:] = 10.0
np_kernel_label = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 1, 0, 0, 0, 2, 0],
[0, 0, 1, 1, 1, 0, 0, 0, 2, 0],
[0, 0, 1, 1, 1, 0, 0, 0, 2, 0],
[0, 0, 1, 1, 1, 0, 0, 0, 2, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0,
0]]).astype(np.int32)
np_kernel_contour = np.array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 1, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 1, 1, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0,
0]]).astype(np.uint8)
kernel_region_num = 3
distance_threshold = float(0.8)
result = pixel_group(np_score, np_mask, np_embedding, np_kernel_label,
np_kernel_contour, kernel_region_num,
distance_threshold)
gt_1 = [
0.8999997973442078, 24.0, 1.0, 3.0, 2.0, 3.0, 3.0, 3.0, 4.0, 3.0, 5.0,
3.0, 6.0, 3.0, 1.0, 4.0, 2.0, 4.0, 3.0, 4.0, 4.0, 4.0, 5.0, 4.0, 6.0,
4.0, 1.0, 5.0, 2.0, 5.0, 3.0, 5.0, 4.0, 5.0, 5.0, 5.0, 6.0, 5.0, 1.0,
6.0, 2.0, 6.0, 3.0, 6.0, 4.0, 6.0, 5.0, 6.0, 6.0, 6.0
]
gt_2 = [
0.9000000357627869, 8.0, 7.0, 3.0, 8.0, 3.0, 7.0, 4.0, 8.0, 4.0, 7.0,
5.0, 8.0, 5.0, 7.0, 6.0, 8.0, 6.0
]
assert np.allclose(result[0], [0, 0])
assert np.allclose(result[1], gt_1)
assert np.allclose(result[2], gt_2)
# test torch Tensor
np_score_t = torch.from_numpy(np_score)
np_mask_t = torch.from_numpy(np_mask)
np_embedding_t = torch.from_numpy(np_embedding)
np_kernel_label_t = torch.from_numpy(np_kernel_label)
np_kernel_contour_t = torch.from_numpy(np_kernel_contour)
result = pixel_group(np_score_t, np_mask_t, np_embedding_t,
np_kernel_label_t, np_kernel_contour_t,
kernel_region_num, distance_threshold)
assert np.allclose(result[0], [0, 0])
assert np.allclose(result[1], gt_1)
assert np.allclose(result[2], gt_2)