diff --git a/poetry.lock b/poetry.lock index a96a8c614..b40ab0626 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "alfred-cli" @@ -2192,13 +2192,13 @@ files = [ [[package]] name = "writer-sdk" -version = "0.1.2" +version = "0.5.0" description = "The official Python library for the writer API" optional = false python-versions = ">=3.7" files = [ - {file = "writer_sdk-0.1.2-py3-none-any.whl", hash = "sha256:2072f5c13b8011a0c2ebe1b63fbc041642e9e9ca6eab7e976f7b3f20ad9931e5"}, - {file = "writer_sdk-0.1.2.tar.gz", hash = "sha256:3fa0b09a57ba969ca344024fa46305cd441cb45cb34d26d65a51a743477a4315"}, + {file = "writer_sdk-0.5.0-py3-none-any.whl", hash = "sha256:654e08fa0040126b8a8ae2468a9f2185997d94926a9322cee6278a721d098598"}, + {file = "writer_sdk-0.5.0.tar.gz", hash = "sha256:984f289f2576f8fbaec6893e9c83a0810a3084fbc536cc22666e0f620e8eff64"}, ] [package.dependencies] @@ -2212,4 +2212,4 @@ typing-extensions = ">=4.7,<5" [metadata] lock-version = "2.0" python-versions = ">=3.9.2, <4.0" -content-hash = "2b7dece7f8e0c504b0d1c393f324b8f86e382910106db130a9cfc129786fc138" +content-hash = "087a35e24f0286c5063c952de4bf460b036579e4956f1e598b9ac2172fa82cdd" diff --git a/pyproject.toml b/pyproject.toml index 217785bb0..fdf78e87b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ requests = "^2.31.0" uvicorn = ">= 0.20.0, < 1" watchdog = ">= 3.0.0, < 4" websockets = ">= 12, < 13" -writer-sdk = ">= 0.1.2, < 1" +writer-sdk = ">= 0.5.0, < 1" [tool.poetry.group.build] diff --git a/src/writer/ai.py b/src/writer/ai.py index 5e5e0c5a4..ae9655de2 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -1,31 +1,58 @@ import logging -from typing import Generator, Iterable, List, Literal, Optional, TypedDict, Union, cast +from datetime import datetime +from typing import ( + Generator, + Iterable, + List, + Literal, + Optional, + Set, + TypedDict, + Union, + cast, +) +from uuid import uuid4 from httpx import Timeout from writerai import Writer from writerai._exceptions import WriterError +from writerai._response import BinaryAPIResponse from writerai._streaming import Stream from writerai._types import Body, Headers, NotGiven, Query -from writerai.types import Chat, Completion, StreamingData +from writerai.resources import FilesResource, GraphsResource +from writerai.types import ( + Chat, + Completion, + FileDeleteResponse, + GraphDeleteResponse, + GraphRemoveFileFromGraphResponse, + GraphUpdateResponse, + StreamingData, +) +from writerai.types import File as SDKFile +from writerai.types import Graph as SDKGraph from writerai.types.chat_chat_params import Message as WriterAIMessage from writer.core import get_app_process -class ChatOptions(TypedDict, total=False): +class APIOptions(TypedDict, total=False): + extra_headers: Optional[Headers] + extra_query: Optional[Query] + extra_body: Optional[Body] + timeout: Union[float, Timeout, None, NotGiven] + + +class ChatOptions(APIOptions, total=False): model: str max_tokens: Union[int, NotGiven] n: Union[int, NotGiven] stop: Union[List[str], str, NotGiven] temperature: Union[float, NotGiven] top_p: Union[float, NotGiven] - extra_headers: Optional[Headers] - extra_query: Optional[Query] - extra_body: Optional[Body] - timeout: Union[float, Timeout, None, NotGiven] -class CreateOptions(TypedDict, total=False): +class CreateOptions(APIOptions, total=False): model: str best_of: Union[int, NotGiven] max_tokens: Union[int, NotGiven] @@ -33,10 +60,13 @@ class CreateOptions(TypedDict, total=False): stop: Union[List[str], str, NotGiven] temperature: Union[float, NotGiven] top_p: Union[float, NotGiven] - extra_headers: Optional[Headers] - extra_query: Optional[Query] - extra_body: Optional[Body] - timeout: Union[float, Timeout, None, NotGiven] + + +class APIListOptions(APIOptions, total=False): + after: Union[str, NotGiven] + before: Union[str, NotGiven] + limit: Union[int, NotGiven] + order: Union[Literal["asc", "desc"], NotGiven] logger = logging.Logger(__name__) @@ -141,6 +171,510 @@ def acquire_client(cls) -> Writer: return instance.client +class SDKWrapper: + """ + A wrapper class for SDK objects, allowing dynamic access to properties. + + Attributes: + _wrapped (Union[SDKFile, SDKGraph]): The wrapped SDK object. + """ + _wrapped: Union[SDKFile, SDKGraph] + + def _get_property(self, property_name): + """ + Retrieves a property from the wrapped object. + + :param property_name: The name of the property to retrieve. + :type property_name: str + :returns: The value of the requested property. + :raises AttributeError: If the property does not exist. + """ + try: + return getattr(self._wrapped, property_name) + except AttributeError: + raise AttributeError( + f"type object '{self.__class__}' has no attribute {property_name}" + ) from None + + +class Graph(SDKWrapper): + """ + A wrapper class for SDKGraph objects, providing additional functionality. + + Attributes: + _wrapped (writerai.types.Graph): The wrapped SDK Graph object. + stale_ids (set): A set of stale graph IDs that need updates. + """ + _wrapped: SDKGraph + stale_ids: Set[str] = set() + + def __init__( + self, + graph_object: SDKGraph + ): + """ + Initializes the Graph with the given SDKGraph object. + + :param graph_object: The SDKGraph object to wrap. + :type graph_object: writerai.types.Graph + """ + self._wrapped = graph_object + + @staticmethod + def _retrieve_graphs_accessor() -> GraphsResource: + """ + Acquires the graphs accessor from the WriterAIManager singleton instance. + + :returns: The graphs accessor instance. + :rtype: GraphsResource + """ + return WriterAIManager.acquire_client().graphs + + @property + def id(self) -> str: + return self._get_property('id') + + @property + def created_at(self) -> datetime: + return self._get_property('created_at') + + def _fetch_object_updates(self): + """ + Fetches updates for the graph object if it is stale. + """ + if self.id in Graph.stale_ids: + graphs = self._retrieve_graphs_accessor() + fresh_object = graphs.retrieve(self.id) + self._wrapped = fresh_object + Graph.stale_ids.remove(self.id) + + @property + def name(self) -> str: + self._fetch_object_updates() + return self._wrapped.name + + @property + def description(self) -> Optional[str]: + self._fetch_object_updates() + return self._wrapped.description + + @property + def file_status(self): + self._fetch_object_updates() + return self._wrapped.file_status + + def update( + self, + name: Optional[str] = None, + description: Optional[str] = None, + config: Optional[APIOptions] = None + ) -> GraphUpdateResponse: + """ + Updates the graph with the given parameters. + + :param name: The new name for the graph. + :type name: Optional[str] + :param description: The new description for the graph. + :type description: Optional[str] + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The response from the update operation. + :rtype: GraphUpdateResponse + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ + config = config or {} + + # We use the payload dictionary + # to distinguish between None-values + # and NotGiven values + payload = {} + if name: + payload["name"] = name + if description: + payload["description"] = description + graphs = self._retrieve_graphs_accessor() + response = graphs.update(self.id, **payload, **config) + Graph.stale_ids.add(self.id) + return response + + def add_file( + self, + file_id_or_file: Union['File', str], + config: Optional[APIOptions] = None + ) -> 'File': + """ + Adds a file to the graph. + + :param file_id_or_file: The file object or file ID to add. + :type file_id_or_file: Union['File', str] + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The added file object. + :rtype: File + :raises ValueError: If the input is neither a File object nor a file ID string. + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ + config = config or {} + file_id = None + if isinstance(file_id_or_file, File): + file_id = file_id_or_file.id + elif isinstance(file_id_or_file, str): + file_id = file_id_or_file + else: + raise ValueError( + "'Graph.add_file' method accepts either 'File' object" + + f" or ID of file as string; got '{type(file_id_or_file)}'" + ) + graphs = self._retrieve_graphs_accessor() + response = graphs.add_file_to_graph( + graph_id=self.id, + file_id=file_id, + **config + ) + Graph.stale_ids.add(self.id) + return File(response) + + def remove_file( + self, + file_id_or_file: Union['File', str], + config: Optional[APIOptions] = None + ) -> Optional[GraphRemoveFileFromGraphResponse]: + """ + Removes a file from the graph. + + :param file_id_or_file: The file object or file ID to remove. + :type file_id_or_file: Union['File', str] + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The response from the remove operation. + :rtype: Optional[GraphRemoveFileFromGraphResponse] + :raises ValueError: If the input is neither a File object nor a file ID string. + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ + config = config or {} + file_id = None + if isinstance(file_id_or_file, File): + file_id = file_id_or_file.id + elif isinstance(file_id_or_file, str): + file_id = file_id_or_file + else: + raise ValueError( + "'Graph.remove_file' method accepts either 'File' object" + + f" or ID of file as string; got '{type(file_id_or_file)}'" + ) + graphs = self._retrieve_graphs_accessor() + response = graphs.remove_file_from_graph( + graph_id=self.id, + file_id=file_id + ) + Graph.stale_ids.add(self.id) + return response + + +def create_graph( + name: str, + description: Optional[str] = None, + config: Optional[APIOptions] = None + ) -> Graph: + """ + Creates a new graph with the given parameters. + + :param name: The name of the graph. + :type name: str + :param description: The description of the graph. + :type description: Optional[str] + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The created graph object. + :rtype: Graph + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ + config = config or {} + graphs = Graph._retrieve_graphs_accessor() + graph_object = graphs.create(name=name, description=description or NotGiven(), **config) + converted_object = cast(SDKGraph, graph_object) + graph = Graph(converted_object) + return graph + + +def retrieve_graph( + graph_id: str, + config: Optional[APIOptions] = None + ) -> Graph: + """ + Retrieves a graph by its ID. + + :param graph_id: The ID of the graph to retrieve. + :type graph_id: str + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The retrieved graph object. + :rtype: Graph + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ + config = config or {} + graphs = Graph._retrieve_graphs_accessor() + graph_object = graphs.retrieve(graph_id, **config) + graph = Graph(graph_object) + return graph + + +def list_graphs(config: Optional[APIListOptions] = None) -> List[Graph]: + """ + Lists all graphs with the given configuration. + + :param config: Additional configuration options. + :type config: Optional[APIListOptions] + :returns: A list of graph objects. + :rtype: List[Graph] + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + - `after` (Union[str, NotGiven]): Filter to retrieve items created after a specific cursor. + - `before` (Union[str, NotGiven]): Filter to retrieve items created before a specific cursor. + - `limit` (Union[int, NotGiven]): The number of items to retrieve. + - `order` (Union[Literal["asc", "desc"], NotGiven]): The order in which to retrieve items. + """ + config = config or {} + graphs = Graph._retrieve_graphs_accessor() + sdk_graphs = graphs.list(**config) + return [Graph(sdk_graph) for sdk_graph in sdk_graphs] + + +def delete_graph(graph_id_or_graph: Union[Graph, str]) -> GraphDeleteResponse: + """ + Deletes a graph by its ID or object. + + :param graph_id_or_graph: The graph object or graph ID to delete. + :type graph_id_or_graph: Union[Graph, str] + :returns: The response from the delete operation. + :rtype: GraphDeleteResponse + :raises ValueError: If the input is neither a Graph object nor a graph ID string. + """ + graph_id = None + if isinstance(graph_id_or_graph, Graph): + graph_id = graph_id_or_graph.id + elif isinstance(graph_id_or_graph, str): + graph_id = graph_id_or_graph + else: + raise ValueError( + "'delete_graph' method accepts either 'Graph' object" + + f" or ID of graph as string; got '{type(graph_id_or_graph)}'" + ) + graphs = Graph._retrieve_graphs_accessor() + return graphs.delete(graph_id) + + +class File(SDKWrapper): + """ + A wrapper class for SDK File objects, providing additional functionality. + + Attributes: + _wrapped (writerai.types.File): The wrapped SDKFile object. + """ + _wrapped: SDKFile + + def __init__(self, file_object: SDKFile): + """ + Initializes the File with the given SDKFile object. + + :param file_object: The SDKFile object to wrap. + :type file_object: writerai.types.File + """ + self._wrapped = file_object + + @staticmethod + def _retrieve_files_accessor() -> FilesResource: + """ + Acquires the files client from the WriterAIManager singleton instance. + + :returns: The files client instance. + :rtype: FilesResource + """ + return WriterAIManager.acquire_client().files + + @property + def id(self) -> str: + return self._get_property('id') + + @property + def created_at(self) -> datetime: + return self._get_property('created_at') + + @property + def graph_ids(self) -> List[str]: + return self._get_property('graph_ids') + + @property + def name(self) -> str: + return self._get_property('name') + + def download(self) -> BinaryAPIResponse: + """ + Downloads the file content. + + :returns: The response containing the file content. + :rtype: BinaryAPIResponse + """ + files = self._retrieve_files_accessor() + return files.download(self.id) + + +def retrieve_file(file_id: str, config: Optional[APIOptions] = None) -> File: + """ + Retrieves a file by its ID. + + :param file_id: The ID of the file to retrieve. + :type file_id: str + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The retrieved file object. + :rtype: writerai.types.File + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ + config = config or {} + files = File._retrieve_files_accessor() + file_object = files.retrieve(file_id, **config) + file = File(file_object) + return file + + +def list_files(config: Optional[APIListOptions] = None) -> List[File]: + """ + Lists all files with the given configuration. + + :param config: Additional configuration options. + :type config: Optional[APIListOptions] + :returns: A list of file objects. + :rtype: List[File] + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + - `after` (Union[str, NotGiven]): Filter to retrieve items created after a specific cursor. + - `before` (Union[str, NotGiven]): Filter to retrieve items created before a specific cursor. + - `limit` (Union[int, NotGiven]): The number of items to retrieve. + - `order` (Union[Literal["asc", "desc"], NotGiven]): The order in which to retrieve items. + """ + config = config or {} + files = File._retrieve_files_accessor() + sdk_files = files.list(**config) + return [File(sdk_file) for sdk_file in sdk_files] + + +def upload_file( + data: bytes, + type: str, + name: Optional[str] = None, + config: Optional[APIOptions] = None + ) -> File: + """ + Uploads a new file with the given parameters. + + :param data: The file content as bytes. + :type data: bytes + :param type: The MIME type of the file. + :type type: str + :param name: The name of the file. + :type name: Optional[str] + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The uploaded file object. + :rtype: writerai.types.File + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ + config = config or {} + files = File._retrieve_files_accessor() + + file_name = name or f"WF-{type}-{uuid4()}" + content_disposition = f'attachment; filename="{file_name}"' + + # Now calling the upload method with correct types. + sdk_file = files.upload( + content=data, + content_type=type, + content_disposition=content_disposition + ) + return File(sdk_file) + + +def delete_file( + file_id_or_file: Union['File', str], + config: Optional[APIOptions] = None + ) -> FileDeleteResponse: + """ + Deletes a file by its ID or object. + + :param file_id_or_file: The file object or file ID to delete. + :type file_id_or_file: Union['File', str] + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The response from the delete operation. + :rtype: FileDeleteResponse + :raises ValueError: If the input is neither a File object nor a file ID string. + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ + config = config or {} + file_id = None + if isinstance(file_id_or_file, File): + file_id = file_id_or_file.id + elif isinstance(file_id_or_file, str): + file_id = file_id_or_file + else: + raise ValueError( + "'delete_file' method accepts either 'File' object" + + f" or ID of file as string; got '{type(file_id_or_file)}'" + ) + + files = File._retrieve_files_accessor() + return files.delete(file_id, **config) + + class Conversation: """ Manages messages within a conversation flow with an AI system, including message validation, @@ -307,8 +841,7 @@ def complete(self, config: Optional['ChatOptions'] = None) -> 'Conversation.Mess :return: Generated message. :raises RuntimeError: If response data was not properly formatted to retrieve model text. """ - if not config: - config = {'max_tokens': 2048} + config = config or {'max_tokens': 2048} client = WriterAIManager.acquire_client() passed_messages: Iterable[WriterAIMessage] = [self._prepare_message(message) for message in self.messages] @@ -343,8 +876,7 @@ def stream_complete(self, config: Optional['ChatOptions'] = None) -> Generator[d :param config: Optional parameters to pass for processing. :yields: Model response chunks as they arrive from the stream. """ - if not config: - config = {'max_tokens': 2048} + config = config or {'max_tokens': 2048} client = WriterAIManager.acquire_client() passed_messages: Iterable[WriterAIMessage] = [self._prepare_message(message) for message in self.messages] @@ -403,8 +935,7 @@ def complete(initial_text: str, config: Optional['CreateOptions'] = None) -> str :return: The text of the first choice from the completion response. :raises RuntimeError: If response data was not properly formatted to retrieve model text. """ - if not config: - config = {} + config = config or {} client = WriterAIManager.acquire_client() request_model = config.get("model", None) or WriterAIManager.use_completion_model()