Skip to content

Commit

Permalink
fix: improve code for unit tests and fix LLM client logic (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
yoomlam authored Jun 8, 2024
1 parent dc872da commit 2468f25
Show file tree
Hide file tree
Showing 15 changed files with 67 additions and 56 deletions.
6 changes: 3 additions & 3 deletions 05-assistive-chatbot/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ After running the chatbot and providing feedback in the UI, review the feedback
## Running an application

There are several ways to run the chatbot application, offering different ways to interact with the chatbot.
All apps use configurations set in `.env`, which is *not* checked into git. These configurations (like `CHAT_ENGINE` and `LLM_MODEL_NAME`) can be overridden by environment variables set on the commandline. See `_init_settings()` in `chatbot/__init__.py` for other variables.
All apps use configurations set in `.env`, which is *not* checked into git. These configurations (like `CHAT_ENGINE` and `LLM_MODEL_NAME`) can be overridden by environment variables set on the commandline. See `chatbot/__init__.py` for other variables.

### Run commandline app

Expand Down Expand Up @@ -53,7 +53,7 @@ This application runs the chatbot API for other applications to make requests to

- Application entrypoints are in the root folder of the repo. Other Python files are under the `chatbot` folder.
- The chatbot package `chatbot/__init__.py` is run for all apps because they `import chatbot`.
- It initializes settings (`_init_settings()`) and creates a specified chat engine (`create_chat_engine(settings)`).
- It initializes settings and creates a specified chat engine (`create_chat_engine(settings)`).

### Adding a chat engine

Expand All @@ -79,7 +79,7 @@ To create a new LLM client, add a new Python file under `chatbot/llms` with:
- an LLM client class that:
- sets `self.client` based on the provided `settings`, and
- implements a `generate_reponse(self, message)` function that uses `self.client` to generate a response, which may need to be parsed so that a string is returned to `chat_engine.gen_response(self, query)`.
- (optional) a `requirements_satisfied(settings)` function that checks if necessary environment variable(s) and other LLM client preconditions are satisfied;
- (optional) a `requirements_satisfied()` function that checks if necessary environment variable(s) and other LLM client preconditions are satisfied;
The new Python file will be automatically discovered and registered for display in the Chainlit settings web UI.

An LLM client can be used in any arbitrary program by:
Expand Down
20 changes: 11 additions & 9 deletions 05-assistive-chatbot/chatbot-chainlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

logger = logging.getLogger(f"chatbot.{__name__}")

if chatbot.initial_settings["enable_api"]:
if utils.is_env_var_true("ENABLE_CHATBOT_API", False):
import chatbot_api

logger.info("Chatbot API loaded: %s", chatbot_api.__name__)
Expand All @@ -33,8 +33,9 @@ async def init_chat():
logger.debug("init_chat")
git_sha = os.environ.get("GIT_SHA", "")
build_date = os.environ.get("BUILD_DATE", "unknown")
initial_settings = chatbot.create_init_settings()
metadata = {
**chatbot.initial_settings,
**initial_settings,
"build_date": build_date,
"git_sha": git_sha,
"hostname": socket.gethostname(),
Expand All @@ -50,26 +51,26 @@ async def init_chat():
id="chat_engine",
label="Chat Mode",
values=engines.available_engines(),
initial_value=chatbot.initial_settings["chat_engine"],
initial_value=initial_settings["chat_engine"],
),
Select(
id="model",
label="Primary LLM Model",
values=available_llms,
initial_value=chatbot.initial_settings["model"],
initial_value=initial_settings["model"],
),
Slider(
id="temperature",
label="Temperature for primary LLM",
initial=chatbot.initial_settings["temperature"],
initial=initial_settings["temperature"],
min=0,
max=2,
step=0.1,
),
Slider(
id="retrieve_k",
label="Guru cards to retrieve",
initial=chatbot.initial_settings["retrieve_k"],
initial=initial_settings["retrieve_k"],
min=1,
max=10,
step=1,
Expand All @@ -78,12 +79,12 @@ async def init_chat():
id="model2",
label="LLM Model for summarizer",
values=available_llms,
initial_value=chatbot.initial_settings["model2"],
initial_value=initial_settings["model2"],
),
Slider(
id="temperature2",
label="Temperature for summarizer",
initial=chatbot.initial_settings["temperature2"],
initial=initial_settings["temperature2"],
min=0,
max=2,
step=0.1,
Expand All @@ -98,7 +99,8 @@ async def init_chat():
if error:
assert False, f"Validation error: {error}"

if chatbot.initial_settings["preload_chat_engine"]:
preload_chat_engine_default = "ENGINE_MODULES" in os.environ and "LLM_MODULES" in os.environ
if utils.is_env_var_true("PRELOAD_CHAT_ENGINE", preload_chat_engine_default):
logger.info("Preloading chat engine")
await apply_settings()

Expand Down
25 changes: 6 additions & 19 deletions 05-assistive-chatbot/chatbot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
## Set default environment variables


os.environ.setdefault("ENV", "DEV")

# Opt out of telemetry -- https://docs.trychroma.com/telemetry
os.environ.setdefault("ANONYMIZED_TELEMETRY", "False")

Expand Down Expand Up @@ -44,7 +46,7 @@ def configure_logging():
logging.info("Configured logging level for root logger: %s", root_log_level)


env = os.environ.get("ENV", "DEV")
env = os.environ.get("ENV")
print(f"Loading .env-{env}")
dotenv.load_dotenv(f".env-{env}")
dotenv.load_dotenv()
Expand All @@ -63,14 +65,10 @@ def configure_logging():


@utils.verbose_timer(logger)
def _init_settings():
# Remember to update ChatSettings in chatbot-chainlit.py when adding new settings
# and update chatbot/engines/__init.py:CHATBOT_SETTING_KEYS
preload_chat_engine_default = "ENGINE_MODULES" in os.environ and "LLM_MODULES" in os.environ
def create_init_settings():
# REMINDER: when adding new settings, update ChatSettings in chatbot-chainlit.py
# and chatbot/engines/__init.py:LLM_SETTING_KEYS, if applicable
return {
"env": env,
"enable_api": is_env_var_true("ENABLE_CHATBOT_API", False),
"preload_chat_engine": is_env_var_true("PRELOAD_CHAT_ENGINE", preload_chat_engine_default),
"chat_engine": os.environ.get("CHAT_ENGINE", "Direct"),
"model": os.environ.get("LLM_MODEL_NAME", "mock :: llm"),
"temperature": float(os.environ.get("LLM_TEMPERATURE", 0.1)),
Expand All @@ -81,21 +79,10 @@ def _init_settings():
}


def is_env_var_true(var_name, default=False):
if value:= os.environ.get(var_name, None):
return value.lower() not in ["false", "f", "no", "n"]
return default


initial_settings = _init_settings()


def reset():
configure_logging()
engines._engines.clear()
llms._llms.clear()
global initial_settings
initial_settings = _init_settings()


@utils.verbose_timer(logger)
Expand Down
10 changes: 4 additions & 6 deletions 05-assistive-chatbot/chatbot/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from types import ModuleType
from typing import Dict

import chatbot
from chatbot import llms, utils

logger = logging.getLogger(__name__)
Expand All @@ -26,12 +25,11 @@ def _discover_chat_engines(force=False):
if not engine_modules:
engine_modules = utils.scan_modules(__package__)

settings = chatbot.initial_settings
for module_name, module in engine_modules.items():
if not hasattr(module, "ENGINE_NAME"):
logger.debug("Skipping module without an ENGINE_NAME: %s", module_name)
continue
if hasattr(module, "requirements_satisfied") and not module.requirements_satisfied(settings):
if hasattr(module, "requirements_satisfied") and not module.requirements_satisfied():
logger.debug("Engine requirements not satisfied; skipping: %s", module_name)
continue
engine_name = module.ENGINE_NAME
Expand All @@ -50,13 +48,13 @@ def create_engine(engine_name, settings=None):

## Utility functions

# Settings that are specific to our chatbot and shouldn't be passed onto the LLM client
CHATBOT_SETTING_KEYS = ["env", "enable_api", "chat_engine", "model", "model2", "temperature2", "retrieve_k"]
# Settings that are specific to LLMs and should be passed onto the LLM client
LLM_SETTING_KEYS = ["temperature"]


@utils.timer
def create_llm_client(settings):
llm_name = settings["model"]
remaining_settings = {k: settings[k] for k in settings if k not in CHATBOT_SETTING_KEYS}
remaining_settings = {k: settings[k] for k in settings if k in LLM_SETTING_KEYS}
client = llms.init_client(llm_name, remaining_settings)
return client
6 changes: 2 additions & 4 deletions 05-assistive-chatbot/chatbot/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from types import ModuleType
from typing import Dict, Tuple

import chatbot
from chatbot import utils

logger = logging.getLogger(__name__)
Expand All @@ -26,17 +25,16 @@ def _discover_llms(force=False):
if not llm_modules:
llm_modules = utils.scan_modules(__package__)

settings = chatbot.initial_settings
for module_name, module in llm_modules.items():
if not module or ignore(module_name):
logger.debug("Skipping module: %s", module_name)
continue
if hasattr(module, "requirements_satisfied") and not module.requirements_satisfied(settings):
if hasattr(module, "requirements_satisfied") and not module.requirements_satisfied():
logger.debug("Module requirements not satisfied; skipping: %s", module_name)
continue
client_name = module.CLIENT_NAME or module_name
if hasattr(module, "model_names"):
model_names = module.model_names(settings)
model_names = module.model_names()
else:
model_names = module.MODEL_NAMES

Expand Down
4 changes: 2 additions & 2 deletions 05-assistive-chatbot/chatbot/llms/dspy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
MODEL_NAMES = _OLLAMA_LLMS + _OPENAI_LLMS + _GOOGLE_LLMS + _GROQ_LLMS


def model_names(settings):
def model_names():
available_models = []
if settings["env"] != "PROD":
if os.environ.get("ENV") != "PROD":
# Include Ollama models if not in production b/c it requires a local Ollama installation
available_models += _OLLAMA_LLMS

Expand Down
2 changes: 1 addition & 1 deletion 05-assistive-chatbot/chatbot/llms/google_gemini_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
MODEL_NAMES = ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash"]


def requirements_satisfied(_settings):
def requirements_satisfied():
if not os.environ.get("GOOGLE_API_KEY"):
return False
return True
Expand Down
11 changes: 7 additions & 4 deletions 05-assistive-chatbot/chatbot/llms/groq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
MODEL_NAMES = ["llama3-70b-8192", "mixtral-8x7b-32768"]


def requirements_satisfied(_settings):
def requirements_satisfied():
if not os.environ.get("GROQ_API_KEY"):
return False
return True
Expand All @@ -20,12 +20,15 @@ def init_client(model_name, settings):


class GroqClient:
INIT_ARGS = ["timeout", "max_retries", "default_headers", "default_query", "base_url", "http_client"]

def __init__(self, model_name, settings):
self.model_name = model_name
self.settings = settings
logger.info("Creating LLM client '%s' with %s", model_name, self.settings)
# TODO: remove temperature from settings
self.client = Groq(**self.settings)

init_settings = {k: settings[k] for k in settings if k in self.INIT_ARGS}
logger.info("Creating LLM client '%s' with %s", model_name, init_settings)
self.client = Groq(**init_settings)

def generate_reponse(self, message):
chat_completion = self.client.chat.completions.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
MODEL_NAMES = ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash"]


def requirements_satisfied(_settings):
def requirements_satisfied():
if not os.environ.get("GOOGLE_API_KEY"):
return False
return True
Expand Down
7 changes: 4 additions & 3 deletions 05-assistive-chatbot/chatbot/llms/langchain_ollama_client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import logging
import os

from langchain_community.llms.ollama import Ollama

logger = logging.getLogger(__name__)

CLIENT_NAME = "langchain.ollama"
MODEL_NAMES = ["openhermes", "llama2", "mistral"]
MODEL_NAMES = ["openhermes", "llama2", "llama2:chat", "llama3", "mistral", "mistral:instruct"]


def requirements_satisfied(settings):
if settings["env"] == "PROD":
def requirements_satisfied():
if os.environ.get("ENV") == "PROD":
# Exclude Ollama models in production b/c it requires a local Ollama installation
return False
return True
Expand Down
17 changes: 16 additions & 1 deletion 05-assistive-chatbot/chatbot/llms/mock_llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,24 @@ def init_client(model_name, settings):
class MockLlmClient:
"Mock client that returns the mock_responses or the message itself."

sample_qa_pairs = {
"test Q1": {
"errorMsg": "Some error message",
"systemMsg": "Some system message",
"content": "Some content",
},
"test Q2": {
"errorMsg": "Some error message",
"systemMsg": {"content": "Some system message", "metadata": {"errorObj": 123}},
"content": "Some content",
"metadata": {"key1": "value1", "key2": "value2"},
},
}

def __init__(self, model_name, settings):
logger.info("Creating Mock LLM client '%s' with %s", model_name, settings)
self.mock_responses = settings

self.mock_responses = self.sample_qa_pairs | settings

def generate_reponse(self, message):
return self.mock_responses.get(message, f"Mock LLM> Your query was: {message}")
Expand Down
2 changes: 1 addition & 1 deletion 05-assistive-chatbot/chatbot/llms/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
MODEL_NAMES = _CHAT_MODELS + _LEGACY_MODELS


def requirements_satisfied(_settings):
def requirements_satisfied():
if not os.environ.get("OPENAI_API_KEY"):
return False
return True
Expand Down
7 changes: 7 additions & 0 deletions 05-assistive-chatbot/chatbot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@
import importlib
import inspect
import logging
import os
import pkgutil
import pprint
import textwrap
import time


def is_env_var_true(var_name, default=False):
if value := os.environ.get(var_name, None):
return value.lower() not in ["false", "f", "no", "n"]
return default


def timer(func):
"A decorator that logs the time it takes for the decorated function to run"
module = inspect.getmodule(func)
Expand Down
2 changes: 1 addition & 1 deletion 05-assistive-chatbot/chatbot_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class ApiState:
@cached_property
def chat_engine(self):
# Load the initial settings
settings = chatbot.initial_settings
settings = chatbot.create_init_settings()
chatbot.validate_settings(settings)

# Create the chat engine
Expand Down
2 changes: 1 addition & 1 deletion 05-assistive-chatbot/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
logger = logging.getLogger(f"chatbot.{__name__}")

# Load the initial settings
settings = chatbot.initial_settings
settings = chatbot.create_init_settings()
chatbot.validate_settings(settings)

# List LLMs, when CHATBOT_LOG_LEVEL=DEBUG
Expand Down

0 comments on commit 2468f25

Please sign in to comment.