Skip to content

Commit

Permalink
ability to switch
Browse files Browse the repository at this point in the history
  • Loading branch information
wjayesh committed Nov 14, 2024
1 parent 62d3e01 commit 4fce90c
Show file tree
Hide file tree
Showing 10 changed files with 278 additions and 64 deletions.
4 changes: 3 additions & 1 deletion llm-complete-guide/configs/dev/rag.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
enable_cache: False
enable_cache: True

# environment configuration
settings:
Expand Down Expand Up @@ -29,3 +29,5 @@ steps:
parameters:
docs_url: https://docs.zenml.io/
use_dev_set: true
index_generator:
enable_cache: False
1 change: 1 addition & 0 deletions llm-complete-guide/configs/dev/rag_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ settings:
- psycopg2-binary
- tiktoken
- pygithub
- elasticsearch
python_package_installer: "uv"
1 change: 1 addition & 0 deletions llm-complete-guide/configs/production/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ settings:
- matplotlib
- pillow
- pygithub
- elasticsearch
environment:
ZENML_PROJECT_SECRET_NAME: llm_complete
ZENML_ENABLE_RICH_TRACEBACK: FALSE
Expand Down
1 change: 1 addition & 0 deletions llm-complete-guide/configs/staging/eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ settings:
- matplotlib
- pillow
- pygithub
- elasticsearch
environment:
ZENML_PROJECT_SECRET_NAME: llm_complete
ZENML_ENABLE_RICH_TRACEBACK: FALSE
Expand Down
2 changes: 1 addition & 1 deletion llm-complete-guide/configs/staging/rag.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ settings:
- rerankers[flashrank]
- matplotlib
- elasticsearch

environment:
ZENML_PROJECT_SECRET_NAME: llm_complete
ZENML_ENABLE_RICH_TRACEBACK: FALSE
Expand Down
3 changes: 3 additions & 0 deletions llm-complete-guide/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
384 # Update this to match the dimensionality of the new model
)

# ZenML constants
ZENML_CHATBOT_MODEL = "zenml-docs-qa-chatbot"

# Scraping constants
RATE_LIMIT = 5 # Maximum number of requests per second

Expand Down
2 changes: 1 addition & 1 deletion llm-complete-guide/pipelines/llm_basic_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from zenml import pipeline


@pipeline
@pipeline(enable_cache=True)
def llm_basic_rag() -> None:
"""Executes the pipeline to train a basic RAG model.
Expand Down
17 changes: 15 additions & 2 deletions llm-complete-guide/steps/eval_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from datasets import load_dataset
from utils.llm_utils import (
find_vectorstore_name,
get_db_conn,
get_embeddings,
get_es_client,
Expand Down Expand Up @@ -77,11 +78,23 @@ def query_similar_docs(
Tuple containing the question, URL ending, and retrieved URLs.
"""
embedded_question = get_embeddings(question)
es_client = get_es_client()
conn = None
es_client = None

vector_store_name = find_vectorstore_name()
if vector_store_name == "pgvector":
conn = get_db_conn()
else:
es_client = get_es_client()

num_docs = 20 if use_reranking else returned_sample_size
# get (content, url) tuples for the top n similar documents
top_similar_docs = get_topn_similar_docs(
embedded_question, es_client, n=num_docs, include_metadata=True
embedded_question,
conn=conn,
es_client=es_client,
n=num_docs,
include_metadata=True
)

if use_reranking:
Expand Down
196 changes: 152 additions & 44 deletions llm-complete-guide/steps/populate_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@
# https://www.timescale.com/blog/postgresql-as-a-vector-database-create-store-and-query-openai-embeddings-with-pgvector/
# for providing the base implementation for this indexing functionality

import hashlib
import json
import logging
import math
from typing import Annotated, Any, Dict, List, Tuple
from enum import Enum

from constants import (
CHUNK_OVERLAP,
CHUNK_SIZE,
EMBEDDING_DIMENSIONALITY,
EMBEDDINGS_MODEL,
SECRET_NAME_ELASTICSEARCH,
ZENML_CHATBOT_MODEL,
)
from pgvector.psycopg2 import register_vector
from PIL import Image, ImageDraw, ImageFont
Expand Down Expand Up @@ -593,9 +596,14 @@ def generate_embeddings(
raise


class IndexType(Enum):
ELASTICSEARCH = "elasticsearch"
POSTGRES = "postgres"

@step(enable_cache=False)
def index_generator(
documents: str,
index_type: IndexType = IndexType.ELASTICSEARCH,
) -> None:
"""Generates an index for the given documents.
Expand All @@ -606,14 +614,23 @@ def index_generator(
Args:
documents (str): A JSON string containing the Document objects with generated embeddings.
index_type (IndexType): The type of index to use. Defaults to Elasticsearch.
Raises:
Exception: If an error occurs during the index generation.
"""
from elasticsearch import Elasticsearch
from elasticsearch.helpers import bulk
import hashlib

try:
if index_type == IndexType.ELASTICSEARCH:
_index_generator_elastic(documents)
else:
_index_generator_postgres(documents)

except Exception as e:
logger.error(f"Error in index_generator: {e}")
raise

def _index_generator_elastic(documents: str) -> None:
"""Generates an Elasticsearch index for the given documents."""
try:
es = get_es_client()
index_name = "zenml_docs"
Expand Down Expand Up @@ -643,16 +660,13 @@ def index_generator(

# Parse the JSON string into a list of Document objects
document_list = [Document(**doc) for doc in json.loads(documents)]

# Prepare bulk operations
operations = []

for doc in document_list:
# Create a unique identifier based on content and metadata
content_hash = hashlib.md5(
f"{doc.page_content}{doc.filename}{doc.parent_section}{doc.url}".encode()
).hexdigest()

# Check if document exists
exists_query = {
"query": {
"term": {
Expand Down Expand Up @@ -694,45 +708,139 @@ def index_generator(
else:
logger.info("No new documents to index")

# Log the model metadata
prompt = """
You are a friendly chatbot. \
You can answer questions about ZenML, its features and its use cases. \
You respond in a concise, technically credible tone. \
You ONLY use the context from the ZenML documentation to provide relevant
answers. \
You do not make up answers or provide opinions that you don't have
information to support. \
If you are unsure or don't know, just say so. \
"""

client = Client()
_log_metadata(index_type=IndexType.ELASTICSEARCH)

except Exception as e:
logger.error(f"Error in Elasticsearch indexing: {e}")
raise

def _index_generator_postgres(documents: str) -> None:
"""Generates a PostgreSQL index for the given documents."""
try:
conn = get_db_conn()

with conn.cursor() as cur:
# Install pgvector if not already installed
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
conn.commit()

# Create the embeddings table if it doesn't exist
table_create_command = f"""
CREATE TABLE IF NOT EXISTS embeddings (
id SERIAL PRIMARY KEY,
content TEXT,
token_count INTEGER,
embedding VECTOR({EMBEDDING_DIMENSIONALITY}),
filename TEXT,
parent_section TEXT,
url TEXT
);
"""
cur.execute(table_create_command)
conn.commit()

register_vector(conn)

# Parse the JSON string into a list of Document objects
document_list = [Document(**doc) for doc in json.loads(documents)]

# Insert data only if it doesn't already exist
for doc in document_list:
content = doc.page_content
token_count = doc.token_count
embedding = doc.embedding
filename = doc.filename
parent_section = doc.parent_section
url = doc.url

cur.execute(
"SELECT COUNT(*) FROM embeddings WHERE content = %s",
(content,),
)
count = cur.fetchone()[0]
if count == 0:
cur.execute(
"INSERT INTO embeddings (content, token_count, embedding, filename, parent_section, url) VALUES (%s, %s, %s, %s, %s, %s)",
(
content,
token_count,
embedding,
filename,
parent_section,
url,
),
)
conn.commit()


cur.execute("SELECT COUNT(*) as cnt FROM embeddings;")
num_records = cur.fetchone()[0]
logger.info(f"Number of vector records in table: {num_records}")

# calculate the index parameters according to best practices
num_lists = max(num_records / 1000, 10)
if num_records > 1000000:
num_lists = math.sqrt(num_records)

# use the cosine distance measure, which is what we'll later use for querying
cur.execute(
f"CREATE INDEX IF NOT EXISTS embeddings_idx ON embeddings USING ivfflat (embedding vector_cosine_ops) WITH (lists = {num_lists});"
)
conn.commit()

_log_metadata(index_type=IndexType.POSTGRES)

except Exception as e:
logger.error(f"Error in PostgreSQL indexing: {e}")
raise
finally:
if conn:
conn.close()

def _log_metadata(index_type: IndexType) -> None:
"""Log metadata about the indexing process."""
prompt = """
You are a friendly chatbot. \
You can answer questions about ZenML, its features and its use cases. \
You respond in a concise, technically credible tone. \
You ONLY use the context from the ZenML documentation to provide relevant answers. \
You do not make up answers or provide opinions that you don't have information to support. \
If you are unsure or don't know, just say so. \
"""

client = Client()

if index_type == IndexType.ELASTICSEARCH:
es_host = client.get_secret(SECRET_NAME_ELASTICSEARCH).secret_values["elasticsearch_host"]
CONNECTION_DETAILS = {
connection_details = {
"host": es_host,
"api_key": "*********",
}
store_name = "elasticsearch"
else:
store_name = "pgvector"

connection_details = {
"user": client.get_secret(SECRET_NAME).secret_values["supabase_user"],
"password": "**********",
"host": client.get_secret(SECRET_NAME).secret_values["supabase_host"],
"port": client.get_secret(SECRET_NAME).secret_values["supabase_port"],
"dbname": "postgres",
}

log_model_metadata(
metadata={
"embeddings": {
"model": EMBEDDINGS_MODEL,
"dimensionality": EMBEDDING_DIMENSIONALITY,
"model_url": Uri(
f"https://huggingface.co/{EMBEDDINGS_MODEL}"
),
},
"prompt": {
"content": prompt,
},
"vector_store": {
"name": "elasticsearch",
"connection_details": CONNECTION_DETAILS,
"index_name": index_name
},
log_model_metadata(
metadata={
"embeddings": {
"model": EMBEDDINGS_MODEL,
"dimensionality": EMBEDDING_DIMENSIONALITY,
"model_url": Uri(f"https://huggingface.co/{EMBEDDINGS_MODEL}"),
},
)

except Exception as e:
logger.error(f"Error in index_generator: {e}")
raise
"prompt": {
"content": prompt,
},
"vector_store": {
"name": store_name,
"connection_details": connection_details,
},
},
)
Loading

0 comments on commit 4fce90c

Please sign in to comment.