From e2cd859e66016f301ace49776874b2d3ee2e4fe5 Mon Sep 17 00:00:00 2001 From: chenky9106 Date: Sat, 4 Jul 2020 10:52:46 +0800 Subject: [PATCH] Add files via upload --- clustering.py | 49 ++++++++++++++++++++++++++++++ confidence_propagation.py | 64 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+) create mode 100644 clustering.py create mode 100644 confidence_propagation.py diff --git a/clustering.py b/clustering.py new file mode 100644 index 0000000..7cda170 --- /dev/null +++ b/clustering.py @@ -0,0 +1,49 @@ +import config +from word_clustering.kmeans import * +from preprocess import json_loader, json_dumper + +def sort_key(i, centroids, seed_centroids): + mean_vec = centroids[i] + max_sim = -1 + for c in seed_centroids: + sim = cos_sim(c, mean_vec) + if sim > max_sim: + max_sim = sim + return max_sim + +def sort_concept(concept, seed_centroids): + global word_dict + mean_vec = concept_rep(concept, word_dict)[0] + max_sim = -1 + for c in seed_centroids: + sim = cos_sim(c, mean_vec) + if sim > max_sim: + max_sim = sim + return max_sim + +def clustering_main(): + json_list = json_loader(config.cluster_concept_path) + concept_dict = {js['name']: js for js in json_list} + concept_list = [js['name'] for js in json_list] + + centroids, cluster_concepts = K_means(concept_list, config.num_clusters, word_dict) + with open(config.input_seed) as f: + seeds = [word.strip() for word in f.readlines()] + seed_centroids, seed_cluster_concepts = K_means(seeds, config.num_seed_clusters, word_dict) + sorted_cluster_concepts_tuple = [(cluster,sort_key(i, centroids, seed_centroids)) for i, cluster in enumerate(cluster_concepts)] + sorted_cluster_concepts_tuple = sorted(sorted_cluster_concepts_tuple, key = lambda x: x[1], reverse=True) + sorted_cluster_concepts_tuple = [(sorted(concept_list, key = lambda x: sort_concept(x, seed_centroids), reverse=True), score)for concept_list, score in sorted_cluster_concepts_tuple] + index = 1 + js_list = [] + for cluster, score in sorted_cluster_concepts_tuple: + for concept in cluster: + temp_js = concept_dict[concept] + temp_js['cluster'] = index + js_list.append(temp_js) + index += 1 + json_dumper(config.cluster_save_path, js_list) + print([(cluster[:5], score) for cluster, score in sorted_cluster_concepts_tuple]) + +if __name__=='__main__': + word_dict = load_word_dict(config.cluster_concept_path ,config.wordvector_path) + clustering_main() \ No newline at end of file diff --git a/confidence_propagation.py b/confidence_propagation.py new file mode 100644 index 0000000..304ed1c --- /dev/null +++ b/confidence_propagation.py @@ -0,0 +1,64 @@ +import config +import argparse +import os +import confidence_propagation.preprocess as preprocess +import confidence_propagation.graph_propagation as graph_propagation +import confidence_propagation.average_distance as average_distance +import confidence_propagation.tf_idf as tf_idf +import confidence_propagation.pagerank as pagerank +import crawler.snippet_crawler as crawler +def main(): + parser = argparse.ArgumentParser(description='process some parameters, the whole parameters are in config.py') + parser.add_argument('-task', type=str, default='expand', choices=['extract', 'expand'], help='extract | expand', required=True) + parser.add_argument('--input_text', '-it', default=config.input_text,type=str, help='the text file for concept extraction task') + parser.add_argument('--input_seed', '-is',default=config.input_seed, type=str, help='the seed file for concept extraction | expansion task') + parser.add_argument('--language', '-l', default='zh', type=str, choices=['zh', 'en'], help='zh | en', required=True) + parser.add_argument('--snippet_source', '-ss', default='baidu', type=str, choices=['baidu', 'google', 'bing'], help='baidu | google | bing') + parser.add_argument('--times', '-t', default=10, type=int, help='iteration times for graph propagation algorithm') + parser.add_argument('--max_num', '-m', default=-1, type=int, help='maximun number for outgoing edges of each node, "-1" means unlimited') + parser.add_argument('--decay', '-d', default=0.8, type=float, help='decay for graph propagation algorithm') + parser.add_argument('--threshold', '-th', default=0.7, type=float, help='similarity threshold for graph edges') + parser.add_argument('--no_seed', '-ns', action='store_true', help='every candidate in text will be a seed') + parser.add_argument('--noun_filter', '-nf', action='store_true', help='remove non noun candidates') + parser.add_argument('--result', '-r', default=config.result_path, type=str, help='result file path') + parser.add_argument('--algorithm', '-a', type=str, default='graph_propagation', choices=['graph_propagation', 'average_distance', 'tf_idf', 'pagerank'], help='graph_propagation | average_distance | tf_idf | pagerank') + args = parser.parse_args() + if not args.input_text and args.task == 'extract': + raise Exception('concept extraction task need input_text') + if not args.input_seed and args.task == 'expand': + raise Exception('concept extraction task need input_text') + if not args.no_seed and not args.input_seed: + raise Exception('seed config error') + config.input_text = args.input_text + config.input_seed = args.input_seed + config.language = args.language + config.snippet_source = args.snippet_source + config.times = args.times + config.max_num = args.max_num + config.decay = args.decay + config.threshold = args.threshold + config.no_seed = True if args.no_seed else False + config.noun_filter = True if args.noun_filter else False + config.result_path = args.result + + if args.task == 'expand': + text = [] + with open(config.input_seed, 'r', encoding='utf-8') as f: + for line in f.read().split('\n'): + if line != '': + text.append(crawler.get_snippet(line)) + config.input_text = config.tmp_input_text + with open(config.input_text, 'w', encoding='utf-8') as f: + f.write('\n'.join(text)) + preprocess.get_candidates() + if args.algorithm == 'graph_propagation': + graph_propagation.get_result() + if args.algorithm == 'average_distance': + average_distance.get_result() + if args.algorithm == 'tf_idf': + tf_idf.get_result() + if args.algorithm == 'pagerank': + pagerank.get_result() + +if __name__ == '__main__': + main() \ No newline at end of file