From 7b82c7827e00a8bbc6e9ac858c68cdb6ad3d3391 Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Tue, 9 Jul 2024 09:39:34 +0300 Subject: [PATCH 01/14] feat: management for graphs & files in writer.ai --- src/writer/ai.py | 202 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 189 insertions(+), 13 deletions(-) diff --git a/src/writer/ai.py b/src/writer/ai.py index 5e5e0c5a4..48a2371f7 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -1,31 +1,41 @@ import logging -from typing import Generator, Iterable, List, Literal, Optional, TypedDict, Union, cast +from typing import (Generator, Iterable, List, Literal, Optional, TypedDict, + Union, cast) from httpx import Timeout from writerai import Writer from writerai._exceptions import WriterError +from writerai._response import BinaryAPIResponse from writerai._streaming import Stream -from writerai._types import Body, Headers, NotGiven, Query -from writerai.types import Chat, Completion, StreamingData +from writerai._types import Body, FileTypes, Headers, NotGiven, Query +from writerai.pagination import SyncCursorPage +from writerai.types import (Chat, Completion, File, FileDeleteResponse, Graph, + GraphCreateResponse, GraphDeleteResponse, + GraphRemoveFileFromGraphResponse, + GraphUpdateResponse, StreamingData) from writerai.types.chat_chat_params import Message as WriterAIMessage from writer.core import get_app_process +from writer.ss_types import WriterFileItem -class ChatOptions(TypedDict, total=False): +class APIOptions(TypedDict, total=False): + extra_headers: Optional[Headers] + extra_query: Optional[Query] + extra_body: Optional[Body] + timeout: Union[float, Timeout, None, NotGiven] + + +class ChatOptions(APIOptions, total=False): model: str max_tokens: Union[int, NotGiven] n: Union[int, NotGiven] stop: Union[List[str], str, NotGiven] temperature: Union[float, NotGiven] top_p: Union[float, NotGiven] - extra_headers: Optional[Headers] - extra_query: Optional[Query] - extra_body: Optional[Body] - timeout: Union[float, Timeout, None, NotGiven] -class CreateOptions(TypedDict, total=False): +class CreateOptions(APIOptions, total=False): model: str best_of: Union[int, NotGiven] max_tokens: Union[int, NotGiven] @@ -33,10 +43,28 @@ class CreateOptions(TypedDict, total=False): stop: Union[List[str], str, NotGiven] temperature: Union[float, NotGiven] top_p: Union[float, NotGiven] - extra_headers: Optional[Headers] - extra_query: Optional[Query] - extra_body: Optional[Body] - timeout: Union[float, Timeout, None, NotGiven] + + +class GraphCreateUpdateOptions(APIOptions, total=False): + name: str + description: Union[str, NotGiven] + + +class APIListOptions(APIOptions, total=False): + after: Union[str, NotGiven] + before: Union[str, NotGiven] + limit: Union[int, NotGiven] + order: Union[Literal["asc", "desc"], NotGiven] + + +class GraphAddFileOptions(APIOptions, total=False): + file_id: str + + +class FileAddOptions(APIOptions, total=False): + content: FileTypes + content_disposition: str + content_type: str logger = logging.Logger(__name__) @@ -141,6 +169,154 @@ def acquire_client(cls) -> Writer: return instance.client +class WriterGraphManager: + """ + Manages graph-related operations using the Writer AI API. + + Provides methods to create, retrieve, update, delete, and manage files within graphs. + """ + + def __init__(self): + """ + Initializes a WriterGraphManager instance. + """ + pass + + @classmethod + def retrieve_graphs_accessor(cls): + """ + Acquires the graphs accessor from the WriterAIManager singleton instance. + + :returns: The graphs accessor instance. + """ + return WriterAIManager.acquire_client().graphs + + @classmethod + def create_graph( + cls, + name: str, + description: str, + config: APIOptions = None + ) -> GraphCreateResponse: + if not config: + config = {} + graphs = cls.retrieve_graphs_accessor() + return graphs.create(name=name, description=description, **config) + + @classmethod + def retrieve_graph(cls, graph_id: str) -> Graph: + graphs = cls.retrieve_graphs_accessor() + return graphs.retrieve(graph_id) + + @classmethod + def update_graph( + cls, + graph_id: str, + name: Optional[str] = None, + description: Optional[str] = None, + config: APIOptions = None + ) -> GraphUpdateResponse: + if not config: + config = {} + + # We use the payload dictionary + # to distinguish between None-values + # and NotGiven values + payload = {} + if name: + payload["name"] = name + if description: + payload["description"] = description + graphs = cls.retrieve_graphs_accessor() + return graphs.update(graph_id, **payload, **config) + + @classmethod + def list_graphs(cls, config: APIListOptions = None) -> SyncCursorPage[Graph]: + if not config: + config = {} + graphs = cls.retrieve_graphs_accessor() + return graphs.list(**config) + + @classmethod + def delete_graph(cls, graph_id: str) -> GraphDeleteResponse: + graphs = cls.retrieve_graphs_accessor() + return graphs.delete(graph_id) + + @classmethod + def add_file_to_graph(cls, graph_id: str, config: GraphAddFileOptions = None) -> File: + if not config: + config = {} + graphs = cls.retrieve_graphs_accessor() + return graphs.add_file_to_graph(graph_id, **config) + + @classmethod + def remove_file_from_graph(cls, graph_id: str, file_id: str) -> GraphRemoveFileFromGraphResponse: + graphs = cls.retrieve_graphs_accessor() + return graphs.remove_file_from_graph(graph_id, file_id) + + +class WriterFileManager: + """ + Manages file-related operations using the Writer AI API. + + Provides methods to retrieve, list, delete, download, and upload files. + """ + + def __init__(self): + """ + Initializes a WriterFileManager instance. + """ + pass + + @classmethod + def retrieve_files_accessor(cls): + """ + Acquires the files client from the WriterAIManager singleton instance. + + :returns: The files client instance. + """ + return WriterAIManager.acquire_client().files + + @classmethod + def retrieve_file(cls, file_id: str) -> File: + files = cls.retrieve_files_accessor() + return files.retrieve(file_id) + + @classmethod + def list_files(cls, config: APIListOptions = None) -> SyncCursorPage[File]: + if not config: + config = {} + files = cls.retrieve_files_accessor() + return files.list(**config) + + @classmethod + def delete_file(cls, file_id: str) -> FileDeleteResponse: + files = cls.retrieve_files_accessor() + return files.delete(file_id) + + @classmethod + def download_file(cls, file_id: str) -> BinaryAPIResponse: + files = cls.retrieve_files_accessor() + return files.download(file_id) + + @classmethod + def upload_file(cls, file: WriterFileItem, config: APIOptions) -> File: + files = cls.retrieve_files_accessor() + if "data" not in file: + raise ValueError("Missing `data` in file payload") + if "type" not in file: + raise ValueError("Missing `type` in file payload") + if "name" not in file: + raise ValueError("Missing `name` in file payload") + + uploaded_file = { + "content": file["data"], + "content_type": file["type"], + "content_disposition": f'attachment;filename="{file["name"]}"' + } + return files.upload(**uploaded_file, **config) + + class Conversation: """ Manages messages within a conversation flow with an AI system, including message validation, From adbea60f9c48f7ad38b9ad0339f0e89e3b2972f9 Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Tue, 9 Jul 2024 09:44:48 +0300 Subject: [PATCH 02/14] fix: lint --- src/writer/ai.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/writer/ai.py b/src/writer/ai.py index 48a2371f7..003a15364 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -1,6 +1,5 @@ import logging -from typing import (Generator, Iterable, List, Literal, Optional, TypedDict, - Union, cast) +from typing import Generator, Iterable, List, Literal, Optional, TypedDict, Union, cast from httpx import Timeout from writerai import Writer @@ -9,10 +8,18 @@ from writerai._streaming import Stream from writerai._types import Body, FileTypes, Headers, NotGiven, Query from writerai.pagination import SyncCursorPage -from writerai.types import (Chat, Completion, File, FileDeleteResponse, Graph, - GraphCreateResponse, GraphDeleteResponse, - GraphRemoveFileFromGraphResponse, - GraphUpdateResponse, StreamingData) +from writerai.types import ( + Chat, + Completion, + File, + FileDeleteResponse, + Graph, + GraphCreateResponse, + GraphDeleteResponse, + GraphRemoveFileFromGraphResponse, + GraphUpdateResponse, + StreamingData, +) from writerai.types.chat_chat_params import Message as WriterAIMessage from writer.core import get_app_process From 6cf44cebd53eacba391277ae9ae912ae8353cdbd Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Tue, 9 Jul 2024 09:50:57 +0300 Subject: [PATCH 03/14] fix: typing --- src/writer/ai.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/writer/ai.py b/src/writer/ai.py index 003a15364..82d8f18dc 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -203,7 +203,7 @@ def create_graph( cls, name: str, description: str, - config: APIOptions = None + config: Optional[APIOptions] = None ) -> GraphCreateResponse: if not config: config = {} @@ -221,7 +221,7 @@ def update_graph( graph_id: str, name: Optional[str] = None, description: Optional[str] = None, - config: APIOptions = None + config: Optional[APIOptions] = None ) -> GraphUpdateResponse: if not config: config = {} @@ -238,7 +238,7 @@ def update_graph( return graphs.update(graph_id, **payload, **config) @classmethod - def list_graphs(cls, config: APIListOptions = None) -> SyncCursorPage[Graph]: + def list_graphs(cls, config: Optional[APIListOptions] = None) -> SyncCursorPage[Graph]: if not config: config = {} graphs = cls.retrieve_graphs_accessor() @@ -250,7 +250,7 @@ def delete_graph(cls, graph_id: str) -> GraphDeleteResponse: return graphs.delete(graph_id) @classmethod - def add_file_to_graph(cls, graph_id: str, config: GraphAddFileOptions = None) -> File: + def add_file_to_graph(cls, graph_id: str, config: Optional[GraphAddFileOptions] = None) -> File: if not config: config = {} graphs = cls.retrieve_graphs_accessor() @@ -290,7 +290,7 @@ def retrieve_file(cls, file_id: str) -> File: return files.retrieve(file_id) @classmethod - def list_files(cls, config: APIListOptions = None) -> SyncCursorPage[File]: + def list_files(cls, config: Optional[APIListOptions] = None) -> SyncCursorPage[File]: if not config: config = {} files = cls.retrieve_files_accessor() From c0c74e2ff9b3003aa375427ddde966d4a02ea1f7 Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Tue, 9 Jul 2024 19:00:19 +0300 Subject: [PATCH 04/14] fix: removing redundant TypedDicts --- src/writer/ai.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/src/writer/ai.py b/src/writer/ai.py index 82d8f18dc..c5190c86b 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -6,7 +6,7 @@ from writerai._exceptions import WriterError from writerai._response import BinaryAPIResponse from writerai._streaming import Stream -from writerai._types import Body, FileTypes, Headers, NotGiven, Query +from writerai._types import Body, Headers, NotGiven, Query from writerai.pagination import SyncCursorPage from writerai.types import ( Chat, @@ -52,11 +52,6 @@ class CreateOptions(APIOptions, total=False): top_p: Union[float, NotGiven] -class GraphCreateUpdateOptions(APIOptions, total=False): - name: str - description: Union[str, NotGiven] - - class APIListOptions(APIOptions, total=False): after: Union[str, NotGiven] before: Union[str, NotGiven] @@ -64,16 +59,6 @@ class APIListOptions(APIOptions, total=False): order: Union[Literal["asc", "desc"], NotGiven] -class GraphAddFileOptions(APIOptions, total=False): - file_id: str - - -class FileAddOptions(APIOptions, total=False): - content: FileTypes - content_disposition: str - content_type: str - - logger = logging.Logger(__name__) @@ -250,11 +235,12 @@ def delete_graph(cls, graph_id: str) -> GraphDeleteResponse: return graphs.delete(graph_id) @classmethod - def add_file_to_graph(cls, graph_id: str, config: Optional[GraphAddFileOptions] = None) -> File: - if not config: - config = {} + def add_file_to_graph(cls, graph_id: str, file_id: Optional[str] = None) -> File: + payload = {} + if file_id: + payload["file_id"] = file_id graphs = cls.retrieve_graphs_accessor() - return graphs.add_file_to_graph(graph_id, **config) + return graphs.add_file_to_graph(graph_id, **payload) @classmethod def remove_file_from_graph(cls, graph_id: str, file_id: str) -> GraphRemoveFileFromGraphResponse: From fec4ea4a1f2f46a7fbd98a76b5388f3f9fdadf55 Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Tue, 9 Jul 2024 19:01:52 +0300 Subject: [PATCH 05/14] fix: proper signature for add_file_to_graph --- src/writer/ai.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/writer/ai.py b/src/writer/ai.py index c5190c86b..5a67223d4 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -235,12 +235,11 @@ def delete_graph(cls, graph_id: str) -> GraphDeleteResponse: return graphs.delete(graph_id) @classmethod - def add_file_to_graph(cls, graph_id: str, file_id: Optional[str] = None) -> File: - payload = {} - if file_id: - payload["file_id"] = file_id + def add_file_to_graph(cls, graph_id: str, file_id: str, config: Optional[APIOptions] = None) -> File: + if not config: + config = {} graphs = cls.retrieve_graphs_accessor() - return graphs.add_file_to_graph(graph_id, **payload) + return graphs.add_file_to_graph(graph_id, file_id, **config) @classmethod def remove_file_from_graph(cls, graph_id: str, file_id: str) -> GraphRemoveFileFromGraphResponse: From d6fd59b6339011a6944666425ed5519610369d9b Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Wed, 10 Jul 2024 22:41:38 +0300 Subject: [PATCH 06/14] chore: extend KG/file management functionality to Graph & File classes - class Graph to manage file addition/removal & update - class File to manage downolading - static methods to manage creation/retrieval/removal --- src/writer/ai.py | 334 ++++++++++++++++++++++++++++++----------------- 1 file changed, 212 insertions(+), 122 deletions(-) diff --git a/src/writer/ai.py b/src/writer/ai.py index 5a67223d4..0355c1d34 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -1,5 +1,8 @@ import logging -from typing import Generator, Iterable, List, Literal, Optional, TypedDict, Union, cast +from datetime import datetime +from typing import (Generator, Iterable, List, Literal, Optional, TypedDict, + Union, cast) +from uuid import uuid4 from httpx import Timeout from writerai import Writer @@ -7,23 +10,17 @@ from writerai._response import BinaryAPIResponse from writerai._streaming import Stream from writerai._types import Body, Headers, NotGiven, Query -from writerai.pagination import SyncCursorPage -from writerai.types import ( - Chat, - Completion, - File, - FileDeleteResponse, - Graph, - GraphCreateResponse, - GraphDeleteResponse, - GraphRemoveFileFromGraphResponse, - GraphUpdateResponse, - StreamingData, -) +from writerai.resources import FilesResource, GraphsResource +from writerai.types import Chat, Completion +from writerai.types import File as SDKFile +from writerai.types import FileDeleteResponse +from writerai.types import Graph as SDKGraph +from writerai.types import (GraphDeleteResponse, + GraphRemoveFileFromGraphResponse, + GraphUpdateResponse, StreamingData) from writerai.types.chat_chat_params import Message as WriterAIMessage from writer.core import get_app_process -from writer.ss_types import WriterFileItem class APIOptions(TypedDict, total=False): @@ -161,21 +158,54 @@ def acquire_client(cls) -> Writer: return instance.client -class WriterGraphManager: - """ - Manages graph-related operations using the Writer AI API. +class SDKWrapper: + _wrapped: Union[SDKFile, SDKGraph] - Provides methods to create, retrieve, update, delete, and manage files within graphs. - """ + def _get_property(self, property_name): + try: + return getattr(self._wrapped, property_name) + except AttributeError: + raise AttributeError( + f"type object '{self.__class__}' has no attribute {property_name}" + ) from None - def __init__(self): - """ - Initializes a WriterGraphManager instance. - """ - pass - @classmethod - def retrieve_graphs_accessor(cls): +class Graph(SDKWrapper): + _wrapped: SDKGraph + stale_ids = set() + + def __init__( + self, + name: str, + description: Optional[str] = None, + config: Optional[APIOptions] = None, + new_graph: bool = True + ): + config = config or {} + if new_graph is True: + graphs = self._retrieve_graphs_accessor() + graph_response = \ + graphs.create( + name=name, + description=description, + **config + ) + graph_object = graphs.retrieve(graph_response.id) + self._wrapped = graph_object + + @staticmethod + def _init_from_object(graph_object: SDKGraph): + instance = Graph( + name=graph_object.name, + description=graph_object.description, + config=None, + new_graph=False + ) + instance._wrapped = graph_object + return instance + + @staticmethod + def _retrieve_graphs_accessor() -> GraphsResource: """ Acquires the graphs accessor from the WriterAIManager singleton instance. @@ -183,33 +213,43 @@ def retrieve_graphs_accessor(cls): """ return WriterAIManager.acquire_client().graphs - @classmethod - def create_graph( - cls, - name: str, - description: str, - config: Optional[APIOptions] = None - ) -> GraphCreateResponse: - if not config: - config = {} - graphs = cls.retrieve_graphs_accessor() - return graphs.create(name=name, description=description, **config) + @property + def id(self) -> str: + return self._get_property('id') - @classmethod - def retrieve_graph(cls, graph_id: str) -> Graph: - graphs = cls.retrieve_graphs_accessor() - return graphs.retrieve(graph_id) + @property + def created_at(self) -> datetime: + return self._get_property('created_at') - @classmethod - def update_graph( - cls, - graph_id: str, + def _fetch_object_updates(self): + if self.id in Graph.stale_ids: + graphs = self._retrieve_graphs_accessor() + fresh_object = graphs.retrieve(self.id) + self._wrapped = fresh_object + Graph.stale_ids.remove(self.id) + + @property + def name(self) -> str: + self._fetch_object_updates() + return self._wrapped.name + + @property + def description(self) -> Optional[str]: + self._fetch_object_updates() + return self._wrapped.description + + @property + def file_status(self): + self._fetch_object_updates() + return self._wrapped.file_status + + def update( + self, name: Optional[str] = None, description: Optional[str] = None, config: Optional[APIOptions] = None ) -> GraphUpdateResponse: - if not config: - config = {} + config = config or {} # We use the payload dictionary # to distinguish between None-values @@ -219,49 +259,73 @@ def update_graph( payload["name"] = name if description: payload["description"] = description - graphs = cls.retrieve_graphs_accessor() - return graphs.update(graph_id, **payload, **config) + graphs = self._retrieve_graphs_accessor() + response = graphs.update(self.id, **payload, **config) + Graph.stale_ids.add(self.id) + return response + + def add_file( + self, + file_id: str, + config: Optional[APIOptions] = None + ) -> 'File': + config = config or {} + graphs = self._retrieve_graphs_accessor() + response = graphs.add_file_to_graph( + graph_id=self.id, + file_id=file_id, + **config + ) + Graph.stale_ids.add(self.id) + return response + + def remove_file(self, file_id: str) -> GraphRemoveFileFromGraphResponse: + graphs = self._retrieve_graphs_accessor() + response = graphs.remove_file_from_graph( + graph_id=self.id, + file_id=file_id + ) + Graph.stale_ids.add(self.id) + return response - @classmethod - def list_graphs(cls, config: Optional[APIListOptions] = None) -> SyncCursorPage[Graph]: - if not config: - config = {} - graphs = cls.retrieve_graphs_accessor() - return graphs.list(**config) - @classmethod - def delete_graph(cls, graph_id: str) -> GraphDeleteResponse: - graphs = cls.retrieve_graphs_accessor() - return graphs.delete(graph_id) +def retrieve_graph(graph_id: str) -> Graph: + graphs = Graph._retrieve_graphs_accessor() + graph_object = graphs.retrieve(graph_id) + graph = Graph._init_from_object(graph_object) + return graph - @classmethod - def add_file_to_graph(cls, graph_id: str, file_id: str, config: Optional[APIOptions] = None) -> File: - if not config: - config = {} - graphs = cls.retrieve_graphs_accessor() - return graphs.add_file_to_graph(graph_id, file_id, **config) - @classmethod - def remove_file_from_graph(cls, graph_id: str, file_id: str) -> GraphRemoveFileFromGraphResponse: - graphs = cls.retrieve_graphs_accessor() - return graphs.remove_file_from_graph(graph_id, file_id) +def list_graphs(config: Optional[APIListOptions] = None) -> List[Graph]: + config = config or {} + graphs = Graph._retrieve_graphs_accessor() + sdk_graphs = graphs.list(**config) + return [Graph._init_from_object(sdk_graph) for sdk_graph in sdk_graphs] -class WriterFileManager: - """ - Manages file-related operations using the Writer AI API. +def delete_graph(graph_id_or_graph: Union[Graph, str]) -> GraphDeleteResponse: + graph_id = None + if isinstance(graph_id_or_graph, Graph): + graph_id = graph_id_or_graph.id + elif isinstance(graph_id_or_graph, str): + graph_id = graph_id_or_graph + else: + raise ValueError( + "'delete_graph' method accepts either 'Graph' object" + + f" or ID of graph as string; got '{type(graph_id_or_graph)}'" + ) + graphs = Graph._retrieve_graphs_accessor() + return graphs.delete(graph_id) - Provides methods to retrieve, list, delete, download, and upload files. - """ - def __init__(self): - """ - Initializes a WriterFileManager instance. - """ - pass +class File(SDKWrapper): + _wrapped: SDKFile - @classmethod - def retrieve_files_accessor(cls): + def __init__(self, file_object: SDKFile): + self._wrapped = file_object + + @staticmethod + def _retrieve_files_accessor() -> FilesResource: """ Acquires the files client from the WriterAIManager singleton instance. @@ -269,44 +333,73 @@ def retrieve_files_accessor(cls): """ return WriterAIManager.acquire_client().files - @classmethod - def retrieve_file(cls, file_id: str) -> File: - files = cls.retrieve_files_accessor() - return files.retrieve(file_id) + @property + def id(self) -> str: + return self._get_property('id') - @classmethod - def list_files(cls, config: Optional[APIListOptions] = None) -> SyncCursorPage[File]: - if not config: - config = {} - files = cls.retrieve_files_accessor() - return files.list(**config) + @property + def created_at(self) -> datetime: + return self._get_property('created_at') - @classmethod - def delete_file(cls, file_id: str) -> FileDeleteResponse: - files = cls.retrieve_files_accessor() - return files.delete(file_id) + @property + def graph_ids(self) -> List[str]: + return self._get_property('graph_ids') - @classmethod - def download_file(cls, file_id: str) -> BinaryAPIResponse: - files = cls.retrieve_files_accessor() - return files.download(file_id) + @property + def name(self) -> str: + return self._get_property('name') + + def download(self) -> BinaryAPIResponse: + files = self._retrieve_files_accessor() + return files.download(self.id) + + +def retrieve_file(file_id: str) -> File: + files = File._retrieve_files_accessor() + file_object = files.retrieve(file_id) + file = File(file_object) + return file + + +def list_files(config: Optional[APIListOptions] = None) -> List[File]: + config = config or {} + files = File._retrieve_files_accessor() + sdk_files = files.list(**config) + return [File(sdk_file) for sdk_file in sdk_files] + + +def upload_file( + data: bytes, + type: str, + name: Optional[str] = None, + config: Optional[APIOptions] = None + ) -> File: + config = config or {} + files = File._retrieve_files_accessor() + uploaded_file = { + "content": data, + "content_type": type, + "content_disposition": + f'attachment;filename="{name or f"WF-{type}-{uuid4()}"}"' + } + sdk_file = files.upload(**uploaded_file, **config) + return File(sdk_file) + + +def delete_file(file_id_or_file: Union['File', str]) -> FileDeleteResponse: + file_id = None + if isinstance(file_id_or_file, File): + file_id = file_id_or_file.id + elif isinstance(file_id_or_file, str): + file_id = file_id_or_file + else: + raise ValueError( + "'delete_file' method accepts either 'File' object" + + f" or ID of file as string; got '{type(file_id_or_file)}'" + ) - @classmethod - def upload_file(cls, file: WriterFileItem, config: APIOptions) -> File: - files = cls.retrieve_files_accessor() - if "data" not in file: - raise ValueError("Missing `data` in file payload") - if "type" not in file: - raise ValueError("Missing `type` in file payload") - if "name" not in file: - raise ValueError("Missing `name` in file payload") - - uploaded_file = { - "content": file["data"], - "content_type": file["type"], - "content_disposition": f'attachment;filename="{file["name"]}"' - } - return files.upload(**uploaded_file, **config) + files = File._retrieve_files_accessor() + return files.delete(file_id) class Conversation: @@ -475,8 +568,7 @@ def complete(self, config: Optional['ChatOptions'] = None) -> 'Conversation.Mess :return: Generated message. :raises RuntimeError: If response data was not properly formatted to retrieve model text. """ - if not config: - config = {'max_tokens': 2048} + config = config or {'max_tokens': 2048} client = WriterAIManager.acquire_client() passed_messages: Iterable[WriterAIMessage] = [self._prepare_message(message) for message in self.messages] @@ -511,8 +603,7 @@ def stream_complete(self, config: Optional['ChatOptions'] = None) -> Generator[d :param config: Optional parameters to pass for processing. :yields: Model response chunks as they arrive from the stream. """ - if not config: - config = {'max_tokens': 2048} + config = config or {'max_tokens': 2048} client = WriterAIManager.acquire_client() passed_messages: Iterable[WriterAIMessage] = [self._prepare_message(message) for message in self.messages] @@ -571,8 +662,7 @@ def complete(initial_text: str, config: Optional['CreateOptions'] = None) -> str :return: The text of the first choice from the completion response. :raises RuntimeError: If response data was not properly formatted to retrieve model text. """ - if not config: - config = {} + config = config or {} client = WriterAIManager.acquire_client() request_model = config.get("model", None) or WriterAIManager.use_completion_model() From 1d5c4bce6c709a06ba7786124ebcca4d409c07f5 Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Wed, 10 Jul 2024 22:51:31 +0300 Subject: [PATCH 07/14] fix: APIOptions for file methods --- src/writer/ai.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/writer/ai.py b/src/writer/ai.py index 0355c1d34..e341bbc64 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -354,9 +354,10 @@ def download(self) -> BinaryAPIResponse: return files.download(self.id) -def retrieve_file(file_id: str) -> File: +def retrieve_file(file_id: str, config: Optional[APIOptions]) -> File: + config = config or {} files = File._retrieve_files_accessor() - file_object = files.retrieve(file_id) + file_object = files.retrieve(file_id, **config) file = File(file_object) return file @@ -386,7 +387,11 @@ def upload_file( return File(sdk_file) -def delete_file(file_id_or_file: Union['File', str]) -> FileDeleteResponse: +def delete_file( + file_id_or_file: Union['File', str], + config: Optional[APIOptions] + ) -> FileDeleteResponse: + config = config or {} file_id = None if isinstance(file_id_or_file, File): file_id = file_id_or_file.id @@ -399,7 +404,7 @@ def delete_file(file_id_or_file: Union['File', str]) -> FileDeleteResponse: ) files = File._retrieve_files_accessor() - return files.delete(file_id) + return files.delete(file_id, **config) class Conversation: From 15b37051e0732e4173a921fe4ea2090849ff6684 Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Wed, 10 Jul 2024 22:58:44 +0300 Subject: [PATCH 08/14] fix: accept both id and File object for Graph methods --- src/writer/ai.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/writer/ai.py b/src/writer/ai.py index e341bbc64..d12060260 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -266,10 +266,20 @@ def update( def add_file( self, - file_id: str, + file_id_or_file: Union['File', str], config: Optional[APIOptions] = None ) -> 'File': config = config or {} + file_id = None + if isinstance(file_id_or_file, File): + file_id = file_id_or_file.id + elif isinstance(file_id_or_file, str): + file_id = file_id_or_file + else: + raise ValueError( + "'Graph.add_file' method accepts either 'File' object" + + f" or ID of file as string; got '{type(file_id_or_file)}'" + ) graphs = self._retrieve_graphs_accessor() response = graphs.add_file_to_graph( graph_id=self.id, @@ -279,7 +289,22 @@ def add_file( Graph.stale_ids.add(self.id) return response - def remove_file(self, file_id: str) -> GraphRemoveFileFromGraphResponse: + def remove_file( + self, + file_id_or_file: Union['File', str], + config: Optional[APIOptions] = None + ) -> GraphRemoveFileFromGraphResponse: + config = config or {} + file_id = None + if isinstance(file_id_or_file, File): + file_id = file_id_or_file.id + elif isinstance(file_id_or_file, str): + file_id = file_id_or_file + else: + raise ValueError( + "'Graph.remove_file' method accepts either 'File' object" + + f" or ID of file as string; got '{type(file_id_or_file)}'" + ) graphs = self._retrieve_graphs_accessor() response = graphs.remove_file_from_graph( graph_id=self.id, From e45e8b4919c42087164afa949b0d7e5cead6877b Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Thu, 11 Jul 2024 12:30:50 +0300 Subject: [PATCH 09/14] chore: create_graph method and shift from class-managed API calls --- src/writer/ai.py | 58 +++++++++++++++++++++--------------------------- 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/src/writer/ai.py b/src/writer/ai.py index d12060260..574bb416f 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -171,38 +171,14 @@ def _get_property(self, property_name): class Graph(SDKWrapper): - _wrapped: SDKGraph + _wrapped: SDKGraph = None stale_ids = set() def __init__( self, - name: str, - description: Optional[str] = None, - config: Optional[APIOptions] = None, - new_graph: bool = True + graph_object: SDKGraph ): - config = config or {} - if new_graph is True: - graphs = self._retrieve_graphs_accessor() - graph_response = \ - graphs.create( - name=name, - description=description, - **config - ) - graph_object = graphs.retrieve(graph_response.id) - self._wrapped = graph_object - - @staticmethod - def _init_from_object(graph_object: SDKGraph): - instance = Graph( - name=graph_object.name, - description=graph_object.description, - config=None, - new_graph=False - ) - instance._wrapped = graph_object - return instance + self._wrapped = graph_object @staticmethod def _retrieve_graphs_accessor() -> GraphsResource: @@ -287,13 +263,13 @@ def add_file( **config ) Graph.stale_ids.add(self.id) - return response + return File(response) def remove_file( self, file_id_or_file: Union['File', str], config: Optional[APIOptions] = None - ) -> GraphRemoveFileFromGraphResponse: + ) -> Optional[GraphRemoveFileFromGraphResponse]: config = config or {} file_id = None if isinstance(file_id_or_file, File): @@ -314,10 +290,26 @@ def remove_file( return response -def retrieve_graph(graph_id: str) -> Graph: +def create_graph( + name: str, + description: Optional[str] = None, + config: Optional[APIOptions] = None + ) -> Graph: + config = config or {} + graphs = Graph._retrieve_graphs_accessor() + graph_object = graphs.create(name=name, description=description, **config) + graph = Graph(graph_object) + return graph + + +def retrieve_graph( + graph_id: str, + config: Optional[APIListOptions] = None + ) -> Graph: + config = config or {} graphs = Graph._retrieve_graphs_accessor() - graph_object = graphs.retrieve(graph_id) - graph = Graph._init_from_object(graph_object) + graph_object = graphs.retrieve(graph_id, **config) + graph = Graph(graph_object) return graph @@ -325,7 +317,7 @@ def list_graphs(config: Optional[APIListOptions] = None) -> List[Graph]: config = config or {} graphs = Graph._retrieve_graphs_accessor() sdk_graphs = graphs.list(**config) - return [Graph._init_from_object(sdk_graph) for sdk_graph in sdk_graphs] + return [Graph(sdk_graph) for sdk_graph in sdk_graphs] def delete_graph(graph_id_or_graph: Union[Graph, str]) -> GraphDeleteResponse: From 6ae9337ee47a386d1e8c305eeb5aa2fc80bd02ad Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Mon, 15 Jul 2024 13:02:58 +0300 Subject: [PATCH 10/14] fix: default None value for optional configs on file --- src/writer/ai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/writer/ai.py b/src/writer/ai.py index 574bb416f..a91dc4e0a 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -371,7 +371,7 @@ def download(self) -> BinaryAPIResponse: return files.download(self.id) -def retrieve_file(file_id: str, config: Optional[APIOptions]) -> File: +def retrieve_file(file_id: str, config: Optional[APIOptions] = None) -> File: config = config or {} files = File._retrieve_files_accessor() file_object = files.retrieve(file_id, **config) @@ -406,7 +406,7 @@ def upload_file( def delete_file( file_id_or_file: Union['File', str], - config: Optional[APIOptions] + config: Optional[APIOptions] = None ) -> FileDeleteResponse: config = config or {} file_id = None From 4ed8eb9045bb1a8bdf0b27a09ef17c7227016c02 Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Mon, 15 Jul 2024 13:33:07 +0300 Subject: [PATCH 11/14] chore: docstrings --- src/writer/ai.py | 238 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 238 insertions(+) diff --git a/src/writer/ai.py b/src/writer/ai.py index a91dc4e0a..79c264362 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -159,9 +159,23 @@ def acquire_client(cls) -> Writer: class SDKWrapper: + """ + A wrapper class for SDK objects, allowing dynamic access to properties. + + Attributes: + _wrapped (Union[SDKFile, SDKGraph]): The wrapped SDK object. + """ _wrapped: Union[SDKFile, SDKGraph] def _get_property(self, property_name): + """ + Retrieves a property from the wrapped object. + + :param property_name: The name of the property to retrieve. + :type property_name: str + :returns: The value of the requested property. + :raises AttributeError: If the property does not exist. + """ try: return getattr(self._wrapped, property_name) except AttributeError: @@ -171,6 +185,13 @@ def _get_property(self, property_name): class Graph(SDKWrapper): + """ + A wrapper class for SDKGraph objects, providing additional functionality. + + Attributes: + _wrapped (writerai.types.Graph): The wrapped SDK Graph object. + stale_ids (set): A set of stale graph IDs that need updates. + """ _wrapped: SDKGraph = None stale_ids = set() @@ -178,6 +199,12 @@ def __init__( self, graph_object: SDKGraph ): + """ + Initializes the Graph with the given SDKGraph object. + + :param graph_object: The SDKGraph object to wrap. + :type graph_object: writerai.types.Graph + """ self._wrapped = graph_object @staticmethod @@ -186,6 +213,7 @@ def _retrieve_graphs_accessor() -> GraphsResource: Acquires the graphs accessor from the WriterAIManager singleton instance. :returns: The graphs accessor instance. + :rtype: GraphsResource """ return WriterAIManager.acquire_client().graphs @@ -198,6 +226,9 @@ def created_at(self) -> datetime: return self._get_property('created_at') def _fetch_object_updates(self): + """ + Fetches updates for the graph object if it is stale. + """ if self.id in Graph.stale_ids: graphs = self._retrieve_graphs_accessor() fresh_object = graphs.retrieve(self.id) @@ -225,6 +256,24 @@ def update( description: Optional[str] = None, config: Optional[APIOptions] = None ) -> GraphUpdateResponse: + """ + Updates the graph with the given parameters. + + :param name: The new name for the graph. + :type name: Optional[str] + :param description: The new description for the graph. + :type description: Optional[str] + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The response from the update operation. + :rtype: GraphUpdateResponse + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ config = config or {} # We use the payload dictionary @@ -245,6 +294,23 @@ def add_file( file_id_or_file: Union['File', str], config: Optional[APIOptions] = None ) -> 'File': + """ + Adds a file to the graph. + + :param file_id_or_file: The file object or file ID to add. + :type file_id_or_file: Union['File', str] + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The added file object. + :rtype: File + :raises ValueError: If the input is neither a File object nor a file ID string. + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ config = config or {} file_id = None if isinstance(file_id_or_file, File): @@ -270,6 +336,23 @@ def remove_file( file_id_or_file: Union['File', str], config: Optional[APIOptions] = None ) -> Optional[GraphRemoveFileFromGraphResponse]: + """ + Removes a file from the graph. + + :param file_id_or_file: The file object or file ID to remove. + :type file_id_or_file: Union['File', str] + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The response from the remove operation. + :rtype: Optional[GraphRemoveFileFromGraphResponse] + :raises ValueError: If the input is neither a File object nor a file ID string. + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ config = config or {} file_id = None if isinstance(file_id_or_file, File): @@ -295,6 +378,24 @@ def create_graph( description: Optional[str] = None, config: Optional[APIOptions] = None ) -> Graph: + """ + Creates a new graph with the given parameters. + + :param name: The name of the graph. + :type name: str + :param description: The description of the graph. + :type description: Optional[str] + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The created graph object. + :rtype: Graph + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ config = config or {} graphs = Graph._retrieve_graphs_accessor() graph_object = graphs.create(name=name, description=description, **config) @@ -306,6 +407,26 @@ def retrieve_graph( graph_id: str, config: Optional[APIListOptions] = None ) -> Graph: + """ + Retrieves a graph by its ID. + + :param graph_id: The ID of the graph to retrieve. + :type graph_id: str + :param config: Additional configuration options. + :type config: Optional[APIListOptions] + :returns: The retrieved graph object. + :rtype: Graph + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + - `after` (Union[str, NotGiven]): Filter to retrieve items created after a specific cursor. + - `before` (Union[str, NotGiven]): Filter to retrieve items created before a specific cursor. + - `limit` (Union[int, NotGiven]): The number of items to retrieve. + - `order` (Union[Literal["asc", "desc"], NotGiven]): The order in which to retrieve items. + """ config = config or {} graphs = Graph._retrieve_graphs_accessor() graph_object = graphs.retrieve(graph_id, **config) @@ -314,6 +435,24 @@ def retrieve_graph( def list_graphs(config: Optional[APIListOptions] = None) -> List[Graph]: + """ + Lists all graphs with the given configuration. + + :param config: Additional configuration options. + :type config: Optional[APIListOptions] + :returns: A list of graph objects. + :rtype: List[Graph] + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + - `after` (Union[str, NotGiven]): Filter to retrieve items created after a specific cursor. + - `before` (Union[str, NotGiven]): Filter to retrieve items created before a specific cursor. + - `limit` (Union[int, NotGiven]): The number of items to retrieve. + - `order` (Union[Literal["asc", "desc"], NotGiven]): The order in which to retrieve items. + """ config = config or {} graphs = Graph._retrieve_graphs_accessor() sdk_graphs = graphs.list(**config) @@ -321,6 +460,15 @@ def list_graphs(config: Optional[APIListOptions] = None) -> List[Graph]: def delete_graph(graph_id_or_graph: Union[Graph, str]) -> GraphDeleteResponse: + """ + Deletes a graph by its ID or object. + + :param graph_id_or_graph: The graph object or graph ID to delete. + :type graph_id_or_graph: Union[Graph, str] + :returns: The response from the delete operation. + :rtype: GraphDeleteResponse + :raises ValueError: If the input is neither a Graph object nor a graph ID string. + """ graph_id = None if isinstance(graph_id_or_graph, Graph): graph_id = graph_id_or_graph.id @@ -336,9 +484,21 @@ def delete_graph(graph_id_or_graph: Union[Graph, str]) -> GraphDeleteResponse: class File(SDKWrapper): + """ + A wrapper class for SDK File objects, providing additional functionality. + + Attributes: + _wrapped (writerai.types.File): The wrapped SDKFile object. + """ _wrapped: SDKFile def __init__(self, file_object: SDKFile): + """ + Initializes the File with the given SDKFile object. + + :param file_object: The SDKFile object to wrap. + :type file_object: writerai.types.File + """ self._wrapped = file_object @staticmethod @@ -347,6 +507,7 @@ def _retrieve_files_accessor() -> FilesResource: Acquires the files client from the WriterAIManager singleton instance. :returns: The files client instance. + :rtype: FilesResource """ return WriterAIManager.acquire_client().files @@ -367,11 +528,33 @@ def name(self) -> str: return self._get_property('name') def download(self) -> BinaryAPIResponse: + """ + Downloads the file content. + + :returns: The response containing the file content. + :rtype: BinaryAPIResponse + """ files = self._retrieve_files_accessor() return files.download(self.id) def retrieve_file(file_id: str, config: Optional[APIOptions] = None) -> File: + """ + Retrieves a file by its ID. + + :param file_id: The ID of the file to retrieve. + :type file_id: str + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The retrieved file object. + :rtype: writerai.types.File + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ config = config or {} files = File._retrieve_files_accessor() file_object = files.retrieve(file_id, **config) @@ -380,6 +563,24 @@ def retrieve_file(file_id: str, config: Optional[APIOptions] = None) -> File: def list_files(config: Optional[APIListOptions] = None) -> List[File]: + """ + Lists all files with the given configuration. + + :param config: Additional configuration options. + :type config: Optional[APIListOptions] + :returns: A list of file objects. + :rtype: List[File] + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + - `after` (Union[str, NotGiven]): Filter to retrieve items created after a specific cursor. + - `before` (Union[str, NotGiven]): Filter to retrieve items created before a specific cursor. + - `limit` (Union[int, NotGiven]): The number of items to retrieve. + - `order` (Union[Literal["asc", "desc"], NotGiven]): The order in which to retrieve items. + """ config = config or {} files = File._retrieve_files_accessor() sdk_files = files.list(**config) @@ -392,6 +593,26 @@ def upload_file( name: Optional[str] = None, config: Optional[APIOptions] = None ) -> File: + """ + Uploads a new file with the given parameters. + + :param data: The file content as bytes. + :type data: bytes + :param type: The MIME type of the file. + :type type: str + :param name: The name of the file. + :type name: Optional[str] + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The uploaded file object. + :rtype: writerai.types.File + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ config = config or {} files = File._retrieve_files_accessor() uploaded_file = { @@ -408,6 +629,23 @@ def delete_file( file_id_or_file: Union['File', str], config: Optional[APIOptions] = None ) -> FileDeleteResponse: + """ + Deletes a file by its ID or object. + + :param file_id_or_file: The file object or file ID to delete. + :type file_id_or_file: Union['File', str] + :param config: Additional configuration options. + :type config: Optional[APIOptions] + :returns: The response from the delete operation. + :rtype: FileDeleteResponse + :raises ValueError: If the input is neither a File object nor a file ID string. + + The `config` dictionary can include the following keys: + - `extra_headers` (Optional[Headers]): Additional headers for the request. + - `extra_query` (Optional[Query]): Additional query parameters for the request. + - `extra_body` (Optional[Body]): Additional body parameters for the request. + - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. + """ config = config or {} file_id = None if isinstance(file_id_or_file, File): From 79f4074d896320e19869234cf8f2759a74befdbb Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Mon, 15 Jul 2024 14:06:22 +0300 Subject: [PATCH 12/14] chore: SDK version bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 217785bb0..fdf78e87b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ requests = "^2.31.0" uvicorn = ">= 0.20.0, < 1" watchdog = ">= 3.0.0, < 4" websockets = ">= 12, < 13" -writer-sdk = ">= 0.1.2, < 1" +writer-sdk = ">= 0.5.0, < 1" [tool.poetry.group.build] From 99035e4f4ea8b8b5304d58398ab0097cd066ad1a Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Mon, 15 Jul 2024 14:11:37 +0300 Subject: [PATCH 13/14] fix: poetry lock file --- poetry.lock | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index a96a8c614..b40ab0626 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "alfred-cli" @@ -2192,13 +2192,13 @@ files = [ [[package]] name = "writer-sdk" -version = "0.1.2" +version = "0.5.0" description = "The official Python library for the writer API" optional = false python-versions = ">=3.7" files = [ - {file = "writer_sdk-0.1.2-py3-none-any.whl", hash = "sha256:2072f5c13b8011a0c2ebe1b63fbc041642e9e9ca6eab7e976f7b3f20ad9931e5"}, - {file = "writer_sdk-0.1.2.tar.gz", hash = "sha256:3fa0b09a57ba969ca344024fa46305cd441cb45cb34d26d65a51a743477a4315"}, + {file = "writer_sdk-0.5.0-py3-none-any.whl", hash = "sha256:654e08fa0040126b8a8ae2468a9f2185997d94926a9322cee6278a721d098598"}, + {file = "writer_sdk-0.5.0.tar.gz", hash = "sha256:984f289f2576f8fbaec6893e9c83a0810a3084fbc536cc22666e0f620e8eff64"}, ] [package.dependencies] @@ -2212,4 +2212,4 @@ typing-extensions = ">=4.7,<5" [metadata] lock-version = "2.0" python-versions = ">=3.9.2, <4.0" -content-hash = "2b7dece7f8e0c504b0d1c393f324b8f86e382910106db130a9cfc129786fc138" +content-hash = "087a35e24f0286c5063c952de4bf460b036579e4956f1e598b9ac2172fa82cdd" From 2780bbcd4c255151d277ecf42015b545ae84d5ac Mon Sep 17 00:00:00 2001 From: mmikita95 Date: Mon, 15 Jul 2024 14:31:13 +0300 Subject: [PATCH 14/14] fix: typing & lint --- src/writer/ai.py | 61 +++++++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/src/writer/ai.py b/src/writer/ai.py index 79c264362..ae9655de2 100644 --- a/src/writer/ai.py +++ b/src/writer/ai.py @@ -1,7 +1,16 @@ import logging from datetime import datetime -from typing import (Generator, Iterable, List, Literal, Optional, TypedDict, - Union, cast) +from typing import ( + Generator, + Iterable, + List, + Literal, + Optional, + Set, + TypedDict, + Union, + cast, +) from uuid import uuid4 from httpx import Timeout @@ -11,13 +20,17 @@ from writerai._streaming import Stream from writerai._types import Body, Headers, NotGiven, Query from writerai.resources import FilesResource, GraphsResource -from writerai.types import Chat, Completion +from writerai.types import ( + Chat, + Completion, + FileDeleteResponse, + GraphDeleteResponse, + GraphRemoveFileFromGraphResponse, + GraphUpdateResponse, + StreamingData, +) from writerai.types import File as SDKFile -from writerai.types import FileDeleteResponse from writerai.types import Graph as SDKGraph -from writerai.types import (GraphDeleteResponse, - GraphRemoveFileFromGraphResponse, - GraphUpdateResponse, StreamingData) from writerai.types.chat_chat_params import Message as WriterAIMessage from writer.core import get_app_process @@ -192,8 +205,8 @@ class Graph(SDKWrapper): _wrapped (writerai.types.Graph): The wrapped SDK Graph object. stale_ids (set): A set of stale graph IDs that need updates. """ - _wrapped: SDKGraph = None - stale_ids = set() + _wrapped: SDKGraph + stale_ids: Set[str] = set() def __init__( self, @@ -398,14 +411,15 @@ def create_graph( """ config = config or {} graphs = Graph._retrieve_graphs_accessor() - graph_object = graphs.create(name=name, description=description, **config) - graph = Graph(graph_object) + graph_object = graphs.create(name=name, description=description or NotGiven(), **config) + converted_object = cast(SDKGraph, graph_object) + graph = Graph(converted_object) return graph def retrieve_graph( graph_id: str, - config: Optional[APIListOptions] = None + config: Optional[APIOptions] = None ) -> Graph: """ Retrieves a graph by its ID. @@ -413,7 +427,7 @@ def retrieve_graph( :param graph_id: The ID of the graph to retrieve. :type graph_id: str :param config: Additional configuration options. - :type config: Optional[APIListOptions] + :type config: Optional[APIOptions] :returns: The retrieved graph object. :rtype: Graph @@ -422,10 +436,6 @@ def retrieve_graph( - `extra_query` (Optional[Query]): Additional query parameters for the request. - `extra_body` (Optional[Body]): Additional body parameters for the request. - `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request. - - `after` (Union[str, NotGiven]): Filter to retrieve items created after a specific cursor. - - `before` (Union[str, NotGiven]): Filter to retrieve items created before a specific cursor. - - `limit` (Union[int, NotGiven]): The number of items to retrieve. - - `order` (Union[Literal["asc", "desc"], NotGiven]): The order in which to retrieve items. """ config = config or {} graphs = Graph._retrieve_graphs_accessor() @@ -615,13 +625,16 @@ def upload_file( """ config = config or {} files = File._retrieve_files_accessor() - uploaded_file = { - "content": data, - "content_type": type, - "content_disposition": - f'attachment;filename="{name or f"WF-{type}-{uuid4()}"}"' - } - sdk_file = files.upload(**uploaded_file, **config) + + file_name = name or f"WF-{type}-{uuid4()}" + content_disposition = f'attachment; filename="{file_name}"' + + # Now calling the upload method with correct types. + sdk_file = files.upload( + content=data, + content_type=type, + content_disposition=content_disposition + ) return File(sdk_file)