Merge pull request from LDOUBLEV/sdmgr

fix train
pull/4988/head
MissPenguin 2021-12-18 17:49:47 +08:00 committed by GitHub
commit 57f0125398
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 4 deletions

View File

@ -227,10 +227,6 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
if model_type == "kie":
preds = model(batch)
train_start = time.time()
# use amp
@ -243,6 +239,8 @@ def train(config,
else:
if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
elif model_type == "kie":
preds = model(batch)
else:
preds = model(images)
loss = loss_class(preds, batch)