Skip to content

Commit

Permalink
Merge pull request #96 from aurelio-labs/bogdan/local-llm
Browse files Browse the repository at this point in the history
feat: Adds LlamaCpp LLM
  • Loading branch information
jamescalam authored Jan 13, 2024
2 parents 666b361 + 2285771 commit 311e909
Show file tree
Hide file tree
Showing 12 changed files with 1,083 additions and 117 deletions.
699 changes: 699 additions & 0 deletions docs/05-local-execution.ipynb

Large diffs are not rendered by default.

114 changes: 73 additions & 41 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ pinecone-text = {version = "^0.7.1", optional = true}
fastembed = {version = "^0.1.3", optional = true, python = "<3.12"}
torch = {version = "^2.1.2", optional = true}
transformers = {version = "^4.36.2", optional = true}
llama-cpp-python = {version = "^0.2.28", optional = true}

[tool.poetry.extras]
hybrid = ["pinecone-text"]
fastembed = ["fastembed"]
local = ["torch", "transformers"]
local = ["torch", "transformers", "llama-cpp-python"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
Expand Down
3 changes: 2 additions & 1 deletion semantic_router/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from semantic_router.llms.base import BaseLLM
from semantic_router.llms.cohere import CohereLLM
from semantic_router.llms.llamacpp import LlamaCppLLM
from semantic_router.llms.openai import OpenAILLM
from semantic_router.llms.openrouter import OpenRouterLLM

__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM"]
__all__ = ["BaseLLM", "OpenAILLM", "OpenRouterLLM", "CohereLLM", "LlamaCppLLM"]
74 changes: 73 additions & 1 deletion semantic_router/llms/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import List, Optional
import json
from typing import Any, List, Optional

from pydantic import BaseModel

from semantic_router.schema import Message
from semantic_router.utils.logger import logger


class BaseLLM(BaseModel):
Expand All @@ -11,5 +13,75 @@ class BaseLLM(BaseModel):
class Config:
arbitrary_types_allowed = True

def __init__(self, name: str, **kwargs):
super().__init__(name=name, **kwargs)

def __call__(self, messages: List[Message]) -> Optional[str]:
raise NotImplementedError("Subclasses must implement this method")

def _is_valid_inputs(
self, inputs: dict[str, Any], function_schema: dict[str, Any]
) -> bool:
"""Validate the extracted inputs against the function schema"""
try:
# Extract parameter names and types from the signature string
signature = function_schema["signature"]
param_info = [param.strip() for param in signature[1:-1].split(",")]
param_names = [info.split(":")[0].strip() for info in param_info]
param_types = [
info.split(":")[1].strip().split("=")[0].strip() for info in param_info
]

for name, type_str in zip(param_names, param_types):
if name not in inputs:
logger.error(f"Input {name} missing from query")
return False
return True
except Exception as e:
logger.error(f"Input validation error: {str(e)}")
return False

def extract_function_inputs(
self, query: str, function_schema: dict[str, Any]
) -> dict:
logger.info("Extracting function input...")

prompt = f"""
You are a helpful assistant designed to output JSON.
Given the following function schema
<< {function_schema} >>
and query
<< {query} >>
extract the parameters values from the query, in a valid JSON format.
Example:
Input:
query: "How is the weather in Hawaii right now in International units?"
schema:
{{
"name": "get_weather",
"description": "Useful to get the weather in a specific location",
"signature": "(location: str, degree: str) -> str",
"output": "<class 'str'>",
}}
Result: {{
"location": "London",
"degree": "Celsius",
}}
Input:
query: {query}
schema: {function_schema}
Result:
"""
llm_input = [Message(role="user", content=prompt)]
output = self(llm_input)
if not output:
raise Exception("No output generated for extract function input")

output = output.replace("'", '"').strip().rstrip(",")

function_inputs = json.loads(output)
if not self._is_valid_inputs(function_inputs, function_schema):
raise ValueError("Invalid inputs")
return function_inputs
25 changes: 25 additions & 0 deletions semantic_router/llms/grammars/json.gbnf
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws

object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws

array ::=
"[" ws (
value
("," ws value)*
)? "]" ws

string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws

number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws

# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
Loading

0 comments on commit 311e909

Please sign in to comment.