add spin
parent
f56a7e9c45
commit
46e3442e2e
|
@ -75,7 +75,7 @@ Train:
|
|||
data_dir: ./train_data/ic15_data/
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
|
||||
transforms:
|
||||
- NRTRDecodeImage: # load image
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SPINAttnLabelEncode: # Class handling label
|
||||
|
@ -98,7 +98,7 @@ Eval:
|
|||
data_dir: ./train_data/ic15_data
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
|
||||
transforms:
|
||||
- NRTRDecodeImage: # load image
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SPINAttnLabelEncode: # Class handling label
|
||||
|
|
|
@ -274,6 +274,7 @@ class SPINRecResizeImg(object):
|
|||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
# different interpolation type corresponding the OpenCV
|
||||
if self.interpolation == 0:
|
||||
interpolation = cv2.INTER_NEAREST
|
||||
|
@ -294,12 +295,9 @@ class SPINRecResizeImg(object):
|
|||
img = np.expand_dims(img, -1)
|
||||
img = img.transpose((2, 0, 1))
|
||||
# normalize the image
|
||||
to_rgb = False
|
||||
img = img.copy().astype(np.float32)
|
||||
mean = np.float64(self.mean.reshape(1, -1))
|
||||
stdinv = 1 / np.float64(self.std.reshape(1, -1))
|
||||
if to_rgb:
|
||||
cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img -= mean
|
||||
img *= stdinv
|
||||
data['image'] = img
|
||||
|
|
|
@ -76,7 +76,7 @@ Train:
|
|||
data_dir: ./train_data/ic15_data/
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
|
||||
transforms:
|
||||
- NRTRDecodeImage: # load image
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SPINAttnLabelEncode: # Class handling label
|
||||
|
@ -99,7 +99,7 @@ Eval:
|
|||
data_dir: ./train_data/ic15_data
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
|
||||
transforms:
|
||||
- NRTRDecodeImage: # load image
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SPINAttnLabelEncode: # Class handling label
|
||||
|
|
|
@ -91,7 +91,7 @@ def export_single_model(model,
|
|||
]
|
||||
# print([None, 3, 32, 128])
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "NRTR":
|
||||
elif arch_config["algorithm"] == "NRTR" or arch_config["algorithm"] == "SPIN":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 1, 32, 100], dtype="float32"),
|
||||
|
|
|
@ -81,7 +81,6 @@ class TextRecognizer(object):
|
|||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
|
||||
elif self.rec_algorithm == "SPIN":
|
||||
postprocess_params = {
|
||||
'name': 'SPINAttnLabelDecode',
|
||||
|
@ -362,6 +361,8 @@ class TextRecognizer(object):
|
|||
norm_img_batch.append(norm_img)
|
||||
elif self.rec_algorithm == 'SPIN':
|
||||
norm_img = self.resize_norm_img_spin(img_list[indices[ino]])
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
elif self.rec_algorithm == "ABINet":
|
||||
norm_img = self.resize_norm_img_abinet(
|
||||
img_list[indices[ino]], self.rec_image_shape)
|
||||
|
|
Loading…
Reference in New Issue