fix re inference error (#8475)
parent
b9c17d6990
commit
7a9cfaad9d
|
@ -81,7 +81,7 @@ def make_input(ser_inputs, ser_results):
|
||||||
end.append(entity['end'])
|
end.append(entity['end'])
|
||||||
label.append(entities_labels[res['pred']])
|
label.append(entities_labels[res['pred']])
|
||||||
|
|
||||||
entities = np.full([max_seq_len + 1, 3], fill_value=-1)
|
entities = np.full([max_seq_len + 1, 3], fill_value=-1, dtype=np.int64)
|
||||||
entities[0, 0] = len(start)
|
entities[0, 0] = len(start)
|
||||||
entities[1:len(start) + 1, 0] = start
|
entities[1:len(start) + 1, 0] = start
|
||||||
entities[0, 1] = len(end)
|
entities[0, 1] = len(end)
|
||||||
|
@ -98,7 +98,7 @@ def make_input(ser_inputs, ser_results):
|
||||||
head.append(i)
|
head.append(i)
|
||||||
tail.append(j)
|
tail.append(j)
|
||||||
|
|
||||||
relations = np.full([len(head) + 1, 2], fill_value=-1)
|
relations = np.full([len(head) + 1, 2], fill_value=-1, dtype=np.int64)
|
||||||
relations[0, 0] = len(head)
|
relations[0, 0] = len(head)
|
||||||
relations[1:len(head) + 1, 0] = head
|
relations[1:len(head) + 1, 0] = head
|
||||||
relations[0, 1] = len(tail)
|
relations[0, 1] = len(tail)
|
||||||
|
|
Loading…
Reference in New Issue