Skip to content

Commit

Permalink
Extract ResponseFormat to standalone class
Browse files Browse the repository at this point in the history
 - Extracts ResponseFormat from being a nested record in OpenAiApi to
   a dedicated class with builder pattern support.
 - Resolve the issue with constructor bindings for the Boog property
   binding.
 - Re-enables previously disabled response format integration tests.

 - Add checkstyle changes
 - Add schema field in ResponseFormat and set jsonSchema via the setter for schema,
   this way schema set via a Boot property also sets the correct JsonSchema
 - Add default constructors in ResponseFormat and JsonSchema

 Resolves #1681
  • Loading branch information
tzolov authored and sobychacko committed Nov 6, 2024
1 parent fb65ed0 commit 83f7164
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder;
import org.springframework.ai.openai.api.ResponseFormat;
import org.springframework.util.Assert;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
Expand Down Expand Up @@ -870,74 +869,6 @@ public static Object FUNCTION(String functionName) {
}
}

/**
* An object specifying the format that the model must output.
* @param type Must be one of 'text' or 'json_object'.
* @param jsonSchema JSON schema object that describes the format of the JSON object.
* Only applicable when type is 'json_schema'.
*/
@JsonInclude(Include.NON_NULL)
public record ResponseFormat(
@JsonProperty("type") Type type,
@JsonProperty("json_schema") JsonSchema jsonSchema) {

public ResponseFormat(Type type) {
this(type, (JsonSchema) null);
}

public ResponseFormat(Type type, String schema) {
this(type, "custom_schema", schema, true);
}

public ResponseFormat(Type type, String name, String schema, Boolean strict) {
this(type, StringUtils.hasText(schema) ? new JsonSchema(name, schema, strict) : null);
}

public enum Type {
/**
* Generates a text response. (default)
*/
@JsonProperty("text")
TEXT,

/**
* Enables JSON mode, which guarantees the message
* the model generates is valid JSON.
*/
@JsonProperty("json_object")
JSON_OBJECT,

/**
* Enables Structured Outputs which guarantees the model
* will match your supplied JSON schema.
*/
@JsonProperty("json_schema")
JSON_SCHEMA
}

/**
* JSON schema object that describes the format of the JSON object.
* Applicable for the 'json_schema' type only.
* @param name The name of the schema.
* @param schema The JSON schema object that describes the format of the JSON object.
* @param strict If true, the model will only generate outputs that match the schema.
*/
@JsonInclude(Include.NON_NULL)
public record JsonSchema(
@JsonProperty("name") String name,
@JsonProperty("schema") Map<String, Object> schema,
@JsonProperty("strict") Boolean strict) {

public JsonSchema(String name, String schema) {
this(name, ModelOptionsUtils.jsonToMap(schema), true);
}

public JsonSchema(String name, String schema, Boolean strict) {
this(StringUtils.hasText(name) ? name : "custom_schema", ModelOptionsUtils.jsonToMap(schema), strict);
}
}

}
/**
* @param includeUsage If set, an additional chunk will be streamed
* before the data: [DONE] message. The usage field on this chunk
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.openai.api;

import java.util.Map;
import java.util.Objects;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.util.StringUtils;

/**
* An object specifying the format that the model must output.
*
* Setting the type to JSON_SCHEMA, enables Structured Outputs which ensures the model
* will match your supplied JSON schema. Learn more in the
* <a href="https://platform.openai.com/docs/guides/structured-outputs"> Structured
* Outputs guide.</a <br/>
*
* References: <a href=
* "https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format">OpenAi
* API - ResponseFormat</a>,
* <a href="https://platform.openai.com/docs/guides/structured-outputs#json-mode">JSON
* Mode</a>, <a href=
* "https://platform.openai.com/docs/guides/structured-outputs#structured-outputs-vs-json-mode">Structured
* Outputs vs JSON mode</a>
*
* @author Christian Tzolov
* @since 1.0.0
*/

@JsonInclude(Include.NON_NULL)
public class ResponseFormat {

/**
* Type Must be one of 'text', 'json_object' or 'json_schema'.
*/
@JsonProperty("type")
private Type type;

/**
* JSON schema object that describes the format of the JSON object. Only applicable
* when type is 'json_schema'.
*/
@JsonProperty("json_schema")
private JsonSchema jsonSchema = null;

private String schema;

public ResponseFormat() {

}

public Type getType() {
return this.type;
}

public void setType(Type type) {
this.type = type;
}

public JsonSchema getJsonSchema() {
return this.jsonSchema;
}

public void setJsonSchema(JsonSchema jsonSchema) {
this.jsonSchema = jsonSchema;
}

public String getSchema() {
return this.schema;
}

public void setSchema(String schema) {
this.schema = schema;
if (schema != null) {
this.jsonSchema = JsonSchema.builder().schema(schema).strict(true).build();
}
}

private ResponseFormat(Type type, JsonSchema jsonSchema) {
this.type = type;
this.jsonSchema = jsonSchema;
}

public ResponseFormat(Type type, String schema) {
this(type, StringUtils.hasText(schema) ? JsonSchema.builder().schema(schema).strict(true).build() : null);
}

public static Builder builder() {
return new Builder();
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ResponseFormat that = (ResponseFormat) o;
return this.type == that.type && Objects.equals(this.jsonSchema, that.jsonSchema);
}

@Override
public int hashCode() {
return Objects.hash(this.type, this.jsonSchema);
}

@Override
public String toString() {
return "ResponseFormat{" + "type=" + this.type + ", jsonSchema=" + this.jsonSchema + '}';
}

public static final class Builder {

private Type type;

private JsonSchema jsonSchema;

private Builder() {
}

public Builder type(Type type) {
this.type = type;
return this;
}

public Builder jsonSchema(JsonSchema jsonSchema) {
this.jsonSchema = jsonSchema;
return this;
}

public Builder jsonSchema(String jsonSchema) {
this.jsonSchema = JsonSchema.builder().schema(jsonSchema).build();
return this;
}

public ResponseFormat build() {
return new ResponseFormat(this.type, this.jsonSchema);
}

}

public enum Type {

/**
* Generates a text response. (default)
*/
@JsonProperty("text")
TEXT,

/**
* Enables JSON mode, which guarantees the message the model generates is valid
* JSON.
*/
@JsonProperty("json_object")
JSON_OBJECT,

/**
* Enables Structured Outputs which guarantees the model will match your supplied
* JSON schema.
*/
@JsonProperty("json_schema")
JSON_SCHEMA

}

/**
* JSON schema object that describes the format of the JSON object. Applicable for the
* 'json_schema' type only.
*/
@JsonInclude(Include.NON_NULL)
public static class JsonSchema {

@JsonProperty("name")
private String name;

@JsonProperty("schema")
private Map<String, Object> schema;

@JsonProperty("strict")
private Boolean strict;

public JsonSchema() {

}

public String getName() {
return this.name;
}

public Map<String, Object> getSchema() {
return this.schema;
}

public Boolean getStrict() {
return this.strict;
}

private JsonSchema(String name, Map<String, Object> schema, Boolean strict) {
this.name = name;
this.schema = schema;
this.strict = strict;
}

public static Builder builder() {
return new Builder();
}

@Override
public int hashCode() {
return Objects.hash(this.name, this.schema, this.strict);
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
JsonSchema that = (JsonSchema) o;
return Objects.equals(this.name, that.name) && Objects.equals(this.schema, that.schema)
&& Objects.equals(this.strict, that.strict);
}

public static final class Builder {

private String name = "custom_schema";

private Map<String, Object> schema;

private Boolean strict = true;

private Builder() {
}

public Builder name(String name) {
this.name = name;
return this;
}

public Builder schema(Map<String, Object> schema) {
this.schema = schema;
return this;
}

public Builder schema(String schema) {
this.schema = ModelOptionsUtils.jsonToMap(schema);
return this;
}

public Builder strict(Boolean strict) {
this.strict = strict;
return this;
}

public JsonSchema build() {
return new JsonSchema(this.name, this.schema, this.strict);
}

}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ResponseFormat;
import org.springframework.ai.openai.api.OpenAiApi.ChatModel;
import org.springframework.ai.openai.api.ResponseFormat;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand Down Expand Up @@ -80,7 +80,7 @@ void jsonObject() throws JsonMappingException, JsonProcessingException {

Prompt prompt = new Prompt("List 8 planets. Use JSON response",
OpenAiChatOptions.builder()
.withResponseFormat(new ResponseFormat(ResponseFormat.Type.JSON_OBJECT))
.withResponseFormat(ResponseFormat.builder().type(ResponseFormat.Type.JSON_OBJECT).build())
.build());

ChatResponse response = this.openAiChatModel.call(prompt);
Expand Down
Loading

0 comments on commit 83f7164

Please sign in to comment.