Skip to content

Commit

Permalink
clean up outputs of multi modal autonomous agents
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Nov 25, 2023
1 parent 208032c commit 1d1d0f0
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion multi_modal_auto_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
flow = Flow(
llm=llm,
max_loops="auto",
dashboard=True,

)

flow.run(task=task, img=img)
20 changes: 14 additions & 6 deletions swarms/models/gpt4_vision_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ class GPT4VisionAPI:
----------
openai_api_key : str
The OpenAI API key. Defaults to the OPENAI_API_KEY environment variable.
max_tokens : int
The maximum number of tokens to generate. Defaults to 300.
Methods
-------
encode_image(img: str)
Expand All @@ -39,9 +42,10 @@ class GPT4VisionAPI:
"""

def __init__(self, openai_api_key: str = openai_api_key):
def __init__(self, openai_api_key: str = openai_api_key, max_tokens: str = 300):
super().__init__()
self.openai_api_key = openai_api_key
self.max_tokens = max_tokens

def encode_image(self, img: str):
"""Encode image to base64."""
Expand Down Expand Up @@ -75,7 +79,7 @@ def run(self, task: str, img: str):
],
}
],
"max_tokens": 300,
"max_tokens": self.max_tokens,
}
response = requests.post(
"https://api.openai.com/v1/chat/completions",
Expand All @@ -84,8 +88,8 @@ def run(self, task: str, img: str):
)

out = response.json()

out = out["choices"][0]["text"]
content = out["choices"][0]["message"]["content"]
print(content)
except Exception as error:
print(f"Error with the request: {error}")
raise error
Expand Down Expand Up @@ -117,14 +121,18 @@ def __call__(self, task: str, img: str):
],
}
],
"max_tokens": 300,
"max_tokens": self.max_tokens,
}
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload,
)
return response.json()

out = response.json()
content = out["choices"][0]["message"]["content"]
print(content)
except Exception as error:
print(f"Error with the request: {error}")
raise error
# Function to handle vision tasks
10 changes: 5 additions & 5 deletions tests/models/test_gpt4_vision_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_encode_image(vision_api):
with patch(
"builtins.open", mock_open(read_data=b"test_image_data"), create=True
):
encoded_image = vision_api.encode_image("test_image.jpg")
encoded_image = vision_api.encode_image(img)
assert encoded_image == "dGVzdF9pbWFnZV9kYXRh"


Expand All @@ -34,7 +34,7 @@ def test_run_success(vision_api):
with patch(
"requests.post", return_value=Mock(json=lambda: expected_response)
) as mock_post:
result = vision_api.run("What is this?", "test_image.jpg")
result = vision_api.run("What is this?", img)
mock_post.assert_called_once()
assert result == "This is the model's response."

Expand All @@ -44,7 +44,7 @@ def test_run_request_error(vision_api):
"requests.post", side_effect=RequestException("Request Error")
) as mock_post:
with pytest.raises(RequestException):
vision_api.run("What is this?", "test_image.jpg")
vision_api.run("What is this?", img)


def test_run_response_error(vision_api):
Expand All @@ -53,15 +53,15 @@ def test_run_response_error(vision_api):
"requests.post", return_value=Mock(json=lambda: expected_response)
) as mock_post:
with pytest.raises(RuntimeError):
vision_api.run("What is this?", "test_image.jpg")
vision_api.run("What is this?", img)


def test_call(vision_api):
expected_response = {"choices": [{"text": "This is the model's response."}]}
with patch(
"requests.post", return_value=Mock(json=lambda: expected_response)
) as mock_post:
result = vision_api("What is this?", "test_image.jpg")
result = vision_api("What is this?", img)
mock_post.assert_called_once()
assert result == "This is the model's response."

Expand Down

0 comments on commit 1d1d0f0

Please sign in to comment.