Skip to content

Commit

Permalink
refactor: switch tool factory functions to dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
lpm0073 committed Jan 25, 2024
1 parent d3359f8 commit a8e0a5e
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ def search_terms_are_in_messages(messages: list, search_terms: list = None, sear


def customized_prompt(config: RefersTo, messages: list) -> list:
"""Return a prompt for Lawrence McDaniel"""
custom_prompt = {
"role": "system",
"content": config.system_prompt.system_prompt,
}
"""Modify the system prompt based on the custom configuration object"""

for i, message in enumerate(messages):
if message.get("role") == "system":
system_prompt = message.get("content")
custom_prompt = {
"role": "system",
"content": system_prompt + "\n\n and also " + config.system_prompt.system_prompt,
}
messages[i] = custom_prompt
break

Expand All @@ -68,23 +69,21 @@ def info_tool_factory(config: RefersTo):
"""
Return a dictionary of chat completion tools.
"""
tools = [
{
"type": "function",
"function": {
"name": "get_additional_info",
"description": config.function_description,
"parameters": {
"type": "object",
"properties": {
"inquiry_type": {
"type": "string",
"enum": config.additional_information.keys,
},
tool = {
"type": "function",
"function": {
"name": "get_additional_info",
"description": config.function_description,
"parameters": {
"type": "object",
"properties": {
"inquiry_type": {
"type": "string",
"enum": config.additional_information.keys,
},
"required": ["inquiry_type"],
},
"required": ["inquiry_type"],
},
}
]
return tools
},
}
return tool
Original file line number Diff line number Diff line change
Expand Up @@ -111,24 +111,22 @@ def get_current_weather(location, unit="METRIC"):

def weather_tool_factory():
"""Return a list of tools that can be called by the OpenAI API"""
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["METRIC", "USCS"]},
tool = {
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"required": ["location"],
"unit": {"type": "string", "enum": ["METRIC", "USCS"]},
},
"required": ["location"],
},
}
]
return tools
},
}
return tool
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def handler(event, context):
OpenAI API endpoint based on the contents of the request.
"""
cloudwatch_handler(event, settings.dump, debug_mode=settings.debug_mode)
tools = weather_tool_factory()
weather_tool = weather_tool_factory()
tools = [weather_tool]

try:
openai_results = {}
Expand All @@ -82,10 +83,9 @@ def handler(event, context):
):
model = "gpt-3.5-turbo-1106"
messages = customized_prompt(config=config, messages=messages)
custom_tool = info_tool_factory(config=config)[0]
custom_tool = info_tool_factory(config=config)
tools.append(custom_tool)
print(f"Using custom configuration: {config.name} and adding custom tool: {custom_tool}")
break
print(f"Adding custom configuration: {config.name}")

# https://platform.openai.com/docs/guides/gpt/chat-completions-api
validate_item(
Expand All @@ -94,10 +94,6 @@ def handler(event, context):
item_type="ChatCompletion models",
)
validate_completion_request(request_body)
print("Calling OpenAI Chat Completion API...")
print(
f"model: {model}, messages: {messages}, tools: {tools}, temperature: {temperature}, max_tokens: {max_tokens}"
)
openai_results = openai.chat.completions.create(
model=model,
messages=messages,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ def test_get_additional_info(self):
def test_info_tool_factory(self):
"""Test integrity info_tool_factory()"""
itf = info_tool_factory(config=self.config)
self.assertIsInstance(itf, list)
self.assertIsInstance(itf, dict)

d = itf[0]
self.assertIsInstance(d, dict)
self.assertTrue("type" in d)
self.assertTrue("function" in d)
self.assertIsInstance(itf, dict)
self.assertTrue("type" in itf)
self.assertTrue("function" in itf)
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ def test_get_current_weather(self):
def test_weather_tool_factory(self):
"""Test integrity weather_tool_factory()"""
wtf = weather_tool_factory()
self.assertIsInstance(wtf, list)
self.assertIsInstance(wtf, dict)

d = wtf[0]
self.assertIsInstance(d, dict)
self.assertTrue("type" in d)
self.assertTrue("function" in d)
self.assertIsInstance(wtf, dict)
self.assertTrue("type" in wtf)
self.assertTrue("function" in wtf)

0 comments on commit a8e0a5e

Please sign in to comment.