From 9a99d034788f24800e98818c34fa4a7a1be90e8e Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Wed, 4 Dec 2024 15:37:09 +0100 Subject: [PATCH 01/11] fix: OpenAI json_schema doesn't validate. --- src/raglite/_eval.py | 14 ++++++++++---- src/raglite/_extract.py | 1 + 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index f26789c..23ebf48 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -26,9 +26,7 @@ class QuestionResponse(BaseModel): """A specific question about the content of a set of document contexts.""" question: str = Field( - ..., - description="A specific question about the content of a set of document contexts.", - min_length=1, + ..., description="A specific question about the content of a set of document contexts." ) system_prompt: ClassVar[str] = """ You are given a set of contexts extracted from a document. @@ -43,6 +41,9 @@ class QuestionResponse(BaseModel): - The question MUST treat the context as if its contents are entirely part of your working memory. """.strip() + class Config: + extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API json schema. + @field_validator("question") @classmethod def validate_question(cls, value: str) -> str: @@ -112,6 +113,9 @@ class ContextEvalResponse(BaseModel): An example of a context that does NOT contain (a part of) the answer is a table of contents. """.strip() + class Config: + extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API json schema. + relevant_chunks = [] for candidate_chunk in tqdm( candidate_chunks, desc="Evaluating chunks", unit="chunk", dynamic_ncols=True @@ -135,7 +139,6 @@ class AnswerResponse(BaseModel): answer: str = Field( ..., description="A complete answer to the given question using the provided context.", - min_length=1, ) system_prompt: ClassVar[str] = f""" You are given a set of contexts extracted from a document. @@ -148,6 +151,9 @@ class AnswerResponse(BaseModel): - The answer MUST treat the context as if its contents are entirely part of your working memory. """.strip() + class Config: + extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API json schema. + try: answer_response = extract_with_llm( AnswerResponse, diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py index f3d73ff..323e902 100644 --- a/src/raglite/_extract.py +++ b/src/raglite/_extract.py @@ -52,6 +52,7 @@ class MyNameResponse(BaseModel): "name": return_type.__name__, "description": return_type.__doc__ or "", "schema": return_type.model_json_schema(), + "strict": True, }, } if "response_format" From 85918fc98ea7d3719f0ed3b2f44773211b3631be Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Wed, 4 Dec 2024 17:21:07 +0100 Subject: [PATCH 02/11] fix: test extract with OpenAI --- tests/test_extract.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_extract.py b/tests/test_extract.py index 3ff2a85..23eaf2b 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -13,11 +13,9 @@ params=[ pytest.param(RAGLiteConfig().llm, id="llama_cpp_python"), pytest.param("gpt-4o-mini", id="openai"), - ], + ] ) -def llm( - request: pytest.FixtureRequest, -) -> str: +def llm(request: pytest.FixtureRequest) -> str: """Get an LLM to test RAGLite with.""" llm: str = request.param return llm @@ -34,6 +32,9 @@ class LoginResponse(BaseModel): password: str = Field(..., description="The password.") system_prompt: ClassVar[str] = "Extract the username and password from the input." + class Config: + extra = "forbid" + username, password = "cypher", "steak" login_response = extract_with_llm(LoginResponse, f"{username} // {password}", config=config) # Validate the response. From 475e461d1f19319d2246a13a44d588128bd8d60d Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Thu, 5 Dec 2024 13:36:32 +0100 Subject: [PATCH 03/11] Update src/raglite/_eval.py Co-authored-by: Laurent Sorber --- src/raglite/_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index 23ebf48..05653c9 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -42,7 +42,7 @@ class QuestionResponse(BaseModel): """.strip() class Config: - extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API json schema. + extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI's strict mode. @field_validator("question") @classmethod From 266b93aa5e0df00d9a93270c4e2171f7bdca1bec Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Thu, 5 Dec 2024 13:36:41 +0100 Subject: [PATCH 04/11] Update src/raglite/_eval.py Co-authored-by: Laurent Sorber --- src/raglite/_eval.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index 05653c9..2691de6 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -114,8 +114,7 @@ class ContextEvalResponse(BaseModel): """.strip() class Config: - extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API json schema. - + extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API's strict mode. relevant_chunks = [] for candidate_chunk in tqdm( candidate_chunks, desc="Evaluating chunks", unit="chunk", dynamic_ncols=True From 54241a58b0748f83676125c9a7c3129220620da6 Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Thu, 5 Dec 2024 13:36:54 +0100 Subject: [PATCH 05/11] Update src/raglite/_eval.py Co-authored-by: Laurent Sorber --- src/raglite/_eval.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index 2691de6..ceb805a 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -151,8 +151,7 @@ class AnswerResponse(BaseModel): """.strip() class Config: - extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API json schema. - + extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API's strict mode. try: answer_response = extract_with_llm( AnswerResponse, From 6e2d05322b9489da2d1e1d6067ac0d30a5511f3f Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Thu, 5 Dec 2024 14:22:24 +0100 Subject: [PATCH 06/11] fix: Remove strict --- src/raglite/_extract.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py index 323e902..f3d73ff 100644 --- a/src/raglite/_extract.py +++ b/src/raglite/_extract.py @@ -52,7 +52,6 @@ class MyNameResponse(BaseModel): "name": return_type.__name__, "description": return_type.__doc__ or "", "schema": return_type.model_json_schema(), - "strict": True, }, } if "response_format" From 31ce3fe7beeb8eef1a0e54d7c35e9b4316e7d5e3 Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Thu, 5 Dec 2024 14:28:17 +0100 Subject: [PATCH 07/11] fix: format --- src/raglite/_eval.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index ceb805a..a163ba5 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -42,7 +42,9 @@ class QuestionResponse(BaseModel): """.strip() class Config: - extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI's strict mode. + extra = ( + "forbid" # Ensure no extra fields are allowed as required by OpenAI's strict mode. + ) @field_validator("question") @classmethod @@ -115,6 +117,7 @@ class ContextEvalResponse(BaseModel): class Config: extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API's strict mode. + relevant_chunks = [] for candidate_chunk in tqdm( candidate_chunks, desc="Evaluating chunks", unit="chunk", dynamic_ncols=True @@ -152,6 +155,7 @@ class AnswerResponse(BaseModel): class Config: extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API's strict mode. + try: answer_response = extract_with_llm( AnswerResponse, From 1fcf1773fd1a46f19d1bb24230661501ea74d981 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Thu, 5 Dec 2024 16:06:58 +0100 Subject: [PATCH 08/11] feat: add strict mode option --- src/raglite/_eval.py | 27 +++++++++++++-------------- src/raglite/_extract.py | 12 +++++------- tests/test_extract.py | 26 +++++++++++++++----------- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index a163ba5..b5dd058 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from sqlmodel import Session, func, select from tqdm.auto import tqdm, trange @@ -25,6 +25,9 @@ def insert_evals( # noqa: C901 class QuestionResponse(BaseModel): """A specific question about the content of a set of document contexts.""" + model_config = ConfigDict( + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + ) question: str = Field( ..., description="A specific question about the content of a set of document contexts." ) @@ -41,11 +44,6 @@ class QuestionResponse(BaseModel): - The question MUST treat the context as if its contents are entirely part of your working memory. """.strip() - class Config: - extra = ( - "forbid" # Ensure no extra fields are allowed as required by OpenAI's strict mode. - ) - @field_validator("question") @classmethod def validate_question(cls, value: str) -> str: @@ -88,7 +86,7 @@ def validate_question(cls, value: str) -> str: # Extract a question from the seed chunk's related chunks. try: question_response = extract_with_llm( - QuestionResponse, related_chunks, config=config + QuestionResponse, related_chunks, strict=True, config=config ) except ValueError: continue @@ -104,6 +102,9 @@ def validate_question(cls, value: str) -> str: class ContextEvalResponse(BaseModel): """Indicate whether the provided context can be used to answer a given question.""" + model_config = ConfigDict( + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + ) hit: bool = Field( ..., description="True if the provided context contains (a part of) the answer to the given question, false otherwise.", @@ -115,16 +116,13 @@ class ContextEvalResponse(BaseModel): An example of a context that does NOT contain (a part of) the answer is a table of contents. """.strip() - class Config: - extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API's strict mode. - relevant_chunks = [] for candidate_chunk in tqdm( candidate_chunks, desc="Evaluating chunks", unit="chunk", dynamic_ncols=True ): try: context_eval_response = extract_with_llm( - ContextEvalResponse, str(candidate_chunk), config=config + ContextEvalResponse, str(candidate_chunk), strict=True, config=config ) except ValueError: # noqa: PERF203 pass @@ -138,6 +136,9 @@ class Config: class AnswerResponse(BaseModel): """Answer a question using the provided context.""" + model_config = ConfigDict( + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + ) answer: str = Field( ..., description="A complete answer to the given question using the provided context.", @@ -153,13 +154,11 @@ class AnswerResponse(BaseModel): - The answer MUST treat the context as if its contents are entirely part of your working memory. """.strip() - class Config: - extra = "forbid" # Ensure no extra fields are allowed as required by OpenAI API's strict mode. - try: answer_response = extract_with_llm( AnswerResponse, [str(relevant_chunk) for relevant_chunk in relevant_chunks], + strict=True, config=config, ) except ValueError: diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py index f3d73ff..634f6dc 100644 --- a/src/raglite/_extract.py +++ b/src/raglite/_extract.py @@ -2,7 +2,6 @@ from typing import Any, TypeVar -import litellm from litellm import completion, get_supported_openai_params # type: ignore[attr-defined] from pydantic import BaseModel, ValidationError @@ -14,6 +13,7 @@ def extract_with_llm( return_type: type[T], user_prompt: str | list[str], + strict: bool = False, # noqa: FBT001,FBT002 config: RAGLiteConfig | None = None, **kwargs: Any, ) -> T: @@ -41,8 +41,10 @@ class MyNameResponse(BaseModel): str(return_type.model_json_schema()), ) ) - # Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. + # Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode + # is disabled by default because it only supports a subset of JSON schema features [2]. # [1] https://docs.litellm.ai/docs/completion/json_mode + # [2] https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported # TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM. llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None response_format: dict[str, Any] | None = ( @@ -52,6 +54,7 @@ class MyNameResponse(BaseModel): "name": return_type.__name__, "description": return_type.__doc__ or "", "schema": return_type.model_json_schema(), + "strict": strict, }, } if "response_format" @@ -64,9 +67,6 @@ class MyNameResponse(BaseModel): f'\n{chunk.strip()}\n' for i, chunk in enumerate(user_prompt) ) - # Enable JSON schema validation. - enable_json_schema_validation = litellm.enable_json_schema_validation - litellm.enable_json_schema_validation = True # Extract structured data from the unstructured input. for _ in range(config.llm_max_tries): response = completion( @@ -89,6 +89,4 @@ class MyNameResponse(BaseModel): else: error_message = f"Failed to extract {return_type} from input {user_prompt}." raise ValueError(error_message) from last_exception - # Restore the previous JSON schema validation setting. - litellm.enable_json_schema_validation = enable_json_schema_validation return instance diff --git a/tests/test_extract.py b/tests/test_extract.py index 23eaf2b..ab9d881 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -3,7 +3,7 @@ from typing import ClassVar import pytest -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from raglite import RAGLiteConfig from raglite._extract import extract_with_llm @@ -26,18 +26,22 @@ def test_extract(llm: str) -> None: # Set the LLM. config = RAGLiteConfig(llm=llm) - # Extract structured data. + # Define the JSON schema of the response. class LoginResponse(BaseModel): + model_config = ConfigDict( + extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. + ) username: str = Field(..., description="The username.") password: str = Field(..., description="The password.") system_prompt: ClassVar[str] = "Extract the username and password from the input." - class Config: - extra = "forbid" - - username, password = "cypher", "steak" - login_response = extract_with_llm(LoginResponse, f"{username} // {password}", config=config) - # Validate the response. - assert isinstance(login_response, LoginResponse) - assert login_response.username == username - assert login_response.password == password + for strict in (False, True): + # Extract structured data. + username, password = "cypher", "steak" + login_response = extract_with_llm( + LoginResponse, f"{username} // {password}", strict=strict, config=config + ) + # Validate the response. + assert isinstance(login_response, LoginResponse) + assert login_response.username == username + assert login_response.password == password From a3cf1ed32840b98b907eeff2172baa9f115c1d0d Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Thu, 5 Dec 2024 17:01:46 +0100 Subject: [PATCH 09/11] test: parametrize test_extract --- tests/test_extract.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/test_extract.py b/tests/test_extract.py index ab9d881..90a4e9a 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -21,27 +21,28 @@ def llm(request: pytest.FixtureRequest) -> str: return llm -def test_extract(llm: str) -> None: +@pytest.mark.parametrize( + "strict", + [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")], +) +def test_extract(llm: str, strict: bool) -> None: # noqa: FBT001 """Test extracting structured data.""" # Set the LLM. config = RAGLiteConfig(llm=llm) # Define the JSON schema of the response. class LoginResponse(BaseModel): - model_config = ConfigDict( - extra="forbid" # Forbid extra attributes as required by OpenAI's strict mode. - ) + model_config = ConfigDict(extra="forbid" if strict else "allow") username: str = Field(..., description="The username.") password: str = Field(..., description="The password.") system_prompt: ClassVar[str] = "Extract the username and password from the input." - for strict in (False, True): - # Extract structured data. - username, password = "cypher", "steak" - login_response = extract_with_llm( - LoginResponse, f"{username} // {password}", strict=strict, config=config - ) - # Validate the response. - assert isinstance(login_response, LoginResponse) - assert login_response.username == username - assert login_response.password == password + # Extract structured data. + username, password = "cypher", "steak" + login_response = extract_with_llm( + LoginResponse, f"{username} // {password}", strict=strict, config=config + ) + # Validate the response. + assert isinstance(login_response, LoginResponse) + assert login_response.username == username + assert login_response.password == password From cb049a6a412bd30dd09c7161d3f1688a500364ac Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Thu, 5 Dec 2024 17:32:07 +0100 Subject: [PATCH 10/11] fix: Add schema to the prompt conditionally. --- src/raglite/_extract.py | 20 ++++++++++---------- tests/test_extract.py | 5 +++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py index 634f6dc..476420b 100644 --- a/src/raglite/_extract.py +++ b/src/raglite/_extract.py @@ -33,20 +33,21 @@ class MyNameResponse(BaseModel): """ # Load the default config if not provided. config = config or RAGLiteConfig() - # Update the system prompt with the JSON schema of the return type to help the LLM. - system_prompt = "\n".join( - ( - return_type.system_prompt.strip(), # type: ignore[attr-defined] - "Format your response according to this JSON schema:", - str(return_type.model_json_schema()), - ) + # Check if the LLM supports the response format. + llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None + supports_response_format = "response_format" in ( + get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or [] ) + # Update the system prompt with the JSON schema of the return type to help the LLM. + system_prompt = return_type.system_prompt.strip() # type: ignore[attr-defined] + if not supports_response_format: + system_prompt += f"\n\nFormat your response according to this JSON schema:\n{return_type.model_json_schema()!s}" + # Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode # is disabled by default because it only supports a subset of JSON schema features [2]. # [1] https://docs.litellm.ai/docs/completion/json_mode # [2] https://platform.openai.com/docs/guides/structured-outputs#some-type-specific-keywords-are-not-yet-supported # TODO: Fall back to {"type": "json_object"} if JSON schema is not supported by the LLM. - llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None response_format: dict[str, Any] | None = ( { "type": "json_schema", @@ -57,8 +58,7 @@ class MyNameResponse(BaseModel): "strict": strict, }, } - if "response_format" - in (get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or []) + if supports_response_format else None ) # Concatenate the user prompt if it is a list of strings. diff --git a/tests/test_extract.py b/tests/test_extract.py index 90a4e9a..33ef6e0 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -22,8 +22,7 @@ def llm(request: pytest.FixtureRequest) -> str: @pytest.mark.parametrize( - "strict", - [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")], + "strict", [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")] ) def test_extract(llm: str, strict: bool) -> None: # noqa: FBT001 """Test extracting structured data.""" @@ -32,6 +31,8 @@ def test_extract(llm: str, strict: bool) -> None: # noqa: FBT001 # Define the JSON schema of the response. class LoginResponse(BaseModel): + """The response to a login request.""" + model_config = ConfigDict(extra="forbid" if strict else "allow") username: str = Field(..., description="The username.") password: str = Field(..., description="The password.") From dc279c04561f2e17818362eb31887eff63820108 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Thu, 5 Dec 2024 20:55:42 +0100 Subject: [PATCH 11/11] fix: add schema to system prompt for llama.cpp models --- src/raglite/_extract.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py index 476420b..bd85d47 100644 --- a/src/raglite/_extract.py +++ b/src/raglite/_extract.py @@ -35,14 +35,13 @@ class MyNameResponse(BaseModel): config = config or RAGLiteConfig() # Check if the LLM supports the response format. llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None - supports_response_format = "response_format" in ( + llm_supports_response_format = "response_format" in ( get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or [] ) # Update the system prompt with the JSON schema of the return type to help the LLM. - system_prompt = return_type.system_prompt.strip() # type: ignore[attr-defined] - if not supports_response_format: + system_prompt = getattr(return_type, "system_prompt", "").strip() + if not llm_supports_response_format or llm_provider == "llama-cpp-python": system_prompt += f"\n\nFormat your response according to this JSON schema:\n{return_type.model_json_schema()!s}" - # Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode # is disabled by default because it only supports a subset of JSON schema features [2]. # [1] https://docs.litellm.ai/docs/completion/json_mode @@ -58,7 +57,7 @@ class MyNameResponse(BaseModel): "strict": strict, }, } - if supports_response_format + if llm_supports_response_format else None ) # Concatenate the user prompt if it is a list of strings.