commit
57f0125398
|
@ -227,10 +227,6 @@ def train(config,
|
||||||
images = batch[0]
|
images = batch[0]
|
||||||
if use_srn:
|
if use_srn:
|
||||||
model_average = True
|
model_average = True
|
||||||
if model_type == 'table' or extra_input:
|
|
||||||
preds = model(images, data=batch[1:])
|
|
||||||
if model_type == "kie":
|
|
||||||
preds = model(batch)
|
|
||||||
|
|
||||||
train_start = time.time()
|
train_start = time.time()
|
||||||
# use amp
|
# use amp
|
||||||
|
@ -243,6 +239,8 @@ def train(config,
|
||||||
else:
|
else:
|
||||||
if model_type == 'table' or extra_input:
|
if model_type == 'table' or extra_input:
|
||||||
preds = model(images, data=batch[1:])
|
preds = model(images, data=batch[1:])
|
||||||
|
elif model_type == "kie":
|
||||||
|
preds = model(batch)
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
loss = loss_class(preds, batch)
|
loss = loss_class(preds, batch)
|
||||||
|
|
Loading…
Reference in New Issue