messages;
/**
- * What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower
- * values like 0.2 will make it more focused and deterministic.
- * We generally recommend altering this or top_p but not both.
- */
- Double temperature;
-
- /**
- * An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens
- * with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
- * We generally recommend altering this or temperature but not both.
+ * ID of the model to use.
*/
- @JsonProperty("top_p")
- Double topP;
+ String model;
/**
- * How many chat completion chatCompletionChoices to generate for each input message.
+ * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
+ * decreasing the model's likelihood to repeat the same line verbatim.
*/
- Integer n;
+ @JsonProperty("frequency_penalty")
+ Double frequencyPenalty;
/**
- * If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only server-sent
- * events as they become available, with the stream terminated by a data: [DONE] message.
+ * An object specifying the format that the model must output.
+ *
+ * Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.
+ *
+ * Important: when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message.
+ * Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting
+ * in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if
+ * finish_reason="length", which indicates the generation exceeded max_tokens or the conversation exceeded the max context length.
*/
- Boolean stream;
+ @JsonProperty("response_format")
+ ResponseFormat responseFormat;
/**
- * Up to 4 sequences where the API will stop generating further tokens.
+ * Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100
+ * to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will
+ * vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100
+ * should result in a ban or exclusive selection of the relevant token.
*/
- List stop;
+ @JsonProperty("logit_bias")
+ Map logitBias;
/**
* The maximum number of tokens allowed for the generated answer. By default, the number of tokens the model can return will
@@ -66,6 +63,11 @@ public class ChatCompletionRequest {
@JsonProperty("max_tokens")
Integer maxTokens;
+ /**
+ * How many chat completion chatCompletionChoices to generate for each input message.
+ */
+ Integer n;
+
/**
* Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,
* increasing the model's likelihood to talk about new topics.
@@ -74,31 +76,36 @@ public class ChatCompletionRequest {
Double presencePenalty;
/**
- * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far,
- * decreasing the model's likelihood to repeat the same line verbatim.
+ * Up to 4 sequences where the API will stop generating further tokens.
*/
- @JsonProperty("frequency_penalty")
- Double frequencyPenalty;
+ List stop;
/**
- * Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100
- * to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will
- * vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100
- * should result in a ban or exclusive selection of the relevant token.
+ * If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only server-sent
+ * events as they become available, with the stream terminated by a data: [DONE] message.
*/
- @JsonProperty("logit_bias")
- Map logitBias;
+ Boolean stream;
+ /**
+ * What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower
+ * values like 0.2 will make it more focused and deterministic.
+ * We generally recommend altering this or top_p but not both.
+ */
+ Double temperature;
/**
- * A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
+ * An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens
+ * with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
+ * We generally recommend altering this or temperature but not both.
*/
- String user;
+ @JsonProperty("top_p")
+ Double topP;
/**
- * A list of the available functions.
+ * A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
*/
- List> functions;
+ String user;
/**
* Controls how the model responds to function calls, as specified in the OpenAI documentation.
@@ -106,6 +113,11 @@ public class ChatCompletionRequest {
@JsonProperty("function_call")
ChatCompletionRequestFunctionCall functionCall;
+ /**
+ * A list of the available functions.
+ */
+ List> functions;
+
@Data
@Builder
@AllArgsConstructor
@@ -118,4 +130,18 @@ public static ChatCompletionRequestFunctionCall of(String name) {
}
}
+
+ @Data
+ @Builder
+ @AllArgsConstructor
+ @NoArgsConstructor
+ public static class ResponseFormat {
+ String type;
+
+ public static ResponseFormat of(String type) {
+ return new ResponseFormat(type);
+ }
+
+ }
+
}
diff --git a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java
index 25f0defb..e10144dc 100644
--- a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java
+++ b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java
@@ -2,11 +2,14 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.theokanning.openai.completion.chat.*;
import org.junit.jupiter.api.Test;
+import java.io.IOException;
import java.util.*;
import static org.junit.jupiter.api.Assertions.*;
@@ -23,7 +26,7 @@ static class Weather {
}
enum WeatherUnit {
- CELSIUS, FAHRENHEIT;
+ CELSIUS, FAHRENHEIT
}
static class WeatherResponse {
@@ -300,4 +303,46 @@ void streamChatCompletionWithDynamicFunctions() {
assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("unit"));
}
+ @Test
+ void streamChatCompletionWithJsonResponseFormat() {
+ final List messages = new ArrayList<>();
+
+ // The system message is deliberately vague in order to not give too much of a direction of how response should look like.
+ // The main gist there is that chat competition should always contain JSON content.
+ final ChatMessage systemMessage = new ChatMessage(
+ ChatMessageRole.SYSTEM.value(),
+ "You are a dog and will speak as such - but please do it in JSON."
+ );
+
+ messages.add(systemMessage);
+
+ ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
+ .builder()
+ .model("gpt-4-1106-preview")
+ .messages(messages)
+ .n(1)
+ .maxTokens(256)
+ .responseFormat(ChatCompletionRequest.ResponseFormat.of("json_object"))
+ .build();
+
+ ChatCompletionResult chatCompletion = service.createChatCompletion(chatCompletionRequest);
+
+ ChatCompletionChoice chatCompletionChoice = chatCompletion.getChoices().get(0);
+ String expectedJsonContent = chatCompletionChoice.getMessage().getContent();
+
+ assertTrue(isValidJSON(expectedJsonContent), "Invalid JSON response:\n\n" + expectedJsonContent);
+ }
+
+ private boolean isValidJSON(String json) {
+ try (final JsonParser parser = new ObjectMapper().createParser(json)) {
+ while (parser.nextToken() != null) {
+ // Just try to read all tokens in order to verify whether this is valid json.
+ }
+ return true;
+ } catch (IOException ioe) {
+ ioe.printStackTrace();
+ return false;
+ }
+ }
+
}