Skip to content

Commit

Permalink
tests for gpt4visionapi
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Nov 25, 2023
1 parent 6abe9a0 commit 208032c
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 39 deletions.
4 changes: 2 additions & 2 deletions multi_modal_auto_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
## Initialize the workflow
flow = Flow(
llm=llm,
max_loops='auto',
max_loops="auto",
dashboard=True,
)

flow.run(task=task, img=img)
flow.run(task=task, img=img)
4 changes: 2 additions & 2 deletions playground/demos/positive_med/positive_med.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def social_media_prompt(article: str, goal: str = "Clicks and engagement"):
"Generate 10 topics on gaining mental clarity using ancient practices"
)
topics = llm(
f"Your System Instructions: {TOPIC_GENERATOR_SYSTEM_PROMPT}, Your current task:"
f" {topic_selection_task}"
f"Your System Instructions: {TOPIC_GENERATOR_SYSTEM_PROMPT}, Your current"
f" task: {topic_selection_task}"
)

dashboard = print(
Expand Down
1 change: 0 additions & 1 deletion swarms/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,4 @@
# "Dalle3",
# "DistilWhisperModel",
"GPT4VisionAPI",

]
6 changes: 2 additions & 4 deletions swarms/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,9 @@ def build_extra_kwargs(
if field_name in extra_kwargs:
raise ValueError(f"Found {field_name} supplied twice.")
if field_name not in all_required_field_names:
warnings.warn(
f"""WARNING! {field_name} is not default parameter.
warnings.warn(f"""WARNING! {field_name} is not default parameter.
{field_name} was transferred to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
Please confirm that {field_name} is what you intended.""")
extra_kwargs[field_name] = values.pop(field_name)

invalid_model_kwargs = all_required_field_names.intersection(
Expand Down
18 changes: 12 additions & 6 deletions swarms/models/dalle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,10 @@ def __call__(self, task: str):
# Handling exceptions and printing the errors details
print(
colored(
f"Error running Dalle3: {error} try optimizing your api"
" key and or try again",
(
f"Error running Dalle3: {error} try optimizing your api"
" key and or try again"
),
"red",
)
)
Expand Down Expand Up @@ -231,8 +233,10 @@ def create_variations(self, img: str):
except (Exception, openai.OpenAIError) as error:
print(
colored(
f"Error running Dalle3: {error} try optimizing your api"
" key and or try again",
(
f"Error running Dalle3: {error} try optimizing your api"
" key and or try again"
),
"red",
)
)
Expand Down Expand Up @@ -306,8 +310,10 @@ def process_batch_concurrently(
except Exception as error:
print(
colored(
f"Error running Dalle3: {error} try optimizing"
" your api key and or try again",
(
f"Error running Dalle3: {error} try optimizing"
" your api key and or try again"
),
"red",
)
)
Expand Down
23 changes: 13 additions & 10 deletions swarms/models/gpt4_vision_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")


class GPT4VisionAPI:
"""
GPT-4 Vision API
Expand Down Expand Up @@ -34,13 +35,11 @@ class GPT4VisionAPI:
>>> task = "What is the color of the object?"
>>> img = "https://i.imgur.com/2M2ZGwC.jpeg"
>>> llm.run(task, img)
"""
def __init__(
self,
openai_api_key: str = openai_api_key
):

def __init__(self, openai_api_key: str = openai_api_key):
super().__init__()
self.openai_api_key = openai_api_key

Expand All @@ -52,7 +51,7 @@ def encode_image(self, img: str):
# Function to handle vision tasks
def run(self, task: str, img: str):
"""Run the model."""
try:
try:
base64_image = self.encode_image(img)
headers = {
"Content-Type": "application/json",
Expand All @@ -68,7 +67,9 @@ def run(self, task: str, img: str):
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
"url": (
f"data:image/jpeg;base64,{base64_image}"
)
},
},
],
Expand All @@ -92,7 +93,7 @@ def run(self, task: str, img: str):

def __call__(self, task: str, img: str):
"""Run the model."""
try:
try:
base64_image = self.encode_image(img)
headers = {
"Content-Type": "application/json",
Expand All @@ -108,7 +109,9 @@ def __call__(self, task: str, img: str):
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
"url": (
f"data:image/jpeg;base64,{base64_image}"
)
},
},
],
Expand Down
6 changes: 4 additions & 2 deletions swarms/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,10 @@ def run(self, task: str):
except Exception as e:
print(
colored(
"HuggingfaceLLM could not generate text because of"
f" error: {e}, try optimizing your arguments",
(
"HuggingfaceLLM could not generate text because of"
f" error: {e}, try optimizing your arguments"
),
"red",
)
)
Expand Down
12 changes: 8 additions & 4 deletions swarms/models/ssd_1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,10 @@ def __call__(self, task: str, neg_prompt: str):
# Handling exceptions and printing the errors details
print(
colored(
f"Error running SSD1B: {error} try optimizing your api"
" key and or try again",
(
f"Error running SSD1B: {error} try optimizing your api"
" key and or try again"
),
"red",
)
)
Expand Down Expand Up @@ -226,8 +228,10 @@ def process_batch_concurrently(
except Exception as error:
print(
colored(
f"Error running SSD1B: {error} try optimizing"
" your api key and or try again",
(
f"Error running SSD1B: {error} try optimizing"
" your api key and or try again"
),
"red",
)
)
Expand Down
2 changes: 0 additions & 2 deletions swarms/prompts/autobloggen.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,5 +274,3 @@
- Flag any bold claims that lack credible evidence for fact-checker review.
"""


8 changes: 5 additions & 3 deletions swarms/structs/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,10 @@ def activate_autonomous_agent(self):
except Exception as error:
print(
colored(
"Error activating autonomous agent. Try optimizing your"
" parameters...",
(
"Error activating autonomous agent. Try optimizing your"
" parameters..."
),
"red",
)
)
Expand Down Expand Up @@ -657,7 +659,7 @@ async def arun(self, task: str, **kwargs):
while attempt < self.retry_attempts:
try:
response = self.llm(
task ** kwargs,
task**kwargs,
)
if self.interactive:
print(f"AI: {response}")
Expand Down
8 changes: 5 additions & 3 deletions swarms/structs/sequential_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,11 @@ def run(self) -> None:
except Exception as e:
print(
colored(
f"Error initializing the Sequential workflow: {e} try"
" optimizing your inputs like the flow class and task"
" description",
(
f"Error initializing the Sequential workflow: {e} try"
" optimizing your inputs like the flow class and task"
" description"
),
"red",
attrs=["bold", "underline"],
)
Expand Down
122 changes: 122 additions & 0 deletions tests/models/test_gpt4_vision_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import pytest
from unittest.mock import mock_open, patch, Mock
from requests.exceptions import RequestException
from swarms.models.gpt4_vision_api import GPT4VisionAPI
import os
from dotenv import load_dotenv

load_dotenv()


custom_api_key = os.environ.get("OPENAI_API_KEY")
img = "images/swarms.jpeg"


@pytest.fixture
def vision_api():
return GPT4VisionAPI(openai_api_key="test_api_key")


def test_init(vision_api):
assert vision_api.openai_api_key == "test_api_key"


def test_encode_image(vision_api):
with patch(
"builtins.open", mock_open(read_data=b"test_image_data"), create=True
):
encoded_image = vision_api.encode_image("test_image.jpg")
assert encoded_image == "dGVzdF9pbWFnZV9kYXRh"


def test_run_success(vision_api):
expected_response = {"choices": [{"text": "This is the model's response."}]}
with patch(
"requests.post", return_value=Mock(json=lambda: expected_response)
) as mock_post:
result = vision_api.run("What is this?", "test_image.jpg")
mock_post.assert_called_once()
assert result == "This is the model's response."


def test_run_request_error(vision_api):
with patch(
"requests.post", side_effect=RequestException("Request Error")
) as mock_post:
with pytest.raises(RequestException):
vision_api.run("What is this?", "test_image.jpg")


def test_run_response_error(vision_api):
expected_response = {"error": "Model Error"}
with patch(
"requests.post", return_value=Mock(json=lambda: expected_response)
) as mock_post:
with pytest.raises(RuntimeError):
vision_api.run("What is this?", "test_image.jpg")


def test_call(vision_api):
expected_response = {"choices": [{"text": "This is the model's response."}]}
with patch(
"requests.post", return_value=Mock(json=lambda: expected_response)
) as mock_post:
result = vision_api("What is this?", "test_image.jpg")
mock_post.assert_called_once()
assert result == "This is the model's response."


@pytest.fixture
def gpt_api():
return GPT4VisionAPI()


def test_initialization_with_default_key():
api = GPT4VisionAPI()
assert api.openai_api_key == custom_api_key


def test_initialization_with_custom_key():
custom_key = custom_api_key
api = GPT4VisionAPI(openai_api_key=custom_key)
assert api.openai_api_key == custom_key


def test_run_successful_response(gpt_api):
task = "What is in the image?"
img_url = img
response_json = {"choices": [{"text": "Answer from GPT-4 Vision"}]}
mock_response = Mock()
mock_response.json.return_value = response_json
with patch("requests.post", return_value=mock_response) as mock_post:
result = gpt_api.run(task, img_url)
mock_post.assert_called_once()
assert result == response_json["choices"][0]["text"]


def test_run_with_exception(gpt_api):
task = "What is in the image?"
img_url = img
with patch("requests.post", side_effect=Exception("Test Exception")):
with pytest.raises(Exception):
gpt_api.run(task, img_url)


def test_call_method_successful_response(gpt_api):
task = "What is in the image?"
img_url = img
response_json = {"choices": [{"text": "Answer from GPT-4 Vision"}]}
mock_response = Mock()
mock_response.json.return_value = response_json
with patch("requests.post", return_value=mock_response) as mock_post:
result = gpt_api(task, img_url)
mock_post.assert_called_once()
assert result == response_json


def test_call_method_with_exception(gpt_api):
task = "What is in the image?"
img_url = img
with patch("requests.post", side_effect=Exception("Test Exception")):
with pytest.raises(Exception):
gpt_api(task, img_url)

0 comments on commit 208032c

Please sign in to comment.