68 lines
2.4 KiB
Python
68 lines
2.4 KiB
Python
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
class VQAReTokenRelation(object):
|
|
def __init__(self, **kwargs):
|
|
pass
|
|
|
|
def __call__(self, data):
|
|
"""
|
|
build relations
|
|
"""
|
|
entities = data['entities']
|
|
relations = data['relations']
|
|
id2label = data.pop('id2label')
|
|
empty_entity = data.pop('empty_entity')
|
|
entity_id_to_index_map = data.pop('entity_id_to_index_map')
|
|
|
|
relations = list(set(relations))
|
|
relations = [
|
|
rel for rel in relations
|
|
if rel[0] not in empty_entity and rel[1] not in empty_entity
|
|
]
|
|
kv_relations = []
|
|
for rel in relations:
|
|
pair = [id2label[rel[0]], id2label[rel[1]]]
|
|
if pair == ["question", "answer"]:
|
|
kv_relations.append({
|
|
"head": entity_id_to_index_map[rel[0]],
|
|
"tail": entity_id_to_index_map[rel[1]]
|
|
})
|
|
elif pair == ["answer", "question"]:
|
|
kv_relations.append({
|
|
"head": entity_id_to_index_map[rel[1]],
|
|
"tail": entity_id_to_index_map[rel[0]]
|
|
})
|
|
else:
|
|
continue
|
|
relations = sorted(
|
|
[{
|
|
"head": rel["head"],
|
|
"tail": rel["tail"],
|
|
"start_index": self.get_relation_span(rel, entities)[0],
|
|
"end_index": self.get_relation_span(rel, entities)[1],
|
|
} for rel in kv_relations],
|
|
key=lambda x: x["head"], )
|
|
|
|
data['relations'] = relations
|
|
return data
|
|
|
|
def get_relation_span(self, rel, entities):
|
|
bound = []
|
|
for entity_index in [rel["head"], rel["tail"]]:
|
|
bound.append(entities[entity_index]["start"])
|
|
bound.append(entities[entity_index]["end"])
|
|
return min(bound), max(bound)
|