From 21f9420743372126b07015521c7bb9836784f983 Mon Sep 17 00:00:00 2001 From: Raja Kolli Date: Mon, 3 Jun 2024 05:45:34 +0000 Subject: [PATCH] Update RAG sample --- .../service/AIChatService.java | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java b/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java index 3a06340..05656d8 100644 --- a/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java +++ b/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java @@ -1,17 +1,17 @@ package com.learning.ai.llmragwithspringai.service; +import static org.springframework.ai.chat.client.advisor.AbstractChatMemoryAdvisor.CHAT_MEMORY_RETRIEVE_SIZE_KEY; + import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import org.springframework.ai.chat.client.ChatClient; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.client.advisor.PromptChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor; +import org.springframework.ai.chat.memory.InMemoryChatMemory; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.Prompt; -import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.document.Document; -import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.stereotype.Service; @@ -40,8 +40,15 @@ with one player from the fielding team (the bowler) bowling the ball towards the private final ChatClient aiClient; private final VectorStore vectorStore; - public AIChatService(ChatClient.Builder chatClientBuilder, VectorStore vectorStore) { - this.aiClient = chatClientBuilder.build(); + public AIChatService(ChatClient.Builder modelBuilder, VectorStore vectorStore) { + this.aiClient = modelBuilder + .defaultSystem(template) + .defaultAdvisors( + new PromptChatMemoryAdvisor(new InMemoryChatMemory()), + // new MessageChatMemoryAdvisor(chatMemory), // CHAT MEMORY + new QuestionAnswerAdvisor(vectorStore, SearchRequest.defaults())) // RAG + .defaultFunctions("currentDateFunction") // FUNCTION CALLING + .build(); this.vectorStore = vectorStore; } @@ -53,12 +60,12 @@ public String chat(String searchQuery) { .collect(Collectors.joining(System.lineSeparator())); // Constructing the systemMessage to indicate the AI model to use the passed information // to answer the question. - Message systemMessage = new SystemPromptTemplate(template).createMessage(Map.of("documents", documents)); - UserMessage userMessage = new UserMessage(searchQuery); - OpenAiChatOptions chatOptions = - OpenAiChatOptions.builder().withFunction("currentDateFunction").build(); - Prompt prompt = new Prompt(List.of(systemMessage, userMessage), chatOptions); - ChatResponse aiResponse = aiClient.prompt(prompt).call().chatResponse(); + ChatResponse aiResponse = aiClient.prompt() + .system(sp -> sp.param("documents", documents)) + .user(searchQuery) + .advisors(a -> a.param(CHAT_MEMORY_RETRIEVE_SIZE_KEY, 100)) + .call() + .chatResponse(); Generation generation = aiResponse.getResult(); return (generation != null) ? generation.getOutput().getContent() : ""; }