Skip to content

Commit

Permalink
feat : polish spring chat model (#36)
Browse files Browse the repository at this point in the history
* feat : polish spring chat model

* remove running on mac and windows

* feat: using customizer to increase read and connect timeout

* feat : correct way of customizing restClient
  • Loading branch information
rajadilipkolli authored Apr 7, 2024
1 parent ef10569 commit c894cd0
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 43 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rag-springai-ollama-llm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
matrix:
distribution: [ 'temurin' ]
java: [ '21' ]
os: [ubuntu-latest, macos-latest, windows-latest]
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v4
with:
Expand Down
24 changes: 24 additions & 0 deletions chatmodel-springai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
<properties>
<java.version>17</java.version>
<spring-ai.version>0.8.1</spring-ai.version>
<spotless.version>2.43.0</spotless.version>
</properties>

<dependencies>
Expand Down Expand Up @@ -72,6 +73,29 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
<plugin>
<groupId>com.diffplug.spotless</groupId>
<artifactId>spotless-maven-plugin</artifactId>
<version>${spotless.version}</version>
<configuration>
<java>
<palantirJavaFormat>
<version>2.40.0</version>
</palantirJavaFormat>
<importOrder />
<removeUnusedImports />
<formatAnnotations />
</java>
</configuration>
<executions>
<execution>
<phase>compile</phase>
<goals>
<goal>check</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.web.client.RestClientCustomizer;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpRequest;
Expand All @@ -15,7 +16,6 @@
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.util.StreamUtils;
import org.springframework.web.client.RestClient;

@Configuration(proxyBeanMethods = false)
@ConditionalOnProperty(value = "spring.ai.openai.api-key", havingValue = "demo")
Expand All @@ -24,8 +24,8 @@ public class LoggingConfig {
private static final Logger LOGGER = LoggerFactory.getLogger(LoggingConfig.class);

@Bean
RestClient.Builder restClientBuilder() {
return RestClient.builder()
public RestClientCustomizer restClientCustomizer() {
return restClientBuilder -> restClientBuilder
.requestFactory(new BufferingClientHttpRequestFactory(new HttpComponentsClientHttpRequestFactory()))
.requestInterceptor((request, body, execution) -> {
logRequest(request, body);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
package com.example.ai.controller;

import com.example.ai.model.request.AIChatRequest;
import com.example.ai.model.response.AIChatResponse;
import java.util.List;
import java.util.Map;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;

@RestController
Expand All @@ -20,32 +23,42 @@ public class ChatController {

private final ChatClient chatClient;

ChatController(ChatClient chatClient) {
private final EmbeddingClient embeddingClient;

ChatController(ChatClient chatClient, EmbeddingClient embeddingClient) {
this.chatClient = chatClient;
this.embeddingClient = embeddingClient;
}

@GetMapping("/chat")
Map<String, String> chat(@RequestParam String question) {
var response = chatClient.call(question);
return Map.of("question", question, "answer", response);
@PostMapping("/chat")
AIChatResponse chat(@RequestBody AIChatRequest aiChatRequest) {
var answer = chatClient.call(aiChatRequest.query());
return new AIChatResponse(answer);
}

@GetMapping("/chat-with-prompt")
AIChatResponse chatWithPrompt(@RequestParam String subject) {
@PostMapping("/chat-with-prompt")
AIChatResponse chatWithPrompt(@RequestBody AIChatRequest aiChatRequest) {
PromptTemplate promptTemplate = new PromptTemplate("Tell me a joke about {subject}");
Prompt prompt = promptTemplate.create(Map.of("subject", subject));
Prompt prompt = promptTemplate.create(Map.of("subject", aiChatRequest.query()));
ChatResponse response = chatClient.call(prompt);
String answer = response.getResult().getOutput().getContent();
Generation generation = response.getResult();
String answer = (generation != null) ? generation.getOutput().getContent() : "";
return new AIChatResponse(answer);
}

@GetMapping("/chat-with-system-prompt")
AIChatResponse chatWithSystemPrompt(@RequestParam String subject) {
@PostMapping("/chat-with-system-prompt")
AIChatResponse chatWithSystemPrompt(@RequestBody AIChatRequest aiChatRequest) {
SystemMessage systemMessage = new SystemMessage("You are a sarcastic and funny chatbot");
UserMessage userMessage = new UserMessage("Tell me a joke about " + subject);
UserMessage userMessage = new UserMessage("Tell me a joke about " + aiChatRequest.query());
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse response = chatClient.call(prompt);
String answer = response.getResult().getOutput().getContent();
return new AIChatResponse(answer);
}

@PostMapping("/emebedding-client-conversion")
AIChatResponse chatWithEmbeddingClient(@RequestBody AIChatRequest aiChatRequest) {
List<Double> embed = embeddingClient.embed(aiChatRequest.query());
return new AIChatResponse(embed.toString());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.example.ai.model.request;

public record AIChatRequest(String query) {}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ spring.ai.openai.chat.options.model=gpt-3.5-turbo
spring.ai.openai.chat.options.temperature=0.2
spring.ai.openai.chat.options.responseFormat=json_object

spring.ai.openai.embedding.enabled=false
spring.ai.openai.embedding.enabled=true

##logging
logging.level.org.apache.hc.client5.http=INFO
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import static io.restassured.RestAssured.given;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.containsStringIgnoringCase;

import com.example.ai.model.request.AIChatRequest;
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
Expand All @@ -25,30 +26,32 @@ void setUp() {

@Test
void testChat() {
given().param("question", "Hello?")
given().contentType(ContentType.JSON)
.body(new AIChatRequest("Hello?"))
.when()
.get("/api/ai/chat")
.post("/api/ai/chat")
.then()
.statusCode(200)
.body("question", containsStringIgnoringCase("Hello?"))
.body("answer", containsString("Hello!"));
}

@Test
void chatWithPrompt() {
given().param("subject", "java")
given().contentType(ContentType.JSON)
.body(new AIChatRequest("java"))
.when()
.get("/api/ai/chat-with-prompt")
.post("/api/ai/chat-with-prompt")
.then()
.statusCode(200)
.body("answer", containsString("Java"));
}

@Test
void chatWithSystemPrompt() {
given().param("subject", "cricket")
given().contentType(ContentType.JSON)
.body(new AIChatRequest("cricket"))
.when()
.get("/api/ai/chat-with-system-prompt")
.post("/api/ai/chat-with-system-prompt")
.then()
.statusCode(200)
.body("answer", containsString("cricket"));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
package com.learning.ai.llmragwithspringai.config;

import java.time.Duration;
import org.springframework.boot.web.client.ClientHttpRequestFactories;
import org.springframework.boot.web.client.ClientHttpRequestFactorySettings;
import org.springframework.boot.web.client.RestClientCustomizer;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.client.JdkClientHttpRequestFactory;
import org.springframework.web.client.RestClient;

@Configuration(proxyBeanMethods = false)
public class RestClientBuilderConfig {

@Bean
RestClient.Builder restClientBuilder(JdkClientHttpRequestFactory jdkClientHttpRequestFactory) {
return RestClient.builder().requestFactory(jdkClientHttpRequestFactory);
}

@Bean
JdkClientHttpRequestFactory jdkClientHttpRequestFactory() {
JdkClientHttpRequestFactory jdkClientHttpRequestFactory = new JdkClientHttpRequestFactory();
jdkClientHttpRequestFactory.setReadTimeout(Duration.ofMinutes(5));
return jdkClientHttpRequestFactory;
public RestClientCustomizer restClientCustomizer() {
return restClientBuilder -> restClientBuilder.requestFactory(
ClientHttpRequestFactories.get(ClientHttpRequestFactorySettings.DEFAULTS
.withConnectTimeout(Duration.ofSeconds(60))
.withReadTimeout(Duration.ofMinutes(5))));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
Expand Down Expand Up @@ -64,6 +65,7 @@ public String chat(String query) {
LOGGER.info("Calling ai with prompt :{}", prompt);
ChatResponse aiResponse = aiClient.call(prompt);
LOGGER.info("Response received from call :{}", aiResponse);
return aiResponse.getResult().getOutput().getContent();
Generation generation = aiResponse.getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.io.InputStream;
import java.util.Collections;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.web.client.RestClientCustomizer;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
Expand All @@ -12,15 +13,14 @@
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.RestClient;

@Configuration(proxyBeanMethods = false)
@ConditionalOnProperty(value = "spring.ai.openai.api-key", havingValue = "demo")
public class ResponseHeadersModification {

@Bean
RestClient.Builder restClientBuilder() {
return RestClient.builder().requestInterceptor((request, body, execution) -> {
public RestClientCustomizer restClientCustomizer() {
return restClientBuilder -> restClientBuilder.requestInterceptor((request, body, execution) -> {
ClientHttpResponse response = execution.execute(request, body);
return new CustomClientHttpResponse(response);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.util.stream.Collectors;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
Expand Down Expand Up @@ -55,6 +56,7 @@ public String chat(String searchQuery) {
UserMessage userMessage = new UserMessage(searchQuery);
Prompt prompt = new Prompt(List.of(systemMessage, userMessage));
ChatResponse aiResponse = aiClient.call(prompt);
return aiResponse.getResult().getOutput().getContent();
Generation generation = aiResponse.getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
}

0 comments on commit c894cd0

Please sign in to comment.