Skip to content

Commit

Permalink
Fix OpenAI client init
Browse files Browse the repository at this point in the history
  • Loading branch information
NivekT committed Nov 8, 2023
1 parent 57ec4d4 commit e7655ad
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 12 deletions.
5 changes: 2 additions & 3 deletions examples/prompttests/test_openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from prompttools.mock.mock import mock_openai_completion_fn


client = OpenAI()


if not (("OPENAI_API_KEY" in os.environ) or ("DEBUG" in os.environ)):
print("Error: This example requires you to set either your OPENAI_API_KEY or DEBUG=1")
exit(1)
Expand Down Expand Up @@ -48,6 +45,7 @@ def json_completion_fn(prompt: str):
if os.getenv("DEBUG", default=False):
response = mock_openai_completion_fn(**{"prompt": prompt})
else:
client = OpenAI()
response = client.completions.create(model="babbage-002", prompt=prompt)
return response.choices[0].text

Expand All @@ -65,6 +63,7 @@ def completion_fn(prompt: str):
if os.getenv("DEBUG", default=False):
response = mock_openai_completion_fn(**{"prompt": prompt})
else:
client = OpenAI()
response = client.completions.create(prompt)
return response.choices[0].text

Expand Down
2 changes: 1 addition & 1 deletion prompttools/utils/autoeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
RESPONSE: {{response}}
ANSWER:
"""
client = OpenAI()


def _get_messages(prompt: str, response: str):
Expand All @@ -49,6 +48,7 @@ def compute(prompt: str, response: str, model: str = "gpt-4") -> float:
"""
if not os.environ["OPENAI_API_KEY"]:
raise PromptToolsUtilityError
client = OpenAI()
evaluation = client.chat.completions.create(model=model, messages=_get_messages(prompt, response))
return 1.0 if "RIGHT" in evaluation.choices[0].message.content else 0.0

Expand Down
5 changes: 3 additions & 2 deletions prompttools/utils/autoeval_from_expected.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
ANSWER:
"""

client = OpenAI()


def _get_messages(prompt: str, expected: str, response: str):
environment = jinja2.Environment()
Expand All @@ -53,6 +51,9 @@ def compute(prompt: str, expected: str, response: str, model: str = "gpt-4") ->
"""
if not os.environ["OPENAI_API_KEY"]:
raise PromptToolsUtilityError("Missing API key for evaluation.")
global client
if client is None:
client = OpenAI()
evaluation = client.chat.completions.create(model=model, messages=_get_messages(prompt, expected, response))
return 1.0 if "RIGHT" in evaluation.choices[0].message.content else 0.0

Expand Down
4 changes: 1 addition & 3 deletions prompttools/utils/autoeval_with_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
"""


client = OpenAI()


def _get_messages(documents: list[str], response: str):
environment = jinja2.Environment()
template = environment.from_string(EVALUATION_USER_TEMPLATE)
Expand All @@ -52,6 +49,7 @@ def compute(documents: list[str], response: str, model: str = "gpt-4") -> float:
"""
if not os.environ["OPENAI_API_KEY"]:
raise PromptToolsUtilityError
client = OpenAI()
evaluation = client.chat.completions.create(model=model, messages=_get_messages(documents, response))
score_text = evaluation.choices[0].message.content
return int(score_text)
Expand Down
4 changes: 1 addition & 3 deletions prompttools/utils/expected.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
from . import similarity


client = OpenAI()


def compute(prompt: str, model: str = "gpt-4") -> str:
r"""
Computes the expected result of a given prompt by using a high
Expand All @@ -27,6 +24,7 @@ def compute(prompt: str, model: str = "gpt-4") -> str:
"""
if not os.environ["OPENAI_API_KEY"]:
raise PromptToolsUtilityError
client = OpenAI()
response = client.chat.completions.create(model=model, prompt=prompt)
return response.choices[0].message.content

Expand Down

0 comments on commit e7655ad

Please sign in to comment.