diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java index e4479ff3..099bff37 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java @@ -15,11 +15,6 @@ @NoArgsConstructor public class ChatCompletionRequest { - /** - * ID of the model to use. - */ - String model; - /** * The messages to generate chat completions for, in the chat format.
@@ -28,36 +23,38 @@ public class ChatCompletionRequest { List messages; /** - * What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower - * values like 0.2 will make it more focused and deterministic.
- * We generally recommend altering this or top_p but not both. - */ - Double temperature; - - /** - * An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens - * with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
- * We generally recommend altering this or temperature but not both. + * ID of the model to use. */ - @JsonProperty("top_p") - Double topP; + String model; /** - * How many chat completion chatCompletionChoices to generate for each input message. + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, + * decreasing the model's likelihood to repeat the same line verbatim. */ - Integer n; + @JsonProperty("frequency_penalty") + Double frequencyPenalty; /** - * If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only server-sent - * events as they become available, with the stream terminated by a data: [DONE] message. + *

An object specifying the format that the model must output.

+ * + *

Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.

+ * + *

Important: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message. + * Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting + * in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if + * finish_reason="length", which indicates the generation exceeded max_tokens or the conversation exceeded the max context length.

*/ - Boolean stream; + @JsonProperty("response_format") + ResponseFormat responseFormat; /** - * Up to 4 sequences where the API will stop generating further tokens. + * Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 + * to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will + * vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 + * should result in a ban or exclusive selection of the relevant token. */ - List stop; + @JsonProperty("logit_bias") + Map logitBias; /** * The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will @@ -66,6 +63,11 @@ public class ChatCompletionRequest { @JsonProperty("max_tokens") Integer maxTokens; + /** + * How many chat completion chatCompletionChoices to generate for each input message. + */ + Integer n; + /** * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, * increasing the model's likelihood to talk about new topics. @@ -74,31 +76,36 @@ public class ChatCompletionRequest { Double presencePenalty; /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, - * decreasing the model's likelihood to repeat the same line verbatim. + * Up to 4 sequences where the API will stop generating further tokens. */ - @JsonProperty("frequency_penalty") - Double frequencyPenalty; + List stop; /** - * Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 - * to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will - * vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 - * should result in a ban or exclusive selection of the relevant token. + * If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only server-sent + * events as they become available, with the stream terminated by a data: [DONE] message. */ - @JsonProperty("logit_bias") - Map logitBias; + Boolean stream; + /** + * What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower + * values like 0.2 will make it more focused and deterministic.
+ * We generally recommend altering this or top_p but not both. + */ + Double temperature; /** - * A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. + * An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens + * with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
+ * We generally recommend altering this or temperature but not both. */ - String user; + @JsonProperty("top_p") + Double topP; /** - * A list of the available functions. + * A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. */ - List functions; + String user; /** * Controls how the model responds to function calls, as specified in the OpenAI documentation. @@ -106,6 +113,11 @@ public class ChatCompletionRequest { @JsonProperty("function_call") ChatCompletionRequestFunctionCall functionCall; + /** + * A list of the available functions. + */ + List functions; + @Data @Builder @AllArgsConstructor @@ -118,4 +130,18 @@ public static ChatCompletionRequestFunctionCall of(String name) { } } + + @Data + @Builder + @AllArgsConstructor + @NoArgsConstructor + public static class ResponseFormat { + String type; + + public static ResponseFormat of(String type) { + return new ResponseFormat(type); + } + + } + } diff --git a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java index 25f0defb..e10144dc 100644 --- a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java +++ b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java @@ -2,11 +2,14 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import com.theokanning.openai.completion.chat.*; import org.junit.jupiter.api.Test; +import java.io.IOException; import java.util.*; import static org.junit.jupiter.api.Assertions.*; @@ -23,7 +26,7 @@ static class Weather { } enum WeatherUnit { - CELSIUS, FAHRENHEIT; + CELSIUS, FAHRENHEIT } static class WeatherResponse { @@ -300,4 +303,46 @@ void streamChatCompletionWithDynamicFunctions() { assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("unit")); } + @Test + void streamChatCompletionWithJsonResponseFormat() { + final List messages = new ArrayList<>(); + + // The system message is deliberately vague in order to not give too much of a direction of how response should look like. + // The main gist there is that chat competition should always contain JSON content. + final ChatMessage systemMessage = new ChatMessage( + ChatMessageRole.SYSTEM.value(), + "You are a dog and will speak as such - but please do it in JSON." + ); + + messages.add(systemMessage); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest + .builder() + .model("gpt-4-1106-preview") + .messages(messages) + .n(1) + .maxTokens(256) + .responseFormat(ChatCompletionRequest.ResponseFormat.of("json_object")) + .build(); + + ChatCompletionResult chatCompletion = service.createChatCompletion(chatCompletionRequest); + + ChatCompletionChoice chatCompletionChoice = chatCompletion.getChoices().get(0); + String expectedJsonContent = chatCompletionChoice.getMessage().getContent(); + + assertTrue(isValidJSON(expectedJsonContent), "Invalid JSON response:\n\n" + expectedJsonContent); + } + + private boolean isValidJSON(String json) { + try (final JsonParser parser = new ObjectMapper().createParser(json)) { + while (parser.nextToken() != null) { + // Just try to read all tokens in order to verify whether this is valid json. + } + return true; + } catch (IOException ioe) { + ioe.printStackTrace(); + return false; + } + } + }