Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deployment to llm-complete #160

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions llm-complete-guide/.env.local
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
MODELS=[{"name":"llm-complete-rag-webui","parameters":{"temperature":0.5,"max_new_tokens":1024},"endpoints":[{"type":"openai","baseURL":"http://localhost:3000/generate"}]}]

14 changes: 6 additions & 8 deletions llm-complete-guide/gh_action_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,10 @@

import click
import yaml
from zenml.enums import PluginSubType

from pipelines.llm_index_and_evaluate import llm_index_and_evaluate
from zenml.client import Client
from zenml import Model
from zenml.exceptions import ZenKeyError
from zenml.client import Client
from zenml.enums import PluginSubType


@click.command(
Expand Down Expand Up @@ -89,7 +87,7 @@ def main(
zenml_model_name: Optional[str] = "zenml-docs-qa-rag",
zenml_model_version: Optional[str] = None,
):
"""
"""
Executes the pipeline to train a basic RAG model.

Args:
Expand All @@ -108,14 +106,14 @@ def main(
config = yaml.safe_load(file)

# Read the model version from a file in the root of the repo
# called "ZENML_VERSION.txt".
# called "ZENML_VERSION.txt".
if zenml_model_version == "staging":
postfix = "-rc0"
elif zenml_model_version == "production":
postfix = ""
else:
postfix = "-dev"

if Path("ZENML_VERSION.txt").exists():
with open("ZENML_VERSION.txt", "r") as file:
zenml_model_version = file.read().strip()
Expand Down Expand Up @@ -177,7 +175,7 @@ def main(
service_account_id=service_account_id,
auth_window=0,
flavor="builtin",
action_type=PluginSubType.PIPELINE_RUN
action_type=PluginSubType.PIPELINE_RUN,
).id
client.create_trigger(
name="Production Trigger LLM-Complete",
Expand Down
4 changes: 3 additions & 1 deletion llm-complete-guide/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,7 @@
from pipelines.generate_chunk_questions import generate_chunk_questions
from pipelines.llm_basic_rag import llm_basic_rag
from pipelines.llm_eval import llm_eval
from pipelines.llm_index_and_evaluate import llm_index_and_evaluate
from pipelines.local_deployment import local_deployment
from pipelines.prod_deployment import production_deployment
from pipelines.rag_deployment import rag_deployment
from pipelines.llm_index_and_evaluate import llm_index_and_evaluate
1 change: 0 additions & 1 deletion llm-complete-guide/pipelines/finetune_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.

from constants import EMBEDDINGS_MODEL_NAME_ZENML
from steps.finetune_embeddings import (
evaluate_base_model,
evaluate_finetuned_model,
Expand Down
1 change: 0 additions & 1 deletion llm-complete-guide/pipelines/llm_basic_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from litellm import config_path

from steps.populate_index import (
generate_embeddings,
Expand Down
3 changes: 2 additions & 1 deletion llm-complete-guide/pipelines/llm_index_and_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# limitations under the License.
#

from pipelines import llm_basic_rag, llm_eval
from zenml import pipeline

from pipelines import llm_basic_rag, llm_eval


@pipeline
def llm_index_and_evaluate() -> None:
Expand Down
9 changes: 9 additions & 0 deletions llm-complete-guide/pipelines/local_deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from steps.bento_builder import bento_builder
from steps.bento_deployment import bento_deployment
from zenml import pipeline


@pipeline(enable_cache=False)
def local_deployment():
bento = bento_builder()
bento_deployment(bento)
32 changes: 32 additions & 0 deletions llm-complete-guide/pipelines/prod_deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Apache Software License 2.0
#
# Copyright (c) ZenML GmbH 2024. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from steps.bento_dockerizer import bento_dockerizer
from steps.k8s_deployment import k8s_deployment
from steps.visualize_chat import create_chat_interface
from zenml import pipeline


@pipeline(enable_cache=False)
def production_deployment():
"""Model deployment pipeline.

This is a pipeline deploys trained model for future inference.
"""
bento_model_image = bento_dockerizer()
deployment_info = k8s_deployment(bento_model_image)
create_chat_interface(deployment_info)
52 changes: 42 additions & 10 deletions llm-complete-guide/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,14 @@
generate_synthetic_data,
llm_basic_rag,
llm_eval,
rag_deployment,
llm_index_and_evaluate,
local_deployment,
production_deployment,
rag_deployment,
)
from structures import Document
from zenml.materializers.materializer_registry import materializer_registry
from zenml import Model
from zenml.materializers.materializer_registry import materializer_registry

logger = get_logger(__name__)

Expand Down Expand Up @@ -95,6 +97,13 @@
default="gpt4",
help="The model to use for the completion.",
)
@click.option(
"--query-text",
"query_text",
required=False,
default=None,
help="The query text to use for the completion.",
)
@click.option(
"--zenml-model-name",
"zenml_model_name",
Expand Down Expand Up @@ -136,6 +145,12 @@
default=None,
help="Path to config",
)
@click.option(
"--env",
"env",
default="local",
help="The environment to use for the completion.",
)
def main(
pipeline: str,
query_text: Optional[str] = None,
Expand All @@ -146,6 +161,7 @@ def main(
use_argilla: bool = False,
use_reranker: bool = False,
config: Optional[str] = None,
env: str = "local",
):
"""Main entry point for the pipeline execution.

Expand All @@ -159,6 +175,7 @@ def main(
use_argilla (bool): If True, Argilla an notations will be used
use_reranker (bool): If True, rerankers will be used
config (Optional[str]): Path to config file
env (str): The environment to use for the deployment (local, huggingface space, k8s etc.)
"""
pipeline_args = {"enable_cache": not no_cache}
embeddings_finetune_args = {
Expand All @@ -169,9 +186,9 @@ def main(
}
},
}

# Read the model version from a file in the root of the repo
# called "ZENML_VERSION.txt".
# called "ZENML_VERSION.txt".
if zenml_model_version == "staging":
postfix = "-rc0"
elif zenml_model_version == "production":
Expand All @@ -181,8 +198,10 @@ def main(

if Path("ZENML_VERSION.txt").exists():
with open("ZENML_VERSION.txt", "r") as file:
zenml_model_version = file.read().strip()
zenml_model_version += postfix
zenml_version = file.read().strip()
zenml_version += postfix
# zenml_model_version = file.read().strip()
# zenml_model_version += postfix
else:
raise RuntimeError(
"No model version file found. Please create a file called ZENML_VERSION.txt in the root of the repo with the model version."
Expand All @@ -191,7 +210,7 @@ def main(
# Create ZenML model
zenml_model = Model(
name=zenml_model_name,
version=zenml_model_version,
version=zenml_version,
license="Apache 2.0",
description="RAG application for ZenML docs",
tags=["rag", "finetuned", "chatbot"],
Expand Down Expand Up @@ -251,8 +270,19 @@ def main(
)()

elif pipeline == "deploy":
rag_deployment.with_options(model=zenml_model, **pipeline_args)()

zenml_model.version = zenml_model_version
if env == "local":
local_deployment.with_options(
model=zenml_model, config_path=config_path, **pipeline_args
)()
elif env == "huggingface":
rag_deployment.with_options(
model=zenml_model, config_path=config_path, **pipeline_args
)()
elif env == "k8s":
production_deployment.with_options(
model=zenml_model, config_path=config_path, **pipeline_args
)()
elif pipeline == "evaluation":
pipeline_args["enable_cache"] = False
llm_eval.with_options(model=zenml_model, config_path=config_path)()
Expand All @@ -264,7 +294,9 @@ def main(

elif pipeline == "embeddings":
finetune_embeddings.with_options(
model=zenml_model, config_path=config_path, **embeddings_finetune_args
model=zenml_model,
config_path=config_path,
**embeddings_finetune_args,
)()

elif pipeline == "chunks":
Expand Down
Loading
Loading