-
Notifications
You must be signed in to change notification settings - Fork 0
/
extract_kg_embeddings.py
79 lines (64 loc) · 3.16 KB
/
extract_kg_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
import pandas as pd
import jsonlines
import gc
linked_wikitext_2 = "linked-wikitext-2/"
train = linked_wikitext_2+"train.jsonl"
valid = linked_wikitext_2+"valid.jsonl"
test = linked_wikitext_2+"test.jsonl"
synthetic = "sythetic_dataset_w_negative_samples.jsonl"
def get_qid_embeds(input_filename, output_filename):
df_size = 0
for chunk in pd.read_csv(input_filename,
delimiter='\t',
header=None,
chunksize=10000,
skiprows=1,
encoding='unicode-escape',
names=['id']+[f"embedding_{num}" for num in range(1,201)]):
chunk = chunk.dropna()
qid_df = chunk[chunk['id'].str.startswith("<http://www.wikidata.org/entity/")]
df_size += len(qid_df)
qid_df["id"] = qid_df["id"].apply(lambda x: x.split("/")[-1][:-1])
qid_df.to_csv(output_filename, index=False, header=False, mode="a")
gc.collect()
print("Number of total QID embedding in Pytorch BigGraph dataset:", df_size)
return df_size
def get_relevant_qids(input_filename, output_filename):
qid_set = set()
# from linked-wikitext2 and synthetic dataset
for dataset in [train, valid, test, synthetic]:
with jsonlines.open(dataset) as f:
for line in f.iter():
for annot in line['annotations']:
qid_set.add(annot['id'])
print("Number of unique QIDs so far:",len(qid_set))
qids = pd.read_csv(input_filename, header=None, names=['id'], usecols=[0])
relevant_qids = qids[qids['id'].isin(qid_set)]
relevant_qids.to_csv(output_filename) # will have the same index as the input file
def get_relevant_qid_embedding(input_embedding_file, input_qid_file, output_filename, tot_qids):
relevant_qids = pd.read_csv(input_qid_file, index_col=0)
to_exclude = [i for i in range(tot_qids) if i not in relevant_qids.index.values]
embeds_wktxt = pd.read_csv(input_embedding_file,
header=None,
skiprows=to_exclude,
names=['id']+[f"embedding_{num}" for num in range(1,201)])
embeds_wktxt.to_csv(output_filename, index=False)
def main():
# get all qid embeddings
# tot_qids = get_qid_embeds(input_filename="wikidata_translation_v1.tsv", output_filename="qid_embedding.csv") # takes *really* long
tot_qids = 55032670
# select Q-ids only in linked wikitext-2 and synthetic dataset
get_relevant_qids(input_filename='qid_embedding.csv', output_filename='relevant_qids.csv')
# get embeds of Qids in our datasets
get_relevant_qid_embedding(input_embedding_file='qid_embedding.csv', input_qid_file='relevant_qids.csv',
output_filename="relevant_qid_embedding.csv", tot_qids=tot_qids)
main()
## RUN ##
# python3 -W ignore extract_kg_embeddings.py
## OUTPUTS ##
# Number of total QID embedding in Pytorch BigGraph dataset: 55032670
# Number of unique QIDs so far: 41058
# Number of unique QIDs so far: 44413
# Number of unique QIDs so far: 47932
# Number of unique QIDs so far: 49141