Skip to content

Commit

Permalink
[0.7.0] adds async support for distributed chat
Browse files Browse the repository at this point in the history
  • Loading branch information
yashbonde committed Jan 3, 2025
1 parent ac3af6b commit 73823b9
Show file tree
Hide file tree
Showing 11 changed files with 992 additions and 17 deletions.
2 changes: 1 addition & 1 deletion cookbooks/function_calling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.13.0"
}
},
"nbformat": 4,
Expand Down
6 changes: 6 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ minor versions.

All relevant steps to be taken will be mentioned here.

0.7.0
-----

- All models now have ``<model>.distributed_chat_async`` that can be used in servers without blocking the main event
loop. This will give a much needed UX improvement to the entire system.

0.6.3
-----

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tuneapi"
version = "0.6.3"
version = "0.7.0"
description = "Tune AI APIs."
authors = ["Frello Technology Private Limited <[email protected]>"]
license = "MIT"
Expand All @@ -17,6 +17,7 @@ tqdm = "^4.66.1"
snowflake_id = "1.0.2"
nutree = "0.8.0"
pillow = "^10.2.0"
httpx = "^0.28.1"
protobuf = { version = "^5.27.3", optional = true }
boto3 = { version = "1.29.6", optional = true }

Expand Down
165 changes: 161 additions & 4 deletions tuneapi/apis/model_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

# Copyright © 2024- Frello Technology Private Limited

import json
import httpx
import requests
from typing import Optional, Dict, Any, Tuple, List

import tuneapi.utils as tu
import tuneapi.types as tt
from tuneapi.apis.turbo import distributed_chat
from tuneapi.apis.turbo import distributed_chat, distributed_chat_async


class Anthropic(tt.ModelInterface):
Expand Down Expand Up @@ -203,7 +203,7 @@ def stream_chat(

try:
# print(line)
resp = json.loads(line.replace("data:", "").strip())
resp = tu.from_json(line.replace("data:", "").strip())
if resp["type"] == "content_block_start":
if resp["content_block"]["type"] == "tool_use":
fn_call = {
Expand All @@ -229,20 +229,155 @@ def stream_chat(
fn_call["arguments"] += delta["partial_json"]
elif resp["type"] == "content_block_stop":
if fn_call:
fn_call["arguments"] = json.loads(fn_call["arguments"] or "{}")
fn_call["arguments"] = tu.from_json(
fn_call["arguments"] or "{}"
)
yield fn_call
fn_call = None
except:
break
return

async def chat_async(
self,
chats: tt.Thread | str,
model: Optional[str] = None,
max_tokens: int = 1024,
temperature: Optional[float] = None,
token: Optional[str] = None,
return_message: bool = False,
extra_headers: Optional[Dict[str, str]] = None,
**kwargs,
):
output = ""
fn_call = None
async for i in self.stream_chat_async(
chats=chats,
model=model,
max_tokens=max_tokens,
temperature=temperature,
token=token,
extra_headers=extra_headers,
raw=False,
**kwargs,
):
if isinstance(i, dict):
fn_call = i.copy()
else:
output += i
if return_message:
return output, fn_call
if fn_call:
return fn_call
return output

async def stream_chat_async(
self,
chats: tt.Thread | str,
model: Optional[str] = None,
max_tokens: int = 1024,
temperature: Optional[float] = None,
token: Optional[str] = None,
timeout=(5, 30),
raw: bool = False,
debug: bool = False,
extra_headers: Optional[Dict[str, str]] = None,
**kwargs,
) -> Any:

tools = []
if isinstance(chats, tt.Thread):
tools = [x.to_dict() for x in chats.tools]
for t in tools:
t["input_schema"] = t.pop("parameters")
headers, system, claude_messages = self._process_input(chats=chats, token=token)
extra_headers = extra_headers or self.extra_headers
if extra_headers:
headers.update(extra_headers)

data = {
"model": model or self.model_id,
"max_tokens": max_tokens,
"messages": claude_messages,
"system": system,
"tools": tools,
"stream": True,
}
if temperature:
data["temperature"] = temperature
if kwargs:
data.update(kwargs)

if debug:
fp = "sample_anthropic.json"
print("Saving at path " + fp)
tu.to_json(data, fp=fp)

async with httpx.AsyncClient() as client:
response = await client.post(
self.base_url,
headers=headers,
json=data,
timeout=timeout,
)
try:
response.raise_for_status()
except Exception as e:
yield str(e)
return

async for chunk in response.aiter_bytes():
for line in chunk.decode("utf-8").splitlines():
line = line.strip()
if not line or not "data:" in line:
continue

try:
# print(line)
resp = tu.from_json(line.replace("data:", "").strip())
if resp["type"] == "content_block_start":
if resp["content_block"]["type"] == "tool_use":
fn_call = {
"name": resp["content_block"]["name"],
"arguments": "",
}
elif resp["type"] == "content_block_delta":
delta = resp["delta"]
delta_type = delta["type"]
if delta_type == "text_delta":
if raw:
yield b"data: " + tu.to_json(
{
"object": delta_type,
"choices": [
{"delta": {"content": delta["text"]}}
],
},
tight=True,
).encode()
yield b"" # uncomment this line if you want 1:1 with OpenAI
else:
yield delta["text"]
elif delta_type == "input_json_delta":
fn_call["arguments"] += delta["partial_json"]
elif resp["type"] == "content_block_stop":
if fn_call:
fn_call["arguments"] = tu.from_json(
fn_call["arguments"] or "{}"
)
yield fn_call
fn_call = None
except:
break

def distributed_chat(
self,
prompts: List[tt.Thread],
post_logic: Optional[callable] = None,
max_threads: int = 10,
retry: int = 3,
pbar=True,
debug=False,
**kwargs,
):
return distributed_chat(
Expand All @@ -252,5 +387,27 @@ def distributed_chat(
max_threads=max_threads,
retry=retry,
pbar=pbar,
debug=debug,
**kwargs,
)

async def distributed_chat_async(
self,
prompts: List[tt.Thread],
post_logic: Optional[callable] = None,
max_threads: int = 10,
retry: int = 3,
pbar=True,
debug=False,
**kwargs,
):
return await distributed_chat_async(
self,
prompts=prompts,
post_logic=post_logic,
max_threads=max_threads,
retry=retry,
pbar=pbar,
debug=debug,
**kwargs,
)
Loading

0 comments on commit 73823b9

Please sign in to comment.