fix kie infer and eval bug
parent
30e8dd8eef
commit
66029dd8c6
|
@ -167,10 +167,10 @@ class Kie_backbone(nn.Layer):
|
||||||
gt_bboxes[i, :num, ...], dtype='float32'))
|
gt_bboxes[i, :num, ...], dtype='float32'))
|
||||||
return img, temp_relations, temp_texts, temp_gt_bboxes
|
return img, temp_relations, temp_texts, temp_gt_bboxes
|
||||||
|
|
||||||
def forward(self, images, inputs):
|
def forward(self, inputs):
|
||||||
img = images
|
img = inputs[0]
|
||||||
relations, texts, gt_bboxes, tag, img_size = inputs[0], inputs[
|
relations, texts, gt_bboxes, tag, img_size = inputs[1], inputs[
|
||||||
1], inputs[2], inputs[4], inputs[-1]
|
2], inputs[3], inputs[5], inputs[-1]
|
||||||
img, relations, texts, gt_bboxes = self.pre_process(
|
img, relations, texts, gt_bboxes = self.pre_process(
|
||||||
img, relations, texts, gt_bboxes, tag, img_size)
|
img, relations, texts, gt_bboxes, tag, img_size)
|
||||||
x = self.img_feat(img)
|
x = self.img_feat(img)
|
||||||
|
|
|
@ -49,7 +49,7 @@ class SDMGRHead(nn.Layer):
|
||||||
self.node_cls = nn.Linear(node_embed, num_classes)
|
self.node_cls = nn.Linear(node_embed, num_classes)
|
||||||
self.edge_cls = nn.Linear(edge_embed, 2)
|
self.edge_cls = nn.Linear(edge_embed, 2)
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input, targets):
|
||||||
relations, texts, x = input
|
relations, texts, x = input
|
||||||
node_nums, char_nums = [], []
|
node_nums, char_nums = [], []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
|
|
|
@ -54,7 +54,7 @@ def main():
|
||||||
config['Architecture']["Head"]['out_channels'] = char_num
|
config['Architecture']["Head"]['out_channels'] = char_num
|
||||||
|
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"]
|
extra_input = config['Architecture']['algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
|
||||||
if "model_type" in config['Architecture'].keys():
|
if "model_type" in config['Architecture'].keys():
|
||||||
model_type = config['Architecture']['model_type']
|
model_type = config['Architecture']['model_type']
|
||||||
else:
|
else:
|
||||||
|
@ -68,7 +68,7 @@ def main():
|
||||||
|
|
||||||
# build metric
|
# build metric
|
||||||
eval_class = build_metric(config['Metric'])
|
eval_class = build_metric(config['Metric'])
|
||||||
|
logger.info(f"extra_inputs: {extra_input}")
|
||||||
# start eval
|
# start eval
|
||||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||||
eval_class, model_type, extra_input)
|
eval_class, model_type, extra_input)
|
||||||
|
|
|
@ -80,8 +80,7 @@ def draw_kie_result(batch, node, idx_to_cls, count):
|
||||||
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
|
vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
|
||||||
vis_img[:, :w] = img
|
vis_img[:, :w] = img
|
||||||
vis_img[:, w:] = pred_img
|
vis_img[:, w:] = pred_img
|
||||||
save_kie_path = os.path.dirname(config['Global'][
|
save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/"
|
||||||
'save_res_path']) + "/kie_results/"
|
|
||||||
if not os.path.exists(save_kie_path):
|
if not os.path.exists(save_kie_path):
|
||||||
os.makedirs(save_kie_path)
|
os.makedirs(save_kie_path)
|
||||||
save_path = os.path.join(save_kie_path, str(count) + ".png")
|
save_path = os.path.join(save_kie_path, str(count) + ".png")
|
||||||
|
@ -129,8 +128,7 @@ def main():
|
||||||
batch_pred[i] = paddle.to_tensor(
|
batch_pred[i] = paddle.to_tensor(
|
||||||
np.expand_dims(
|
np.expand_dims(
|
||||||
batch[i], axis=0))
|
batch[i], axis=0))
|
||||||
|
node, edge = model(batch_pred)
|
||||||
node, edge = model(batch[0], batch[1:])
|
|
||||||
node = F.softmax(node, -1)
|
node = F.softmax(node, -1)
|
||||||
draw_kie_result(batch, node, idx_to_cls, index)
|
draw_kie_result(batch, node, idx_to_cls, index)
|
||||||
logger.info("success!")
|
logger.info("success!")
|
||||||
|
|
|
@ -196,7 +196,7 @@ def train(config,
|
||||||
|
|
||||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
extra_input = config['Architecture'][
|
extra_input = config['Architecture'][
|
||||||
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED", "SDMGR"]
|
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
|
||||||
try:
|
try:
|
||||||
model_type = config['Architecture']['model_type']
|
model_type = config['Architecture']['model_type']
|
||||||
except:
|
except:
|
||||||
|
@ -228,6 +228,8 @@ def train(config,
|
||||||
model_average = True
|
model_average = True
|
||||||
if model_type == 'table' or extra_input:
|
if model_type == 'table' or extra_input:
|
||||||
preds = model(images, data=batch[1:])
|
preds = model(images, data=batch[1:])
|
||||||
|
if model_type == "kie":
|
||||||
|
preds = model(batch)
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
loss = loss_class(preds, batch)
|
loss = loss_class(preds, batch)
|
||||||
|
@ -249,7 +251,7 @@ def train(config,
|
||||||
|
|
||||||
if cal_metric_during_train: # only rec and cls need
|
if cal_metric_during_train: # only rec and cls need
|
||||||
batch = [item.numpy() for item in batch]
|
batch = [item.numpy() for item in batch]
|
||||||
if model_type == 'table':
|
if model_type in ['table', 'kie']:
|
||||||
eval_class(preds, batch)
|
eval_class(preds, batch)
|
||||||
else:
|
else:
|
||||||
post_result = post_process_class(preds, batch[1])
|
post_result = post_process_class(preds, batch[1])
|
||||||
|
@ -377,13 +379,15 @@ def eval(model,
|
||||||
start = time.time()
|
start = time.time()
|
||||||
if model_type == 'table' or extra_input:
|
if model_type == 'table' or extra_input:
|
||||||
preds = model(images, data=batch[1:])
|
preds = model(images, data=batch[1:])
|
||||||
|
if model_type == "kie":
|
||||||
|
preds = model(batch)
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
batch = [item.numpy() for item in batch]
|
batch = [item.numpy() for item in batch]
|
||||||
# Obtain usable results from post-processing methods
|
# Obtain usable results from post-processing methods
|
||||||
total_time += time.time() - start
|
total_time += time.time() - start
|
||||||
# Evaluate the results of the current batch
|
# Evaluate the results of the current batch
|
||||||
if model_type == 'table':
|
if model_type in ['table', 'kie']:
|
||||||
eval_class(preds, batch)
|
eval_class(preds, batch)
|
||||||
else:
|
else:
|
||||||
post_result = post_process_class(preds, batch[1])
|
post_result = post_process_class(preds, batch[1])
|
||||||
|
|
Loading…
Reference in New Issue