-
Notifications
You must be signed in to change notification settings - Fork 238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add Fireworks Python SDK support #592
base: main
Are you sure you want to change the base?
Changes from 5 commits
b3035d7
84e44b3
e822933
58fee49
2608f16
566ced9
8c8fa36
bb3b056
26c71cf
68b276e
e9a5a34
0f10f0a
9d22c6e
a28c0f1
49d3f27
74eed4a
c5d6c21
9cb419d
bf9355a
7203a76
001921e
5a8341c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
import logging | ||
from typing import Optional, AsyncGenerator | ||
import pprint | ||
from agentops.session import Session | ||
from agentops.helpers import get_ISO_time | ||
from agentops.event import LLMEvent | ||
from agentops.enums import EventType | ||
from .instrumented_provider import InstrumentedProvider | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class FireworksProvider(InstrumentedProvider): | ||
"""Provider for Fireworks.ai API.""" | ||
|
||
def __init__(self, client): | ||
super().__init__(client) | ||
self._provider_name = "Fireworks" | ||
self._original_completion = None | ||
self._original_async_completion = None | ||
self._session = None # Initialize session attribute | ||
logger.info(f"Initializing {self._provider_name} provider") | ||
|
||
def set_session(self, session: Session): | ||
"""Set the session for event tracking.""" | ||
self._session = session | ||
logger.debug(f"Set session {session.session_id} for {self._provider_name} provider") | ||
|
||
def handle_response(self, response, kwargs, init_timestamp, session: Optional[Session] = None) -> dict: | ||
"""Handle the response from the Fireworks API.""" | ||
if session: | ||
self._session = session | ||
logger.debug(f"Updated session to {session.session_id} for {self._provider_name} provider") | ||
|
||
try: | ||
# Handle streaming response | ||
if kwargs.get('stream', False): | ||
async def async_generator(stream): | ||
async for chunk in stream: | ||
try: | ||
# Parse the chunk data | ||
if hasattr(chunk, 'choices') and chunk.choices: | ||
content = chunk.choices[0].delta.content if hasattr(chunk.choices[0].delta, 'content') else None | ||
else: | ||
# Handle raw string chunks from streaming response | ||
content = chunk | ||
|
||
if content: | ||
# Create event data for streaming chunk | ||
event = LLMEvent( | ||
event_type=EventType.LLM.value, | ||
init_timestamp=init_timestamp, | ||
end_timestamp=get_ISO_time(), | ||
model=kwargs.get('model', 'unknown'), | ||
prompt=str(kwargs.get('messages', [])), | ||
completion="[Streaming Response]", | ||
prompt_tokens=0, | ||
completion_tokens=0, | ||
cost=0.0 | ||
) | ||
if self._session: | ||
self._session.record(event) | ||
logger.debug(f"Recorded streaming chunk for session {self._session.session_id}") | ||
yield content | ||
except Exception as e: | ||
logger.error(f"Error processing streaming chunk: {str(e)}") | ||
continue | ||
|
||
def generator(stream): | ||
for chunk in stream: | ||
try: | ||
# Parse the chunk data | ||
if hasattr(chunk, 'choices') and chunk.choices: | ||
content = chunk.choices[0].delta.content if hasattr(chunk.choices[0].delta, 'content') else None | ||
else: | ||
# Handle raw string chunks from streaming response | ||
content = chunk | ||
|
||
if content: | ||
# Create event data for streaming chunk | ||
event = LLMEvent( | ||
event_type=EventType.LLM.value, | ||
init_timestamp=init_timestamp, | ||
end_timestamp=get_ISO_time(), | ||
model=kwargs.get('model', 'unknown'), | ||
prompt=str(kwargs.get('messages', [])), | ||
completion="[Streaming Response]", | ||
prompt_tokens=0, | ||
completion_tokens=0, | ||
cost=0.0 | ||
) | ||
if self._session: | ||
self._session.record(event) | ||
logger.debug(f"Recorded streaming chunk for session {self._session.session_id}") | ||
yield content | ||
except Exception as e: | ||
logger.error(f"Error processing streaming chunk: {str(e)}") | ||
continue | ||
|
||
if hasattr(response, '__aiter__'): | ||
return async_generator(response) | ||
else: | ||
return generator(response) | ||
|
||
# Handle non-streaming response | ||
if hasattr(response, 'choices') and response.choices: | ||
content = response.choices[0].message.content if hasattr(response.choices[0], 'message') else "" | ||
|
||
# Create event data for non-streaming response | ||
event = LLMEvent( | ||
event_type=EventType.LLM.value, | ||
init_timestamp=init_timestamp, | ||
end_timestamp=get_ISO_time(), | ||
model=kwargs.get('model', 'unknown'), | ||
prompt=str(kwargs.get('messages', [])), | ||
completion=content, | ||
prompt_tokens=0, | ||
completion_tokens=0, | ||
cost=0.0 | ||
) | ||
if self._session: | ||
self._session.record(event) | ||
logger.debug(f"Recorded non-streaming response for session {self._session.session_id}") | ||
|
||
return response | ||
|
||
except Exception as e: | ||
logger.error(f"Error handling Fireworks response: {str(e)}") | ||
raise | ||
|
||
def override(self): | ||
"""Override Fireworks API methods with instrumented versions.""" | ||
logger.info(f"Overriding {self._provider_name} provider methods") | ||
|
||
# Store original methods | ||
self._original_completion = self.client.chat.completions.create | ||
self._original_async_completion = getattr(self.client.chat.completions, 'acreate', None) | ||
|
||
# Override methods | ||
self._override_fireworks_completion() | ||
if self._original_async_completion: | ||
self._override_fireworks_async_completion() | ||
|
||
def _override_fireworks_completion(self): | ||
"""Override synchronous completion method.""" | ||
original_create = self._original_completion | ||
provider = self | ||
|
||
def patched_function(*args, **kwargs): | ||
try: | ||
init_timestamp = get_ISO_time() | ||
response = original_create(*args, **kwargs) | ||
return provider.handle_response(response, kwargs, init_timestamp, provider._session) | ||
except Exception as e: | ||
logger.error(f"Error in Fireworks completion: {str(e)}") | ||
raise | ||
|
||
self.client.chat.completions.create = patched_function | ||
|
||
def _override_fireworks_async_completion(self): | ||
"""Override asynchronous completion method.""" | ||
original_acreate = self._original_async_completion | ||
provider = self | ||
|
||
async def patched_function(*args, **kwargs): | ||
try: | ||
init_timestamp = get_ISO_time() | ||
response = await original_acreate(*args, **kwargs) | ||
return provider.handle_response(response, kwargs, init_timestamp, provider._session) | ||
except Exception as e: | ||
logger.error(f"Error in Fireworks async completion: {str(e)}") | ||
raise | ||
|
||
self.client.chat.completions.acreate = patched_function | ||
|
||
def undo_override(self): | ||
"""Restore original Fireworks API methods.""" | ||
logger.info(f"Restoring original {self._provider_name} provider methods") | ||
if self._original_completion: | ||
self.client.chat.completions.create = self._original_completion | ||
if self._original_async_completion: | ||
self.client.chat.completions.acreate = self._original_async_completion | ||
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -16,6 +16,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from .providers.anthropic import AnthropicProvider | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from .providers.mistral import MistralProvider | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from .providers.ai21 import AI21Provider | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from .providers.fireworks import FireworksProvider | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
original_func = {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
original_create = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -48,6 +49,9 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"mistralai": { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"1.0.1": ("chat.complete", "chat.stream"), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"fireworks-ai": { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"0.1.0": ("chat.completions.create",), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"ai21": { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"2.0.0": ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"chat.completions.create", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -155,6 +159,15 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.warning(f"Only AI21>=2.0.0 supported. v{module_version} found.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if api == "fireworks-ai": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
module_version = version(api) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if Version(module_version) >= parse("0.1.0"): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
provider = FireworksProvider(self.client) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
provider.override() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logger.warning(f"Only Fireworks>=0.1.0 supported. v{module_version} found.") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if api == "llama_stack_client": | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
module_version = version(api) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
159
to
173
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🤖 Bug Fix: Actionable Steps:
🔧 Suggested Code Diff: else:
logger.warning(f"Only Fireworks>=0.1.0 supported. v{module_version} found.")
+ raise RuntimeError(f"Unsupported Fireworks version: {module_version}. Please upgrade to >=0.1.0.") 📝 Committable Code Suggestion
Suggested change
📜 Guidelines• Use exceptions for error handling rather than return codes |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -174,3 +187,4 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
MistralProvider(self.client).undo_override() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
AI21Provider(self.client).undo_override() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
LlamaStackClientProvider(self.client).undo_override() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
FireworksProvider(self.client).undo_override() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
--- | ||
title: Fireworks | ||
description: "AgentOps provides support for Fireworks AI's LLM models" | ||
--- | ||
|
||
import CodeTooltip from '/snippets/add-code-tooltip.mdx' | ||
import EnvTooltip from '/snippets/add-env-tooltip.mdx' | ||
|
||
<Card title="Fireworks" icon="robot" href="https://fireworks.ai"> | ||
First class support for Fireworks AI models including Llama-v3 | ||
</Card> | ||
|
||
<Steps> | ||
<Step title="Install the AgentOps SDK"> | ||
<CodeGroup> | ||
```bash pip | ||
pip install agentops | ||
``` | ||
```bash poetry | ||
poetry add agentops | ||
``` | ||
</CodeGroup> | ||
</Step> | ||
<Step title="Install the Fireworks SDK"> | ||
<CodeGroup> | ||
```bash pip | ||
pip install --upgrade fireworks-ai | ||
``` | ||
```bash poetry | ||
poetry add fireworks-ai | ||
``` | ||
</CodeGroup> | ||
</Step> | ||
<Step title="Add 3 lines of code"> | ||
<CodeTooltip/> | ||
<CodeGroup> | ||
```python python | ||
import agentops | ||
from fireworks.client import Fireworks | ||
|
||
agentops.init(<INSERT YOUR API KEY HERE>) | ||
client = Fireworks() | ||
... | ||
# End of program (e.g. main.py) | ||
agentops.end_session("Success") | ||
``` | ||
</CodeGroup> | ||
<EnvTooltip /> | ||
<CodeGroup> | ||
```python .env | ||
AGENTOPS_API_KEY=<YOUR API KEY> | ||
FIREWORKS_API_KEY=<YOUR FIREWORKS API KEY> | ||
``` | ||
</CodeGroup> | ||
Read more about environment variables in [Advanced Configuration](/v1/usage/advanced-configuration) | ||
</Step> | ||
<Step title="Run your Agent"> | ||
Execute your program and visit [app.agentops.ai/drilldown](https://app.agentops.ai/drilldown) to observe your Agent! 🕵️ | ||
<Tip> | ||
After your run, AgentOps prints a clickable url to console linking directly to your session in the Dashboard | ||
</Tip> | ||
<div/> | ||
<Frame type="glass" caption="Clickable link to session"> | ||
<img height="200" src="https://raw.githubusercontent.com/AgentOps-AI/agentops/refs/heads/main/docs/images/external/app_screenshots/session-replay.png?raw=true" /> | ||
</Frame> | ||
</Step> | ||
</Steps> | ||
|
||
## Full Examples | ||
|
||
<CodeGroup> | ||
```python sync | ||
from fireworks.client import Fireworks | ||
import agentops | ||
|
||
agentops.init(<INSERT YOUR API KEY HERE>) | ||
client = Fireworks() | ||
|
||
response = client.chat.completions.create( | ||
model="accounts/fireworks/models/llama-v3p1-8b-instruct", | ||
messages=[{ | ||
"role": "user", | ||
"content": "Write a haiku about AI and humans working together" | ||
}] | ||
) | ||
|
||
print(response.choices[0].message.content) | ||
agentops.end_session('Success') | ||
``` | ||
|
||
```python async | ||
from fireworks.client import AsyncFireworks | ||
import agentops | ||
import asyncio | ||
|
||
async def main(): | ||
agentops.init(<INSERT YOUR API KEY HERE>) | ||
client = AsyncFireworks() | ||
|
||
response = await client.chat.completions.create( | ||
model="accounts/fireworks/models/llama-v3p1-8b-instruct", | ||
messages=[{ | ||
"role": "user", | ||
"content": "Write a haiku about AI and humans working together" | ||
}] | ||
) | ||
|
||
print(response.choices[0].message.content) | ||
agentops.end_session('Success') | ||
|
||
asyncio.run(main()) | ||
``` | ||
|
||
</CodeGroup> | ||
|
||
### Streaming examples | ||
|
||
<CodeGroup> | ||
```python sync | ||
from fireworks.client import Fireworks | ||
import agentops | ||
|
||
agentops.init(<INSERT YOUR API KEY HERE>) | ||
client = Fireworks() | ||
|
||
stream = client.chat.completions.create( | ||
model="accounts/fireworks/models/llama-v3p1-8b-instruct", | ||
stream=True, | ||
messages=[{ | ||
"role": "user", | ||
"content": "Write a haiku about AI and humans working together" | ||
}], | ||
) | ||
|
||
for chunk in stream: | ||
print(chunk.choices[0].delta.content or "", end="") | ||
|
||
agentops.end_session('Success') | ||
``` | ||
|
||
```python async | ||
from fireworks.client import AsyncFireworks | ||
import agentops | ||
import asyncio | ||
|
||
async def main(): | ||
agentops.init(<INSERT YOUR API KEY HERE>) | ||
client = AsyncFireworks() | ||
|
||
stream = await client.chat.completions.create( | ||
model="accounts/fireworks/models/llama-v3p1-8b-instruct", | ||
stream=True, | ||
messages=[{ | ||
"role": "user", | ||
"content": "Write a haiku about AI and humans working together" | ||
}], | ||
) | ||
|
||
|
||
async for chunk in stream: | ||
print(chunk.choices[0].delta.content or "", end="") | ||
|
||
agentops.end_session('Success') | ||
|
||
asyncio.run(main()) | ||
``` | ||
|
||
</CodeGroup> | ||
|
||
<script type="module" src="/scripts/github_stars.js"></script> | ||
<script type="module" src="/scripts/scroll-img-fadein-animation.js"></script> | ||
<script type="module" src="/scripts/button_heartbeat_animation.js"></script> | ||
<script type="css" src="/styles/styles.css"></script> | ||
<script type="module" src="/scripts/adjust_api_dynamically.js"></script> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤖 Bug Fix:
Missing Asynchronous Generator for Streaming Responses
The current implementation attempts to handle streaming responses by checking for the
__aiter__
attribute to decide whether to use an asynchronous generator. However, theasync_generator
function is not defined, which will lead to runtime errors when handling asynchronous responses. This is a critical issue as it can break the functionality of theFireworksProvider
when dealing with asynchronous data streams.Actionable Suggestions:
async_generator
function to properly handle asynchronous streaming responses. This function should mirror the logic of the synchronousgenerator
function but should be capable of handling asynchronous iteration.async_generator
function correctly yields content from the asynchronous response, handling any exceptions that may occur during iteration.By implementing these changes, the code will be robust against both synchronous and asynchronous streaming responses, ensuring consistent functionality across different response types.
🔧 Suggested Code Diff:
📝 Committable Code Suggestion
📜 Guidelines
• Use type annotations to improve code clarity and maintainability
• Follow PEP 8 style guide for consistent code formatting