rm svtrlabeldecode resize
parent
ac703e56f2
commit
d2c11969c2
|
@ -62,7 +62,7 @@ python3 tools/eval.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model
|
|||
使用如下命令进行单张图片预测:
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.load_static_weights=false
|
||||
python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy
|
||||
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
|
||||
```
|
||||
|
||||
|
@ -72,11 +72,11 @@ python3 tools/infer_rec.py -c configs/rec/rec_mtb_nrtr.yml -o Global.infer_img='
|
|||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](#model)),可以使用如下命令进行转换:
|
||||
首先将训练得到best模型,转换成inference model。这里以训练完成的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar)),可以使用如下命令进行转换:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/export_model.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.save_inference_dir=./inference/rec_mtb_nrtr/ Global.load_static_weights=False
|
||||
python3 tools/export_model.py -c configs/rec/rec_mtb_nrtr.yml -o Global.pretrained_model=./rec_mtb_nrtr_train/best_accuracy Global.save_inference_dir=./inference/rec_mtb_nrtr/
|
||||
```
|
||||
|
||||
执行如下命令进行模型推理:
|
||||
|
|
|
@ -26,10 +26,13 @@
|
|||
1. 首次发现单视觉模型可以达到与视觉语言模型相媲美甚至更高的准确率,并且其具有效率高和适应多语言的优点,在实际应用中很有前景。
|
||||
2. SVTR从字符组件的角度出发,逐渐的合并字符组件,自下而上地完成字符的识别。
|
||||
3. SVTR引入了局部和全局Mixing,分别用于提取字符组件特征和字符间依赖关系,与多尺度的特征一起,形成多粒度特征描述。
|
||||
4. SVTR-L在识别英文和中文场景文本方面实现了最先进的性能。SVTR-T平衡精确度和效率,在一个NVIDIA 1080Ti GPU中,每个英文图像文本平均消耗4.5ms。
|
||||
|
||||
|
||||
<a name="model"></a>
|
||||
`SVTR`在场景文本识别公开数据集上的精度(%)和模型文件如下:
|
||||
SVTR在场景文本识别公开数据集上的精度(%)和模型文件如下:
|
||||
|
||||
* 中文数据集来自于[Chinese Benckmark](https://arxiv.org/abs/2112.15093) ,SVTR的中文训练评估策略遵循该论文。
|
||||
|
||||
| SVTR |IC13<br/>857 | SVT |IIIT5k<br/>3000 |IC15<br/>1811| SVTP |CUTE80 | Avg_6 |IC15<br/>2077 |IC13<br/>1015 |IC03<br/>867|IC03<br/>860|Avg_10 |Chinese| 英文<br/>链接 | 中文<br/>链接 |
|
||||
|:-----:|:------:|:-----:|:---------:|:------:|:-----:|:-----:|:-----:|:-------:|:-------:|:-----:|:-----:|:-----:|:-----:|:----------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------------:|
|
||||
|
@ -56,10 +59,6 @@
|
|||
[英文数据集下载](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here)
|
||||
[中文数据集下载](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download)
|
||||
|
||||
**注意:**
|
||||
1. 训练`SVTR`时,需将将配置文件中的测试数据集路径设置为本地的评估数据集路径,例如将中文的`scene_test`数据集修改为`scene_val`。
|
||||
2. 训练`SVTR`时,需将配置文件中的`SVTRLableDecode`修改为`CTCLabelDecode`,将`SVTRRecResizeImg`修改为`RecResizeImg`。
|
||||
|
||||
#### 启动训练
|
||||
|
||||
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练`SVTR`识别模型时需要**更换配置文件**为`SVTR`的[配置文件](../../configs/rec/rec_svtrnet.yml)。
|
||||
|
@ -67,7 +66,7 @@
|
|||
<a name="3-2"></a>
|
||||
### 3.2 评估
|
||||
|
||||
可下载`SVTR`提供模型文件和配置文件[模型下载](#model),以`SVTR-T`为例,使用如下命令进行评估:
|
||||
可下载`SVTR`提供模型文件和配置文件:[下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ,以`SVTR-T`为例,使用如下命令进行评估:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
|
@ -81,7 +80,7 @@ python3 tools/eval.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_s
|
|||
使用如下命令进行单张图片预测:
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.load_static_weights=false
|
||||
python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.infer_img='./doc/imgs_words_en/word_10.png' Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy
|
||||
# 预测文件夹下所有图像时,可修改infer_img为文件夹,如 Global.infer_img='./doc/imgs_words_en/'。
|
||||
```
|
||||
|
||||
|
@ -91,11 +90,11 @@ python3 tools/infer_rec.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6glo
|
|||
|
||||
<a name="4-1"></a>
|
||||
### 4.1 Python推理
|
||||
首先将训练得到best模型,转换成inference model。下面以基于`SVTR-T`,在英文数据集训练的模型为例([模型下载地址](#model)),可以使用如下命令进行转换:
|
||||
首先将训练得到best模型,转换成inference model。下面以基于`SVTR-T`,在英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/rec_svtr_tiny_none_ctc_en_train.tar) ),可以使用如下命令进行转换:
|
||||
|
||||
```shell
|
||||
# 注意将pretrained_model的路径设置为本地路径。
|
||||
python3 tools/export_model.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.save_inference_dir=./inference/rec_svtr_tiny_stn_en/ Global.load_static_weights=False
|
||||
python3 tools/export_model.py -c ./rec_svtr_tiny_en_train/rec_svtr_tiny_6local_6global_stn_en.yml -o Global.pretrained_model=./rec_svtr_tiny_none_ctc_en_train/best_accuracy Global.save_inference_dir=./inference/rec_svtr_tiny_stn_en
|
||||
```
|
||||
|
||||
执行如下命令进行模型推理:
|
||||
|
|
|
@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask
|
|||
from .make_pse_gt import MakePseGt
|
||||
|
||||
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
|
||||
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg
|
||||
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg
|
||||
from .ssl_img_aug import SSLRotateResize
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
|
|
|
@ -207,25 +207,6 @@ class PRENResizeImg(object):
|
|||
return data
|
||||
|
||||
|
||||
class SVTRRecResizeImg(object):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
infer_mode=False,
|
||||
character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
|
||||
padding=True,
|
||||
**kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.infer_mode = infer_mode
|
||||
self.character_dict_path = character_dict_path
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
norm_img = resize_norm_img_svtr(img, self.image_shape, self.padding)
|
||||
data['image'] = norm_img
|
||||
return data
|
||||
|
||||
|
||||
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
|
@ -344,58 +325,6 @@ def resize_norm_img_srn(img, image_shape):
|
|||
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||
|
||||
|
||||
def resize_norm_img_svtr(img, image_shape, padding=False):
|
||||
imgC, imgH, imgW = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
if not padding:
|
||||
if h > 2.0 * w:
|
||||
image = Image.fromarray(img)
|
||||
image1 = image.rotate(90, expand=True)
|
||||
image2 = image.rotate(-90, expand=True)
|
||||
img1 = np.array(image1)
|
||||
img2 = np.array(image2)
|
||||
else:
|
||||
img1 = copy.deepcopy(img)
|
||||
img2 = copy.deepcopy(img)
|
||||
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image1 = cv2.resize(
|
||||
img1, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image2 = cv2.resize(
|
||||
img2, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_w = imgW
|
||||
else:
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image1 = resized_image1.astype('float32')
|
||||
resized_image2 = resized_image2.astype('float32')
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image1 = resized_image1.transpose((2, 0, 1)) / 255
|
||||
resized_image2 = resized_image2.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
resized_image1 -= 0.5
|
||||
resized_image1 /= 0.5
|
||||
resized_image2 -= 0.5
|
||||
resized_image2 /= 0.5
|
||||
padding_im = np.zeros((3, imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[0, :, :, 0:resized_w] = resized_image
|
||||
padding_im[1, :, :, 0:resized_w] = resized_image1
|
||||
padding_im[2, :, :, 0:resized_w] = resized_image2
|
||||
return padding_im
|
||||
|
||||
|
||||
def srn_other_inputs(image_shape, num_heads, max_text_length):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
|
|
@ -128,8 +128,6 @@ class STN_ON(nn.Layer):
|
|||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, image):
|
||||
if len(image.shape)==5:
|
||||
image = image.reshape([0, image.shape[-3], image.shape[-2], image.shape[-1]])
|
||||
stn_input = paddle.nn.functional.interpolate(
|
||||
image, self.tps_inputsize, mode="bilinear", align_corners=True)
|
||||
stn_img_feat, ctrl_points = self.stn_head(stn_input)
|
||||
|
|
|
@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess
|
|||
from .fce_postprocess import FCEPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
|
||||
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
|
||||
SEEDLabelDecode, PRENLabelDecode, SVTRLabelDecode
|
||||
SEEDLabelDecode, PRENLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
|
||||
|
@ -42,7 +42,7 @@ def build_post_process(config, global_config=None):
|
|||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
|
||||
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
|
||||
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
|
||||
'DistillationSARLabelDecode', 'SVTRLabelDecode'
|
||||
'DistillationSARLabelDecode'
|
||||
]
|
||||
|
||||
if config['name'] == 'PSEPostProcess':
|
||||
|
|
|
@ -752,40 +752,3 @@ class PRENLabelDecode(BaseRecLabelDecode):
|
|||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
||||
|
||||
class SVTRLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(SVTRLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, tuple):
|
||||
preds = preds[-1]
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=-1)
|
||||
preds_prob = preds.max(axis=-1)
|
||||
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
||||
return_text = []
|
||||
for i in range(0, len(text), 3):
|
||||
text0 = text[i]
|
||||
text1 = text[i + 1]
|
||||
text2 = text[i + 2]
|
||||
|
||||
text_pred = [text0[0], text1[0], text2[0]]
|
||||
text_prob = [text0[1], text1[1], text2[1]]
|
||||
id_max = text_prob.index(max(text_prob))
|
||||
return_text.append((text_pred[id_max], text_prob[id_max]))
|
||||
if label is None:
|
||||
return return_text
|
||||
label = self.decode(label)
|
||||
return return_text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank'] + dict_character
|
||||
return dict_character
|
Loading…
Reference in New Issue