Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the implementation of ChatClient-based interfaces #1663

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

package org.springframework.ai.chat.client;

import java.lang.reflect.Type;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
Expand All @@ -37,21 +39,24 @@
import org.springframework.ai.model.Media;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.Resource;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;

/**
* Client to perform stateless requests to an AI Model, using a fluent API.
*
* Client used to perform stateless requests to an AI Model, using a fluent API.
* <p/>
* Use {@link ChatClient#builder(ChatModel)} to prepare an instance.
*
* @author Mark Pollack
* @author Christian Tzolov
* @author Josh Long
* @author Arjen Poutsma
* @author Thomas Vitale
* @author John Blum
* @see ChatModel
* @since 1.0.0
*/
public interface ChatClient {
Expand Down Expand Up @@ -97,15 +102,22 @@ static Builder builder(ChatModel chatModel, ObservationRegistry observationRegis

interface PromptUserSpec {

PromptUserSpec text(String text);
default PromptUserSpec text(String text) {
Charset defaultCharset = Charset.defaultCharset();
return text(new ByteArrayResource(text.getBytes(defaultCharset)), defaultCharset);
}

PromptUserSpec text(Resource text, Charset charset);
default PromptUserSpec text(Resource text) {
return text(text, Charset.defaultCharset());
}

PromptUserSpec text(Resource text);
PromptUserSpec text(Resource text, Charset charset);

PromptUserSpec params(Map<String, Object> p);
default PromptUserSpec param(String key, Object value) {
return params(Map.of(key, value));
}

PromptUserSpec param(String k, Object v);
PromptUserSpec params(Map<String, Object> params);

PromptUserSpec media(Media... media);

Expand All @@ -117,25 +129,36 @@ interface PromptUserSpec {

interface PromptSystemSpec {

PromptSystemSpec text(String text);
default PromptSystemSpec text(String text) {
Charset defaultCharset = Charset.defaultCharset();
return text(new ByteArrayResource(text.getBytes(defaultCharset)), defaultCharset);
}

PromptSystemSpec text(Resource text, Charset charset);
default PromptSystemSpec text(Resource text) {
return text(text, Charset.defaultCharset());
}

PromptSystemSpec text(Resource text);
PromptSystemSpec text(Resource text, Charset charset);

PromptSystemSpec params(Map<String, Object> p);
default PromptSystemSpec param(String key, Object value) {
return params(Map.of(key, value));
}

PromptSystemSpec param(String k, Object v);
PromptSystemSpec params(Map<String, Object> params);

}

interface AdvisorSpec {

AdvisorSpec param(String k, Object v);
default AdvisorSpec param(String key, Object value) {
return params(Map.of(key, value));
}

AdvisorSpec params(Map<String, Object> p);
AdvisorSpec params(Map<String, Object> params);

AdvisorSpec advisors(Advisor... advisors);
default AdvisorSpec advisors(Advisor... advisors) {
return advisors(Arrays.asList(advisors));
}

AdvisorSpec advisors(List<Advisor> advisors);

Expand All @@ -144,21 +167,39 @@ interface AdvisorSpec {
interface CallResponseSpec {

@Nullable
<T> T entity(ParameterizedTypeReference<T> type);
default <T> T entity(Class<T> type) {

return entity(new ParameterizedTypeReference<>() {

@Override
public Type getType() {
return type;
}
});
}

@Nullable
<T> T entity(StructuredOutputConverter<T> structuredOutputConverter);
<T> T entity(ParameterizedTypeReference<T> type);

@Nullable
<T> T entity(Class<T> type);
<T> T entity(StructuredOutputConverter<T> structuredOutputConverter);

@Nullable
ChatResponse chatResponse();

@Nullable
String content();

<T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type);
default <T> ResponseEntity<ChatResponse, T> responseEntity(Class<T> type) {

return responseEntity(new ParameterizedTypeReference<T>() {

@Override
public Type getType() {
return type;
}
});
}

<T> ResponseEntity<ChatResponse, T> responseEntity(ParameterizedTypeReference<T> type);

Expand Down Expand Up @@ -202,11 +243,15 @@ interface ChatClientRequestSpec {

ChatClientRequestSpec advisors(Consumer<AdvisorSpec> consumer);

ChatClientRequestSpec advisors(Advisor... advisors);
default ChatClientRequestSpec advisors(Advisor... advisors) {
return advisors(Arrays.asList(advisors));
}

ChatClientRequestSpec advisors(List<Advisor> advisors);

ChatClientRequestSpec messages(Message... messages);
default ChatClientRequestSpec messages(Message... messages) {
return messages(Arrays.asList(messages));
}

ChatClientRequestSpec messages(List<Message> messages);

Expand All @@ -227,19 +272,29 @@ <I, O> ChatClientRequestSpec function(String name, String description, Class<I>

ChatClientRequestSpec toolContext(Map<String, Object> toolContext);

ChatClientRequestSpec system(String text);
default ChatClientRequestSpec system(String text) {
Charset defaultCharset = Charset.defaultCharset();
return system(new ByteArrayResource(text.getBytes(defaultCharset)), defaultCharset);
}

ChatClientRequestSpec system(Resource textResource, Charset charset);
default ChatClientRequestSpec system(Resource text) {
return system(text, Charset.defaultCharset());
}

ChatClientRequestSpec system(Resource text);
ChatClientRequestSpec system(Resource textResource, Charset charset);

ChatClientRequestSpec system(Consumer<PromptSystemSpec> consumer);

ChatClientRequestSpec user(String text);
default ChatClientRequestSpec user(String text) {
Charset defaultCharset = Charset.defaultCharset();
return user(new ByteArrayResource(text.getBytes(defaultCharset)), defaultCharset);
}

ChatClientRequestSpec user(Resource text, Charset charset);
default ChatClientRequestSpec user(Resource text) {
return user(text, Charset.defaultCharset());
}

ChatClientRequestSpec user(Resource text);
ChatClientRequestSpec user(Resource text, Charset charset);

ChatClientRequestSpec user(Consumer<PromptUserSpec> consumer);

Expand All @@ -254,27 +309,39 @@ <I, O> ChatClientRequestSpec function(String name, String description, Class<I>
*/
interface Builder {

Builder defaultAdvisors(Advisor... advisor);

Builder defaultAdvisors(Consumer<AdvisorSpec> advisorSpecConsumer);
default Builder defaultAdvisors(Advisor... advisors) {
return defaultAdvisors(Arrays.asList(advisors));
}

Builder defaultAdvisors(List<Advisor> advisors);

Builder defaultAdvisors(Consumer<AdvisorSpec> advisorSpecConsumer);

Builder defaultOptions(ChatOptions chatOptions);

Builder defaultUser(String text);
default Builder defaultUser(String text) {
Charset defaulCharset = Charset.defaultCharset();
return defaultUser(new ByteArrayResource(text.getBytes(defaulCharset)), defaulCharset);
}

Builder defaultUser(Resource text, Charset charset);
default Builder defaultUser(Resource text) {
return defaultUser(text, Charset.defaultCharset());
}

Builder defaultUser(Resource text);
Builder defaultUser(Resource text, Charset charset);

Builder defaultUser(Consumer<PromptUserSpec> userSpecConsumer);

Builder defaultSystem(String text);
default Builder defaultSystem(String text) {
Charset defaultCharset = Charset.defaultCharset();
return defaultSystem(new ByteArrayResource(text.getBytes(defaultCharset)), defaultCharset);
}

Builder defaultSystem(Resource text, Charset charset);
default Builder defaultSystem(Resource text) {
return defaultSystem(text, Charset.defaultCharset());
}

Builder defaultSystem(Resource text);
Builder defaultSystem(Resource text, Charset charset);

Builder defaultSystem(Consumer<PromptSystemSpec> systemSpecConsumer);

Expand Down