From 83799be3282c3f1d3400a4c45f05bc69d675c20d Mon Sep 17 00:00:00 2001 From: moes-91 Date: Tue, 17 Dec 2024 17:03:10 +0100 Subject: [PATCH] feat: implement classes for proper typing and dot notation in responses chore: refactor dataclasses into new file chore: update Readme chore: bump to 0.0.8 --- README.md | 263 +++++++++++++++++++++++++++++++---- pyproject.toml | 2 +- src/xai_grok_sdk/__init__.py | 24 +++- src/xai_grok_sdk/models.py | 136 ++++++++++++++++++ src/xai_grok_sdk/xai.py | 192 ++++++++++++++----------- 5 files changed, 508 insertions(+), 109 deletions(-) create mode 100644 src/xai_grok_sdk/models.py diff --git a/README.md b/README.md index 0fb594b..0c5a1f8 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,9 @@ response = xai.invoke( ] ) -print(response["message"]) +response_message = response.choices[0].message +print(response_message) +# Response: Message(role='assistant', content="Hello! I can help you with a wide range of tasks and questions. Whether you need assistance with information, problem-solving, learning something new, or just want to have a conversation, I'm here to help. What specifically would you like assistance with today?", tool_calls=None, tool_results=None, refusal=None) ``` ## Parameters @@ -109,27 +111,33 @@ tools = [ } ] -# Implement the function +# Implement the tool function def get_weather(location: str) -> str: return f"The weather in {location} is sunny." # Initialize the client with tools and function implementations -xai = XAI( - api_key="your_api_key", +llm = XAI( + api_key=api_key, model="grok-2-1212", tools=tools, - function_map={"get_weather": get_weather} # Map function names to implementations + function_map={"get_weather": get_weather} ) -# Make a request that might trigger function calling -response = xai.invoke( +# Make a request that will use function calling +response = llm.invoke( messages=[ - {"role": "user", "content": "What's the weather like in San Francisco?"} + {"role": "user", "content": "What is the weather in San Francisco?"}, ], - tool_choice="auto" # Can be 'auto', 'required', 'none', or a specific function + tool_choice="auto" # Can be "auto", "required", or "none" ) -print(response["message"]) +response_message = response.choices[0].message +print(response_message) +# Response: Message(role='assistant', content='I am retrieving the weather for San Francisco.', tool_calls=[{'id': '0', 'function': {'name': 'get_weather', 'arguments': '{"location":"San Francisco, CA"}'}, 'type': 'function'}], tool_results=[{'tool_call_id': '0', 'role': 'tool', 'name': 'get_weather', 'content': 'The weather in San Francisco, CA is sunny.'}], refusal=None) + +tool_result_content = response_message.tool_results[0].content +print(tool_result_content) +# Response: The weather in San Francisco, CA is sunny. ``` The SDK supports various function calling modes through the `tool_choice` parameter: @@ -144,37 +152,242 @@ The SDK supports various function calling modes through the `tool_choice` parame For more details, see the [xAI Function Calling Guide](https://docs.x.ai/docs/guides/function-calling) and the [API Reference](https://docs.x.ai/api/endpoints#chat-completions). -> **Note**: Currently, using `"required"` or specific function calls may produce unexpected outputs. It is recommended to use either `"auto"` or `"none"` for more reliable results. - ### Required Parameters for Function Calling When using function calling, you need to provide: - `tools`: List of tool definitions with their schemas -- `function_map`: Dictionary mapping function names to their actual implementations -> **Note**: The `function_map` parameter is required when tools are provided. Each tool name must have a corresponding implementation in the function map. +### Function Map + +The `function_map` optional parameter maps tool names to their Python implementations. This allows you to actually invoke the function and append its result to the response. Each function in the map must: + +- Have a name matching a tool definition +- Accept the parameters specified in the tool's JSON Schema +- Return a value that can be converted to a string + +> **Note**: The `function_map` parameter is not required when tools are provided. However, when omitted, only the tool call with the parameters used by the model will be included in the response. ## API Reference ### XAI Class -#### `__init__(api_key: str, model: str, tools: Optional[List[Dict[str, Any]]] = None, function_map: Optional[Dict[str, Any]] = None)` +The main class for interacting with the xAI API. + +#### Constructor Parameters + +- `api_key` (str, required): Your xAI API key +- `model` (ModelType, required): Model to use ("grok-2-1212" or "grok-beta") +- `tools` (List[Dict[str, Any]], optional): List of available tools +- `function_map` (Dict[str, Callable], optional): Map of function names to implementations + +#### Methods + +##### invoke + +Makes a chat completion request to the xAI API. + +```python +def invoke( + messages: List[Dict[str, Any]], # REQUIRED + frequency_penalty: Optional[float] = None, # Range: -2.0 to 2.0 + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, # Range: -2.0 to 2.0 + response_format: Optional[Any] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Any] = None, + temperature: Optional[float] = None, # Range: 0 to 2 + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, # Range: 0 to 20 + top_p: Optional[float] = None, # Range: 0 to 1 + user: Optional[str] = None +) -> ChatCompletionResponse +``` + +### Response Models + +The SDK uses several dataclasses to represent the API response structure: + +#### Message + +Represents a message in the chat completion response. + +```python +@dataclass +class Message: + role: str # Role of the message sender (e.g., "assistant", "user") + content: str # Content of the message + tool_calls: Optional[List[ToolCall]] = None # List of tool calls made by the model + tool_results: Optional[List[ToolResult]] = None # Results from tool executions + refusal: Optional[Any] = None # Information about message refusal if applicable +``` + +#### ToolCall + +Represents a tool call in the chat completion response. + +```python +@dataclass +class ToolCall: + id: str # Unique identifier for the tool call + function: Function # Function details + type: str # Type of the tool call +``` + +#### Function + +Represents a function in a tool call. + +```python +@dataclass +class Function: + name: str # Name of the function + arguments: Dict[str, Any] # Arguments passed to the function +``` + +#### ToolResult + +Represents a tool result in the chat completion response. + +```python +@dataclass +class ToolResult: + tool_call_id: str # ID of the associated tool call + role: str # Role (typically "tool") + name: str # Name of the tool + content: Any # Result content from the tool execution +``` + +#### ChatCompletionResponse + +The main response object returned by the `invoke` method. + +```python +@dataclass +class ChatCompletionResponse: + id: str # Unique identifier for the completion + choices: List[Choice] # List of completion choices + created: int # Unix timestamp of creation + model: str # Model used for completion + object: str # Object type ("chat.completion") + system_fingerprint: str # System fingerprint + usage: Optional[Usage] # Token usage statistics +``` + +#### Choice + +Represents a single completion choice in the response. + +```python +@dataclass +class Choice: + index: int # Index of this choice + message: Message # The generated message + finish_reason: Optional[str] # Why the model stopped generating + logprobs: Optional[Dict[str, Any]] # Log probabilities if requested +``` + +#### Usage + +Contains token usage statistics for the request. + +```python +@dataclass +class Usage: + prompt_tokens: int # Tokens in the prompt + completion_tokens: int # Tokens in the completion + total_tokens: int # Total tokens used + prompt_tokens_details: Optional[Dict[str, Any]] # Detailed token usage +``` + +### Example Response + +Here's an example of a typical response when using function calling: + +```python +ChatCompletionResponse( + id='...', + choices=[ + Choice( + index=0, + message=Message( + role='assistant', + content='I am retrieving the weather for San Francisco.', + tool_calls=[{ + 'id': '0', + 'function': { + 'name': 'get_weather', + 'arguments': '{"location":"San Francisco, CA"}' + }, + 'type': 'function' + }], + tool_results=[{ + 'tool_call_id': '0', + 'role': 'tool', + 'name': 'get_weather', + 'content': 'The weather in San Francisco, CA is sunny.' + }], + refusal=None + ), + finish_reason='stop', + logprobs=None + ) + ], + created=1703..., + model='grok-2-1212', + object='chat.completion', + system_fingerprint='...', + usage=Usage( + prompt_tokens=50, + completion_tokens=20, + total_tokens=70, + prompt_tokens_details=None + ) +) +``` + +## Security Best Practices + +When using this SDK, follow these security best practices: + +1. **API Key Management** + + - Never hardcode your API key directly in your code + - Use environment variables to store your API key: + + ```python + import os -Initialize the XAI client. + api_key = os.getenv("XAI_API_KEY") + xai = XAI(api_key=api_key, model="grok-2-1212") + ``` -- `api_key`: Your xAI API key -- `model`: The model to use for chat completions -- `tools`: Optional list of tools available for the model to use -- `function_map`: Optional dictionary mapping function names to their implementations + - Consider using a secure secrets management service in production + - Keep your API key private and never commit it to version control -#### `invoke(messages: List[Dict[str, Any]], tool_choice: str = "auto") -> Dict[str, Any]` +2. **Environment Variables** -Run a conversation with the model. + - Create a `.env` file for local development (and add it to `.gitignore`) + - Example `.env` file: + ``` + XAI_API_KEY=your_api_key_here + ``` -- `messages`: List of conversation messages -- `tool_choice`: Function calling mode ('auto', 'required', 'none', or specific function) -- Returns: Dictionary containing the model's response +3. **Request Validation** + - The SDK automatically validates all parameters before making API calls + - Always handle potential exceptions in your code: + ```python + try: + response = xai.invoke(messages=[{"role": "user", "content": "Hello"}]) + except Exception as e: + # Handle the error appropriately + print(f"An error occurred: {e}") + ``` ## License diff --git a/pyproject.toml b/pyproject.toml index 90251ae..011f438 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "xai-grok-sdk" -version = "0.0.7" +version = "0.0.8" description = "Lightweight xAI SDK with minimal dependencies" dependencies = [ "requests>=2.32.3", diff --git a/src/xai_grok_sdk/__init__.py b/src/xai_grok_sdk/__init__.py index 2dab1a9..1eac5f7 100644 --- a/src/xai_grok_sdk/__init__.py +++ b/src/xai_grok_sdk/__init__.py @@ -1,3 +1,25 @@ from xai_grok_sdk.xai import XAI +from xai_grok_sdk.models import ( + ModelType, + ChatCompletionRequest, + Message, + Usage, + Choice, + ChatCompletionResponse, + Function, + ToolCall, + ToolResult, +) -__all__ = ["XAI"] \ No newline at end of file +__all__ = [ + "XAI", + "ModelType", + "ChatCompletionRequest", + "Message", + "Usage", + "Choice", + "ChatCompletionResponse", + "Function", + "ToolCall", + "ToolResult", +] diff --git a/src/xai_grok_sdk/models.py b/src/xai_grok_sdk/models.py new file mode 100644 index 0000000..67625df --- /dev/null +++ b/src/xai_grok_sdk/models.py @@ -0,0 +1,136 @@ +"""Data models for XAI API interactions.""" + +from typing import List, Dict, Any, Optional, Union, Literal +from dataclasses import dataclass + + +# Supported xAI models +ModelType = Literal["grok-2-1212", "grok-beta"] + + +@dataclass +class ChatCompletionRequest: + """Request structure for chat completions API.""" + + messages: List[ + Dict[str, Any] + ] # Required: List of messages in the chat conversation + model: str # Required: Model name to use + + # Optional parameters with defaults + frequency_penalty: Optional[float] # Range: -2.0 to 2.0 + logit_bias: Optional[Dict[Any, Any]] + logprobs: Optional[bool] # Whether to return log probabilities + max_tokens: Optional[int] # Maximum tokens to generate + n: Optional[int] # Number of chat completion choices + presence_penalty: Optional[float] # Range: -2.0 to 2.0 + response_format: Optional[Any] + seed: Optional[int] # For deterministic sampling + stop: Optional[List[Any]] # Up to 4 sequences to stop generation + stream: Optional[bool] # Whether to stream partial message deltas + stream_options: Optional[Any] + temperature: Optional[float] # Range: 0 to 2 + tool_choice: Optional[Union[str, Dict[str, Any]]] + tools: Optional[List[Dict[str, Any]]] # List of available tools + top_logprobs: Optional[int] # Range: 0 to 20 + top_p: Optional[float] # Range: 0 to 1 + user: Optional[str] # Unique end-user identifier + + def __post_init__(self): + """Validate and normalize the request parameters.""" + if not self.messages: + raise ValueError("messages is required and cannot be empty") + if not self.model: + raise ValueError("model is required and cannot be empty") + + # Initialize defaults if None + self.logit_bias = self.logit_bias or {} + self.stop = self.stop or [] + self.tools = self.tools or [] + if not self.tools: + self.tool_choice = None + + # Clamp values to their valid ranges + if self.frequency_penalty: + self.frequency_penalty = max(-2.0, min(2.0, self.frequency_penalty)) + if self.presence_penalty: + self.presence_penalty = max(-2.0, min(2.0, self.presence_penalty)) + if self.temperature: + self.temperature = max(0.0, min(2.0, self.temperature)) + if self.top_p: + self.top_p = max(0.0, min(1.0, self.top_p)) + if self.top_logprobs: + self.top_logprobs = max(0, min(20, self.top_logprobs)) + + # Ensure stop sequences don't exceed limit + if len(self.stop) > 4: + self.stop = self.stop[:4] + + +@dataclass +class Function: + name: str + arguments: Dict[str, Any] + + +@dataclass +class ToolCall: + """Represents a tool call in the chat completion response.""" + + id: str + function: Function + type: str + + +@dataclass +class ToolResult: + """Represents a tool result in the chat completion response.""" + + tool_call_id: str + role: str + name: str + content: Any + + +@dataclass +class Message: + """Represents a message in the chat completion response.""" + + role: str + content: str + tool_calls: Optional[List[ToolCall]] = None + tool_results: Optional[List[ToolResult]] = None + refusal: Optional[Any] = None + + +@dataclass +class Usage: + """Usage information for the API response.""" + + prompt_tokens: int + completion_tokens: int + total_tokens: int + prompt_tokens_details: Optional[Dict[str, Any]] + + +@dataclass +class Choice: + """Represents a single completion choice in the API response.""" + + index: int + message: Message + finish_reason: Optional[str] + logprobs: Optional[Dict[str, Any]] = None + + +@dataclass +class ChatCompletionResponse: + """Response structure for chat completions API.""" + + id: str # Unique identifier for the completion + choices: List[Choice] # List of completion choices + created: int # Unix timestamp of creation + model: str # Model used for completion + object: str # Object type, typically "chat.completion" + system_fingerprint: str # System fingerprint for the response + usage: Optional[Usage] # Token usage statistics diff --git a/src/xai_grok_sdk/xai.py b/src/xai_grok_sdk/xai.py index 8d47036..2c0d118 100644 --- a/src/xai_grok_sdk/xai.py +++ b/src/xai_grok_sdk/xai.py @@ -2,10 +2,18 @@ import json import requests -from typing import List, Dict, Any, Optional, Callable, Union, Literal - -# Supported xAI models -ModelType = Literal["grok-2-1212", "grok-beta"] +from typing import List, Dict, Any, Optional, Callable, Union +from xai_grok_sdk.models import ( + ModelType, + ChatCompletionRequest, + Message, + Usage, + Choice, + ChatCompletionResponse, + ToolCall, + ToolResult, + Function, +) class XAI: @@ -36,22 +44,18 @@ def __init__( self.function_map = {} if tools: - if not function_map: - raise ValueError( - "function_map must be provided when tools are specified" - ) - for tool in tools: if "name" not in tool: raise ValueError("Each tool must have a 'name' field") - - func_name = tool["name"] - if func_name not in function_map: - raise ValueError( - f"Function '{func_name}' not found in function_map" - ) self.tools.append({"type": "function", "function": tool}) - self.function_map[func_name] = function_map[func_name] + + if function_map: + func_name = tool["name"] + if func_name not in function_map: + raise ValueError( + f"Function '{func_name}' not found in function_map" + ) + self.function_map[func_name] = function_map[func_name] def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: """Make an API call to XAI.""" @@ -68,23 +72,23 @@ def _make_api_call(self, payload: Dict[str, Any]) -> Dict[str, Any]: def invoke( self, messages: List[Dict[str, Any]], - frequency_penalty: float = 0, - logit_bias: Dict[str, float] = None, - logprobs: bool = False, - max_tokens: int = 0, - n: int = 0, - presence_penalty: float = 0, - response_format: Optional[Dict[str, str]] = None, - seed: int = 0, - stop: Optional[Union[str, List[str]]] = None, - stream: bool = False, - stream_options: Optional[Dict[str, Any]] = None, - temperature: float = 0, - tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto", - top_logprobs: int = 0, - top_p: float = 0, - user: str = "", - ) -> Dict[str, Any]: + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[int, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Any] = None, + seed: Optional[int] = None, + stop: Optional[List[Any]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Any] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> ChatCompletionResponse: """ Run a conversation with the model. @@ -108,50 +112,30 @@ def invoke( user: End-user identifier Returns: - Dict containing the model's response + Choice containing the model's response """ # Prepare initial payload - payload = { - "model": self.model, - "messages": messages, - } - - # Add optional parameters if they differ from defaults - if frequency_penalty != 0: - payload["frequency_penalty"] = frequency_penalty - if logit_bias: - payload["logit_bias"] = logit_bias - if logprobs: - payload["logprobs"] = logprobs - if max_tokens != 0: - payload["max_tokens"] = max_tokens - if n != 0: - payload["n"] = n - if presence_penalty != 0: - payload["presence_penalty"] = presence_penalty - if response_format is not None: - payload["response_format"] = response_format - if seed != 0: - payload["seed"] = seed - if stop: - payload["stop"] = stop - if stream: - payload["stream"] = stream - if stream_options is not None: - payload["stream_options"] = stream_options - if temperature != 0: - payload["temperature"] = temperature - if top_logprobs != 0: - payload["top_logprobs"] = top_logprobs - if top_p != 0: - payload["top_p"] = top_p - if user: - payload["user"] = user - - # Include tools if configured - if len(self.tools) > 0: - payload["tools"] = self.tools - payload["tool_choice"] = tool_choice + payload = ChatCompletionRequest( + messages=messages, + model=self.model, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tools=self.tools, + tool_choice=tool_choice, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ).__dict__ # Make API call response_data = self._make_api_call(payload) @@ -173,16 +157,60 @@ def invoke( arguments = json.loads(tool_call["function"]["arguments"]) result = self.function_map[function_name](**arguments) tool_results.append( - { - "tool_call_id": tool_call["id"], - "role": "tool", - "name": function_name, - "content": str(result), - } + ToolResult( + tool_call_id=tool_call["id"], + role="tool", + name=function_name, + content=str(result), + ) ) # Add tool results to the message if any were generated if tool_results: message["tool_results"] = tool_results - return response_data["choices"][0] + # Convert response to ChatCompletionResponse + response = ChatCompletionResponse( + id=response_data["id"], + choices=[ + Choice( + index=choice["index"] if "index" in choice else 0, + message=Message( + role=choice["message"]["role"], + content=choice["message"]["content"], + tool_calls=( + [ + ToolCall( + id=tc["id"], + type=tc["type"], + function=Function( + name=tc["function"]["name"], + arguments=json.loads( + tc["function"]["arguments"] + ), + ), + ) + for tc in choice["message"].get("tool_calls", []) or [] + ] + if "tool_calls" in choice["message"] + else None + ), + tool_results=( + tool_results + if "tool_results" in choice["message"] + else None + ), + ), + finish_reason=choice["finish_reason"], + logprobs=choice["logprobs"] if "logprobs" in choice else None, + ) + for choice in response_data["choices"] + ], + created=response_data["created"], + model=response_data["model"], + object=response_data["object"], + system_fingerprint=response_data["system_fingerprint"], + usage=Usage(**response_data["usage"]) if "usage" in response_data else None, + ) + + return response