rename SLANetLoss to SLALoss
parent
731688c2dd
commit
c2c43bb1bc
|
@ -54,7 +54,7 @@ Architecture:
|
|||
loc_reg_num: &loc_reg_num 4
|
||||
|
||||
Loss:
|
||||
name: SLANetLoss
|
||||
name: SLALoss
|
||||
structure_weight: 1.0
|
||||
loc_weight: 2.0
|
||||
loc_loss: smooth_l1
|
||||
|
|
|
@ -52,7 +52,7 @@ from .basic_loss import DistanceLoss
|
|||
from .combined_loss import CombinedLoss
|
||||
|
||||
# table loss
|
||||
from .table_att_loss import TableAttentionLoss, SLANetLoss
|
||||
from .table_att_loss import TableAttentionLoss, SLALoss
|
||||
from .table_master_loss import TableMasterLoss
|
||||
# vqa token loss
|
||||
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
|
||||
|
@ -64,7 +64,7 @@ def build_loss(config):
|
|||
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
|
||||
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
|
||||
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLANetLoss'
|
||||
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLALoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -55,9 +55,9 @@ class TableAttentionLoss(nn.Layer):
|
|||
}
|
||||
|
||||
|
||||
class SLANetLoss(nn.Layer):
|
||||
class SLALoss(nn.Layer):
|
||||
def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs):
|
||||
super(SLANetLoss, self).__init__()
|
||||
super(SLALoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean')
|
||||
self.structure_weight = structure_weight
|
||||
self.loc_weight = loc_weight
|
||||
|
|
Loading…
Reference in New Issue