[bug] fix issues about sort_function and DB Head (#8580)
* support min_area_rect crop * add check_install * fix requirement.txt * fix check_install * add lanms-neo for drrg * fix * fix doc * fixpull/8616/head
parent
161e0ebfa6
commit
5cac747656
|
@ -308,7 +308,7 @@ void Utility::sorted_boxes(std::vector<OCRPredictResult> &ocr_result) {
|
|||
std::sort(ocr_result.begin(), ocr_result.end(), Utility::comparison_box);
|
||||
if (ocr_result.size() > 0) {
|
||||
for (int i = 0; i < ocr_result.size() - 1; i++) {
|
||||
for (int j = i; j > 0; j--) {
|
||||
for (int j = i; j >= 0; j--) {
|
||||
if (abs(ocr_result[j + 1].box[0][1] - ocr_result[j].box[0][1]) < 10 &&
|
||||
(ocr_result[j + 1].box[0][0] < ocr_result[j].box[0][0])) {
|
||||
std::swap(ocr_result[i], ocr_result[i + 1]);
|
||||
|
|
|
@ -31,7 +31,7 @@ def get_bias_attr(k):
|
|||
|
||||
|
||||
class Head(nn.Layer):
|
||||
def __init__(self, in_channels, name_list, kernel_list=[3, 2, 2], **kwargs):
|
||||
def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs):
|
||||
super(Head, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2D(
|
||||
|
@ -93,16 +93,8 @@ class DBHead(nn.Layer):
|
|||
def __init__(self, in_channels, k=50, **kwargs):
|
||||
super(DBHead, self).__init__()
|
||||
self.k = k
|
||||
binarize_name_list = [
|
||||
'conv2d_56', 'batch_norm_47', 'conv2d_transpose_0', 'batch_norm_48',
|
||||
'conv2d_transpose_1', 'binarize'
|
||||
]
|
||||
thresh_name_list = [
|
||||
'conv2d_57', 'batch_norm_49', 'conv2d_transpose_2', 'batch_norm_50',
|
||||
'conv2d_transpose_3', 'thresh'
|
||||
]
|
||||
self.binarize = Head(in_channels, binarize_name_list, **kwargs)
|
||||
self.thresh = Head(in_channels, thresh_name_list, **kwargs)
|
||||
self.binarize = Head(in_channels, **kwargs)
|
||||
self.thresh = Head(in_channels, **kwargs)
|
||||
|
||||
def step_function(self, x, y):
|
||||
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
|
||||
|
|
|
@ -14,4 +14,4 @@ lxml
|
|||
premailer
|
||||
openpyxl
|
||||
attrdict
|
||||
PyMuPDF<1.21.0
|
||||
PyMuPDF<1.21.0
|
||||
|
|
|
@ -123,7 +123,7 @@ def sorted_boxes(dt_boxes):
|
|||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
for j in range(i, 0, -1):
|
||||
for j in range(i, -1, -1):
|
||||
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
||||
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
||||
tmp = _boxes[j]
|
||||
|
|
Loading…
Reference in New Issue