Merge branch 'PaddlePaddle:release/2.3' into release/2.3

pull/4332/head
天涯古巷 2021-11-04 15:33:00 +08:00 committed by GitHub
commit afe8ed19dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 5 additions and 13 deletions

View File

@ -20,7 +20,7 @@
# 下载超轻量中文检测模型:
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar
tar xf ch_PP-OCRv2_det_infer.tar
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./ch_PP-OCRv2_det_infer.tar/"
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./ch_PP-OCRv2_det_infer/"
```

View File

@ -33,7 +33,7 @@ ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
mklink /d <path/to/paddle_ocr>/train_data/dataset <path/to/dataset>
```
<a name="准备数据集"></a>
<a name="自定义数据集"></a>
### 1.1 自定义数据集
下面以通用数据集为例, 介绍如何准备数据集:

View File

@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer):
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype='int64'))
tgt.shape, dtype=tgt.dtype))
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:

View File

@ -168,14 +168,6 @@ class NRTRLabelDecode(BaseRecLabelDecode):
character_type, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if preds.dtype == paddle.int64:
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
if preds[0][0]==2:
preds_idx = preds[:,1:]
else:
preds_idx = preds
if len(preds) == 2:
preds_id = preds[0]
preds_prob = preds[1]

View File

@ -51,7 +51,7 @@ def get_image_file_list(img_file):
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF','webp','ppm'}
if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
@ -77,4 +77,4 @@ def check_and_read_gif(img_path):
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
imgvalue = frame[:, :, ::-1]
return imgvalue, True
return None, False
return None, False