Skip to content

Commit

Permalink
refactor dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
george1459 committed Sep 28, 2024
1 parent 9303416 commit 19b0563
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 24 deletions.
12 changes: 9 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Package metadata
name = "suql"
version = "1.1.7a7"
version = "1.1.7a8"
description = "Structured and Unstructured Query Language (SUQL) Python API"
author = "Shicheng Liu"
author_email = "[email protected]"
Expand All @@ -18,15 +18,18 @@
'Flask-Cors==4.0.0',
'Flask-RESTful==0.3.10',
'requests==2.31.0',
'spacy==3.6.0',
'tiktoken==0.4.0',
'psycopg2-binary==2.9.7',
'pglast==5.3',
'FlagEmbedding~=1.2.5',
'litellm==1.34.34',
'platformdirs>=4.0.0'
]

install_dev_requires = [
'spacy==3.6.0',
'FlagEmbedding~=1.2.5',
]

# Additional package information
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand All @@ -49,6 +52,9 @@
packages=packages,
package_dir={"": "src"},
install_requires=install_requires,
extra_requires={
"dev": install_dev_requires
},
url=url,
classifiers=classifiers,
package_data={
Expand Down
48 changes: 30 additions & 18 deletions src/suql/faiss_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
from collections import OrderedDict

import os
import faiss
import hashlib
import pickle
from FlagEmbedding import FlagModel
from flask import Flask, request
from tqdm import tqdm
from platformdirs import user_cache_dir
Expand All @@ -21,20 +19,21 @@
# number of rows to consider for multi-column operations
MULTIPLE_COLUMN_SEL = 1000

# currently using https://huggingface.co/BAAI/bge-large-en-v1.5
# change this line for custom embedding model
model = FlagModel(
"BAAI/bge-large-en-v1.5",
query_instruction_for_retrieval="Represent this sentence for searching relevant passages:",
use_fp16=True,
) # Setting use_fp16 to True speeds up computation with a slight performance degradation


def embed_query(query):
"""
Embed a query for dot product matching
"""
# change this line for custom embedding model
# currently using https://huggingface.co/BAAI/bge-large-en-v1.5
# change this line for custom embedding model
from FlagEmbedding import FlagModel

model = FlagModel(
"BAAI/bge-large-en-v1.5",
query_instruction_for_retrieval="Represent this sentence for searching relevant passages:",
use_fp16=True,
) # Setting use_fp16 to True speeds up computation with a slight performance degradation
q_embedding = model.encode_queries([query])
return q_embedding

Expand All @@ -44,6 +43,15 @@ def embed_documents(documents):
Embed a list of docuemnts to store in vector store
"""
# change this line for custom embedding model
# currently using https://huggingface.co/BAAI/bge-large-en-v1.5
# change this line for custom embedding model
from FlagEmbedding import FlagModel

model = FlagModel(
"BAAI/bge-large-en-v1.5",
query_instruction_for_retrieval="Represent this sentence for searching relevant passages:",
use_fp16=True,
) # Setting use_fp16 to True speeds up computation with a slight performance degradation
embeddings = model.encode(documents)
return embeddings

Expand Down Expand Up @@ -90,6 +98,8 @@ def __len__(self):


def compute_top_similarity_documents(documents, query, chunking_param=0, top=3):
import faiss

"""
Directly call the model to compute the top documents based on
dot product with query
Expand Down Expand Up @@ -140,6 +150,8 @@ def __init__(
cache_embedding=True,
force_recompute=False
) -> None:
import faiss
self.faiss = faiss
# stores three lists:
# 1. PSQL primary key for each row
# 2. list of strings in this field
Expand Down Expand Up @@ -257,18 +269,18 @@ def initialize_embedding(self):
if (os.path.exists(faiss_cache_location) and not self.force_recompute):
try:
print(f"initializing from existing faiss embedding index at {faiss_cache_location}")
self.embeddings = faiss.read_index(faiss_cache_location)
self.embeddings = self.faiss.read_index(faiss_cache_location)
return
except Exception:
print(f"reading {faiss_cache_location} failed. Re-computing embeddings")

self.embeddings = faiss.IndexFlatIP(EMBEDDING_DIMENSION)
self.embeddings = self.faiss.IndexFlatIP(EMBEDDING_DIMENSION)
indexs = embed_documents(self.chunked_text)
self.embeddings.add(indexs)

print(f"writing computed faiss embedding to {faiss_cache_location}")
os.makedirs(_user_cache_dir, exist_ok=True)
faiss.write_index(self.embeddings, faiss_cache_location)
self.faiss.write_index(self.embeddings, faiss_cache_location)

def dot_product(self, id_list, query, top, individual_id_list=[]):
# given a list of id and a particular query, return the top ids and documents according to similarity score ranking
Expand All @@ -294,18 +306,18 @@ def dot_product(self, id_list, query, top, individual_id_list=[]):

query_embedding = embed_query(query)

sel = faiss.IDSelectorBatch(embedding_indices)
sel = self.faiss.IDSelectorBatch(embedding_indices)
if top < 0:
D, I = self.embeddings.search(
query_embedding,
len(embedding_indices),
params=faiss.SearchParametersIVF(sel=sel),
params=self.faiss.SearchParametersIVF(sel=sel),
)
else:
if top > min(self.embeddings.ntotal, len(embedding_indices)):
top = min(self.embeddings.ntotal, len(embedding_indices))
D, I = self.embeddings.search(
query_embedding, top, params=faiss.SearchParametersIVF(sel=sel)
query_embedding, top, params=self.faiss.SearchParametersIVF(sel=sel)
)

embeddings_indices_max = I[0]
Expand Down Expand Up @@ -364,11 +376,11 @@ def dot_product_with_value(self, id_list, query, individual_id_list=[]):
# this is actually a 2-D array, matching what faiss expects
query_embedding = embed_query(query)

sel = faiss.IDSelectorBatch(embedding_indices)
sel = self.faiss.IDSelectorBatch(embedding_indices)
D, I = self.embeddings.search(
query_embedding,
MULTIPLE_COLUMN_SEL,
params=faiss.SearchParametersIVF(sel=sel),
params=self.faiss.SearchParametersIVF(sel=sel),
)
embedding_indices = I[0]
dot_products = D[0]
Expand Down
6 changes: 3 additions & 3 deletions src/suql/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import spacy

nlp = spacy.load("en_core_web_sm")
import hashlib

import tiktoken
Expand Down Expand Up @@ -52,6 +49,9 @@ def chunk_text(text, k, use_spacy=True):
if text == "":
return [""]
# in case of using spacy, k is the minimum number of words per chunk
import spacy
nlp = spacy.load("en_core_web_sm")

chunks = [i.text for i in nlp(text).sents]
res = []
carryover = ""
Expand Down

0 comments on commit 19b0563

Please sign in to comment.