Skip to content

Commit

Permalink
fix: typing & lint
Browse files Browse the repository at this point in the history
  • Loading branch information
mmikita95 committed Jul 15, 2024
1 parent 99035e4 commit 2780bbc
Showing 1 changed file with 37 additions and 24 deletions.
61 changes: 37 additions & 24 deletions src/writer/ai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import logging
from datetime import datetime
from typing import (Generator, Iterable, List, Literal, Optional, TypedDict,
Union, cast)
from typing import (
Generator,
Iterable,
List,
Literal,
Optional,
Set,
TypedDict,
Union,
cast,
)
from uuid import uuid4

from httpx import Timeout
Expand All @@ -11,13 +20,17 @@
from writerai._streaming import Stream
from writerai._types import Body, Headers, NotGiven, Query
from writerai.resources import FilesResource, GraphsResource
from writerai.types import Chat, Completion
from writerai.types import (
Chat,
Completion,
FileDeleteResponse,
GraphDeleteResponse,
GraphRemoveFileFromGraphResponse,
GraphUpdateResponse,
StreamingData,
)
from writerai.types import File as SDKFile
from writerai.types import FileDeleteResponse
from writerai.types import Graph as SDKGraph
from writerai.types import (GraphDeleteResponse,
GraphRemoveFileFromGraphResponse,
GraphUpdateResponse, StreamingData)
from writerai.types.chat_chat_params import Message as WriterAIMessage

from writer.core import get_app_process
Expand Down Expand Up @@ -192,8 +205,8 @@ class Graph(SDKWrapper):
_wrapped (writerai.types.Graph): The wrapped SDK Graph object.
stale_ids (set): A set of stale graph IDs that need updates.
"""
_wrapped: SDKGraph = None
stale_ids = set()
_wrapped: SDKGraph
stale_ids: Set[str] = set()

def __init__(
self,
Expand Down Expand Up @@ -398,22 +411,23 @@ def create_graph(
"""
config = config or {}
graphs = Graph._retrieve_graphs_accessor()
graph_object = graphs.create(name=name, description=description, **config)
graph = Graph(graph_object)
graph_object = graphs.create(name=name, description=description or NotGiven(), **config)
converted_object = cast(SDKGraph, graph_object)
graph = Graph(converted_object)
return graph


def retrieve_graph(
graph_id: str,
config: Optional[APIListOptions] = None
config: Optional[APIOptions] = None
) -> Graph:
"""
Retrieves a graph by its ID.
:param graph_id: The ID of the graph to retrieve.
:type graph_id: str
:param config: Additional configuration options.
:type config: Optional[APIListOptions]
:type config: Optional[APIOptions]
:returns: The retrieved graph object.
:rtype: Graph
Expand All @@ -422,10 +436,6 @@ def retrieve_graph(
- `extra_query` (Optional[Query]): Additional query parameters for the request.
- `extra_body` (Optional[Body]): Additional body parameters for the request.
- `timeout` (Union[float, httpx.Timeout, None, NotGiven]): Timeout for the request.
- `after` (Union[str, NotGiven]): Filter to retrieve items created after a specific cursor.
- `before` (Union[str, NotGiven]): Filter to retrieve items created before a specific cursor.
- `limit` (Union[int, NotGiven]): The number of items to retrieve.
- `order` (Union[Literal["asc", "desc"], NotGiven]): The order in which to retrieve items.
"""
config = config or {}
graphs = Graph._retrieve_graphs_accessor()
Expand Down Expand Up @@ -615,13 +625,16 @@ def upload_file(
"""
config = config or {}
files = File._retrieve_files_accessor()
uploaded_file = {
"content": data,
"content_type": type,
"content_disposition":
f'attachment;filename="{name or f"WF-{type}-{uuid4()}"}"'
}
sdk_file = files.upload(**uploaded_file, **config)

file_name = name or f"WF-{type}-{uuid4()}"
content_disposition = f'attachment; filename="{file_name}"'

# Now calling the upload method with correct types.
sdk_file = files.upload(
content=data,
content_type=type,
content_disposition=content_disposition
)
return File(sdk_file)


Expand Down

0 comments on commit 2780bbc

Please sign in to comment.