From c894cd0a7ebe5051e70a4e16b7d0ddfb6ac0fae5 Mon Sep 17 00:00:00 2001 From: Raja Kolli Date: Sun, 7 Apr 2024 10:06:32 +0530 Subject: [PATCH] feat : polish spring chat model (#36) * feat : polish spring chat model * remove running on mac and windows * feat: using customizer to increase read and connect timeout * feat : correct way of customizing restClient --- .github/workflows/rag-springai-ollama-llm.yml | 2 +- chatmodel-springai/pom.xml | 24 +++++++++++ .../com/example/ai/config/LoggingConfig.java | 6 +-- .../example/ai/controller/ChatController.java | 41 ++++++++++++------- .../ai/model/request/AIChatRequest.java | 3 ++ .../src/main/resources/application.properties | 2 +- .../ai/controller/ChatControllerTest.java | 19 +++++---- .../config/RestClientBuilderConfig.java | 19 ++++----- .../service/AIChatService.java | 4 +- .../config/ResponseHeadersModification.java | 6 +-- .../service/AIChatService.java | 4 +- 11 files changed, 87 insertions(+), 43 deletions(-) create mode 100644 chatmodel-springai/src/main/java/com/example/ai/model/request/AIChatRequest.java diff --git a/.github/workflows/rag-springai-ollama-llm.yml b/.github/workflows/rag-springai-ollama-llm.yml index 177c0ae..5f879ac 100644 --- a/.github/workflows/rag-springai-ollama-llm.yml +++ b/.github/workflows/rag-springai-ollama-llm.yml @@ -25,7 +25,7 @@ jobs: matrix: distribution: [ 'temurin' ] java: [ '21' ] - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest] steps: - uses: actions/checkout@v4 with: diff --git a/chatmodel-springai/pom.xml b/chatmodel-springai/pom.xml index 1a84adf..870ebdb 100644 --- a/chatmodel-springai/pom.xml +++ b/chatmodel-springai/pom.xml @@ -16,6 +16,7 @@ 17 0.8.1 + 2.43.0 @@ -72,6 +73,29 @@ org.springframework.boot spring-boot-maven-plugin + + com.diffplug.spotless + spotless-maven-plugin + ${spotless.version} + + + + 2.40.0 + + + + + + + + + compile + + check + + + + diff --git a/chatmodel-springai/src/main/java/com/example/ai/config/LoggingConfig.java b/chatmodel-springai/src/main/java/com/example/ai/config/LoggingConfig.java index 9017df6..1aa66d2 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/config/LoggingConfig.java +++ b/chatmodel-springai/src/main/java/com/example/ai/config/LoggingConfig.java @@ -7,6 +7,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.web.client.RestClientCustomizer; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.http.HttpRequest; @@ -15,7 +16,6 @@ import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.util.StreamUtils; -import org.springframework.web.client.RestClient; @Configuration(proxyBeanMethods = false) @ConditionalOnProperty(value = "spring.ai.openai.api-key", havingValue = "demo") @@ -24,8 +24,8 @@ public class LoggingConfig { private static final Logger LOGGER = LoggerFactory.getLogger(LoggingConfig.class); @Bean - RestClient.Builder restClientBuilder() { - return RestClient.builder() + public RestClientCustomizer restClientCustomizer() { + return restClientBuilder -> restClientBuilder .requestFactory(new BufferingClientHttpRequestFactory(new HttpComponentsClientHttpRequestFactory())) .requestInterceptor((request, body, execution) -> { logRequest(request, body); diff --git a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java index 15b1d52..8e80cd5 100644 --- a/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java +++ b/chatmodel-springai/src/main/java/com/example/ai/controller/ChatController.java @@ -1,17 +1,20 @@ package com.example.ai.controller; +import com.example.ai.model.request.AIChatRequest; import com.example.ai.model.response.AIChatResponse; import java.util.List; import java.util.Map; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.PromptTemplate; -import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; @RestController @@ -20,32 +23,42 @@ public class ChatController { private final ChatClient chatClient; - ChatController(ChatClient chatClient) { + private final EmbeddingClient embeddingClient; + + ChatController(ChatClient chatClient, EmbeddingClient embeddingClient) { this.chatClient = chatClient; + this.embeddingClient = embeddingClient; } - @GetMapping("/chat") - Map chat(@RequestParam String question) { - var response = chatClient.call(question); - return Map.of("question", question, "answer", response); + @PostMapping("/chat") + AIChatResponse chat(@RequestBody AIChatRequest aiChatRequest) { + var answer = chatClient.call(aiChatRequest.query()); + return new AIChatResponse(answer); } - @GetMapping("/chat-with-prompt") - AIChatResponse chatWithPrompt(@RequestParam String subject) { + @PostMapping("/chat-with-prompt") + AIChatResponse chatWithPrompt(@RequestBody AIChatRequest aiChatRequest) { PromptTemplate promptTemplate = new PromptTemplate("Tell me a joke about {subject}"); - Prompt prompt = promptTemplate.create(Map.of("subject", subject)); + Prompt prompt = promptTemplate.create(Map.of("subject", aiChatRequest.query())); ChatResponse response = chatClient.call(prompt); - String answer = response.getResult().getOutput().getContent(); + Generation generation = response.getResult(); + String answer = (generation != null) ? generation.getOutput().getContent() : ""; return new AIChatResponse(answer); } - @GetMapping("/chat-with-system-prompt") - AIChatResponse chatWithSystemPrompt(@RequestParam String subject) { + @PostMapping("/chat-with-system-prompt") + AIChatResponse chatWithSystemPrompt(@RequestBody AIChatRequest aiChatRequest) { SystemMessage systemMessage = new SystemMessage("You are a sarcastic and funny chatbot"); - UserMessage userMessage = new UserMessage("Tell me a joke about " + subject); + UserMessage userMessage = new UserMessage("Tell me a joke about " + aiChatRequest.query()); Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); ChatResponse response = chatClient.call(prompt); String answer = response.getResult().getOutput().getContent(); return new AIChatResponse(answer); } + + @PostMapping("/emebedding-client-conversion") + AIChatResponse chatWithEmbeddingClient(@RequestBody AIChatRequest aiChatRequest) { + List embed = embeddingClient.embed(aiChatRequest.query()); + return new AIChatResponse(embed.toString()); + } } diff --git a/chatmodel-springai/src/main/java/com/example/ai/model/request/AIChatRequest.java b/chatmodel-springai/src/main/java/com/example/ai/model/request/AIChatRequest.java new file mode 100644 index 0000000..b0d7c3c --- /dev/null +++ b/chatmodel-springai/src/main/java/com/example/ai/model/request/AIChatRequest.java @@ -0,0 +1,3 @@ +package com.example.ai.model.request; + +public record AIChatRequest(String query) {} diff --git a/chatmodel-springai/src/main/resources/application.properties b/chatmodel-springai/src/main/resources/application.properties index e11ba56..8f173ff 100644 --- a/chatmodel-springai/src/main/resources/application.properties +++ b/chatmodel-springai/src/main/resources/application.properties @@ -6,7 +6,7 @@ spring.ai.openai.chat.options.model=gpt-3.5-turbo spring.ai.openai.chat.options.temperature=0.2 spring.ai.openai.chat.options.responseFormat=json_object -spring.ai.openai.embedding.enabled=false +spring.ai.openai.embedding.enabled=true ##logging logging.level.org.apache.hc.client5.http=INFO diff --git a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java index 1ef927a..973f72f 100644 --- a/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java +++ b/chatmodel-springai/src/test/java/com/example/ai/controller/ChatControllerTest.java @@ -2,9 +2,10 @@ import static io.restassured.RestAssured.given; import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.containsStringIgnoringCase; +import com.example.ai.model.request.AIChatRequest; import io.restassured.RestAssured; +import io.restassured.http.ContentType; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -25,20 +26,21 @@ void setUp() { @Test void testChat() { - given().param("question", "Hello?") + given().contentType(ContentType.JSON) + .body(new AIChatRequest("Hello?")) .when() - .get("/api/ai/chat") + .post("/api/ai/chat") .then() .statusCode(200) - .body("question", containsStringIgnoringCase("Hello?")) .body("answer", containsString("Hello!")); } @Test void chatWithPrompt() { - given().param("subject", "java") + given().contentType(ContentType.JSON) + .body(new AIChatRequest("java")) .when() - .get("/api/ai/chat-with-prompt") + .post("/api/ai/chat-with-prompt") .then() .statusCode(200) .body("answer", containsString("Java")); @@ -46,9 +48,10 @@ void chatWithPrompt() { @Test void chatWithSystemPrompt() { - given().param("subject", "cricket") + given().contentType(ContentType.JSON) + .body(new AIChatRequest("cricket")) .when() - .get("/api/ai/chat-with-system-prompt") + .post("/api/ai/chat-with-system-prompt") .then() .statusCode(200) .body("answer", containsString("cricket")); diff --git a/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/config/RestClientBuilderConfig.java b/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/config/RestClientBuilderConfig.java index 726829e..9407e24 100644 --- a/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/config/RestClientBuilderConfig.java +++ b/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/config/RestClientBuilderConfig.java @@ -1,23 +1,20 @@ package com.learning.ai.llmragwithspringai.config; import java.time.Duration; +import org.springframework.boot.web.client.ClientHttpRequestFactories; +import org.springframework.boot.web.client.ClientHttpRequestFactorySettings; +import org.springframework.boot.web.client.RestClientCustomizer; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.http.client.JdkClientHttpRequestFactory; -import org.springframework.web.client.RestClient; @Configuration(proxyBeanMethods = false) public class RestClientBuilderConfig { @Bean - RestClient.Builder restClientBuilder(JdkClientHttpRequestFactory jdkClientHttpRequestFactory) { - return RestClient.builder().requestFactory(jdkClientHttpRequestFactory); - } - - @Bean - JdkClientHttpRequestFactory jdkClientHttpRequestFactory() { - JdkClientHttpRequestFactory jdkClientHttpRequestFactory = new JdkClientHttpRequestFactory(); - jdkClientHttpRequestFactory.setReadTimeout(Duration.ofMinutes(5)); - return jdkClientHttpRequestFactory; + public RestClientCustomizer restClientCustomizer() { + return restClientBuilder -> restClientBuilder.requestFactory( + ClientHttpRequestFactories.get(ClientHttpRequestFactorySettings.DEFAULTS + .withConnectTimeout(Duration.ofSeconds(60)) + .withReadTimeout(Duration.ofMinutes(5)))); } } diff --git a/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java b/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java index 2b62d9c..2030415 100644 --- a/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java +++ b/rag/rag-springai-ollama-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java @@ -7,6 +7,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; @@ -64,6 +65,7 @@ public String chat(String query) { LOGGER.info("Calling ai with prompt :{}", prompt); ChatResponse aiResponse = aiClient.call(prompt); LOGGER.info("Response received from call :{}", aiResponse); - return aiResponse.getResult().getOutput().getContent(); + Generation generation = aiResponse.getResult(); + return (generation != null) ? generation.getOutput().getContent() : ""; } } diff --git a/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/config/ResponseHeadersModification.java b/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/config/ResponseHeadersModification.java index c808f4a..b9f0c2a 100644 --- a/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/config/ResponseHeadersModification.java +++ b/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/config/ResponseHeadersModification.java @@ -4,6 +4,7 @@ import java.io.InputStream; import java.util.Collections; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.web.client.RestClientCustomizer; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.http.HttpHeaders; @@ -12,15 +13,14 @@ import org.springframework.http.client.ClientHttpResponse; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; -import org.springframework.web.client.RestClient; @Configuration(proxyBeanMethods = false) @ConditionalOnProperty(value = "spring.ai.openai.api-key", havingValue = "demo") public class ResponseHeadersModification { @Bean - RestClient.Builder restClientBuilder() { - return RestClient.builder().requestInterceptor((request, body, execution) -> { + public RestClientCustomizer restClientCustomizer() { + return restClientBuilder -> restClientBuilder.requestInterceptor((request, body, execution) -> { ClientHttpResponse response = execution.execute(request, body); return new CustomClientHttpResponse(response); }); diff --git a/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java b/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java index 66a76a5..7d48aa7 100644 --- a/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java +++ b/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java @@ -5,6 +5,7 @@ import java.util.stream.Collectors; import org.springframework.ai.chat.ChatClient; import org.springframework.ai.chat.ChatResponse; +import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; @@ -55,6 +56,7 @@ public String chat(String searchQuery) { UserMessage userMessage = new UserMessage(searchQuery); Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); ChatResponse aiResponse = aiClient.call(prompt); - return aiResponse.getResult().getOutput().getContent(); + Generation generation = aiResponse.getResult(); + return (generation != null) ? generation.getOutput().getContent() : ""; } }