diff --git a/src/raglite/_litellm.py b/src/raglite/_litellm.py index 31bf279..9f1d0ea 100644 --- a/src/raglite/_litellm.py +++ b/src/raglite/_litellm.py @@ -22,6 +22,7 @@ get_model_info, ) from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.utils import custom_llm_setup from llama_cpp import ( # type: ignore[attr-defined] ChatCompletionRequestMessage, CreateChatCompletionResponse, @@ -33,6 +34,7 @@ from raglite._config import RAGLiteConfig # Reduce the logging level for LiteLLM, flashrank, and httpx. +litellm.suppress_debug_info = True os.environ["LITELLM_LOG"] = "WARNING" logging.getLogger("LiteLLM").setLevel(logging.WARNING) logging.getLogger("flashrank").setLevel(logging.WARNING) @@ -125,24 +127,23 @@ def llm(model: str, **kwargs: Any) -> Llama: # Enable caching. llm.set_cache(LlamaRAMCache()) # Register the model info with LiteLLM. - litellm.register_model( # type: ignore[attr-defined] - { - model: { - "max_tokens": llm.n_ctx(), - "max_input_tokens": llm.n_ctx(), - "max_output_tokens": None, - "input_cost_per_token": 0.0, - "output_cost_per_token": 0.0, - "output_vector_size": llm.n_embd() if kwargs.get("embedding") else None, - "litellm_provider": "llama-cpp-python", - "mode": "embedding" if kwargs.get("embedding") else "completion", - "supported_openai_params": LlamaCppPythonLLM.supported_openai_params, - "supports_function_calling": True, - "supports_parallel_function_calling": True, - "supports_vision": False, - } + model_info = { + model: { + "max_tokens": llm.n_ctx(), + "max_input_tokens": llm.n_ctx(), + "max_output_tokens": None, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "output_vector_size": llm.n_embd() if kwargs.get("embedding") else None, + "litellm_provider": "llama-cpp-python", + "mode": "embedding" if kwargs.get("embedding") else "completion", + "supported_openai_params": LlamaCppPythonLLM.supported_openai_params, + "supports_function_calling": True, + "supports_parallel_function_calling": True, + "supports_vision": False, } - ) + } + litellm.register_model(model_info) # type: ignore[attr-defined] return llm def _translate_openai_params(self, optional_params: dict[str, Any]) -> dict[str, Any]: @@ -307,7 +308,7 @@ async def astreaming( # type: ignore[misc,override] # noqa: PLR0913 litellm.custom_provider_map.append( {"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()} ) - litellm.suppress_debug_info = True + custom_llm_setup() # type: ignore[no-untyped-call] @cache