diff --git a/src/damavand/base/controllers/llm.py b/src/damavand/base/controllers/llm.py index b45ba63..44915ae 100644 --- a/src/damavand/base/controllers/llm.py +++ b/src/damavand/base/controllers/llm.py @@ -1,6 +1,10 @@ +from functools import cache import logging +from typing import Optional from damavand.base.controllers import ApplicationController +from damavand.base.controllers.base_controller import runtime +from damavand.errors import RuntimeException logger = logging.getLogger(__name__) @@ -12,7 +16,57 @@ class LlmController(ApplicationController): def __init__( self, name, + model: Optional[str] = None, tags: dict[str, str] = {}, **kwargs, ) -> None: ApplicationController.__init__(self, name, tags, **kwargs) + self._model_name = model + + @property + def model_id(self) -> str: + """Return the model name/ID.""" + + return self._model_name or "microsoft/Phi-3-mini-4k-instruct" + + @property + @runtime + @cache + def base_url(self) -> str: + """Return the base URL for the LLM API.""" + + raise NotImplementedError + + @property + @runtime + @cache + def default_api_key(self) -> str: + """Return the default API key.""" + + raise NotImplementedError + + @property + @runtime + @cache + def chat_completions_url(self) -> str: + """Return the chat completions URL.""" + + return f"{self.base_url}/chat/completions" + + @property + @runtime + @cache + def client(self) -> "openai.OpenAI": # type: ignore # noqa + """Return an OpenAI client.""" + + try: + import openai # type: ignore # noqa + except ImportError: + raise RuntimeException( + "Failed to import OpenAI library. Damavand provide this library as an optional dependency. Try to install it using `pip install damavand[openai]` or directly install it using pip or your dependency manager." + ) + + return openai.OpenAI( + api_key=self.default_api_key, + base_url=f"{self.base_url}", + ) diff --git a/src/damavand/cloud/aws/controllers/llm.py b/src/damavand/cloud/aws/controllers/llm.py index d07d70b..14f05a1 100644 --- a/src/damavand/cloud/aws/controllers/llm.py +++ b/src/damavand/cloud/aws/controllers/llm.py @@ -24,17 +24,10 @@ def __init__( tags: dict[str, str] = {}, **kwargs, ) -> None: - super().__init__(name, tags, **kwargs) + super().__init__(name, model, tags, **kwargs) self._parameter_store = boto3.client("ssm") - self._model_name = model self._region = region - @property - def model_id(self) -> str: - """Return the model name/ID.""" - - return self._model_name or "microsoft/Phi-3-mini-4k-instruct" - @property def _base_url_ssm_name(self) -> str: """Return the SSM parameter name for the base url.""" @@ -78,28 +71,10 @@ def base_url(self) -> str: @property @runtime @cache - def chat_completions_url(self) -> str: - """Return the chat completions URL.""" - - return f"{self.base_url}/chat/completions" + def default_api_key(self) -> str: + """Return the default API key.""" - @property - @runtime - @cache - def client(self) -> "openai.OpenAI": # type: ignore # noqa - """Return an OpenAI client.""" - - try: - import openai # type: ignore # noqa - except ImportError: - raise RuntimeException( - "Failed to import OpenAI library. Damavand provide this library as an optional dependency. Try to install it using `pip install damavand[openai]` or directly install it using pip or your dependency manager." - ) - - return openai.OpenAI( - api_key="EMPTY", - base_url=f"{self.base_url}", - ) + return "EMPTY" @buildtime @cache