Skip to content

Commit

Permalink
migration for openai
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Nov 11, 2023
1 parent bcf2b88 commit 7e3e6ca
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 24 deletions.
4 changes: 2 additions & 2 deletions docs/swarms/models/fuyu.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ To use Fuyu, follow these steps:
1. Initialize the Fuyu instance:

```python
from swarms.models import Fuyu
from swarms.models.fuyu import Fuyu

fuyu = Fuyu()
```
Expand All @@ -54,7 +54,7 @@ output_text = fuyu(text, img_path)
### Example 2 - Text Generation

```python
from swarms.models import Fuyu
from swarms.models.fuyu import Fuyu

fuyu = Fuyu()

Expand Down
4 changes: 2 additions & 2 deletions swarms/models/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _create_retry_decorator(
import openai

errors = [
openai.error.Timeout,
openai.Timeout,
openai.error.APIError,
openai.error.APIConnectionError,
openai.error.RateLimitError,
Expand Down Expand Up @@ -547,7 +547,7 @@ def _client_params(self) -> Dict[str, Any]:
if self.openai_proxy:
import openai

openai.proxy = {"http": self.openai_proxy, "https": self.openai_proxy} # type: ignore[assignment] # noqa: E501
raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={"http": self.openai_proxy, "https": self.openai_proxy})'") # type: ignore[assignment] # noqa: E501
return {**self._default_params, **openai_creds}

def _get_invocation_params(
Expand Down
16 changes: 8 additions & 8 deletions swarms/models/openai_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _create_retry_decorator(
import openai

errors = [
openai.error.Timeout,
openai.Timeout,
openai.error.APIError,
openai.error.APIConnectionError,
openai.error.RateLimitError,
Expand Down Expand Up @@ -500,10 +500,10 @@ def _invocation_params(self) -> Dict[str, Any]:
if self.openai_proxy:
import openai

openai.proxy = {
raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={
"http": self.openai_proxy,
"https": self.openai_proxy,
} # type: ignore[assignment] # noqa: E501
})'") # type: ignore[assignment] # noqa: E501
return {**openai_creds, **self._default_params}

@property
Expand Down Expand Up @@ -782,16 +782,16 @@ def validate_environment(cls, values: Dict) -> Dict:
try:
import openai

openai.api_key = openai_api_key

if openai_api_base:
openai.api_base = openai_api_base
raise Exception("The 'openai.api_base' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(api_base=openai_api_base)'")
if openai_organization:
openai.organization = openai_organization
raise Exception("The 'openai.organization' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(organization=openai_organization)'")
if openai_proxy:
openai.proxy = {
raise Exception("The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy={
"http": openai_proxy,
"https": openai_proxy,
} # type: ignore[assignment] # noqa: E501
})'") # type: ignore[assignment] # noqa: E501
except ImportError:
raise ImportError(
"Could not import openai python package. "
Expand Down
12 changes: 6 additions & 6 deletions swarms/models/simple_ada.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import openai
from openai import OpenAI

client = OpenAI(api_key=getenv("OPENAI_API_KEY"))
from dotenv import load_dotenv
from os import getenv

Expand All @@ -14,13 +16,11 @@ def get_ada_embeddings(text: str, model: str = "text-embedding-ada-002"):
>>> get_ada_embeddings("Hello World", model="text-embedding-ada-001")
"""
openai.api_key = getenv("OPENAI_API_KEY")


text = text.replace("\n", " ")

return openai.Embedding.create(
input=[text],
model=model,
)["data"][
return client.embeddings.create(input=[text],
model=model)["data"][
0
]["embedding"]
12 changes: 6 additions & 6 deletions tests/models/ada.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_texts():

# Basic Test
def test_get_ada_embeddings_basic(test_texts):
with patch("openai.Embedding.create") as mock_create:
with patch("openai.resources.Embeddings.create") as mock_create:
# Mocking the OpenAI API call
mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}

Expand All @@ -49,7 +49,7 @@ def test_get_ada_embeddings_basic(test_texts):
],
)
def test_get_ada_embeddings_models(text, model, expected_call_model):
with patch("openai.Embedding.create") as mock_create:
with patch("openai.resources.Embeddings.create") as mock_create:
mock_create.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}

_ = get_ada_embeddings(text, model=model)
Expand All @@ -58,16 +58,16 @@ def test_get_ada_embeddings_models(text, model, expected_call_model):

# Exception Test
def test_get_ada_embeddings_exception():
with patch("openai.Embedding.create") as mock_create:
mock_create.side_effect = openai.error.OpenAIError("Test error")
with pytest.raises(openai.error.OpenAIError):
with patch("openai.resources.Embeddings.create") as mock_create:
mock_create.side_effect = openai.OpenAIError("Test error")
with pytest.raises(openai.OpenAIError):
get_ada_embeddings("Some text")


# Tests for environment variable loading
def test_env_var_loading(monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "testkey123")
with patch("openai.Embedding.create"):
with patch("openai.resources.Embeddings.create"):
assert (
getenv("OPENAI_API_KEY") == "testkey123"
), "Environment variable for API key is not set correctly"
Expand Down

0 comments on commit 7e3e6ca

Please sign in to comment.