Skip to content

Commit

Permalink
fix: accelerate the node/rel embeddings by passing the matrix of node…
Browse files Browse the repository at this point in the history
…s/rel at once
  • Loading branch information
lairgiyassir committed Sep 16, 2024
1 parent 8135c4f commit bde7d2c
Showing 1 changed file with 45 additions and 3 deletions.
48 changes: 45 additions & 3 deletions itext2kg/utils/data_handling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Literal, List, Tuple
import re

class DataHandler:
"""
Expand All @@ -9,6 +10,45 @@ class DataHandler:
def __init__(self):
"""Initialize the DataHandler instance."""
pass

def add_embeddings_as_property_batch(self, embeddings_function, items: List[dict], property_name="properties", embeddings_name="embeddings", item_name_key="name", embeddings=True):
"""
Add embeddings as a property to a list of dictionaries (items), such as entities or relations.
Args:
embeddings_function (function): A function to calculate embeddings for a list of item names.
items (list): A list of dictionaries representing items (entities or relations) to which embeddings will be added.
property_name (str): The key under which embeddings will be stored.
embeddings_name (str): The name of the embeddings key.
item_name_key (str): The key name for the item's name.
embeddings (bool): A flag to determine whether to calculate embeddings.
Returns:
list: A list of dictionaries with added embeddings.
"""
# Copy the items list and preprocess item names using a list comprehension
items = [
{
**item,
item_name_key: item[item_name_key].lower().replace("_", " ").replace("-", " "),
property_name: {}
} for item in items
]

if embeddings:
# Prepare a list of all item names to calculate embeddings in one shot
item_names = [item[item_name_key] for item in items]
# Calculate embeddings for all item names at once
all_embeddings = embeddings_function(item_names)

# Use zip to efficiently assign embeddings to each item
items = [
{**item, property_name: {embeddings_name: embedding}}
for item, embedding in zip(items, all_embeddings)
]

return items


def process(self, data: dict, data_type: Literal['entity', 'relation']) -> dict:
"""
Expand All @@ -27,10 +67,10 @@ def process(self, data: dict, data_type: Literal['entity', 'relation']) -> dict:
data["startNode"] = data["startNode"].lower()
data["endNode"] = data["endNode"].lower()
# Replace spaces, dashes, periods, and '&' in names with underscores or 'and'.
data["name"] = data["name"].replace(" ", "_").replace("-", "_").replace(".", "_").replace("&", "and")
data["name"] = re.sub(r'[^a-zA-Z0-9]', '_', data["name"]).replace("&", "and")
elif data_type == 'entity':
# Replace spaces, dashes, periods, and '&' in labels with underscores or 'and'.
data["label"] = data["label"].replace(" ", "_").replace("-", "_").replace(".", "_").replace("&", "and")
data["label"] = re.sub(r'[^a-zA-Z0-9]', '_', data["label"]).replace("&", "and")

return data

Expand Down Expand Up @@ -139,4 +179,6 @@ def find_isolated_entities(self, global_entities: List[dict], relations: List[di
"""
relation_nodes = set(rel["startNode"] for rel in relations) | set(rel["endNode"] for rel in relations)
isolated_entities = [ent for ent in global_entities if ent["name"] not in relation_nodes]
return isolated_entities
return isolated_entities


0 comments on commit bde7d2c

Please sign in to comment.