Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add and enable OpenAI strict mode #55

Merged
merged 11 commits into from
Dec 5, 2024
21 changes: 14 additions & 7 deletions src/raglite/_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import pandas as pd
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from sqlmodel import Session, func, select
from tqdm.auto import tqdm, trange

Expand All @@ -25,10 +25,11 @@ def insert_evals( # noqa: C901
class QuestionResponse(BaseModel):
"""A specific question about the content of a set of document contexts."""

model_config = ConfigDict(
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
)
question: str = Field(
...,
description="A specific question about the content of a set of document contexts.",
min_length=1,
lsorber marked this conversation as resolved.
Show resolved Hide resolved
..., description="A specific question about the content of a set of document contexts."
)
system_prompt: ClassVar[str] = """
You are given a set of contexts extracted from a document.
Expand Down Expand Up @@ -85,7 +86,7 @@ def validate_question(cls, value: str) -> str:
# Extract a question from the seed chunk's related chunks.
try:
question_response = extract_with_llm(
QuestionResponse, related_chunks, config=config
QuestionResponse, related_chunks, strict=True, config=config
)
except ValueError:
continue
Expand All @@ -101,6 +102,9 @@ def validate_question(cls, value: str) -> str:
class ContextEvalResponse(BaseModel):
"""Indicate whether the provided context can be used to answer a given question."""

model_config = ConfigDict(
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
)
hit: bool = Field(
...,
description="True if the provided context contains (a part of) the answer to the given question, false otherwise.",
Expand All @@ -118,7 +122,7 @@ class ContextEvalResponse(BaseModel):
):
try:
context_eval_response = extract_with_llm(
ContextEvalResponse, str(candidate_chunk), config=config
ContextEvalResponse, str(candidate_chunk), strict=True, config=config
)
except ValueError: # noqa: PERF203
pass
Expand All @@ -132,10 +136,12 @@ class ContextEvalResponse(BaseModel):
class AnswerResponse(BaseModel):
"""Answer a question using the provided context."""

model_config = ConfigDict(
extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode.
)
answer: str = Field(
...,
description="A complete answer to the given question using the provided context.",
min_length=1,
)
system_prompt: ClassVar[str] = f"""
You are given a set of contexts extracted from a document.
Expand All @@ -152,6 +158,7 @@ class AnswerResponse(BaseModel):
answer_response = extract_with_llm(
AnswerResponse,
[str(relevant_chunk) for relevant_chunk in relevant_chunks],
strict=True,
config=config,
)
except ValueError:
Expand Down
31 changes: 14 additions & 17 deletions src/raglite/_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import Any, TypeVar

import litellm
from litellm import completion, get_supported_openai_params # type: ignore[attr-defined]
from pydantic import BaseModel, ValidationError

Expand All @@ -14,6 +13,7 @@
def extract_with_llm(
return_type: type[T],
user_prompt: str | list[str],
strict: bool = False, # noqa: FBT001,FBT002
config: RAGLiteConfig | None = None,
**kwargs: Any,
) -> T:
Expand All @@ -33,29 +33,31 @@ class MyNameResponse(BaseModel):
"""
# Load the default config if not provided.
config = config or RAGLiteConfig()
# Update the system prompt with the JSON schema of the return type to help the LLM.
system_prompt = "\n".join(
(
return_type.system_prompt.strip(), # type: ignore[attr-defined]
"Format your response according to this JSON schema:",
str(return_type.model_json_schema()),
)
# Check if the LLM supports the response format.
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
llm_supports_response_format = "response_format" in (
get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or []
)
# Constrain the reponse format to the JSON schema if it's supported by the LLM [1].
# Update the system prompt with the JSON schema of the return type to help the LLM.
system_prompt = getattr(return_type, "system_prompt", "").strip()
if not llm_supports_response_format or llm_provider == "llama-cpp-python":
system_prompt += f"\n\nFormat your response according to this JSON schema:\n{return_type.model_json_schema()!s}"
# Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode
# is disabled by default because it only supports a subset of JSON schema features [2].
# [1] https://docs.litellm.ai/docs/completion/json_mode
# [2] https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported
# TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM.
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
response_format: dict[str, Any] | None = (
{
"type": "json_schema",
"json_schema": {
"name": return_type.__name__,
"description": return_type.__doc__ or "",
"schema": return_type.model_json_schema(),
"strict": strict,
},
}
if "response_format"
in (get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or [])
if llm_supports_response_format
else None
)
# Concatenate the user prompt if it is a list of strings.
Expand All @@ -64,9 +66,6 @@ class MyNameResponse(BaseModel):
f'<context index="{i + 1}">\n{chunk.strip()}\n</context>'
for i, chunk in enumerate(user_prompt)
)
# Enable JSON schema validation.
enable_json_schema_validation = litellm.enable_json_schema_validation
litellm.enable_json_schema_validation = True
# Extract structured data from the unstructured input.
for _ in range(config.llm_max_tries):
response = completion(
Expand All @@ -89,6 +88,4 @@ class MyNameResponse(BaseModel):
else:
error_message = f"Failed to extract {return_type} from input {user_prompt}."
raise ValueError(error_message) from last_exception
# Restore the previous JSON schema validation setting.
litellm.enable_json_schema_validation = enable_json_schema_validation
return instance
23 changes: 15 additions & 8 deletions tests/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import ClassVar

import pytest
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from raglite import RAGLiteConfig
from raglite._extract import extract_with_llm
Expand All @@ -13,29 +13,36 @@
params=[
pytest.param(RAGLiteConfig().llm, id="llama_cpp_python"),
pytest.param("gpt-4o-mini", id="openai"),
],
]
)
def llm(
request: pytest.FixtureRequest,
) -> str:
def llm(request: pytest.FixtureRequest) -> str:
"""Get an LLM to test RAGLite with."""
llm: str = request.param
return llm


def test_extract(llm: str) -> None:
@pytest.mark.parametrize(
"strict", [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")]
)
def test_extract(llm: str, strict: bool) -> None: # noqa: FBT001
"""Test extracting structured data."""
# Set the LLM.
config = RAGLiteConfig(llm=llm)

# Extract structured data.
# Define the JSON schema of the response.
class LoginResponse(BaseModel):
"""The response to a login request."""

model_config = ConfigDict(extra="forbid" if strict else "allow")
username: str = Field(..., description="The username.")
password: str = Field(..., description="The password.")
system_prompt: ClassVar[str] = "Extract the username and password from the input."

# Extract structured data.
username, password = "cypher", "steak"
login_response = extract_with_llm(LoginResponse, f"{username} // {password}", config=config)
login_response = extract_with_llm(
LoginResponse, f"{username} // {password}", strict=strict, config=config
)
# Validate the response.
assert isinstance(login_response, LoginResponse)
assert login_response.username == username
Expand Down
Loading