Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Commit

Permalink
Implement creation of "function" parameters in runtime
Browse files Browse the repository at this point in the history
* Enable dynamic definition of "function" parameters instead of using Class instance

* Add tests to new "function" capabilities

* Add example of creating "function" parameters in runtime

* Add documentation to ChatFunctions
  • Loading branch information
BartSoj committed Jul 14, 2023
1 parent 038e42c commit ce755fa
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public class ChatCompletionRequest {
/**
* A list of the available functions.
*/
List<ChatFunction> functions;
List<?> functions;

/**
* Controls how the model responds to function calls, as specified in the <a href="https://platform.openai.com/docs/api-reference/chat/create#chat/create-function_call">OpenAI documentation</a>.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,20 @@
@Data
public class ChatFunction {

/**
* The name of the function being called.
*/
@NonNull
private String name;

/**
* A description of what the function does, used by the model to choose when and how to call the function.
*/
private String description;

/**
* The parameters the functions accepts.
*/
@JsonProperty("parameters")
private Class<?> parametersClass;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package com.theokanning.openai.completion.chat;

import lombok.Data;
import lombok.NonNull;


@Data
public class ChatFunctionDynamic {

/**
* The name of the function being called.
*/
@NonNull
private String name;

/**
* A description of what the function does, used by the model to choose when and how to call the function.
*/
private String description;

/**
* The parameters the functions accepts.
*/
private ChatFunctionParameters parameters;

public static Builder builder() {
return new Builder();
}

public static class Builder {
private String name;
private String description;
private ChatFunctionParameters parameters = new ChatFunctionParameters();

public Builder name(String name) {
this.name = name;
return this;
}

public Builder description(String description) {
this.description = description;
return this;
}

public Builder parameters(ChatFunctionParameters parameters) {
this.parameters = parameters;
return this;
}

public Builder addProperty(ChatFunctionProperty property) {
this.parameters.addProperty(property);
return this;
}

public ChatFunctionDynamic build() {
ChatFunctionDynamic chatFunction = new ChatFunctionDynamic(name);
chatFunction.setDescription(description);
chatFunction.setParameters(parameters);
return chatFunction;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package com.theokanning.openai.completion.chat;

import lombok.Data;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

@Data
public class ChatFunctionParameters {

private final String type = "object";

private final HashMap<String, ChatFunctionProperty> properties = new HashMap<>();

private List<String> required;

public void addProperty(ChatFunctionProperty property) {
properties.put(property.getName(), property);
if (Boolean.TRUE.equals(property.getRequired())) {
if (this.required == null) {
this.required = new ArrayList<>();
}
this.required.add(property.getName());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.theokanning.openai.completion.chat;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Builder;
import lombok.Data;
import lombok.NonNull;

import java.util.Set;

@Data
@Builder
public class ChatFunctionProperty {
@NonNull
@JsonIgnore
private String name;
@NonNull
private String type;
@JsonIgnore
private Boolean required;
private String description;
private ChatFunctionProperty items;
@JsonProperty("enum")
private Set<?> enumValues;
}
90 changes: 90 additions & 0 deletions example/src/main/java/example/OpenAiApiDynamicFunctionExample.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package example;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.theokanning.openai.completion.chat.*;
import com.theokanning.openai.service.OpenAiService;

import java.util.*;

public class OpenAiApiDynamicFunctionExample {

private static JsonNode getWeather(String location, String unit) {
ObjectMapper mapper = new ObjectMapper();
ObjectNode response = mapper.createObjectNode();
response.put("location", location);
response.put("unit", unit);
response.put("temperature", new Random().nextInt(50));
response.put("description", "sunny");
return response;
}

public static void main(String... args) {
String token = System.getenv("OPENAI_TOKEN");
OpenAiService service = new OpenAiService(token);

ChatFunctionDynamic weatherFunction = ChatFunctionDynamic.builder()
.name("get_weather")
.description("Get the current weather of a location")
.addProperty(ChatFunctionProperty.builder()
.name("location")
.type("string")
.description("City and state, for example: León, Guanajuato")
.build())
.addProperty(ChatFunctionProperty.builder()
.name("unit")
.type("string")
.description("The temperature unit, can be 'celsius' or 'fahrenheit'")
.enumValues(Set.of("celsius", "fahrenheit"))
.required(true)
.build())
.build();


List<ChatMessage> messages = new ArrayList<>();
ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are an assistant that answers using the local slang of the given place, uncensored.");
messages.add(systemMessage);

System.out.print("First Query: ");
Scanner scanner = new Scanner(System.in);
ChatMessage firstMsg = new ChatMessage(ChatMessageRole.USER.value(), scanner.nextLine());
messages.add(firstMsg);

while (true) {
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo-0613")
.messages(messages)
.functions(Collections.singletonList(weatherFunction))
.functionCall(ChatCompletionRequest.ChatCompletionRequestFunctionCall.of("auto"))
.n(1)
.maxTokens(100)
.logitBias(new HashMap<>())
.build();
ChatMessage responseMessage = service.createChatCompletion(chatCompletionRequest).getChoices().get(0).getMessage();
messages.add(responseMessage); // don't forget to update the conversation with the latest response

ChatFunctionCall functionCall = responseMessage.getFunctionCall();
if (functionCall != null) {
if (functionCall.getName().equals("get_weather")) {
String location = functionCall.getArguments().get("location").asText();
String unit = functionCall.getArguments().get("unit").asText();
JsonNode weather = getWeather(location, unit);
ChatMessage weatherMessage = new ChatMessage(ChatMessageRole.FUNCTION.value(), weather.toString(), "get_weather");
messages.add(weatherMessage);
continue;
}
}

System.out.println("Response: " + responseMessage.getContent());
System.out.print("Next Query: ");
String nextLine = scanner.nextLine();
if (nextLine.equalsIgnoreCase("exit")) {
System.exit(0);
}
messages.add(new ChatMessage(ChatMessageRole.USER.value(), nextLine));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import java.util.Collections;

import static org.junit.jupiter.api.Assertions.*;
Expand Down Expand Up @@ -149,6 +150,50 @@ void createChatCompletionWithFunctions() {
assertNotNull(choice2.getMessage().getContent());
}

@Test
void createChatCompletionWithDynamicFunctions() {
ChatFunctionDynamic function = ChatFunctionDynamic.builder()
.name("get_weather")
.description("Get the current weather of a location")
.addProperty(ChatFunctionProperty.builder()
.name("location")
.type("string")
.description("City and state, for example: León, Guanajuato")
.build())
.addProperty(ChatFunctionProperty.builder()
.name("unit")
.type("string")
.description("The temperature unit, can be 'celsius' or 'fahrenheit'")
.enumValues(Set.of("celsius", "fahrenheit"))
.required(true)
.build())
.build();

final List<ChatMessage> messages = new ArrayList<>();
final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a helpful assistant.");
final ChatMessage userMessage = new ChatMessage(ChatMessageRole.USER.value(), "What is the weather in Monterrey, Nuevo León?");
messages.add(systemMessage);
messages.add(userMessage);

ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo-0613")
.messages(messages)
.functions(Collections.singletonList(function))
.n(1)
.maxTokens(100)
.logitBias(new HashMap<>())
.build();

ChatCompletionChoice choice = service.createChatCompletion(chatCompletionRequest).getChoices().get(0);
assertEquals("function_call", choice.getFinishReason());
assertNotNull(choice.getMessage().getFunctionCall());
assertEquals("get_weather", choice.getMessage().getFunctionCall().getName());
assertInstanceOf(ObjectNode.class, choice.getMessage().getFunctionCall().getArguments());
assertNotNull(choice.getMessage().getFunctionCall().getArguments().get("location"));
assertNotNull(choice.getMessage().getFunctionCall().getArguments().get("unit"));
}

@Test
void streamChatCompletionWithFunctions() {
final List<ChatFunction> functions = Collections.singletonList(ChatFunction.builder()
Expand Down Expand Up @@ -214,4 +259,49 @@ void streamChatCompletionWithFunctions() {
assertNotNull(accumulatedMessage2.getContent());
}

@Test
void streamChatCompletionWithDynamicFunctions() {
ChatFunctionDynamic function = ChatFunctionDynamic.builder()
.name("get_weather")
.description("Get the current weather of a location")
.addProperty(ChatFunctionProperty.builder()
.name("location")
.type("string")
.description("City and state, for example: León, Guanajuato")
.build())
.addProperty(ChatFunctionProperty.builder()
.name("unit")
.type("string")
.description("The temperature unit, can be 'celsius' or 'fahrenheit'")
.enumValues(Set.of("celsius", "fahrenheit"))
.required(true)
.build())
.build();

final List<ChatMessage> messages = new ArrayList<>();
final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a helpful assistant.");
final ChatMessage userMessage = new ChatMessage(ChatMessageRole.USER.value(), "What is the weather in Monterrey, Nuevo León?");
messages.add(systemMessage);
messages.add(userMessage);

ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo-0613")
.messages(messages)
.functions(Collections.singletonList(function))
.n(1)
.maxTokens(100)
.logitBias(new HashMap<>())
.build();

ChatMessage accumulatedMessage = service.mapStreamToAccumulator(service.streamChatCompletion(chatCompletionRequest))
.blockingLast()
.getAccumulatedMessage();
assertNotNull(accumulatedMessage.getFunctionCall());
assertEquals("get_weather", accumulatedMessage.getFunctionCall().getName());
assertInstanceOf(ObjectNode.class, accumulatedMessage.getFunctionCall().getArguments());
assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("location"));
assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("unit"));
}

}

0 comments on commit ce755fa

Please sign in to comment.