PaddleOCR/ppocr/losses/rec_parseq_loss.py

53 lines
1.7 KiB
Python
Raw Normal View History

Add new recognition method "ParseQ" (#10836) * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md (#10616) * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md * Update README.md * Cherrypicking GH-10217 and GH-10216 to PaddlePaddle:Release/2.7 (#10655) * Don't break overall processing on a bad image * Add preprocessing common to OCR tasks Add preprocessing to options * Update requirements.txt (#10656) added missing pyyaml library * [TIPC]update xpu tipc script (#10658) * fix-typo (#10642) Co-authored-by: Dennis <dvorst@users.noreply.github.com> Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com> * 修改数据增强导致的DSR报错 (#10662) (#10681) * 修改数据增强导致的DSR报错 * 错误修改回滚 * Update algorithm_overview_en.md (#10670) Fixed simple spelling errors. * Implement recoginition method ParseQ * Document update for new recognition method ParseQ * add prediction for parseq * Update rec_vit_parseq.yml * Update rec_r31_sar.yml * Update rec_r31_sar.yml * Update rec_r50_fpn_srn.yml * Update rec_vit_parseq.py * Update rec_vit_parseq.yml * Update rec_parseq_head.py * Update rec_img_aug.py * Update rec_vit_parseq.yml * Update __init__.py * Update predict_rec.py * Update paddleocr.py * Update requirements.txt * Update utility.py * Update utility.py --------- Co-authored-by: xiaoting <31891223+tink2123@users.noreply.github.com> Co-authored-by: topduke <784990967@qq.com> Co-authored-by: dyning <dyning.2003@163.com> Co-authored-by: UserUnknownFactor <63057995+UserUnknownFactor@users.noreply.github.com> Co-authored-by: itasli <ilyas.tasli@outlook.fr> Co-authored-by: Kai Song <50285351+USTCKAY@users.noreply.github.com> Co-authored-by: dvorst <87502756+dvorst@users.noreply.github.com> Co-authored-by: Dennis <dvorst@users.noreply.github.com> Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com> Co-authored-by: Dec20B <1192152456@qq.com> Co-authored-by: ncoffman <51147417+ncoffman@users.noreply.github.com>
2023-09-07 16:36:47 +08:00
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class ParseQLoss(nn.Layer):
def __init__(self, **kwargs):
super(ParseQLoss, self).__init__()
def forward(self, predicts, targets):
label = targets[1] # label
label_len = targets[2]
max_step = paddle.max(label_len).cpu().numpy()[0] + 2
tgt = label[:, :max_step]
logits_list = predicts["logits_list"]
pad_id = predicts["pad_id"]
eos_id = predicts["eos_id"]
Add new recognition method "ParseQ" (#10836) * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md (#10616) * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md * Update README.md * Cherrypicking GH-10217 and GH-10216 to PaddlePaddle:Release/2.7 (#10655) * Don't break overall processing on a bad image * Add preprocessing common to OCR tasks Add preprocessing to options * Update requirements.txt (#10656) added missing pyyaml library * [TIPC]update xpu tipc script (#10658) * fix-typo (#10642) Co-authored-by: Dennis <dvorst@users.noreply.github.com> Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com> * 修改数据增强导致的DSR报错 (#10662) (#10681) * 修改数据增强导致的DSR报错 * 错误修改回滚 * Update algorithm_overview_en.md (#10670) Fixed simple spelling errors. * Implement recoginition method ParseQ * Document update for new recognition method ParseQ * add prediction for parseq * Update rec_vit_parseq.yml * Update rec_r31_sar.yml * Update rec_r31_sar.yml * Update rec_r50_fpn_srn.yml * Update rec_vit_parseq.py * Update rec_vit_parseq.yml * Update rec_parseq_head.py * Update rec_img_aug.py * Update rec_vit_parseq.yml * Update __init__.py * Update predict_rec.py * Update paddleocr.py * Update requirements.txt * Update utility.py * Update utility.py --------- Co-authored-by: xiaoting <31891223+tink2123@users.noreply.github.com> Co-authored-by: topduke <784990967@qq.com> Co-authored-by: dyning <dyning.2003@163.com> Co-authored-by: UserUnknownFactor <63057995+UserUnknownFactor@users.noreply.github.com> Co-authored-by: itasli <ilyas.tasli@outlook.fr> Co-authored-by: Kai Song <50285351+USTCKAY@users.noreply.github.com> Co-authored-by: dvorst <87502756+dvorst@users.noreply.github.com> Co-authored-by: Dennis <dvorst@users.noreply.github.com> Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com> Co-authored-by: Dec20B <1192152456@qq.com> Co-authored-by: ncoffman <51147417+ncoffman@users.noreply.github.com>
2023-09-07 16:36:47 +08:00
tgt_out = tgt[:, 1:]
loss = 0
loss_numel = 0
n = (tgt_out != pad_id).sum().item()
for i, logits in enumerate(logits_list):
loss += n * paddle.nn.functional.cross_entropy(
input=logits, label=tgt_out.flatten(), ignore_index=pad_id
)
Add new recognition method "ParseQ" (#10836) * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md (#10616) * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md * Update PP-OCRv4_introduction.md * Update README.md * Cherrypicking GH-10217 and GH-10216 to PaddlePaddle:Release/2.7 (#10655) * Don't break overall processing on a bad image * Add preprocessing common to OCR tasks Add preprocessing to options * Update requirements.txt (#10656) added missing pyyaml library * [TIPC]update xpu tipc script (#10658) * fix-typo (#10642) Co-authored-by: Dennis <dvorst@users.noreply.github.com> Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com> * 修改数据增强导致的DSR报错 (#10662) (#10681) * 修改数据增强导致的DSR报错 * 错误修改回滚 * Update algorithm_overview_en.md (#10670) Fixed simple spelling errors. * Implement recoginition method ParseQ * Document update for new recognition method ParseQ * add prediction for parseq * Update rec_vit_parseq.yml * Update rec_r31_sar.yml * Update rec_r31_sar.yml * Update rec_r50_fpn_srn.yml * Update rec_vit_parseq.py * Update rec_vit_parseq.yml * Update rec_parseq_head.py * Update rec_img_aug.py * Update rec_vit_parseq.yml * Update __init__.py * Update predict_rec.py * Update paddleocr.py * Update requirements.txt * Update utility.py * Update utility.py --------- Co-authored-by: xiaoting <31891223+tink2123@users.noreply.github.com> Co-authored-by: topduke <784990967@qq.com> Co-authored-by: dyning <dyning.2003@163.com> Co-authored-by: UserUnknownFactor <63057995+UserUnknownFactor@users.noreply.github.com> Co-authored-by: itasli <ilyas.tasli@outlook.fr> Co-authored-by: Kai Song <50285351+USTCKAY@users.noreply.github.com> Co-authored-by: dvorst <87502756+dvorst@users.noreply.github.com> Co-authored-by: Dennis <dvorst@users.noreply.github.com> Co-authored-by: shiyutang <34859558+shiyutang@users.noreply.github.com> Co-authored-by: Dec20B <1192152456@qq.com> Co-authored-by: ncoffman <51147417+ncoffman@users.noreply.github.com>
2023-09-07 16:36:47 +08:00
loss_numel += n
if i == 1:
tgt_out = paddle.where(condition=tgt_out == eos_id, x=pad_id, y=tgt_out)
n = (tgt_out != pad_id).sum().item()
loss /= loss_numel
return {"loss": loss}