Support windows (#89)

* Fix compile error for Chinese comments in windows

* Fix pan on windows

* Fix types

* Update setup.py

Co-authored-by: Hongbin Sun <hongbin306@gmail.com>
pull/98/head
lizz 2021-04-19 19:26:00 +08:00 committed by GitHub
parent 997a89250e
commit bba33a3f4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 48 additions and 44 deletions

View File

@ -5,6 +5,7 @@
#include <map>
#include <algorithm>
#include <vector>
#include <array>
#include "include/pybind11/pybind11.h"
#include "include/pybind11/numpy.h"
#include "include/pybind11/stl.h"
@ -42,7 +43,7 @@ namespace panet{
std::queue<std::tuple<int, int, int32_t>> q;
// 计算各个kernel的similarity_vectors
float kernel_vector[label_num][5] = {0};
std::vector<std::array<float, 5>> kernel_vector(label_num);
// 文本像素入队列
for (int i = 0; i<h; i++)

View File

@ -27,14 +27,13 @@ namespace pse_adaptor {
};
void growing_text_line(const int *data,
vector<long int> &data_shape,
vector<pybind11::ssize_t> &data_shape,
const int *label_map,
vector<long int> &label_shape,
vector<pybind11::ssize_t> &label_shape,
int &label_num,
float &min_area,
vector<vector<int>> &text_line) {
int area[label_num + 1];
memset(area, 0, sizeof(area));
std::vector<int> area(label_num + 1);
for (int x = 0; x < label_shape[0]; ++x) {
for (int y = 0; y < label_shape[1]; ++y) {
int label = label_map[x * label_shape[1] + y];
@ -100,11 +99,11 @@ namespace pse_adaptor {
int label_num) {
auto buf = quad_n9.request();
auto data = static_cast<int *>(buf.ptr);
vector<long int> data_shape = buf.shape;
vector<pybind11::ssize_t> data_shape = buf.shape;
auto buf_label_map = label_map.request();
auto data_label_map = static_cast<int32_t *>(buf_label_map.ptr);
vector<long int> label_map_shape = buf_label_map.shape;
vector<pybind11::ssize_t> label_map_shape = buf_label_map.shape;
vector<vector<int>> text_line;

View File

@ -97,8 +97,8 @@ __global__ void RoIAlignForward(const int nthreads, const T* bottom_data,
const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
@ -215,8 +215,8 @@ __global__ void RoIAlignBackwardFeature(const int nthreads, const T* top_diff,
const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw];
// We use roi_bin_grid to sample the grid and mimic integral
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width);
int roi_bin_grid_h = (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_height / pooled_height); // e.g., = 2
int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);
// We do average (integral) pooling inside a bin
const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4
@ -272,7 +272,7 @@ at::Tensor ROIAlign_forward_cuda(const at::Tensor& input,
auto output_size = num_rois * pooled_height * pooled_width * channels;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L));
dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
dim3 block(512);
if (output.numel() == 0) {
@ -317,7 +317,7 @@ at::Tensor ROIAlign_backward_cuda(const at::Tensor& grad,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L));
dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
dim3 block(512);
// handle possibly empty gradients

View File

@ -27,10 +27,10 @@ __global__ void RoIPoolFForward(const int nthreads, const T* bottom_data,
const T* offset_bottom_rois = bottom_rois + n * 5;
int roi_batch_ind = offset_bottom_rois[0];
int roi_start_w = round(offset_bottom_rois[1] * spatial_scale);
int roi_start_h = round(offset_bottom_rois[2] * spatial_scale);
int roi_end_w = round(offset_bottom_rois[3] * spatial_scale);
int roi_end_h = round(offset_bottom_rois[4] * spatial_scale);
int roi_start_w = roundf(offset_bottom_rois[1] * spatial_scale);
int roi_start_h = roundf(offset_bottom_rois[2] * spatial_scale);
int roi_end_w = roundf(offset_bottom_rois[3] * spatial_scale);
int roi_end_h = roundf(offset_bottom_rois[4] * spatial_scale);
// Force malformed ROIs to be 1x1
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
@ -40,13 +40,13 @@ __global__ void RoIPoolFForward(const int nthreads, const T* bottom_data,
T bin_size_w = static_cast<T>(roi_width)
/ static_cast<T>(pooled_width);
int hstart = static_cast<int>(floor(static_cast<T>(ph)
int hstart = static_cast<int>(floorf(static_cast<T>(ph)
* bin_size_h));
int wstart = static_cast<int>(floor(static_cast<T>(pw)
int wstart = static_cast<int>(floorf(static_cast<T>(pw)
* bin_size_w));
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1)
int hend = static_cast<int>(ceilf(static_cast<T>(ph + 1)
* bin_size_h));
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1)
int wend = static_cast<int>(ceilf(static_cast<T>(pw + 1)
* bin_size_w));
// Add roi offsets and clip to input boundaries
@ -126,7 +126,7 @@ std::tuple<at::Tensor, at::Tensor> ROIPool_forward_cuda(const at::Tensor& input,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L));
dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
dim3 block(512);
if (output.numel() == 0) {
@ -173,7 +173,7 @@ at::Tensor ROIPool_backward_cuda(const at::Tensor& grad,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L));
dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
dim3 block(512);
// handle possibly empty gradients

View File

@ -85,10 +85,10 @@ __global__ void RROIAlignForward(
P[6] = M[0][0]*(pw+1)+M[0][1]*(ph+1)+M[0][2];
P[7] = M[1][0]*(pw+1)+M[1][1]*(ph+1)+M[1][2];
T leftMost = (max(round(min(min(P[0],P[2]),min(P[4],P[6]))),0.0));
T rightMost= (min(round(max(max(P[0],P[2]),max(P[4],P[6]))),imageWidth-1.0));
T topMost= (max(round(min(min(P[1],P[3]),min(P[5],P[7]))),0.0));
T bottomMost= (min(round(max(max(P[1],P[3]),max(P[5],P[7]))),imageHeight-1.0));
T leftMost = (max(roundf(min(min(P[0],P[2]),min(P[4],P[6]))),0.0));
T rightMost= (min(roundf(max(max(P[0],P[2]),max(P[4],P[6]))),imageWidth-1.0));
T topMost= (max(roundf(min(min(P[1],P[3]),min(P[5],P[7]))),0.0));
T bottomMost= (min(roundf(max(max(P[1],P[3]),max(P[5],P[7]))),imageHeight-1.0));
//float maxval = 0;
//int maxidx = -1;
@ -106,10 +106,10 @@ __global__ void RROIAlignForward(
float bin_cx = (leftMost + rightMost) / 2.0; // shift
float bin_cy = (topMost + bottomMost) / 2.0;
int bin_l = (int)floor(bin_cx);
int bin_r = (int)ceil(bin_cx);
int bin_t = (int)floor(bin_cy);
int bin_b = (int)ceil(bin_cy);
int bin_l = (int)floorf(bin_cx);
int bin_r = (int)ceilf(bin_cx);
int bin_t = (int)floorf(bin_cy);
int bin_b = (int)ceilf(bin_cy);
T lt_value = 0.0;
if (bin_t > 0 && bin_l > 0 && bin_t < height && bin_l < width)
@ -124,8 +124,8 @@ __global__ void RROIAlignForward(
if (bin_b > 0 && bin_r > 0 && bin_b < height && bin_r < width)
rb_value = offset_bottom_data[bin_b * width + bin_r];
T rx = bin_cx - floor(bin_cx);
T ry = bin_cy - floor(bin_cy);
T rx = bin_cx - floorf(bin_cx);
T ry = bin_cy - floorf(bin_cy);
T wlt = (1.0 - rx) * (1.0 - ry);
T wrt = rx * (1.0 - ry);
@ -206,7 +206,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> RROIAlign_forward_cuda(
auto con_idx_x = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kFloat));
auto con_idx_y = at::zeros({num_rois, channels, pooled_height, pooled_width}, input.options().dtype(at::kFloat));
dim3 grid(std::min(THCCeilDiv(output_size, 512L), 4096L));
dim3 grid(std::min(THCCeilDiv((long)output_size, 512L), 4096L));
dim3 block(512);
//const int kThreadsPerBlock = 1024;
@ -276,8 +276,8 @@ __global__ void RROIAlignBackward(
float bh = con_idx_y[index];
//if (bh > 0.00001 && bw > 0.00001 && bw < height-1 && bw < width-1){
int bin_xs = int(floor(bw));
int bin_ys = int(floor(bh));
int bin_xs = int(floorf(bw));
int bin_ys = int(floorf(bh));
float rx = bw - float(bin_xs);
float ry = bh - float(bin_ys);
@ -295,10 +295,10 @@ __global__ void RROIAlignBackward(
//int max_x = max(min(bin_xs + 1, width - 1), 0);
//int max_y = max(min(bin_ys + 1, height - 1), 0);
int min_x = (int)floor(bw);
int max_x = (int)ceil(bw);
int min_y = (int)floor(bh);
int max_y = (int)ceil(bh);
int min_x = (int)floorf(bw);
int max_x = (int)ceilf(bw);
int min_y = (int)floorf(bh);
int max_y = (int)ceilf(bh);
T top_diff_of_bin = top_diff[index];
@ -345,7 +345,7 @@ at::Tensor RROIAlign_backward_cuda(const at::Tensor& grad,
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 grid(std::min(THCCeilDiv(grad.numel(), 512L), 4096L));
dim3 grid(std::min(THCCeilDiv((long)grad.numel(), 512L), 4096L));
dim3 block(512);
// handle possibly empty gradients

View File

@ -1,5 +1,6 @@
import glob
import os
import sys
from setuptools import find_packages, setup
import torch
@ -14,6 +15,7 @@ def readme():
version_file = 'mmocr/version.py'
is_windows = sys.platform == 'win32'
def get_version():
@ -108,7 +110,7 @@ def get_rroi_align_extensions():
source_cuda = glob.glob(os.path.join(extensions_dir, 'cuda', '*.cu'))
sources = main_file + source_cpu
extension = CppExtension
extra_compile_args = {'cxx': []}
extra_compile_args = {'cxx': ['/utf-8']} if is_windows else {'cxx': []}
define_macros = []
if torch.cuda.is_available() and CUDA_HOME is not None:
@ -176,10 +178,12 @@ if __name__ == '__main__':
ext_modules=[
CppExtension(
name='mmocr.models.textdet.postprocess.pan',
sources=[cpp_root + 'pan.cpp']),
sources=[cpp_root + 'pan.cpp'],
extra_compile_args=(['/utf-8'] if is_windows else [])),
CppExtension(
name='mmocr.models.textdet.postprocess.pse',
sources=[cpp_root + 'pse.cpp']),
sources=[cpp_root + 'pse.cpp'],
extra_compile_args=(['/utf-8'] if is_windows else [])),
get_rroi_align_extensions()
],
cmdclass={'build_ext': BuildExtension},