rename SLANetLoss to SLALoss

pull/7137/head
WenmuZhou 2022-08-10 14:58:08 +00:00
parent 731688c2dd
commit c2c43bb1bc
3 changed files with 5 additions and 5 deletions

View File

@ -54,7 +54,7 @@ Architecture:
loc_reg_num: &loc_reg_num 4
Loss:
name: SLANetLoss
name: SLALoss
structure_weight: 1.0
loc_weight: 2.0
loc_loss: smooth_l1

View File

@ -52,7 +52,7 @@ from .basic_loss import DistanceLoss
from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss, SLANetLoss
from .table_att_loss import TableAttentionLoss, SLALoss
from .table_master_loss import TableMasterLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
@ -64,7 +64,7 @@ def build_loss(config):
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLANetLoss'
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'SLALoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')

View File

@ -55,9 +55,9 @@ class TableAttentionLoss(nn.Layer):
}
class SLANetLoss(nn.Layer):
class SLALoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, loc_loss='mse', **kwargs):
super(SLANetLoss, self).__init__()
super(SLALoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='mean')
self.structure_weight = structure_weight
self.loc_weight = loc_weight