diff --git a/tools/program.py b/tools/program.py index 333e8ed97..1dfd06af8 100755 --- a/tools/program.py +++ b/tools/program.py @@ -266,7 +266,7 @@ def train(config, stats['lr'] = lr train_stats.update(stats) - if cal_metric_during_train: # only rec and cls need + if cal_metric_during_train and model_type is not "det": # only rec and cls need batch = [item.numpy() for item in batch] if model_type in ['table', 'kie']: eval_class(preds, batch)