diff --git a/itext2kg/utils/data_handling.py b/itext2kg/utils/data_handling.py index 21a1ddf..b9fa508 100644 --- a/itext2kg/utils/data_handling.py +++ b/itext2kg/utils/data_handling.py @@ -1,4 +1,5 @@ from typing import Literal, List, Tuple +import re class DataHandler: """ @@ -9,6 +10,45 @@ class DataHandler: def __init__(self): """Initialize the DataHandler instance.""" pass + + def add_embeddings_as_property_batch(self, embeddings_function, items: List[dict], property_name="properties", embeddings_name="embeddings", item_name_key="name", embeddings=True): + """ + Add embeddings as a property to a list of dictionaries (items), such as entities or relations. + + Args: + embeddings_function (function): A function to calculate embeddings for a list of item names. + items (list): A list of dictionaries representing items (entities or relations) to which embeddings will be added. + property_name (str): The key under which embeddings will be stored. + embeddings_name (str): The name of the embeddings key. + item_name_key (str): The key name for the item's name. + embeddings (bool): A flag to determine whether to calculate embeddings. + + Returns: + list: A list of dictionaries with added embeddings. + """ + # Copy the items list and preprocess item names using a list comprehension + items = [ + { + **item, + item_name_key: item[item_name_key].lower().replace("_", " ").replace("-", " "), + property_name: {} + } for item in items + ] + + if embeddings: + # Prepare a list of all item names to calculate embeddings in one shot + item_names = [item[item_name_key] for item in items] + # Calculate embeddings for all item names at once + all_embeddings = embeddings_function(item_names) + + # Use zip to efficiently assign embeddings to each item + items = [ + {**item, property_name: {embeddings_name: embedding}} + for item, embedding in zip(items, all_embeddings) + ] + + return items + def process(self, data: dict, data_type: Literal['entity', 'relation']) -> dict: """ @@ -27,10 +67,10 @@ def process(self, data: dict, data_type: Literal['entity', 'relation']) -> dict: data["startNode"] = data["startNode"].lower() data["endNode"] = data["endNode"].lower() # Replace spaces, dashes, periods, and '&' in names with underscores or 'and'. - data["name"] = data["name"].replace(" ", "_").replace("-", "_").replace(".", "_").replace("&", "and") + data["name"] = re.sub(r'[^a-zA-Z0-9]', '_', data["name"]).replace("&", "and") elif data_type == 'entity': # Replace spaces, dashes, periods, and '&' in labels with underscores or 'and'. - data["label"] = data["label"].replace(" ", "_").replace("-", "_").replace(".", "_").replace("&", "and") + data["label"] = re.sub(r'[^a-zA-Z0-9]', '_', data["label"]).replace("&", "and") return data @@ -139,4 +179,6 @@ def find_isolated_entities(self, global_entities: List[dict], relations: List[di """ relation_nodes = set(rel["startNode"] for rel in relations) | set(rel["endNode"] for rel in relations) isolated_entities = [ent for ent in global_entities if ent["name"] not in relation_nodes] - return isolated_entities \ No newline at end of file + return isolated_entities + + \ No newline at end of file