Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : upgrade to 1.0.0-M1 #61

Closed
wants to merge 10 commits into from
Closed
20 changes: 20 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@
"projectName": "neo4j-springai",
"args": "--spring.profiles.active=local",
"envFile": "${workspaceFolder}/.env"
},
{
"type": "java",
"name": "Spring Boot-TestLlmRagWithSpringAiApplication<rag-springai-openai-llm>",
"request": "launch",
"cwd": "${workspaceFolder}",
"mainClass": "com.learning.ai.llmragwithspringai.TestLlmRagWithSpringAiApplication",
"projectName": "rag-springai-openai-llm",
"args": "",
"envFile": "${workspaceFolder}/.env"
},
{
"type": "java",
"name": "Spring Boot-LlmRagWithSpringAiApplication<rag-springai-openai-llm>",
"request": "launch",
"cwd": "${workspaceFolder}",
"mainClass": "com.learning.ai.llmragwithspringai.LlmRagWithSpringAiApplication",
"projectName": "rag-springai-openai-llm",
"args": "--spring.profiles.active=local",
"envFile": "${workspaceFolder}/.env"
}
]
}
2 changes: 1 addition & 1 deletion rag/rag-springai-openai-llm/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

<properties>
<java.version>21</java.version>
<spring-ai.version>0.8.1</spring-ai.version>
<spring-ai.version>1.0.0-M1</spring-ai.version>
<spotless.version>2.43.0</spotless.version>
</properties>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,68 @@

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.web.client.ClientHttpRequestFactories;
import org.springframework.boot.web.client.ClientHttpRequestFactorySettings;
import org.springframework.boot.web.client.RestClientCustomizer;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpRequest;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.MediaType;
import org.springframework.http.client.BufferingClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StreamUtils;

@Configuration(proxyBeanMethods = false)
@ConditionalOnProperty(value = "spring.ai.openai.api-key", havingValue = "demo")
public class ResponseHeadersModification {

private static final Logger LOGGER = LoggerFactory.getLogger(ResponseHeadersModification.class);

@Bean
RestClientCustomizer restClientCustomizer() {
return restClientBuilder -> restClientBuilder.requestInterceptor((request, body, execution) -> {
ClientHttpResponse response = execution.execute(request, body);
return new CustomClientHttpResponse(response);
});
return restClientBuilder -> restClientBuilder
.requestFactory(new BufferingClientHttpRequestFactory(
ClientHttpRequestFactories.get(ClientHttpRequestFactorySettings.DEFAULTS)))
.requestInterceptor((request, body, execution) -> {
logRequest(request, body);
ClientHttpResponse response = execution.execute(request, body);
logResponse(response);
return new CustomClientHttpResponse(response);
})
.defaultHeaders(httpHeaders -> {
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
httpHeaders.setAccept(List.of(MediaType.ALL));
});
}

private void logResponse(ClientHttpResponse response) throws IOException {
LOGGER.info("============================response begin==========================================");
LOGGER.info("Status code : {}", response.getStatusCode());
LOGGER.info("Status text : {}", response.getStatusText());
LOGGER.info("Headers : {}", response.getHeaders());
LOGGER.info("Response body: {}", StreamUtils.copyToString(response.getBody(), Charset.defaultCharset()));
LOGGER.info("=======================response end=================================================");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enhance the logResponse method to handle potential exceptions more gracefully.

- LOGGER.info("Response body: {}", StreamUtils.copyToString(response.getBody(), Charset.defaultCharset()));
+ try {
+     LOGGER.info("Response body: {}", StreamUtils.copyToString(response.getBody(), Charset.defaultCharset()));
+ } catch (IOException e) {
+     LOGGER.error("Error reading response body", e);
+ }

This change is recommended to prevent the application from crashing in case of an IOException when reading the response body.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
private void logResponse(ClientHttpResponse response) throws IOException {
LOGGER.info("============================response begin==========================================");
LOGGER.info("Status code : {}", response.getStatusCode());
LOGGER.info("Status text : {}", response.getStatusText());
LOGGER.info("Headers : {}", response.getHeaders());
LOGGER.info("Response body: {}", StreamUtils.copyToString(response.getBody(), Charset.defaultCharset()));
LOGGER.info("=======================response end=================================================");
}
private void logResponse(ClientHttpResponse response) throws IOException {
LOGGER.info("============================response begin==========================================");
LOGGER.info("Status code : {}", response.getStatusCode());
LOGGER.info("Status text : {}", response.getStatusText());
LOGGER.info("Headers : {}", response.getHeaders());
try {
LOGGER.info("Response body: {}", StreamUtils.copyToString(response.getBody(), Charset.defaultCharset()));
} catch (IOException e) {
LOGGER.error("Error reading response body", e);
}
LOGGER.info("=======================response end=================================================");
}


private void logRequest(HttpRequest request, byte[] body) {

LOGGER.info("===========================request begin================================================");
LOGGER.info("URI : {}", request.getURI());
LOGGER.info("Method : {}", request.getMethod());
LOGGER.info("Headers : {}", request.getHeaders());
LOGGER.info("Request body: {}", new String(body, StandardCharsets.UTF_8));
LOGGER.info("==========================request end================================================");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor the logRequest method to handle potential exceptions more gracefully.

- LOGGER.info("Request body: {}", new String(body, StandardCharsets.UTF_8));
+ try {
+     LOGGER.info("Request body: {}", new String(body, StandardCharsets.UTF_8));
+ } catch (Exception e) {
+     LOGGER.error("Error logging request body", e);
+ }

This change ensures that any exceptions during the logging of the request body are caught and handled appropriately.

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
private void logRequest(HttpRequest request, byte[] body) {
LOGGER.info("===========================request begin================================================");
LOGGER.info("URI : {}", request.getURI());
LOGGER.info("Method : {}", request.getMethod());
LOGGER.info("Headers : {}", request.getHeaders());
LOGGER.info("Request body: {}", new String(body, StandardCharsets.UTF_8));
LOGGER.info("==========================request end================================================");
private void logRequest(HttpRequest request, byte[] body) {
LOGGER.info("===========================request begin================================================");
LOGGER.info("URI : {}", request.getURI());
LOGGER.info("Method : {}", request.getMethod());
LOGGER.info("Headers : {}", request.getHeaders());
try {
LOGGER.info("Request body: {}", new String(body, StandardCharsets.UTF_8));
} catch (Exception e) {
LOGGER.error("Error logging request body", e);
}
LOGGER.info("==========================request end================================================");

}

private static class CustomClientHttpResponse implements ClientHttpResponse {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package com.learning.ai.llmragwithspringai.service;

import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_RETRIEVE_SIZE_KEY;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.document.Document;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.stereotype.Service;

Expand Down Expand Up @@ -40,8 +40,15 @@ with one player from the fielding team (the bowler) bowling the ball towards the
private final ChatClient aiClient;
private final VectorStore vectorStore;

public AIChatService(ChatClient aiClient, VectorStore vectorStore) {
this.aiClient = aiClient;
public AIChatService(ChatClient.Builder modelBuilder, VectorStore vectorStore) {
this.aiClient = modelBuilder
.defaultSystem(template)
.defaultAdvisors(
new PromptChatMemoryAdvisor(new InMemoryChatMemory()),
// new MessageChatMemoryAdvisor(chatMemory), // CHAT MEMORY
new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults())) // RAG
.defaultFunctions("currentDateFunction") // FUNCTION CALLING
.build();
this.vectorStore = vectorStore;
}

Expand All @@ -53,12 +60,12 @@ public String chat(String searchQuery) {
.collect(Collectors.joining(System.lineSeparator()));
// Constructing the systemMessage to indicate the AI model to use the passed information
// to answer the question.
Message systemMessage = new SystemPromptTemplate(template).createMessage(Map.of("documents", documents));
UserMessage userMessage = new UserMessage(searchQuery);
OpenAiChatOptions chatOptions =
OpenAiChatOptions.builder().withFunction("currentDateFunction").build();
Prompt prompt = new Prompt(List.of(systemMessage, userMessage), chatOptions);
ChatResponse aiResponse = aiClient.call(prompt);
ChatResponse aiResponse = aiClient.prompt()
.system(sp -> sp.param("documents", documents))
.user(searchQuery)
.advisors(a -> a.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 100))
.call()
.chatResponse();
Generation generation = aiResponse.getResult();
return (generation != null) ? generation.getOutput().getContent() : "";
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
spring.datasource.password=secret
spring.datasource.username=appuser
spring.datasource.url=jdbc:postgresql://localhost/appdb

# default value of openai is 1536
spring.ai.vectorstore.pgvector.dimensions=384
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ spring.ai.openai.chat.options.model=gpt-3.5-turbo
spring.ai.openai.chat.options.temperature=0.2
spring.ai.openai.chat.options.responseFormat=json_object

spring.ai.openai.embedding.options.model=text-embedding-ada-002

#spring.ai.openai.image.model=dall-e-3
Loading