diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 0ada1f33ba..e2a4087529 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -34,9 +34,9 @@ import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.openai.api.OpenAiApi; -import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder; +import org.springframework.ai.openai.api.ResponseFormat; import org.springframework.util.Assert; /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index cb0ca13ca1..9f7111bb9d 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -42,7 +42,6 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import org.springframework.util.StringUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; @@ -870,74 +869,6 @@ public static Object FUNCTION(String functionName) { } } - /** - * An object specifying the format that the model must output. - * @param type Must be one of 'text' or 'json_object'. - * @param jsonSchema JSON schema object that describes the format of the JSON object. - * Only applicable when type is 'json_schema'. - */ - @JsonInclude(Include.NON_NULL) - public record ResponseFormat( - @JsonProperty("type") Type type, - @JsonProperty("json_schema") JsonSchema jsonSchema) { - - public ResponseFormat(Type type) { - this(type, (JsonSchema) null); - } - - public ResponseFormat(Type type, String schema) { - this(type, "custom_schema", schema, true); - } - - public ResponseFormat(Type type, String name, String schema, Boolean strict) { - this(type, StringUtils.hasText(schema) ? new JsonSchema(name, schema, strict) : null); - } - - public enum Type { - /** - * Generates a text response. (default) - */ - @JsonProperty("text") - TEXT, - - /** - * Enables JSON mode, which guarantees the message - * the model generates is valid JSON. - */ - @JsonProperty("json_object") - JSON_OBJECT, - - /** - * Enables Structured Outputs which guarantees the model - * will match your supplied JSON schema. - */ - @JsonProperty("json_schema") - JSON_SCHEMA - } - - /** - * JSON schema object that describes the format of the JSON object. - * Applicable for the 'json_schema' type only. - * @param name The name of the schema. - * @param schema The JSON schema object that describes the format of the JSON object. - * @param strict If true, the model will only generate outputs that match the schema. - */ - @JsonInclude(Include.NON_NULL) - public record JsonSchema( - @JsonProperty("name") String name, - @JsonProperty("schema") Map schema, - @JsonProperty("strict") Boolean strict) { - - public JsonSchema(String name, String schema) { - this(name, ModelOptionsUtils.jsonToMap(schema), true); - } - - public JsonSchema(String name, String schema, Boolean strict) { - this(StringUtils.hasText(name) ? name : "custom_schema", ModelOptionsUtils.jsonToMap(schema), strict); - } - } - - } /** * @param includeUsage If set, an additional chunk will be streamed * before the data: [DONE] message. The usage field on this chunk diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ResponseFormat.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ResponseFormat.java new file mode 100644 index 0000000000..50711099dc --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/ResponseFormat.java @@ -0,0 +1,286 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.api; + +import java.util.Map; +import java.util.Objects; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.util.StringUtils; + +/** + * An object specifying the format that the model must output. + * + * Setting the type to JSON_SCHEMA, enables Structured Outputs which ensures the model + * will match your supplied JSON schema. Learn more in the + * Structured + * Outputs guide. + * + * References: OpenAi + * API - ResponseFormat, + * JSON + * Mode, Structured + * Outputs vs JSON mode + * + * @author Christian Tzolov + * @since 1.0.0 + */ + +@JsonInclude(Include.NON_NULL) +public class ResponseFormat { + + /** + * Type Must be one of 'text', 'json_object' or 'json_schema'. + */ + @JsonProperty("type") + private Type type; + + /** + * JSON schema object that describes the format of the JSON object. Only applicable + * when type is 'json_schema'. + */ + @JsonProperty("json_schema") + private JsonSchema jsonSchema = null; + + private String schema; + + public ResponseFormat() { + + } + + public Type getType() { + return this.type; + } + + public void setType(Type type) { + this.type = type; + } + + public JsonSchema getJsonSchema() { + return this.jsonSchema; + } + + public void setJsonSchema(JsonSchema jsonSchema) { + this.jsonSchema = jsonSchema; + } + + public String getSchema() { + return this.schema; + } + + public void setSchema(String schema) { + this.schema = schema; + if (schema != null) { + this.jsonSchema = JsonSchema.builder().schema(schema).strict(true).build(); + } + } + + private ResponseFormat(Type type, JsonSchema jsonSchema) { + this.type = type; + this.jsonSchema = jsonSchema; + } + + public ResponseFormat(Type type, String schema) { + this(type, StringUtils.hasText(schema) ? JsonSchema.builder().schema(schema).strict(true).build() : null); + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ResponseFormat that = (ResponseFormat) o; + return this.type == that.type && Objects.equals(this.jsonSchema, that.jsonSchema); + } + + @Override + public int hashCode() { + return Objects.hash(this.type, this.jsonSchema); + } + + @Override + public String toString() { + return "ResponseFormat{" + "type=" + this.type + ", jsonSchema=" + this.jsonSchema + '}'; + } + + public static final class Builder { + + private Type type; + + private JsonSchema jsonSchema; + + private Builder() { + } + + public Builder type(Type type) { + this.type = type; + return this; + } + + public Builder jsonSchema(JsonSchema jsonSchema) { + this.jsonSchema = jsonSchema; + return this; + } + + public Builder jsonSchema(String jsonSchema) { + this.jsonSchema = JsonSchema.builder().schema(jsonSchema).build(); + return this; + } + + public ResponseFormat build() { + return new ResponseFormat(this.type, this.jsonSchema); + } + + } + + public enum Type { + + /** + * Generates a text response. (default) + */ + @JsonProperty("text") + TEXT, + + /** + * Enables JSON mode, which guarantees the message the model generates is valid + * JSON. + */ + @JsonProperty("json_object") + JSON_OBJECT, + + /** + * Enables Structured Outputs which guarantees the model will match your supplied + * JSON schema. + */ + @JsonProperty("json_schema") + JSON_SCHEMA + + } + + /** + * JSON schema object that describes the format of the JSON object. Applicable for the + * 'json_schema' type only. + */ + @JsonInclude(Include.NON_NULL) + public static class JsonSchema { + + @JsonProperty("name") + private String name; + + @JsonProperty("schema") + private Map schema; + + @JsonProperty("strict") + private Boolean strict; + + public JsonSchema() { + + } + + public String getName() { + return this.name; + } + + public Map getSchema() { + return this.schema; + } + + public Boolean getStrict() { + return this.strict; + } + + private JsonSchema(String name, Map schema, Boolean strict) { + this.name = name; + this.schema = schema; + this.strict = strict; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public int hashCode() { + return Objects.hash(this.name, this.schema, this.strict); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + JsonSchema that = (JsonSchema) o; + return Objects.equals(this.name, that.name) && Objects.equals(this.schema, that.schema) + && Objects.equals(this.strict, that.strict); + } + + public static final class Builder { + + private String name = "custom_schema"; + + private Map schema; + + private Boolean strict = true; + + private Builder() { + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder schema(Map schema) { + this.schema = schema; + return this; + } + + public Builder schema(String schema) { + this.schema = ModelOptionsUtils.jsonToMap(schema); + return this; + } + + public Builder strict(Boolean strict) { + this.strict = strict; + return this; + } + + public JsonSchema build() { + return new JsonSchema(this.name, this.schema, this.strict); + } + + } + + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java index a9c9f2724e..3f24a24b36 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelResponseFormatIT.java @@ -33,8 +33,8 @@ import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; -import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.openai.api.OpenAiApi.ChatModel; +import org.springframework.ai.openai.api.ResponseFormat; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.SpringBootTest; @@ -80,7 +80,7 @@ void jsonObject() throws JsonMappingException, JsonProcessingException { Prompt prompt = new Prompt("List 8 planets. Use JSON response", OpenAiChatOptions.builder() - .withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_OBJECT)) + .withResponseFormat(ResponseFormat.builder().type(ResponseFormat.Type.JSON_OBJECT).build()) .build()); ChatResponse response = this.openAiChatModel.call(prompt); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java index 6e15eff938..ec09403ea6 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java @@ -16,7 +16,6 @@ package org.springframework.ai.openai.chat; -import java.time.Duration; import java.util.List; import java.util.Optional; diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java index 6ba9e1af51..c7f5baaec1 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/openai/OpenAiResponseFormatPropertiesTests.java @@ -16,7 +16,6 @@ package org.springframework.ai.autoconfigure.openai; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.springframework.ai.openai.OpenAiAudioSpeechModel; @@ -24,8 +23,8 @@ import org.springframework.ai.openai.OpenAiChatModel; import org.springframework.ai.openai.OpenAiEmbeddingModel; import org.springframework.ai.openai.OpenAiImageModel; -import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.openai.api.OpenAiAudioApi; +import org.springframework.ai.openai.api.ResponseFormat; import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -39,7 +38,6 @@ public class OpenAiResponseFormatPropertiesTests { @Test - @Disabled("GH-1645") public void responseFormatJsonSchema() { String responseFormatJsonSchema = """ @@ -72,13 +70,12 @@ public void responseFormatJsonSchema() { assertThat(connectionProperties.getApiKey()).isEqualTo("API_KEY"); - assertThat(chatProperties.getOptions().getResponseFormat()).isEqualTo( - new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, "MyName", responseFormatJsonSchema, true)); + assertThat(chatProperties.getOptions().getResponseFormat()) + .isEqualTo(new ResponseFormat(ResponseFormat.Type.JSON_SCHEMA, responseFormatJsonSchema)); }); } @Test - @Disabled("GH-1645") public void responseFormatJsonObject() { new ApplicationContextRunner() @@ -90,7 +87,7 @@ public void responseFormatJsonObject() { var chatProperties = context.getBean(OpenAiChatProperties.class); assertThat(chatProperties.getOptions().getResponseFormat()) - .isEqualTo(new ResponseFormat(ResponseFormat.Type.JSON_OBJECT)); + .isEqualTo(ResponseFormat.builder().type(ResponseFormat.Type.JSON_OBJECT).build()); }); }