Remove usage of \ ()

* Remove usage of \

Signed-off-by: lizz <lizz@sensetime.com>

* rebase

Signed-off-by: lizz <lizz@sensetime.com>

* typos

Signed-off-by: lizz <lizz@sensetime.com>

* Remove test dependency on tools/

Signed-off-by: lizz <lizz@sensetime.com>

* Remove usage of \

Signed-off-by: lizz <lizz@sensetime.com>

* rebase

Signed-off-by: lizz <lizz@sensetime.com>

* typos

Signed-off-by: lizz <lizz@sensetime.com>

* Remove test dependency on tools/

Signed-off-by: lizz <lizz@sensetime.com>

* typo

Signed-off-by: lizz <lizz@sensetime.com>

* KIE in keywords

Signed-off-by: lizz <lizz@sensetime.com>

* some renames

Signed-off-by: lizz <lizz@sensetime.com>

* kill isort skip

Signed-off-by: lizz <lizz@sensetime.com>

* aggregation discrimination

Signed-off-by: lizz <lizz@sensetime.com>

* aggregation discrimination

Signed-off-by: lizz <lizz@sensetime.com>

* tiny

Signed-off-by: lizz <lizz@sensetime.com>

* fix bug: model infer on cpu

Co-authored-by: Hongbin Sun <hongbin306@gmail.com>
pull/2/head
lizz 2021-04-06 20:16:46 +08:00 committed by GitHub
parent cbb4ec349b
commit 44ca9c2a61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 360 additions and 447 deletions

View File

@ -54,13 +54,13 @@ data = dict(
pipeline=train_pipeline),
val=dict(
type=dataset_type,
# select_firstk=1,
# select_first_k=1,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
# select_firstk=1,
# select_first_k=1,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
pipeline=test_pipeline))

View File

@ -53,13 +53,13 @@ data = dict(
pipeline=train_pipeline),
val=dict(
type=dataset_type,
# select_firstk=1,
# select_first_k=1,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
# select_firstk=1,
# select_first_k=1,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
pipeline=test_pipeline))

View File

@ -54,13 +54,13 @@ data = dict(
pipeline=train_pipeline),
val=dict(
type=dataset_type,
# select_firstk=1,
# select_first_k=1,
ann_file=data_root + '/instances_val.json',
img_prefix=data_root + '/imgs',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
# select_firstk=1,
# select_first_k=1,
ann_file=data_root + '/instances_val.json',
img_prefix=data_root + '/imgs',
pipeline=test_pipeline))

View File

@ -84,19 +84,19 @@ data = dict(
type=dataset_type,
ann_file=data_root + '/instances_training.json',
# for debugging top k imgs
# select_firstk=200,
# select_first_k=200,
img_prefix=data_root + '/imgs',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
# select_firstk=100,
# select_first_k=100,
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + '/instances_test.json',
img_prefix=data_root + '/imgs',
# select_firstk=100,
# select_first_k=100,
pipeline=test_pipeline))
evaluation = dict(interval=10, metric='hmean-iou')

View File

@ -25,8 +25,8 @@ def main():
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = \
model.cfg.data.test['datasets'][0].pipeline
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
# test a single image
result = model_inference(model, args.img)

View File

@ -30,8 +30,8 @@ def main():
model = init_detector(args.config, args.checkpoint, device=device)
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = \
model.cfg.data.test['datasets'][0].pipeline
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
camera = cv2.VideoCapture(args.camera_id)

View File

@ -71,7 +71,7 @@ master_doc = 'index'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = []
def builder_inited_handler(app):

View File

@ -25,6 +25,9 @@ def model_inference(model, img):
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
# process img_metas
data['img_metas'] = data['img_metas'][0].data
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]

View File

@ -48,7 +48,7 @@ def get_gt_masks(ann_infos):
infos of one image, containing following keys:
masks, masks_ignore.
Returns:
gt_masks (list[list[list[int]]]): Ground thruth masks.
gt_masks (list[list[list[int]]]): Ground truth masks.
gt_masks_ignore (list[list[list[int]]]): Ignored masks.
"""
assert utils.is_type_list(ann_infos, dict)

View File

@ -130,9 +130,9 @@ def eval_hmean_ic13(det_boxes,
# match one gt to one pred box.
for gt_id in range(gt_num):
for pred_id in range(pred_num):
if gt_hit[gt_id] != 0 or pred_hit[
pred_id] != 0 or gt_id in gt_ignored_index \
or pred_id in pred_ignored_index:
if (gt_hit[gt_id] != 0 or pred_hit[pred_id] != 0
or gt_id in gt_ignored_index
or pred_id in pred_ignored_index):
continue
match = eval_utils.one2one_match_ic13(
gt_id, pred_id, recall_mat, precision_mat, recall_thr,

View File

@ -80,9 +80,9 @@ def eval_hmean_iou(pred_boxes,
for gt_id in range(gt_num):
for pred_id in range(pred_num):
if gt_hit[gt_id] != 0 or pred_hit[
pred_id] != 0 or gt_id in gt_ignored_index \
or pred_id in pred_ignored_index:
if (gt_hit[gt_id] != 0 or pred_hit[pred_id] != 0
or gt_id in gt_ignored_index
or pred_id in pred_ignored_index):
continue
if iou_mat[gt_id, pred_id] > iou_thr:
gt_hit[gt_id] = 1

View File

@ -29,24 +29,23 @@ def ignore_pred(pred_boxes, gt_ignored_index, gt_polys, precision_thr):
pred_points = []
pred_ignored_index = []
gt_dont_care_num = len(gt_ignored_index)
gt_ignored_num = len(gt_ignored_index)
# get detection polygons
for box_id, box in enumerate(pred_boxes):
poly = points2polygon(box)
pred_polys.append(poly)
pred_points.append(box)
if gt_dont_care_num < 1:
if gt_ignored_num < 1:
continue
# ignore the current detection box
# if its overlap with any ignored gt > precision_thr
for dont_care_box_id in gt_ignored_index:
dont_care_box = gt_polys[dont_care_box_id]
inter_area, _ = poly_intersection(poly, dont_care_box)
for ignored_box_id in gt_ignored_index:
ignored_box = gt_polys[ignored_box_id]
inter_area, _ = poly_intersection(poly, ignored_box)
area = poly.area()
precision = 0 if area == 0 \
else inter_area / area
precision = 0 if area == 0 else inter_area / area
if precision > precision_thr:
pred_ignored_index.append(box_id)
break
@ -216,10 +215,10 @@ def one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, recall_thr,
Args:
gt_id (int): The ground truth id index.
det_id (int): The detection result id index.
recall_mat (ndarray): gt_numxdet_num matrix with element (i,j) being
the recall ratio of gt i to det j.
precision_mat (ndarray): gt_numxdet_num matrix with element (i,j) being
the precision ratio of gt i to det j.
recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the recall ratio of gt i to det j.
precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the precision ratio of gt i to det j.
recall_thr (float): The recall threshold.
precision_thr (float): The precision threshold.
Returns:
@ -258,21 +257,21 @@ def one2one_match_ic13(gt_id, det_id, recall_mat, precision_mat, recall_thr,
def one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr,
precision_thr, gt_match_flag, det_match_flag,
det_dont_care_index):
det_ignored_index):
"""One-to-Many match gt and detections with icdar2013 standards.
Args:
gt_id (int): gt index.
recall_mat (ndarray): gt_numxdet_num matrix with element (i,j) being
the recall ratio of gt i to det j.
precision_mat (ndarray): gt_numxdet_num matrix with element (i,j) being
the precision ratio of gt i to det j.
recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the recall ratio of gt i to det j.
precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the precision ratio of gt i to det j.
recall_thr (float): The recall threshold.
precision_thr (float): The precision threshold.
gt_match_flag (ndarray): An array indicates each gt matched already.
det_match_flag (ndarray): An array indicates each box has been
matched already or not.
det_dont_care_index (list): A list indicates each detection box can be
det_ignored_index (list): A list indicates each detection box can be
ignored or not.
Returns:
@ -287,13 +286,13 @@ def one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr,
assert isinstance(gt_match_flag, list)
assert isinstance(det_match_flag, list)
assert isinstance(det_dont_care_index, list)
assert isinstance(det_ignored_index, list)
many_sum = 0.
det_ids = []
for det_id in range(recall_mat.shape[1]):
if gt_match_flag[gt_id] == 0 and det_match_flag[
det_id] == 0 and det_id not in det_dont_care_index:
det_id] == 0 and det_id not in det_ignored_index:
if precision_mat[gt_id, det_id] >= precision_thr:
many_sum += recall_mat[gt_id, det_id]
det_ids.append(det_id)
@ -304,22 +303,22 @@ def one2many_match_ic13(gt_id, recall_mat, precision_mat, recall_thr,
def many2one_match_ic13(det_id, recall_mat, precision_mat, recall_thr,
precision_thr, gt_match_flag, det_match_flag,
gt_dont_care_index):
gt_ignored_index):
"""Many-to-One match gt and detections with icdar2013 standards.
Args:
det_id (int): Detection index.
recall_mat (ndarray): gt_numxdet_num matrix with element (i,j) being
the recall ratio of gt i to det j.
precision_mat (ndarray): gt_numxdet_num matrix with element (i,j) being
the precision ratio of gt i to det j.
recall_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the recall ratio of gt i to det j.
precision_mat (ndarray): `gt_num x det_num` matrix with element (i,j)
being the precision ratio of gt i to det j.
recall_thr (float): The recall threshold.
precision_thr (float): The precision threshold.
gt_match_flag (ndarray): An array indicates each gt has been matched
already.
det_match_flag (ndarray): An array indicates each detection box has
been matched already or not.
gt_dont_care_index (list): A list indicates each gt box can be ignored
gt_ignored_index (list): A list indicates each gt box can be ignored
or not.
Returns:
@ -334,12 +333,12 @@ def many2one_match_ic13(det_id, recall_mat, precision_mat, recall_thr,
assert isinstance(gt_match_flag, list)
assert isinstance(det_match_flag, list)
assert isinstance(gt_dont_care_index, list)
assert isinstance(gt_ignored_index, list)
many_sum = 0.
gt_ids = []
for gt_id in range(recall_mat.shape[0]):
if gt_match_flag[gt_id] == 0 and det_match_flag[
det_id] == 0 and gt_id not in gt_dont_care_index:
det_id] == 0 and gt_id not in gt_ignored_index:
if recall_mat[gt_id, det_id] >= recall_thr:
many_sum += precision_mat[gt_id, det_id]
gt_ids.append(gt_id)

View File

@ -89,7 +89,7 @@ class BaseDataset(Dataset):
index (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys \
dict: Training data and annotation after pipeline with new keys
introduced by pipeline.
"""
img_info = self.data_infos[index]
@ -104,7 +104,7 @@ class BaseDataset(Dataset):
idx (int): Index of data.
Returns:
dict: Testing data after pipeline with new keys introduced by \
dict: Testing data after pipeline with new keys introduced by
pipeline.
"""
return self.prepare_train_img(img_info)

View File

@ -21,9 +21,9 @@ class IcdarDataset(CocoDataset):
proposal_file=None,
test_mode=False,
filter_empty_gt=True,
select_firstk=-1):
select_first_k=-1):
# select first k images for fast debugging.
self.select_firstk = select_firstk
self.select_first_k = select_first_k
super().__init__(ann_file, pipeline, classes, data_root, img_prefix,
seg_prefix, proposal_file, test_mode, filter_empty_gt)
@ -50,7 +50,7 @@ class IcdarDataset(CocoDataset):
info['filename'] = info['file_name']
data_infos.append(info)
count = count + 1
if count > self.select_firstk and self.select_firstk > 0:
if count > self.select_first_k and self.select_first_k > 0:
break
return data_infos

View File

@ -117,7 +117,7 @@ class KIEDataset(BaseDataset):
index (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys \
dict: Training data and annotation after pipeline with new keys
introduced by pipeline.
"""
img_ann_info = self.data_infos[index]

View File

@ -28,8 +28,7 @@ class OCRSegDataset(OCRDataset):
assert utils.is_type_list(annotations, dict)
assert 'char_box' in annotations[0]
assert 'char_text' in annotations[0]
assert len(annotations[0]['char_box']) == 4 or \
len(annotations[0]['char_box']) == 8
assert len(annotations[0]['char_box']) in [4, 8]
chars, char_rects, char_quads = [], [], []
for ann in annotations:
@ -75,7 +74,7 @@ class OCRSegDataset(OCRDataset):
index (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys \
dict: Training data and annotation after pipeline with new keys
introduced by pipeline.
"""
img_ann_info = self.data_infos[index]

View File

@ -143,10 +143,10 @@ class EastRandomCrop:
scale = min(scale_w, scale_h)
h = int(crop_h * scale)
w = int(crop_w * scale)
padimg = np.zeros(
padded_img = np.zeros(
(self.target_size[1], self.target_size[0], img.shape[2]),
img.dtype)
padimg[:h, :w] = cv2.resize(
padded_img[:h, :w] = cv2.resize(
img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
# for bboxes
@ -172,8 +172,8 @@ class EastRandomCrop:
if key == 'gt_masks':
results['gt_labels'] = polys_label
results['img'] = padimg
results['img_shape'] = padimg.shape
results['img'] = padded_img
results['img_shape'] = padded_img.shape
return results
@ -229,12 +229,12 @@ class EastRandomCrop:
for points in polys:
points = np.round(
points, decimals=0).astype(np.int32).reshape(-1, 2)
minx = np.min(points[:, 0])
maxx = np.max(points[:, 0])
w_array[minx:maxx] = 1
miny = np.min(points[:, 1])
maxy = np.max(points[:, 1])
h_array[miny:maxy] = 1
min_x = np.min(points[:, 0])
max_x = np.max(points[:, 0])
w_array[min_x:max_x] = 1
min_y = np.min(points[:, 1])
max_y = np.max(points[:, 1])
h_array[min_y:max_y] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
@ -255,8 +255,8 @@ class EastRandomCrop:
else:
ymin, ymax = self.random_select(h_axis, h)
if xmax - xmin < self.min_crop_side_ratio * w or \
ymax - ymin < self.min_crop_side_ratio * h:
if (xmax - xmin < self.min_crop_side_ratio * w
or ymax - ymin < self.min_crop_side_ratio * h):
# area too small
continue
num_poly_in_rect = 0

View File

@ -20,7 +20,7 @@ class KIEFormatBundle(DefaultFormatBundle):
- gt_bboxes_ignore: (1) to tensor, (2) to DataContainer
- gt_labels: (1) to tensor, (2) to DataContainer
- gt_masks: (1) to tensor, (2) to DataContainer (cpu_only=True)
- gt_semantic_seg: (1) unsqueeze dim-0 (2) to tensor, \
- gt_semantic_seg: (1) unsqueeze dim-0 (2) to tensor,
(3) to DataContainer (stack=True)
- relations: (1) scale, (2) to tensor, (3) to DataContainer
- texts: (1) to tensor, (2) to DataContainer
@ -33,7 +33,7 @@ class KIEFormatBundle(DefaultFormatBundle):
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data that is formatted with \
dict: The result dict contains the data that is formatted with
default bundle.
"""
super().__call__(results)

View File

@ -44,9 +44,8 @@ class ResizeOCR:
assert utils.is_none_or_type(min_width, (int, tuple))
assert utils.is_none_or_type(max_width, (int, tuple))
if not keep_aspect_ratio:
assert max_width is not None, \
'"max_width" must assigned ' + \
'if "keep_aspect_ratio" is False'
assert max_width is not None, ('"max_width" must assigned '
'if "keep_aspect_ratio" is False')
assert isinstance(img_pad_value, int)
if isinstance(height, tuple):
assert isinstance(min_width, tuple)

View File

@ -30,28 +30,25 @@ class BaseTextDetTargets:
"""
# suppose a triangle with three edge abc with c=point_1 point_2
# a^2
square_distance_1 = np.square(xs - point_1[0]) + np.square(ys -
point_1[1])
a_square = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
# b^2
square_distance_2 = np.square(xs - point_2[0]) + np.square(ys -
point_2[1])
b_square = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
# c^2
square_distance = np.square(point_1[0] -
point_2[0]) + np.square(point_1[1] -
point_2[1])
# cosC=(c^2-a^2-b^2)/2(ab)
cosin = (square_distance - square_distance_1 - square_distance_2) / \
(np.finfo(np.float32).eps +
2 * np.sqrt(square_distance_1 * square_distance_2))
c_square = np.square(point_1[0] - point_2[0]) + np.square(point_1[1] -
point_2[1])
# -cosC=(c^2-a^2-b^2)/2(ab)
neg_cos_c = (
(c_square - a_square - b_square) /
(np.finfo(np.float32).eps + 2 * np.sqrt(a_square * b_square)))
# sinC^2=1-cosC^2
square_sin = 1 - np.square(cosin)
square_sin = 1 - np.square(neg_cos_c)
square_sin = np.nan_to_num(square_sin)
# distance=a*b*sinC/c=a*h/c=2*area/c
result = np.sqrt(square_distance_1 * square_distance_2 * square_sin /
(np.finfo(np.float32).eps + square_distance))
# set result to minimum edge if C>pi/2
result[cosin < 0] = np.sqrt(
np.fmin(square_distance_1, square_distance_2))[cosin < 0]
result = np.sqrt(a_square * b_square * square_sin /
(np.finfo(np.float32).eps + c_square))
# set result to minimum edge if C<pi/2
result[neg_cos_c < 0] = np.sqrt(np.fmin(a_square,
b_square))[neg_cos_c < 0]
return result
def polygon_area(self, polygon):

View File

@ -136,8 +136,9 @@ class DBNetTargets(BaseTextDetTargets):
assert polygon.shape[1] == 2
polygon_shape = Polygon(polygon)
distance = polygon_shape.area * \
(1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
distance = (
polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) /
polygon_shape.length)
subject = [tuple(p) for p in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND,

View File

@ -90,8 +90,8 @@ class TextSnakeTargets(BaseTextDetTargets):
theta_sum = np.array(theta_sum)
head_start, tail_start = np.argsort(theta_sum)[::-1][0:2]
if abs(head_start - tail_start) < 2 \
or abs(head_start - tail_start) > 12:
if (abs(head_start - tail_start) < 2
or abs(head_start - tail_start) > 12):
tail_start = (head_start + len(points) // 2) % len(points)
head_end = (head_start + 1) % len(points)
tail_end = (tail_start + 1) % len(points)

View File

@ -44,8 +44,9 @@ class RandomCropInstances:
# target size is bigger than origin size
t_h = t_h if t_h < h else h
t_w = t_w if t_w < w else w
if img_gt is not None and np.random.random_sample() < \
self.positive_sample_ratio and np.max(img_gt) > 0:
if (img_gt is not None
and np.random.random_sample() < self.positive_sample_ratio
and np.max(img_gt) > 0):
# make sure to crop the positive region

View File

@ -58,7 +58,7 @@ class TextDetDataset(BaseDataset):
index (int): Index of data.
Returns:
dict: Training data and annotation after pipeline with new keys \
dict: Training data and annotation after pipeline with new keys
introduced by pipeline.
"""
img_ann_info = self.data_infos[index]

View File

@ -212,12 +212,13 @@ class DeconvModule(nn.Module):
scale_factor=2):
super().__init__()
assert (kernel_size - scale_factor >= 0) and\
(kernel_size - scale_factor) % 2 == 0,\
f'kernel_size should be greater than or equal to scale_factor '\
f'and (kernel_size - scale_factor) should be even numbers, '\
f'while the kernel size is {kernel_size} and scale_factor is '\
f'{scale_factor}.'
assert (
kernel_size - scale_factor >= 0
and (kernel_size - scale_factor) % 2 == 0), (
f'kernel_size should be greater than or equal to scale_factor '
f'and (kernel_size - scale_factor) should be even numbers, '
f'while the kernel size is {kernel_size} and scale_factor is '
f'{scale_factor}.')
stride = scale_factor
padding = (kernel_size - scale_factor) // 2
@ -394,36 +395,36 @@ class UNet(nn.Module):
super().__init__()
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
assert len(strides) == num_stages, \
'The length of strides should be equal to num_stages, '\
f'while the strides is {strides}, the length of '\
f'strides is {len(strides)}, and the num_stages is '\
f'{num_stages}.'
assert len(enc_num_convs) == num_stages, \
'The length of enc_num_convs should be equal to num_stages, '\
f'while the enc_num_convs is {enc_num_convs}, the length of '\
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '\
f'{num_stages}.'
assert len(dec_num_convs) == (num_stages-1), \
'The length of dec_num_convs should be equal to (num_stages-1), '\
f'while the dec_num_convs is {dec_num_convs}, the length of '\
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '\
f'{num_stages}.'
assert len(downsamples) == (num_stages-1), \
'The length of downsamples should be equal to (num_stages-1), '\
f'while the downsamples is {downsamples}, the length of '\
f'downsamples is {len(downsamples)}, and the num_stages is '\
f'{num_stages}.'
assert len(enc_dilations) == num_stages, \
'The length of enc_dilations should be equal to num_stages, '\
f'while the enc_dilations is {enc_dilations}, the length of '\
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '\
f'{num_stages}.'
assert len(dec_dilations) == (num_stages-1), \
'The length of dec_dilations should be equal to (num_stages-1), '\
f'while the dec_dilations is {dec_dilations}, the length of '\
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '\
f'{num_stages}.'
assert len(strides) == num_stages, (
'The length of strides should be equal to num_stages, '
f'while the strides is {strides}, the length of '
f'strides is {len(strides)}, and the num_stages is '
f'{num_stages}.')
assert len(enc_num_convs) == num_stages, (
'The length of enc_num_convs should be equal to num_stages, '
f'while the enc_num_convs is {enc_num_convs}, the length of '
f'enc_num_convs is {len(enc_num_convs)}, and the num_stages is '
f'{num_stages}.')
assert len(dec_num_convs) == (num_stages - 1), (
'The length of dec_num_convs should be equal to (num_stages-1), '
f'while the dec_num_convs is {dec_num_convs}, the length of '
f'dec_num_convs is {len(dec_num_convs)}, and the num_stages is '
f'{num_stages}.')
assert len(downsamples) == (num_stages - 1), (
'The length of downsamples should be equal to (num_stages-1), '
f'while the downsamples is {downsamples}, the length of '
f'downsamples is {len(downsamples)}, and the num_stages is '
f'{num_stages}.')
assert len(enc_dilations) == num_stages, (
'The length of enc_dilations should be equal to num_stages, '
f'while the enc_dilations is {enc_dilations}, the length of '
f'enc_dilations is {len(enc_dilations)}, and the num_stages is '
f'{num_stages}.')
assert len(dec_dilations) == (num_stages - 1), (
'The length of dec_dilations should be equal to (num_stages-1), '
f'while the dec_dilations is {dec_dilations}, the length of '
f'dec_dilations is {len(dec_dilations)}, and the num_stages is '
f'{num_stages}.')
self.num_stages = num_stages
self.strides = strides
self.downsamples = downsamples
@ -501,12 +502,12 @@ class UNet(nn.Module):
for i in range(1, self.num_stages):
if self.strides[i] == 2 or self.downsamples[i - 1]:
whole_downsample_rate *= 2
assert (h % whole_downsample_rate == 0) \
and (w % whole_downsample_rate == 0),\
f'The input image size {(h, w)} should be divisible by the whole '\
f'downsample rate {whole_downsample_rate}, when num_stages is '\
f'{self.num_stages}, strides is {self.strides}, and downsamples '\
f'is {self.downsamples}.'
assert (
h % whole_downsample_rate == 0 and w % whole_downsample_rate == 0
), (f'The input image size {(h, w)} should be divisible by the whole '
f'downsample rate {whole_downsample_rate}, when num_stages is '
f'{self.num_stages}, strides is {self.strides}, and downsamples '
f'is {self.downsamples}.')
def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.

View File

@ -141,10 +141,10 @@ class Block(nn.Module):
self.pos_norm = pos_norm
# Modules
self.linear0 = nn.Linear(input_dims[0], mm_dim)
self.linear1 = self.linear0 if shared \
else nn.Linear(input_dims[1], mm_dim)
self.merge_linears0, self.merge_linears1 =\
nn.ModuleList(), nn.ModuleList()
self.linear1 = (
self.linear0 if shared else nn.Linear(input_dims[1], mm_dim))
self.merge_linears0 = nn.ModuleList()
self.merge_linears1 = nn.ModuleList()
self.chunks = self.chunk_sizes(mm_dim, chunks)
for size in self.chunks:
ml0 = nn.Linear(size, size * rank)

View File

@ -1,10 +1,10 @@
from .single_stage_text_detector import SingleStageTextDetector # isort:skip
from .text_detector_mixin import TextDetectorMixin # isort:skip
from .dbnet import DBNet # isort:skip
from .ocr_mask_rcnn import OCRMaskRCNN # isort:skip
from .panet import PANet # isort:skip
from .psenet import PSENet # isort:skip
from .textsnake import TextSnake # isort:skip
from .dbnet import DBNet
from .ocr_mask_rcnn import OCRMaskRCNN
from .panet import PANet
from .psenet import PSENet
from .single_stage_text_detector import SingleStageTextDetector
from .text_detector_mixin import TextDetectorMixin
from .textsnake import TextSnake
__all__ = [
'TextDetectorMixin', 'SingleStageTextDetector', 'OCRMaskRCNN', 'DBNet',

View File

@ -1,5 +1,8 @@
from mmdet.models.builder import DETECTORS
from . import SingleStageTextDetector, TextDetectorMixin
from mmocr.models.textdet.detectors.single_stage_text_detector import \
SingleStageTextDetector
from mmocr.models.textdet.detectors.text_detector_mixin import \
TextDetectorMixin
@DETECTORS.register_module()

View File

@ -1,6 +1,7 @@
from mmdet.models.builder import DETECTORS
from mmdet.models.detectors import MaskRCNN
from . import TextDetectorMixin
from mmocr.models.textdet.detectors.text_detector_mixin import \
TextDetectorMixin
@DETECTORS.register_module()

View File

@ -1,5 +1,8 @@
from mmdet.models.builder import DETECTORS
from . import SingleStageTextDetector, TextDetectorMixin
from mmocr.models.textdet.detectors.single_stage_text_detector import \
SingleStageTextDetector
from mmocr.models.textdet.detectors.text_detector_mixin import \
TextDetectorMixin
@DETECTORS.register_module()

View File

@ -1,5 +1,8 @@
from mmdet.models.builder import DETECTORS
from . import SingleStageTextDetector, TextDetectorMixin
from mmocr.models.textdet.detectors.single_stage_text_detector import \
SingleStageTextDetector
from mmocr.models.textdet.detectors.text_detector_mixin import \
TextDetectorMixin
@DETECTORS.register_module()

View File

@ -23,8 +23,8 @@ class PANLoss(nn.Module):
def __init__(self,
alpha=0.5,
beta=0.25,
delta_aggr=0.5,
delta_discr=3,
delta_aggregation=0.5,
delta_discrimination=3,
ohem_ratio=3,
reduction='mean',
speedup_bbox_thr=-1):
@ -33,20 +33,19 @@ class PANLoss(nn.Module):
Args:
alpha (float): The kernel loss coef.
beta (float): The aggregation and discriminative loss coef.
delta_aggr (float): The constant for aggregation loss.
delta_discr (float): The constant for discriminative loss.
delta_aggregation (float): The constant for aggregation loss.
delta_discrimination (float): The constant for discriminative loss.
ohem_ratio (float): The negative/positive ratio in ohem.
reduction (str): The way to reduce the loss.
speedup_bbox_thr (int): Speed up if speedup_bbox_thr >0
and <bbox num.
"""
super().__init__()
assert reduction in ['mean',
'sum'], " reduction must in ['mean','sum']"
assert reduction in ['mean', 'sum'], "reduction must in ['mean','sum']"
self.alpha = alpha
self.beta = beta
self.delta_aggr = delta_aggr
self.delta_discr = delta_discr
self.delta_aggregation = delta_aggregation
self.delta_discrimination = delta_discrimination
self.ohem_ratio = ohem_ratio
self.reduction = reduction
self.speedup_bbox_thr = speedup_bbox_thr
@ -120,9 +119,8 @@ class PANLoss(nn.Module):
gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
gt[key] = [item.to(preds.device) for item in gt[key]]
loss_aggrs, loss_discrs = self.aggr_discr_loss(gt['gt_kernels'][0],
gt['gt_kernels'][1],
inst_embed)
loss_aggrs, loss_discrs = self.aggregation_discrimination_loss(
gt['gt_kernels'][0], gt['gt_kernels'][1], inst_embed)
# compute text loss
sampled_mask = self.ohem_batch(pred_texts.detach(),
gt['gt_kernels'][0], gt['gt_mask'][0])
@ -152,18 +150,19 @@ class PANLoss(nn.Module):
results.update(
loss_text=losses[0],
loss_kernel=losses[1],
loss_aggr=losses[2],
loss_discr=losses[3])
loss_aggregation=losses[2],
loss_discrimination=losses[3])
return results
def aggr_discr_loss(self, gt_texts, gt_kernels, inst_embeds):
def aggregation_discrimination_loss(self, gt_texts, gt_kernels,
inst_embeds):
"""Compute the aggregation and discrimnative losses.
Args:
gt_texts (tensor): The ground truth text mask of size Nx1xHxW.
gt_kernels (tensor): The ground truth text kernel mask of
size Nx1xHxW.
inst_embeds(tensor): The text instance emebdding tensor
inst_embeds(tensor): The text instance embedding tensor
of size Nx4xHxW.
Returns:
@ -205,9 +204,9 @@ class PANLoss(nn.Module):
kernel_avgs.append(avg)
embed_i = embed[:, text == i] # 0.6ms
# ||F(p) - G(K_i)|| - delta_aggr, shape: nums
# ||F(p) - G(K_i)|| - delta_aggregation, shape: nums
distance = (embed_i - avg.reshape(4, 1)).norm( # 0.5ms
2, dim=0) - self.delta_aggr
2, dim=0) - self.delta_aggregation
# compute D(p,K_i) in Eq (2)
hinge = torch.max(
distance,
@ -227,8 +226,9 @@ class PANLoss(nn.Module):
loss_discr_img = 0
for avg_i, avg_j in itertools.combinations(kernel_avgs, 2):
# delta_discr - ||G(K_i) - G(K_j)||
distance_ij = self.delta_discr - (avg_i - avg_j).norm(2)
# delta_discrimination - ||G(K_i) - G(K_j)||
distance_ij = self.delta_discrimination - (avg_i -
avg_j).norm(2)
# D(K_i,K_j)
D_ij = torch.max(
distance_ij,

View File

@ -57,7 +57,7 @@ def pan_decode(preds,
text_score = preds[0].astype(np.float32)
text = preds[0] > min_text_confidence
kernel = (preds[1] > min_kernel_confidence) * text
embeddings = preds[2:].transpose((1, 2, 0)) # hxwx4
embeddings = preds[2:].transpose((1, 2, 0)) # (h, w, 4)
region_num, labels = cv2.connectedComponents(
kernel.astype(np.uint8), connectivity=4)

View File

@ -152,12 +152,15 @@ class ParallelSARDecoder(BaseDecoder):
return y
def forward_train(self, feat, out_enc, targets_dict, img_metas):
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
if img_metas is not None:
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
targets = targets_dict['padded_targets'].to(feat.device)
tgt_embedding = self.embedding(targets)
@ -173,12 +176,15 @@ class ParallelSARDecoder(BaseDecoder):
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
def forward_test(self, feat, out_enc, img_metas):
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
if img_metas is not None:
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
seq_len = self.max_seq_len
@ -348,12 +354,15 @@ class SequentialSARDecoder(BaseDecoder):
return y, hx1, hx1, hx2, hx2
def forward_train(self, feat, out_enc, targets_dict, img_metas=None):
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
if img_metas is not None:
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
if self.train_mode:
targets = targets_dict['padded_targets'].to(feat.device)
@ -401,7 +410,8 @@ class SequentialSARDecoder(BaseDecoder):
return outputs
def forward_test(self, feat, out_enc, img_metas):
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
if img_metas is not None:
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
return self.forward_train(feat, out_enc, None, img_metas)

View File

@ -72,12 +72,15 @@ class SAREncoder(BaseEncoder):
uniform_init(m)
def forward(self, feat, img_metas=None):
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
if img_metas is not None:
assert utils.is_type_list(img_metas, dict)
assert len(img_metas) == feat.size(0)
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
valid_ratios = None
if img_metas is not None:
valid_ratios = [
img_meta.get('valid_ratio', 1.0) for img_meta in img_metas
] if self.mask else None
h_feat = feat.size(2)
feat_v = F.max_pool2d(

View File

@ -77,9 +77,9 @@ class BaseRecognizer(nn.Module, metaclass=ABCMeta):
if isinstance(imgs, list):
assert len(imgs) == len(img_metas)
assert len(imgs) > 0
assert imgs[0].size(0) == 1, 'aug test does not support ' \
'inference with batch size ' \
f'{imgs[0].size(0)}'
assert imgs[0].size(0) == 1, ('aug test does not support '
f'inference with batch size '
f'{imgs[0].size(0)}')
return self.aug_test(imgs, img_metas, **kwargs)
return self.simple_test(imgs, img_metas, **kwargs)
@ -105,8 +105,8 @@ class BaseRecognizer(nn.Module, metaclass=ABCMeta):
losses and other necessary infomation.
Returns:
tuple[tensor, dict]: (loss, log_vars), loss is the loss tensor \
which may be a weighted sum of all losses, log_vars contains \
tuple[tensor, dict]: (loss, log_vars), loss is the loss tensor
which may be a weighted sum of all losses, log_vars contains
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
@ -148,15 +148,15 @@ class BaseRecognizer(nn.Module, metaclass=ABCMeta):
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
``num_samples``.
- ``loss`` is a tensor for back propagation, which is a \
- ``loss`` is a tensor for back propagation, which is a
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
logger.
- ``num_samples`` indicates the batch size used for \
averaging the logs (Note: for the \
- ``num_samples`` indicates the batch size used for
averaging the logs (Note: for the
DDP model, num_samples refers to the batch size for each GPU).
"""
losses = self(**data)

View File

@ -135,8 +135,8 @@ class EncodeDecodeRecognizer(BaseRecognizer):
out_dec = self.decoder(
feat, out_enc, None, img_metas, train_mode=False)
label_indexes, label_scores = \
self.label_convertor.tensor2idx(out_dec, img_metas)
label_indexes, label_scores = self.label_convertor.tensor2idx(
out_dec, img_metas)
label_strings = self.label_convertor.idx2str(label_indexes)
# flatten batch results

View File

@ -4,10 +4,12 @@ from mmdet.utils import get_root_logger
from .check_argument import (equal_len, is_2dlist, is_3dlist, is_ndarray_list,
is_none_or_type, is_type_list, valid_boundary)
from .collect_env import collect_env
from .img_util import drop_orientation
from .lmdb_util import lmdb_converter
__all__ = [
'Registry', 'build_from_cfg', 'get_root_logger', 'collect_env',
'is_3dlist', 'is_ndarray_list', 'is_type_list', 'is_none_or_type',
'equal_len', 'is_2dlist', 'valid_boundary', 'lmdb_converter'
'equal_len', 'is_2dlist', 'valid_boundary', 'lmdb_converter',
'drop_orientation'
]

View File

@ -0,0 +1,34 @@
import os
import mmcv
def drop_orientation(img_file):
"""Check if the image has orientation information. If yes, ignore it by
converting the image format to png, and return new filename, otherwise
return the original filename.
Args:
img_file(str): The image path
Returns:
The converted image filename with proper postfix
"""
assert isinstance(img_file, str)
assert img_file
# read imgs with ignoring orientations
img = mmcv.imread(img_file, 'unchanged')
# read imgs with orientations as dataloader does when training and testing
img_color = mmcv.imread(img_file, 'color')
# make sure imgs have no orientation info, or annotation gt is wrong.
if img.shape[:2] == img_color.shape[:2]:
return img_file
target_file = os.path.splitext(img_file)[0] + '.png'
# read img with ignoring orientation information
img = mmcv.imread(img_file, 'unchanged')
mmcv.imwrite(img, target_file)
os.remove(img_file)
print(f'{img_file} has orientation info. Ignore it by converting to png')
return target_file

View File

@ -6,9 +6,9 @@ from pathlib import Path
import lmdb
def lmdb_converter(imglist, output, batch_size=1000, coding='utf-8'):
# read imglist
with open(imglist) as f:
def lmdb_converter(img_list, output, batch_size=1000, coding='utf-8'):
# read img_list
with open(img_list) as f:
lines = f.readlines()
# create lmdb database

View File

@ -148,7 +148,7 @@ if __name__ == '__main__':
version=get_version(),
description='Text Detection, OCR, and NLP Toolbox',
long_description=readme(),
keywords='Text Detection, OCR, NLP',
keywords='Text Detection, OCR, KIE, NLP',
url='https://github.com/jeffreykuang/mmocr',
packages=find_packages(exclude=('configs', 'tools', 'demo')),
package_data={'mmocr.ops': ['*/*.so']},

View File

@ -31,12 +31,12 @@ def test_model_inference():
else:
print(f'Using existing checkpoint {checkpoint_file}')
device = 'cuda:0'
device = 'cpu'
model = init_detector(
config_file, checkpoint=checkpoint_file, device=device)
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = \
model.cfg.data.test['datasets'][0].pipeline
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
img = os.path.join(project_dir, '../demo/demo_text_recog.jpg')

View File

@ -97,7 +97,7 @@ def test_icdar_dataset():
dataset = IcdarDataset(ann_file=fake_json_file, pipeline=[])
assert dataset.CLASSES == ('text')
assert dataset.img_ids == [0, 1]
assert dataset.select_firstk == -1
assert dataset.select_first_k == -1
# test _parse_ann_info
ann = dataset.get_ann_info(0)

View File

@ -21,8 +21,10 @@ def test_line_str_parser():
# test get_item
parser = LineStrParser(keys, keys_idx, separator)
assert parser.get_item(data_ret, 0) == \
{'filename': 'sample1.jpg', 'text': 'hello'}
assert parser.get_item(data_ret, 0) == {
'filename': 'sample1.jpg',
'text': 'hello'
}
with pytest.raises(Exception):
parser = LineStrParser(['filename', 'text', 'ignore'], [0, 1, 2],
@ -51,8 +53,10 @@ def test_line_dict_parser():
# test get_item
parser = LineJsonParser(keys)
assert parser.get_item(data_ret, 0) == \
{'filename': 'sample1.jpg', 'text': 'hello'}
assert parser.get_item(data_ret, 0) == {
'filename': 'sample1.jpg',
'text': 'hello'
}
with pytest.raises(Exception):
parser = LineJsonParser(['img_name', 'text'])

View File

@ -23,8 +23,6 @@ def test_sar_encoder():
encoder.train()
feat = torch.randn(1, 512, 4, 40)
with pytest.raises(AssertionError):
encoder(feat)
img_metas = [{'valid_ratio': 1.0}]
with pytest.raises(AssertionError):
encoder(feat, img_metas * 2)

View File

@ -3,16 +3,16 @@
import shutil
import tempfile
from mmocr.utils import drop_orientation
def test_check_ignore_orientation():
from tools.data.utils.common \
import check_ignore_orientation
def test_drop_orientation():
img_file = 'tests/data/test_img2.jpg'
output_file = check_ignore_orientation(img_file)
output_file = drop_orientation(img_file)
assert output_file is img_file
img_file = 'tests/data/test_img1.jpg'
tmp_dir = tempfile.TemporaryDirectory()
dst_file = shutil.copy(img_file, tmp_dir.name)
output_file = check_ignore_orientation(dst_file)
output_file = drop_orientation(dst_file)
assert output_file[-3:] == 'png'

View File

@ -1,6 +1,5 @@
import argparse
import glob
import os
import os.path as osp
import xml.etree.ElementTree as ET
from functools import partial
@ -8,56 +7,9 @@ from functools import partial
import mmcv
import numpy as np
from shapely.geometry import Polygon
from tools.data.utils.common import convert_annotations, is_not_png
def check_ignore_orientation(img_file):
"""Check if the image has orientation information.
If yes, ignore it by converting the image format to png, otherwise return
the original filename.
Args:
img_file(str): The image path
Returns:
The converted image filename with proper postfix
"""
assert isinstance(img_file, str)
assert img_file
# read imgs with ignoring orientations
img = mmcv.imread(img_file, 'unchanged')
# read imgs with orientations as dataloader does when training and testing
img_color = mmcv.imread(img_file, 'color')
# make sure imgs have no orientations info, or annotation gt is wrong.
if img.shape[:2] == img_color.shape[:2]:
return img_file
else:
target_file = osp.splitext(img_file)[0] + '.png'
# read img with ignoring orientation information
img = mmcv.imread(img_file, 'unchanged')
mmcv.imwrite(img, target_file)
os.remove(img_file)
print(
f'{img_file} has orientation info. Ingore it by converting to png')
return target_file
def is_not_png(img_file):
"""Check img_file is not png image.
Args:
img_file(str): The input image file name
Returns:
The bool flag indicating whether it is not png
"""
assert isinstance(img_file, str)
assert img_file
suffix = osp.splitext(img_file)[1]
return (suffix not in ['.PNG', '.png'])
from mmocr.utils import drop_orientation
def collect_files(img_dir, gt_dir, split):
@ -79,14 +31,13 @@ def collect_files(img_dir, gt_dir, split):
# note that we handle png and jpg only. Pls convert others such as gif to
# jpg or png offline
suffixes = ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG']
# suffixes = ['.png']
imgs_list = []
for suffix in suffixes:
imgs_list.extend(glob.glob(osp.join(img_dir, '*' + suffix)))
imgs_list = [
check_ignore_orientation(f) if is_not_png(f) else f for f in imgs_list
drop_orientation(f) if is_not_png(f) else f for f in imgs_list
]
files = []
@ -152,8 +103,8 @@ def load_txt_info(gt_file, img_info):
iscrowd = 0
area = polygon.area
# convert to COCO style XYWH format
minx, miny, maxx, maxy = polygon.bounds
bbox = [minx, miny, maxx - minx, maxy - miny]
min_x, min_y, max_x, max_y = polygon.bounds
bbox = [min_x, min_y, max_x - min_x, max_y - min_y]
anno = dict(
iscrowd=iscrowd,
@ -245,52 +196,6 @@ def load_img_info(files, split):
return img_info
def convert_annotations(image_infos, out_json_name):
"""Convert the annotation into coco style.
Args:
image_infos(list): The list of image information dicts
out_json_name(str): The output json filename
Returns:
out_json(dict): The coco style dict
"""
assert isinstance(image_infos, list)
assert isinstance(out_json_name, str)
assert out_json_name
out_json = dict()
img_id = 0
ann_id = 0
out_json['images'] = []
out_json['categories'] = []
out_json['annotations'] = []
for image_info in image_infos:
image_info['id'] = img_id
anno_infos = image_info.pop('anno_info')
out_json['images'].append(image_info)
for anno_info in anno_infos:
anno_info['image_id'] = img_id
anno_info['id'] = ann_id
out_json['annotations'].append(anno_info)
ann_id += 1
# if image_info['file_name'].find('png'):
# img = mmcv.imread('data/ctw1500/imgs/'+
# image_info['file_name'], 'color')
# show_img_boundary(img, anno_info['segmentation'] )
img_id += 1
print(img_id)
cat = dict(id=1, name='text')
out_json['categories'].append(cat)
if len(out_json['annotations']) == 0:
out_json.pop('annotations')
mmcv.dump(out_json, out_json_name)
return out_json
def parse_args():
parser = argparse.ArgumentParser(
description='Convert ctw1500 annotations to COCO format')

View File

@ -6,8 +6,9 @@ from functools import partial
import mmcv
import numpy as np
from shapely.geometry import Polygon
from tools.data.utils.common import (check_ignore_orientation,
convert_annotations, is_not_png)
from tools.data.utils.common import convert_annotations, is_not_png
from mmocr.utils import drop_orientation
def collect_files(img_dir, gt_dir):
@ -33,7 +34,7 @@ def collect_files(img_dir, gt_dir):
imgs_list.extend(glob.glob(osp.join(img_dir, '*' + suffix)))
imgs_list = [
check_ignore_orientation(f) if is_not_png(f) else f for f in imgs_list
drop_orientation(f) if is_not_png(f) else f for f in imgs_list
]
files = []
@ -124,8 +125,8 @@ def load_img_info(files, dataset):
area = polygon.area
# convert to COCO style XYWH format
minx, miny, maxx, maxy = polygon.bounds
bbox = [minx, miny, maxx - minx, maxy - miny]
min_x, min_y, max_x, max_y = polygon.bounds
bbox = [min_x, min_y, max_x - min_x, max_y - min_y]
anno = dict(
iscrowd=iscrowd,

View File

@ -11,8 +11,6 @@ from shapely.geometry import Polygon
from mmocr.utils import check_argument
# from mmocr.core.mask import imshow_text_char_boundary
def trace_boundary(char_boxes):
"""Trace the boundary point of text.
@ -41,23 +39,23 @@ def trace_boundary(char_boxes):
return boundary
def match_bbox_char_str(bboxes, charbboxes, strs):
def match_bbox_char_str(bboxes, char_bboxes, strs):
"""match the bboxes, char bboxes, and strs.
Args:
bboxes (ndarray): The text boxes of size 2x4xnum_box.
charbboxes (ndarray): The char boxes of size 2x4xnum_char_box.
bboxes (ndarray): The text boxes of size (2, 4, num_box).
char_bboxes (ndarray): The char boxes of size (2, 4, num_char_box).
strs (ndarray): The string of size (num_strs,)
"""
assert isinstance(bboxes, np.ndarray)
assert isinstance(charbboxes, np.ndarray)
assert isinstance(char_bboxes, np.ndarray)
assert isinstance(strs, np.ndarray)
bboxes = bboxes.astype(np.int32)
charbboxes = charbboxes.astype(np.int32)
char_bboxes = char_bboxes.astype(np.int32)
if len(charbboxes.shape) == 2:
charbboxes = np.expand_dims(charbboxes, axis=2)
charbboxes = np.transpose(charbboxes, (2, 1, 0))
if len(char_bboxes.shape) == 2:
char_bboxes = np.expand_dims(char_bboxes, axis=2)
char_bboxes = np.transpose(char_bboxes, (2, 1, 0))
if len(bboxes.shape) == 2:
bboxes = np.expand_dims(bboxes, axis=2)
bboxes = np.transpose(bboxes, (2, 1, 0))
@ -81,7 +79,7 @@ def match_bbox_char_str(bboxes, charbboxes, strs):
for char_inx in range(start_inx, end_inx):
poly_char_idx_list[word_inx].append(char_inx)
poly_char_list[word_inx].append(chars[char_inx])
poly_charbox_list[word_inx].append(charbboxes[char_inx])
poly_charbox_list[word_inx].append(char_bboxes[char_inx])
start_inx = end_inx
for box_inx in range(num_boxes):
@ -119,18 +117,10 @@ def convert_annotations(root_path, gt_name, lmdb_name):
total_time_sec = time.time() - start_time
avg_time_sec = total_time_sec / img_id
eta_mins = (avg_time_sec * (img_num - img_id)) / 60
print(f'\ncurrent_img/total_imgs {img_id}/{img_num}\
| eta: {eta_mins:.3f} mins')
print(f'\ncurrent_img/total_imgs {img_id}/{img_num} | '
f'eta: {eta_mins:.3f} mins')
# for each img
img_file = osp.join(root_path, 'imgs', gt['imnames'][0][img_id][0])
# read imgs with ignoring orientations
# img = mmcv.imread(img_file, 'unchanged')
# read imgs with orientations as dataloader does when training and
# test
# img_color = mmcv.imread(img_file, 'color')
# make sure imgs have no orientations info, or annotation gt
# is wrong.
# assert img.shape[0:2] == img_color.shape[0:2]
img = mmcv.imread(img_file, 'unchanged')
height, width = img.shape[0:2]
img_json = {}
@ -141,17 +131,13 @@ def convert_annotations(root_path, gt_name, lmdb_name):
wordBB = gt['wordBB'][0][img_id]
charBB = gt['charBB'][0][img_id]
txt = gt['txt'][0][img_id]
poly_list, poly_box_list, poly_boundary_list, poly_charbox_list,\
poly_char_idx_list, poly_char_list = match_bbox_char_str(
wordBB, charBB, txt)
# imshow_text_char_boundary(img_file, poly_box_list, \
# poly_boundary_list,\
# poly_charbox_list, poly_char_list, out_file='tmp.jpg')
poly_list, _, poly_boundary_list, _, _, _ = match_bbox_char_str(
wordBB, charBB, txt)
for poly_inx in range(len(poly_list)):
polygon = poly_list[poly_inx]
minx, miny, maxx, maxy = polygon.bounds
bbox = [minx, miny, maxx - minx, maxy - miny]
min_x, min_y, max_x, max_y = polygon.bounds
bbox = [min_x, min_y, max_x - min_x, max_y - min_y]
anno_info = dict()
anno_info['iscrowd'] = 0
anno_info['category_id'] = 1
@ -159,13 +145,9 @@ def convert_annotations(root_path, gt_name, lmdb_name):
anno_info['segmentation'] = [
poly_boundary_list[poly_inx].flatten().tolist()
]
# anno_info['text'] = ''.join(poly_char_list[poly_inx])
# anno_info['char_boxes'] =
# np.concatenate(poly_charbox_list[poly_inx]).flatten().tolist()
img_json['annotations'].append(anno_info)
string = json.dumps(img_json)
# print(len(string))
txn.put(str(img_id).encode('utf8'), string.encode('utf8'))
key = 'total_number'.encode('utf8')
value = str(img_num).encode('utf8')

View File

@ -8,8 +8,9 @@ import mmcv
import numpy as np
import scipy.io as scio
from shapely.geometry import Polygon
from tools.data_converter.common import (check_ignore_orientation,
convert_annotations, is_not_png)
from tools.data_converter.common import convert_annotations, is_not_png
from mmocr.utils import drop_orientation
def collect_files(img_dir, gt_dir, split):
@ -38,7 +39,7 @@ def collect_files(img_dir, gt_dir, split):
imgs_list.extend(glob.glob(osp.join(img_dir, '*' + suffix)))
imgs_list = [
check_ignore_orientation(f) if is_not_png(f) else f for f in imgs_list
drop_orientation(f) if is_not_png(f) else f for f in imgs_list
]
files = []
@ -166,8 +167,8 @@ def load_mat_info(img_info, gt_file, split):
area = polygon.area
# convert to COCO style XYWH format
minx, miny, maxx, maxy = polygon.bounds
bbox = [minx, miny, maxx - minx, maxy - miny]
min_x, min_y, max_x, max_y = polygon.bounds
bbox = [min_x, min_y, max_x - min_x, max_y - min_y]
anno = dict(
iscrowd=iscrowd,
@ -211,8 +212,8 @@ def load_png_info(gt_file, img_info):
area = polygon.area
# convert to COCO style XYWH format
minx, miny, maxx, maxy = polygon.bounds
bbox = [minx, miny, maxx - minx, maxy - miny]
min_x, min_y, max_x, max_y = polygon.bounds
bbox = [min_x, min_y, max_x - min_x, max_y - min_y]
anno = dict(
iscrowd=iscrowd,
@ -241,11 +242,11 @@ def load_img_info(files, split):
assert isinstance(split, str)
img_file, gt_file = files
# read imgs with ignoring oritations
# read imgs with ignoring orientations
img = mmcv.imread(img_file, 'unchanged')
# read imgs with oritations as dataloader does when training and testing
# read imgs with orientations as dataloader does when training and testing
img_color = mmcv.imread(img_file, 'color')
# make sure imgs have no oritation info, or annotation gt is wrong.
# make sure imgs have no orientation info, or annotation gt is wrong.
assert img.shape[0:2] == img_color.shape[0:2]
split_name = osp.basename(osp.dirname(img_file))

View File

@ -1,42 +1,8 @@
import os
import os.path as osp
import mmcv
def check_ignore_orientation(img_file):
"""Check if the image has orientation information.
If yes, ignore it by converting the image format to png, otherwise return
the original filename.
Args:
img_file(str): The image path
Returns:
The converted image filename with proper postfix
"""
assert isinstance(img_file, str)
assert img_file
# read imgs with ignoring oritations
img = mmcv.imread(img_file, 'unchanged')
# read imgs with oritations as dataloader does when training and testing
img_color = mmcv.imread(img_file, 'color')
# make sure imgs have no oritation info, or annotation gt is wrong.
if img.shape[:2] == img_color.shape[:2]:
return img_file
else:
target_file = osp.splitext(img_file)[0] + '.png'
# read img with ignoring orientation information
img = mmcv.imread(img_file, 'unchanged')
mmcv.imwrite(img, target_file)
os.remove(img_file)
print(
f'{img_file} has orientation info. Ingore it by converting to png')
return target_file
def is_not_png(img_file):
"""Check img_file is not png image.
@ -55,7 +21,7 @@ def is_not_png(img_file):
def convert_annotations(image_infos, out_json_name):
"""Convert the annotion into coco style.
"""Convert the annotation into coco style.
Args:
image_infos(list): The list of image information dicts
@ -83,12 +49,7 @@ def convert_annotations(image_infos, out_json_name):
anno_info['id'] = ann_id
out_json['annotations'].append(anno_info)
ann_id += 1
# if image_info['file_name'].find('png'):
# img = mmcv.imread('data/ctw1500/imgs/'+
# image_info['file_name'], 'color')
# show_img_boundary(img, anno_info['segmentation'] )
img_id += 1
# print(img_id)
cat = dict(id=1, name='text')
out_json['categories'].append(cat)

View File

@ -69,9 +69,10 @@ def parse_args():
def main():
args = parse_args()
assert args.show or args.show_dir, \
('Please specify at least one operation (show the results'
' / save the results) with the argument "--show" or "--show-dir".')
assert args.show or args.show_dir, ('Please specify at least one '
'operation (show the results / save )'
'the results with the argument '
'"--show" or "--show-dir".')
cfg = Config.fromfile(args.config)
# import modules from string list.

View File

@ -65,8 +65,8 @@ def main():
if hasattr(model, 'module'):
model = model.module
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = \
model.cfg.data.test['datasets'][0].pipeline
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
# Start Inference
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')
@ -120,8 +120,8 @@ def main():
# eval
eval_results = eval_ocr_metric(pred_labels, gt_labels)
logger.info('\n' + '-' * 100)
info = 'eval on testset with img_root_path ' + \
f'{args.img_root_path} and img_list {args.img_list}\n'
info = ('eval on testset with img_root_path '
f'{args.img_root_path} and img_list {args.img_list}\n')
logger.info(info)
logger.info(eval_results)

View File

@ -102,11 +102,12 @@ def parse_args():
def main():
args = parse_args()
assert args.out or args.eval or args.format_only or args.show \
or args.show_dir, \
('Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir".')
assert (
args.out or args.eval or args.format_only or args.show
or args.show_dir), (
'Please specify at least one operation (save/eval/format/show the '
'results / save the results) with the argument "--out", "--eval"'
', "--format-only", "--show" or "--show-dir".')
if args.eval and args.format_only:
raise ValueError('--eval and --format_only cannot be both specified.')

View File

@ -42,35 +42,35 @@ def save_2darray(mat, file_name):
fw.write(row_str + '\n')
def save_bboxes_quadrangels(bboxes_with_scores,
qudrangels_with_scores,
def save_bboxes_quadrangles(bboxes_with_scores,
quadrangles_with_scores,
img_name,
out_bbox_txt_dir,
out_quadrangel_txt_dir,
out_quadrangle_txt_dir,
score_thr=0.3,
save_score=True):
"""Save results of detected bounding boxes and quadrangels to txt file.
"""Save results of detected bounding boxes and quadrangles to txt file.
Args:
bboxes_with_scores (ndarray): Detected bboxes of shape (n,5).
qudrangels_with_scores (ndarray): Detected quadrangels of shape (n,9).
quadrangles_with_scores (ndarray): Detected quadrangles of shape (n,9).
img_name (str): Image file name.
out_bbox_txt_dir (str): Dir of txt files to save detected bboxes
results.
out_quadrangel_txt_dir (str): Dir of txt files to save
quadrangel results.
out_quadrangle_txt_dir (str): Dir of txt files to save
quadrangle results.
score_thr (float, optional): Score threshold for bboxes.
save_score (bool, optional): Whether to save score at each line end
to search best threshold when evaluating.
"""
assert bboxes_with_scores.ndim == 2
assert bboxes_with_scores.shape[1] == 5 or bboxes_with_scores.shape[1] == 9
assert qudrangels_with_scores.ndim == 2
assert qudrangels_with_scores.shape[1] == 9
assert bboxes_with_scores.shape[0] >= qudrangels_with_scores.shape[0]
assert quadrangles_with_scores.ndim == 2
assert quadrangles_with_scores.shape[1] == 9
assert bboxes_with_scores.shape[0] >= quadrangles_with_scores.shape[0]
assert isinstance(img_name, str)
assert isinstance(out_bbox_txt_dir, str)
assert isinstance(out_quadrangel_txt_dir, str)
assert isinstance(out_quadrangle_txt_dir, str)
assert isinstance(score_thr, float)
assert score_thr >= 0 and score_thr < 1
@ -87,25 +87,25 @@ def save_bboxes_quadrangels(bboxes_with_scores,
elif initial_valid_bboxes.shape[1] == 8:
valid_bboxes = initial_valid_bboxes
valid_quadrangels, valid_quadrangel_scores = filter_result(
qudrangels_with_scores[:, :-1], qudrangels_with_scores[:, -1],
valid_quadrangles, valid_quadrangle_scores = filter_result(
quadrangles_with_scores[:, :-1], quadrangles_with_scores[:, -1],
score_thr)
# gen target file path
bbox_txt_file = gen_target_path(out_bbox_txt_dir, img_name, '.txt')
quadrangel_txt_file = gen_target_path(out_quadrangel_txt_dir, img_name,
quadrangle_txt_file = gen_target_path(out_quadrangle_txt_dir, img_name,
'.txt')
# save txt
if save_score:
valid_bboxes = np.concatenate(
(valid_bboxes, valid_bbox_scores.reshape(-1, 1)), axis=1)
valid_quadrangels = np.concatenate(
(valid_quadrangels, valid_quadrangel_scores.reshape(-1, 1)),
valid_quadrangles = np.concatenate(
(valid_quadrangles, valid_quadrangle_scores.reshape(-1, 1)),
axis=1)
save_2darray(valid_bboxes, bbox_txt_file)
save_2darray(valid_quadrangels, quadrangel_txt_file)
save_2darray(valid_quadrangles, quadrangle_txt_file)
def main():
@ -134,8 +134,8 @@ def main():
if hasattr(model, 'module'):
model = model.module
if model.cfg.data.test['type'] == 'ConcatDataset':
model.cfg.data.test.pipeline = \
model.cfg.data.test['datasets'][0].pipeline
model.cfg.data.test.pipeline = model.cfg.data.test['datasets'][
0].pipeline
# Start Inference
out_vis_dir = osp.join(args.out_dir, 'out_vis_dir')