diff --git a/paper/network_creation.py b/paper/network_creation.py new file mode 100644 index 0000000..3d52cd7 --- /dev/null +++ b/paper/network_creation.py @@ -0,0 +1,440 @@ +import networkx as nx +import json +from collections import Counter +import matplotlib.pyplot as plt +from matplotlib import rcParams +import os +import numpy as np +import ndjson +import typing +from typing import Dict, Tuple, Optional, List +from networkx.drawing.layout import ( + fruchterman_reingold_layout, + spring_layout, + kamada_kawai_layout, +) + + +def load_json(file_path: str) -> dict: + """Loads a json file. + + Args: + file_path (str): path to the json file + + Returns: + data (dict): the data from the file as a dictionary + """ + with open(file_path, mode="r", encoding="utf8") as f: + data = json.load(f) + return data + + +def most_frequent_tuples( + node_dict: Dict[int, tuple], + n: int, + hard_filter: bool = False, +) -> Dict[int, tuple]: + """Find the tuples that contain one of the n most frequent elements. + + Args: + node_dict (dict): dictionary with node ids as keys and node pairs as values + n (int): The number of most frequent elements to keep + + Returns: + result (dict): filtered dictionary with node ids as keys and node pairs as values + """ + element_counter: typing.Counter = Counter() + # Count the frequency for all elements and identify the n most frequent + for i, tup in node_dict.items(): + for string in tup: + element_counter[string] += 1 + most_common_strings = set(string for string, _ in element_counter.most_common(n)) + highest_freq = element_counter.most_common(1)[0][1] + + # Only keep the nodes that contain one of the most frequent elements + result = {} + for i, tup in node_dict.items(): + if any(string in most_common_strings for string in tup): + fr = sum([element_counter[tup[0]], element_counter[tup[1]]]) + # If hard_filter, keep only pairs of nodes at least at frequent + # as the most frequent node + if hard_filter: + if fr >= highest_freq: + result[i] = tup + else: + result[i] = tup + return result + + +def get_nodes_edges( + event: str, + file: str, + remove_self_edges: bool = True, + remove_custom_nodes: Optional[List[str]] = None, + n_most_frequent: int = 10, + hard_filter: bool = False, + save: Optional[str] = None, +) -> Tuple[Dict[int, tuple], Dict[int, tuple]]: + """Loads nodes and edges from a json file, removes self edges and only + keeps the n most frequent nodes. + + Args: + event (str): the name of the event + file (str): the specific file to load. Must be placed in a folder with the event name + remove_self_edges (bool, optional): Whether or not to remove self edges. Defaults to True. + remove_custom_nodes (list, optional): A list of nodes to remove from the graph. Defaults to None. + n_most_frequent (int, optional): The number of most frequent elements to include. Defaults to 10. + hard_filter (bool, optional): Passed to most_frequent_tuples. Whether or not to filter away + pairs of nodes that are less frequent than the single most frequent. Defaults to False. + save (str, optional): If provided, the filtered nodes and edges will be saved to this path. + + Returns: + Tuple[Dict[int, tuple], Dict[int, tuple]]: a dictionary with the most frequent nodes and + a dictionary with the associated edges + """ + nodes_edges = load_json( + os.path.join(event, file), + ) + nodes = {i: tuple(edge) for i, edge in enumerate(nodes_edges["nodes"])} + edges = {i: edge for i, edge in enumerate(nodes_edges["edges"])} + + # Removing self edges and only keeping the most frequent nodes + if remove_self_edges: + nodes = {i: (e1, e2) for i, (e1, e2) in nodes.items() if e1 != e2} + if remove_custom_nodes: + nodes = { + i: (e1, e2) + for i, (e1, e2) in nodes.items() + if e1 not in remove_custom_nodes and e2 not in remove_custom_nodes + } + most_frequent_nodes = most_frequent_tuples(nodes, n_most_frequent, hard_filter) + associated_edges = { + i: edge for i, edge in edges.items() if i in most_frequent_nodes.keys() + } + if save: + with open(save, "w") as f: + ndjson.dump( + { + "nodes": list(most_frequent_nodes.values()), + "edges": list(associated_edges.values()), + }, + f, + ) + return most_frequent_nodes, associated_edges + + +def quantile_min_value(lst, quantile): + q = np.quantile(lst, quantile) + return min(filter(lambda x: x >= q, lst)) + + +def min_max_normalize(list_to_normalize: list, min_constant=0.5) -> list: + """Normalizes a list between 0 and 1 using min-max normalization. + + Args: + list_to_normalize (list): The list to normalize + + Returns: + list: The normalized list + """ + min_value = min(list_to_normalize) + max_value = max(list_to_normalize) + if min_value == max_value: + return list_to_normalize + scaled_list = [(x - min_value) / (max_value - min_value) for x in list_to_normalize] + return [x + min_constant for x in scaled_list] + + +def create_network_graph( + node_list, + edge_list, + title: Optional[str] = None, + layout=fruchterman_reingold_layout, + k: float = 0.3, + node_size_mult: float = 3000, + node_size_min: float = 0.001, + edge_weight_mult: float = 5, + fontsize: int = 12, + edge_quantile_value: float = 0.90, + node_quantile_value: Optional[float] = None, + node_color: str = "#146D25", + edge_color: str = "#54A463", + fig_size: int = 10, + plot_coordinates: bool = False, + seed: Optional[int] = None, + draw_labels: bool = True, + save=False, +): + G = nx.Graph() + G.add_edges_from(list(node_list.values())) + c = Counter(list(node_list.values())) # edge weights = frequency of edge + for u, v, d in G.edges(data=True): + # Make the graph undirected - for some reason, the tuples are sometimes reversed in the edge list + d["weight"] = c[(u, v)] + c[(v, u)] + + edge_label_weight_cutoff = quantile_min_value(list(c.values()), edge_quantile_value) + edges_to_draw = {} + for n, nodes in node_list.items(): + if c[nodes] >= edge_label_weight_cutoff: + edges_to_draw[nodes] = edge_list[n] + + if layout == kamada_kawai_layout: + pos = layout(G, scale=2) + else: + pos = layout(G, k=k, seed=seed) + + degrees = nx.degree(G) + normalized_degrees = min_max_normalize( + [d[1] for d in degrees], + min_constant=node_size_min, + ) + + plt.figure(figsize=(fig_size, fig_size)) + if title: + plt.title( + title, + color="k", + fontsize=fontsize + 8, + ) + + edge_weights = min_max_normalize([d["weight"] for _, _, d in G.edges(data=True)]) + + nx.draw( + G, + pos, + # node_size=non_norm_degrees, + node_size=[d * node_size_mult for d in normalized_degrees], + node_color=node_color, + edge_color=edge_color + "80", + # width=[d["weight"] ** edge_weight_mult for _, _, d in G.edges(data=True)], + width=[e * edge_weight_mult for e in edge_weights], + ) + if draw_labels: + nx.draw_networkx_edge_labels( + G, + pos, + edge_labels=edges_to_draw, + font_size=fontsize - 1, + label_pos=0.5, + bbox=dict( + facecolor="white", + edgecolor=edge_color, + alpha=0.8, + boxstyle="round,pad=0.2", + ), + ) + + offset = 0.015 + for node, (x, y) in pos.items(): + h_align = "center" + v_align = "center" + if x < 0: + x -= offset + h_align = "right" + if x > 0: + x += offset + h_align = "left" + if y < 0: + y -= offset + v_align = "top" + if y > 0: + y += offset + v_align = "bottom" + if plot_coordinates: + label = f"{node} ({x:.2f}, {y:.2f})" + else: + label = node + if node_quantile_value: + node_label_draw_cutoff = quantile_min_value( + [value for key, value in degrees], + node_quantile_value, + ) + if degrees[node] >= node_label_draw_cutoff: + plt.text( + x, + y, + label, + fontsize=fontsize, + color="k", + ha=h_align, + va=v_align, + bbox=dict( + facecolor="white", + edgecolor=node_color, + alpha=0.8, + boxstyle="round,pad=0.1", + ), + ) + else: + plt.text( + x, + y, + label, + fontsize=fontsize, + # fontname="Helvetica", + color="k", + ha=h_align, + va=v_align, + bbox=dict( + facecolor="white", + edgecolor=node_color, + alpha=0.8, + boxstyle="round,pad=0.1", + ), + ) + if save: + plt.savefig( + f"{save}.png", + format="PNG", + bbox_inches="tight", + ) + return G + + +# Twitter +twitter_week_1_nodes, twitter_week_1_edges = get_nodes_edges( + "extracted_triplets_tweets/covid_week_1", + "paraphrase_dim=40_neigh=15_clust=5_samp=3_nodes_edges.json", + n_most_frequent=3, +) + +# GReens +twitter_week_1_graph = create_network_graph( + twitter_week_1_nodes, + twitter_week_1_edges, + title="Twitter (GPT-3): First week of the lockdown", + k=2.5, + edge_quantile_value=0.9, + save="fig/twitter_week_1", +) + + +# # No få +# twitter_week_1_nodes_rm_få, twitter_week_1_edges_rm_få = get_nodes_edges( +# "extracted_triplets_tweets/covid_week_1", +# "paraphrase_dim=40_neigh=15_clust=5_samp=3_nodes_edges.json", +# remove_custom_nodes=["få"], +# hard_filter=True, +# ) +# twitter_week_1_rm_få = create_network_graph( +# twitter_week_1_nodes_rm_få, +# twitter_week_1_edges_rm_få, +# title="Covid-19 lockdown week 1 - Twitter", +# k=2.5, +# node_color="#A82800", +# fontsize=11, +# save="fig/twitter_week_1_graph_rm_få", +# ) + + +# With old triplet extraction instead of GPT +twitter_week_1_nodes_multi, twitter_week_1_edges_multi = get_nodes_edges( + "extracted_triplets_tweets/covid_week_1_multi", + "paraphrase_dim=40_neigh=15_clust=5_samp=3_nodes_edges.json", +) + +twitter_week_1_graph_multi = create_network_graph( + twitter_week_1_nodes_multi, + twitter_week_1_edges_multi, + title="Twitter (Multi2OIE): First week of the lockdown", + k=3.5, + edge_quantile_value=0.8, + save="fig/twitter_week_1_multi_no_labels", + draw_labels=False, +) + +# Mink start +twitter_mink_start_nodes, twitter_mink_start_edges = get_nodes_edges( + "extracted_triplets_tweets/mink_start", + "paraphrase_dim=40_neigh=15_clust=5_samp=3_nodes_edges.json", + hard_filter=True, +) + +twitter_mink_start_graph = create_network_graph( + twitter_mink_start_nodes, + twitter_mink_start_edges, + title="Twitter (GPT-3): First week of the mink case", + layout=spring_layout, + k=4, + edge_quantile_value=0.83, + save="fig/twitter_mink_start", +) + +twitter_mink_start_nodes_multi, twitter_mink_start_edges_multi = get_nodes_edges( + "extracted_triplets_tweets/mink_start_multi", + "paraphrase_dim=40_neigh=15_clust=5_samp=3_nodes_edges.json", +) + +twitter_mink_start_graph_multi = create_network_graph( + twitter_mink_start_nodes_multi, + twitter_mink_start_edges_multi, + title="Twitter (Multi2OIE): First week of the mink case", + k=3, + save="fig/twitter_mink_start_multi", +) + +### News papers + +# Mink start + +news_mink_start_nodes, news_mink_start_edges = get_nodes_edges( + "extracted_triplets_papers/mink_start", + "paraphrase_dim=40_neigh=15_clust=5_samp=3_nodes_edges.json", + hard_filter=True, +) + +news_mink_start_graph = create_network_graph( + news_mink_start_nodes, + news_mink_start_edges, + title="Newspapers: First week of the mink case", + layout=spring_layout, + k=3, + edge_quantile_value=0.83, + save="fig/news_mink_start", +) + +# mink - Mogens Jensen resigning + +news_mink_mj_nodes, news_mink_mj_edges = get_nodes_edges( + "extracted_triplets_papers/mink_mogens_jensen", + "paraphrase_dim=40_neigh=15_clust=5_samp=3_nodes_edges.json", + hard_filter=True, +) + +news_mink_mj_graph = create_network_graph( + news_mink_mj_nodes, + news_mink_mj_edges, + title="Newspapers: Mogens Jensen's resignation", + layout=spring_layout, + k=2, + save="fig/news_mink_mj", +) + +# Covid week 1 + +news_week_1_nodes, news_week_1_edges = get_nodes_edges( + "extracted_triplets_papers/covid_week_1", + "paraphrase_dim=40_neigh=15_clust=5_samp=3_nodes_edges.json", +) + +news_week_2_nodes, news_week_2_edges = get_nodes_edges( + "extracted_triplets_papers/covid_week_2", + "paraphrase_dim=40_neigh=15_clust=5_samp=3_nodes_edges.json", +) + + +news_week_1_graph = create_network_graph( + news_week_1_nodes, + news_week_1_edges, + title="Newspapers: First week of the lockdown", + save="fig/news_week_1", + k=2.5, +) + +news_week_2_graph = create_network_graph( + news_week_2_nodes, + news_week_2_edges, + title="Newspapers: Second week of the lockdown", + k=2.5, + save="fig/news_week_2", +) diff --git a/paper/umap_hdb.py b/paper/umap_hdb.py new file mode 100644 index 0000000..912d176 --- /dev/null +++ b/paper/umap_hdb.py @@ -0,0 +1,538 @@ +import json +from typing import Tuple, List, Dict, Optional, Union +import os +import spacy +from umap import UMAP +from hdbscan import HDBSCAN +from sentence_transformers import SentenceTransformer +from stop_words import get_stop_words +from sklearn.preprocessing import StandardScaler +from numpy import ndarray +from collections import Counter +import random +import argparse + + +def read_txt(path: str): + with open(path, mode="r", encoding="utf8") as f: + lines = f.read().splitlines() + return lines + + +def triplet_from_line(line: str) -> Union[Tuple[str, str, str], None]: + """Converts a line from a txt file to a triplet. + + Lines that are not exactly three elements are ignored. + All elements in the triplet are stripped of whitespace and lowercased. + Args: + line (str): Line from a txt file + Returns: + triplet (tuple): Triplet + """ + as_list = line.split(", ") + if len(as_list) != 3: + return None + if line in ["Subject, Predicate, Object", "---, ---, ---"]: + return None + return tuple(map(str.strip, map(str.lower, as_list))) # type: ignore + + +def filter_triplets_with_stopwords( + triplets: List[Tuple[str, str, str]], + stopwords: List[str], + soft: bool = True, +) -> List[Tuple[str, str, str]]: + """Filters triplets that contain a stopword. + + Args: + triplets (List[Tuple[str, str, str]]): List of triplets. A triplet is a tuple of three strings. + stopwords (List[str]): List of stopwords + soft (bool): If True, only the subject and object are checked for stopwords. If False, the whole triplet is checked. + Returns: + filtered_triplets (List[Tuple[str, str, str]]): List of triplets without stopwords + """ + filtered_triplets = [] + if soft: + for triplet in triplets: + subject, predicate, obj = triplet + if subject not in stopwords and obj not in stopwords: + filtered_triplets.append(triplet) + else: + for triplet in triplets: + if not any(stopword in triplet for stopword in stopwords): + filtered_triplets.append(triplet) + return filtered_triplets + + +def load_triplets( + file_path: str, + soft_filtering: bool = True, + shuffle: bool = True, +) -> Tuple[list, list, list, list]: + """Loads triplets from a file and filters them. + + Args: + file_name (str): Name of the file to load triplets from + soft_filtering (bool): Whether to use soft filtering or not + shuffle (bool): Whether to shuffle the triplets or not + Returns: + subjects (list): List of subjects + predicates (list): List of predicates + objects (list): List of objects + filtered_triplets (list): List of filtered triplets + """ + triplets_list: List[Tuple[str, str, str]] = [] + data = read_txt(file_path) + triplets_list = [ + triplet_from_line(line) for line in data if triplet_from_line(line) # type: ignore + ] + filtered_triplets = filter_triplets_with_stopwords( + triplets_list, + get_stop_words("danish"), + soft=soft_filtering, + ) + + if shuffle: + random.shuffle(filtered_triplets) + + subjects = [ + triplet[0] + for triplet in filtered_triplets + if triplet[0] not in ["Subject", "---"] + ] + predicates = [ + triplet[1] + for triplet in filtered_triplets + if triplet[0] not in ["Predicate", "---"] + ] + objects = [ + triplet[2] + for triplet in filtered_triplets + if triplet[0] not in ["Object", "---"] + ] + return subjects, predicates, objects, filtered_triplets + + +def freq_of_most_frequent(list_of_strings: List[str]) -> Tuple[str, float]: + """Calculates the frequency of the most frequent element in a list of + strings. + + Frequency is measured as how much of the list the most frequent element takes up, percentage-wise. + Args: + list_of_strings (List[str]): List of strings to find the most frequent element in + Returns: + most_common_string (str), percentage (float): Frequency of the most frequent element, its percentage + """ + most_common = Counter(list_of_strings).most_common(1)[0] + most_common_string = most_common[0] + percentage = most_common[1] / len(list_of_strings) + return most_common_string, percentage + + +def most_frequent_token(list_of_strings: List[str], nlp) -> Tuple[str, float]: + """Finds the most frequent token in a list of strings. + + Args: + list_of_strings (List[str]): List of strings to find the most frequent token in + nlp (spacy.lang): Spacy language model + Returns: + most_common_token (str): Most frequent token + """ + token_list = [] + for string in list_of_strings: + doc = nlp(string) + for token in doc: + token_list.append(token.text) + most_common_token, percentage = freq_of_most_frequent(token_list) + return most_common_token, percentage + + +def get_cluster_label( + cluster: List[Tuple[str, int]], + nlp, + first_cutoff: float = 0.8, + min_cluster_length: int = 10, + second_cutoff: float = 0.3, +) -> Union[str, None]: + """Finds the most frequent token in a cluster. + + Args: + cluster (List[Tuple[str, int]]): Cluster of elements + nlp (spacy.lang): Spacy language model + certain_cutoff (float): Minimum percentage of the most frequent token. + All clusters that have a most frequent token at least as frequent as this value gets that token as label. + Returns: + label (str): The cluster label according to the rules + """ + cluster_strings = [element[0] for element in cluster] + most_common, percentage = freq_of_most_frequent(cluster_strings) + + # If the most frequent token is frequent enough, use it as label + if percentage >= first_cutoff: + return most_common + + else: + # If the cluster is not clearly defined and it is short, + # return None to indicate it should be removed + if len(cluster) < min_cluster_length: + return None + + # Clusters that are large and relatively clearly defined + if percentage >= second_cutoff: + return most_common + + # If the cluster is not clearly defined, find the most frequent token + most_common_token, percentage = most_frequent_token(cluster_strings, nlp) + + # The most frequent token must be frequent enough to be used as label, + # otherwise return None to indicate it should be removed + if percentage >= 0.2: # TODO: Make this a parameter + return most_common_token + else: + return None + + +def cluster_dict(topic_labels: ndarray, input_list: List[str]) -> Dict[int, List[str]]: + """Creates a dictionary containing all elements in a cluster from a + BERTopic model. + + Only clusters with at least `cutoff` elements. + Args: + topic_model (BERTopic): BERTopic model + input_list (List[str]): The list of string that were used to create the BERTopic model + cutoff (int): Minimum number of elements in a cluster + Returns: + cluster_dict (Dict[int, List[str]]): Dictionary of clusters + """ + assert len(topic_labels) == len( + input_list, + ), "Length of topic labels and input list must be equal" + # topic_info = topic_model.get_topic_info() + # relevant_topics = topic_info.loc[topic_info["Count"] >= cutoff] + cluster_dict: Dict[int, list] = {i: [] for i in range(max(topic_labels) + 1)} + topic_tuples = zip(topic_labels, input_list) + for index, (topic_n, element) in enumerate(topic_tuples): + if topic_n != -1: + cluster_dict[topic_n].append((element, index)) + return cluster_dict + + +def label_clusters( + cluster_dict, + nlp, + first_cutoff: float = 0.8, + min_cluster_length: int = 10, + second_cutoff: float = 0.3, + predicates: bool = False, +): + """Labels clusters according to the rules in `get_cluster_label`. + Args: + cluster_dict (Dict[int, List[str]]): Dictionary of clusters to label + nlp (spacy.lang): Spacy language model to use for tokenization of less defined clusters + first_cutoff (float): Minimum percentage of the most frequent element. + All clusters that have a most frequent element at least as frequent as this value gets that element as label. + min_cluster_length (int): Minimum number of elements in a cluster + second_cutoff (float): Minimum percentage of the most frequent element. + This cutoff is only used if the cluster has a less clearly defined label, but is still large enough. + + Returns: + dict_with_labels (Dict[str, Dict[str, List[str]]]): Dictionary of clusters with labels + The keys are the labels, the values are dictionaries with the keys + "cluster" (final elements in the cluster), and + "n_elements" (number of elements in the final cluster) + """ + if predicates: + min_cluster_length = 1 + second_cutoff = 0.0 + dict_with_labels: Dict[str, dict] = {} + # Get the label for each cluster + for cluster in cluster_dict.values(): + cluster_label = get_cluster_label( + cluster, + nlp, + first_cutoff, + min_cluster_length, + second_cutoff, + ) + if cluster_label: # If the cluster is not None + if ( + cluster_label in dict_with_labels.keys() + ): # If the label already exists, merge + dict_with_labels[cluster_label]["cluster"].extend(cluster) + dict_with_labels[cluster_label]["n_elements"] += len(cluster) + else: # If the label does not exist, create a new entry + dict_with_labels[cluster_label] = {} + dict_with_labels[cluster_label]["cluster"] = cluster + dict_with_labels[cluster_label]["n_elements"] = len(cluster) + + # Remove clusters that are too small even after merging identical clusters + clusters_to_keep = { + label: content + for label, content in dict_with_labels.items() + if content["n_elements"] > min_cluster_length + } + return clusters_to_keep + + +def embed_and_cluster( + list_to_embed: List[str], + embedding_model: str = "vesteinn/DanskBERT", + n_dimensions: int = 40, + n_neighbors: int = 15, + min_cluster_size: int = 5, + min_samples: int = 3, + min_topic_size: int = 10, + predicates: bool = False, +): + """Embeds and clusters a list of strings. + + Args: + list_to_embed (List[str]): List of strings to embed and cluster + n_dimensions (int): Number of dimensions to reduce the embedding space to + n_neighbors (int): Number of neighbors to use for UMAP + min_cluster_size (int): Minimum cluster size for HDBscan + min_samples (int): Minimum number of samples for HDBscan + min_topic_size (int): Minimum number of elements in a cluster + Returns: + clusters (Dict[str, Dict[str, List[str]]]): Dictionary of clusters with labels + The keys are the labels, the values are dictionaries with the keys + "cluster" (final elements in the cluster), and + "n_elements" (number of elements in the final cluster) + """ + + embedding_model = SentenceTransformer(embedding_model) + + # Embed and reduce embdding space + print("Embedding and reducing embedding space") + embeddings = embedding_model.encode(list_to_embed) # type: ignore + scaled_embeddings = StandardScaler().fit_transform(embeddings) + reducer = UMAP(n_components=n_dimensions, n_neighbors=n_neighbors) + reduced_embeddings = reducer.fit_transform(scaled_embeddings) + + # Cluster with HDBscan + print("Clustering") + hdbscan_model = HDBSCAN(min_cluster_size=min_cluster_size, min_samples=min_samples) + hdbscan_model.fit(reduced_embeddings) + hdbscan_labels = hdbscan_model.labels_ + assert len(hdbscan_labels) == len( + list_to_embed, + ), "Length of hdbscan labels and input list must be equal" + clusters = cluster_dict(hdbscan_labels, list_to_embed) + + # Label and prune clusters + print("Labeling clusters") + nlp = spacy.load("da_core_news_sm") + labeled_clusters = label_clusters( + clusters, + nlp, + min_cluster_length=min_topic_size, + predicates=predicates, + ) + + return labeled_clusters + + +def create_nodes_and_edges( + subj_obj_clusters: Dict[str, Dict[str, List[str]]], + predicate_clusters: Dict[str, Dict[str, List[str]]], + n_elements: int, + no_predicate_filler: str = "", + save: Optional[Union[bool, str]] = False, +): + """Creates nodes and edges from clusters of subjects, objects and predicates. + Args: + subj_obj_clusters (Dict[str, Dict[str, List[str]]]): Dictionary of clusters with labels + The keys are the labels, the values are dictionaries with the keys + "cluster" (final elements in the cluster), and + "n_elements" (number of elements in the final cluster) + predicate_clusters (Dict[str, Dict[str, List[str]]]): Dictionary of clusters with labels + The keys are the labels, the values are dictionaries with the keys + "cluster" (final elements in the cluster), and + "n_elements" (number of elements in the final cluster) + no_predicate_filler (str): String to use as filler for predicates that do not have a cluster + save (Optional[Union[bool,str]]): If a string, nodes and edges are saved to a json file. + If False, does not save. + + Returns: + nodes, edges (List[Tuple[str, str]], List[str]]): nodes and edges for the graph + """ + + labelled_subjects = {i: "" for i in range(0, n_elements)} + labelled_objects = {i: "" for i in range(0, n_elements)} + + for label, content in subj_obj_clusters.items(): + cluster = content["cluster"] + for _, index in cluster: # type: ignore + if index < n_elements: # type: ignore + labelled_subjects[index] = label # type: ignore + else: + labelled_objects[index - n_elements] = label # type: ignore + labelled_subjects = { # type: ignore + i: label for i, label in labelled_subjects.items() if label != "" # type: ignore + } + labelled_objects = { + i: label for i, label in labelled_objects.items() if label != "" + } + + labelled_predicates = {i: "" for i in range(0, n_elements)} + for label, content in predicate_clusters.items(): + cluster = content["cluster"] + for _, index in cluster: # type: ignore + labelled_predicates[index] = label # type: ignore + labelled_predicates = { + i: label for i, label in labelled_predicates.items() if label != "" + } + + nodes = [] + edges = [] + for s_index, subject in labelled_subjects.items(): + if s_index in labelled_objects.keys(): + nodes.append((subject, labelled_objects[s_index])) + if s_index in labelled_predicates.keys(): + edges.append(labelled_predicates[s_index]) + else: + edges.append(no_predicate_filler) + + if save: + with open(save, "w") as f: + json.dump({"nodes": nodes, "edges": edges}, f) + return nodes, edges + + +def main( + path: str, + embedding_model: str, + dim=40, + n_neighbors=15, + min_cluster_size=5, + min_samples=3, + min_topic_size=20, + save: bool = False, +): + # Load triplets + print("Loading triplets") + subjects, predicates, objects, filtered_triplets = load_triplets( + path, + soft_filtering=True, + shuffle=True, + ) + + if save: + save = path.replace( + "triplets.txt", + f"{embedding_model}_dim={dim}_neigh={n_neighbors}_clust={min_cluster_size}_samp={min_samples}_nodes_edges.json", + ) # type: ignore + + model = ( + "vesteinn/DanskBERT" + if embedding_model == "danskBERT" + else "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) + + print( + f"Dimensions: {dim}, neighbors: {n_neighbors}, min cluster size: {min_cluster_size}, samples: {min_samples}, min topic size: {min_topic_size}", + ) + print("\n_________________\n") + print("Embedding and clustering predicates") + # For predicate, we wanna keep all clusters -> min_topic_size=1 + predicate_clusters = embed_and_cluster( + list_to_embed=predicates, + embedding_model=model, + n_dimensions=dim, + n_neighbors=n_neighbors, + min_cluster_size=min_cluster_size, + min_samples=min_samples, + min_topic_size=1, + predicates=True, + ) + + print("\n_________________\n") + print("Embedding and clustering subjects and objects together") + subj_obj = subjects + objects + subj_obj_clusters = embed_and_cluster( + list_to_embed=subj_obj, + embedding_model=model, + n_dimensions=dim, + n_neighbors=n_neighbors, + min_cluster_size=min_cluster_size, + min_samples=min_samples, + min_topic_size=min_topic_size, + ) + + # Create nodes and edges + print("Creating nodes and edges") + + assert ( + len(subjects) == len(objects) == len(predicates) + ), "Subjects, objects and predicates must have the same length" + nodes, edges = create_nodes_and_edges( + subj_obj_clusters, + predicate_clusters, + len(subjects), + save=save, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-e", + "--event", + type=str, + help="Event to cluster. Must include name of source folder (newspapers or twitter) and event", + ) + parser.add_argument( + "-emb", + "--embedding_model", + type=str, + default="paraphrase", + help="""Which embedding model to use, default is paraphrase. + The other option is danskBERT""", + ) + parser.add_argument( + "-dim", + "--n_dimensions", + type=int, + default=40, + help="Number of dimensions to reduce the embedding space to", + ) + parser.add_argument( + "-neigh", + "--n_neighbors", + type=int, + default=15, + help="Number of neighbors to use for UMAP", + ) + parser.add_argument( + "-min_clust", + "--min_cluster_size", + type=int, + default=5, + help="Minimum cluster size for HDBscan", + ) + parser.add_argument( + "-min_samp", + "--min_samples", + type=int, + default=3, + help="Minimum number of samples for HDBscan", + ) + parser.add_argument( + "-save", + "--save", + type=bool, + default=False, + help="whether or not to save nodes and edges to json file", + ) + + args = parser.parse_args() + path = os.path.join(args.event, "triplets.txt") + main( + path, + embedding_model=args.embedding_model, + dim=args.n_dimensions, + n_neighbors=args.n_neighbors, + min_cluster_size=args.min_cluster_size, + min_samples=args.min_samples, + save=args.save, + )