mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-06-03 21:53:39 +08:00
rename SLANetLoss to SLALoss
This commit is contained in:
parent
731688c2dd
commit
c2c43bb1bc
@ -54,7 +54,7 @@ Architecture:
|
|||||||
loc_reg_num: &loc_reg_num 4
|
loc_reg_num: &loc_reg_num 4
|
||||||
|
|
||||||
Loss:
|
Loss:
|
||||||
name: SLANetLoss
|
name: SLALoss
|
||||||
structure_weight: 1.0
|
structure_weight: 1.0
|
||||||
loc_weight: 2.0
|
loc_weight: 2.0
|
||||||
loc_loss: smooth_l1
|
loc_loss: smooth_l1
|
||||||
|
@ -52,7 +52,7 @@ from .basic_loss import DistanceLoss
|
|||||||
from .combined_loss import CombinedLoss
|
from .combined_loss import CombinedLoss
|
||||||
|
|
||||||
# table loss
|
# table loss
|
||||||
from .table_att_loss import TableAttentionLoss, SLANetLoss
|
from .table_att_loss import TableAttentionLoss, SLALoss
|
||||||
from .table_master_loss import TableMasterLoss
|
from .table_master_loss import TableMasterLoss
|
||||||
# vqa token loss
|
# vqa token loss
|
||||||
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
|
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
|
||||||
@ -64,7 +64,7 @@ def build_loss(config):
|
|||||||
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
|
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
|
||||||
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLANetLoss'
|
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLALoss'
|
||||||
]
|
]
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
|
@ -55,9 +55,9 @@ class TableAttentionLoss(nn.Layer):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class SLANetLoss(nn.Layer):
|
class SLALoss(nn.Layer):
|
||||||
def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs):
|
def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs):
|
||||||
super(SLANetLoss, self).__init__()
|
super(SLALoss, self).__init__()
|
||||||
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean')
|
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean')
|
||||||
self.structure_weight = structure_weight
|
self.structure_weight = structure_weight
|
||||||
self.loc_weight = loc_weight
|
self.loc_weight = loc_weight
|
||||||
|
Loading…
x
Reference in New Issue
Block a user