diff --git a/cookbooks/function_calling.ipynb b/cookbooks/function_calling.ipynb index 2b8daf7..ce28c1e 100644 --- a/cookbooks/function_calling.ipynb +++ b/cookbooks/function_calling.ipynb @@ -297,7 +297,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.13.0" } }, "nbformat": 4, diff --git a/docs/changelog.rst b/docs/changelog.rst index ff253ca..c7aec59 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,6 +7,12 @@ minor versions. All relevant steps to be taken will be mentioned here. +0.7.0 +----- + +- All models now have ``.distributed_chat_async`` that can be used in servers without blocking the main event + loop. This will give a much needed UX improvement to the entire system. + 0.6.3 ----- diff --git a/pyproject.toml b/pyproject.toml index d47b82f..4f2e092 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "tuneapi" -version = "0.6.3" +version = "0.7.0" description = "Tune AI APIs." authors = ["Frello Technology Private Limited "] license = "MIT" @@ -17,6 +17,7 @@ tqdm = "^4.66.1" snowflake_id = "1.0.2" nutree = "0.8.0" pillow = "^10.2.0" +httpx = "^0.28.1" protobuf = { version = "^5.27.3", optional = true } boto3 = { version = "1.29.6", optional = true } diff --git a/tuneapi/apis/model_anthropic.py b/tuneapi/apis/model_anthropic.py index 8f8f503..ea16505 100644 --- a/tuneapi/apis/model_anthropic.py +++ b/tuneapi/apis/model_anthropic.py @@ -4,13 +4,13 @@ # Copyright © 2024- Frello Technology Private Limited -import json +import httpx import requests from typing import Optional, Dict, Any, Tuple, List import tuneapi.utils as tu import tuneapi.types as tt -from tuneapi.apis.turbo import distributed_chat +from tuneapi.apis.turbo import distributed_chat, distributed_chat_async class Anthropic(tt.ModelInterface): @@ -203,7 +203,7 @@ def stream_chat( try: # print(line) - resp = json.loads(line.replace("data:", "").strip()) + resp = tu.from_json(line.replace("data:", "").strip()) if resp["type"] == "content_block_start": if resp["content_block"]["type"] == "tool_use": fn_call = { @@ -229,13 +229,147 @@ def stream_chat( fn_call["arguments"] += delta["partial_json"] elif resp["type"] == "content_block_stop": if fn_call: - fn_call["arguments"] = json.loads(fn_call["arguments"] or "{}") + fn_call["arguments"] = tu.from_json( + fn_call["arguments"] or "{}" + ) yield fn_call fn_call = None except: break return + async def chat_async( + self, + chats: tt.Thread | str, + model: Optional[str] = None, + max_tokens: int = 1024, + temperature: Optional[float] = None, + token: Optional[str] = None, + return_message: bool = False, + extra_headers: Optional[Dict[str, str]] = None, + **kwargs, + ): + output = "" + fn_call = None + async for i in self.stream_chat_async( + chats=chats, + model=model, + max_tokens=max_tokens, + temperature=temperature, + token=token, + extra_headers=extra_headers, + raw=False, + **kwargs, + ): + if isinstance(i, dict): + fn_call = i.copy() + else: + output += i + if return_message: + return output, fn_call + if fn_call: + return fn_call + return output + + async def stream_chat_async( + self, + chats: tt.Thread | str, + model: Optional[str] = None, + max_tokens: int = 1024, + temperature: Optional[float] = None, + token: Optional[str] = None, + timeout=(5, 30), + raw: bool = False, + debug: bool = False, + extra_headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> Any: + + tools = [] + if isinstance(chats, tt.Thread): + tools = [x.to_dict() for x in chats.tools] + for t in tools: + t["input_schema"] = t.pop("parameters") + headers, system, claude_messages = self._process_input(chats=chats, token=token) + extra_headers = extra_headers or self.extra_headers + if extra_headers: + headers.update(extra_headers) + + data = { + "model": model or self.model_id, + "max_tokens": max_tokens, + "messages": claude_messages, + "system": system, + "tools": tools, + "stream": True, + } + if temperature: + data["temperature"] = temperature + if kwargs: + data.update(kwargs) + + if debug: + fp = "sample_anthropic.json" + print("Saving at path " + fp) + tu.to_json(data, fp=fp) + + async with httpx.AsyncClient() as client: + response = await client.post( + self.base_url, + headers=headers, + json=data, + timeout=timeout, + ) + try: + response.raise_for_status() + except Exception as e: + yield str(e) + return + + async for chunk in response.aiter_bytes(): + for line in chunk.decode("utf-8").splitlines(): + line = line.strip() + if not line or not "data:" in line: + continue + + try: + # print(line) + resp = tu.from_json(line.replace("data:", "").strip()) + if resp["type"] == "content_block_start": + if resp["content_block"]["type"] == "tool_use": + fn_call = { + "name": resp["content_block"]["name"], + "arguments": "", + } + elif resp["type"] == "content_block_delta": + delta = resp["delta"] + delta_type = delta["type"] + if delta_type == "text_delta": + if raw: + yield b"data: " + tu.to_json( + { + "object": delta_type, + "choices": [ + {"delta": {"content": delta["text"]}} + ], + }, + tight=True, + ).encode() + yield b"" # uncomment this line if you want 1:1 with OpenAI + else: + yield delta["text"] + elif delta_type == "input_json_delta": + fn_call["arguments"] += delta["partial_json"] + elif resp["type"] == "content_block_stop": + if fn_call: + fn_call["arguments"] = tu.from_json( + fn_call["arguments"] or "{}" + ) + yield fn_call + fn_call = None + except: + break + def distributed_chat( self, prompts: List[tt.Thread], @@ -243,6 +377,7 @@ def distributed_chat( max_threads: int = 10, retry: int = 3, pbar=True, + debug=False, **kwargs, ): return distributed_chat( @@ -252,5 +387,27 @@ def distributed_chat( max_threads=max_threads, retry=retry, pbar=pbar, + debug=debug, + **kwargs, + ) + + async def distributed_chat_async( + self, + prompts: List[tt.Thread], + post_logic: Optional[callable] = None, + max_threads: int = 10, + retry: int = 3, + pbar=True, + debug=False, + **kwargs, + ): + return await distributed_chat_async( + self, + prompts=prompts, + post_logic=post_logic, + max_threads=max_threads, + retry=retry, + pbar=pbar, + debug=debug, **kwargs, ) diff --git a/tuneapi/apis/model_gemini.py b/tuneapi/apis/model_gemini.py index e474325..07108d9 100644 --- a/tuneapi/apis/model_gemini.py +++ b/tuneapi/apis/model_gemini.py @@ -5,13 +5,13 @@ # Copyright © 2024- Frello Technology Private Limited # https://ai.google.dev/gemini-api/docs/function-calling -import json +import httpx import requests from typing import Optional, Any, Dict, List import tuneapi.utils as tu import tuneapi.types as tt -from tuneapi.apis.turbo import distributed_chat +from tuneapi.apis.turbo import distributed_chat, distributed_chat_async class Gemini(tt.ModelInterface): @@ -267,7 +267,7 @@ def stream_chat( # print(f"{block_lines=}") if done: - part_data = json.loads(block_lines)["candidates"][0]["content"][ + part_data = tu.from_json(block_lines)["candidates"][0]["content"][ "parts" ][0] if "text" in part_data: @@ -288,6 +288,187 @@ def stream_chat( yield fn_call block_lines = "" + async def chat_async( + self, + chats: tt.Thread | str, + model: Optional[str] = None, + max_tokens: int = None, + temperature: float = 1, + token: Optional[str] = None, + timeout=None, + extra_headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> Any: + output = "" + x = None + try: + async for x in self.stream_chat_async( + chats=chats, + model=model, + max_tokens=max_tokens, + temperature=temperature, + token=token, + timeout=timeout, + extra_headers=extra_headers, + raw=False, + **kwargs, + ): + if isinstance(x, dict): + output = x + else: + output += x + except Exception as e: + if not x: + raise e + else: + raise ValueError(x) + return output + + async def stream_chat_async( + self, + chats: tt.Thread | str, + model: Optional[str] = None, + max_tokens: int = 4096, + temperature: float = 1, + token: Optional[str] = None, + timeout=(5, 60), + raw: bool = False, + debug: bool = False, + extra_headers: Optional[Dict[str, str]] = None, + **kwargs, + ): + tools = [] + if isinstance(chats, tt.Thread): + tools = [x.to_dict() for x in chats.tools] + headers, system, messages, params = self._process_input(chats, token) + extra_headers = extra_headers or self.extra_headers + if extra_headers: + headers.update(extra_headers) + + data = { + "systemInstruction": { + "parts": [{"text": system}], + }, + "contents": messages, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE", + }, + ], + } + + generation_config = { + "temperature": temperature, + "maxOutputTokens": max_tokens, + "stopSequences": [], + } + + if chats.gen_schema: + generation_config.update( + { + "response_mime_type": "application/json", + "response_schema": chats.gen_schema, + } + ) + data["generationConfig"] = generation_config + + if tools: + data["tool_config"] = { + "function_calling_config": { + "mode": "ANY", + } + } + std_tools = [] + for i, t in enumerate(tools): + props = t["parameters"]["properties"] + t_copy = t.copy() + if not props: + t_copy.pop("parameters") + std_tools.append(t_copy) + data["tools"] = [{"function_declarations": std_tools}] + data.update(kwargs) + + if debug: + fp = "sample_gemini.json" + print("Saving at path " + fp) + tu.to_json(data, fp=fp) + + async with httpx.AsyncClient() as client: + response = await client.post( + self.base_url.format( + id=model or self.model_id, + rpc="streamGenerateContent", + ), + headers=headers, + params=params, + json=data, + timeout=timeout, + ) + try: + response.raise_for_status() + except Exception as e: + yield str(e) + return + + block_lines = "" + done = False + + async for chunk in response.aiter_bytes(): + for line in chunk.decode("utf-8").splitlines(): + # print(f"[{lno:03d}] {line}") + + # get the clean line for block + if line == ("[{"): # first line + line = line[1:] + elif line == "," or line == "]": # intermediate or last line + continue + block_lines += line + + done = False + try: + tu.from_json(block_lines) + done = True + except Exception as e: + pass + + # print(f"{block_lines=}") + if done: + part_data = tu.from_json(block_lines)["candidates"][0]["content"][ + "parts" + ][0] + if "text" in part_data: + if raw: + yield b"data: " + tu.to_json( + { + "object": "gemini_text", + "choices": [ + {"delta": {"content": part_data["text"]}} + ], + }, + tight=True, + ).encode() + yield b"" + else: + yield part_data["text"] + elif "functionCall" in part_data: + fn_call = part_data["functionCall"] + fn_call["arguments"] = fn_call.pop("args") + yield fn_call + block_lines = "" + def distributed_chat( self, prompts: List[tt.Thread], @@ -295,6 +476,7 @@ def distributed_chat( max_threads: int = 10, retry: int = 3, pbar=True, + debug=False, **kwargs, ): return distributed_chat( @@ -304,5 +486,27 @@ def distributed_chat( max_threads=max_threads, retry=retry, pbar=pbar, + debug=debug, + **kwargs, + ) + + async def distributed_chat_async( + self, + prompts: List[tt.Thread], + post_logic: Optional[callable] = None, + max_threads: int = 10, + retry: int = 3, + pbar=True, + debug=False, + **kwargs, + ): + return await distributed_chat_async( + self, + prompts=prompts, + post_logic=post_logic, + max_threads=max_threads, + retry=retry, + pbar=pbar, + debug=debug, **kwargs, ) diff --git a/tuneapi/apis/model_groq.py b/tuneapi/apis/model_groq.py index d12542a..148f874 100644 --- a/tuneapi/apis/model_groq.py +++ b/tuneapi/apis/model_groq.py @@ -4,13 +4,13 @@ # Copyright © 2024- Frello Technology Private Limited -import json +import httpx import requests from typing import Optional, Dict, Any, Tuple, List import tuneapi.utils as tu import tuneapi.types as tt -from tuneapi.apis.turbo import distributed_chat +from tuneapi.apis.turbo import distributed_chat, distributed_chat_async class Groq(tt.ModelInterface): @@ -174,7 +174,7 @@ def stream_chat( line = line.decode().strip() if line: try: - x = json.loads(line.replace("data: ", ""))["choices"][0]["delta"] + x = tu.from_json(line.replace("data: ", ""))["choices"][0]["delta"] if "tool_calls" in x: y = x["tool_calls"][0]["function"] if fn_call is None: @@ -192,6 +192,110 @@ def stream_chat( yield fn_call return + async def chat_async( + self, + chats: tt.Thread | str, + model: Optional[str] = None, + max_tokens: int = 1024, + temperature: float = 0.7, + token: Optional[str] = None, + timeout=(5, 60), + stop: Optional[List[str]] = None, + extra_headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> str | Dict[str, Any]: + output = "" + async for x in self.stream_chat_async( + chats=chats, + model=model, + max_tokens=max_tokens, + temperature=temperature, + token=token, + timeout=timeout, + stop=stop, + extra_headers=extra_headers, + raw=False, + **kwargs, + ): + if isinstance(x, dict): + output = x + else: + output += x + return output + + async def stream_chat_async( + self, + chats: tt.Thread | str, + model: Optional[str] = None, + max_tokens: int = 1024, + temperature: float = 0.7, + token: Optional[str] = None, + timeout=(5, 60), + stop: Optional[List[str]] = None, + raw: bool = False, + debug: bool = False, + extra_headers: Optional[Dict[str, str]] = None, + ): + tools = [] + if isinstance(chats, tt.Thread): + tools = [{"type": "function", "function": x.to_dict()} for x in chats.tools] + headers, messages = self._process_input(chats, token) + extra_headers = extra_headers or self.extra_headers + if extra_headers: + headers.update(extra_headers) + data = { + "temperature": temperature, + "messages": messages, + "model": model or self.model_id, + "stream": True, + "max_tokens": max_tokens, + "tools": tools, + } + if debug: + fp = "sample_groq.json" + print("Saving at path " + fp) + tu.to_json(data, fp=fp) + + async with httpx.AsyncClient() as client: + response = await client.post( + self.base_url, + headers=headers, + json=data, + timeout=timeout, + ) + response.raise_for_status() + + fn_call = None + async for line in response.aiter_text(): + if raw: + yield line + continue + + line = line.strip() + if line: + try: + x = tu.from_json(line.replace("data: ", ""))["choices"][0][ + "delta" + ] + if "tool_calls" in x: + y = x["tool_calls"][0]["function"] + if fn_call is None: + fn_call = { + "name": y["name"], + "arguments": y["arguments"], + } + else: + fn_call["arguments"] += y["arguments"] + elif "content" in x: + c = x["content"] + if c: + yield c + except: + break + if fn_call: + fn_call["arguments"] = tu.from_json(fn_call["arguments"]) + yield fn_call + def distributed_chat( self, prompts: List[tt.Thread], @@ -199,6 +303,7 @@ def distributed_chat( max_threads: int = 10, retry: int = 3, pbar=True, + debug=False, **kwargs, ): return distributed_chat( @@ -208,5 +313,27 @@ def distributed_chat( max_threads=max_threads, retry=retry, pbar=pbar, + debug=debug, + **kwargs, + ) + + def distributed_chat_async( + self, + prompts: List[tt.Thread], + post_logic: Optional[callable] = None, + max_threads: int = 10, + retry: int = 3, + pbar=True, + debug=False, + **kwargs, + ): + return distributed_chat_async( + self, + prompts=prompts, + post_logic=post_logic, + max_threads=max_threads, + retry=retry, + pbar=pbar, + debug=debug, **kwargs, ) diff --git a/tuneapi/apis/model_mistral.py b/tuneapi/apis/model_mistral.py index 16d871a..bf240c6 100644 --- a/tuneapi/apis/model_mistral.py +++ b/tuneapi/apis/model_mistral.py @@ -194,6 +194,115 @@ def stream_chat( yield fn_call return + async def chat_async( + self, + chats: tt.Thread | str, + model: Optional[str] = None, + max_tokens: int = 1024, + temperature: float = 0.7, + token: Optional[str] = None, + timeout=(5, 60), + stop: Optional[List[str]] = None, + extra_headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> str | Dict[str, Any]: + output = "" + async for x in self.stream_chat_async( + chats=chats, + model=model, + max_tokens=max_tokens, + temperature=temperature, + token=token, + timeout=timeout, + stop=stop, + extra_headers=extra_headers, + raw=False, + **kwargs, + ): + if isinstance(x, dict): + output = x + else: + output += x + return output + + async def stream_chat_async( + self, + chats: tt.Thread | str, + model: Optional[str] = None, + max_tokens: int = 1024, + temperature: float = 0.7, + token: Optional[str] = None, + timeout=(5, 60), + stop: Optional[List[str]] = None, + raw: bool = False, + debug: bool = False, + extra_headers: Optional[Dict[str, str]] = None, + ): + tools = [] + if isinstance(chats, tt.Thread): + tools = [{"type": "function", "function": x.to_dict()} for x in chats.tools] + headers, messages = self._process_input(chats, token) + extra_headers = extra_headers or self.extra_headers + if extra_headers: + headers.update(extra_headers) + data = { + "messages": messages, + "model": model or self.model_id, + "stream": True, + "max_tokens": max_tokens, + "tools": tools, + } + if temperature: + data["temperature"] = temperature + if debug: + fp = "sample_mistral.json" + print("Saving at path " + fp) + tu.to_json(data, fp=fp) + + async with requests.post( + self.base_url, + headers=headers, + json=data, + stream=True, + timeout=timeout, + ) as response: + try: + response.raise_for_status() + except Exception as e: + yield response.text + raise e + + fn_call = None + async for line in response.aiter_lines(): + if raw: + yield line + continue + + line = line.decode().strip() + if line: + try: + x = json.loads(line.replace("data: ", ""))["choices"][0][ + "delta" + ] + if "tool_calls" in x: + y = x["tool_calls"][0]["function"] + if fn_call is None: + fn_call = { + "name": y["name"], + "arguments": y["arguments"], + } + else: + fn_call["arguments"] += y["arguments"] + elif "content" in x: + c = x["content"] + if c: + yield c + except: + break + if fn_call: + fn_call["arguments"] = tu.from_json(fn_call["arguments"]) + yield fn_call + def distributed_chat( self, prompts: List[tt.Thread], @@ -201,6 +310,7 @@ def distributed_chat( max_threads: int = 10, retry: int = 3, pbar=True, + debug=False, **kwargs, ): return distributed_chat( @@ -210,5 +320,6 @@ def distributed_chat( max_threads=max_threads, retry=retry, pbar=pbar, + debug=debug, **kwargs, ) diff --git a/tuneapi/apis/model_openai.py b/tuneapi/apis/model_openai.py index 325ccf8..f0faa55 100644 --- a/tuneapi/apis/model_openai.py +++ b/tuneapi/apis/model_openai.py @@ -5,13 +5,14 @@ # Copyright © 2024- Frello Technology Private Limited import json +import httpx import requests from typing import Optional, Any, List, Dict import tuneapi.utils as tu import tuneapi.types as tt -from tuneapi.apis.turbo import distributed_chat +from tuneapi.apis.turbo import distributed_chat, distributed_chat_async class Openai(tt.ModelInterface): @@ -194,6 +195,118 @@ def stream_chat( yield fn_call return + async def chat_async( + self, + chats: tt.Thread | str, + model: Optional[str] = None, + max_tokens: int = None, + temperature: float = 1, + parallel_tool_calls: bool = False, + token: Optional[str] = None, + extra_headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> Any: + output = "" + async for x in self.stream_chat_async( + chats=chats, + model=model, + max_tokens=max_tokens, + temperature=temperature, + parallel_tool_calls=parallel_tool_calls, + token=token, + extra_headers=extra_headers, + raw=False, + **kwargs, + ): + if isinstance(x, dict): + output = x + else: + output += x + return output + + async def stream_chat_async( + self, + chats: tt.Thread | str, + model: Optional[str] = None, + max_tokens: int = None, + temperature: float = 1, + parallel_tool_calls: bool = False, + token: Optional[str] = None, + timeout=(5, 60), + extra_headers: Optional[Dict[str, str]] = None, + debug: bool = False, + raw: bool = False, + **kwargs, + ): + headers, messages = self._process_input(chats, token) + extra_headers = extra_headers or self.extra_headers + if extra_headers: + headers.update(extra_headers) + data = { + "temperature": temperature, + "messages": messages, + "model": model or self.model_id, + "stream": True, + } + if max_tokens: + data["max_tokens"] = max_tokens + if isinstance(chats, tt.Thread) and len(chats.tools): + data["tools"] = [ + {"type": "function", "function": x.to_dict()} for x in chats.tools + ] + data["parallel_tool_calls"] = parallel_tool_calls + if kwargs: + data.update(kwargs) + if debug: + fp = "sample_oai.json" + print("Saving at path " + fp) + tu.to_json(data, fp=fp) + + async with httpx.AsyncClient() as client: + try: + response = await client.post( + self.base_url, + headers=headers, + json=data, + timeout=timeout, + ) + response.raise_for_status() + except Exception as e: + yield str(e) + return + + fn_call = None + async for chunk in response.aiter_bytes(): + for line in chunk.decode("utf-8").splitlines(): + if raw: + yield line + continue + + line = line.strip() + if line.startswith("data: "): + line = line[6:] + if line: + try: + x = json.loads(line)["choices"][0]["delta"] + if "tool_calls" not in x: + yield x["content"] + else: + y = x["tool_calls"][0]["function"] + if fn_call is None: + fn_call = { + "name": y["name"], + "arguments": y["arguments"], + } + else: + fn_call["arguments"] += y["arguments"] + except: + break + if fn_call: + fn_call["arguments"] = tu.from_json(fn_call["arguments"]) + yield fn_call + + # Distributed chat functionalities + def distributed_chat( self, prompts: List[tt.Thread], @@ -201,6 +314,7 @@ def distributed_chat( max_threads: int = 10, retry: int = 3, pbar=True, + debug=False, **kwargs, ): return distributed_chat( @@ -210,9 +324,33 @@ def distributed_chat( max_threads=max_threads, retry=retry, pbar=pbar, + debug=debug, **kwargs, ) + async def distributed_chat_async( + self, + prompts: List[tt.Thread], + post_logic: Optional[callable] = None, + max_threads: int = 10, + retry: int = 3, + pbar=True, + debug=False, + **kwargs, + ): + return await distributed_chat_async( + self, + prompts=prompts, + post_logic=post_logic, + max_threads=max_threads, + retry=retry, + pbar=pbar, + debug=debug, + **kwargs, + ) + + # Embedding models + def embedding( self, chats: tt.Thread | List[str] | str, diff --git a/tuneapi/apis/model_tune.py b/tuneapi/apis/model_tune.py index cfd22af..98675c8 100644 --- a/tuneapi/apis/model_tune.py +++ b/tuneapi/apis/model_tune.py @@ -4,13 +4,13 @@ # Copyright © 2024- Frello Technology Private Limited -import json +import httpx import requests from typing import Optional, Dict, Any, List import tuneapi.utils as tu import tuneapi.types as tt -from tuneapi.apis.turbo import distributed_chat +from tuneapi.apis.turbo import distributed_chat, distributed_chat_async class TuneModel(tt.ModelInterface): @@ -198,7 +198,7 @@ def stream_chat( line = line.decode().strip() if line: try: - delta = json.loads(line.replace("data: ", ""))["choices"][0][ + delta = tu.from_json(line.replace("data: ", ""))["choices"][0][ "delta" ] if "tool_calls" in delta: @@ -219,6 +219,116 @@ def stream_chat( yield fn_call return + async def chat_async( + self, + chats: tt.Thread | str, + model: Optional[str] = None, + max_tokens: int = 1024, + temperature: float = 0.7, + token: Optional[str] = None, + timeout=(5, 60), + stop: Optional[List[str]] = None, + extra_headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> str | Dict[str, Any]: + output = "" + async for x in self.stream_chat_async( + chats=chats, + model=model, + max_tokens=max_tokens, + temperature=temperature, + token=token, + timeout=timeout, + stop=stop, + extra_headers=extra_headers, + raw=False, + **kwargs, + ): + if isinstance(x, dict): + output = x + else: + output += x + return output + + async def stream_chat_async( + self, + chats: tt.Thread | str, + model: Optional[str] = None, + max_tokens: int = 1024, + temperature: float = 0.7, + token: Optional[str] = None, + timeout=(5, 60), + stop: Optional[List[str]] = None, + raw: bool = False, + debug: bool = False, + extra_headers: Optional[Dict[str, str]] = None, + ): + model = model or self.model_id + if not model: + raise Exception( + "Tune Model ID not found. Please set TUNEAPI_MODEL environment variable or pass through function" + ) + headers, messages = self._process_input(chats, token) + extra_headers = extra_headers or self.extra_headers + if extra_headers: + headers.update(extra_headers) + data = { + "temperature": temperature, + "messages": messages, + "model": model, + "stream": True, + "max_tokens": max_tokens, + } + if stop: + data["stop"] = stop + if isinstance(chats, tt.Thread) and len(chats.tools): + data["tools"] = [ + {"type": "function", "function": x.to_dict()} for x in chats.tools + ] + if debug: + fp = "sample_tune.json" + tu.logger.info("Saving at path " + fp) + tu.to_json(data, fp=fp) + + async with httpx.AsyncClient() as client: + response = await client.post( + self.base_url, + headers=headers, + json=data, + timeout=timeout, + ) + response.raise_for_status() + + fn_call = None + async for line in response.aiter_lines(): + if raw: + yield line + continue + + line = line.strip() + if line: + try: + delta = tu.from_json(line.replace("data: ", ""))["choices"][0][ + "delta" + ] + if "tool_calls" in delta: + y = delta["tool_calls"][0]["function"] + if fn_call is None: + fn_call = { + "name": y["name"], + "arguments": y.get("arguments", ""), + } + else: + fn_call["arguments"] += y["arguments"] + elif "content" in delta: + yield delta["content"] + except: + break + if fn_call: + fn_call["arguments"] = tu.from_json(fn_call["arguments"]) + yield fn_call + return + def distributed_chat( self, prompts: List[tt.Thread], @@ -226,6 +336,7 @@ def distributed_chat( max_threads: int = 10, retry: int = 3, pbar=True, + debug=False, **kwargs, ): return distributed_chat( @@ -235,5 +346,27 @@ def distributed_chat( max_threads=max_threads, retry=retry, pbar=pbar, + debug=debug, + **kwargs, + ) + + async def distributed_chat_async( + self, + prompts: List[tt.Thread], + post_logic: Optional[callable] = None, + max_threads: int = 10, + retry: int = 3, + pbar=True, + debug=False, + **kwargs, + ): + return await distributed_chat_async( + self, + prompts=prompts, + post_logic=post_logic, + max_threads=max_threads, + retry=retry, + pbar=pbar, + debug=debug, **kwargs, ) diff --git a/tuneapi/apis/turbo.py b/tuneapi/apis/turbo.py index 5333ddb..c3bf170 100644 --- a/tuneapi/apis/turbo.py +++ b/tuneapi/apis/turbo.py @@ -1,12 +1,14 @@ # Copyright © 2024- Frello Technology Private Limited import queue +import asyncio import threading from tqdm import trange from typing import List, Optional, Dict from dataclasses import dataclass from tuneapi.types import Thread, ModelInterface, human, system +from tuneapi.utils import logger def distributed_chat( @@ -16,6 +18,7 @@ def distributed_chat( max_threads: int = 10, retry: int = 3, pbar=True, + debug=False, **kwargs, ): """ @@ -79,7 +82,7 @@ def worker(): break try: - out = task.model.chat(chat=task.prompt, **task.kwargs) + out = task.model.chat(chats=task.prompt, **task.kwargs) if post_logic: out = post_logic(out) result_channel.put(_Result(task.index, out, True)) @@ -117,6 +120,9 @@ def worker(): t.start() workers.append(t) + if debug: + logger.info(f"Processing {len(prompts)} prompts with {max_threads} workers") + # Initialize progress bar _pbar = trange(len(prompts), desc="Processing", unit=" input") if pbar else None @@ -163,6 +169,74 @@ def worker(): return results +async def distributed_chat_async( + model: ModelInterface, + prompts: List[Thread], + post_logic: Optional[callable] = None, + max_threads: int = 10, + retry: int = 3, + pbar=True, + debug=False, + **kwargs, +): + results = [None for _ in range(len(prompts))] + + async def process_prompt(index, prompt, retry_count=0): + try: + out = await model.chat_async(chats=prompt, **kwargs) + if post_logic: + out = post_logic(out) + return (index, out, True) + except Exception as e: + if retry_count < retry: + # create new async model + nm = model.__class__( + id=model.model_id, + base_url=model.base_url, + extra_headers=model.extra_headers, + ) + nm.set_api_token(model.api_token) + + return await process_prompt(index, prompt, retry_count + 1) + else: + return (index, None, False, e) + + # Run all tasks concurrently using asyncio.gather + tasks = [] + for i, prompt in enumerate(prompts): + nm = model.__class__( + id=model.model_id, + base_url=model.base_url, + extra_headers=model.extra_headers, + ) + nm.set_api_token(model.api_token) + tasks.append(process_prompt(i, prompt)) + + if debug: + logger.info(f"Processing {len(prompts)} prompts with {max_threads} workers") + + _pbar = trange(len(prompts), desc="Processing", unit=" input") if pbar else None + + results_from_gather = await asyncio.gather(*tasks) + + # Process results + for r in results_from_gather: + index, data, success, *error = r + + if success: + results[index] = data + else: + results[index] = error[0] if error else None + + if _pbar: + _pbar.update(1) + + if _pbar: + _pbar.close() + + return results + + # helpers diff --git a/tuneapi/types/chats.py b/tuneapi/types/chats.py index c14bed9..791577d 100644 --- a/tuneapi/types/chats.py +++ b/tuneapi/types/chats.py @@ -346,6 +346,30 @@ def stream_chat( ): """This is the main function to stream chat with the model where each token is iteratively generated""" + async def chat_async( + chats: "Thread", + model: Optional[str] = None, + max_tokens: int = 1024, + temperature: float = 1, + token: Optional[str] = None, + timeout=(5, 30), + extra_headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> str | Dict[str, Any]: + """This is the async function to block chat with the model""" + + async def stream_chat_async( + chats: "Thread", + model: Optional[str] = None, + max_tokens: int = 1024, + temperature: float = 1, + token: Optional[str] = None, + timeout=(5, 30), + extra_headers: Optional[Dict[str, str]] = None, + **kwargs, + ) -> str | Dict[str, Any]: + """This is the async function to stream chat with the model where each token is iteratively generated""" + def distributed_chat( self, prompts: List["Thread"],