Skip to content

Commit

Permalink
Merge pull request #155 from zenml-io/feature/elasticsearch-llm-complete
Browse files Browse the repository at this point in the history
Add elastic
  • Loading branch information
htahir1 authored Nov 14, 2024
2 parents 99f97c3 + 2f59199 commit cf029d1
Show file tree
Hide file tree
Showing 12 changed files with 328 additions and 60 deletions.
2 changes: 1 addition & 1 deletion llm-complete-guide/ZENML_VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v0.68.1
v0.70.0
1 change: 1 addition & 0 deletions llm-complete-guide/configs/dev/rag.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ settings:
- pygithub
- rerankers[flashrank]
- matplotlib
- elasticsearch

environment:
ZENML_PROJECT_SECRET_NAME: llm_complete
Expand Down
1 change: 1 addition & 0 deletions llm-complete-guide/configs/dev/rag_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ settings:
- psycopg2-binary
- tiktoken
- pygithub
- elasticsearch
python_package_installer: "uv"
1 change: 1 addition & 0 deletions llm-complete-guide/configs/production/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ settings:
- matplotlib
- pillow
- pygithub
- elasticsearch
environment:
ZENML_PROJECT_SECRET_NAME: llm_complete
ZENML_ENABLE_RICH_TRACEBACK: FALSE
Expand Down
2 changes: 2 additions & 0 deletions llm-complete-guide/configs/production/rag.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ settings:
- pygithub
- rerankers[flashrank]
- matplotlib
- elasticsearch

environment:
ZENML_PROJECT_SECRET_NAME: llm_complete
ZENML_ENABLE_RICH_TRACEBACK: FALSE
Expand Down
1 change: 1 addition & 0 deletions llm-complete-guide/configs/staging/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ settings:
- matplotlib
- pillow
- pygithub
- elasticsearch
environment:
ZENML_PROJECT_SECRET_NAME: llm_complete
ZENML_ENABLE_RICH_TRACEBACK: FALSE
Expand Down
1 change: 1 addition & 0 deletions llm-complete-guide/configs/staging/rag.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ settings:
- pygithub
- rerankers[flashrank]
- matplotlib
- elasticsearch

environment:
ZENML_PROJECT_SECRET_NAME: llm_complete
Expand Down
4 changes: 4 additions & 0 deletions llm-complete-guide/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
384 # Update this to match the dimensionality of the new model
)

# ZenML constants
ZENML_CHATBOT_MODEL = "zenml-docs-qa-chatbot"

# Scraping constants
RATE_LIMIT = 5 # Maximum number of requests per second

Expand Down Expand Up @@ -78,3 +81,4 @@
USE_ARGILLA_ANNOTATIONS = False

SECRET_NAME = os.getenv("ZENML_PROJECT_SECRET_NAME", "llm-complete")
SECRET_NAME_ELASTICSEARCH = "elasticsearch-zenml"
3 changes: 2 additions & 1 deletion llm-complete-guide/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
zenml[server]>=0.68.1
zenml[server]==0.68.1
ratelimit
pgvector
psycopg2-binary
Expand All @@ -20,6 +20,7 @@ datasets
torch
gradio
huggingface-hub
elasticsearch

# optional requirements for S3 artifact store
# s3fs>2022.3.0
Expand Down
18 changes: 16 additions & 2 deletions llm-complete-guide/steps/eval_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

from datasets import load_dataset
from utils.llm_utils import (
find_vectorstore_name,
get_db_conn,
get_embeddings,
get_es_client,
get_topn_similar_docs,
rerank_documents,
)
Expand Down Expand Up @@ -76,11 +78,23 @@ def query_similar_docs(
Tuple containing the question, URL ending, and retrieved URLs.
"""
embedded_question = get_embeddings(question)
db_conn = get_db_conn()
conn = None
es_client = None

vector_store_name = find_vectorstore_name()
if vector_store_name == "pgvector":
conn = get_db_conn()
else:
es_client = get_es_client()

num_docs = 20 if use_reranking else returned_sample_size
# get (content, url) tuples for the top n similar documents
top_similar_docs = get_topn_similar_docs(
embedded_question, db_conn, n=num_docs, include_metadata=True
embedded_question,
conn=conn,
es_client=es_client,
n=num_docs,
include_metadata=True
)

if use_reranking:
Expand Down
194 changes: 154 additions & 40 deletions llm-complete-guide/steps/populate_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,26 @@
# https://www.timescale.com/blog/postgresql-as-a-vector-database-create-store-and-query-openai-embeddings-with-pgvector/
# for providing the base implementation for this indexing functionality

import hashlib
import json
import logging
import math
from typing import Annotated, Any, Dict, List, Tuple
from enum import Enum

from constants import (
CHUNK_OVERLAP,
CHUNK_SIZE,
EMBEDDING_DIMENSIONALITY,
EMBEDDINGS_MODEL,
SECRET_NAME_ELASTICSEARCH,
ZENML_CHATBOT_MODEL,
)
from pgvector.psycopg2 import register_vector
from PIL import Image, ImageDraw, ImageFont
from sentence_transformers import SentenceTransformer
from structures import Document
from utils.llm_utils import get_db_conn, split_documents
from utils.llm_utils import get_db_conn, get_es_client, split_documents
from zenml import ArtifactConfig, log_artifact_metadata, step, log_model_metadata
from zenml.metadata.metadata_types import Uri
from zenml.client import Client
Expand Down Expand Up @@ -592,9 +596,14 @@ def generate_embeddings(
raise


@step
class IndexType(Enum):
ELASTICSEARCH = "elasticsearch"
POSTGRES = "postgres"

@step(enable_cache=False)
def index_generator(
documents: str,
index_type: IndexType = IndexType.ELASTICSEARCH,
) -> None:
"""Generates an index for the given documents.
Expand All @@ -605,13 +614,111 @@ def index_generator(
Args:
documents (str): A JSON string containing the Document objects with generated embeddings.
index_type (IndexType): The type of index to use. Defaults to Elasticsearch.
Raises:
Exception: If an error occurs during the index generation.
"""
conn = None
try:
if index_type == IndexType.ELASTICSEARCH:
_index_generator_elastic(documents)
else:
_index_generator_postgres(documents)

except Exception as e:
logger.error(f"Error in index_generator: {e}")
raise

def _index_generator_elastic(documents: str) -> None:
"""Generates an Elasticsearch index for the given documents."""
try:
es = get_es_client()
index_name = "zenml_docs"

# Create index with mappings if it doesn't exist
if not es.indices.exists(index=index_name):
mappings = {
"mappings": {
"properties": {
"doc_id": {"type": "keyword"},
"content": {"type": "text"},
"token_count": {"type": "integer"},
"embedding": {
"type": "dense_vector",
"dims": EMBEDDING_DIMENSIONALITY,
"index": True,
"similarity": "cosine"
},
"filename": {"type": "text"},
"parent_section": {"type": "text"},
"url": {"type": "text"}
}
}
}
# TODO move to using mappings param directly
es.indices.create(index=index_name, body=mappings)

# Parse the JSON string into a list of Document objects
document_list = [Document(**doc) for doc in json.loads(documents)]
operations = []

for doc in document_list:
content_hash = hashlib.md5(
f"{doc.page_content}{doc.filename}{doc.parent_section}{doc.url}".encode()
).hexdigest()

exists_query = {
"query": {
"term": {
"doc_id": content_hash
}
}
}

if not es.count(index=index_name, body=exists_query)["count"]:
operations.append({
"index": {
"_index": index_name,
"_id": content_hash
}
})

operations.append({
"doc_id": content_hash,
"content": doc.page_content,
"token_count": doc.token_count,
"embedding": doc.embedding,
"filename": doc.filename,
"parent_section": doc.parent_section,
"url": doc.url
})

if operations:
response = es.bulk(operations=operations, timeout="10m")

success_count = sum(1 for item in response['items'] if 'index' in item and item['index']['status'] == 201)
failed_count = len(response['items']) - success_count

logger.info(f"Successfully indexed {success_count} documents")
if failed_count > 0:
logger.warning(f"Failed to index {failed_count} documents")
for item in response['items']:
if 'index' in item and item['index']['status'] != 201:
logger.warning(f"Failed to index document: {item['index']['error']}")
else:
logger.info("No new documents to index")

_log_metadata(index_type=IndexType.ELASTICSEARCH)

except Exception as e:
logger.error(f"Error in Elasticsearch indexing: {e}")
raise

def _index_generator_postgres(documents: str) -> None:
"""Generates a PostgreSQL index for the given documents."""
try:
conn = get_db_conn()

with conn.cursor() as cur:
# Install pgvector if not already installed
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
Expand All @@ -633,7 +740,7 @@ def index_generator(
conn.commit()

register_vector(conn)

# Parse the JSON string into a list of Document objects
document_list = [Document(**doc) for doc in json.loads(documents)]

Expand Down Expand Up @@ -665,6 +772,7 @@ def index_generator(
)
conn.commit()


cur.execute("SELECT COUNT(*) as cnt FROM embeddings;")
num_records = cur.fetchone()[0]
logger.info(f"Number of vector records in table: {num_records}")
Expand All @@ -680,53 +788,59 @@ def index_generator(
)
conn.commit()

_log_metadata(index_type=IndexType.POSTGRES)

except Exception as e:
logger.error(f"Error in index_generator: {e}")
logger.error(f"Error in PostgreSQL indexing: {e}")
raise
finally:
if conn:
conn.close()

# Log the model metadata
prompt = """
You are a friendly chatbot. \
You can answer questions about ZenML, its features and its use cases. \
You respond in a concise, technically credible tone. \
You ONLY use the context from the ZenML documentation to provide relevant
answers. \
You do not make up answers or provide opinions that you don't have
information to support. \
If you are unsure or don't know, just say so. \
"""

client = Client()
CONNECTION_DETAILS = {
def _log_metadata(index_type: IndexType) -> None:
"""Log metadata about the indexing process."""
prompt = """
You are a friendly chatbot. \
You can answer questions about ZenML, its features and its use cases. \
You respond in a concise, technically credible tone. \
You ONLY use the context from the ZenML documentation to provide relevant answers. \
You do not make up answers or provide opinions that you don't have information to support. \
If you are unsure or don't know, just say so. \
"""

client = Client()

if index_type == IndexType.ELASTICSEARCH:
es_host = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values["elasticsearch_host"]
connection_details = {
"host": es_host,
"api_key": "*********",
}
store_name = "elasticsearch"
else:
store_name = "pgvector"

connection_details = {
"user": client.get_secret(SECRET_NAME).secret_values["supabase_user"],
"password": "**********",
"host": client.get_secret(SECRET_NAME).secret_values["supabase_host"],
"port": client.get_secret(SECRET_NAME).secret_values["supabase_port"],
"dbname": "postgres",
}

log_model_metadata(
metadata={
"embeddings": {
"model": EMBEDDINGS_MODEL,
"dimensionality": EMBEDDING_DIMENSIONALITY,
"model_url": Uri(
f"https://huggingface.co/{EMBEDDINGS_MODEL}"
),
},
"prompt": {
"content": prompt,
},
"vector_store": {
"name": "pgvector",
"connection_details": CONNECTION_DETAILS,
# TODO: Hard-coded for now
"database_url": Uri(
"https://supabase.com/dashboard/project/rkoiacgkeiwpwceahtlp/editor/29505?schema=public"
),
},
log_model_metadata(
metadata={
"embeddings": {
"model": EMBEDDINGS_MODEL,
"dimensionality": EMBEDDING_DIMENSIONALITY,
"model_url": Uri(f"https://huggingface.co/{EMBEDDINGS_MODEL}"),
},
)
"prompt": {
"content": prompt,
},
"vector_store": {
"name": store_name,
"connection_details": connection_details,
},
},
)
Loading

0 comments on commit cf029d1

Please sign in to comment.