diff --git a/api/src/main/java/com/theokanning/openai/assistants/AssistantFunction.java b/api/src/main/java/com/theokanning/openai/assistants/AssistantFunction.java new file mode 100644 index 00000000..3abee5ae --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/assistants/AssistantFunction.java @@ -0,0 +1,28 @@ +package com.theokanning.openai.assistants; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.Map; + +/** + * @description: + * @author: vacuity + * @create: 2023-11-20 10:09 + **/ + + +@Builder +@NoArgsConstructor +@AllArgsConstructor +@Data +public class AssistantFunction { + + private String description; + + private String name; + + private Map parameters; +} diff --git a/api/src/main/java/com/theokanning/openai/assistants/Tool.java b/api/src/main/java/com/theokanning/openai/assistants/Tool.java index e759db80..f35af0ef 100644 --- a/api/src/main/java/com/theokanning/openai/assistants/Tool.java +++ b/api/src/main/java/com/theokanning/openai/assistants/Tool.java @@ -1,6 +1,5 @@ package com.theokanning.openai.assistants; -import com.theokanning.openai.completion.chat.ChatFunction; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; @@ -17,5 +16,5 @@ public class Tool { /** * Function definition, only used if type is "function" */ - ChatFunction function; + AssistantFunction function; } diff --git a/api/src/main/java/com/theokanning/openai/runs/SubmitToolOutputsRequest.java b/api/src/main/java/com/theokanning/openai/runs/SubmitToolOutputsRequest.java index 3086610c..f892f168 100644 --- a/api/src/main/java/com/theokanning/openai/runs/SubmitToolOutputsRequest.java +++ b/api/src/main/java/com/theokanning/openai/runs/SubmitToolOutputsRequest.java @@ -22,5 +22,5 @@ public class SubmitToolOutputsRequest { @JsonProperty("tool_outputs") - private List tool_outputs; + private List toolOutputs; } diff --git a/service/src/test/java/com/theokanning/openai/service/AssistantFunctionTest.java b/service/src/test/java/com/theokanning/openai/service/AssistantFunctionTest.java new file mode 100644 index 00000000..9ad819a7 --- /dev/null +++ b/service/src/test/java/com/theokanning/openai/service/AssistantFunctionTest.java @@ -0,0 +1,148 @@ +package com.theokanning.openai.service; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.PropertyNamingStrategy; +import com.theokanning.openai.ListSearchParameters; +import com.theokanning.openai.OpenAiResponse; +import com.theokanning.openai.assistants.Assistant; +import com.theokanning.openai.assistants.AssistantFunction; +import com.theokanning.openai.assistants.AssistantRequest; +import com.theokanning.openai.assistants.AssistantToolsEnum; +import com.theokanning.openai.assistants.Tool; +import com.theokanning.openai.completion.chat.ChatCompletionRequest; +import com.theokanning.openai.completion.chat.ChatFunction; +import com.theokanning.openai.completion.chat.ChatFunctionCall; +import com.theokanning.openai.messages.Message; +import com.theokanning.openai.messages.MessageRequest; +import com.theokanning.openai.runs.RequiredAction; +import com.theokanning.openai.runs.Run; +import com.theokanning.openai.runs.RunCreateRequest; +import com.theokanning.openai.runs.RunStep; +import com.theokanning.openai.runs.SubmitToolOutputRequestItem; +import com.theokanning.openai.runs.SubmitToolOutputs; +import com.theokanning.openai.runs.SubmitToolOutputsRequest; +import com.theokanning.openai.runs.ToolCall; +import com.theokanning.openai.threads.Thread; +import com.theokanning.openai.threads.ThreadRequest; +import com.theokanning.openai.utils.TikTokensUtil; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +class AssistantFunctionTest { + String token = System.getenv("OPENAI_TOKEN"); + OpenAiService service = new OpenAiService(token, Duration.ofMinutes(1)); + + @Test + void createRetrieveRun() throws JsonProcessingException { + + ObjectMapper mapper = new ObjectMapper(); + mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL); + mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE); + mapper.addMixIn(ChatFunction.class, ChatFunctionMixIn.class); + mapper.addMixIn(ChatCompletionRequest.class, ChatCompletionRequestMixIn.class); + mapper.addMixIn(ChatFunctionCall.class, ChatFunctionCallMixIn.class); + + String funcDef = "{\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"location\": {\n" + + " \"type\": \"string\",\n" + + " \"description\": \"The city and state, e.g. San Francisco, CA\"\n" + + " },\n" + + " \"unit\": {\n" + + " \"type\": \"string\",\n" + + " \"enum\": [\"celsius\", \"fahrenheit\"]\n" + + " }\n" + + " },\n" + + " \"required\": [\"location\"]\n" + + "}"; + Map funcParameters = mapper.readValue(funcDef, new TypeReference>() {}); + AssistantFunction function = AssistantFunction.builder() + .name("weather_reporter") + .description("Get the current weather of a location") + .parameters(funcParameters) + .build(); + + List toolList = new ArrayList<>(); + Tool funcTool = new Tool(AssistantToolsEnum.FUNCTION, function); + toolList.add(funcTool); + + + AssistantRequest assistantRequest = AssistantRequest.builder() + .model(TikTokensUtil.ModelEnum.GPT_4_1106_preview.getName()) + .name("MATH_TUTOR") + .instructions("You are a personal Math Tutor.") + .tools(toolList) + .build(); + Assistant assistant = service.createAssistant(assistantRequest); + + ThreadRequest threadRequest = ThreadRequest.builder() + .build(); + Thread thread = service.createThread(threadRequest); + + MessageRequest messageRequest = MessageRequest.builder() + .content("What's the weather of Xiamen?") + .build(); + + Message message = service.createMessage(thread.getId(), messageRequest); + + RunCreateRequest runCreateRequest = RunCreateRequest.builder() + .assistantId(assistant.getId()) + .build(); + + Run run = service.createRun(thread.getId(), runCreateRequest); + assertNotNull(run); + + Run retrievedRun = service.retrieveRun(thread.getId(), run.getId()); + while (!(retrievedRun.getStatus().equals("completed")) + && !(retrievedRun.getStatus().equals("failed")) + && !(retrievedRun.getStatus().equals("requires_action"))){ + retrievedRun = service.retrieveRun(thread.getId(), run.getId()); + } + if (retrievedRun.getStatus().equals("requires_action")) { + RequiredAction requiredAction = retrievedRun.getRequiredAction(); + System.out.println("requiredAction"); + System.out.println(mapper.writeValueAsString(requiredAction)); + List toolCalls = requiredAction.getSubmitToolOutputs().getToolCalls(); + ToolCall toolCall = toolCalls.get(0); + String toolCallId = toolCall.getId(); + + SubmitToolOutputRequestItem toolOutputRequestItem = SubmitToolOutputRequestItem.builder() + .toolCallId(toolCallId) + .output("sunny") + .build(); + List toolOutputRequestItems = new ArrayList<>(); + toolOutputRequestItems.add(toolOutputRequestItem); + SubmitToolOutputsRequest submitToolOutputsRequest = SubmitToolOutputsRequest.builder() + .toolOutputs(toolOutputRequestItems) + .build(); + retrievedRun = service.submitToolOutputs(retrievedRun.getThreadId(), retrievedRun.getId(), submitToolOutputsRequest); + + while (!(retrievedRun.getStatus().equals("completed")) + && !(retrievedRun.getStatus().equals("failed")) + && !(retrievedRun.getStatus().equals("requires_action"))){ + retrievedRun = service.retrieveRun(thread.getId(), run.getId()); + } + + OpenAiResponse response = service.listMessages(thread.getId()); + + List messages = response.getData(); + + System.out.println(mapper.writeValueAsString(messages)); + + } + } +}