diff --git a/README.md b/README.md index acea009..1de8ca0 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,32 @@ -该工程代码主要是实现自己阅读过的和知识图谱相关的经典算法的代码: -1.TransE是知识图谱中知识表示的经典算法,工程实现了训练代码(多进程通信版)和测试代码 -后续如继续进行论文阅读会补充相应的代码 -2.TransE论文地址: https://www.utc.fr/~bordesan/dokuwiki/_media/en/transe_nips13.pdf -3.该工程代码是基于wuxiyu的TransE代码进行注释和修改,感谢他的工作 https://github.com/wuxiyu/transE -4.TransE SGD解释: https://blog.csdn.net/weixin_42348333/article/details/89598144 \ No newline at end of file +ù̴ҪʵԼĶĺ֪ʶͼصľ㷨Ĵ룺 +1.TransE֪ʶͼ֪ʶʾľ㷨ʵѵ루ͨŰ棩ͲԴ +ĶᲹӦĴ +2.dataļ޷ϴhttps://github.com/thunlp/KB2Edata.zipѹ̵data· +3.TransEĵַ https://www.utc.fr/~bordesan/dokuwiki/_media/en/transe_nips13.pdf +###ѵ +####Simple汾 +./train_fb15k.sh 0 +ʹPythonɶӦѵ +####Manager汾 +./train_fb15k.sh 1 +TransEʵڶ֮䴫 +####Queue汾 +./train_fb15k.sh 2 +TransEѵݴУС̿ӿѵٶ + +ѵ֮ٽв +###Բ +####TestTransEMqQueue +python TestTransEMpQueue.py +̶вԼ٣Чԣ0.5sԽҪ5h +####TestMainTF + python TestMainTF.py +tf̲Լ٣ЧԽҪ8minҡ +###ղԽ + FB15k +epochs:2000 MeanRank Hits@10 + raw filter raw filter +head 320.743 192.152 29.7 41.2 +tail 236.984 153.431 36.1 46.2 +average 278.863 172.792 32.9 43.7 +paper 243 125 34.9 47.1 \ No newline at end of file diff --git a/TestDatasetTF.py b/TestDatasetTF.py new file mode 100644 index 0000000..b464818 --- /dev/null +++ b/TestDatasetTF.py @@ -0,0 +1,90 @@ +import os +import pandas as pd + + +class KnowledgeGraph: + def __init__(self, data_dir): + # ǵtfĸapiʹãPythonܽTensorֱתַͣǿԽTFתnumpy + # ѵԪ飬ԪȵȣidԪ飬ַԪ + self.data_dir = data_dir + self.entity_dict = {} + self.entities = [] + self.relation_dict = {} + self.n_entity = 0 + self.n_relation = 0 + self.training_triples = [] # list of triples in the form of (h, t, r) + self.validation_triples = [] + self.test_triples = [] + self.n_training_triple = 0 + self.n_validation_triple = 0 + self.n_test_triple = 0 + '''load dicts and triples''' + self.load_dicts() + self.load_triples() + '''construct pools after loading''' + self.training_triple_pool = set(self.training_triples) + self.golden_triple_pool = set( + self.training_triples) | set( + self.validation_triples) | set( + self.test_triples) + + def load_dicts(self): + entity_dict_file = 'entity2id.txt' + relation_dict_file = 'relation2id.txt' + print('-----Loading entity dict-----') + entity_df = pd.read_table( + os.path.join( + self.data_dir, + entity_dict_file), + header=None) + self.entity_dict = dict(zip(entity_df[0], entity_df[1])) + self.n_entity = len(self.entity_dict) + self.entities = list(self.entity_dict.values()) + print('#entity: {}'.format(self.n_entity)) + print('-----Loading relation dict-----') + relation_df = pd.read_table( + os.path.join( + self.data_dir, + relation_dict_file), + header=None) + self.relation_dict = dict(zip(relation_df[0], relation_df[1])) + self.n_relation = len(self.relation_dict) + print('#relation: {}'.format(self.n_relation)) + + def load_triples(self): + training_file = 'train.txt' + validation_file = 'valid.txt' + test_file = 'test.txt' + print('-----Loading training triples-----') + training_df = pd.read_table( + os.path.join( + self.data_dir, + training_file), + header=None) + self.training_triples = list(zip([self.entity_dict[h] for h in training_df[0]], + [self.entity_dict[t] for t in training_df[1]], + [self.relation_dict[r] for r in training_df[2]])) + self.n_training_triple = len(self.training_triples) + print('#training triple: {}'.format(self.n_training_triple)) + print('-----Loading validation triples-----') + validation_df = pd.read_table( + os.path.join( + self.data_dir, + validation_file), + header=None) + self.validation_triples = list(zip([self.entity_dict[h] for h in validation_df[0]], + [self.entity_dict[t] for t in validation_df[1]], + [self.relation_dict[r] for r in validation_df[2]])) + self.n_validation_triple = len(self.validation_triples) + print('#validation triple: {}'.format(self.n_validation_triple)) + print('-----Loading test triples------') + test_df = pd.read_table( + os.path.join( + self.data_dir, + test_file), + header=None) + self.test_triples = list(zip([self.entity_dict[h] for h in test_df[0]], + [self.entity_dict[t] for t in test_df[1]], + [self.relation_dict[r] for r in test_df[2]])) + self.n_test_triple = len(self.test_triples) + print('#test triple: {}'.format(self.n_test_triple)) \ No newline at end of file diff --git a/TestMainTF.py b/TestMainTF.py new file mode 100644 index 0000000..0b3a199 --- /dev/null +++ b/TestMainTF.py @@ -0,0 +1,43 @@ +import logging + +import tensorflow as tf +import argparse +from TestDatasetTF import KnowledgeGraph +from TestModelTF import TransE +from TestTransEMpQueue import get_dict_from_vector_file + + +def main(): + parser = argparse.ArgumentParser(description='TransE') + parser.add_argument('--data_dir', type=str, default=r'./data/FB15k/') + parser.add_argument('--score_func', type=str, default='L1') + parser.add_argument('--n_rank_calculator', type=int, default=24) + args = parser.parse_args() + print(args) + kg = KnowledgeGraph(data_dir=args.data_dir) + + entity_vector_file = "data/entityVector.txt" + entity_vector_dyct = get_dict_from_vector_file(entity_vector_file) + relation_vector_file = "data/relationVector.txt" + relation_vector_dyct = get_dict_from_vector_file(relation_vector_file) + logging.info("********** Start Test **********") + + kge_model = TransE( + kg=kg, + score_func=args.score_func, + n_rank_calculator=args.n_rank_calculator, + entity_vector_dict=entity_vector_dyct, + rels_vector_dict=relation_vector_dyct) + + gpu_config = tf.GPUOptions(allow_growth=True) + sess_config = tf.ConfigProto(gpu_options=gpu_config) + with tf.Session(config=sess_config) as sess: + print('-----Initializing tf graph-----') + tf.global_variables_initializer().run() + print('-----Initialization accomplished-----') + kge_model.check_norm() + kge_model.launch_evaluation(session=sess) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/TestModelTF.py b/TestModelTF.py new file mode 100644 index 0000000..185c68c --- /dev/null +++ b/TestModelTF.py @@ -0,0 +1,207 @@ +import timeit +import numpy as np +import tensorflow as tf +import multiprocessing as mp +from TestDatasetTF import KnowledgeGraph + + +class TransE: + def __init__(self, kg: KnowledgeGraph, + score_func, + n_rank_calculator, entity_vector_dict, rels_vector_dict): + self.kg = kg + self.score_func = score_func + self.n_rank_calculator = n_rank_calculator + + self.entity_vector_dict = entity_vector_dict + self.rels_vector_dict = rels_vector_dict + self.entity_embedding = None + self.relation_embedding = None + + '''ops for evaluation''' + self.eval_triple = tf.placeholder(dtype=tf.int32, shape=[3]) + self.idx_head_prediction = None + self.idx_tail_prediction = None + self.build_entity_embedding() + self.build_eval_graph() + + def build_entity_embedding(self): + self.entity_embedding = np.array( + list(self.entity_vector_dict.values())) + self.relation_embedding = np.array( + list(self.rels_vector_dict.values())) + + def build_eval_graph(self): + with tf.name_scope('evaluation'): + self.idx_head_prediction, self.idx_tail_prediction = self.evaluate( + self.eval_triple) + + def evaluate(self, eval_triple): + with tf.name_scope('lookup'): + head = tf.nn.embedding_lookup( + self.entity_embedding, eval_triple[0]) + tail = tf.nn.embedding_lookup( + self.entity_embedding, eval_triple[1]) + relation = tf.nn.embedding_lookup( + self.relation_embedding, eval_triple[2]) + with tf.name_scope('link'): + # ̫ȷ;h,r,tӦö[1,dim]άȵ self.entity_embeddingӦ[n,dim]άȵӼõʲôͣ + # listͣڲͬάDzֱӼӼġǶnp.arraytfembeddingǿֱģͬ self.entity_embedding + # ÿһжںh,r,t + distance_head_prediction = self.entity_embedding + relation - tail + distance_tail_prediction = head + relation - self.entity_embedding + with tf.name_scope('rank'): + if self.score_func == 'L1': # L1 score + _, idx_head_prediction = tf.nn.top_k(tf.reduce_sum( + tf.abs(distance_head_prediction), axis=1), k=self.kg.n_entity) + _, idx_tail_prediction = tf.nn.top_k(tf.reduce_sum( + tf.abs(distance_tail_prediction), axis=1), k=self.kg.n_entity) + else: # L2 score + _, idx_head_prediction = tf.nn.top_k(tf.reduce_sum( + tf.square(distance_head_prediction), axis=1), k=self.kg.n_entity) + _, idx_tail_prediction = tf.nn.top_k(tf.reduce_sum( + tf.square(distance_tail_prediction), axis=1), k=self.kg.n_entity) + return idx_head_prediction, idx_tail_prediction + + def launch_evaluation(self, session): + eval_result_queue = mp.JoinableQueue() + rank_result_queue = mp.Queue() + print('-----Start evaluation-----') + start = timeit.default_timer() + for _ in range(self.n_rank_calculator): + mp.Process( + target=self.calculate_rank, + kwargs={ + 'in_queue': eval_result_queue, + 'out_queue': rank_result_queue}).start() + n_used_eval_triple = 0 + for eval_triple in self.kg.test_triples: + idx_head_prediction, idx_tail_prediction = session.run( + fetches=[ + self.idx_head_prediction, self.idx_tail_prediction], feed_dict={ + self.eval_triple: eval_triple}) + eval_result_queue.put((eval_triple, idx_head_prediction, idx_tail_prediction)) + n_used_eval_triple += 1 + print( + '[{:.3f}s] #evaluation triple: {}/{}'.format( + timeit.default_timer() - start, + n_used_eval_triple, + self.kg.n_test_triple), + end='\r') + print() + for _ in range(self.n_rank_calculator): + eval_result_queue.put(None) + print('-----Joining all rank calculator-----') + eval_result_queue.join() + print('-----All rank calculation accomplished-----') + print('-----Obtaining evaluation results-----') + '''Raw''' + head_meanrank_raw = 0 + head_hits10_raw = 0 + tail_meanrank_raw = 0 + tail_hits10_raw = 0 + '''Filter''' + head_meanrank_filter = 0 + head_hits10_filter = 0 + tail_meanrank_filter = 0 + tail_hits10_filter = 0 + for _ in range(n_used_eval_triple): + head_rank_raw, tail_rank_raw, head_rank_filter, tail_rank_filter = rank_result_queue.get() + head_meanrank_raw += head_rank_raw + if head_rank_raw < 10: + head_hits10_raw += 1 + tail_meanrank_raw += tail_rank_raw + if tail_rank_raw < 10: + tail_hits10_raw += 1 + head_meanrank_filter += head_rank_filter + if head_rank_filter < 10: + head_hits10_filter += 1 + tail_meanrank_filter += tail_rank_filter + if tail_rank_filter < 10: + tail_hits10_filter += 1 + print('-----Raw-----') + head_meanrank_raw /= n_used_eval_triple + head_hits10_raw /= n_used_eval_triple + tail_meanrank_raw /= n_used_eval_triple + tail_hits10_raw /= n_used_eval_triple + print('-----Head prediction-----') + print( + 'MeanRank: {:.3f}, Hits@10: {:.3f}'.format( + head_meanrank_raw, + head_hits10_raw)) + print('-----Tail prediction-----') + print( + 'MeanRank: {:.3f}, Hits@10: {:.3f}'.format( + tail_meanrank_raw, + tail_hits10_raw)) + print('------Average------') + print( + 'MeanRank: {:.3f}, Hits@10: {:.3f}'.format( + (head_meanrank_raw + tail_meanrank_raw) / 2, + (head_hits10_raw + tail_hits10_raw) / 2)) + print('-----Filter-----') + head_meanrank_filter /= n_used_eval_triple + head_hits10_filter /= n_used_eval_triple + tail_meanrank_filter /= n_used_eval_triple + tail_hits10_filter /= n_used_eval_triple + print('-----Head prediction-----') + print('MeanRank: {:.3f}, Hits@10: {:.3f}'.format( + head_meanrank_filter, head_hits10_filter)) + print('-----Tail prediction-----') + print('MeanRank: {:.3f}, Hits@10: {:.3f}'.format( + tail_meanrank_filter, tail_hits10_filter)) + print('-----Average-----') + print( + 'MeanRank: {:.3f}, Hits@10: {:.3f}'.format( + (head_meanrank_filter + tail_meanrank_filter) / 2, + (head_hits10_filter + tail_hits10_filter) / 2)) + print('cost time: {:.3f}s'.format(timeit.default_timer() - start)) + print('-----Finish evaluation-----') + + def calculate_rank(self, in_queue, out_queue): + while True: + idx_predictions = in_queue.get() + if idx_predictions is None: + in_queue.task_done() + return + else: + eval_triple, idx_head_prediction, idx_tail_prediction = idx_predictions + head, tail, relation = eval_triple + head_rank_raw = 0 + tail_rank_raw = 0 + head_rank_filter = 0 + tail_rank_filter = 0 + for candidate in idx_head_prediction[::-1]: + if candidate == head: + break + else: + head_rank_raw += 1 + if (candidate, tail, + relation) in self.kg.golden_triple_pool: + continue + else: + head_rank_filter += 1 + for candidate in idx_tail_prediction[::-1]: + if candidate == tail: + break + else: + tail_rank_raw += 1 + if (head, candidate, + relation) in self.kg.golden_triple_pool: + continue + else: + tail_rank_filter += 1 + out_queue.put( + (head_rank_raw, + tail_rank_raw, + head_rank_filter, + tail_rank_filter)) + in_queue.task_done() + + def check_norm(self): + print('-----Check norm-----') + entity_embedding = self.entity_embedding + relation_embedding = self.relation_embedding + entity_norm = np.linalg.norm(entity_embedding, ord=2, axis=1) + relation_norm = np.linalg.norm(relation_embedding, ord=2, axis=1) + # print('entity norm: {} relation norm: {}'.format(entity_norm, relation_norm)) \ No newline at end of file diff --git a/TestTransE.py b/TestTransE.py deleted file mode 100644 index b5fdbf4..0000000 --- a/TestTransE.py +++ /dev/null @@ -1,180 +0,0 @@ -from numpy import * -import operator - - -class Test: - '''۹ -֪ʶһnʵ壬ô۹£ -ÿһԵԪaеͷʵβʵ壬滻Ϊ֪ʶеʵ壬ҲǻnԪ顣 -ֱnԪֵtransEУǼh+r-tֵԵõnֱֵӦnԪ顣 -nֵ -¼ԭԪaֵš -дڲԼеIJԪظ̡ -ÿȷԪֵƽõֵdzΪMean Rank -ȷԪС10ıõֵdzΪHits@10 -۵Ḷָ́꣺Mean RankHits@10Mean RankԽСԽãHits@10ԽԽáôδHits@10PythonִٶȺ -ߺʹ廪ѧFast_TransX룬ʹC++дܸߣܹٵóѵͲԽ -''' - def __init__(self, entity_list, entity_vector_list, relation_list, relation_vector_list, triple_list_train, - triple_list_test, - label="head", is_fit=False): - self.entity_list = {} - self.relation_list = {} - for name, vec in zip(entity_list, entity_vector_list): - self.entity_list[name] = vec - for name, vec in zip(relation_list, relation_vector_list): - self.relation_list[name] = vec - self.triple_list_train = triple_list_train - self.triple_list_test = triple_list_test - self.rank = [] - self.label = label - self.is_fit = is_fit - - def write_rank(self, dir): - print("д") - file = open(dir, 'w') - for r in self.rank: - file.write(str(r[0]) + "\t") - file.write(str(r[1]) + "\t") - file.write(str(r[2]) + "\t") - file.write(str(r[3]) + "\n") - file.close() - - def get_rank(self): - cou = 0 - for triplet in self.triple_list_test: - rank_list = {} - for entity_temp in self.entity_list.keys(): - if self.label == "head": - corrupted_triplet = (entity_temp, triplet[1], triplet[2]) - if self.is_fit and (corrupted_triplet in self.triple_list_train): - continue - rank_list[entity_temp] = distance(self.entity_list[entity_temp], self.entity_list[triplet[1]], - self.relation_list[triplet[2]]) - else: # ݱǩ滻ͷʵ滻βʵ - corrupted_triplet = (triplet[0], entity_temp, triplet[2]) - if self.is_fit and (corrupted_triplet in self.triple_list_train): - continue - rank_list[entity_temp] = distance(self.entity_list[triplet[0]], self.entity_list[entity_temp], - self.relation_list[triplet[2]]) - name_rank = sorted(rank_list.items(), key=operator.itemgetter(1)) # Ԫصĵһ - if self.label == 'head': - num_tri = 0 - else: - num_tri = 1 - x = 1 - for i in name_rank: - if i[0] == triplet[num_tri]: - break - x += 1 - self.rank.append((triplet, triplet[num_tri], name_rank[0][0], x)) - print(x) - cou += 1 - if cou % 10000 == 0: - print(cou) - - def get_relation_rank(self): - cou = 0 - self.rank = [] - for triplet in self.triple_list_test: - rank_list = {} - for relation_temp in self.relation_list.keys(): - corrupted_triplet = (triplet[0], triplet[1], relation_temp) - if self.is_fit and (corrupted_triplet in self.triple_list_train): - continue - rank_list[relation_temp] = distance(self.entity_list[triplet[0]], self.entity_list[triplet[1]], - self.relation_list[relation_temp]) - name_rank = sorted(rank_list.items(), key=operator.itemgetter(1)) - x = 1 - for i in name_rank: - if i[0] == triplet[2]: - break - x += 1 - self.rank.append((triplet, triplet[2], name_rank[0][0], x)) - print(x) - cou += 1 - if cou % 10000 == 0: - print(cou) - - def get_mean_rank(self): - num = 0 - for r in self.rank: - num += r[3] - return num / len(self.rank) - - -def distance(h, t, r): - h = array(h) - t = array(t) - r = array(r) - s = h + r - t - return linalg.norm(s) - - -def openD(dir, sp="\t"): - # triple = (head, tail, relation) - num = 0 - list = [] - with open(dir) as file: - lines = file.readlines() - for line in lines: - triple = line.strip().split(sp) - if len(triple) < 3: - continue - list.append(tuple(triple)) - num += 1 - print(num) - return num, list - - -def load_data(str): - fr = open(str) - s_arr = [line.strip().split("\t") for line in fr.readlines()] - dat_arr = [[float(s) for s in line[1][1:-1].split(", ")] for line in s_arr] - name_arr = [line[0] for line in s_arr] - return dat_arr, name_arr - - -if __name__ == '__main__': - dir_train = "data/FB15k/train.txt" - triple_num_train, triple_list_train = openD(dir_train) - dir_test = "data/FB15k/test.txt" - triple_num_test, triple_list_test = openD(dir_test) - dir_entity_vector = "data/entityVector.txt" - entity_vector_list, entity_list = load_data(dir_entity_vector) - dir_relation_vector = "data/relationVector.txt" - relation_vector_list, relation_list = load_data(dir_relation_vector) - print("********** Start test... **********") - - test_head_raw = Test(entity_list, entity_vector_list, relation_list, relation_vector_list, triple_list_train, - triple_list_test) - test_head_raw.get_rank() - print(test_head_raw.get_mean_rank()) - test_head_raw.write_rank("data/test/" + "test_head_raw" + ".txt") - test_head_raw.get_relation_rank() - print(test_head_raw.get_mean_rank()) - test_head_raw.write_rank("data/test" + "testRelationRaw" + ".txt") - - test_tail_raw = Test(entity_list, entity_vector_list, relation_list, relation_vector_list, triple_list_train, - triple_list_test, - label="tail") - test_tail_raw.get_rank() - print(test_tail_raw.get_mean_rank()) - test_tail_raw.write_rank("data/test/" + "test_tail_raw" + ".txt") - - test_head_fit = Test(entity_list, entity_vector_list, relation_list, relation_vector_list, triple_list_train, - triple_list_test, - is_fit=True) - test_head_fit.get_rank() - print(test_head_fit.get_mean_rank()) - test_head_fit.write_rank("data/test/" + "test_head_fit" + ".txt") - test_head_fit.get_relation_rank() - print(test_head_fit.get_mean_rank()) - test_head_fit.write_rank("data/test/" + "testRelationFit" + ".txt") - - test_tail_fit = Test(entity_list, entity_vector_list, relation_list, relation_vector_list, triple_list_train, - triple_list_test, - is_fit=True, label="tail") - test_tail_fit.get_rank() - print(test_tail_fit.get_mean_rank()) - test_tail_fit.write_rank("data/test/" + "test_tail_fit" + ".txt") \ No newline at end of file diff --git a/TestTransEMpQueue.py b/TestTransEMpQueue.py new file mode 100644 index 0000000..52ac077 --- /dev/null +++ b/TestTransEMpQueue.py @@ -0,0 +1,206 @@ +from numpy import * +import operator +import logging +from TrainTransESimple import get_details_of_triplets_list +from multiprocessing import Queue, JoinableQueue, Process +import timeit + +LOG_FORMAT = "%(asctime)s - %(name)s - %(message)s" +logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT) + + +class Test: + '''۹ +֪ʶһnʵ壬ô۹£ +ÿһԵԪaеͷʵβʵ壬滻Ϊ֪ʶеʵ壬ҲǻnԪ顣 +ֱnԪֵ(distֵ)transEУǼh+r-tֵԵõnֱֵӦnԪ顣 +nֵ +¼ԭԪaֵš +дڲԼеIJԪظ̡ +ÿȷԪֵƽõֵdzΪMean Rank +ȷԪС10ıõֵdzΪHits@10 +۵Ḷָ́꣺Mean RankHits@10Mean RankԽСԽãHits@10ԽԽáôδHits@10PythonִٶȺ +ߺʹ廪ѧFast_TransX룬ʹC++дܸߣܹٵóѵͲԽ +''' + + def __init__(self, entity_dyct, relation_dyct, train_triple_list, + test_triple_list, + label="head", is_fit=False, n_rank_calculator=24): + self.entity_dyct = entity_dyct + self.relation_dyct = relation_dyct + self.train_triple_list = train_triple_list + self.test_triple_list = test_triple_list + self.rank = [] + self.label = label + self.is_fit = is_fit + self.hit_at_10 = 0 + self.count = 0 + self.n_rank_calculator = n_rank_calculator + + def write_rank(self, file_path): + logging.info("Write int to %s" % file_path) + file = open(file_path, 'w') + for r in self.rank: + file.write(str(r[0]) + "\t") + file.write(str(r[1]) + "\t") + file.write(str(r[2]) + "\t") + file.write(str(r[3]) + "\n") + file.close() + + def get_rank_part(self, triplet): + rank_dyct = {} + for ent in self.entity_dyct.keys(): + if self.label == "head": + corrupted_triplet = (ent, triplet[1], triplet[2]) + if self.is_fit and ( + corrupted_triplet in self.train_triple_list): + continue + rank_dyct[ent] = distance(self.entity_dyct[ent], self.entity_dyct[triplet[1]], + self.relation_dyct[triplet[2]]) + else: # ݱǩ滻ͷʵ滻βʵ + corrupted_triplet = (triplet[0], ent, triplet[2]) + if self.is_fit and ( + corrupted_triplet in self.train_triple_list): + continue + rank_dyct[ent] = distance(self.entity_dyct[triplet[0]], self.entity_dyct[ent], + self.relation_dyct[triplet[2]]) + sorted_rank = sorted(rank_dyct.items(), + key=operator.itemgetter(1)) # Ԫصĵһ + if self.label == 'head': + num_tri = 0 + else: + num_tri = 1 + ranking = 1 + for i in sorted_rank: + if i[0] == triplet[num_tri]: + break + ranking += 1 + if ranking < 10: + self.hit_at_10 += 1 + self.rank.append( + (triplet, triplet[num_tri], sorted_rank[0][0], ranking)) + logging.info( + "Count:{} triplet {} {} ranks {}".format( + self.count, triplet, self.label, ranking)) + self.count += 1 + + def get_relation_rank(self): + count = 0 + self.rank = [] + self.hit_at_10 = 0 + for triplet in self.test_triple_list: + rank_dyct = {} + for rel in self.relation_dyct.keys(): + corrupted_triplet = (triplet[0], triplet[1], rel) + if self.is_fit and ( + corrupted_triplet in self.train_triple_list): + continue + rank_dyct[rel] = distance(self.entity_dyct[triplet[0]], self.entity_dyct[triplet[1]], + self.relation_dyct[rel]) + sorted_rank = sorted(rank_dyct.items(), key=operator.itemgetter(1)) + ranking = 1 + for i in sorted_rank: + if i[0] == triplet[2]: + break + ranking += 1 + if ranking < 10: + self.hit_at_10 += 1 + self.rank.append((triplet, triplet[2], sorted_rank[0][0], ranking)) + logging.info( + "Count:{} triplet {} relation ranks {}".format( + count, triplet, ranking)) + count += 1 + + def get_mean_rank_and_hit(self): + total_rank = 0 + for r in self.rank: + total_rank += r[3] + num = len(self.rank) + return total_rank / num, self.hit_at_10 / num + + def calculate_rank(self, in_queue, out_queue): + while True: + test_triplet = in_queue.get() + if test_triplet is None: + in_queue.task_done() + return + else: + out_queue.put(test_triplet) + in_queue.task_done() + + def launch_test(self): + eval_result_queue = JoinableQueue() + rank_result_queue = Queue() + print('-----Start evaluation-----') + start = timeit.default_timer() + for _ in range(self.n_rank_calculator): + Process( + target=self.calculate_rank, + kwargs={ + 'in_queue': eval_result_queue, + 'out_queue': rank_result_queue}).start() + n_used_eval_triple = 0 + for test_triplet in self.test_triple_list: + eval_result_queue.put(test_triplet) + n_used_eval_triple += 1 + for _ in range(self.n_rank_calculator): + eval_result_queue.put(None) + print('-----Joining all rank calculator-----') + # eval_result_queue.join() + for i in range(n_used_eval_triple): + test_triplet = rank_result_queue.get() + self.get_rank_part(test_triplet) + print('-----All rank calculation accomplished-----') + print('-----Obtaining evaluation results-----') + + +def distance(h, t, r): + h = array(h) + t = array(t) + r = array(r) + s = h + r - t + return linalg.norm(s) + + +def get_dict_from_vector_file(file_path): + file = open(file_path) + dyct = {} + for line in file.readlines(): + name_vector = line.strip().split("\t") + # vectorʹ[1:-1]Ϊvector'[0.11,0.22,..]'strͣ[1:-1]Ϊȥб + vector = [float(s) for s in name_vector[1][1:-1].split(", ")] + name = name_vector[0] + dyct[name] = vector + return dyct + + +def main(): + train_file = "data/FB15k/train.txt" + num_train_triple, train_triple_list = get_details_of_triplets_list( + train_file) + logging.info("Num of Train:%d" % num_train_triple) + test_file = "data/FB15k/test.txt" + num_test_triple, test_triple_list = get_details_of_triplets_list(test_file) + logging.info("Num of Test:%d" % num_test_triple) + entity_vector_file = "data/entityVector.txt" + entity_vector_dyct = get_dict_from_vector_file(entity_vector_file) + relation_vector_file = "data/relationVector.txt" + relation_vector_dyct = get_dict_from_vector_file(relation_vector_file) + logging.info("********** Start Test **********") + + test_head_raw = Test( + entity_vector_dyct, + relation_vector_dyct, + train_triple_list, + test_triple_list) + test_head_raw.launch_test() + + logging.info( + "=========== Test Head Raw MeanRank: %g Hits@10: %g ===========" % + test_head_raw.get_mean_rank_and_hit()) + test_head_raw.write_rank("data/test/" + "test_head_raw" + ".txt") + logging.info("********** End Test **********") + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/TrainMain.py b/TrainMain.py new file mode 100644 index 0000000..a7f3a4a --- /dev/null +++ b/TrainMain.py @@ -0,0 +1,95 @@ +import timeit +from TrainTransESimple import prepare_fb15k_train_data +from TrainTransESimple import TransE +from TrainTransEMpQueue import TransE as fastTransE +from TrainTransEMpManager import TransE as managerTransE +import argparse +from TrainTransEMpManager import Manager2, func1, MyManager + +import logging +from multiprocessing import Process, Lock + +LOG_FORMAT = "%(asctime)s - %(name)s - %(message)s" +logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT) + + +def main(): + parser = argparse.ArgumentParser(description='TransE') + parser.add_argument('--embedding_dim', type=int, default=100) + parser.add_argument('--margin_value', type=float, default=1.0) + parser.add_argument('--normal_form', type=str, default='L1') + parser.add_argument('--batch_size', type=int, default=10000) + parser.add_argument('--learning_rate', type=float, default=0.003) + parser.add_argument('--n_generator', type=int, default=24) + parser.add_argument('--max_epoch', type=int, default=2000) + parser.add_argument('--multi_process', type=str, default="MpQueue") + args = parser.parse_args() + print(args) + entity_list, rels_list, train_triplets_list = prepare_fb15k_train_data() + logging.info("********** Start TransE training ***********") + + if args.multi_process == "Simple": + transE = TransE( + entity_list, + rels_list, + train_triplets_list, + margin=args.margin_value, + dim=args.embedding_dim, + learing_rate=args.learning_rate, + normal_form=args.normal_form, + batch_size=args.batch_size) + logging.info("TransE is initializing...") + transE.transE(args.max_epoch) + elif args.multi_process == "MpQueue": + transE = fastTransE( + entity_list, + rels_list, + train_triplets_list, + margin=args.margin_value, + dim=args.embedding_dim, + learing_rate=args.learning_rate, + normal_form=args.normal_form, + batch_size=args.batch_size, + n_generator=args.n_generator) + logging.info("TransE is initializing...") + for epoch in range(args.max_epoch): + logging.info( + "Mp Queue TransE: After %d training epoch(s):" % + epoch) + transE.launch_training() + else: + MyManager.register('managerTransE', managerTransE) + manager = Manager2() + + transE = manager.managerTransE( + entity_list, + rels_list, + train_triplets_list, + batch_size=args.batch_size, + learing_rate=args.learning_rate, + margin=1, + dim=50, + normal_form=args.normal_form) + logging.info("TransE is initializing...") + start = timeit.default_timer() + for i in range(args.max_epoch): # epochĴ + lock = Lock() + proces = [Process(target=func1, args=(transE, lock)) for j in range(10)] # 10̣УԻܿ + for p in proces: + p.start() + for p in proces: + p.join() + end = timeit.default_timer() + logging.info( + "Mp Manager TransE: After %d training epoch(s):\nbatch size %d, cost time %g s, loss on batch data is %g" + % (i, 10000, end - start, transE.get_loss())) + start = end + transE.clear_loss() + logging.info("********** End TransE training ***********\n") + # ѵβһ100µдļ + transE.write_vector("data/entityVector.txt", "entity") + transE.write_vector("data/relationVector.txt", "relationship") + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/TrainTransE.py b/TrainTransE.py deleted file mode 100644 index 8e4eb1c..0000000 --- a/TrainTransE.py +++ /dev/null @@ -1,273 +0,0 @@ -from random import uniform, sample, choice -import numpy as np -from copy import deepcopy -import time -from multiprocessing import Pool -from multiprocessing import Process, Value, Lock -from multiprocessing.managers import BaseManager -import multiprocessing -from numba import jit - - -def get_details_of_entityOrRels_list(file_path, split_delimeter="\t"): - num_of_file = 0 - lyst = [] - with open(file_path) as file: - lines = file.readlines() - for line in lines: - details_and_id = line.strip().split(split_delimeter) - lyst.append(details_and_id[0]) - num_of_file += 1 - return num_of_file, lyst - - -def get_details_of_triplets_list(file_path, split_delimeter="\t"): - num_of_file = 0 - lyst = [] - with open(file_path) as file: - lines = file.readlines() - for line in lines: - triple = line.strip().split(split_delimeter) - if len(triple) < 3: - continue - lyst.append(tuple(triple)) - num_of_file += 1 - return num_of_file, lyst - - -def norm(lyst): - # 归一化 单位向量 - var = np.linalg.norm(lyst) - i = 0 - while i < len(lyst): - lyst[i] = lyst[i] / var - i += 1 - # 需要返回array值 因为list不支持减法 - # return list - return np.array(lyst) - - -def dist_L1(h, t, l): - s = h + l - t - # 曼哈顿距离/出租车距离, |x-xi|+|y-yi|直接对向量的各个维度取绝对值相加 - # dist = np.fabs(s).sum() - return np.fabs(s).sum() - - -def dist_L2(h, t, l): - s = h + l - t - # 欧氏距离,是向量的平方和未开方。一定要注意,归一化公式和距离公式的错误书写,会引起收敛的失败 - # dist = (s * s).sum() - return (s * s).sum() - - -class TransE(object): - def __init__(self, entity_list, rels_list, triplets_list, margin=1, learing_rate=0.01, dim=50, normal_form="L1"): - self.learning_rate = learing_rate - self.loss = 0 - self.entity_list = entity_list # entityList是entity的list;初始化后,变为字典,key是entity,values是其向量(使用narray)。 - self.rels_list = rels_list - self.triplets_list = triplets_list - self.margin = margin - self.dim = dim - self.normal_form = normal_form - self.entity_vector_dict = {} - self.rels_vector_dict = {} - self.loss_list = [] - - def initialize(self): - """对论文中的初始化稍加改动 - 初始化l和e,对于原本的l和e的文件中的/m/06rf7字符串标识转化为定义的dim维向量,对dim维向量进行uniform和norm归一化操作 - """ - entity_vector_dict, rels_vector_dict = {}, {} - entity_vector_compo_list, rels_vector_compo_list = [], [] - for item, dict, compo_list, name in zip( - [self.entity_list, self.rels_list], [entity_vector_dict, rels_vector_dict], - [entity_vector_compo_list, rels_vector_compo_list], ["entity_vector_dict", "rels_vector_dict"]): - for entity_or_rel in item: - n = 0 - compo_list = [] - while n < self.dim: - random = uniform(-6 / (self.dim ** 0.5), 6 / (self.dim ** 0.5)) - compo_list.append(random) - n += 1 - compo_list = norm(compo_list) - dict[entity_or_rel] = compo_list - print("The " + name + "'s initialization is over. It's number is %d." % len(dict)) - self.entity_vector_dict = entity_vector_dict - self.rels_vector_dict = rels_vector_dict - - def transE(self, cycle_index=20): - count = 0 - print("\n********** Start TransE training **********") - for i in range(cycle_index): - - if count == 0: - start_time = time.time() - count += 1 - - if i % 10 == 0 and i != 0: - print("----------------The {} batches----------------".format(i)) - print("The loss is: %.4f" % self.loss) - # 查看最后的结果收敛情况 - self.loss_list.append(self.loss) - # self.write_vector("data/entityVector.txt", "entity") - # self.write_vector("data/relationVector.txt", "rels") - self.loss = 0 - count = 0 - end_time = time.time() - print("One epoch takes %.2f ms." % ((end_time - start_time) * 1000)) - start_time = end_time - - Sbatch = self.sample(1500) - Tbatch = [] # 元组对(原三元组,打碎的三元组)的列表 :{((h,r,t),(h',r,t'))} - for sbatch in Sbatch: - triplets_with_corrupted_triplets = (sbatch, self.get_corrupted_triplets(sbatch)) - if triplets_with_corrupted_triplets not in Tbatch: - Tbatch.append(triplets_with_corrupted_triplets) - self.update(Tbatch) - - def sample(self, size): - return sample(self.triplets_list, size) - - def get_corrupted_triplets(self, triplets): - '''training triplets with either the head or tail replaced by a random entity (but not both at the same time) - :param triplet:单个(h,t,l) - :return corruptedTriplet:''' - # i = uniform(-1, 1) if i - coin = choice([True, False]) - # 由于这个时候的(h,t,l)是从train文件里面抽出来的,要打坏的话直接随机寻找一个和头实体不等的实体即可 - if coin: # 抛硬币 为真 打破头实体,即第一项 - while True: - searching_entity = sample(self.entity_vector_dict.keys(), 1)[0] # 取第一个元素是因为sample返回的是一个列表类型 - if searching_entity != triplets[0]: - break - corrupted_triplets = (searching_entity, triplets[1], triplets[2]) - else: # 反之,打破尾实体,即第二项 - while True: - searching_entity = sample(self.entity_vector_dict.keys(), 1)[0] - if searching_entity != triplets[1]: - break - corrupted_triplets = (triplets[0], searching_entity, triplets[2]) - return corrupted_triplets - - def update(self, Tbatch): - entity_vector_copy = deepcopy(self.entity_vector_dict) - rels_vector_copy = deepcopy(self.rels_vector_dict) - - for triplets_with_corrupted_triplets in Tbatch: - head_entity_vector = entity_vector_copy[triplets_with_corrupted_triplets[0][0]] - tail_entity_vector = entity_vector_copy[triplets_with_corrupted_triplets[0][1]] - relation_vector = rels_vector_copy[triplets_with_corrupted_triplets[0][2]] - - head_entity_vector_with_corrupted_triplets = entity_vector_copy[triplets_with_corrupted_triplets[1][0]] - tail_entity_vector_with_corrupted_triplets = entity_vector_copy[triplets_with_corrupted_triplets[1][1]] - - head_entity_vector_before_batch = self.entity_vector_dict[triplets_with_corrupted_triplets[0][0]] - tail_entity_vector_before_batch = self.entity_vector_dict[triplets_with_corrupted_triplets[0][1]] - relation_vector_before_batch = self.rels_vector_dict[triplets_with_corrupted_triplets[0][2]] - - head_entity_vector_with_corrupted_triplets_before_batch = self.entity_vector_dict[ - triplets_with_corrupted_triplets[1][0]] - tail_entity_vector_with_corrupted_triplets_before_batch = self.entity_vector_dict[ - triplets_with_corrupted_triplets[1][1]] - - if self.normal_form == "L1": - dist_triplets = dist_L1(head_entity_vector_before_batch, tail_entity_vector_before_batch, - relation_vector_before_batch) - dist_corrupted_triplets = dist_L1(head_entity_vector_with_corrupted_triplets_before_batch, - tail_entity_vector_with_corrupted_triplets_before_batch, - relation_vector_before_batch) - else: - dist_triplets = dist_L2(head_entity_vector_before_batch, tail_entity_vector_before_batch, - relation_vector_before_batch) - dist_corrupted_triplets = dist_L2(head_entity_vector_with_corrupted_triplets_before_batch, - tail_entity_vector_with_corrupted_triplets_before_batch, - relation_vector_before_batch) - eg = self.margin + dist_triplets - dist_corrupted_triplets - if eg > 0: # 大于0取原值,小于0则置0.即合页损失函数margin-based ranking criterion - self.loss += eg - temp_positive = 2 * self.learning_rate * ( - tail_entity_vector_before_batch - head_entity_vector_before_batch - relation_vector_before_batch) - temp_negative = 2 * self.learning_rate * ( - tail_entity_vector_with_corrupted_triplets_before_batch - head_entity_vector_with_corrupted_triplets_before_batch - relation_vector_before_batch) - if self.normal_form == "L1": - temp_positive_L1 = [1 if temp_positive[i] >= 0 else -1 for i in range(self.dim)] - temp_negative_L1 = [1 if temp_negative[i] >= 0 else -1 for i in range(self.dim)] - # temp_positive = norm(temp_positive_L1) * self.learning_rate - # temp_negative = norm(temp_negative_L1) * self.learning_rate - temp_positive = np.array(temp_positive_L1) * self.learning_rate - temp_negative = np.array(temp_negative_L1) * self.learning_rate - - # 对损失函数的5个参数进行梯度下降, 随机体现在sample函数上 - head_entity_vector += temp_positive - tail_entity_vector -= temp_positive - relation_vector = relation_vector + temp_positive - temp_negative - head_entity_vector_with_corrupted_triplets -= temp_negative - tail_entity_vector_with_corrupted_triplets += temp_negative - - # 归一化刚才更新的向量,减少计算时间 - entity_vector_copy[triplets_with_corrupted_triplets[0][0]] = norm(head_entity_vector) - entity_vector_copy[triplets_with_corrupted_triplets[0][1]] = norm(tail_entity_vector) - rels_vector_copy[triplets_with_corrupted_triplets[0][2]] = norm(relation_vector) - entity_vector_copy[triplets_with_corrupted_triplets[1][0]] = norm( - head_entity_vector_with_corrupted_triplets) - entity_vector_copy[triplets_with_corrupted_triplets[1][1]] = norm( - tail_entity_vector_with_corrupted_triplets) - - # self.entity_vector_dict = deepcopy(entity_vector_copy) - # self.rels_vector_dict = deepcopy(rels_vector_copy) - self.entity_vector_dict = entity_vector_copy - self.rels_vector_dict = rels_vector_copy - - def write_vector(self, file_path, option): - if option.strip().startswith("entit"): - print("Write entities vetor into file : {}".format(file_path)) - # dyct = deepcopy(self.entity_vector_dict) - dyct = self.entity_vector_dict - if option.strip().startswith("rel"): - print("Write relationships vector into file: {}".format(file_path)) - # dyct = deepcopy(self.rels_vector_dict) - dyct = self.rels_vector_dict - with open(file_path, 'w') as file: # 写文件,每次覆盖写 用with自动调用close - for dyct_key in dyct.keys(): - file.write(dyct_key + "\t") - file.write(str(dyct[dyct_key].tolist())) - file.write("\n") - - def write_loss(self, file_path, num_of_col): - with open(file_path, 'w') as file: - lyst = deepcopy(self.loss_list) - for i in range(len(lyst)): - if num_of_col == 1: - # 保留4位小数 - file.write(str(int(lyst[i] * 10000) / 10000) + "\n") - # file.write(str(lyst[i]).split('.')[0] + '.' + str(lyst[i]).split('.')[1][:4] + "\n") - else: - # file.write(str(lyst[i]).split('.')[0] + '.' + str(lyst[i]).split('.')[1][:4] + "\t") - file.write(str(int(lyst[i] * 10000) / 10000) + " ") - if (i + 1) % num_of_col == 0 and i != 0: - file.write("\n") - - -if __name__ == "__main__": - entity_file_path = "data/FB15k/entity2id.txt" - num_of_entity, entity_list = get_details_of_entityOrRels_list(entity_file_path) - rels_file_path = "data/FB15k/relation2id.txt" - num_of_rels, rels_list = get_details_of_entityOrRels_list(rels_file_path) - train_file_path = "data/FB15k/train.txt" - num_of_triplets, triplets_list = get_details_of_triplets_list(train_file_path) - - transE = TransE(entity_list, rels_list, triplets_list, margin=1, dim=50) - print("\nTransE is initializing...") - transE.initialize() - transE.transE(20000) - - # transE.transE2(num_of_epochs=1000, epoch_triplets=15000, num_of_batches=10) - print("********** End TransE training ***********\n") - # 训练的批次并不一定是100的整数倍,将最后更新的向量写到文件 - transE.write_vector("data/entityVector.txt", "entity") - transE.write_vector("data/relationVector.txt", "relationship") - transE.write_loss("data/lossList_25cols.txt", 25) - transE.write_loss("data/lossList_1cols.txt", 1) - diff --git a/TrainTransEMpManager.py b/TrainTransEMpManager.py new file mode 100644 index 0000000..ecad8f2 --- /dev/null +++ b/TrainTransEMpManager.py @@ -0,0 +1,87 @@ +from multiprocessing import Process, Lock +from multiprocessing.managers import BaseManager +import logging +from TrainTransESimple import TransE as TransESimple +from TrainTransESimple import prepare_fb15k_train_data + +LOG_FORMAT = "%(asctime)s - %(name)s - %(message)s" +logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT) + +INITIAL_LEARNING_RATE = 0.01 + + +class TransE(TransESimple): + + def get_loss(self): + # ο廪Fast-TransXC++룬ȷʵٶȺܿ죬Pythonӽ10СʱѵC++ʮӼɡԵĿһ´룬 + # ԭеSbatch޸ģֱӽˣΪѵԪһepochΪbatchɣÿbatchÿһԪ鶼ݶ½̲߳nbatchӦn߳ + # Pythonʷ⣬ʹGILȫֽʹPythonĶ߳̽Ƽߣ޷cpuԿʹöŻ + # Ϊʹọ̈ʹmanagertransEװΪProxyProxy޷ȡװTransEԣҪдgetloss + # ֵעǣPythonĶܲһforѭͰ˽̵Ĵ١л̼ҪRPCԶͨ + # trainTransEtrainTransE_MultiProcessԱtrainTransEforѭһ10ʱ8s-9strainTransE_MultiProcessһepochһʱ12-13s + # һŻ̳أʵ̸ֽãܣtf + return self.loss + + def clear_loss(self): + # úҲΪProxyⲿʧ0 + self.loss = 0 + + def transE(self): + Sbatch = self.sample(self.batch_size // 10) + Tbatch = [] # ԪԣԭԪ飬Ԫ飩б {((h,r,t),(h',r,t'))} + for sbatch in Sbatch: + pos_neg_triplets = (sbatch, self.get_corrupted_triplets(sbatch)) + if pos_neg_triplets not in Tbatch: + Tbatch.append(pos_neg_triplets) + self.update(Tbatch) + + +class MyManager(BaseManager): + pass + + +def Manager2(): + m = MyManager() + m.start() + return m + + +MyManager.register('TransE', TransE) + + +def func1(em, lock): + with lock: + em.transE() + + +def main(): + manager = Manager2() + entity_list, rels_list, train_triplets_list = prepare_fb15k_train_data() + + transE = manager.TransE( + entity_list, + rels_list, + train_triplets_list, + batch_size=10000, + margin=1, + dim=50) + logging.info("TransE is initializing...") + for i in range(20000): # epochĴ + lock = Lock() + proces = [Process(target=func1, args=(transE, lock)) + for j in range(10)] # 10̣УԻܿ + for p in proces: + p.start() + for p in proces: + p.join() + if i != 0: + logging.info( + "After %d training epoch(s), loss on batch data is %g" % + (i * 10, transE.get_loss())) + transE.clear_loss() + # transE.transE(100000) + logging.info("********** End TransE training ***********\n") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/TrainTransEMpQueue.py b/TrainTransEMpQueue.py new file mode 100644 index 0000000..d4ce1c8 --- /dev/null +++ b/TrainTransEMpQueue.py @@ -0,0 +1,138 @@ +import numpy as np +from multiprocessing import Process, Queue +import logging +import timeit +from TrainTransESimple import TransE as TransESimple +from TrainTransESimple import norm, dist_L1, dist_L2 +from TrainTransESimple import prepare_fb15k_train_data + +LOG_FORMAT = "%(asctime)s - %(name)s - %(message)s" +logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT) + +INITIAL_LEARNING_RATE = 0.01 + + +class TransE(TransESimple): + def __init__( + self, + entity_list, + rels_list, + triplets_list, + margin=1, + learing_rate=INITIAL_LEARNING_RATE, + dim=100, + normal_form="L1", + batch_size=10000, + n_generator=24): + TransESimple.__init__(self, entity_list, rels_list, triplets_list, margin=margin, learing_rate=learing_rate, + dim=dim, normal_form=normal_form, batch_size=batch_size) + self.n_generator = n_generator + + def generate_training_batch( + self, + sbatch_queue: Queue, + tbatch_queue: Queue): + while True: + raw_batch = sbatch_queue.get() + if raw_batch is None: + return + else: + pos_triplet = raw_batch + neg_triplet = self.get_corrupted_triplets(pos_triplet) + pos_neg_triplets = (pos_triplet, neg_triplet) + tbatch_queue.put(pos_neg_triplets) + + def launch_training(self): + raw_batch_queue = Queue() + training_batch_queue = Queue() + for _ in range(self.n_generator): + Process( + target=self.generate_training_batch, + kwargs={ + 'sbatch_queue': raw_batch_queue, + 'tbatch_queue': training_batch_queue}).start() + start = timeit.default_timer() + Sbatch = self.sample(self.batch_size) + n_batch = 0 + for raw_batch in Sbatch: + raw_batch_queue.put(raw_batch) + n_batch += 1 + for _ in range(self.n_generator): + raw_batch_queue.put(None) + epoch_loss = 0 + self.loss = 0 + for i in range(n_batch): + batch_pos, batch_neg = training_batch_queue.get() + self.update_part(batch_pos, batch_neg) + epoch_loss += self.loss + print("batch size %d, cost time %g s, loss on batch data is %g" % ( + n_batch, timeit.default_timer() - start, epoch_loss)) + + def update_part(self, pos_triplet, neg_triplet): + entity_vector_copy = self.entity_vector_dict + rels_vector_copy = self.rels_vector_dict + + # h,t,rͷʵβʵϵh2t2еh't'Ԫеͷβʵ + # TbatchԪԣԭԪ飬Ԫ飩б + # [((h,r,t),(h',r,t'))...]dataļԭ(h,t,r) + h = entity_vector_copy[pos_triplet[0]] + t = entity_vector_copy[pos_triplet[1]] + r = rels_vector_copy[pos_triplet[2]] + # Ԫеͷʵβʵ + h2 = entity_vector_copy[neg_triplet[0]] + t2 = entity_vector_copy[neg_triplet[1]] + # ԭbeforebatchǸΪûбҪѾ뵽batchˣߵľǵ + if self.normal_form == "L1": + dist_triplets = dist_L1(h, t, r) + dist_corrupted_triplets = dist_L1(h2, t2, r) + else: + dist_triplets = dist_L2(h, t, r) + dist_corrupted_triplets = dist_L2(h2, t2, r) + eg = self.margin + dist_triplets - dist_corrupted_triplets + if eg > 0: # 0ȡԭֵС00.ҳʧmargin-based ranking criterion + self.loss += eg + temp_positive = 2 * self.learning_rate * (t - h - r) + temp_negative = 2 * self.learning_rate * (t2 - h2 - r) + if self.normal_form == "L1": + temp_positive_L1 = [1 if temp_positive[i] >= 0 else -1 for i in range(self.dim)] + temp_negative_L1 = [1 if temp_negative[i] >= 0 else -1 for i in range(self.dim)] + temp_positive = np.array(temp_positive_L1) * self.learning_rate + temp_negative = np.array(temp_negative_L1) * self.learning_rate + + # ʧ5ݶ½ sample + h += temp_positive + t -= temp_positive + r = r + temp_positive - temp_negative + h2 -= temp_negative + t2 += temp_negative + + # һղŸµټʱ + entity_vector_copy[pos_triplet[0]] = norm(h) + entity_vector_copy[pos_triplet[1]] = norm(t) + rels_vector_copy[pos_triplet[2]] = norm(r) + entity_vector_copy[neg_triplet[0]] = norm(h2) + entity_vector_copy[neg_triplet[1]] = norm(t2) + + self.entity_vector_dict = entity_vector_copy + self.rels_vector_dict = rels_vector_copy + + +def main(): + entity_list, rels_list, train_triplets_list = prepare_fb15k_train_data() + + transE = TransE( + entity_list, + rels_list, + train_triplets_list, + margin=1, + dim=100, + learing_rate=0.003) + logging.info("TransE is initializing...") + for epoch in range(2000): + print("Mp Queue TransE, After %d training epoch(s):\n" % epoch) + transE.launch_training() + logging.info("********** End TransE training ***********\n") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/TrainTransESimple.py b/TrainTransESimple.py new file mode 100644 index 0000000..83f4fc3 --- /dev/null +++ b/TrainTransESimple.py @@ -0,0 +1,276 @@ +import timeit +from random import uniform, sample, choice +import numpy as np +from copy import deepcopy +import logging + +LOG_FORMAT = "%(asctime)s - %(name)s - %(message)s" +logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT) + +INITIAL_LEARNING_RATE = 0.01 + + +def get_details_of_entityOrRels_list(file_path, split_delimeter="\t"): + num_of_file = 0 + lyst = [] + with open(file_path) as file: + # ȷʵֱʹreadlinesģڴģʽread_csv apiʹ csv_data = pd.read_csv(csv_file, low_memory=False) + # ȡʱԲдrIJΪmodeĬϼΪr + lines = file.readlines() + for line in lines: + details_and_id = line.strip().split(split_delimeter) + lyst.append(details_and_id[0]) + num_of_file += 1 + return num_of_file, lyst + + +def get_details_of_triplets_list(file_path, split_delimeter="\t"): + num_of_file = 0 + lyst = [] + with open(file_path) as file: + lines = file.readlines() + for line in lines: + triple = line.strip().split(split_delimeter) + if len(triple) < 3: + continue + lyst.append(tuple(triple)) + num_of_file += 1 + return num_of_file, lyst + + +def norm(lyst): + # һ λ + var = np.linalg.norm(lyst) + i = 0 + while i < len(lyst): + lyst[i] = lyst[i] / var + i += 1 + # Ҫarrayֵ Ϊlistּ֧ + return np.array(lyst) + + +def dist_L1(h, t, l): + s = h + l - t + # پ/⳵룬 |x-xi|+|y-yi|ֱӶĸάȡֵ + return np.fabs(s).sum() + + +def dist_L2(h, t, l): + s = h + l - t + # ŷϾ,ƽδһҪע⣬һʽ;빫ʽĴдʧ + return (s * s).sum() + + +class TransE(object): + def __init__( + self, + entity_list, + rels_list, + triplets_list, + margin=1, + learing_rate=INITIAL_LEARNING_RATE, + dim=100, + normal_form="L1", + batch_size=10000): + self.learning_rate = learing_rate + self.loss = 0 + self.entity_list = entity_list + self.rels_list = rels_list + self.triplets_list = triplets_list + self.margin = margin + self.dim = dim + self.normal_form = normal_form + self.batch_size = batch_size + self.entity_vector_dict = {} + self.rels_vector_dict = {} + self.loss_list = [] + self.initialize() + + def initialize(self): + ''' + еijʼԼӸĶ + ʼleԭleļе/m/06rf7ַʶתΪdimάdimάuniformnormһ + :return: + ''' + entity_vector_dict, rels_vector_dict = {}, {} + # component˼ķﵽά֮󣬶йһαеijʼ֡ + entity_vector_compo_list, rels_vector_compo_list = [], [] + for item, dyct, compo_list, name in zip([self.entity_list, self.rels_list], + [entity_vector_dict, rels_vector_dict], + [entity_vector_compo_list, rels_vector_compo_list], + ["entity_vector_dict", "rels_vector_dict"]): + for entity_or_rel in item: + n = 0 + compo_list = [] + while n < self.dim: + random = uniform(-6 / (self.dim ** 0.5), + 6 / (self.dim ** 0.5)) + compo_list.append(random) + n += 1 + compo_list = norm(compo_list) + dyct[entity_or_rel] = compo_list + self.entity_vector_dict = entity_vector_dict + self.rels_vector_dict = rels_vector_dict + + def transE(self, cycle_index=20): + for i in range(cycle_index): + start = timeit.default_timer() + Sbatch = self.sample(self.batch_size) + Tbatch = [] # ԪԣԭԪ飬Ԫ飩б {((h,r,t),(h',r,t'))} + for sbatch in Sbatch: + # pos_neg_tripletsԪԣpositivenegative + pos_neg_triplets = ( + sbatch, self.get_corrupted_triplets(sbatch)) + if pos_neg_triplets not in Tbatch: + Tbatch.append(pos_neg_triplets) + self.update(Tbatch) + if i % 1 == 0: + # Ըiֵʹemaָƽ + # self.learning_rate = INITIAL_LEARNING_RATE * (pow(0.96, i / 100)) + end = timeit.default_timer() + logging.info( + "Simple TransE, After %d training epoch(s):\nbatch size is %d, cost time is %g s, loss on batch data is %g" % + (i, self.batch_size, end - start, self.loss)) + # 鿴Ľ + self.loss_list.append(self.loss) + # self.write_vector("data/entityVector.txt", "entity") + # self.write_vector("data/relationVector.txt", "rels") + self.loss = 0 + + def sample(self, size): + return sample(self.triplets_list, size) + + def get_corrupted_triplets(self, triplets): + '''training triplets with either the head or tail replaced by a random entity (but not both at the same time) + :param triplet:h,t,l + :return corruptedTriplet:''' + coin = choice([True, False]) + # ʱ(h,t,l)ǴtrainļģҪ򻵵ĻֱѰһͷʵ岻ȵʵ弴 + if coin: # Ӳ Ϊ ͷʵ壬һ + while True: + # ȡһԪΪsampleصһб + searching_entity = sample(self.entity_vector_dict.keys(), 1)[0] + if searching_entity != triplets[0]: + break + corrupted_triplets = (searching_entity, triplets[1], triplets[2]) + else: # ֮βʵ壬ڶ + while True: + searching_entity = sample(self.entity_vector_dict.keys(), 1)[0] + if searching_entity != triplets[1]: + break + corrupted_triplets = (triplets[0], searching_entity, triplets[2]) + return corrupted_triplets + + def update(self, Tbatch): + entity_vector_copy = self.entity_vector_dict + rels_vector_copy = self.rels_vector_dict + + # h,t,rͷʵβʵϵh2t2еh't'Ԫеͷβʵ + # TbatchԪԣԭԪ飬Ԫ飩б + # [((h,r,t),(h',r,t'))...]dataļԭ(h,t,r) + for pos_neg_triplets in Tbatch: + h = entity_vector_copy[pos_neg_triplets[0][0]] + t = entity_vector_copy[pos_neg_triplets[0][1]] + r = rels_vector_copy[pos_neg_triplets[0][2]] + # Ԫеͷʵβʵ + h2 = entity_vector_copy[pos_neg_triplets[1][0]] + t2 = entity_vector_copy[pos_neg_triplets[1][1]] + # ԭbeforebatchǸΪûбҪѾ뵽batchˣߵľǵ + if self.normal_form == "L1": + dist_triplets = dist_L1(h, t, r) + dist_corrupted_triplets = dist_L1(h2, t2, r) + else: + dist_triplets = dist_L2(h, t, r) + dist_corrupted_triplets = dist_L2(h2, t2, r) + eg = self.margin + dist_triplets - dist_corrupted_triplets + if eg > 0: # 0ȡԭֵС00.ҳʧmargin-based ranking criterion + self.loss += eg + temp_positive = 2 * self.learning_rate * (t - h - r) + temp_negative = 2 * self.learning_rate * (t2 - h2 - r) + if self.normal_form == "L1": + temp_positive_L1 = [1 if temp_positive[i] >= 0 else -1 for i in range(self.dim)] + temp_negative_L1 = [1 if temp_negative[i] >= 0 else -1 for i in range(self.dim)] + temp_positive = np.array(temp_positive_L1) * self.learning_rate + temp_negative = np.array(temp_negative_L1) * self.learning_rate + + # ʧ5ݶ½ sample + h += temp_positive + t -= temp_positive + r = r + temp_positive - temp_negative + h2 -= temp_negative + t2 += temp_negative + + # һղŸµټʱ + entity_vector_copy[pos_neg_triplets[0][0]] = norm(h) + entity_vector_copy[pos_neg_triplets[0][1]] = norm(t) + rels_vector_copy[pos_neg_triplets[0][2]] = norm(r) + entity_vector_copy[pos_neg_triplets[1][0]] = norm(h2) + entity_vector_copy[pos_neg_triplets[1][1]] = norm(t2) + + self.entity_vector_dict = entity_vector_copy + self.rels_vector_dict = rels_vector_copy + + def write_vector(self, file_path, option): + if option.strip().startswith("entit"): + logging.info( + "Write entities vetor into file : {}".format(file_path)) + # dyct = deepcopy(self.entity_vector_dict) + dyct = self.entity_vector_dict + if option.strip().startswith("rel"): + logging.info( + "Write relationships vector into file: {}".format(file_path)) + # dyct = deepcopy(self.rels_vector_dict) + dyct = self.rels_vector_dict + with open(file_path, 'w') as file: # дļÿθд withԶclose + for dyct_key in dyct.keys(): + file.write(dyct_key + "\t") + file.write(str(dyct[dyct_key].tolist())) + file.write("\n") + + def write_loss(self, file_path, num_of_col): + with open(file_path, 'w') as file: + lyst = deepcopy(self.loss_list) + for i in range(len(lyst)): + if num_of_col == 1: + # 4λС + file.write(str(int(lyst[i] * 10000) / 10000) + "\n") + else: + file.write(str(int(lyst[i] * 10000) / 10000) + " ") + if (i + 1) % num_of_col == 0 and i != 0: + file.write("\n") + + +def prepare_fb15k_train_data(): + entity_file = "data/FB15k/entity2id.txt" + num_entity, entity_list = get_details_of_entityOrRels_list(entity_file) + logging.info("The number of entity_list is %d." % num_entity) + rels_file = "data/FB15k/relation2id.txt" + num_rels, rels_list = get_details_of_entityOrRels_list(rels_file) + logging.info("The num of rels_list is %d." % num_rels) + train_file = "data/FB15k/train.txt" + num_triplets, train_triplets_list = get_details_of_triplets_list( + train_file) + logging.info("The num of train_triplets_list is %d." % num_triplets) + return entity_list, rels_list, train_triplets_list + + +def main(): + # ӦTrainMainе --multi_process "None"IJԴ + entity_list, rels_list, train_triplets_list = prepare_fb15k_train_data() + + transE = TransE( + entity_list, + rels_list, + train_triplets_list, + margin=1, + dim=100, + learing_rate=0.003) + logging.info("TransE is initializing...") + transE.transE(5000) + + # transE.transE2(num_of_epochs=1000, epoch_triplets=15000, num_of_batches=10) + logging.info("********** End TransE training ***********\n") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/TrainTransE_MultiProcess.py b/TrainTransE_MultiProcess.py deleted file mode 100644 index b8fedbb..0000000 --- a/TrainTransE_MultiProcess.py +++ /dev/null @@ -1,296 +0,0 @@ -from random import uniform, sample, choice -import numpy as np -from copy import deepcopy -from multiprocessing import Process, Value, Lock -from multiprocessing.managers import BaseManager -import time - - -def get_details_of_entityOrRels_list(file_path, split_delimeter="\t"): - num_of_file = 0 - lyst = [] - with open(file_path) as file: - lines = file.readlines() - for line in lines: - details_and_id = line.strip().split(split_delimeter) - lyst.append(details_and_id[0]) - num_of_file += 1 - return num_of_file, lyst - - -def get_details_of_triplets_list(file_path, split_delimeter="\t"): - num_of_file = 0 - lyst = [] - with open(file_path) as file: - lines = file.readlines() - for line in lines: - triple = line.strip().split(split_delimeter) - if len(triple) < 3: - continue - lyst.append(tuple(triple)) - num_of_file += 1 - return num_of_file, lyst - - -def norm(lyst): - # һ λ - var = np.linalg.norm(lyst) - i = 0 - while i < len(lyst): - lyst[i] = lyst[i] / var - i += 1 - # Ҫarrayֵ Ϊlistּ֧ - # return list - return np.array(lyst) - - -def dist_L1(h, t, l): - s = h + l - t - # پ/⳵룬 |x-xi|+|y-yi|ֱӶĸάȡֵ - # dist = np.fabs(s).sum() - return np.fabs(s).sum() - - -def dist_L2(h, t, l): - s = h + l - t - # ŷϾ,ƽδһҪע⣬һʽ;빫ʽĴдʧ - # dist = (s * s).sum() - return (s * s).sum() - - -class TransE(object): - def __init__(self, entity_list, rels_list, triplets_list, margin=1, learing_rate=0.01, dim=50, normal_form="L1"): - self.learning_rate = learing_rate - self.loss = 0 - self.entity_list = entity_list # entityListentitylistʼ󣬱Ϊֵ䣬keyentityvaluesʹnarray - self.rels_list = rels_list - self.triplets_list = triplets_list - self.margin = margin - self.dim = dim - self.normal_form = normal_form - self.entity_vector_dict = {} - self.rels_vector_dict = {} - self.loss_list = [] - - def get_loss(self): - # ο廪Fast-TransXC++룬ȷʵٶȺܿ죬Pythonӽ5СʱѵC++ʮӼɡԵĿһ´룬 - # ԭеSbatch޸ģֱӽˣΪѵԪһepochΪbatchɣÿbatchÿһԪ鶼ݶ½̲߳nbatchӦn߳ - # Pythonʷ⣬ʹGILȫֽʹPythonĶ߳̽Ƽߣ޷cpuԿʹöŻ - # Ϊʹọ̈ʹmanagertransEװΪProxyProxy޷ȡװTransEԣҪдgetloss - # ֵעǣPythonĶܲһforѭͰ˽̵Ĵ١л̼ҪRPCԶͨ - # trainTransEtrainTransE_MultiProcessԱtrainTransEforѭһ10ʱ8s-9strainTransE_MultiProcessһepochһʱ12-13s - # һŻ̳أʵ̸ֽãܣtf - return self.loss - - def clear_loss(self): - # úҲΪProxyⲿʧ0 - self.loss = 0 - - def initialize(self): - """еijʼԼӸĶ - ʼleԭleļе/m/06rf7ַʶתΪdimάdimάuniformnormһ - """ - entity_vector_dict, rels_vector_dict = {}, {} - entity_vector_compo_list, rels_vector_compo_list = [], [] - for item, dict, compo_list, name in zip( - [self.entity_list, self.rels_list], [entity_vector_dict, rels_vector_dict], - [entity_vector_compo_list, rels_vector_compo_list], ["entity_vector_dict", "rels_vector_dict"]): - for entity_or_rel in item: - n = 0 - compo_list = [] - while n < self.dim: - random = uniform(-6 / (self.dim ** 0.5), 6 / (self.dim ** 0.5)) - compo_list.append(random) - n += 1 - compo_list = norm(compo_list) - dict[entity_or_rel] = compo_list - print("The " + name + "'s initialization is over. It's number is %d." % len(dict)) - self.entity_vector_dict = entity_vector_dict - self.rels_vector_dict = rels_vector_dict - - def transE(self, cycle_index=1, num=1500): - Sbatch = self.sample(num) - Tbatch = [] # ԪԣԭԪ飬Ԫ飩б {((h,r,t),(h',r,t'))} - for sbatch in Sbatch: - triplets_with_corrupted_triplets = (sbatch, self.get_corrupted_triplets(sbatch)) - if triplets_with_corrupted_triplets not in Tbatch: - Tbatch.append(triplets_with_corrupted_triplets) - self.update(Tbatch) - - def sample(self, size): - return sample(self.triplets_list, size) - - def get_corrupted_triplets(self, triplets): - '''training triplets with either the head or tail replaced by a random entity (but not both at the same time) - :param triplet:h,t,l - :return corruptedTriplet:''' - # i = uniform(-1, 1) if i - coin = choice([True, False]) - # ʱ(h,t,l)ǴtrainļģҪ򻵵ĻֱѰһͷʵ岻ȵʵ弴 - if coin: # Ӳ Ϊ ͷʵ壬һ - while True: - searching_entity = sample(self.entity_vector_dict.keys(), 1)[0] # ȡһԪΪsampleصһб - if searching_entity != triplets[0]: - break - corrupted_triplets = (searching_entity, triplets[1], triplets[2]) - else: # ֮βʵ壬ڶ - while True: - searching_entity = sample(self.entity_vector_dict.keys(), 1)[0] - if searching_entity != triplets[1]: - break - corrupted_triplets = (triplets[0], searching_entity, triplets[2]) - return corrupted_triplets - - def update(self, Tbatch): - entity_vector_copy = deepcopy(self.entity_vector_dict) - rels_vector_copy = deepcopy(self.rels_vector_dict) - - for triplets_with_corrupted_triplets in Tbatch: - head_entity_vector = entity_vector_copy[triplets_with_corrupted_triplets[0][0]] - tail_entity_vector = entity_vector_copy[triplets_with_corrupted_triplets[0][1]] - relation_vector = rels_vector_copy[triplets_with_corrupted_triplets[0][2]] - - head_entity_vector_with_corrupted_triplets = entity_vector_copy[triplets_with_corrupted_triplets[1][0]] - tail_entity_vector_with_corrupted_triplets = entity_vector_copy[triplets_with_corrupted_triplets[1][1]] - - head_entity_vector_before_batch = self.entity_vector_dict[triplets_with_corrupted_triplets[0][0]] - tail_entity_vector_before_batch = self.entity_vector_dict[triplets_with_corrupted_triplets[0][1]] - relation_vector_before_batch = self.rels_vector_dict[triplets_with_corrupted_triplets[0][2]] - - head_entity_vector_with_corrupted_triplets_before_batch = self.entity_vector_dict[ - triplets_with_corrupted_triplets[1][0]] - tail_entity_vector_with_corrupted_triplets_before_batch = self.entity_vector_dict[ - triplets_with_corrupted_triplets[1][1]] - - if self.normal_form == "L1": - dist_triplets = dist_L1(head_entity_vector_before_batch, tail_entity_vector_before_batch, - relation_vector_before_batch) - dist_corrupted_triplets = dist_L1(head_entity_vector_with_corrupted_triplets_before_batch, - tail_entity_vector_with_corrupted_triplets_before_batch, - relation_vector_before_batch) - else: - dist_triplets = dist_L2(head_entity_vector_before_batch, tail_entity_vector_before_batch, - relation_vector_before_batch) - dist_corrupted_triplets = dist_L2(head_entity_vector_with_corrupted_triplets_before_batch, - tail_entity_vector_with_corrupted_triplets_before_batch, - relation_vector_before_batch) - eg = self.margin + dist_triplets - dist_corrupted_triplets - if eg > 0: # 0ȡԭֵС00.ҳʧmargin-based ranking criterion - self.loss += eg - temp_positive = 2 * self.learning_rate * ( - tail_entity_vector_before_batch - head_entity_vector_before_batch - relation_vector_before_batch) - temp_negative = 2 * self.learning_rate * ( - tail_entity_vector_with_corrupted_triplets_before_batch - head_entity_vector_with_corrupted_triplets_before_batch - relation_vector_before_batch) - if self.normal_form == "L1": - temp_positive_L1 = [1 if temp_positive[i] >= 0 else -1 for i in range(self.dim)] - temp_negative_L1 = [1 if temp_negative[i] >= 0 else -1 for i in range(self.dim)] - temp_positive_L1 = [float(f) for f in temp_positive_L1] - temp_negative_L1 = [float(f) for f in temp_negative_L1] - temp_positive = np.array(temp_positive_L1) * self.learning_rate - temp_negative = np.array(temp_negative_L1) * self.learning_rate - # temp_positive = norm(temp_positive_L1) * self.learning_rate - # temp_negative = norm(temp_negative_L1) * self.learning_rate - - # ʧ5ݶ½ sample - head_entity_vector += temp_positive - tail_entity_vector -= temp_positive - relation_vector = relation_vector + temp_positive - temp_negative - head_entity_vector_with_corrupted_triplets -= temp_negative - tail_entity_vector_with_corrupted_triplets += temp_negative - - # һղŸµټʱ - entity_vector_copy[triplets_with_corrupted_triplets[0][0]] = norm(head_entity_vector) - entity_vector_copy[triplets_with_corrupted_triplets[0][1]] = norm(tail_entity_vector) - rels_vector_copy[triplets_with_corrupted_triplets[0][2]] = norm(relation_vector) - entity_vector_copy[triplets_with_corrupted_triplets[1][0]] = norm( - head_entity_vector_with_corrupted_triplets) - entity_vector_copy[triplets_with_corrupted_triplets[1][1]] = norm( - tail_entity_vector_with_corrupted_triplets) - - # self.entity_vector_dict = deepcopy(entity_vector_copy) - # self.rels_vector_dict = deepcopy(rels_vector_copy) - self.entity_vector_dict = entity_vector_copy - self.rels_vector_dict = rels_vector_copy - - def write_vector(self, file_path, option): - if option.strip().startswith("entit"): - print("Write entities vetor into file : {}".format(file_path)) - # dyct = deepcopy(self.entity_vector_dict) - dyct = self.entity_vector_dict - if option.strip().startswith("rel"): - print("Write relationships vector into file: {}".format(file_path)) - # dyct = deepcopy(self.rels_vector_dict) - dyct = self.rels_vector_dict - with open(file_path, 'w') as file: # дļÿθд withԶclose - for dyct_key in dyct.keys(): - file.write(dyct_key + "\t") - file.write(str(dyct[dyct_key].tolist())) - file.write("\n") - - def write_loss(self, file_path, num_of_col): - with open(file_path, 'w') as file: - lyst = deepcopy(self.loss_list) - for i in range(len(lyst)): - if num_of_col == 1: - # 4λС - file.write(str(int(lyst[i] * 10000) / 10000) + "\n") - # file.write(str(lyst[i]).split('.')[0] + '.' + str(lyst[i]).split('.')[1][:4] + "\n") - else: - # file.write(str(lyst[i]).split('.')[0] + '.' + str(lyst[i]).split('.')[1][:4] + "\t") - file.write(str(int(lyst[i] * 10000) / 10000) + " ") - if (i + 1) % num_of_col == 0 and i != 0: - file.write("\n") - - -class MyManager(BaseManager): - pass - - -def Manager2(): - m = MyManager() - m.start() - return m - - -MyManager.register('TransE', TransE) - - -def func1(em, lock, num): - with lock: - em.transE(num=num) - - -if __name__ == "__main__": - entity_file_path = "data/FB15k/entity2id.txt" - num_of_entity, entity_list = get_details_of_entityOrRels_list(entity_file_path) - rels_file_path = "data/FB15k/relation2id.txt" - num_of_rels, rels_list = get_details_of_entityOrRels_list(rels_file_path) - train_file_path = "data/FB15k/train.txt" - num_of_triplets, triplets_list = get_details_of_triplets_list(train_file_path) - - manager = Manager2() - - transE = manager.TransE(entity_list, rels_list, triplets_list, margin=1, dim=50) - print("\nTransE is initializing...") - transE.initialize() - print("\n********** Start TransE training **********") - - for i in range(20000): # epochĴ - start_time = time.time() - lock = Lock() - proces = [Process(target=func1, args=(transE, lock, 1500)) for j in range(10)] - for p in proces: - p.start() - for p in proces: - p.join() - print("The loss is %.4f" % transE.get_loss()) - transE.clear_loss() - end_time = time.time() - print("The %d epoch(10 batches) takes %.2f ms.\n" % (i, (end_time - start_time) * 1000)) - # transE.transE(100000) - print("********** End TransE training ***********\n") - # ѵβһ100µдļ - transE.write_vector("data/entityVector.txt", "entity") - transE.write_vector("data/relationVector.txt", "relationship") - transE.write_loss("data/lossList_25cols.txt", 25) - transE.write_loss("data/lossList_1cols.txt", 1) \ No newline at end of file diff --git a/train_fb15k.sh b/train_fb15k.sh new file mode 100644 index 0000000..b9b56a0 --- /dev/null +++ b/train_fb15k.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +if [ $# != 1 ] ; then +echo "USAGE: $0 model: 0-simple 1-manger 2-queue" +echo " e.g.: $0 0" +exit 1; +fi +if [[ "$1" -eq 0 ]]; then +mode="Simple" +elif [ "$1" == 1 ]; then +mode="Manager" +elif [ "$1" == 2 ]; then +mode="MpQueue" +else +echo "Illegel argument!" +exit 1; +fi +CUDA_VISIBLE_DEVICES=0 \ +python TrainMain.py --embedding_dim 100 --margin_value 1 --normal_form "L1" --batch_size 10000 --learning_rate 0.003 --n_generator 24 --max_epoch 5000 --multi_process $mode \ No newline at end of file diff --git a/train_wn18.sh b/train_wn18.sh new file mode 100644 index 0000000..b1c33d2 --- /dev/null +++ b/train_wn18.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash +CUDA_VISIBLE_DEVICES=0 \ +python TrainMain.py \ +--data_dir ../data/WN18/ \ +--embedding_dim 50 \ +--margin_value 4 \ +--batch_size 3000 \ +--learning_rate 0.01 \ +--n_generator 24 \ +--n_rank_calculator 24 \ +--eval_freq 100 \ +--max_epoch 5000 \ No newline at end of file