From f40fdcf9a1aa02b59f8866d9157498cbf23e9ca6 Mon Sep 17 00:00:00 2001 From: haidfs <2330530604@qq.com> Date: Fri, 20 Mar 2020 18:58:54 +0800 Subject: [PATCH] add generate embedding res func --- TestDatasetTF.py | 5 ++-- TestMainTF.py | 1 + TestModelTF.py | 7 +++-- TestTransEMpQueue.py | 29 ++++++++++--------- TrainMain.py | 7 +++-- TrainTransEMpManager.py | 28 +++++++++++------- TrainTransEMpQueue.py | 20 +++++++------ TrainTransESimple.py | 64 +++++++++++++++++++++-------------------- 8 files changed, 89 insertions(+), 72 deletions(-) diff --git a/TestDatasetTF.py b/TestDatasetTF.py index b464818..e51ba5b 100644 --- a/TestDatasetTF.py +++ b/TestDatasetTF.py @@ -1,11 +1,12 @@ +# -*- coding: UTF-8 -*- import os import pandas as pd class KnowledgeGraph: def __init__(self, data_dir): - # 考虑到tf的各项api使用,Python不能将Tensor类型直接转换成字符串类型,但是可以将TF类型转换成numpy类型 - # 所以这里的训练三元组,测试三元组等等,都是id三元组,而不是字符串三元组 + # 鑰冭檻鍒皌f鐨勫悇椤筧pi浣跨敤锛孭ython涓嶈兘灏員ensor绫诲瀷鐩存帴杞崲鎴愬瓧绗︿覆绫诲瀷锛屼絾鏄彲浠ュ皢TF绫诲瀷杞崲鎴恘umpy绫诲瀷 + # 鎵浠ヨ繖閲岀殑璁粌涓夊厓缁勶紝娴嬭瘯涓夊厓缁勭瓑绛夛紝閮芥槸id涓夊厓缁勶紝鑰屼笉鏄瓧绗︿覆涓夊厓缁 self.data_dir = data_dir self.entity_dict = {} self.entities = [] diff --git a/TestMainTF.py b/TestMainTF.py index 0b3a199..0d8ce32 100644 --- a/TestMainTF.py +++ b/TestMainTF.py @@ -1,3 +1,4 @@ +# -*- coding: UTF-8 -*- import logging import tensorflow as tf diff --git a/TestModelTF.py b/TestModelTF.py index 185c68c..db9fc02 100644 --- a/TestModelTF.py +++ b/TestModelTF.py @@ -1,3 +1,4 @@ +# -*- coding: UTF-8 -*- import timeit import numpy as np import tensorflow as tf @@ -45,9 +46,9 @@ def evaluate(self, eval_triple): relation = tf.nn.embedding_lookup( self.relation_embedding, eval_triple[2]) with tf.name_scope('link'): - # 并不太明确这里的用途,h,r,t应该都是[1,dim]维度的向量, self.entity_embedding应该是[n,dim]维度的向量,做加减法得到的是什么类型? - # 如果是list类型,对于不同维度是不能直接加减的。但是对于np.array或者tf的embedding,是可以直接相减的,等同于 self.entity_embedding - # 的每一行都在和h,r,t做运算 + # 锟斤拷锟斤拷太锟斤拷确锟斤拷锟斤拷锟斤拷锟酵撅拷锟絟,r,t应锟矫讹拷锟斤拷[1,dim]维锟饺碉拷锟斤拷锟斤拷锟斤拷 self.entity_embedding应锟斤拷锟斤拷[n,dim]维锟饺碉拷锟斤拷锟斤拷锟斤拷锟斤拷锟接硷拷锟斤拷锟矫碉拷锟斤拷锟斤拷什么锟斤拷锟酵o拷 + # 锟斤拷锟斤拷锟絣ist锟斤拷锟酵o拷锟斤拷锟节诧拷同维锟斤拷锟角诧拷锟斤拷直锟接加硷拷锟侥★拷锟斤拷锟角讹拷锟斤拷np.array锟斤拷锟斤拷tf锟斤拷embedding锟斤拷锟角匡拷锟斤拷直锟斤拷锟斤拷锟斤拷模锟斤拷锟酵拷锟 self.entity_embedding + # 锟斤拷每一锟叫讹拷锟节猴拷h,r,t锟斤拷锟斤拷锟斤拷 distance_head_prediction = self.entity_embedding + relation - tail distance_tail_prediction = head + relation - self.entity_embedding with tf.name_scope('rank'): diff --git a/TestTransEMpQueue.py b/TestTransEMpQueue.py index 52ac077..b784271 100644 --- a/TestTransEMpQueue.py +++ b/TestTransEMpQueue.py @@ -1,3 +1,4 @@ +# -*- coding: UTF-8 -*- from numpy import * import operator import logging @@ -10,17 +11,17 @@ class Test: - '''基本的评价过程 -假设整个知识库中一共有n个实体,那么评价过程如下: -对于每一个测试的三元组a中的头实体或者尾实体,依次替换为整个知识库中的所有其它实体,也就是会产生n个三元组。 -分别对上述n个三元组计算其能量值(dist值),在transE中,就是计算h+r-t的值。这样可以得到n个能量值,分别对应上述n个三元组。 -对上述n个能量值进行升序排序。 -记录原本的三元组a的能量值排序后的序号。 -对所有处在测试集中的测试三元组重复上述过程。 -每个正确三元组的能量值排序后的序号求平均,得到的值我们称为Mean Rank。 -计算正确三元组的能量排序后的序号小于10的比例,得到的值我们称为Hits@10。 -上述就是评价的过程,共有两个指标:Mean Rank和Hits@10。其中Mean Rank越小越好,Hits@10越大越好。该代码未计算Hits@10,且Python对于这种大量计算速度很慢。 -建议读者后续使用清华大学库的Fast_TransX代码,使用C++编写,性能高,能够快速得出训练和测试结果。 + '''锟斤拷锟斤拷锟斤拷锟斤拷锟桔癸拷锟斤拷 +锟斤拷锟斤拷锟斤拷锟斤拷知识锟斤拷锟斤拷一锟斤拷锟斤拷n锟斤拷实锟藉,锟斤拷么锟斤拷锟桔癸拷锟斤拷锟斤拷锟铰o拷 +锟斤拷锟斤拷每一锟斤拷锟斤拷锟皆碉拷锟斤拷元锟斤拷a锟叫碉拷头实锟斤拷锟斤拷锟轿彩碉拷澹拷锟斤拷锟斤拷婊晃拷锟斤拷锟街讹拷锟斤拷械锟斤拷锟斤拷锟斤拷锟斤拷锟绞碉拷澹诧拷锟斤拷腔锟斤拷锟斤拷n锟斤拷锟斤拷元锟介。 +锟街憋拷锟斤拷锟斤拷锟絥锟斤拷锟斤拷元锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟街(dist值)锟斤拷锟斤拷transE锟叫o拷锟斤拷锟角硷拷锟斤拷h+r-t锟斤拷值锟斤拷锟斤拷锟斤拷锟斤拷锟皆得碉拷n锟斤拷锟斤拷锟斤拷值锟斤拷锟街憋拷锟接︼拷锟斤拷锟絥锟斤拷锟斤拷元锟介。 +锟斤拷锟斤拷锟斤拷n锟斤拷锟斤拷锟斤拷值锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷 +锟斤拷录原锟斤拷锟斤拷锟斤拷元锟斤拷a锟斤拷锟斤拷锟斤拷值锟斤拷锟斤拷锟斤拷锟斤拷拧锟 +锟斤拷锟斤拷锟叫达拷锟节诧拷锟皆硷拷锟叫的诧拷锟斤拷锟斤拷元锟斤拷锟截革拷锟斤拷锟斤拷锟斤拷锟教★拷 +每锟斤拷锟斤拷确锟斤拷元锟斤拷锟斤拷锟斤拷锟街碉拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷平锟斤拷锟斤拷锟矫碉拷锟斤拷值锟斤拷锟角筹拷为Mean Rank锟斤拷 +锟斤拷锟斤拷锟斤拷确锟斤拷元锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷小锟斤拷10锟侥憋拷锟斤拷锟斤拷锟矫碉拷锟斤拷值锟斤拷锟角筹拷为Hits@10锟斤拷 +锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟桔的癸拷锟教o拷锟斤拷锟斤拷锟斤拷锟斤拷指锟疥:Mean Rank锟斤拷Hits@10锟斤拷锟斤拷锟斤拷Mean Rank越小越锟矫o拷Hits@10越锟斤拷越锟矫★拷锟矫达拷锟斤拷未锟斤拷锟斤拷Hits@10锟斤拷锟斤拷Python锟斤拷锟斤拷锟斤拷锟街达拷锟斤拷锟斤拷锟斤拷锟劫度猴拷锟斤拷锟斤拷 +锟斤拷锟斤拷锟斤拷吆锟斤拷锟绞癸拷锟斤拷寤拷锟窖э拷锟斤拷Fast_TransX锟斤拷锟诫,使锟斤拷C++锟斤拷写锟斤拷锟斤拷锟杰高o拷锟杰癸拷锟斤拷锟劫得筹拷训锟斤拷锟酵诧拷锟皆斤拷锟斤拷锟 ''' def __init__(self, entity_dyct, relation_dyct, train_triple_list, @@ -57,7 +58,7 @@ def get_rank_part(self, triplet): continue rank_dyct[ent] = distance(self.entity_dyct[ent], self.entity_dyct[triplet[1]], self.relation_dyct[triplet[2]]) - else: # 根据标签替换头实体或者替换尾实体计算距离 + else: # 锟斤拷锟捷憋拷签锟芥换头实锟斤拷锟斤拷锟斤拷婊晃彩碉拷锟斤拷锟斤拷锟斤拷锟 corrupted_triplet = (triplet[0], ent, triplet[2]) if self.is_fit and ( corrupted_triplet in self.train_triple_list): @@ -65,7 +66,7 @@ def get_rank_part(self, triplet): rank_dyct[ent] = distance(self.entity_dyct[triplet[0]], self.entity_dyct[ent], self.relation_dyct[triplet[2]]) sorted_rank = sorted(rank_dyct.items(), - key=operator.itemgetter(1)) # 按照元素的第一个域进行升序排序 + key=operator.itemgetter(1)) # 锟斤拷锟斤拷元锟截的碉拷一锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟 if self.label == 'head': num_tri = 0 else: @@ -167,7 +168,7 @@ def get_dict_from_vector_file(file_path): dyct = {} for line in file.readlines(): name_vector = line.strip().split("\t") - # 这里的vector使用[1:-1]是因为vector是'[0.11,0.22,..]'这样的str类型,[1:-1]是为了去掉列表的中括号 + # 锟斤拷锟斤拷锟絭ector使锟斤拷[1:-1]锟斤拷锟斤拷为vector锟斤拷'[0.11,0.22,..]'锟斤拷锟斤拷锟斤拷str锟斤拷锟酵o拷[1:-1]锟斤拷为锟斤拷去锟斤拷锟叫憋拷锟斤拷锟斤拷锟斤拷锟 vector = [float(s) for s in name_vector[1][1:-1].split(", ")] name = name_vector[0] dyct[name] = vector diff --git a/TrainMain.py b/TrainMain.py index a7f3a4a..ead1212 100644 --- a/TrainMain.py +++ b/TrainMain.py @@ -1,3 +1,4 @@ +# -*- coding: UTF-8 -*- import timeit from TrainTransESimple import prepare_fb15k_train_data from TrainTransESimple import TransE @@ -72,9 +73,9 @@ def main(): normal_form=args.normal_form) logging.info("TransE is initializing...") start = timeit.default_timer() - for i in range(args.max_epoch): # epoch的次数 + for i in range(args.max_epoch): # epoch锟侥达拷锟斤拷 lock = Lock() - proces = [Process(target=func1, args=(transE, lock)) for j in range(10)] # 10个多进程,谨慎运行,电脑会很卡 + proces = [Process(target=func1, args=(transE, lock)) for j in range(10)] # 10锟斤拷锟斤拷锟斤拷蹋锟斤拷锟斤拷锟斤拷锟斤拷校锟斤拷锟斤拷曰锟杰匡拷 for p in proces: p.start() for p in proces: @@ -86,7 +87,7 @@ def main(): start = end transE.clear_loss() logging.info("********** End TransE training ***********\n") - # 训练的批次并不一定是100的整数倍,将最后更新的向量写到文件 + # 训锟斤拷锟斤拷锟斤拷锟轿诧拷锟斤拷一锟斤拷锟斤拷100锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟斤拷锟铰碉拷锟斤拷锟斤拷写锟斤拷锟侥硷拷 transE.write_vector("data/entityVector.txt", "entity") transE.write_vector("data/relationVector.txt", "relationship") diff --git a/TrainTransEMpManager.py b/TrainTransEMpManager.py index ecad8f2..f3c88e6 100644 --- a/TrainTransEMpManager.py +++ b/TrainTransEMpManager.py @@ -1,3 +1,4 @@ +# -*- coding: UTF-8 -*- from multiprocessing import Process, Lock from multiprocessing.managers import BaseManager import logging @@ -13,22 +14,24 @@ class TransE(TransESimple): def get_loss(self): - # 参考清华的Fast-TransX的C++代码,确实速度很快,Python接近10个小时的训练C++大概在十几分钟即可完成。粗略的看了一下代码, - # 它对原本的论文中的Sbatch做了修改,直接进行了(总数量为训练三元组数,一个epoch分为多个batch完成,每个batch的每一个三元组都随机采样),随机梯度下降。多线程并发,n个batch对应n个线程 - # Python由于历史遗留问题,使用了GIL,全局解释锁,使得Python的多线程近似鸡肋,无法跑满多核cpu,所以考虑使用多进程优化 - # 为了使用多进程,使用了manager将transE封装为Proxy对象。由于Proxy对象无法获取封装的TransE类的属性,所以需要写get函数将loss传出。 - # 另外值得注意的是,Python的多进程性能不一定优于for循环。基本开销就包括了进程的创建和销毁、上下文切换(进程间需要RPC远程通信以做到类变量共享)。 - # 至少在trainTransE和trainTransE_MultiProcess对比来看,trainTransE的for循环一批10个耗时在8s-9s,trainTransE_MultiProcess的一个epoch即一批,耗时在12-13s。 - # 进一步优化方法:进程池,实现进程复用?框架:tf?? + # 鍙傝冩竻鍗庣殑Fast-TransX鐨凜++浠g爜锛岀‘瀹為熷害寰堝揩锛孭ython鎺ヨ繎10涓皬鏃剁殑璁粌C++澶ф鍦ㄥ崄鍑犲垎閽熷嵆鍙畬鎴愩傜矖鐣ョ殑鐪嬩簡涓涓嬩唬鐮侊紝 + # 瀹冨鍘熸湰鐨勮鏂囦腑鐨凷batch鍋氫簡淇敼锛岀洿鎺ヨ繘琛屼簡锛堟绘暟閲忎负璁粌涓夊厓缁勬暟锛屼竴涓猠poch鍒嗕负澶氫釜batch瀹屾垚锛屾瘡涓猙atch鐨勬瘡涓涓笁鍏冪粍閮介殢鏈洪噰鏍凤級锛岄殢鏈烘搴︿笅闄嶃傚绾跨▼骞跺彂锛宯涓猙atch瀵瑰簲n涓嚎绋 + # Python鐢变簬鍘嗗彶閬楃暀闂锛屼娇鐢ㄤ簡GIL锛屽叏灞瑙i噴閿侊紝浣垮緱Python鐨勫绾跨▼杩戜技楦¤倠锛屾棤娉曡窇婊″鏍竎pu锛屾墍浠ヨ冭檻浣跨敤澶氳繘绋嬩紭鍖 + # 涓轰簡浣跨敤澶氳繘绋嬶紝浣跨敤浜唌anager灏唗ransE灏佽涓篜roxy瀵硅薄銆傜敱浜嶱roxy瀵硅薄鏃犳硶鑾峰彇灏佽鐨凾ransE绫荤殑灞炴э紝鎵浠ラ渶瑕佸啓get鍑芥暟灏唋oss浼犲嚭銆 + # 鍙﹀鍊煎緱娉ㄦ剰鐨勬槸锛孭ython鐨勫杩涚▼鎬ц兘涓嶄竴瀹氫紭浜巉or寰幆銆傚熀鏈紑閿灏卞寘鎷簡杩涚▼鐨勫垱寤哄拰閿姣併佷笂涓嬫枃鍒囨崲锛堣繘绋嬮棿闇瑕丷PC杩滅▼閫氫俊浠ュ仛鍒扮被鍙橀噺鍏变韩锛夈 + # 鑷冲皯鍦╰rainTransE鍜宼rainTransE_MultiProcess瀵规瘮鏉ョ湅锛宼rainTransE鐨刦or寰幆涓鎵10涓楁椂鍦8s-9s锛宼rainTransE_MultiProcess鐨勪竴涓猠poch鍗充竴鎵癸紝鑰楁椂鍦12-13s銆 + # 杩涗竴姝ヤ紭鍖栨柟娉曪細杩涚▼姹狅紝瀹炵幇杩涚▼澶嶇敤锛熸鏋讹細tf锛燂紵 return self.loss def clear_loss(self): - # 该函数也是为了Proxy对象外部将损失置0 + # 璇ュ嚱鏁颁篃鏄负浜哖roxy瀵硅薄澶栭儴灏嗘崯澶辩疆0 self.loss = 0 def transE(self): + # 杩欎釜鍦版柟鍜岀埗绫荤殑transE鐨勫尯鍒湪浜庯紝杩欓噷鐢变簬鏄杩涚▼涔嬮棿鐩存帴鍏变韩class TransE鐨勫疄渚嬶紝鎵浠ョ幇鍦ㄥ苟涓嶇煡閬撳搴旂殑 + # 璁粌epoch锛屼簬鏄繖涓湴鏂瑰垹鎺変簡鍘熸湰鐨勫啓鏂囦欢鍑芥暟 Sbatch = self.sample(self.batch_size // 10) - Tbatch = [] # 元组对(原三元组,打碎的三元组)的列表 :{((h,r,t),(h',r,t'))} + Tbatch = [] # 鍏冪粍瀵癸紙鍘熶笁鍏冪粍锛屾墦纰庣殑涓夊厓缁勶級鐨勫垪琛 锛歿((h,r,t),(h',r,t'))} for sbatch in Sbatch: pos_neg_triplets = (sbatch, self.get_corrupted_triplets(sbatch)) if pos_neg_triplets not in Tbatch: @@ -66,10 +69,10 @@ def main(): margin=1, dim=50) logging.info("TransE is initializing...") - for i in range(20000): # epoch的次数 + for i in range(2000): # epoch鐨勬鏁 lock = Lock() proces = [Process(target=func1, args=(transE, lock)) - for j in range(10)] # 10个多进程,谨慎运行,电脑会很卡 + for j in range(10)] # 10涓杩涚▼锛岃皑鎱庤繍琛岋紝鐢佃剳浼氬緢鍗 for p in proces: p.start() for p in proces: @@ -78,6 +81,9 @@ def main(): logging.info( "After %d training epoch(s), loss on batch data is %g" % (i * 10, transE.get_loss())) + if i % 100 == 0: + transE.write_vector("data/entityVectorMpManager.txt", "entity") + transE.write_vector("data/relationVectorMpManager.txt", "rels") transE.clear_loss() # transE.transE(100000) logging.info("********** End TransE training ***********\n") diff --git a/TrainTransEMpQueue.py b/TrainTransEMpQueue.py index d4ce1c8..a1a2bcd 100644 --- a/TrainTransEMpQueue.py +++ b/TrainTransEMpQueue.py @@ -1,3 +1,4 @@ +# -*- coding: UTF-8 -*- import numpy as np from multiprocessing import Process, Queue import logging @@ -72,16 +73,16 @@ def update_part(self, pos_triplet, neg_triplet): entity_vector_copy = self.entity_vector_dict rels_vector_copy = self.rels_vector_dict - # 这里的h,t,r代表头实体向量、尾实体向量、关系向量,h2和t2代表论文中的h'和t',即负例三元组中的头尾实体向量 - # Tbatch是元组对(原三元组,打碎的三元组)的列表 - # :[((h,r,t),(h',r,t'))...],这里由于data文件的原因是(h,t,r) + # 杩欓噷鐨刪,t,r浠h〃澶村疄浣撳悜閲忋佸熬瀹炰綋鍚戦噺銆佸叧绯诲悜閲忥紝h2鍜宼2浠h〃璁烘枃涓殑h'鍜宼'锛屽嵆璐熶緥涓夊厓缁勪腑鐨勫ご灏惧疄浣撳悜閲 + # Tbatch鏄厓缁勫锛堝師涓夊厓缁勶紝鎵撶鐨勪笁鍏冪粍锛夌殑鍒楄〃 + # 锛歔((h,r,t),(h',r,t'))...]锛岃繖閲岀敱浜巇ata鏂囦欢鐨勫師鍥犳槸(h,t,r) h = entity_vector_copy[pos_triplet[0]] t = entity_vector_copy[pos_triplet[1]] r = rels_vector_copy[pos_triplet[2]] - # 损坏三元组中的头实体向量与尾实体向量 + # 鎹熷潖涓夊厓缁勪腑鐨勫ご瀹炰綋鍚戦噺涓庡熬瀹炰綋鍚戦噺 h2 = entity_vector_copy[neg_triplet[0]] t2 = entity_vector_copy[neg_triplet[1]] - # 在这里原本定义了beforebatch,但是个人认为没有必要,这里已经进入到batch里面了,走的就是单个处理 + # 鍦ㄨ繖閲屽師鏈畾涔変簡beforebatch锛屼絾鏄釜浜鸿涓烘病鏈夊繀瑕侊紝杩欓噷宸茬粡杩涘叆鍒癰atch閲岄潰浜嗭紝璧扮殑灏辨槸鍗曚釜澶勭悊 if self.normal_form == "L1": dist_triplets = dist_L1(h, t, r) dist_corrupted_triplets = dist_L1(h2, t2, r) @@ -89,7 +90,7 @@ def update_part(self, pos_triplet, neg_triplet): dist_triplets = dist_L2(h, t, r) dist_corrupted_triplets = dist_L2(h2, t2, r) eg = self.margin + dist_triplets - dist_corrupted_triplets - if eg > 0: # 大于0取原值,小于0则置0.即合页损失函数margin-based ranking criterion + if eg > 0: # 澶т簬0鍙栧師鍊硷紝灏忎簬0鍒欑疆0.鍗冲悎椤垫崯澶卞嚱鏁癿argin-based ranking criterion self.loss += eg temp_positive = 2 * self.learning_rate * (t - h - r) temp_negative = 2 * self.learning_rate * (t2 - h2 - r) @@ -99,14 +100,14 @@ def update_part(self, pos_triplet, neg_triplet): temp_positive = np.array(temp_positive_L1) * self.learning_rate temp_negative = np.array(temp_negative_L1) * self.learning_rate - # 对损失函数的5个参数进行梯度下降, 随机体现在sample函数上 + # 瀵规崯澶卞嚱鏁扮殑5涓弬鏁拌繘琛屾搴︿笅闄嶏紝 闅忔満浣撶幇鍦╯ample鍑芥暟涓 h += temp_positive t -= temp_positive r = r + temp_positive - temp_negative h2 -= temp_negative t2 += temp_negative - # 归一化刚才更新的向量,减少计算时间 + # 褰掍竴鍖栧垰鎵嶆洿鏂扮殑鍚戦噺锛屽噺灏戣绠楁椂闂 entity_vector_copy[pos_triplet[0]] = norm(h) entity_vector_copy[pos_triplet[1]] = norm(t) rels_vector_copy[pos_triplet[2]] = norm(r) @@ -131,6 +132,9 @@ def main(): for epoch in range(2000): print("Mp Queue TransE, After %d training epoch(s):\n" % epoch) transE.launch_training() + if epoch % 100 == 0: + transE.write_vector("data/entityVectorMpQueue.txt", "entity") + transE.write_vector("data/relationVectorMpQueue.txt", "rels") logging.info("********** End TransE training ***********\n") diff --git a/TrainTransESimple.py b/TrainTransESimple.py index 83f4fc3..47a98c3 100644 --- a/TrainTransESimple.py +++ b/TrainTransESimple.py @@ -1,3 +1,4 @@ +# -*- coding: UTF-8 -*- import timeit from random import uniform, sample, choice import numpy as np @@ -14,8 +15,8 @@ def get_details_of_entityOrRels_list(file_path, split_delimeter="\t"): num_of_file = 0 lyst = [] with open(file_path) as file: - # 确实是直接使用readlines的,低内存模式是在read_csv api中使用 csv_data = pd.read_csv(csv_file, low_memory=False) - # 读取时可以不写r的参数,因为mode参数默认即为r + # 纭疄鏄洿鎺ヤ娇鐢╮eadlines鐨勶紝浣庡唴瀛樻ā寮忔槸鍦╮ead_csv api涓娇鐢 csv_data = pd.read_csv(csv_file, low_memory=False) + # 璇诲彇鏃跺彲浠ヤ笉鍐檙鐨勫弬鏁帮紝鍥犱负mode鍙傛暟榛樿鍗充负r lines = file.readlines() for line in lines: details_and_id = line.strip().split(split_delimeter) @@ -39,25 +40,25 @@ def get_details_of_triplets_list(file_path, split_delimeter="\t"): def norm(lyst): - # 归一化 单位向量 + # 褰掍竴鍖 鍗曚綅鍚戦噺 var = np.linalg.norm(lyst) i = 0 while i < len(lyst): lyst[i] = lyst[i] / var i += 1 - # 需要返回array值 因为list不支持减法 + # 闇瑕佽繑鍥瀉rray鍊 鍥犱负list涓嶆敮鎸佸噺娉 return np.array(lyst) def dist_L1(h, t, l): s = h + l - t - # 曼哈顿距离/出租车距离, |x-xi|+|y-yi|直接对向量的各个维度取绝对值相加 + # 鏇煎搱椤胯窛绂/鍑虹杞﹁窛绂伙紝 |x-xi|+|y-yi|鐩存帴瀵瑰悜閲忕殑鍚勪釜缁村害鍙栫粷瀵瑰肩浉鍔 return np.fabs(s).sum() def dist_L2(h, t, l): s = h + l - t - # 欧氏距离,是向量的平方和未开方。一定要注意,归一化公式和距离公式的错误书写,会引起收敛的失败 + # 娆ф皬璺濈,鏄悜閲忕殑骞虫柟鍜屾湭寮鏂广備竴瀹氳娉ㄦ剰锛屽綊涓鍖栧叕寮忓拰璺濈鍏紡鐨勯敊璇功鍐欙紝浼氬紩璧锋敹鏁涚殑澶辫触 return (s * s).sum() @@ -88,12 +89,12 @@ def __init__( def initialize(self): ''' - 对论文中的初始化稍加改动 - 初始化l和e,对于原本的l和e的文件中的/m/06rf7字符串标识转化为定义的dim维向量,对dim维向量进行uniform和norm归一化操作 + 瀵硅鏂囦腑鐨勫垵濮嬪寲绋嶅姞鏀瑰姩 + 鍒濆鍖杔鍜宔锛屽浜庡師鏈殑l鍜宔鐨勬枃浠朵腑鐨/m/06rf7瀛楃涓叉爣璇嗚浆鍖栦负瀹氫箟鐨刣im缁村悜閲忥紝瀵筪im缁村悜閲忚繘琛寀niform鍜宯orm褰掍竴鍖栨搷浣 :return: ''' entity_vector_dict, rels_vector_dict = {}, {} - # component的意思是向量的分量,当达到向量维数之后,对向量进行归一化,就完成了伪码中的初始化部分。 + # component鐨勬剰鎬濇槸鍚戦噺鐨勫垎閲忥紝褰撹揪鍒板悜閲忕淮鏁颁箣鍚庯紝瀵瑰悜閲忚繘琛屽綊涓鍖栵紝灏卞畬鎴愪簡浼爜涓殑鍒濆鍖栭儴鍒嗐 entity_vector_compo_list, rels_vector_compo_list = [], [] for item, dyct, compo_list, name in zip([self.entity_list, self.rels_list], [entity_vector_dict, rels_vector_dict], @@ -116,25 +117,26 @@ def transE(self, cycle_index=20): for i in range(cycle_index): start = timeit.default_timer() Sbatch = self.sample(self.batch_size) - Tbatch = [] # 元组对(原三元组,打碎的三元组)的列表 :{((h,r,t),(h',r,t'))} + Tbatch = [] # 鍏冪粍瀵癸紙鍘熶笁鍏冪粍锛屾墦纰庣殑涓夊厓缁勶級鐨勫垪琛 锛歿((h,r,t),(h',r,t'))} for sbatch in Sbatch: - # 这里的pos_neg_triplets代表正负例三元组对,positive,negative + # 杩欓噷鐨刾os_neg_triplets浠h〃姝h礋渚嬩笁鍏冪粍瀵癸紝positive锛宯egative pos_neg_triplets = ( sbatch, self.get_corrupted_triplets(sbatch)) if pos_neg_triplets not in Tbatch: Tbatch.append(pos_neg_triplets) self.update(Tbatch) if i % 1 == 0: - # 可以更改i值考虑使用ema,即指数滑动平均 + # 鍙互鏇存敼i鍊艰冭檻浣跨敤ema锛屽嵆鎸囨暟婊戝姩骞冲潎 # self.learning_rate = INITIAL_LEARNING_RATE * (pow(0.96, i / 100)) end = timeit.default_timer() logging.info( "Simple TransE, After %d training epoch(s):\nbatch size is %d, cost time is %g s, loss on batch data is %g" % (i, self.batch_size, end - start, self.loss)) - # 查看最后的结果收敛情况 + # 鏌ョ湅鏈鍚庣殑缁撴灉鏀舵暃鎯呭喌 self.loss_list.append(self.loss) - # self.write_vector("data/entityVector.txt", "entity") - # self.write_vector("data/relationVector.txt", "rels") + if i % 100 == 0: + self.write_vector("data/entityVector.txt", "entity") + self.write_vector("data/relationVector.txt", "rels") self.loss = 0 def sample(self, size): @@ -142,18 +144,18 @@ def sample(self, size): def get_corrupted_triplets(self, triplets): '''training triplets with either the head or tail replaced by a random entity (but not both at the same time) - :param triplet:单个(h,t,l) + :param triplet:鍗曚釜锛坔,t,l锛 :return corruptedTriplet:''' coin = choice([True, False]) - # 由于这个时候的(h,t,l)是从train文件里面抽出来的,要打坏的话直接随机寻找一个和头实体不等的实体即可 - if coin: # 抛硬币 为真 打破头实体,即第一项 + # 鐢变簬杩欎釜鏃跺欑殑(h,t,l)鏄粠train鏂囦欢閲岄潰鎶藉嚭鏉ョ殑锛岃鎵撳潖鐨勮瘽鐩存帴闅忔満瀵绘壘涓涓拰澶村疄浣撲笉绛夌殑瀹炰綋鍗冲彲 + if coin: # 鎶涚‖甯 涓虹湡 鎵撶牬澶村疄浣擄紝鍗崇涓椤 while True: - # 取第一个元素是因为sample返回的是一个列表类型 + # 鍙栫涓涓厓绱犳槸鍥犱负sample杩斿洖鐨勬槸涓涓垪琛ㄧ被鍨 searching_entity = sample(self.entity_vector_dict.keys(), 1)[0] if searching_entity != triplets[0]: break corrupted_triplets = (searching_entity, triplets[1], triplets[2]) - else: # 反之,打破尾实体,即第二项 + else: # 鍙嶄箣锛屾墦鐮村熬瀹炰綋锛屽嵆绗簩椤 while True: searching_entity = sample(self.entity_vector_dict.keys(), 1)[0] if searching_entity != triplets[1]: @@ -165,17 +167,17 @@ def update(self, Tbatch): entity_vector_copy = self.entity_vector_dict rels_vector_copy = self.rels_vector_dict - # 这里的h,t,r代表头实体向量、尾实体向量、关系向量,h2和t2代表论文中的h'和t',即负例三元组中的头尾实体向量 - # Tbatch是元组对(原三元组,打碎的三元组)的列表 - # :[((h,r,t),(h',r,t'))...],这里由于data文件的原因是(h,t,r) + # 杩欓噷鐨刪,t,r浠h〃澶村疄浣撳悜閲忋佸熬瀹炰綋鍚戦噺銆佸叧绯诲悜閲忥紝h2鍜宼2浠h〃璁烘枃涓殑h'鍜宼'锛屽嵆璐熶緥涓夊厓缁勪腑鐨勫ご灏惧疄浣撳悜閲 + # Tbatch鏄厓缁勫锛堝師涓夊厓缁勶紝鎵撶鐨勪笁鍏冪粍锛夌殑鍒楄〃 + # 锛歔((h,r,t),(h',r,t'))...]锛岃繖閲岀敱浜巇ata鏂囦欢鐨勫師鍥犳槸(h,t,r) for pos_neg_triplets in Tbatch: h = entity_vector_copy[pos_neg_triplets[0][0]] t = entity_vector_copy[pos_neg_triplets[0][1]] r = rels_vector_copy[pos_neg_triplets[0][2]] - # 损坏三元组中的头实体向量与尾实体向量 + # 鎹熷潖涓夊厓缁勪腑鐨勫ご瀹炰綋鍚戦噺涓庡熬瀹炰綋鍚戦噺 h2 = entity_vector_copy[pos_neg_triplets[1][0]] t2 = entity_vector_copy[pos_neg_triplets[1][1]] - # 这里原本定义了beforebatch,但是个人认为没有必要,这里已经进入到batch里面了,走的就是单个处理 + # 杩欓噷鍘熸湰瀹氫箟浜哹eforebatch锛屼絾鏄釜浜鸿涓烘病鏈夊繀瑕侊紝杩欓噷宸茬粡杩涘叆鍒癰atch閲岄潰浜嗭紝璧扮殑灏辨槸鍗曚釜澶勭悊 if self.normal_form == "L1": dist_triplets = dist_L1(h, t, r) dist_corrupted_triplets = dist_L1(h2, t2, r) @@ -183,7 +185,7 @@ def update(self, Tbatch): dist_triplets = dist_L2(h, t, r) dist_corrupted_triplets = dist_L2(h2, t2, r) eg = self.margin + dist_triplets - dist_corrupted_triplets - if eg > 0: # 大于0取原值,小于0则置0.即合页损失函数margin-based ranking criterion + if eg > 0: # 澶т簬0鍙栧師鍊硷紝灏忎簬0鍒欑疆0.鍗冲悎椤垫崯澶卞嚱鏁癿argin-based ranking criterion self.loss += eg temp_positive = 2 * self.learning_rate * (t - h - r) temp_negative = 2 * self.learning_rate * (t2 - h2 - r) @@ -193,14 +195,14 @@ def update(self, Tbatch): temp_positive = np.array(temp_positive_L1) * self.learning_rate temp_negative = np.array(temp_negative_L1) * self.learning_rate - # 对损失函数的5个参数进行梯度下降, 随机体现在sample函数上 + # 瀵规崯澶卞嚱鏁扮殑5涓弬鏁拌繘琛屾搴︿笅闄嶏紝 闅忔満浣撶幇鍦╯ample鍑芥暟涓 h += temp_positive t -= temp_positive r = r + temp_positive - temp_negative h2 -= temp_negative t2 += temp_negative - # 归一化刚才更新的向量,减少计算时间 + # 褰掍竴鍖栧垰鎵嶆洿鏂扮殑鍚戦噺锛屽噺灏戣绠楁椂闂 entity_vector_copy[pos_neg_triplets[0][0]] = norm(h) entity_vector_copy[pos_neg_triplets[0][1]] = norm(t) rels_vector_copy[pos_neg_triplets[0][2]] = norm(r) @@ -221,7 +223,7 @@ def write_vector(self, file_path, option): "Write relationships vector into file: {}".format(file_path)) # dyct = deepcopy(self.rels_vector_dict) dyct = self.rels_vector_dict - with open(file_path, 'w') as file: # 写文件,每次覆盖写 用with自动调用close + with open(file_path, 'w') as file: # 鍐欐枃浠讹紝姣忔瑕嗙洊鍐 鐢╳ith鑷姩璋冪敤close for dyct_key in dyct.keys(): file.write(dyct_key + "\t") file.write(str(dyct[dyct_key].tolist())) @@ -232,7 +234,7 @@ def write_loss(self, file_path, num_of_col): lyst = deepcopy(self.loss_list) for i in range(len(lyst)): if num_of_col == 1: - # 保留4位小数 + # 淇濈暀4浣嶅皬鏁 file.write(str(int(lyst[i] * 10000) / 10000) + "\n") else: file.write(str(int(lyst[i] * 10000) / 10000) + " ") @@ -255,7 +257,7 @@ def prepare_fb15k_train_data(): def main(): - # 对应TrainMain中的 --multi_process "None"的测试代码 + # 瀵瑰簲TrainMain涓殑 --multi_process "None"鐨勬祴璇曚唬鐮 entity_list, rels_list, train_triplets_list = prepare_fb15k_train_data() transE = TransE(