forked from haidfs/TransE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TestDatasetTF.py
91 lines (87 loc) · 3.88 KB
/
TestDatasetTF.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# -*- coding: UTF-8 -*-
import os
import pandas as pd
class KnowledgeGraph:
def __init__(self, data_dir):
# 考虑到tf的各项api使用,Python不能将Tensor类型直接转换成字符串类型,但是可以将TF类型转换成numpy类型
# 所以这里的训练三元组,测试三元组等等,都是id三元组,而不是字符串三元组
self.data_dir = data_dir
self.entity_dict = {}
self.entities = []
self.relation_dict = {}
self.n_entity = 0
self.n_relation = 0
self.training_triples = [] # list of triples in the form of (h, t, r)
self.validation_triples = []
self.test_triples = []
self.n_training_triple = 0
self.n_validation_triple = 0
self.n_test_triple = 0
'''load dicts and triples'''
self.load_dicts()
self.load_triples()
'''construct pools after loading'''
self.training_triple_pool = set(self.training_triples)
self.golden_triple_pool = set(
self.training_triples) | set(
self.validation_triples) | set(
self.test_triples)
def load_dicts(self):
entity_dict_file = 'entity2id.txt'
relation_dict_file = 'relation2id.txt'
print('-----Loading entity dict-----')
entity_df = pd.read_table(
os.path.join(
self.data_dir,
entity_dict_file),
header=None)
self.entity_dict = dict(zip(entity_df[0], entity_df[1]))
self.n_entity = len(self.entity_dict)
self.entities = list(self.entity_dict.values())
print('#entity: {}'.format(self.n_entity))
print('-----Loading relation dict-----')
relation_df = pd.read_table(
os.path.join(
self.data_dir,
relation_dict_file),
header=None)
self.relation_dict = dict(zip(relation_df[0], relation_df[1]))
self.n_relation = len(self.relation_dict)
print('#relation: {}'.format(self.n_relation))
def load_triples(self):
training_file = 'train.txt'
validation_file = 'valid.txt'
test_file = 'test.txt'
print('-----Loading training triples-----')
training_df = pd.read_table(
os.path.join(
self.data_dir,
training_file),
header=None)
self.training_triples = list(zip([self.entity_dict[h] for h in training_df[0]],
[self.entity_dict[t] for t in training_df[1]],
[self.relation_dict[r] for r in training_df[2]]))
self.n_training_triple = len(self.training_triples)
print('#training triple: {}'.format(self.n_training_triple))
print('-----Loading validation triples-----')
validation_df = pd.read_table(
os.path.join(
self.data_dir,
validation_file),
header=None)
self.validation_triples = list(zip([self.entity_dict[h] for h in validation_df[0]],
[self.entity_dict[t] for t in validation_df[1]],
[self.relation_dict[r] for r in validation_df[2]]))
self.n_validation_triple = len(self.validation_triples)
print('#validation triple: {}'.format(self.n_validation_triple))
print('-----Loading test triples------')
test_df = pd.read_table(
os.path.join(
self.data_dir,
test_file),
header=None)
self.test_triples = list(zip([self.entity_dict[h] for h in test_df[0]],
[self.entity_dict[t] for t in test_df[1]],
[self.relation_dict[r] for r in test_df[2]]))
self.n_test_triple = len(self.test_triples)
print('#test triple: {}'.format(self.n_test_triple))