diff --git a/examples/prompttests/test_openai_chat.py b/examples/prompttests/test_openai_chat.py index c21a73b..47f71aa 100644 --- a/examples/prompttests/test_openai_chat.py +++ b/examples/prompttests/test_openai_chat.py @@ -14,9 +14,6 @@ from prompttools.mock.mock import mock_openai_completion_fn -client = OpenAI() - - if not (("OPENAI_API_KEY" in os.environ) or ("DEBUG" in os.environ)): print("Error: This example requires you to set either your OPENAI_API_KEY or DEBUG=1") exit(1) @@ -48,6 +45,7 @@ def json_completion_fn(prompt: str): if os.getenv("DEBUG", default=False): response = mock_openai_completion_fn(**{"prompt": prompt}) else: + client = OpenAI() response = client.completions.create(model="babbage-002", prompt=prompt) return response.choices[0].text @@ -65,6 +63,7 @@ def completion_fn(prompt: str): if os.getenv("DEBUG", default=False): response = mock_openai_completion_fn(**{"prompt": prompt}) else: + client = OpenAI() response = client.completions.create(prompt) return response.choices[0].text diff --git a/prompttools/utils/autoeval.py b/prompttools/utils/autoeval.py index 9cd3d95..2dc4167 100644 --- a/prompttools/utils/autoeval.py +++ b/prompttools/utils/autoeval.py @@ -23,7 +23,6 @@ RESPONSE: {{response}} ANSWER: """ -client = OpenAI() def _get_messages(prompt: str, response: str): @@ -49,6 +48,7 @@ def compute(prompt: str, response: str, model: str = "gpt-4") -> float: """ if not os.environ["OPENAI_API_KEY"]: raise PromptToolsUtilityError + client = OpenAI() evaluation = client.chat.completions.create(model=model, messages=_get_messages(prompt, response)) return 1.0 if "RIGHT" in evaluation.choices[0].message.content else 0.0 diff --git a/prompttools/utils/autoeval_from_expected.py b/prompttools/utils/autoeval_from_expected.py index cb16b39..8dcdf61 100644 --- a/prompttools/utils/autoeval_from_expected.py +++ b/prompttools/utils/autoeval_from_expected.py @@ -27,8 +27,6 @@ ANSWER: """ -client = OpenAI() - def _get_messages(prompt: str, expected: str, response: str): environment = jinja2.Environment() @@ -53,6 +51,9 @@ def compute(prompt: str, expected: str, response: str, model: str = "gpt-4") -> """ if not os.environ["OPENAI_API_KEY"]: raise PromptToolsUtilityError("Missing API key for evaluation.") + global client + if client is None: + client = OpenAI() evaluation = client.chat.completions.create(model=model, messages=_get_messages(prompt, expected, response)) return 1.0 if "RIGHT" in evaluation.choices[0].message.content else 0.0 diff --git a/prompttools/utils/autoeval_with_docs.py b/prompttools/utils/autoeval_with_docs.py index 0cbd696..f7962a9 100644 --- a/prompttools/utils/autoeval_with_docs.py +++ b/prompttools/utils/autoeval_with_docs.py @@ -27,9 +27,6 @@ """ -client = OpenAI() - - def _get_messages(documents: list[str], response: str): environment = jinja2.Environment() template = environment.from_string(EVALUATION_USER_TEMPLATE) @@ -52,6 +49,7 @@ def compute(documents: list[str], response: str, model: str = "gpt-4") -> float: """ if not os.environ["OPENAI_API_KEY"]: raise PromptToolsUtilityError + client = OpenAI() evaluation = client.chat.completions.create(model=model, messages=_get_messages(documents, response)) score_text = evaluation.choices[0].message.content return int(score_text) diff --git a/prompttools/utils/expected.py b/prompttools/utils/expected.py index aa2e07c..6d8244d 100644 --- a/prompttools/utils/expected.py +++ b/prompttools/utils/expected.py @@ -12,9 +12,6 @@ from . import similarity -client = OpenAI() - - def compute(prompt: str, model: str = "gpt-4") -> str: r""" Computes the expected result of a given prompt by using a high @@ -27,6 +24,7 @@ def compute(prompt: str, model: str = "gpt-4") -> str: """ if not os.environ["OPENAI_API_KEY"]: raise PromptToolsUtilityError + client = OpenAI() response = client.chat.completions.create(model=model, prompt=prompt) return response.choices[0].message.content