Merge pull request #3717 from littletomatodonkey/dyg/fix_cls_type

fix cls type
pull/3779/head^2
Double_V 2021-08-19 14:16:51 +08:00 committed by GitHub
commit 4938218139
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -25,6 +25,6 @@ class ClsLoss(nn.Layer):
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
def forward(self, predicts, batch):
label = batch[1]
label = batch[1].astype("int64")
loss = self.loss_func(input=predicts, label=label)
return {'loss': loss}