Skip to content

Commit

Permalink
Added token level embeddings for documents
Browse files Browse the repository at this point in the history
  • Loading branch information
Dicklesworthstone committed May 23, 2024
1 parent da5e74f commit 6aa3129
Show file tree
Hide file tree
Showing 6 changed files with 426 additions and 150 deletions.
4 changes: 3 additions & 1 deletion database_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from embeddings_data_models import Base, TextEmbedding, DocumentEmbedding, Document, TokenLevelEmbedding, TokenLevelEmbeddingBundle, TokenLevelEmbeddingBundleCombinedFeatureVector, AudioTranscript
from embeddings_data_models import Base, TextEmbedding, DocumentEmbedding, DocumentTokenLevelEmbedding, Document, TokenLevelEmbedding, TokenLevelEmbeddingBundle, TokenLevelEmbeddingBundleCombinedFeatureVector, AudioTranscript
from logger_config import setup_logger
import traceback
import asyncio
Expand Down Expand Up @@ -36,6 +36,7 @@ def _get_hash_from_operation(self, operation):
attr_name = {
TextEmbedding: 'text_hash',
DocumentEmbedding: 'file_hash',
DocumentTokenLevelEmbedding: 'file_hash',
Document: 'document_hash',
TokenLevelEmbedding: 'word_hash',
TokenLevelEmbeddingBundle: 'input_text_hash',
Expand Down Expand Up @@ -81,6 +82,7 @@ async def _handle_integrity_error(self, e, write_operation, session):
unique_constraint_msg = {
TextEmbedding: "token_embeddings.text_hash, token_embeddings.llm_model_name",
DocumentEmbedding: "document_embeddings.file_hash, document_embeddings.llm_model_name",
DocumentTokenLevelEmbedding: "document_token_level_embeddings.file_hash, document_token_level_embeddings.llm_model_name",
Document: "documents.document_hash, documents.llm_model_name",
TokenLevelEmbedding: "token_level_embeddings.word_hash, token_level_embeddings.llm_model_name",
TokenLevelEmbeddingBundle: "token_level_embedding_bundles.input_text_hash, token_level_embedding_bundles.llm_model_name",
Expand Down
83 changes: 57 additions & 26 deletions embeddings_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from sqlalchemy.dialects.sqlite import JSON
from sqlalchemy.orm import declarative_base, relationship, validates
from hashlib import sha3_256
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, Field, field_validator
from typing import List, Optional, Union, Dict
from typing_extensions import Annotated
from decouple import config
from sqlalchemy import event
from sqlalchemy.ext.hybrid import hybrid_property
Expand Down Expand Up @@ -36,6 +37,29 @@ def update_text_hash(self, key, text):
self.text_hash = sha3_256(text.encode('utf-8')).hexdigest()
return text


class TokenLevelEmbedding(Base):
__tablename__ = "token_level_embeddings"
id = Column(Integer, primary_key=True, index=True)
word = Column(String, index=True)
word_hash = Column(String, index=True)
llm_model_name = Column(String, index=True)
token_level_embedding_json = Column(String)
ip_address = Column(String)
request_time = Column(DateTime)
response_time = Column(DateTime)
total_time = Column(Float)
document_file_hash = Column(String, ForeignKey('document_token_level_embeddings.file_hash'))
corpus_identifier_string = Column(String, index=True)
document = relationship("DocumentTokenLevelEmbedding", back_populates="token_level_embeddings", foreign_keys=[document_file_hash, corpus_identifier_string])
token_level_embedding_bundle_id = Column(Integer, ForeignKey('token_level_embedding_bundles.id'))
token_level_embedding_bundle = relationship("TokenLevelEmbeddingBundle", back_populates="token_level_embeddings")
__table_args__ = (UniqueConstraint('word_hash', 'llm_model_name', name='_word_hash_model_uc'),)
@validates('word')
def update_word_hash(self, key, word):
self.word_hash = sha3_256(word.encode('utf-8')).hexdigest()
return word

class DocumentEmbedding(Base):
__tablename__ = "document_embeddings"
id = Column(Integer, primary_key=True, index=True)
Expand All @@ -46,6 +70,7 @@ class DocumentEmbedding(Base):
llm_model_name = Column(String, index=True)
corpus_identifier_string = Column(String, index=True)
file_data = Column(LargeBinary) # To store the original file
sentences = Column(String)
document_embedding_results_json = Column(JSON) # To store the embedding results JSON
ip_address = Column(String)
request_time = Column(DateTime)
Expand All @@ -55,46 +80,51 @@ class DocumentEmbedding(Base):
__table_args__ = (UniqueConstraint('file_hash', 'llm_model_name', 'corpus_identifier_string', name='_file_hash_model_corpus_uc'),)
document = relationship("Document", back_populates="document_embeddings", foreign_keys=[document_hash, corpus_identifier_string])

class DocumentTokenLevelEmbedding(Base):
__tablename__ = "document_token_level_embeddings"
id = Column(Integer, primary_key=True, index=True)
document_hash = Column(String, ForeignKey('documents.document_hash'))
filename = Column(String)
mimetype = Column(String)
file_hash = Column(String, index=True)
llm_model_name = Column(String, index=True)
corpus_identifier_string = Column(String, index=True)
file_data = Column(LargeBinary) # To store the original file
sentences = Column(String)
document_embedding_results_json = Column(JSON) # To store the embedding results JSON
ip_address = Column(String)
request_time = Column(DateTime)
response_time = Column(DateTime)
total_time = Column(Float)
token_level_embeddings = relationship("TokenLevelEmbedding", back_populates="document", foreign_keys=[TokenLevelEmbedding.document_file_hash])
__table_args__ = (UniqueConstraint('file_hash', 'llm_model_name', 'corpus_identifier_string', name='_file_hash_model_corpus_uc'),)
document = relationship("Document", back_populates="document_token_level_embeddings", foreign_keys=[document_hash, corpus_identifier_string])

class Document(Base):
__tablename__ = "documents"
id = Column(Integer, primary_key=True, index=True)
llm_model_name = Column(String, index=True)
corpus_identifier_string = Column(String, index=True)
document_hash = Column(String, index=True)
document_embeddings = relationship("DocumentEmbedding", back_populates="document", foreign_keys=[DocumentEmbedding.document_hash])
document_token_level_embeddings = relationship("DocumentTokenLevelEmbedding", back_populates="document", foreign_keys=[DocumentTokenLevelEmbedding.document_hash])
corpus_identifier_string = Column(String, index=True)
def update_hash(self): # Concatenate specific attributes from the document_embeddings relationship
hash_data = "".join([emb.filename + emb.mimetype for emb in self.document_embeddings])
self.document_hash = sha3_256(hash_data.encode('utf-8')).hexdigest()

@event.listens_for(Document.document_embeddings, 'append')
def update_document_hash_on_append(target, value, initiator):
target.update_hash()

@event.listens_for(Document.document_embeddings, 'remove')
def update_document_hash_on_remove(target, value, initiator):
target.update_hash()
@event.listens_for(Document.document_token_level_embeddings, 'append')
def update_document_token_level_hash_on_append(target, value, initiator):
target.update_hash()
@event.listens_for(Document.document_token_level_embeddings, 'remove')
def update_document_token_level_hash_hash_on_remove(target, value, initiator):
target.update_hash()

class TokenLevelEmbedding(Base):
__tablename__ = "token_level_embeddings"
id = Column(Integer, primary_key=True, index=True)
word = Column(String, index=True)
word_hash = Column(String, index=True)
llm_model_name = Column(String, index=True)
corpus_identifier_string = Column(String, index=True)
token_level_embedding_json = Column(String)
ip_address = Column(String)
request_time = Column(DateTime)
response_time = Column(DateTime)
total_time = Column(Float)
token_level_embedding_bundle_id = Column(Integer, ForeignKey('token_level_embedding_bundles.id'))
token_level_embedding_bundle = relationship("TokenLevelEmbeddingBundle", back_populates="token_level_embeddings")
__table_args__ = (UniqueConstraint('word_hash', 'llm_model_name', name='_word_hash_model_uc'),)
@validates('word')
def update_word_hash(self, key, word):
self.word_hash = sha3_256(word.encode('utf-8')).hexdigest()
return word

class TokenLevelEmbeddingBundle(Base):
__tablename__ = "token_level_embedding_bundles"
id = Column(Integer, primary_key=True, index=True)
Expand Down Expand Up @@ -154,9 +184,10 @@ def validate_similarity_measure(cls, value):

class SemanticSearchRequest(BaseModel):
query_text: str
number_of_most_similar_strings_to_return: Optional[int] = 10
llm_model_name: Optional[str] = DEFAULT_MODEL_NAME
corpus_identifier_string: Optional[str] = ""
number_of_most_similar_strings_to_return: int = 10
llm_model_name: str = DEFAULT_MODEL_NAME
corpus_identifier_string: str = ""
use_token_level_embeddings: Annotated[int, Field(ge=0, le=1)] = 0

class SemanticSearchResponse(BaseModel):
query_text: str
Expand Down
17 changes: 11 additions & 6 deletions misc_utility_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logger_config import setup_logger
from embeddings_data_models import TextEmbedding, TokenLevelEmbeddingBundleCombinedFeatureVector
from embeddings_data_models import TextEmbedding, TokenLevelEmbeddingBundle, TokenLevelEmbeddingBundleCombinedFeatureVector
import socket
import os
import re
Expand All @@ -12,7 +12,6 @@
import faiss
from typing import Any
from database_functions import AsyncSessionLocal
from sqlalchemy import text as sql_text
from sqlalchemy import select
from collections import defaultdict
logger = setup_logger()
Expand Down Expand Up @@ -134,11 +133,16 @@ async def build_faiss_indexes(force_rebuild=False):
faiss_indexes = {}
token_faiss_indexes = {} # Separate FAISS indexes for token-level embeddings
associated_texts_by_model = defaultdict(list) # Create a dictionary to store associated texts by model name
associated_token_level_embeddings_by_model = defaultdict(list) # Create a dictionary to store associated token-level embeddings by model name
async with AsyncSessionLocal() as session:
# result = await session.execute(sql_text("SELECT llm_model_name, text, embedding_json FROM embeddings")) # Query regular embeddings
# token_result = await session.execute(sql_text("SELECT llm_model_name, input_text, combined_feature_vector_json FROM token_level_embedding_bundle_combined_feature_vectors")) # Query token-level embeddings
result = await session.execute(select(TextEmbedding.llm_model_name, TextEmbedding.text, TextEmbedding.embedding_json))
token_result = await session.execute(select(TokenLevelEmbeddingBundleCombinedFeatureVector.llm_model_name, TokenLevelEmbeddingBundleCombinedFeatureVector.combined_feature_vector_json, TokenLevelEmbeddingBundleCombinedFeatureVector.token_level_embedding_bundle))
token_result = await session.execute(
select(
TokenLevelEmbeddingBundleCombinedFeatureVector.llm_model_name,
TokenLevelEmbeddingBundleCombinedFeatureVector.combined_feature_vector_json,
TokenLevelEmbeddingBundleCombinedFeatureVector.token_level_embedding_bundle,
).join(TokenLevelEmbeddingBundle)
)
embeddings_by_model = defaultdict(list)
token_embeddings_by_model = defaultdict(list)
for row in result.fetchall(): # Process regular embeddings
Expand All @@ -147,6 +151,7 @@ async def build_faiss_indexes(force_rebuild=False):
embeddings_by_model[llm_model_name].append((row[1], json.loads(row[2])))
for row in token_result.fetchall(): # Process token-level embeddings
llm_model_name = row[0]
associated_token_level_embeddings_by_model[llm_model_name].append(row[1]) # Store the associated token-level embeddings by model name
token_embeddings_by_model[llm_model_name].append(json.loads(row[2]))
for llm_model_name, embeddings in embeddings_by_model.items():
logger.info(f"Building Faiss index over embeddings for model {llm_model_name}...")
Expand All @@ -168,7 +173,7 @@ async def build_faiss_indexes(force_rebuild=False):
token_faiss_index.add(token_embeddings_combined_feature_vector)
token_faiss_indexes[llm_model_name] = token_faiss_index # Store the token-level index by model name
os.environ["FAISS_SETUP_DONE"] = "1"
return faiss_indexes, token_faiss_indexes, associated_texts_by_model
return faiss_indexes, token_faiss_indexes, associated_texts_by_model, associated_token_level_embeddings_by_model

def normalize_logprobs(avg_logprob, min_logprob, max_logprob):
range_logprob = max_logprob - min_logprob
Expand Down
Loading

0 comments on commit 6aa3129

Please sign in to comment.