Merge branch 'PaddlePaddle:release/2.3' into release/2.3
commit
afe8ed19dc
|
@ -20,7 +20,7 @@
|
|||
# 下载超轻量中文检测模型:
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar
|
||||
tar xf ch_PP-OCRv2_det_infer.tar
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./ch_PP-OCRv2_det_infer.tar/"
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./ch_PP-OCRv2_det_infer/"
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
|
|||
mklink /d <path/to/paddle_ocr>/train_data/dataset <path/to/dataset>
|
||||
```
|
||||
|
||||
<a name="准备数据集"></a>
|
||||
<a name="自定义数据集"></a>
|
||||
### 1.1 自定义数据集
|
||||
下面以通用数据集为例, 介绍如何准备数据集:
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ class NRTRLoss(nn.Layer):
|
|||
log_prb = F.log_softmax(pred, axis=1)
|
||||
non_pad_mask = paddle.not_equal(
|
||||
tgt, paddle.zeros(
|
||||
tgt.shape, dtype='int64'))
|
||||
tgt.shape, dtype=tgt.dtype))
|
||||
loss = -(one_hot * log_prb).sum(axis=1)
|
||||
loss = loss.masked_select(non_pad_mask).mean()
|
||||
else:
|
||||
|
|
|
@ -168,14 +168,6 @@ class NRTRLabelDecode(BaseRecLabelDecode):
|
|||
character_type, use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if preds.dtype == paddle.int64:
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
if preds[0][0]==2:
|
||||
preds_idx = preds[:,1:]
|
||||
else:
|
||||
preds_idx = preds
|
||||
|
||||
if len(preds) == 2:
|
||||
preds_id = preds[0]
|
||||
preds_prob = preds[1]
|
||||
|
|
|
@ -51,7 +51,7 @@ def get_image_file_list(img_file):
|
|||
if img_file is None or not os.path.exists(img_file):
|
||||
raise Exception("not found any img file in {}".format(img_file))
|
||||
|
||||
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF'}
|
||||
img_end = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'GIF','webp','ppm'}
|
||||
if os.path.isfile(img_file) and imghdr.what(img_file) in img_end:
|
||||
imgs_lists.append(img_file)
|
||||
elif os.path.isdir(img_file):
|
||||
|
@ -77,4 +77,4 @@ def check_and_read_gif(img_path):
|
|||
frame = cv2.cvtColor(frame, cv2.COLOR_GRAY2RGB)
|
||||
imgvalue = frame[:, :, ::-1]
|
||||
return imgvalue, True
|
||||
return None, False
|
||||
return None, False
|
||||
|
|
Loading…
Reference in New Issue