diff --git a/api/src/main/java/com/theokanning/openai/runs/MessageCreation.java b/api/src/main/java/com/theokanning/openai/runs/MessageCreation.java new file mode 100644 index 00000000..fe59d845 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/runs/MessageCreation.java @@ -0,0 +1,16 @@ +package com.theokanning.openai.runs; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Builder +@NoArgsConstructor +@AllArgsConstructor +@Data +public class MessageCreation { + @JsonProperty("message_id") + String messageId; +} diff --git a/api/src/main/java/com/theokanning/openai/runs/Run.java b/api/src/main/java/com/theokanning/openai/runs/Run.java new file mode 100644 index 00000000..2c14dd10 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/runs/Run.java @@ -0,0 +1,45 @@ +package com.theokanning.openai.runs; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; +import java.util.Map; + +@Builder +@NoArgsConstructor +@AllArgsConstructor +@Data +public class Run { + + @JsonProperty("assistant_id") + String assistantId; + @JsonProperty("cancelled_at") + Long cancelledAt; + @JsonProperty("completed_at") + Long completedAt; + @JsonProperty("created_at") + Long createdAt; + @JsonProperty("expires_at") + Long expiresAt; + @JsonProperty("failed_at") + Long failedAt; + @JsonProperty("file_ids") + List fileIds; + String id; + String instructions; + @JsonProperty("last_error") + String lastError; + Map metadata; + String model; + String object; + @JsonProperty("started_at") + Long startedAt; + String status; + @JsonProperty("thread_id") + String threadId; + List tools; +} diff --git a/api/src/main/java/com/theokanning/openai/runs/RunCreateRequest.java b/api/src/main/java/com/theokanning/openai/runs/RunCreateRequest.java new file mode 100644 index 00000000..93744d44 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/runs/RunCreateRequest.java @@ -0,0 +1,23 @@ +package com.theokanning.openai.runs; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; +import java.util.Map; + +@Builder +@NoArgsConstructor +@AllArgsConstructor +@Data +public class RunCreateRequest { + String assistantId; + + // Optional + String model; + String instructions; + List tools; + Map metadata; +} diff --git a/api/src/main/java/com/theokanning/openai/runs/RunStep.java b/api/src/main/java/com/theokanning/openai/runs/RunStep.java new file mode 100644 index 00000000..98c1eee3 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/runs/RunStep.java @@ -0,0 +1,39 @@ +package com.theokanning.openai.runs; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Builder +@NoArgsConstructor +@AllArgsConstructor +@Data +public class RunStep { + + @JsonProperty("assistant_id") + String assistantId; + @JsonProperty("canelled_at") + Long cancelledAt; + @JsonProperty("completed_at") + Long completedAt; + @JsonProperty("created_at") + Long createdAt; + @JsonProperty("expired_at") + Long expiredAt; + @JsonProperty("failed_at") + Long failedAt; + String id; + @JsonProperty("last_error") + String lastError; + String object; + @JsonProperty("run_id") + String runId; + String status; + @JsonProperty("step_details") + StepDetails stepDetails; + @JsonProperty("thread_id") + String threadId; + String type; +} diff --git a/api/src/main/java/com/theokanning/openai/runs/RunSteps.java b/api/src/main/java/com/theokanning/openai/runs/RunSteps.java new file mode 100644 index 00000000..c7260cfe --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/runs/RunSteps.java @@ -0,0 +1,17 @@ +package com.theokanning.openai.runs; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +public class RunSteps { + + String object; + List data; + @JsonProperty("first_id") + String firstId; + @JsonProperty("last_id") + String lastId; + @JsonProperty("has_more") + boolean hasMore; +} diff --git a/api/src/main/java/com/theokanning/openai/runs/Runs.java b/api/src/main/java/com/theokanning/openai/runs/Runs.java new file mode 100644 index 00000000..2b07de7d --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/runs/Runs.java @@ -0,0 +1,17 @@ +package com.theokanning.openai.runs; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +public class Runs { + + String object; + List data; + @JsonProperty("first_id") + String firstId; + @JsonProperty("last_id") + String lastId; + @JsonProperty("has_more") + boolean hasMore; +} diff --git a/api/src/main/java/com/theokanning/openai/runs/StepDetails.java b/api/src/main/java/com/theokanning/openai/runs/StepDetails.java new file mode 100644 index 00000000..c17f9224 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/runs/StepDetails.java @@ -0,0 +1,18 @@ +package com.theokanning.openai.runs; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Builder +@NoArgsConstructor +@AllArgsConstructor +@Data +public class StepDetails { + + @JsonProperty("message_creation") + MessageCreation messageCreation; + String type; +} diff --git a/api/src/main/java/com/theokanning/openai/runs/Tool.java b/api/src/main/java/com/theokanning/openai/runs/Tool.java new file mode 100644 index 00000000..abeee04b --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/runs/Tool.java @@ -0,0 +1,15 @@ +package com.theokanning.openai.runs; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Builder +@NoArgsConstructor +@AllArgsConstructor +@Data +public class Tool { + + String type; +} diff --git a/client/src/main/java/com/theokanning/openai/client/OpenAiApi.java b/client/src/main/java/com/theokanning/openai/client/OpenAiApi.java index cf09ff90..bb57aa53 100644 --- a/client/src/main/java/com/theokanning/openai/client/OpenAiApi.java +++ b/client/src/main/java/com/theokanning/openai/client/OpenAiApi.java @@ -33,6 +33,8 @@ import com.theokanning.openai.model.Model; import com.theokanning.openai.moderation.ModerationRequest; import com.theokanning.openai.moderation.ModerationResult; +import com.theokanning.openai.runs.Run; +import com.theokanning.openai.runs.RunCreateRequest; import com.theokanning.openai.threads.Thread; import com.theokanning.openai.threads.ThreadRequest; import io.reactivex.Single; @@ -278,4 +280,12 @@ public interface OpenAiApi { @GET("/v1/threads/{thread_id}/messages/{message_id}/files") Single> listMessageFiles(@Path("thread_id") String threadId, @Path("message_id") String messageId, @QueryMap Map filterRequest); + @Headers("OpenAI-Beta: assistants=v1") + @POST("/v1/threads/{thread_id}/runs") + Single createRun(@Path("thread_id") String threadId, @Body RunCreateRequest runCreateRequest); + + @Headers("OpenAI-Beta: assistants=v1") + @GET("/v1/threads/{thread_id}/runs/{run_id}") + Single retrieveRun(@Path("thread_id") String threadId, @Path("run_id") String runId); + } diff --git a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java index 9c15522b..bbd6d550 100644 --- a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java +++ b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java @@ -8,11 +8,7 @@ import com.fasterxml.jackson.databind.node.TextNode; import com.theokanning.openai.*; import com.theokanning.openai.assistants.*; -import com.theokanning.openai.audio.CreateSpeechRequest; -import com.theokanning.openai.audio.CreateTranscriptionRequest; -import com.theokanning.openai.audio.CreateTranslationRequest; -import com.theokanning.openai.audio.TranscriptionResult; -import com.theokanning.openai.audio.TranslationResult; +import com.theokanning.openai.audio.*; import com.theokanning.openai.billing.BillingUsage; import com.theokanning.openai.billing.Subscription; import com.theokanning.openai.client.OpenAiApi; @@ -42,6 +38,8 @@ import com.theokanning.openai.model.Model; import com.theokanning.openai.moderation.ModerationRequest; import com.theokanning.openai.moderation.ModerationResult; +import com.theokanning.openai.runs.Run; +import com.theokanning.openai.runs.RunCreateRequest; import com.theokanning.openai.threads.Thread; import com.theokanning.openai.threads.ThreadRequest; import io.reactivex.BackpressureStrategy; @@ -166,7 +164,7 @@ public List listFiles() { public File uploadFile(String purpose, String filepath) { java.io.File file = new java.io.File(filepath); - RequestBody purposeBody = RequestBody.create(okhttp3.MultipartBody.FORM, purpose); + RequestBody purposeBody = RequestBody.create(MultipartBody.FORM, purpose); RequestBody fileBody = RequestBody.create(MediaType.parse("text"), file); MultipartBody.Part body = MultipartBody.Part.createFormData("file", filepath, fileBody); @@ -364,7 +362,7 @@ public ModerationResult createModeration(ModerationRequest request) { public ResponseBody createSpeech(CreateSpeechRequest request) { return execute(api.createSpeech(request)); } - + public Assistant createAssistant(AssistantRequest request) { return execute(api.createAssistant(request)); } @@ -382,7 +380,8 @@ public DeleteResult deleteAssistant(String assistantId) { } public ListAssistant listAssistants(ListAssistantQueryRequest filterRequest) { - Map queryParameters = mapper.convertValue(filterRequest, new TypeReference>() {}); + Map queryParameters = mapper.convertValue(filterRequest, new TypeReference>() { + }); return execute(api.listAssistants(queryParameters)); } @@ -399,7 +398,8 @@ public DeleteResult deleteAssistantFile(String assistantId, String fileId) { } public ListAssistant listAssistantFiles(String assistantId, ListAssistantQueryRequest filterRequest) { - Map queryParameters = mapper.convertValue(filterRequest, new TypeReference>() {}); + Map queryParameters = mapper.convertValue(filterRequest, new TypeReference>() { + }); return execute(api.listAssistantFiles(assistantId, queryParameters)); } @@ -424,11 +424,11 @@ public Message createMessage(String threadId, MessageRequest request) { } public Message retrieveMessage(String threadId, String messageId) { - return execute(api.retrieveMessage(threadId,messageId)); + return execute(api.retrieveMessage(threadId, messageId)); } public Message modifyMessage(String threadId, String messageId, ModifyMessageRequest request) { - return execute(api.modifyMessage(threadId,messageId, request)); + return execute(api.modifyMessage(threadId, messageId, request)); } public OpenAiResponse listMessages(String threadId) { @@ -436,23 +436,32 @@ public OpenAiResponse listMessages(String threadId) { } public OpenAiResponse listMessages(String threadId, ListSearchParameters params) { - Map queryParameters = mapper.convertValue(params, new TypeReference>() {}); - return execute(api.listMessages(threadId,queryParameters)); + Map queryParameters = mapper.convertValue(params, new TypeReference>() { + }); + return execute(api.listMessages(threadId, queryParameters)); } public MessageFile retrieveMessageFile(String threadId, String messageId, String fileId) { - return execute(api.retrieveMessageFile(threadId,messageId, fileId)); + return execute(api.retrieveMessageFile(threadId, messageId, fileId)); } public OpenAiResponse listMessageFiles(String threadId, String messageId) { - return execute(api.listMessageFiles(threadId,messageId)); + return execute(api.listMessageFiles(threadId, messageId)); } public OpenAiResponse listMessageFiles(String threadId, String messageId, ListSearchParameters params) { - Map queryParameters = mapper.convertValue(params, new TypeReference>() {}); - return execute(api.listMessageFiles(threadId,messageId, queryParameters)); + Map queryParameters = mapper.convertValue(params, new TypeReference>() { + }); + return execute(api.listMessageFiles(threadId, messageId, queryParameters)); + } + + public Run createRun(String threadId, RunCreateRequest runCreateRequest) { + return execute(api.createRun(threadId, runCreateRequest)); } + public Run retrieveRun(String threadId, String runId) { + return execute(api.retrieveRun(threadId, runId)); + } /** * Calls the Open AI api, returns the response, and parses error messages if the request fails diff --git a/service/src/test/java/com/theokanning/openai/service/RunTest.java b/service/src/test/java/com/theokanning/openai/service/RunTest.java new file mode 100644 index 00000000..2bd0c166 --- /dev/null +++ b/service/src/test/java/com/theokanning/openai/service/RunTest.java @@ -0,0 +1,69 @@ +package com.theokanning.openai.service; + +import com.theokanning.openai.OpenAiResponse; +import com.theokanning.openai.assistants.Assistant; +import com.theokanning.openai.assistants.AssistantRequest; +import com.theokanning.openai.messages.Message; +import com.theokanning.openai.messages.MessageRequest; +import com.theokanning.openai.runs.Run; +import com.theokanning.openai.runs.RunCreateRequest; +import com.theokanning.openai.threads.Thread; +import com.theokanning.openai.threads.ThreadRequest; +import com.theokanning.openai.utils.TikTokensUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +class RunTest { + String token = System.getenv("OPENAI_TOKEN"); + OpenAiService service = new OpenAiService(token); + + @Test + @Timeout(10) + void createRetrieveRun() { + AssistantRequest assistantRequest = AssistantRequest.builder() + .model(TikTokensUtil.ModelEnum.GPT_4_1106_preview.getName()) + .name("MATH_TUTOR") + .instructions("You are a personal Math Tutor.") + .build(); + Assistant assistant = service.createAssistant(assistantRequest); + + ThreadRequest threadRequest = ThreadRequest.builder() + .build(); + Thread thread = service.createThread(threadRequest); + + MessageRequest messageRequest = MessageRequest.builder() + .content("Hello") + .build(); + + Message message = service.createMessage(thread.getId(), messageRequest); + + RunCreateRequest runCreateRequest = RunCreateRequest.builder() + .assistantId(assistant.getId()) + .build(); + + Run run = service.createRun(thread.getId(), runCreateRequest); + assertNotNull(run); + + Run retrievedRun; + do { + retrievedRun = service.retrieveRun(thread.getId(), run.getId()); + assertEquals(run.getId(), retrievedRun.getId()); + } + while (!(retrievedRun.getStatus().equals("completed")) && !(retrievedRun.getStatus().equals("failed"))); + + + assertNotNull(retrievedRun); + + OpenAiResponse response = service.listMessages(thread.getId()); + + List messages = response.getData(); + assertEquals(2, messages.size()); + assertEquals("user", messages.get(1).getRole()); + assertEquals("assistant", messages.get(0).getRole()); + } +}