diff --git a/llm-complete-guide/.assets/rag-pipeline-zenml-cloud.png b/llm-complete-guide/.assets/rag-pipeline-zenml-cloud.png new file mode 100644 index 00000000..8d008e75 Binary files /dev/null and b/llm-complete-guide/.assets/rag-pipeline-zenml-cloud.png differ diff --git a/llm-complete-guide/.assets/supabase-connection-string.png b/llm-complete-guide/.assets/supabase-connection-string.png new file mode 100644 index 00000000..d5759d02 Binary files /dev/null and b/llm-complete-guide/.assets/supabase-connection-string.png differ diff --git a/llm-complete-guide/.assets/supabase-create-project.png b/llm-complete-guide/.assets/supabase-create-project.png new file mode 100644 index 00000000..11d39b2b Binary files /dev/null and b/llm-complete-guide/.assets/supabase-create-project.png differ diff --git a/llm-complete-guide/.assets/tsne.png b/llm-complete-guide/.assets/tsne.png new file mode 100644 index 00000000..ed6df64a Binary files /dev/null and b/llm-complete-guide/.assets/tsne.png differ diff --git a/llm-complete-guide/.assets/umap.png b/llm-complete-guide/.assets/umap.png new file mode 100644 index 00000000..adbc43d1 Binary files /dev/null and b/llm-complete-guide/.assets/umap.png differ diff --git a/llm-complete-guide/.dockerignore b/llm-complete-guide/.dockerignore new file mode 100644 index 00000000..496552c8 --- /dev/null +++ b/llm-complete-guide/.dockerignore @@ -0,0 +1,9 @@ +* +!/pipelines/** +!/steps/** +!/materializers/** +!/evaluate/** +!/finetune/** +!/generate/** +!/lit_gpt/** +!/scripts/** diff --git a/llm-complete-guide/LICENSE b/llm-complete-guide/LICENSE new file mode 100644 index 00000000..75d01fb4 --- /dev/null +++ b/llm-complete-guide/LICENSE @@ -0,0 +1,15 @@ +Apache Software License 2.0 + +Copyright (c) ZenML GmbH 2024. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/llm-complete-guide/README.md b/llm-complete-guide/README.md new file mode 100644 index 00000000..9c7e08d4 --- /dev/null +++ b/llm-complete-guide/README.md @@ -0,0 +1,144 @@ +# 🦜 Production-ready RAG pipelines for chat applications + +This project showcases how you can work up from a simple RAG pipeline to a more complex setup that +involves finetuning embeddings, reranking retrieved documents, and even finetuning the +LLM itself. We'll do this all for a use case relevant to ZenML: a question +answering system that can provide answers to common questions about ZenML. This +will help you understand how to apply the concepts covered in this guide to your +own projects. + +![](.assets/rag-pipeline-zenml-cloud.png) + +Contained within this project is all the code needed to run the full pipelines. +You can follow along [in our guide](https://docs.zenml.io/user-guide/llmops-guide/) to understand the decisions and tradeoffs +behind the pipeline and step code contained here. You'll build a solid understanding of how to leverage +LLMs in your MLOps workflows using ZenML, enabling you to build powerful, +scalable, and maintainable LLM-powered applications. + +This project contains all the pipeline and step code necessary to follow along +with the guide. You'll need a PostgreSQL database to store the embeddings; full +instructions are provided below for how to set that up. + +## 🙏🏻 Inspiration and Credit + +The RAG pipeline relies on code from [this Timescale +blog](https://www.timescale.com/blog/postgresql-as-a-vector-database-create-store-and-query-openai-embeddings-with-pgvector/) +that showcased using PostgreSQL as a vector database. We adapted it for our use +case and adapted it to work with Supabase. + +## 🏃 How to run + +This project showcases production-ready pipelines so we use some cloud +infrastructure to manage the assets. You can run the pipelines locally using a +local PostgreSQL database, but we encourage you to use a cloud database for +production use cases. + +### Connecting to ZenML Cloud + +If you run the pipeline using ZenML Cloud you'll have access to the managed +dashboard which will allow you to get started quickly. We offer a free trial so +you can try out the platform without any cost. Visit the [ZenML Cloud +dashboard](https://cloud.zenml.io/) to get started. + +### Setting up Supabase + +[Supabase](https://supabase.com/) is a cloud provider that provides a PostgreSQL database. It's simple to +use and has a free tier that should be sufficient for this project. Once you've +created a Supabase account and organisation, you'll need to create a new +project. + +![](.assets/supabase-create-project.png) + +You'll then want to connect to this database instance by getting the connection +string from the Supabase dashboard. + +![](.assets/supabase-connection-string.png) + +You'll then use these details to populate some environment variables where the pipeline code expects them: + +```shell +export ZENML_SUPABASE_USER= +export ZENML_SUPABASE_HOST= +export ZENML_SUPABASE_PORT= +``` + +You'll want to save the Supabase database password as a ZenML secret so that it +isn't stored in plaintext. You can do this by running the following command: + +```shell +zenml secret create supabase_postgres_db --password="YOUR_PASSWORD" +``` + +### Running the RAG pipeline + +To run the pipeline, you can use the `run.py` script. This script will allow you +to run the pipelines in the correct order. You can run the script with the +following command: + +```shell +python run.py --basic-rag +``` + +This will run the basic RAG pipeline, which scrapes the ZenML documentation and stores the embeddings in the Supabase database. + +### Querying your RAG pipeline assets + +Once the pipeline has run successfully, you can query the assets in the Supabase +database using the `--rag-query` flag as well as passing in the model you'd like +to use for the LLM. + +In order to use the default LLM for this query, you'll need an account +and an API key from OpenAI specified as another environment variable: + +```shell +export OPENAI_API_KEY= +``` + +When you're ready to make the query, run the following command: + +```shell +python run.py --rag-query "how do I use a custom materializer inside my own zenml steps? i.e. how do I set it? inside the @step decorator?" --model=gpt4 +``` + +Alternative options for LLMs to use include: + +- `gpt4` +- `gpt35` +- `claude3` +- `claudehaiku` + +Note that Claude will require a different API key from Anthropic. See [the +`litellm` docs](https://docs.litellm.ai/docs/providers/anthropic) on how to set this up. + +## ☁️ Running with a remote stack + +The basic RAG pipeline will run using a local stack, but if you want to improve +the speed of the embeddings step you might want to consider using a cloud +orchestrator. Please follow the instructions in [our basic cloud setup guides](https://docs.zenml.io/user-guide/cloud-guide) +(currently available for [AWS](https://docs.zenml.io/user-guide/cloud-guide/aws-guide) and [GCP](https://docs.zenml.io/user-guide/cloud-guide/gcp-guide)) to learn how you can run the pipelines on +a remote stack. + +## 📜 Project Structure + +The project loosely follows [the recommended ZenML project structure](https://docs.zenml.io/user-guide/starter-guide/follow-best-practices): + +``` +. +├── LICENSE # License file +├── README.md # This file +├── constants.py # Constants for the project +├── pipelines +│   ├── __init__.py +│   └── llm_basic_rag.py # Basic RAG pipeline +├── requirements.txt # Requirements file +├── run.py # Script to run the pipelines +├── steps +│   ├── __init__.py +│   ├── populate_index.py # Step to populate the index +│   ├── url_scraper.py # Step to scrape the URLs +│   ├── url_scraping_utils.py # Utilities for the URL scraper +│   └── web_url_loader.py # Step to load the URLs +└── utils + ├── __init__.py + └── llm_utils.py # Utilities related to the LLM +``` diff --git a/llm-complete-guide/constants.py b/llm-complete-guide/constants.py new file mode 100644 index 00000000..bcc81902 --- /dev/null +++ b/llm-complete-guide/constants.py @@ -0,0 +1,19 @@ +# Vector Store constants +CHUNK_SIZE = 500 +CHUNK_OVERLAP = 50 +EMBEDDING_DIMENSIONALITY = ( + 384 # Update this to match the dimensionality of the new model +) + +# Scraping constants +RATE_LIMIT = 5 # Maximum number of requests per second + +# LLM Utils constants +OPENAI_MODEL = "gpt-3.5-turbo" +EMBEDDINGS_MODEL = "sentence-transformers/all-MiniLM-L12-v2" +MODEL_NAME_MAP = { + "gpt4": "gpt-4-0125-preview", + "gpt35": "gpt-3.5-turbo", + "claude3": "claude-3-opus-20240229", + "claudehaiku": "claude-3-haiku-20240307", +} diff --git a/llm-complete-guide/materializers/__init__.py b/llm-complete-guide/materializers/__init__.py new file mode 100644 index 00000000..757bd841 --- /dev/null +++ b/llm-complete-guide/materializers/__init__.py @@ -0,0 +1,16 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/llm-complete-guide/most_basic_rag_pipeline.py b/llm-complete-guide/most_basic_rag_pipeline.py new file mode 100644 index 00000000..f01f55f8 --- /dev/null +++ b/llm-complete-guide/most_basic_rag_pipeline.py @@ -0,0 +1,87 @@ +import os +import re +import string + +from openai import OpenAI + + +def preprocess_text(text): + text = text.lower() + text = text.translate(str.maketrans("", "", string.punctuation)) + text = re.sub(r"\s+", " ", text).strip() + return text + + +def tokenize(text): + return preprocess_text(text).split() + + +def retrieve_relevant_chunks(query, corpus, top_n=2): + query_tokens = set(tokenize(query)) + similarities = [] + for chunk in corpus: + chunk_tokens = set(tokenize(chunk)) + similarity = len(query_tokens.intersection(chunk_tokens)) / len( + query_tokens.union(chunk_tokens) + ) + similarities.append((chunk, similarity)) + similarities.sort(key=lambda x: x[1], reverse=True) + return [chunk for chunk, _ in similarities[:top_n]] + + +def answer_question(query, corpus, top_n=2): + relevant_chunks = retrieve_relevant_chunks(query, corpus, top_n) + if not relevant_chunks: + return "I don't have enough information to answer the question." + + context = "\n".join(relevant_chunks) + client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + chat_completion = client.chat.completions.create( + messages=[ + { + "role": "system", + "content": f"Based on the provided context, answer the following question: {query}\n\nContext:\n{context}", + }, + { + "role": "user", + "content": query, + }, + ], + model="gpt-3.5-turbo", + ) + + return chat_completion.choices[0].message.content.strip() + + +# Sci-fi themed corpus about "ZenML World" +corpus = [ + "The luminescent forests of ZenML World are inhabited by glowing Zenbots that emit a soft, pulsating light as they roam the enchanted landscape.", + "In the neon skies of ZenML World, Cosmic Butterflies flutter gracefully, their iridescent wings leaving trails of stardust in their wake.", + "Telepathic Treants, ancient sentient trees, communicate through the quantum neural network that spans the entire surface of ZenML World, sharing wisdom and knowledge.", + "Deep within the melodic caverns of ZenML World, Fractal Fungi emit pulsating tones that resonate through the crystalline structures, creating a symphony of otherworldly sounds.", + "Near the ethereal waterfalls of ZenML World, Holographic Hummingbirds hover effortlessly, their translucent wings refracting the prismatic light into mesmerizing patterns.", + "Gravitational Geckos, masters of anti-gravity, traverse the inverted cliffs of ZenML World, defying the laws of physics with their extraordinary abilities.", + "Plasma Phoenixes, majestic creatures of pure energy, soar above the chromatic canyons of ZenML World, their fiery trails painting the sky in a dazzling display of colors.", + "Along the prismatic shores of ZenML World, Crystalline Crabs scuttle and burrow, their transparent exoskeletons refracting the light into a kaleidoscope of hues.", +] + +# Preprocess the corpus +corpus = [preprocess_text(sentence) for sentence in corpus] + +# Ask questions +question1 = "What are Plasma Phoenixes?" +answer1 = answer_question(question1, corpus) +print(f"Question: {question1}") +print(f"Answer: {answer1}") + +question2 = ( + "What kinds of creatures live on the prismatic shores of ZenML World?" +) +answer2 = answer_question(question2, corpus) +print(f"Question: {question2}") +print(f"Answer: {answer2}") + +irrelevant_question_3 = "What is the capital of Panglossia?" +answer3 = answer_question(irrelevant_question_3, corpus) +print(f"Question: {irrelevant_question_3}") +print(f"Answer: {answer3}") diff --git a/llm-complete-guide/pipelines/__init__.py b/llm-complete-guide/pipelines/__init__.py new file mode 100644 index 00000000..820059e9 --- /dev/null +++ b/llm-complete-guide/pipelines/__init__.py @@ -0,0 +1,17 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from pipelines.llm_basic_rag import llm_basic_rag diff --git a/llm-complete-guide/pipelines/llm_basic_rag.py b/llm-complete-guide/pipelines/llm_basic_rag.py new file mode 100644 index 00000000..0a4e5fdd --- /dev/null +++ b/llm-complete-guide/pipelines/llm_basic_rag.py @@ -0,0 +1,26 @@ +from steps.populate_index import ( + generate_embeddings, + index_generator, + preprocess_documents, +) +from steps.url_scraper import url_scraper +from steps.web_url_loader import web_url_loader +from zenml import pipeline + + +@pipeline +def llm_basic_rag() -> None: + """Executes the pipeline to train a basic RAG model. + + This function performs the following steps: + 1. Scrapes URLs using the url_scraper function. + 2. Loads documents from the scraped URLs using the web_url_loader function. + 3. Preprocesses the loaded documents using the preprocess_documents function. + 4. Generates embeddings for the preprocessed documents using the generate_embeddings function. + 5. Generates an index for the embeddings and documents using the index_generator function. + """ + urls = url_scraper() + docs = web_url_loader(urls=urls) + processed_docs = preprocess_documents(documents=docs) + embeddings = generate_embeddings(split_documents=processed_docs) + index_generator(embeddings=embeddings, documents=docs) diff --git a/llm-complete-guide/requirements.txt b/llm-complete-guide/requirements.txt new file mode 100644 index 00000000..2fab8547 --- /dev/null +++ b/llm-complete-guide/requirements.txt @@ -0,0 +1,13 @@ +zenml +langchain-community +ratelimit +langchain>=0.0.325 +langchain-openai +pgvector +psycopg2-binary +beautifulsoup4 +unstructured +pandas +numpy +sentence-transformers +litellm diff --git a/llm-complete-guide/run.py b/llm-complete-guide/run.py new file mode 100644 index 00000000..f89da89b --- /dev/null +++ b/llm-complete-guide/run.py @@ -0,0 +1,93 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +import click +from constants import OPENAI_MODEL +from pipelines import ( + llm_basic_rag, +) +from utils.llm_utils import process_input_with_retrieval +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@click.command( + help=""" +ZenML LLM Complete Guide project CLI v0.1.0. + +Run the ZenML LLM RAG complete guide project pipelines. +""" +) +@click.option( + "--basic-rag", + "basic_rag", + is_flag=True, + default=False, + help="Whether to run the pipeline that creates the dataset.", +) +@click.option( + "--rag-query", + "rag_query", + type=str, + required=False, + help="Query the RAG model.", +) +@click.option( + "--model", + "model", + type=click.Choice( + [ + "gpt4", + "gpt35", + "claude3", + "claudehaiku", + ] + ), + required=False, + default="gpt4", + help="The model to use for the completion.", +) +def main( + basic_rag: bool = False, + rag_query: Optional[str] = None, + model: str = OPENAI_MODEL, + no_cache: bool = False, +): + """Main entry point for the pipeline execution. + + Args: + basic_rag (bool): If `True`, the basic RAG pipeline will be run. + rag_query (Optional[str]): If provided, the RAG model will be queried with this string. + model (str): The model to use for the completion. Default is OPENAI_MODEL. + no_cache (bool): If `True`, cache will be disabled. + + """ + pipeline_args = {"enable_cache": not no_cache} + + if rag_query: + response = process_input_with_retrieval(rag_query, model=model) + print(response) + + if basic_rag: + llm_basic_rag.with_options(**pipeline_args)() + + +if __name__ == "__main__": + main() diff --git a/llm-complete-guide/steps/__init__.py b/llm-complete-guide/steps/__init__.py new file mode 100644 index 00000000..757bd841 --- /dev/null +++ b/llm-complete-guide/steps/__init__.py @@ -0,0 +1,16 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/llm-complete-guide/steps/populate_index.py b/llm-complete-guide/steps/populate_index.py new file mode 100644 index 00000000..ad664e91 --- /dev/null +++ b/llm-complete-guide/steps/populate_index.py @@ -0,0 +1,170 @@ +# credit to +# https://www.timescale.com/blog/postgresql-as-a-vector-database-create-store-and-query-openai-embeddings-with-pgvector/ +# for providing the base implementation for this indexing functionality + +import logging +import math +from typing import Annotated, List + +import numpy as np +from constants import ( + CHUNK_OVERLAP, + CHUNK_SIZE, + EMBEDDING_DIMENSIONALITY, + EMBEDDINGS_MODEL, +) +from pgvector.psycopg2 import register_vector +from sentence_transformers import SentenceTransformer +from utils.llm_utils import get_db_conn, split_documents +from zenml import ArtifactConfig, log_artifact_metadata, step + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +@step(enable_cache=False) +def preprocess_documents( + documents: List[str], +) -> Annotated[List[str], ArtifactConfig(name="split_chunks")]: + """ + Preprocesses a list of documents by splitting them into chunks. + + Args: + documents (List[str]): A list of documents to be preprocessed. + + Returns: + Annotated[List[str], ArtifactConfig(name="split_chunks")]: A list of preprocessed documents annotated with an ArtifactConfig. + + Raises: + Exception: If an error occurs during preprocessing. + """ + try: + log_artifact_metadata( + artifact_name="split_chunks", + metadata={ + "chunk_size": CHUNK_SIZE, + "chunk_overlap": CHUNK_OVERLAP, + }, + ) + return split_documents( + documents, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP + ) + except Exception as e: + logger.error(f"Error in preprocess_documents: {e}") + raise + + +@step(enable_cache=False) +def generate_embeddings( + split_documents: List[str], +) -> Annotated[np.ndarray, ArtifactConfig(name="embeddings")]: + """ + Generates embeddings for a list of split documents using a SentenceTransformer model. + + Args: + split_documents (List[str]): A list of documents that have been split into chunks. + + Returns: + Annotated[np.ndarray, ArtifactConfig(name="embeddings")]: The generated embeddings for each document chunk, annotated with an ArtifactConfig. + + Raises: + Exception: If an error occurs during the generation of embeddings. + """ + try: + model = SentenceTransformer(EMBEDDINGS_MODEL) + + log_artifact_metadata( + artifact_name="embeddings", + metadata={ + "embedding_type": EMBEDDINGS_MODEL, + "embedding_dimensionality": EMBEDDING_DIMENSIONALITY, + }, + ) + return model.encode(split_documents) + except Exception as e: + logger.error(f"Error in generate_embeddings: {e}") + raise + + +@step(enable_cache=False) +def index_generator( + embeddings: np.ndarray, + documents: List[str], +) -> None: + """ + Generates an index for the given embeddings and documents. + + This function creates a database connection, installs the pgvector extension if not already installed, + creates an embeddings table if it doesn't exist, and inserts the embeddings and documents into the table. + It then calculates the index parameters according to best practices and creates an index on the embeddings + using the cosine distance measure. + + Args: + embeddings (np.ndarray): The embeddings to index. + documents (List[str]): The documents corresponding to the embeddings. + + Raises: + Exception: If an error occurs during the index generation. + """ + try: + conn = get_db_conn() + with conn.cursor() as cur: + # Install pgvector if not already installed + cur.execute("CREATE EXTENSION IF NOT EXISTS vector") + conn.commit() + + # Create the embeddings table if it doesn't exist + table_create_command = f""" + CREATE TABLE IF NOT EXISTS embeddings ( + id SERIAL PRIMARY KEY, + content TEXT, + tokens INTEGER, + embedding VECTOR({EMBEDDING_DIMENSIONALITY}) + ); + """ + cur.execute(table_create_command) + conn.commit() + + register_vector(conn) + + # Insert data only if it doesn't already exist + for i, doc in enumerate(documents): + content = doc + tokens = len( + content.split() + ) # Approximate token count based on word count + embedding = embeddings[i].tolist() + + cur.execute( + "SELECT COUNT(*) FROM embeddings WHERE content = %s", + (content,), + ) + count = cur.fetchone()[0] + if count == 0: + cur.execute( + "INSERT INTO embeddings (content, tokens, embedding) VALUES (%s, %s, %s)", + (content, tokens, embedding), + ) + conn.commit() + + cur.execute("SELECT COUNT(*) as cnt FROM embeddings;") + num_records = cur.fetchone()[0] + logger.info(f"Number of vector records in table: {num_records}") + + # calculate the index parameters according to best practices + num_lists = max(num_records / 1000, 10) + if num_records > 1000000: + num_lists = math.sqrt(num_records) + + # use the cosine distance measure, which is what we'll later use for querying + cur.execute( + f"CREATE INDEX IF NOT EXISTS embeddings_idx ON embeddings USING ivfflat (embedding vector_cosine_ops) WITH (lists = {num_lists});" + ) + conn.commit() + + except Exception as e: + logger.error(f"Error in index_generator: {e}") + raise + finally: + if conn: + conn.close() diff --git a/llm-complete-guide/steps/url_scraper.py b/llm-complete-guide/steps/url_scraper.py new file mode 100644 index 00000000..213c6994 --- /dev/null +++ b/llm-complete-guide/steps/url_scraper.py @@ -0,0 +1,51 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import List + +from typing_extensions import Annotated +from zenml import log_artifact_metadata, step + +from steps.url_scraping_utils import get_all_pages + + +@step +def url_scraper( + docs_url: str = "https://docs.zenml.io", + repo_url: str = "https://github.com/zenml-io/zenml", + website_url: str = "https://zenml.io", +) -> Annotated[List[str], "urls"]: + """Generates a list of relevant URLs to scrape. + + Args: + docs_url: URL to the documentation. + repo_url: URL to the repository. + release_notes_url: URL to the release notes. + website_url: URL to the website. + + Returns: + List of URLs to scrape. + """ + # We comment this out to make this pipeline faster + # examples_readme_urls = get_nested_readme_urls(repo_url) + docs_urls = get_all_pages(docs_url) + # website_urls = get_all_pages(website_url) + # all_urls = docs_urls + website_urls + examples_readme_urls + all_urls = docs_urls + log_artifact_metadata( + metadata={ + "count": len(all_urls), + }, + ) + return all_urls diff --git a/llm-complete-guide/steps/url_scraping_utils.py b/llm-complete-guide/steps/url_scraping_utils.py new file mode 100644 index 00000000..4da721f1 --- /dev/null +++ b/llm-complete-guide/steps/url_scraping_utils.py @@ -0,0 +1,182 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import re +from functools import lru_cache +from logging import getLogger +from time import sleep +from typing import List, Set, Tuple +from urllib.parse import urljoin, urlparse + +import requests +from bs4 import BeautifulSoup +from constants import RATE_LIMIT +from ratelimit import limits, sleep_and_retry + +logger = getLogger(__name__) + + +def is_valid_url(url: str, base: str) -> bool: + """ + Check if the given URL is valid, has the same base as the provided base, + and does not contain any version-specific paths. + + Args: + url (str): The URL to check. + base (str): The base URL to compare against. + + Returns: + bool: True if the URL is valid, has the same base, and does not contain version-specific paths, False otherwise. + """ + parsed = urlparse(url) + if not bool(parsed.netloc) or parsed.netloc != base: + return False + + # Check if the URL contains a version pattern (e.g., /v/0.x.x/) + version_pattern = r"/v/0\.\d+\.\d+/" + return not re.search(version_pattern, url) + + +def get_all_pages(url: str) -> List[str]: + """ + Retrieve all pages with the same base as the given URL. + + Args: + url (str): The URL to retrieve pages from. + + Returns: + List[str]: A list of all pages with the same base. + """ + logger.info(f"Scraping all pages from {url}...") + base_url = urlparse(url).netloc + pages = crawl(url, base_url) + logger.info(f"Found {len(pages)} pages.") + logger.info("Done scraping pages.") + return list(pages) + + +def crawl(url: str, base: str, visited: Set[str] = None) -> Set[str]: + """ + Recursively crawl a URL and its links, retrieving all valid links with the same base. + + Args: + url (str): The URL to crawl. + base (str): The base URL to compare against. + visited (Set[str]): A set of URLs that have been visited. Defaults to None. + + Returns: + Set[str]: A set of all valid links with the same base. + """ + if visited is None: + visited = set() + + visited.add(url) + logger.debug(f"Crawling URL: {url}") + links = get_all_links(url, base) + + for link in links: + if link not in visited: + visited.update(crawl(link, base, visited)) + sleep(1 / RATE_LIMIT) # Rate limit the recursive calls + + return visited + + +@sleep_and_retry +@limits(calls=RATE_LIMIT, period=1) +@lru_cache(maxsize=128) +def get_all_links(url: str, base: str) -> List[str]: + """ + Retrieve all valid links from a given URL with the same base. + + Args: + url (str): The URL to retrieve links from. + base (str): The base URL to compare against. + + Returns: + List[str]: A list of valid links with the same base. + """ + logger.debug(f"Retrieving links from {url}") + response = requests.get(url) + soup = BeautifulSoup(response.text, "html.parser") + links = [] + + for link in soup.find_all("a", href=True): + href = link["href"] + full_url = urljoin(url, href) + parsed_url = urlparse(full_url) + cleaned_url = parsed_url._replace(fragment="").geturl() + if is_valid_url(cleaned_url, base): + links.append(cleaned_url) + + logger.debug(f"Found {len(links)} valid links from {url}") + return links + + +@sleep_and_retry +@limits(calls=RATE_LIMIT, period=1) +@lru_cache(maxsize=128) +def get_readme_urls(repo_url: str) -> Tuple[List[str], List[str]]: + """ + Retrieve folder and README links from a GitHub repository. + + Args: + repo_url (str): The URL of the GitHub repository. + + Returns: + Tuple[List[str], List[str]]: A tuple containing two lists: folder links and README links. + """ + logger.debug(f"Retrieving README links from {repo_url}") + headers = {"Accept": "application/vnd.github+json"} + r = requests.get(repo_url, headers=headers) + soup = BeautifulSoup(r.text, "html.parser") + + folder_links = [] + readme_links = [] + + for link in soup.find_all("a", class_="js-navigation-open Link--primary"): + href = link["href"] + full_url = f"https://github.com{href}" + if "tree" in href: + folder_links.append(full_url) + elif "README.md" in href: + readme_links.append(full_url) + + logger.debug( + f"Found {len(folder_links)} folder links and {len(readme_links)} README links from {repo_url}" + ) + return folder_links, readme_links + + +def get_nested_readme_urls(repo_url: str) -> List[str]: + """ + Retrieve all nested README links from a GitHub repository. + + Args: + repo_url (str): The URL of the GitHub repository. + + Returns: + List[str]: A list of all nested README links. + """ + logger.info(f"Retrieving nested README links from {repo_url}...") + folder_links, readme_links = get_readme_urls(repo_url) + + for folder_link in folder_links: + _, nested_readme_links = get_readme_urls(folder_link) + readme_links.extend(nested_readme_links) + + logger.info( + f"Found {len(readme_links)} nested README links from {repo_url}" + ) + return readme_links diff --git a/llm-complete-guide/steps/web_url_loader.py b/llm-complete-guide/steps/web_url_loader.py new file mode 100644 index 00000000..e953f4b8 --- /dev/null +++ b/llm-complete-guide/steps/web_url_loader.py @@ -0,0 +1,36 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from typing import List + +from unstructured.partition.html import partition_html +from zenml import step + + +@step +def web_url_loader(urls: List[str]) -> List[str]: + """Loads documents from a list of URLs. + + Args: + urls: List of URLs to load documents from. + + Returns: + List of langchain documents. + """ + document_texts = [] + for url in urls: + elements = partition_html(url=url) + text = "\n\n".join([str(el) for el in elements]) + document_texts.append(text) + return document_texts diff --git a/llm-complete-guide/utils/__init__.py b/llm-complete-guide/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/llm-complete-guide/utils/llm_utils.py b/llm-complete-guide/utils/llm_utils.py new file mode 100644 index 00000000..9a204145 --- /dev/null +++ b/llm-complete-guide/utils/llm_utils.py @@ -0,0 +1,292 @@ +# credit to langchain for the original base implementation of splitting +# functionality +# https://github.com/langchain-ai/langchain/blob/master/libs/text-splitters/langchain_text_splitters/character.py + + +import logging +import os +import re +from typing import Dict, List + +import litellm +import numpy as np +import psycopg2 +from constants import EMBEDDINGS_MODEL, MODEL_NAME_MAP, OPENAI_MODEL +from pgvector.psycopg2 import register_vector +from psycopg2.extensions import connection +from sentence_transformers import SentenceTransformer +from zenml.client import Client + +# Configure the logging level for the root logger +logging.getLogger().setLevel(logging.WARNING) + +logger = logging.getLogger(__name__) + + +def split_text_with_regex( + text: str, separator: str, keep_separator: bool +) -> List[str]: + """Splits a given text using a specified separator. + + This function splits the input text using the provided separator. The separator can be included or excluded + from the resulting splits based on the value of keep_separator. + + Args: + text (str): The text to be split. + separator (str): The separator to use for splitting the text. + keep_separator (bool): If True, the separator is kept in the resulting splits. If False, the separator is removed. + + Returns: + List[str]: A list of strings resulting from splitting the input text. + """ + if separator: + if keep_separator: + _splits = re.split(f"({separator})", text) + splits = [ + _splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2) + ] + if len(_splits) % 2 == 0: + splits += _splits[-1:] + splits = [_splits[0]] + splits + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + +def split_text( + text: str, + separator: str = "\n\n", + chunk_size: int = 4000, + chunk_overlap: int = 200, + keep_separator: bool = False, + strip_whitespace: bool = True, +) -> List[str]: + """Splits a given text into chunks of specified size with optional overlap. + + Args: + text (str): The text to be split. + separator (str, optional): The separator to use for splitting the text. Defaults to "\n\n". + chunk_size (int, optional): The maximum size of each chunk. Defaults to 4000. + chunk_overlap (int, optional): The size of the overlap between consecutive chunks. Defaults to 200. + keep_separator (bool, optional): If True, the separator is kept in the resulting splits. If False, the separator is removed. Defaults to False. + strip_whitespace (bool, optional): If True, leading and trailing whitespace is removed from each split. Defaults to True. + + Raises: + ValueError: If chunk_overlap is larger than chunk_size. + + Returns: + List[str]: A list of strings resulting from splitting the input text into chunks. + """ + if chunk_overlap > chunk_size: + raise ValueError( + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " + f"({chunk_size}), should be smaller." + ) + + separator_regex = re.escape(separator) + splits = split_text_with_regex(text, separator_regex, keep_separator) + _separator = "" if keep_separator else separator + + chunks = [] + current_chunk = "" + + for split in splits: + if strip_whitespace: + split = split.strip() + + if len(current_chunk) + len(split) + len(_separator) <= chunk_size: + current_chunk += split + _separator + else: + if current_chunk: + chunks.append(current_chunk.rstrip(_separator)) + current_chunk = split + _separator + + if current_chunk: + chunks.append(current_chunk.rstrip(_separator)) + + final_chunks = [] + for i in range(len(chunks)): + if i == 0: + final_chunks.append(chunks[i]) + else: + overlap = chunks[i - 1][-chunk_overlap:] + final_chunks.append(overlap + chunks[i]) + + return final_chunks + + +def split_documents( + documents: List[str], + separator: str = "\n\n", + chunk_size: int = 4000, + chunk_overlap: int = 200, + keep_separator: bool = False, + strip_whitespace: bool = True, +) -> List[str]: + """Splits a list of documents into chunks. + + Args: + documents (List[str]): The list of documents to be split. + separator (str, optional): The separator to use for splitting the documents. Defaults to "\n\n". + chunk_size (int, optional): The maximum size of each chunk. Defaults to 4000. + chunk_overlap (int, optional): The size of the overlap between consecutive chunks. Defaults to 200. + keep_separator (bool, optional): If True, the separator is kept in the resulting splits. If False, the separator is removed. Defaults to False. + strip_whitespace (bool, optional): If True, leading and trailing whitespace is removed from each split. Defaults to True. + + Returns: + List[str]: A list of chunked documents. + """ + chunked_documents = [] + for doc in documents: + chunked_documents.extend( + split_text( + doc, + separator=separator, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + keep_separator=keep_separator, + strip_whitespace=strip_whitespace, + ) + ) + return chunked_documents + + +def get_local_db_connection_details() -> Dict[str, str]: + """Returns the connection details for the local database. + + Returns: + dict: A dictionary containing the connection details for the local database. + """ + return { + "user": os.getenv("ZENML_SUPABASE_USER"), + "host": os.getenv("ZENML_SUPABASE_HOST"), + "port": os.getenv("ZENML_SUPABASE_PORT"), + } + + +def get_db_conn() -> connection: + """Establishes and returns a connection to the PostgreSQL database. + + This function retrieves the password for the PostgreSQL database from a secret store, + then uses it along with other connection details to establish a connection. + + Returns: + connection: A psycopg2 connection object to the PostgreSQL database. + """ + pg_password = ( + Client().get_secret("supabase_postgres_db").secret_values["password"] + ) + + local_database_connection = get_local_db_connection_details() + + CONNECTION_DETAILS = { + "user": local_database_connection["user"], + "password": pg_password, + "host": local_database_connection["host"], + "port": local_database_connection["port"], + "dbname": "postgres", + } + + return psycopg2.connect(**CONNECTION_DETAILS) + + +def get_topn_similar_docs(query_embedding, conn, n: int = 5): + """Fetches the top n most similar documents to the given query embedding from the database. + + Args: + query_embedding (list): The query embedding to compare against. + conn (psycopg2.extensions.connection): The database connection object. + n (int, optional): The number of similar documents to fetch. Defaults to 5. + + Returns: + list: A list of tuples containing the content of the top n most similar documents. + """ + embedding_array = np.array(query_embedding) + register_vector(conn) + cur = conn.cursor() + cur.execute( + f"SELECT content FROM embeddings ORDER BY embedding <=> %s LIMIT {n}", + (embedding_array,), + ) + return cur.fetchall() + + +def get_completion_from_messages( + messages, model=OPENAI_MODEL, temperature=0.4, max_tokens=1000 +): + """Generates a completion response from the given messages using the specified model. + + Args: + messages (list): The list of messages to generate a completion from. + model (str, optional): The model to use for generating the completion. Defaults to OPENAI_MODEL. + temperature (float, optional): The temperature to use for the completion. Defaults to 0.4. + max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 1000. + + Returns: + str: The content of the completion response. + """ + model = MODEL_NAME_MAP.get(model, model) + completion_response = litellm.completion( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + return completion_response.choices[0].message.content + + +def get_embeddings(text): + """Generates embeddings for the given text using a SentenceTransformer model. + + Args: + text (str): The text to generate embeddings for. + + Returns: + np.ndarray: The generated embeddings. + """ + model = SentenceTransformer(EMBEDDINGS_MODEL) + return model.encode(text) + + +def process_input_with_retrieval(input: str, model: str = OPENAI_MODEL) -> str: + """Process the input with retrieval. + + Args: + input (str): The input to process. + + Returns: + str: The processed output. + """ + delimiter = "```" + + # Step 1: Get documents related to the user input from database + related_docs = get_topn_similar_docs(get_embeddings(input), get_db_conn()) + + # Step 2: Get completion from OpenAI API + # Set system message to help set appropriate tone and context for model + system_message = f""" + You are a friendly chatbot. \ + You can answer questions about ZenML, its features and its use cases. \ + You respond in a concise, technically credible tone. \ + You ONLY use the context from the ZenML documentation to provide relevant + answers. \ + You do not make up answers or provide opinions that you don't have information to support. \ + """ + + # Prepare messages to pass to model + # We use a delimiter to help the model understand the where the user_input + # starts and ends + + messages = [ + {"role": "system", "content": system_message}, + {"role": "user", "content": f"{delimiter}{input}{delimiter}"}, + { + "role": "assistant", + "content": f"Relevant ZenML documentation: \n" + + "\n".join(doc[0] for doc in related_docs), + }, + ] + logger.debug("CONTEXT USED\n\n", messages[2]["content"], "\n\n") + return get_completion_from_messages(messages, model=model)