mirror of https://github.com/open-mmlab/mmocr.git
Support windows (#89)
* Fix compile error for Chinese comments in windows * Fix pan on windows * Fix types * Update setup.py Co-authored-by: Hongbin Sun <hongbin306@gmail.com>pull/98/head
parent
997a89250e
commit
bba33a3f4f
|
@ -5,6 +5,7 @@
|
|||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/numpy.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
|
@ -42,7 +43,7 @@ namespace panet{
|
|||
|
||||
std::queue<std::tuple<int, int, int32_t>> q;
|
||||
// 计算各个kernel的similarity_vectors
|
||||
float kernel_vector[label_num][5] = {0};
|
||||
std::vector<std::array<float, 5>> kernel_vector(label_num);
|
||||
|
||||
// 文本像素入队列
|
||||
for (int i = 0; i<h; i++)
|
||||
|
|
|
@ -27,14 +27,13 @@ namespace pse_adaptor {
|
|||
};
|
||||
|
||||
void growing_text_line(const int *data,
|
||||
vector<long int> &data_shape,
|
||||
vector<pybind11::ssize_t> &data_shape,
|
||||
const int *label_map,
|
||||
vector<long int> &label_shape,
|
||||
vector<pybind11::ssize_t> &label_shape,
|
||||
int &label_num,
|
||||
float &min_area,
|
||||
vector<vector<int>> &text_line) {
|
||||
int area[label_num + 1];
|
||||
memset(area, 0, sizeof(area));
|
||||
std::vector<int> area(label_num + 1);
|
||||
for (int x = 0; x < label_shape[0]; ++x) {
|
||||
for (int y = 0; y < label_shape[1]; ++y) {
|
||||
int label = label_map[x * label_shape[1] + y];
|
||||
|
@ -100,11 +99,11 @@ namespace pse_adaptor {
|
|||
int label_num) {
|
||||
auto buf = quad_n9.request();
|
||||
auto data = static_cast<int *>(buf.ptr);
|
||||
vector<long int> data_shape = buf.shape;
|
||||
vector<pybind11::ssize_t> data_shape = buf.shape;
|
||||
|
||||
auto buf_label_map = label_map.request();
|
||||
auto data_label_map = static_cast<int32_t *>(buf_label_map.ptr);
|
||||
vector<long int> label_map_shape = buf_label_map.shape;
|
||||
vector<pybind11::ssize_t> label_map_shape = buf_label_map.shape;
|
||||
|
||||
vector<vector<int>> text_line;
|
||||
|
||||
|
|
|
@ -97,8 +97,8 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data,
|
|||
const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
|
||||
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
||||
|
@ -215,8 +215,8 @@ __global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff,
|
|||
const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
|
||||
|
||||
// We use roi_bin_grid to sample the grid and mimic integral
|
||||
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
|
||||
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_height / pooled_height); // e.g., = 2
|
||||
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
|
||||
|
||||
// We do average (integral) pooling inside a bin
|
||||
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
|
||||
|
@ -272,7 +272,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
|
|||
auto output_size = num_rois * pooled_height * pooled_width * channels;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L));
|
||||
dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
if (output.numel() == 0) {
|
||||
|
@ -317,7 +317,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
|
|||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L));
|
||||
dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
// handle possibly empty gradients
|
||||
|
|
|
@ -27,10 +27,10 @@ __global__ void RoIPoolFForward(const int nthreads, const T* bottom_data,
|
|||
|
||||
const T* offset_bottom_rois = bottom_rois + n * 5;
|
||||
int roi_batch_ind = offset_bottom_rois[0];
|
||||
int roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
|
||||
int roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
|
||||
int roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
|
||||
int roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
|
||||
int roi_start_w = roundf(offset_bottom_rois[1] * spatial_scale);
|
||||
int roi_start_h = roundf(offset_bottom_rois[2] * spatial_scale);
|
||||
int roi_end_w = roundf(offset_bottom_rois[3] * spatial_scale);
|
||||
int roi_end_h = roundf(offset_bottom_rois[4] * spatial_scale);
|
||||
|
||||
// Force malformed ROIs to be 1x1
|
||||
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
|
||||
|
@ -40,13 +40,13 @@ __global__ void RoIPoolFForward(const int nthreads, const T* bottom_data,
|
|||
T bin_size_w = static_cast<T>(roi_width)
|
||||
/ static_cast<T>(pooled_width);
|
||||
|
||||
int hstart = static_cast<int>(floor(static_cast<T>(ph)
|
||||
int hstart = static_cast<int>(floorf(static_cast<T>(ph)
|
||||
* bin_size_h));
|
||||
int wstart = static_cast<int>(floor(static_cast<T>(pw)
|
||||
int wstart = static_cast<int>(floorf(static_cast<T>(pw)
|
||||
* bin_size_w));
|
||||
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1)
|
||||
int hend = static_cast<int>(ceilf(static_cast<T>(ph + 1)
|
||||
* bin_size_h));
|
||||
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1)
|
||||
int wend = static_cast<int>(ceilf(static_cast<T>(pw + 1)
|
||||
* bin_size_w));
|
||||
|
||||
// Add roi offsets and clip to input boundaries
|
||||
|
@ -126,7 +126,7 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
|
|||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L));
|
||||
dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
if (output.numel() == 0) {
|
||||
|
@ -173,7 +173,7 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
|
|||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L));
|
||||
dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
// handle possibly empty gradients
|
||||
|
|
|
@ -85,10 +85,10 @@ __global__ void RROIAlignForward(
|
|||
P[6] = M[0][0]*(pw+1)+M[0][1]*(ph+1)+M[0][2];
|
||||
P[7] = M[1][0]*(pw+1)+M[1][1]*(ph+1)+M[1][2];
|
||||
|
||||
T leftMost = (max(round(min(min(P[0],P[2]),min(P[4],P[6]))),0.0));
|
||||
T rightMost= (min(round(max(max(P[0],P[2]),max(P[4],P[6]))),imageWidth-1.0));
|
||||
T topMost= (max(round(min(min(P[1],P[3]),min(P[5],P[7]))),0.0));
|
||||
T bottomMost= (min(round(max(max(P[1],P[3]),max(P[5],P[7]))),imageHeight-1.0));
|
||||
T leftMost = (max(roundf(min(min(P[0],P[2]),min(P[4],P[6]))),0.0));
|
||||
T rightMost= (min(roundf(max(max(P[0],P[2]),max(P[4],P[6]))),imageWidth-1.0));
|
||||
T topMost= (max(roundf(min(min(P[1],P[3]),min(P[5],P[7]))),0.0));
|
||||
T bottomMost= (min(roundf(max(max(P[1],P[3]),max(P[5],P[7]))),imageHeight-1.0));
|
||||
|
||||
//float maxval = 0;
|
||||
//int maxidx = -1;
|
||||
|
@ -106,10 +106,10 @@ __global__ void RROIAlignForward(
|
|||
float bin_cx = (leftMost + rightMost) / 2.0; // shift
|
||||
float bin_cy = (topMost + bottomMost) / 2.0;
|
||||
|
||||
int bin_l = (int)floor(bin_cx);
|
||||
int bin_r = (int)ceil(bin_cx);
|
||||
int bin_t = (int)floor(bin_cy);
|
||||
int bin_b = (int)ceil(bin_cy);
|
||||
int bin_l = (int)floorf(bin_cx);
|
||||
int bin_r = (int)ceilf(bin_cx);
|
||||
int bin_t = (int)floorf(bin_cy);
|
||||
int bin_b = (int)ceilf(bin_cy);
|
||||
|
||||
T lt_value = 0.0;
|
||||
if (bin_t > 0 && bin_l > 0 && bin_t < height && bin_l < width)
|
||||
|
@ -124,8 +124,8 @@ __global__ void RROIAlignForward(
|
|||
if (bin_b > 0 && bin_r > 0 && bin_b < height && bin_r < width)
|
||||
rb_value = offset_bottom_data[bin_b * width + bin_r];
|
||||
|
||||
T rx = bin_cx - floor(bin_cx);
|
||||
T ry = bin_cy - floor(bin_cy);
|
||||
T rx = bin_cx - floorf(bin_cx);
|
||||
T ry = bin_cy - floorf(bin_cy);
|
||||
|
||||
T wlt = (1.0 - rx) * (1.0 - ry);
|
||||
T wrt = rx * (1.0 - ry);
|
||||
|
@ -206,7 +206,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RROIAlign_forward_cuda(
|
|||
auto con_idx_x = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kFloat));
|
||||
auto con_idx_y = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kFloat));
|
||||
|
||||
dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L));
|
||||
dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
//const int kThreadsPerBlock = 1024;
|
||||
|
@ -276,8 +276,8 @@ __global__ void RROIAlignBackward(
|
|||
float bh = con_idx_y[index];
|
||||
//if (bh > 0.00001 && bw > 0.00001 && bw < height-1 && bw < width-1){
|
||||
|
||||
int bin_xs = int(floor(bw));
|
||||
int bin_ys = int(floor(bh));
|
||||
int bin_xs = int(floorf(bw));
|
||||
int bin_ys = int(floorf(bh));
|
||||
|
||||
float rx = bw - float(bin_xs);
|
||||
float ry = bh - float(bin_ys);
|
||||
|
@ -295,10 +295,10 @@ __global__ void RROIAlignBackward(
|
|||
//int max_x = max(min(bin_xs + 1, width - 1), 0);
|
||||
//int max_y = max(min(bin_ys + 1, height - 1), 0);
|
||||
|
||||
int min_x = (int)floor(bw);
|
||||
int max_x = (int)ceil(bw);
|
||||
int min_y = (int)floor(bh);
|
||||
int max_y = (int)ceil(bh);
|
||||
int min_x = (int)floorf(bw);
|
||||
int max_x = (int)ceilf(bw);
|
||||
int min_y = (int)floorf(bh);
|
||||
int max_y = (int)ceilf(bh);
|
||||
|
||||
T top_diff_of_bin = top_diff[index];
|
||||
|
||||
|
@ -345,7 +345,7 @@ at::Tensor RROIAlign_backward_cuda(const at::Tensor& grad,
|
|||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L));
|
||||
dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
|
||||
dim3 block(512);
|
||||
|
||||
// handle possibly empty gradients
|
||||
|
|
10
setup.py
10
setup.py
|
@ -1,5 +1,6 @@
|
|||
import glob
|
||||
import os
|
||||
import sys
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
import torch
|
||||
|
@ -14,6 +15,7 @@ def readme():
|
|||
|
||||
|
||||
version_file = 'mmocr/version.py'
|
||||
is_windows = sys.platform == 'win32'
|
||||
|
||||
|
||||
def get_version():
|
||||
|
@ -108,7 +110,7 @@ def get_rroi_align_extensions():
|
|||
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
|
||||
sources = main_file + source_cpu
|
||||
extension = CppExtension
|
||||
extra_compile_args = {'cxx': []}
|
||||
extra_compile_args = {'cxx': ['/utf-8']} if is_windows else {'cxx': []}
|
||||
define_macros = []
|
||||
|
||||
if torch.cuda.is_available() and CUDA_HOME is not None:
|
||||
|
@ -176,10 +178,12 @@ if __name__ == '__main__':
|
|||
ext_modules=[
|
||||
CppExtension(
|
||||
name='mmocr.models.textdet.postprocess.pan',
|
||||
sources=[cpp_root + 'pan.cpp']),
|
||||
sources=[cpp_root + 'pan.cpp'],
|
||||
extra_compile_args=(['/utf-8'] if is_windows else [])),
|
||||
CppExtension(
|
||||
name='mmocr.models.textdet.postprocess.pse',
|
||||
sources=[cpp_root + 'pse.cpp']),
|
||||
sources=[cpp_root + 'pse.cpp'],
|
||||
extra_compile_args=(['/utf-8'] if is_windows else [])),
|
||||
get_rroi_align_extensions()
|
||||
],
|
||||
cmdclass={'build_ext': BuildExtension},
|
||||
|
|
Loading…
Reference in New Issue