diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py
index 4d9c52740..8babac655 100644
--- a/PPOCRLabel/PPOCRLabel.py
+++ b/PPOCRLabel/PPOCRLabel.py
@@ -1031,7 +1031,7 @@ class MainWindow(QMainWindow, WindowMixin):
for box in self.result_dic:
trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False}
- if trans_dic["label"] is "" and mode == 'Auto':
+ if trans_dic["label"] == "" and mode == 'Auto':
continue
shapes.append(trans_dic)
@@ -1450,7 +1450,7 @@ class MainWindow(QMainWindow, WindowMixin):
item = QListWidgetItem(closeicon, filename)
self.fileListWidget.addItem(item)
- print('dirPath in importDirImages is', dirpath)
+ print('DirPath in importDirImages is', dirpath)
self.iconlist.clear()
self.additems5(dirpath)
self.changeFileFolder = True
@@ -1459,7 +1459,6 @@ class MainWindow(QMainWindow, WindowMixin):
self.reRecogButton.setEnabled(True)
self.actions.AutoRec.setEnabled(True)
self.actions.reRec.setEnabled(True)
- self.actions.saveLabel.setEnabled(True)
def openPrevImg(self, _value=False):
@@ -1764,7 +1763,7 @@ class MainWindow(QMainWindow, WindowMixin):
QMessageBox.information(self, "Information", msg)
return
result = self.ocr.ocr(img_crop, cls=True, det=False)
- if result[0][0] is not '':
+ if result[0][0] != '':
result.insert(0, box)
print('result in reRec is ', result)
self.result_dic.append(result)
@@ -1795,7 +1794,7 @@ class MainWindow(QMainWindow, WindowMixin):
QMessageBox.information(self, "Information", msg)
return
result = self.ocr.ocr(img_crop, cls=True, det=False)
- if result[0][0] is not '':
+ if result[0][0] != '':
result.insert(0, box)
print('result in reRec is ', result)
if result[1][0] == shape.label:
@@ -1862,6 +1861,8 @@ class MainWindow(QMainWindow, WindowMixin):
for each in states:
file, state = each.split('\t')
self.fileStatedict[file] = 1
+ self.actions.saveLabel.setEnabled(True)
+ self.actions.saveRec.setEnabled(True)
def saveFilestate(self):
@@ -1919,22 +1920,29 @@ class MainWindow(QMainWindow, WindowMixin):
rec_gt_dir = os.path.dirname(self.PPlabelpath) + '/rec_gt.txt'
crop_img_dir = os.path.dirname(self.PPlabelpath) + '/crop_img/'
+ ques_img = []
if not os.path.exists(crop_img_dir):
os.mkdir(crop_img_dir)
with open(rec_gt_dir, 'w', encoding='utf-8') as f:
for key in self.fileStatedict:
idx = self.getImglabelidx(key)
- for i, label in enumerate(self.PPlabel[idx]):
- if label['difficult']: continue
+ try:
img = cv2.imread(key)
- img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32))
- img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_'+str(i)+'.jpg'
- cv2.imwrite(crop_img_dir+img_name, img_crop)
- f.write('crop_img/'+ img_name + '\t')
- f.write(label['transcription'] + '\n')
-
- QMessageBox.information(self, "Information", "Cropped images has been saved in "+str(crop_img_dir))
+ for i, label in enumerate(self.PPlabel[idx]):
+ if label['difficult']: continue
+ img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32))
+ img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_'+str(i)+'.jpg'
+ cv2.imwrite(crop_img_dir+img_name, img_crop)
+ f.write('crop_img/'+ img_name + '\t')
+ f.write(label['transcription'] + '\n')
+ except Exception as e:
+ ques_img.append(key)
+ print("Can not read image ",e)
+ if ques_img:
+ QMessageBox.information(self, "Information", "The following images can not be saved, "
+ "please check the image path and labels.\n" + "".join(str(i)+'\n' for i in ques_img))
+ QMessageBox.information(self, "Information", "Cropped images have been saved in "+str(crop_img_dir))
def speedChoose(self):
if self.labelDialogOption.isChecked():
@@ -1991,7 +1999,7 @@ if __name__ == '__main__':
resource_file = './libs/resources.py'
if not os.path.exists(resource_file):
output = os.system('pyrcc5 -o libs/resources.py resources.qrc')
- assert output is 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
+ assert output == 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
"directory resources.py "
import libs.resources
sys.exit(main())
diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml
index 38f1e8691..00c1db885 100644
--- a/configs/rec/rec_mv3_none_bilstm_ctc.yml
+++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml
@@ -1,5 +1,5 @@
Global:
- use_gpu: true
+ use_gpu: True
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
@@ -59,7 +59,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -78,7 +78,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_mv3_none_none_ctc.yml b/configs/rec/rec_mv3_none_none_ctc.yml
index 33079ad48..6711b1d23 100644
--- a/configs/rec/rec_mv3_none_none_ctc.yml
+++ b/configs/rec/rec_mv3_none_none_ctc.yml
@@ -58,7 +58,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -77,7 +77,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml
index 08f68939d..1b9fb0a08 100644
--- a/configs/rec/rec_mv3_tps_bilstm_ctc.yml
+++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml
@@ -63,7 +63,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -82,7 +82,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml
index 4ad2ff89e..e4d301a6a 100644
--- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml
+++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml
@@ -58,7 +58,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -77,7 +77,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml
index 9c1eeb304..4a17a0042 100644
--- a/configs/rec/rec_r34_vd_none_none_ctc.yml
+++ b/configs/rec/rec_r34_vd_none_none_ctc.yml
@@ -56,7 +56,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -75,7 +75,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
index aeded4926..62edf8437 100644
--- a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
+++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
@@ -62,7 +62,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -81,7 +81,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_r50_fpn_srn.yml b/configs/rec/rec_r50_fpn_srn.yml
new file mode 100644
index 000000000..ec7f17056
--- /dev/null
+++ b/configs/rec/rec_r50_fpn_srn.yml
@@ -0,0 +1,107 @@
+Global:
+ use_gpu: True
+ epoch_num: 72
+ log_smooth_window: 20
+ print_batch_step: 5
+ save_model_dir: ./output/rec/srn_new
+ save_epoch_step: 3
+ # evaluation is run every 5000 iterations after the 4000th iteration
+ eval_batch_step: [0, 5000]
+ # if pretrained_model is saved in static mode, load_static_weights must set to True
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ # for data or label process
+ character_dict_path:
+ character_type: en
+ max_text_length: 25
+ num_heads: 8
+ infer_mode: False
+ use_space_char: False
+
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ clip_norm: 10.0
+ lr:
+ learning_rate: 0.0001
+
+Architecture:
+ model_type: rec
+ algorithm: SRN
+ in_channels: 1
+ Transform:
+ Backbone:
+ name: ResNetFPN
+ Head:
+ name: SRNHead
+ max_text_length: 25
+ num_heads: 8
+ num_encoder_TUs: 2
+ num_decoder_TUs: 4
+ hidden_dims: 512
+
+Loss:
+ name: SRNLoss
+
+PostProcess:
+ name: SRNLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/srn_train_data_duiqi
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SRNLabelEncode: # Class handling label
+ - SRNRecResizeImg:
+ image_shape: [1, 64, 256]
+ - KeepKeys:
+ keep_keys: ['image',
+ 'label',
+ 'length',
+ 'encoder_word_pos',
+ 'gsrm_word_pos',
+ 'gsrm_slf_attn_bias1',
+ 'gsrm_slf_attn_bias2'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ batch_size_per_card: 64
+ drop_last: False
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/data_lmdb_release/evaluation
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SRNLabelEncode: # Class handling label
+ - SRNRecResizeImg:
+ image_shape: [1, 64, 256]
+ - KeepKeys:
+ keep_keys: ['image',
+ 'label',
+ 'length',
+ 'encoder_word_pos',
+ 'gsrm_word_pos',
+ 'gsrm_slf_attn_bias1',
+ 'gsrm_slf_attn_bias2']
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 32
+ num_workers: 4
diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md
index 59d1bc8c4..abbc5da4c 100755
--- a/doc/doc_ch/algorithm_overview.md
+++ b/doc/doc_ch/algorithm_overview.md
@@ -41,7 +41,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon
-- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon
+- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
@@ -53,5 +53,6 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
+|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
index c4601e152..0daddd9bb 100755
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -22,8 +22,9 @@ inference 模型(`paddle.jit.save`保存的模型)
- [三、文本识别模型推理](#文本识别模型推理)
- [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理)
- [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理)
- - [3. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- - [4. 多语言模型的推理](#多语言模型的推理)
+ - [3. 基于SRN损失的识别模型推理](#基于SRN损失的识别模型推理)
+ - [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
+ - [5. 多语言模型的推理](#多语言模型的推理)
- [四、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理)
@@ -295,8 +296,20 @@ Predicts of ./doc/imgs_words_en/word_336.png:('super', 0.9999073)
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
```
+
+### 3. 基于SRN损失的识别模型推理
+基于SRN损失的识别模型,需要额外设置识别算法参数 --rec_algorithm="SRN"。
+同时需要保证预测shape与训练时一致,如: --rec_image_shape="1, 64, 256"
-### 3. 自定义文本识别字典的推理
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \
+ --rec_model_dir="./inference/srn/" \
+ --rec_image_shape="1, 64, 256" \
+ --rec_char_type="en" \
+ --rec_algorithm="SRN"
+```
+
+### 4. 自定义文本识别字典的推理
如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch`
```
@@ -304,7 +317,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
```
-### 4. 多语言模型的推理
+### 5. 多语言模型的推理
如果您需要预测的是其他语言模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,
需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:
diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md
index c5f459bdb..bc877ab78 100644
--- a/doc/doc_ch/recognition.md
+++ b/doc/doc_ch/recognition.md
@@ -36,6 +36,7 @@ ln -sf /train_data/dataset
* 数据下载
若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。
+如果希望复现SRN的论文指标,需要下载离线[增广数据](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA),提取码: y3ry。增广数据是由MJSynth和SynthText做旋转和扰动得到的。数据下载完成后请解压到 {your_path}/PaddleOCR/train_data/data_lmdb_release/training/ 路径下。
* 使用自己数据集
@@ -200,6 +201,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
+| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index 68bfd5299..7d7896e71 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -43,7 +43,7 @@ PaddleOCR open-source text recognition algorithms list:
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon
-- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon
+- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
@@ -55,5 +55,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
+|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md
index ccbb71847..c8ce1424f 100755
--- a/doc/doc_en/inference_en.md
+++ b/doc/doc_en/inference_en.md
@@ -25,6 +25,7 @@ Next, we first introduce how to convert a trained model into an inference model,
- [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE)
- [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION)
- [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION)
+ - [3. SRN-BASED TEXT RECOGNITION MODEL INFERENCE](#SRN-BASED_RECOGNITION)
- [3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
- [4. MULTILINGUAL MODEL INFERENCE](MULTILINGUAL_MODEL_INFERENCE)
@@ -304,8 +305,23 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
```
+
+### 3. SRN-BASED TEXT RECOGNITION MODEL INFERENCE
+
+The recognition model based on SRN requires additional setting of the recognition algorithm parameter
+--rec_algorithm="SRN". At the same time, it is necessary to ensure that the predicted shape is consistent
+with the training, such as: --rec_image_shape="1, 64, 256"
+
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \
+ --rec_model_dir="./inference/srn/" \
+ --rec_image_shape="1, 64, 256" \
+ --rec_char_type="en" \
+ --rec_algorithm="SRN"
+```
+
-### 3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
+### 4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
```
@@ -313,7 +329,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
```
-### 4. MULTILINGAUL MODEL INFERENCE
+### 5. MULTILINGAUL MODEL INFERENCE
If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md
index 22f89cdef..f29703d14 100644
--- a/doc/doc_en/recognition_en.md
+++ b/doc/doc_en/recognition_en.md
@@ -195,6 +195,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
+| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
For training Chinese data, it is recommended to use
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index 4809886b7..7cb50d7a6 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -33,7 +33,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
-from ppocr.data.lmdb_dataset import LMDBDateSet
+from ppocr.data.lmdb_dataset import LMDBDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
@@ -54,7 +54,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
- support_dict = ['SimpleDataSet', 'LMDBDateSet']
+ support_dict = ['SimpleDataSet', 'LMDBDataSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 6ea4dd8ed..250ac75e7 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop
-from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg
+from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .randaugment import RandAugment
from .operators import *
from .label_ops import *
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 6d9ea1902..191bda92c 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -102,6 +102,8 @@ class BaseRecLabelEncode(object):
support_character_type, character_type)
self.max_text_len = max_text_length
+ self.beg_str = "sos"
+ self.end_str = "eos"
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
@@ -231,3 +233,49 @@ class AttnLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
+
+
+class SRNLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length=25,
+ character_dict_path=None,
+ character_type='en',
+ use_space_char=False,
+ **kwargs):
+ super(SRNLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ character_type, use_space_char)
+
+ def add_special_char(self, dict_character):
+ dict_character = dict_character + [self.beg_str, self.end_str]
+ return dict_character
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ char_num = len(self.character_str)
+ if text is None:
+ return None
+ if len(text) > self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ text = text + [char_num] * (self.max_text_len - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def get_ignored_tokens(self):
+ beg_idx = self.get_beg_end_flag_idx("beg")
+ end_idx = self.get_beg_end_flag_idx("end")
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "beg":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
+ % beg_or_end
+ return idx
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 2ccb2d1d2..28e6bd0bc 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
import math
import cv2
import numpy as np
@@ -77,6 +63,26 @@ class RecResizeImg(object):
return data
+class SRNRecResizeImg(object):
+ def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
+ self.image_shape = image_shape
+ self.num_heads = num_heads
+ self.max_text_length = max_text_length
+
+ def __call__(self, data):
+ img = data['image']
+ norm_img = resize_norm_img_srn(img, self.image_shape)
+ data['image'] = norm_img
+ [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
+ srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
+
+ data['encoder_word_pos'] = encoder_word_pos
+ data['gsrm_word_pos'] = gsrm_word_pos
+ data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
+ data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
+ return data
+
+
def resize_norm_img(img, image_shape):
imgC, imgH, imgW = image_shape
h = img.shape[0]
@@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape):
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
- max_wh_ratio = 0
+ max_wh_ratio = imgW * 1.0 / imgH
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
@@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape):
return padding_im
+def resize_norm_img_srn(img, image_shape):
+ imgC, imgH, imgW = image_shape
+
+ img_black = np.zeros((imgH, imgW))
+ im_hei = img.shape[0]
+ im_wid = img.shape[1]
+
+ if im_wid <= im_hei * 1:
+ img_new = cv2.resize(img, (imgH * 1, imgH))
+ elif im_wid <= im_hei * 2:
+ img_new = cv2.resize(img, (imgH * 2, imgH))
+ elif im_wid <= im_hei * 3:
+ img_new = cv2.resize(img, (imgH * 3, imgH))
+ else:
+ img_new = cv2.resize(img, (imgW, imgH))
+
+ img_np = np.asarray(img_new)
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
+ img_black[:, 0:img_np.shape[1]] = img_np
+ img_black = img_black[:, :, np.newaxis]
+
+ row, col, c = img_black.shape
+ c = 1
+
+ return np.reshape(img_black, (c, row, col)).astype(np.float32)
+
+
+def srn_other_inputs(image_shape, num_heads, max_text_length):
+
+ imgC, imgH, imgW = image_shape
+ feature_dim = int((imgH / 8) * (imgW / 8))
+
+ encoder_word_pos = np.array(range(0, feature_dim)).reshape(
+ (feature_dim, 1)).astype('int64')
+ gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
+ (max_text_length, 1)).astype('int64')
+
+ gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
+ gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
+ [1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
+ [num_heads, 1, 1]) * [-1e9]
+
+ gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
+ [1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
+ [num_heads, 1, 1]) * [-1e9]
+
+ return [
+ encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2
+ ]
+
+
def flag():
"""
flag
diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py
index bd0630f63..e2d6dc932 100644
--- a/ppocr/data/lmdb_dataset.py
+++ b/ppocr/data/lmdb_dataset.py
@@ -20,9 +20,9 @@ import cv2
from .imaug import transform, create_operators
-class LMDBDateSet(Dataset):
+class LMDBDataSet(Dataset):
def __init__(self, config, mode, logger, seed=None):
- super(LMDBDateSet, self).__init__()
+ super(LMDBDataSet, self).__init__()
global_config = config['Global']
dataset_config = config[mode]['dataset']
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 94314235c..3881abf77 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -24,12 +24,14 @@ def build_loss(config):
# rec loss
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
+ from .rec_srn_loss import SRNLoss
# cls loss
from .cls_loss import ClsLoss
support_dict = [
- 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss'
+ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
+ 'SRNLoss'
]
config = copy.deepcopy(config)
diff --git a/ppocr/losses/rec_srn_loss.py b/ppocr/losses/rec_srn_loss.py
new file mode 100644
index 000000000..7d5b65eba
--- /dev/null
+++ b/ppocr/losses/rec_srn_loss.py
@@ -0,0 +1,47 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class SRNLoss(nn.Layer):
+ def __init__(self, **kwargs):
+ super(SRNLoss, self).__init__()
+ self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
+
+ def forward(self, predicts, batch):
+ predict = predicts['predict']
+ word_predict = predicts['word_out']
+ gsrm_predict = predicts['gsrm_out']
+ label = batch[1]
+
+ casted_label = paddle.cast(x=label, dtype='int64')
+ casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
+
+ cost_word = self.loss_func(word_predict, label=casted_label)
+ cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
+ cost_vsfd = self.loss_func(predict, label=casted_label)
+
+ cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
+ cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
+ cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
+
+ sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
+
+ return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}
diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py
index a86fc8382..b3aa9f38f 100644
--- a/ppocr/metrics/rec_metric.py
+++ b/ppocr/metrics/rec_metric.py
@@ -33,8 +33,6 @@ class RecMetric(object):
if pred == target:
correct_num += 1
all_num += 1
- # if all_num < 10 and kwargs.get('show_str', False):
- # print('{} -> {}'.format(pred, target))
self.correct_num += correct_num
self.all_num += all_num
self.norm_edit_dis += norm_edit_dis
@@ -50,7 +48,7 @@ class RecMetric(object):
'norm_edit_dis': 0,
}
"""
- acc = self.correct_num / self.all_num
+ acc = 1.0 * self.correct_num / self.all_num
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index ab44b53a2..09b6e0346 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -68,11 +68,14 @@ class BaseModel(nn.Layer):
config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"])
- def forward(self, x):
+ def forward(self, x, data=None):
if self.use_transform:
x = self.transform(x)
x = self.backbone(x)
if self.use_neck:
x = self.neck(x)
- x = self.head(x)
+ if data is None:
+ x = self.head(x)
+ else:
+ x = self.head(x, data)
return x
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 43103e53d..03c15508a 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -24,7 +24,8 @@ def build_backbone(config, model_type):
elif model_type == 'rec' or model_type == 'cls':
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
- support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN']
+ from .rec_resnet_fpn import ResNetFPN
+ support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
else:
raise NotImplementedError
diff --git a/ppocr/modeling/backbones/rec_resnet_fpn.py b/ppocr/modeling/backbones/rec_resnet_fpn.py
new file mode 100644
index 000000000..a7e876a2b
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_resnet_fpn.py
@@ -0,0 +1,307 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import paddle.fluid as fluid
+import paddle
+import numpy as np
+
+__all__ = ["ResNetFPN"]
+
+
+class ResNetFPN(nn.Layer):
+ def __init__(self, in_channels=1, layers=50, **kwargs):
+ super(ResNetFPN, self).__init__()
+ supported_layers = {
+ 18: {
+ 'depth': [2, 2, 2, 2],
+ 'block_class': BasicBlock
+ },
+ 34: {
+ 'depth': [3, 4, 6, 3],
+ 'block_class': BasicBlock
+ },
+ 50: {
+ 'depth': [3, 4, 6, 3],
+ 'block_class': BottleneckBlock
+ },
+ 101: {
+ 'depth': [3, 4, 23, 3],
+ 'block_class': BottleneckBlock
+ },
+ 152: {
+ 'depth': [3, 8, 36, 3],
+ 'block_class': BottleneckBlock
+ }
+ }
+ stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)]
+ num_filters = [64, 128, 256, 512]
+ self.depth = supported_layers[layers]['depth']
+ self.F = []
+ self.conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=64,
+ kernel_size=7,
+ stride=2,
+ act="relu",
+ name="conv1")
+ self.block_list = []
+ in_ch = 64
+ if layers >= 50:
+ for block in range(len(self.depth)):
+ for i in range(self.depth[block]):
+ if layers in [101, 152] and block == 2:
+ if i == 0:
+ conv_name = "res" + str(block + 2) + "a"
+ else:
+ conv_name = "res" + str(block + 2) + "b" + str(i)
+ else:
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ block_list = self.add_sublayer(
+ "bottleneckBlock_{}_{}".format(block, i),
+ BottleneckBlock(
+ in_channels=in_ch,
+ out_channels=num_filters[block],
+ stride=stride_list[block] if i == 0 else 1,
+ name=conv_name))
+ in_ch = num_filters[block] * 4
+ self.block_list.append(block_list)
+ self.F.append(block_list)
+ else:
+ for block in range(len(self.depth)):
+ for i in range(self.depth[block]):
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ if i == 0 and block != 0:
+ stride = (2, 1)
+ else:
+ stride = (1, 1)
+ basic_block = self.add_sublayer(
+ conv_name,
+ BasicBlock(
+ in_channels=in_ch,
+ out_channels=num_filters[block],
+ stride=stride_list[block] if i == 0 else 1,
+ is_first=block == i == 0,
+ name=conv_name))
+ in_ch = basic_block.out_channels
+ self.block_list.append(basic_block)
+ out_ch_list = [in_ch // 4, in_ch // 2, in_ch]
+ self.base_block = []
+ self.conv_trans = []
+ self.bn_block = []
+ for i in [-2, -3]:
+ in_channels = out_ch_list[i + 1] + out_ch_list[i]
+
+ self.base_block.append(
+ self.add_sublayer(
+ "F_{}_base_block_0".format(i),
+ nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_ch_list[i],
+ kernel_size=1,
+ weight_attr=ParamAttr(trainable=True),
+ bias_attr=ParamAttr(trainable=True))))
+ self.base_block.append(
+ self.add_sublayer(
+ "F_{}_base_block_1".format(i),
+ nn.Conv2D(
+ in_channels=out_ch_list[i],
+ out_channels=out_ch_list[i],
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(trainable=True),
+ bias_attr=ParamAttr(trainable=True))))
+ self.base_block.append(
+ self.add_sublayer(
+ "F_{}_base_block_2".format(i),
+ nn.BatchNorm(
+ num_channels=out_ch_list[i],
+ act="relu",
+ param_attr=ParamAttr(trainable=True),
+ bias_attr=ParamAttr(trainable=True))))
+ self.base_block.append(
+ self.add_sublayer(
+ "F_{}_base_block_3".format(i),
+ nn.Conv2D(
+ in_channels=out_ch_list[i],
+ out_channels=512,
+ kernel_size=1,
+ bias_attr=ParamAttr(trainable=True),
+ weight_attr=ParamAttr(trainable=True))))
+ self.out_channels = 512
+
+ def __call__(self, x):
+ x = self.conv(x)
+ fpn_list = []
+ F = []
+ for i in range(len(self.depth)):
+ fpn_list.append(np.sum(self.depth[:i + 1]))
+
+ for i, block in enumerate(self.block_list):
+ x = block(x)
+ for number in fpn_list:
+ if i + 1 == number:
+ F.append(x)
+ base = F[-1]
+
+ j = 0
+ for i, block in enumerate(self.base_block):
+ if i % 3 == 0 and i < 6:
+ j = j + 1
+ b, c, w, h = F[-j - 1].shape
+ if [w, h] == list(base.shape[2:]):
+ base = base
+ else:
+ base = self.conv_trans[j - 1](base)
+ base = self.bn_block[j - 1](base)
+ base = paddle.concat([base, F[-j - 1]], axis=1)
+ base = block(base)
+ return base
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ act=None,
+ name=None):
+ super(ConvBNLayer, self).__init__()
+ self.conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=2 if stride == (1, 1) else kernel_size,
+ dilation=2 if stride == (1, 1) else 1,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + '.conv2d.output.1.w_0'),
+ bias_attr=False, )
+
+ if name == "conv1":
+ bn_name = "bn_" + name
+ else:
+ bn_name = "bn" + name[3:]
+ self.bn = nn.BatchNorm(
+ num_channels=out_channels,
+ act=act,
+ param_attr=ParamAttr(name=name + '.output.1.w_0'),
+ bias_attr=ParamAttr(name=name + '.output.1.b_0'),
+ moving_mean_name=bn_name + "_mean",
+ moving_variance_name=bn_name + "_variance")
+
+ def __call__(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class ShortCut(nn.Layer):
+ def __init__(self, in_channels, out_channels, stride, name, is_first=False):
+ super(ShortCut, self).__init__()
+ self.use_conv = True
+
+ if in_channels != out_channels or stride != 1 or is_first == True:
+ if stride == (1, 1):
+ self.conv = ConvBNLayer(
+ in_channels, out_channels, 1, 1, name=name)
+ else: # stride==(2,2)
+ self.conv = ConvBNLayer(
+ in_channels, out_channels, 1, stride, name=name)
+ else:
+ self.use_conv = False
+
+ def forward(self, x):
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class BottleneckBlock(nn.Layer):
+ def __init__(self, in_channels, out_channels, stride, name):
+ super(BottleneckBlock, self).__init__()
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ act='relu',
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ name=name + "_branch2b")
+
+ self.conv2 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ act=None,
+ name=name + "_branch2c")
+
+ self.short = ShortCut(
+ in_channels=in_channels,
+ out_channels=out_channels * 4,
+ stride=stride,
+ is_first=False,
+ name=name + "_branch1")
+ self.out_channels = out_channels * 4
+
+ def forward(self, x):
+ y = self.conv0(x)
+ y = self.conv1(y)
+ y = self.conv2(y)
+ y = y + self.short(x)
+ y = F.relu(y)
+ return y
+
+
+class BasicBlock(nn.Layer):
+ def __init__(self, in_channels, out_channels, stride, name, is_first):
+ super(BasicBlock, self).__init__()
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ act='relu',
+ stride=stride,
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ act=None,
+ name=name + "_branch2b")
+ self.short = ShortCut(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride,
+ is_first=is_first,
+ name=name + "_branch1")
+ self.out_channels = out_channels
+
+ def forward(self, x):
+ y = self.conv0(x)
+ y = self.conv1(y)
+ y = y + self.short(x)
+ return F.relu(y)
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 29d0ba800..efe057185 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -24,11 +24,13 @@ def build_head(config):
# rec head
from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead
+ from .rec_srn_head import SRNHead
# cls head
from .cls_head import ClsHead
support_dict = [
- 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead'
+ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
+ 'SRNHead'
]
module_name = config.pop('name')
diff --git a/ppocr/modeling/heads/rec_srn_head.py b/ppocr/modeling/heads/rec_srn_head.py
new file mode 100644
index 000000000..8aaf65e1a
--- /dev/null
+++ b/ppocr/modeling/heads/rec_srn_head.py
@@ -0,0 +1,279 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import paddle.fluid as fluid
+import numpy as np
+from .self_attention import WrapEncoderForFeature
+from .self_attention import WrapEncoder
+from paddle.static import Program
+from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN
+import paddle.fluid.framework as framework
+
+from collections import OrderedDict
+gradient_clip = 10
+
+
+class PVAM(nn.Layer):
+ def __init__(self, in_channels, char_num, max_text_length, num_heads,
+ num_encoder_tus, hidden_dims):
+ super(PVAM, self).__init__()
+ self.char_num = char_num
+ self.max_length = max_text_length
+ self.num_heads = num_heads
+ self.num_encoder_TUs = num_encoder_tus
+ self.hidden_dims = hidden_dims
+ # Transformer encoder
+ t = 256
+ c = 512
+ self.wrap_encoder_for_feature = WrapEncoderForFeature(
+ src_vocab_size=1,
+ max_length=t,
+ n_layer=self.num_encoder_TUs,
+ n_head=self.num_heads,
+ d_key=int(self.hidden_dims / self.num_heads),
+ d_value=int(self.hidden_dims / self.num_heads),
+ d_model=self.hidden_dims,
+ d_inner_hid=self.hidden_dims,
+ prepostprocess_dropout=0.1,
+ attention_dropout=0.1,
+ relu_dropout=0.1,
+ preprocess_cmd="n",
+ postprocess_cmd="da",
+ weight_sharing=True)
+
+ # PVAM
+ self.flatten0 = paddle.nn.Flatten(start_axis=0, stop_axis=1)
+ self.fc0 = paddle.nn.Linear(
+ in_features=in_channels,
+ out_features=in_channels, )
+ self.emb = paddle.nn.Embedding(
+ num_embeddings=self.max_length, embedding_dim=in_channels)
+ self.flatten1 = paddle.nn.Flatten(start_axis=0, stop_axis=2)
+ self.fc1 = paddle.nn.Linear(
+ in_features=in_channels, out_features=1, bias_attr=False)
+
+ def forward(self, inputs, encoder_word_pos, gsrm_word_pos):
+ b, c, h, w = inputs.shape
+ conv_features = paddle.reshape(inputs, shape=[-1, c, h * w])
+ conv_features = paddle.transpose(conv_features, perm=[0, 2, 1])
+ # transformer encoder
+ b, t, c = conv_features.shape
+
+ enc_inputs = [conv_features, encoder_word_pos, None]
+ word_features = self.wrap_encoder_for_feature(enc_inputs)
+
+ # pvam
+ b, t, c = word_features.shape
+ word_features = self.fc0(word_features)
+ word_features_ = paddle.reshape(word_features, [-1, 1, t, c])
+ word_features_ = paddle.tile(word_features_, [1, self.max_length, 1, 1])
+ word_pos_feature = self.emb(gsrm_word_pos)
+ word_pos_feature_ = paddle.reshape(word_pos_feature,
+ [-1, self.max_length, 1, c])
+ word_pos_feature_ = paddle.tile(word_pos_feature_, [1, 1, t, 1])
+ y = word_pos_feature_ + word_features_
+ y = F.tanh(y)
+ attention_weight = self.fc1(y)
+ attention_weight = paddle.reshape(
+ attention_weight, shape=[-1, self.max_length, t])
+ attention_weight = F.softmax(attention_weight, axis=-1)
+ pvam_features = paddle.matmul(attention_weight,
+ word_features) #[b, max_length, c]
+ return pvam_features
+
+
+class GSRM(nn.Layer):
+ def __init__(self, in_channels, char_num, max_text_length, num_heads,
+ num_encoder_tus, num_decoder_tus, hidden_dims):
+ super(GSRM, self).__init__()
+ self.char_num = char_num
+ self.max_length = max_text_length
+ self.num_heads = num_heads
+ self.num_encoder_TUs = num_encoder_tus
+ self.num_decoder_TUs = num_decoder_tus
+ self.hidden_dims = hidden_dims
+
+ self.fc0 = paddle.nn.Linear(
+ in_features=in_channels, out_features=self.char_num)
+ self.wrap_encoder0 = WrapEncoder(
+ src_vocab_size=self.char_num + 1,
+ max_length=self.max_length,
+ n_layer=self.num_decoder_TUs,
+ n_head=self.num_heads,
+ d_key=int(self.hidden_dims / self.num_heads),
+ d_value=int(self.hidden_dims / self.num_heads),
+ d_model=self.hidden_dims,
+ d_inner_hid=self.hidden_dims,
+ prepostprocess_dropout=0.1,
+ attention_dropout=0.1,
+ relu_dropout=0.1,
+ preprocess_cmd="n",
+ postprocess_cmd="da",
+ weight_sharing=True)
+
+ self.wrap_encoder1 = WrapEncoder(
+ src_vocab_size=self.char_num + 1,
+ max_length=self.max_length,
+ n_layer=self.num_decoder_TUs,
+ n_head=self.num_heads,
+ d_key=int(self.hidden_dims / self.num_heads),
+ d_value=int(self.hidden_dims / self.num_heads),
+ d_model=self.hidden_dims,
+ d_inner_hid=self.hidden_dims,
+ prepostprocess_dropout=0.1,
+ attention_dropout=0.1,
+ relu_dropout=0.1,
+ preprocess_cmd="n",
+ postprocess_cmd="da",
+ weight_sharing=True)
+
+ self.mul = lambda x: paddle.matmul(x=x,
+ y=self.wrap_encoder0.prepare_decoder.emb0.weight,
+ transpose_y=True)
+
+ def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2):
+ # ===== GSRM Visual-to-semantic embedding block =====
+ b, t, c = inputs.shape
+ pvam_features = paddle.reshape(inputs, [-1, c])
+ word_out = self.fc0(pvam_features)
+ word_ids = paddle.argmax(F.softmax(word_out), axis=1)
+ word_ids = paddle.reshape(x=word_ids, shape=[-1, t, 1])
+
+ #===== GSRM Semantic reasoning block =====
+ """
+ This module is achieved through bi-transformers,
+ ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
+ """
+ pad_idx = self.char_num
+
+ word1 = paddle.cast(word_ids, "float32")
+ word1 = F.pad(word1, [1, 0], value=1.0 * pad_idx, data_format="NLC")
+ word1 = paddle.cast(word1, "int64")
+ word1 = word1[:, :-1, :]
+ word2 = word_ids
+
+ enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
+ enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
+
+ gsrm_feature1 = self.wrap_encoder0(enc_inputs_1)
+ gsrm_feature2 = self.wrap_encoder1(enc_inputs_2)
+
+ gsrm_feature2 = F.pad(gsrm_feature2, [0, 1],
+ value=0.,
+ data_format="NLC")
+ gsrm_feature2 = gsrm_feature2[:, 1:, ]
+ gsrm_features = gsrm_feature1 + gsrm_feature2
+
+ gsrm_out = self.mul(gsrm_features)
+
+ b, t, c = gsrm_out.shape
+ gsrm_out = paddle.reshape(gsrm_out, [-1, c])
+
+ return gsrm_features, word_out, gsrm_out
+
+
+class VSFD(nn.Layer):
+ def __init__(self, in_channels=512, pvam_ch=512, char_num=38):
+ super(VSFD, self).__init__()
+ self.char_num = char_num
+ self.fc0 = paddle.nn.Linear(
+ in_features=in_channels * 2, out_features=pvam_ch)
+ self.fc1 = paddle.nn.Linear(
+ in_features=pvam_ch, out_features=self.char_num)
+
+ def forward(self, pvam_feature, gsrm_feature):
+ b, t, c1 = pvam_feature.shape
+ b, t, c2 = gsrm_feature.shape
+ combine_feature_ = paddle.concat([pvam_feature, gsrm_feature], axis=2)
+ img_comb_feature_ = paddle.reshape(
+ combine_feature_, shape=[-1, c1 + c2])
+ img_comb_feature_map = self.fc0(img_comb_feature_)
+ img_comb_feature_map = F.sigmoid(img_comb_feature_map)
+ img_comb_feature_map = paddle.reshape(
+ img_comb_feature_map, shape=[-1, t, c1])
+ combine_feature = img_comb_feature_map * pvam_feature + (
+ 1.0 - img_comb_feature_map) * gsrm_feature
+ img_comb_feature = paddle.reshape(combine_feature, shape=[-1, c1])
+
+ out = self.fc1(img_comb_feature)
+ return out
+
+
+class SRNHead(nn.Layer):
+ def __init__(self, in_channels, out_channels, max_text_length, num_heads,
+ num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs):
+ super(SRNHead, self).__init__()
+ self.char_num = out_channels
+ self.max_length = max_text_length
+ self.num_heads = num_heads
+ self.num_encoder_TUs = num_encoder_TUs
+ self.num_decoder_TUs = num_decoder_TUs
+ self.hidden_dims = hidden_dims
+
+ self.pvam = PVAM(
+ in_channels=in_channels,
+ char_num=self.char_num,
+ max_text_length=self.max_length,
+ num_heads=self.num_heads,
+ num_encoder_tus=self.num_encoder_TUs,
+ hidden_dims=self.hidden_dims)
+
+ self.gsrm = GSRM(
+ in_channels=in_channels,
+ char_num=self.char_num,
+ max_text_length=self.max_length,
+ num_heads=self.num_heads,
+ num_encoder_tus=self.num_encoder_TUs,
+ num_decoder_tus=self.num_decoder_TUs,
+ hidden_dims=self.hidden_dims)
+ self.vsfd = VSFD(in_channels=in_channels)
+
+ self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
+
+ def forward(self, inputs, others):
+ encoder_word_pos = others[0]
+ gsrm_word_pos = others[1]
+ gsrm_slf_attn_bias1 = others[2]
+ gsrm_slf_attn_bias2 = others[3]
+
+ pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos)
+
+ gsrm_feature, word_out, gsrm_out = self.gsrm(
+ pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2)
+
+ final_out = self.vsfd(pvam_feature, gsrm_feature)
+ if not self.training:
+ final_out = F.softmax(final_out, axis=1)
+
+ _, decoded_out = paddle.topk(final_out, k=1)
+
+ predicts = OrderedDict([
+ ('predict', final_out),
+ ('pvam_feature', pvam_feature),
+ ('decoded_out', decoded_out),
+ ('word_out', word_out),
+ ('gsrm_out', gsrm_out),
+ ])
+
+ return predicts
diff --git a/ppocr/modeling/heads/self_attention.py b/ppocr/modeling/heads/self_attention.py
new file mode 100644
index 000000000..51d5198f5
--- /dev/null
+++ b/ppocr/modeling/heads/self_attention.py
@@ -0,0 +1,409 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import paddle
+from paddle import ParamAttr, nn
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import paddle.fluid as fluid
+import numpy as np
+gradient_clip = 10
+
+
+class WrapEncoderForFeature(nn.Layer):
+ def __init__(self,
+ src_vocab_size,
+ max_length,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ weight_sharing,
+ bos_idx=0):
+ super(WrapEncoderForFeature, self).__init__()
+
+ self.prepare_encoder = PrepareEncoder(
+ src_vocab_size,
+ d_model,
+ max_length,
+ prepostprocess_dropout,
+ bos_idx=bos_idx,
+ word_emb_param_name="src_word_emb_table")
+ self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
+ d_inner_hid, prepostprocess_dropout,
+ attention_dropout, relu_dropout, preprocess_cmd,
+ postprocess_cmd)
+
+ def forward(self, enc_inputs):
+ conv_features, src_pos, src_slf_attn_bias = enc_inputs
+ enc_input = self.prepare_encoder(conv_features, src_pos)
+ enc_output = self.encoder(enc_input, src_slf_attn_bias)
+ return enc_output
+
+
+class WrapEncoder(nn.Layer):
+ """
+ embedder + encoder
+ """
+
+ def __init__(self,
+ src_vocab_size,
+ max_length,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ weight_sharing,
+ bos_idx=0):
+ super(WrapEncoder, self).__init__()
+
+ self.prepare_decoder = PrepareDecoder(
+ src_vocab_size,
+ d_model,
+ max_length,
+ prepostprocess_dropout,
+ bos_idx=bos_idx)
+ self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
+ d_inner_hid, prepostprocess_dropout,
+ attention_dropout, relu_dropout, preprocess_cmd,
+ postprocess_cmd)
+
+ def forward(self, enc_inputs):
+ src_word, src_pos, src_slf_attn_bias = enc_inputs
+ enc_input = self.prepare_decoder(src_word, src_pos)
+ enc_output = self.encoder(enc_input, src_slf_attn_bias)
+ return enc_output
+
+
+class Encoder(nn.Layer):
+ """
+ encoder
+ """
+
+ def __init__(self,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+
+ super(Encoder, self).__init__()
+
+ self.encoder_layers = list()
+ for i in range(n_layer):
+ self.encoder_layers.append(
+ self.add_sublayer(
+ "layer_%d" % i,
+ EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
+ prepostprocess_dropout, attention_dropout,
+ relu_dropout, preprocess_cmd,
+ postprocess_cmd)))
+ self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self, enc_input, attn_bias):
+ for encoder_layer in self.encoder_layers:
+ enc_output = encoder_layer(enc_input, attn_bias)
+ enc_input = enc_output
+ enc_output = self.processer(enc_output)
+ return enc_output
+
+
+class EncoderLayer(nn.Layer):
+ """
+ EncoderLayer
+ """
+
+ def __init__(self,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+
+ super(EncoderLayer, self).__init__()
+ self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
+ attention_dropout)
+ self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
+ self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self, enc_input, attn_bias):
+ attn_output = self.self_attn(
+ self.preprocesser1(enc_input), None, None, attn_bias)
+ attn_output = self.postprocesser1(attn_output, enc_input)
+ ffn_output = self.ffn(self.preprocesser2(attn_output))
+ ffn_output = self.postprocesser2(ffn_output, attn_output)
+ return ffn_output
+
+
+class MultiHeadAttention(nn.Layer):
+ """
+ Multi-Head Attention
+ """
+
+ def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
+ super(MultiHeadAttention, self).__init__()
+ self.n_head = n_head
+ self.d_key = d_key
+ self.d_value = d_value
+ self.d_model = d_model
+ self.dropout_rate = dropout_rate
+ self.q_fc = paddle.nn.Linear(
+ in_features=d_model, out_features=d_key * n_head, bias_attr=False)
+ self.k_fc = paddle.nn.Linear(
+ in_features=d_model, out_features=d_key * n_head, bias_attr=False)
+ self.v_fc = paddle.nn.Linear(
+ in_features=d_model, out_features=d_value * n_head, bias_attr=False)
+ self.proj_fc = paddle.nn.Linear(
+ in_features=d_value * n_head, out_features=d_model, bias_attr=False)
+
+ def _prepare_qkv(self, queries, keys, values, cache=None):
+ if keys is None: # self-attention
+ keys, values = queries, queries
+ static_kv = False
+ else: # cross-attention
+ static_kv = True
+
+ q = self.q_fc(queries)
+ q = paddle.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
+ q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
+
+ if cache is not None and static_kv and "static_k" in cache:
+ # for encoder-decoder attention in inference and has cached
+ k = cache["static_k"]
+ v = cache["static_v"]
+ else:
+ k = self.k_fc(keys)
+ v = self.v_fc(values)
+ k = paddle.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
+ k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
+ v = paddle.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
+ v = paddle.transpose(x=v, perm=[0, 2, 1, 3])
+
+ if cache is not None:
+ if static_kv and not "static_k" in cache:
+ # for encoder-decoder attention in inference and has not cached
+ cache["static_k"], cache["static_v"] = k, v
+ elif not static_kv:
+ # for decoder self-attention in inference
+ cache_k, cache_v = cache["k"], cache["v"]
+ k = paddle.concat([cache_k, k], axis=2)
+ v = paddle.concat([cache_v, v], axis=2)
+ cache["k"], cache["v"] = k, v
+
+ return q, k, v
+
+ def forward(self, queries, keys, values, attn_bias, cache=None):
+ # compute q ,k ,v
+ keys = queries if keys is None else keys
+ values = keys if values is None else values
+ q, k, v = self._prepare_qkv(queries, keys, values, cache)
+
+ # scale dot product attention
+ product = paddle.matmul(x=q, y=k, transpose_y=True)
+ product = product * self.d_model**-0.5
+ if attn_bias is not None:
+ product += attn_bias
+ weights = F.softmax(product)
+ if self.dropout_rate:
+ weights = F.dropout(
+ weights, p=self.dropout_rate, mode="downscale_in_infer")
+ out = paddle.matmul(weights, v)
+
+ # combine heads
+ out = paddle.transpose(out, perm=[0, 2, 1, 3])
+ out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
+
+ # project to output
+ out = self.proj_fc(out)
+
+ return out
+
+
+class PrePostProcessLayer(nn.Layer):
+ """
+ PrePostProcessLayer
+ """
+
+ def __init__(self, process_cmd, d_model, dropout_rate):
+ super(PrePostProcessLayer, self).__init__()
+ self.process_cmd = process_cmd
+ self.functors = []
+ for cmd in self.process_cmd:
+ if cmd == "a": # add residual connection
+ self.functors.append(lambda x, y: x + y if y is not None else x)
+ elif cmd == "n": # add layer normalization
+ self.functors.append(
+ self.add_sublayer(
+ "layer_norm_%d" % len(
+ self.sublayers(include_sublayers=False)),
+ paddle.nn.LayerNorm(
+ normalized_shape=d_model,
+ weight_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(1.)),
+ bias_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(0.)))))
+ elif cmd == "d": # add dropout
+ self.functors.append(lambda x: F.dropout(
+ x, p=dropout_rate, mode="downscale_in_infer")
+ if dropout_rate else x)
+
+ def forward(self, x, residual=None):
+ for i, cmd in enumerate(self.process_cmd):
+ if cmd == "a":
+ x = self.functors[i](x, residual)
+ else:
+ x = self.functors[i](x)
+ return x
+
+
+class PrepareEncoder(nn.Layer):
+ def __init__(self,
+ src_vocab_size,
+ src_emb_dim,
+ src_max_len,
+ dropout_rate=0,
+ bos_idx=0,
+ word_emb_param_name=None,
+ pos_enc_param_name=None):
+ super(PrepareEncoder, self).__init__()
+ self.src_emb_dim = src_emb_dim
+ self.src_max_len = src_max_len
+ self.emb = paddle.nn.Embedding(
+ num_embeddings=self.src_max_len,
+ embedding_dim=self.src_emb_dim,
+ sparse=True)
+ self.dropout_rate = dropout_rate
+
+ def forward(self, src_word, src_pos):
+ src_word_emb = src_word
+ src_word_emb = fluid.layers.cast(src_word_emb, 'float32')
+ src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
+ src_pos = paddle.squeeze(src_pos, axis=-1)
+ src_pos_enc = self.emb(src_pos)
+ src_pos_enc.stop_gradient = True
+ enc_input = src_word_emb + src_pos_enc
+ if self.dropout_rate:
+ out = F.dropout(
+ x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
+ else:
+ out = enc_input
+ return out
+
+
+class PrepareDecoder(nn.Layer):
+ def __init__(self,
+ src_vocab_size,
+ src_emb_dim,
+ src_max_len,
+ dropout_rate=0,
+ bos_idx=0,
+ word_emb_param_name=None,
+ pos_enc_param_name=None):
+ super(PrepareDecoder, self).__init__()
+ self.src_emb_dim = src_emb_dim
+ """
+ self.emb0 = Embedding(num_embeddings=src_vocab_size,
+ embedding_dim=src_emb_dim)
+ """
+ self.emb0 = paddle.nn.Embedding(
+ num_embeddings=src_vocab_size,
+ embedding_dim=self.src_emb_dim,
+ padding_idx=bos_idx,
+ weight_attr=paddle.ParamAttr(
+ name=word_emb_param_name,
+ initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
+ self.emb1 = paddle.nn.Embedding(
+ num_embeddings=src_max_len,
+ embedding_dim=self.src_emb_dim,
+ weight_attr=paddle.ParamAttr(name=pos_enc_param_name))
+ self.dropout_rate = dropout_rate
+
+ def forward(self, src_word, src_pos):
+ src_word = fluid.layers.cast(src_word, 'int64')
+ src_word = paddle.squeeze(src_word, axis=-1)
+ src_word_emb = self.emb0(src_word)
+ src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
+ src_pos = paddle.squeeze(src_pos, axis=-1)
+ src_pos_enc = self.emb1(src_pos)
+ src_pos_enc.stop_gradient = True
+ enc_input = src_word_emb + src_pos_enc
+ if self.dropout_rate:
+ out = F.dropout(
+ x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
+ else:
+ out = enc_input
+ return out
+
+
+class FFN(nn.Layer):
+ """
+ Feed-Forward Network
+ """
+
+ def __init__(self, d_inner_hid, d_model, dropout_rate):
+ super(FFN, self).__init__()
+ self.dropout_rate = dropout_rate
+ self.fc1 = paddle.nn.Linear(
+ in_features=d_model, out_features=d_inner_hid)
+ self.fc2 = paddle.nn.Linear(
+ in_features=d_inner_hid, out_features=d_model)
+
+ def forward(self, x):
+ hidden = self.fc1(x)
+ hidden = F.relu(hidden)
+ if self.dropout_rate:
+ hidden = F.dropout(
+ hidden, p=self.dropout_rate, mode="downscale_in_infer")
+ out = self.fc2(hidden)
+ return out
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 2b8d00a9e..0156e438e 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -26,12 +26,12 @@ def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
- from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
+ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
- 'AttnLabelDecode', 'ClsPostProcess', 'AttnLabelDecode'
+ 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
]
config = copy.deepcopy(config)
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 1ac352466..2b82750fc 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -33,6 +33,9 @@ class BaseRecLabelDecode(object):
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
+ self.beg_str = "sos"
+ self.end_str = "eos"
+
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
@@ -109,7 +112,6 @@ class CTCLabelDecode(BaseRecLabelDecode):
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
-
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
@@ -194,3 +196,84 @@ class AttnLabelDecode(BaseRecLabelDecode):
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
+
+
+class SRNLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ character_dict_path=None,
+ character_type='en',
+ use_space_char=False,
+ **kwargs):
+ super(SRNLabelDecode, self).__init__(character_dict_path,
+ character_type, use_space_char)
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ pred = preds['predict']
+ char_num = len(self.character_str) + 2
+ if isinstance(pred, paddle.Tensor):
+ pred = pred.numpy()
+ pred = np.reshape(pred, [-1, char_num])
+
+ preds_idx = np.argmax(pred, axis=1)
+ preds_prob = np.max(pred, axis=1)
+
+ preds_idx = np.reshape(preds_idx, [-1, 25])
+
+ preds_prob = np.reshape(preds_prob, [-1, 25])
+
+ text = self.decode(preds_idx, preds_prob)
+
+ if label is None:
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ return text
+ label = self.decode(label)
+ return text, label
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ ignored_tokens = self.get_ignored_tokens()
+ batch_size = len(text_index)
+
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ if text_index[batch_idx][idx] in ignored_tokens:
+ continue
+ if is_remove_duplicate:
+ # only for predict
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
+ batch_idx][idx]:
+ continue
+ char_list.append(self.character[int(text_index[batch_idx][
+ idx])])
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list)))
+ return result_list
+
+ def add_special_char(self, dict_character):
+ dict_character = dict_character + [self.beg_str, self.end_str]
+ return dict_character
+
+ def get_ignored_tokens(self):
+ beg_idx = self.get_beg_end_flag_idx("beg")
+ end_idx = self.get_beg_end_flag_idx("end")
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "beg":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "unsupport type %s in get_beg_end_flag_idx" \
+ % beg_or_end
+ return idx
diff --git a/tools/export_model.py b/tools/export_model.py
index a9b9e7dd5..1e9526e03 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-c", "--config", help="configuration file to use")
+ parser.add_argument(
+ "-o", "--output_path", type=str, default='./output/infer/')
+ return parser.parse_args()
+
+
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
@@ -52,23 +60,39 @@ def main():
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
- infer_shape = [3, -1, -1]
- if config['Architecture']['model_type'] == "rec":
- infer_shape = [3, 32, -1] # for rec model, H must be 32
- if 'Transform' in config['Architecture'] and config['Architecture'][
- 'Transform'] is not None and config['Architecture'][
- 'Transform']['name'] == 'TPS':
- logger.info(
- 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
- )
- infer_shape[-1] = 100
-
- model = to_static(
- model,
- input_spec=[
+ if config['Architecture']['algorithm'] == "SRN":
+ other_shape = [
paddle.static.InputSpec(
- shape=[None] + infer_shape, dtype='float32')
- ])
+ shape=[None, 1, 64, 256], dtype='float32'), [
+ paddle.static.InputSpec(
+ shape=[None, 256, 1],
+ dtype="int64"), paddle.static.InputSpec(
+ shape=[None, 25, 1],
+ dtype="int64"), paddle.static.InputSpec(
+ shape=[None, 8, 25, 25], dtype="int64"),
+ paddle.static.InputSpec(
+ shape=[None, 8, 25, 25], dtype="int64")
+ ]
+ ]
+ model = to_static(model, input_spec=other_shape)
+ else:
+ infer_shape = [3, -1, -1]
+ if config['Architecture']['model_type'] == "rec":
+ infer_shape = [3, 32, -1] # for rec model, H must be 32
+ if 'Transform' in config['Architecture'] and config['Architecture'][
+ 'Transform'] is not None and config['Architecture'][
+ 'Transform']['name'] == 'TPS':
+ logger.info(
+ 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
+ )
+ infer_shape[-1] = 100
+ model = to_static(
+ model,
+ input_spec=[
+ paddle.static.InputSpec(
+ shape=[None] + infer_shape, dtype='float32')
+ ])
+
paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path))
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 974fdbb6c..fd895e507 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -25,6 +25,7 @@ import numpy as np
import math
import time
import traceback
+import paddle
import tools.infer.utility as utility
from ppocr.postprocess import build_post_process
@@ -46,6 +47,13 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
+ if self.rec_algorithm == "SRN":
+ postprocess_params = {
+ 'name': 'SRNLabelDecode',
+ "character_type": args.rec_char_type,
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, 'rec', logger)
@@ -70,6 +78,78 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
+ def resize_norm_img_srn(self, img, image_shape):
+ imgC, imgH, imgW = image_shape
+
+ img_black = np.zeros((imgH, imgW))
+ im_hei = img.shape[0]
+ im_wid = img.shape[1]
+
+ if im_wid <= im_hei * 1:
+ img_new = cv2.resize(img, (imgH * 1, imgH))
+ elif im_wid <= im_hei * 2:
+ img_new = cv2.resize(img, (imgH * 2, imgH))
+ elif im_wid <= im_hei * 3:
+ img_new = cv2.resize(img, (imgH * 3, imgH))
+ else:
+ img_new = cv2.resize(img, (imgW, imgH))
+
+ img_np = np.asarray(img_new)
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
+ img_black[:, 0:img_np.shape[1]] = img_np
+ img_black = img_black[:, :, np.newaxis]
+
+ row, col, c = img_black.shape
+ c = 1
+
+ return np.reshape(img_black, (c, row, col)).astype(np.float32)
+
+ def srn_other_inputs(self, image_shape, num_heads, max_text_length):
+
+ imgC, imgH, imgW = image_shape
+ feature_dim = int((imgH / 8) * (imgW / 8))
+
+ encoder_word_pos = np.array(range(0, feature_dim)).reshape(
+ (feature_dim, 1)).astype('int64')
+ gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
+ (max_text_length, 1)).astype('int64')
+
+ gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
+ gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
+ [-1, 1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias1 = np.tile(
+ gsrm_slf_attn_bias1,
+ [1, num_heads, 1, 1]).astype('float32') * [-1e9]
+
+ gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
+ [-1, 1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias2 = np.tile(
+ gsrm_slf_attn_bias2,
+ [1, num_heads, 1, 1]).astype('float32') * [-1e9]
+
+ encoder_word_pos = encoder_word_pos[np.newaxis, :]
+ gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
+
+ return [
+ encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2
+ ]
+
+ def process_image_srn(self, img, image_shape, num_heads, max_text_length):
+ norm_img = self.resize_norm_img_srn(img, image_shape)
+ norm_img = norm_img[np.newaxis, :]
+
+ [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
+ self.srn_other_inputs(image_shape, num_heads, max_text_length)
+
+ gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
+ gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
+ encoder_word_pos = encoder_word_pos.astype(np.int64)
+ gsrm_word_pos = gsrm_word_pos.astype(np.int64)
+
+ return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2)
+
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
@@ -93,21 +173,64 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
- # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
- norm_img = self.resize_norm_img(img_list[indices[ino]],
- max_wh_ratio)
- norm_img = norm_img[np.newaxis, :]
- norm_img_batch.append(norm_img)
+ if self.rec_algorithm != "SRN":
+ norm_img = self.resize_norm_img(img_list[indices[ino]],
+ max_wh_ratio)
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
+ else:
+ norm_img = self.process_image_srn(
+ img_list[indices[ino]], self.rec_image_shape, 8, 25)
+ encoder_word_pos_list = []
+ gsrm_word_pos_list = []
+ gsrm_slf_attn_bias1_list = []
+ gsrm_slf_attn_bias2_list = []
+ encoder_word_pos_list.append(norm_img[1])
+ gsrm_word_pos_list.append(norm_img[2])
+ gsrm_slf_attn_bias1_list.append(norm_img[3])
+ gsrm_slf_attn_bias2_list.append(norm_img[4])
+ norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
- starttime = time.time()
- self.input_tensor.copy_from_cpu(norm_img_batch)
- self.predictor.run()
- outputs = []
- for output_tensor in self.output_tensors:
- output = output_tensor.copy_to_cpu()
- outputs.append(output)
- preds = outputs[0]
+
+ if self.rec_algorithm == "SRN":
+ starttime = time.time()
+ encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
+ gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
+ gsrm_slf_attn_bias1_list = np.concatenate(
+ gsrm_slf_attn_bias1_list)
+ gsrm_slf_attn_bias2_list = np.concatenate(
+ gsrm_slf_attn_bias2_list)
+
+ inputs = [
+ norm_img_batch,
+ encoder_word_pos_list,
+ gsrm_word_pos_list,
+ gsrm_slf_attn_bias1_list,
+ gsrm_slf_attn_bias2_list,
+ ]
+ input_names = self.predictor.get_input_names()
+ for i in range(len(input_names)):
+ input_tensor = self.predictor.get_input_handle(input_names[
+ i])
+ input_tensor.copy_from_cpu(inputs[i])
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ preds = {"predict": outputs[2]}
+ else:
+ starttime = time.time()
+ self.input_tensor.copy_from_cpu(norm_img_batch)
+ self.predictor.run()
+
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ preds = outputs[0]
+
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
diff --git a/tools/infer_rec.py b/tools/infer_rec.py
index 7e4b08114..075ec261e 100755
--- a/tools/infer_rec.py
+++ b/tools/infer_rec.py
@@ -62,7 +62,13 @@ def main():
elif op_name in ['RecResizeImg']:
op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys':
- op[op_name]['keep_keys'] = ['image']
+ if config['Architecture']['algorithm'] == "SRN":
+ op[op_name]['keep_keys'] = [
+ 'image', 'encoder_word_pos', 'gsrm_word_pos',
+ 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
+ ]
+ else:
+ op[op_name]['keep_keys'] = ['image']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
@@ -74,10 +80,25 @@ def main():
img = f.read()
data = {'image': img}
batch = transform(data, ops)
+ if config['Architecture']['algorithm'] == "SRN":
+ encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
+ gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
+ gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
+ gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
+
+ others = [
+ paddle.to_tensor(encoder_word_pos_list),
+ paddle.to_tensor(gsrm_word_pos_list),
+ paddle.to_tensor(gsrm_slf_attn_bias1_list),
+ paddle.to_tensor(gsrm_slf_attn_bias2_list)
+ ]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
- preds = model(images)
+ if config['Architecture']['algorithm'] == "SRN":
+ preds = model(images, others)
+ else:
+ preds = model(images)
post_result = post_process_class(preds)
for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt))
diff --git a/tools/program.py b/tools/program.py
index fb9e3802a..f3ba49450 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -174,6 +174,7 @@ def train(config,
best_model_dict = {main_indicator: 0}
best_model_dict.update(pre_best_model_dict)
train_stats = TrainingStats(log_smooth_window, ['lr'])
+ model_average = False
model.train()
if 'start_epoch' in best_model_dict:
@@ -194,7 +195,12 @@ def train(config,
break
lr = optimizer.get_lr()
images = batch[0]
- preds = model(images)
+ if config['Architecture']['algorithm'] == "SRN":
+ others = batch[-4:]
+ preds = model(images, others)
+ model_average = True
+ else:
+ preds = model(images)
loss = loss_class(preds, batch)
avg_loss = loss['loss']
avg_loss.backward()
@@ -238,7 +244,14 @@ def train(config,
# eval
if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
- cur_metric = eval(model, valid_dataloader, post_process_class,
+ if model_average:
+ Model_Average = paddle.incubate.optimizer.ModelAverage(
+ 0.15,
+ parameters=model.parameters(),
+ min_average_window=10000,
+ max_average_window=15625)
+ Model_Average.apply()
+ cur_metirc = eval(model, valid_dataloader, post_process_class,
eval_class)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
@@ -273,6 +286,7 @@ def train(config,
best_model_dict[main_indicator],
global_step)
global_step += 1
+ optimizer.clear_grad()
batch_start = time.time()
if dist.get_rank() == 0:
save_model(
@@ -313,7 +327,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
break
images = batch[0]
start = time.time()
- preds = model(images)
+ if "SRN" in str(model.head):
+ others = batch[-4:]
+ preds = model(images, others)
+ else:
+ preds = model(images)
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods