commit
57f0125398
tools
|
@ -227,10 +227,6 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
if model_type == "kie":
|
||||
preds = model(batch)
|
||||
|
||||
train_start = time.time()
|
||||
# use amp
|
||||
|
@ -243,6 +239,8 @@ def train(config,
|
|||
else:
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
elif model_type == "kie":
|
||||
preds = model(batch)
|
||||
else:
|
||||
preds = model(images)
|
||||
loss = loss_class(preds, batch)
|
||||
|
|
Loading…
Reference in New Issue