From b2692dda4a3187ee6eb3cbd02dc2e2bda9556c1d Mon Sep 17 00:00:00 2001 From: Raja Kolli Date: Mon, 27 May 2024 12:45:53 +0530 Subject: [PATCH] feat : adds vector store to chatbot (#58) * feat : adds vector store to chatbot * Update application.properties * upgrade to latest version * Update pom.xml * fix : wrong pom update * adds sequence diagram * tweak prompt for more data accuracy * feat : removes test case which doesnt adds context * feat : return conversationalId * reads conversational Id and passes to next question * minor polish --- chatbot-ollama-springai/ReadMe.md | 26 +++++++++ chatbot-ollama-springai/pom.xml | 18 +++--- .../example/chatbot/config/ChatConfig.java | 57 +++++++++++++++++++ .../chatbot/controller/ChatbotController.java | 2 +- .../chatbot/model/request/AIChatRequest.java | 2 +- .../model/response/AIChatResponse.java | 2 +- .../chatbot/service/ChatbotService.java | 30 +++++----- .../src/main/resources/application.properties | 6 +- .../ChatbotOllamaApplicationTests.java | 46 +++++++++------ .../chatbot/TestChatbotOllamaApplication.java | 7 +++ 10 files changed, 152 insertions(+), 44 deletions(-) diff --git a/chatbot-ollama-springai/ReadMe.md b/chatbot-ollama-springai/ReadMe.md index c351fb0..3af27ff 100644 --- a/chatbot-ollama-springai/ReadMe.md +++ b/chatbot-ollama-springai/ReadMe.md @@ -3,6 +3,8 @@ ## Sequence Diagram +Before Vector Store + ```mermaid sequenceDiagram participant User @@ -19,4 +21,28 @@ participant ChatMemory ChatbotService-->>ChatbotController: response ChatbotController-->>User: response +``` + +After Vector Store + +```mermaid +sequenceDiagram + participant User + participant ChatbotController + participant ChatbotService + participant ChatService + participant ChatMemory + participant VectorStore + + User->>ChatbotController: POST /api/ai/chat + ChatbotController->>ChatbotService: chat(message) + ChatbotService->>ChatService: Process Chat Request + ChatService->>ChatMemory: Retrieve Memory + ChatService->>VectorStore: Retrieve Vectors + ChatMemory-->>ChatService: Return Memory Data + VectorStore-->>ChatService: Return Vector Data + ChatService-->>ChatbotService: Processed Response + ChatbotService-->>ChatbotController: response + ChatbotController-->>User: response + ``` \ No newline at end of file diff --git a/chatbot-ollama-springai/pom.xml b/chatbot-ollama-springai/pom.xml index d0da21e..279e116 100644 --- a/chatbot-ollama-springai/pom.xml +++ b/chatbot-ollama-springai/pom.xml @@ -10,7 +10,7 @@ com.example.chatbot - chatbot-ollama-lamma3 + chatbot-ollama-springai 0.0.1-SNAPSHOT chatbot-ollama Demo project for Chatbot using Ollama @@ -34,6 +34,10 @@ org.springframework.ai spring-ai-ollama-spring-boot-starter + + org.springframework.ai + spring-ai-chroma-store-spring-boot-starter + org.springframework.boot @@ -45,11 +49,6 @@ rest-assured test - - org.springframework.boot - spring-boot-testcontainers - test - org.springframework.ai spring-ai-spring-boot-testcontainers @@ -65,6 +64,11 @@ ollama test + + org.testcontainers + chromadb + test + @@ -92,7 +96,7 @@ - 2.41.0 + 2.47.0 diff --git a/chatbot-ollama-springai/src/main/java/com/example/chatbot/config/ChatConfig.java b/chatbot-ollama-springai/src/main/java/com/example/chatbot/config/ChatConfig.java index 2a7ceea..74525be 100644 --- a/chatbot-ollama-springai/src/main/java/com/example/chatbot/config/ChatConfig.java +++ b/chatbot-ollama-springai/src/main/java/com/example/chatbot/config/ChatConfig.java @@ -1,9 +1,26 @@ package com.example.chatbot.config; +import java.util.List; +import java.util.Map; +import java.util.Set; import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.ChatMemoryChatServiceListener; +import org.springframework.ai.chat.memory.ChatMemoryRetriever; import org.springframework.ai.chat.memory.InMemoryChatMemory; +import org.springframework.ai.chat.memory.LastMaxTokenSizeContentTransformer; +import org.springframework.ai.chat.memory.SystemPromptChatMemoryAugmentor; +import org.springframework.ai.chat.memory.VectorStoreChatMemoryChatServiceListener; +import org.springframework.ai.chat.memory.VectorStoreChatMemoryRetriever; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.prompt.transformer.QuestionContextAugmentor; +import org.springframework.ai.chat.prompt.transformer.TransformerContentType; +import org.springframework.ai.chat.prompt.transformer.VectorStoreRetriever; +import org.springframework.ai.chat.service.ChatService; +import org.springframework.ai.chat.service.PromptTransformingChatService; import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator; import org.springframework.ai.tokenizer.TokenCountEstimator; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -19,4 +36,44 @@ ChatMemory chatHistory() { TokenCountEstimator tokenCountEstimator() { return new JTokkitTokenCountEstimator(); } + + @Bean + ChatService chatService( + ChatModel chatModel, + ChatMemory chatMemory, + TokenCountEstimator tokenCountEstimator, + VectorStore vectorStore) { + return PromptTransformingChatService.builder(chatModel) + .withRetrievers(List.of( + new VectorStoreRetriever(vectorStore, SearchRequest.defaults()), + ChatMemoryRetriever.builder() + .withChatHistory(chatMemory) + .withMetadata(Map.of(TransformerContentType.SHORT_TERM_MEMORY, "")) + .build(), + new VectorStoreChatMemoryRetriever( + vectorStore, 10, Map.of(TransformerContentType.LONG_TERM_MEMORY, "")))) + .withContentPostProcessors(List.of( + new LastMaxTokenSizeContentTransformer( + tokenCountEstimator, 1000, Set.of(TransformerContentType.SHORT_TERM_MEMORY)), + new LastMaxTokenSizeContentTransformer( + tokenCountEstimator, 1000, Set.of(TransformerContentType.LONG_TERM_MEMORY)), + new LastMaxTokenSizeContentTransformer( + tokenCountEstimator, 2000, Set.of(TransformerContentType.EXTERNAL_KNOWLEDGE)))) + .withAugmentors(List.of( + new QuestionContextAugmentor(), + new SystemPromptChatMemoryAugmentor( + """ + Use the long term conversation history from the LONG TERM HISTORY section to provide accurate answers. + + LONG TERM HISTORY: + {history} + """, + Set.of(TransformerContentType.LONG_TERM_MEMORY)), + new SystemPromptChatMemoryAugmentor(Set.of(TransformerContentType.SHORT_TERM_MEMORY)))) + .withChatServiceListeners(List.of( + new ChatMemoryChatServiceListener(chatMemory), + new VectorStoreChatMemoryChatServiceListener( + vectorStore, Map.of(TransformerContentType.LONG_TERM_MEMORY, "")))) + .build(); + } } diff --git a/chatbot-ollama-springai/src/main/java/com/example/chatbot/controller/ChatbotController.java b/chatbot-ollama-springai/src/main/java/com/example/chatbot/controller/ChatbotController.java index a77aed9..7231a0b 100644 --- a/chatbot-ollama-springai/src/main/java/com/example/chatbot/controller/ChatbotController.java +++ b/chatbot-ollama-springai/src/main/java/com/example/chatbot/controller/ChatbotController.java @@ -20,6 +20,6 @@ class ChatbotController { @PostMapping("/chat") AIChatResponse chat(@RequestBody AIChatRequest request) { - return chatbotService.chat(request.query()); + return chatbotService.chat(request); } } diff --git a/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/request/AIChatRequest.java b/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/request/AIChatRequest.java index 34712bc..1726c2a 100644 --- a/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/request/AIChatRequest.java +++ b/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/request/AIChatRequest.java @@ -1,3 +1,3 @@ package com.example.chatbot.model.request; -public record AIChatRequest(String query) {} +public record AIChatRequest(String query, String conversationId) {} diff --git a/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/response/AIChatResponse.java b/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/response/AIChatResponse.java index e9196bd..06133c8 100644 --- a/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/response/AIChatResponse.java +++ b/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/response/AIChatResponse.java @@ -1,3 +1,3 @@ package com.example.chatbot.model.response; -public record AIChatResponse(String answer) {} +public record AIChatResponse(String answer, String conversationId) {} diff --git a/chatbot-ollama-springai/src/main/java/com/example/chatbot/service/ChatbotService.java b/chatbot-ollama-springai/src/main/java/com/example/chatbot/service/ChatbotService.java index c53dc62..7c828db 100644 --- a/chatbot-ollama-springai/src/main/java/com/example/chatbot/service/ChatbotService.java +++ b/chatbot-ollama-springai/src/main/java/com/example/chatbot/service/ChatbotService.java @@ -1,36 +1,34 @@ package com.example.chatbot.service; +import com.example.chatbot.model.request.AIChatRequest; import com.example.chatbot.model.response.AIChatResponse; -import java.util.List; -import org.springframework.ai.chat.memory.*; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.transformer.ChatServiceContext; import org.springframework.ai.chat.service.ChatService; import org.springframework.ai.chat.service.ChatServiceResponse; -import org.springframework.ai.chat.service.PromptTransformingChatService; -import org.springframework.ai.tokenizer.TokenCountEstimator; import org.springframework.stereotype.Service; @Service public class ChatbotService { + private static final Logger LOGGER = LoggerFactory.getLogger(ChatbotService.class); + private final ChatService chatService; - ChatbotService(ChatModel chatModel, ChatMemory chatMemory, TokenCountEstimator tokenCountEstimator) { - this.chatService = PromptTransformingChatService.builder(chatModel) - .withRetrievers(List.of(new ChatMemoryRetriever(chatMemory))) - .withContentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000))) - .withAugmentors(List.of(new SystemPromptChatMemoryAugmentor())) - .withChatServiceListeners(List.of(new ChatMemoryChatServiceListener(chatMemory))) - .build(); + public ChatbotService(ChatService chatService) { + this.chatService = chatService; } - public AIChatResponse chat(String message) { - Prompt prompt = new Prompt(new UserMessage(message)); - ChatServiceResponse chatServiceResponse = this.chatService.call(new ChatServiceContext(prompt)); + public AIChatResponse chat(AIChatRequest request) { + Prompt prompt = new Prompt(new UserMessage(request.query())); + String conversationId = request.conversationId() == null ? "default" : request.conversationId(); + ChatServiceResponse chatServiceResponse = this.chatService.call(new ChatServiceContext(prompt, conversationId)); + LOGGER.info("Response :{}", chatServiceResponse.getChatResponse().getResult()); return new AIChatResponse( - chatServiceResponse.getChatResponse().getResult().getOutput().getContent()); + chatServiceResponse.getChatResponse().getResult().getOutput().getContent(), + chatServiceResponse.getPromptContext().getConversationId()); } } diff --git a/chatbot-ollama-springai/src/main/resources/application.properties b/chatbot-ollama-springai/src/main/resources/application.properties index edce8bf..8823695 100644 --- a/chatbot-ollama-springai/src/main/resources/application.properties +++ b/chatbot-ollama-springai/src/main/resources/application.properties @@ -1,5 +1,9 @@ spring.application.name=chatbot-ollama +spring.threads.virtual.enabled=true +spring.mvc.problemdetails.enabled=true + spring.ai.ollama.chat.options.model=llama3 +spring.ai.ollama.embedding.options.model=llama3 -spring.threads.virtual.enabled=true +spring.testcontainers.beans.startup=parallel diff --git a/chatbot-ollama-springai/src/test/java/com/example/chatbot/ChatbotOllamaApplicationTests.java b/chatbot-ollama-springai/src/test/java/com/example/chatbot/ChatbotOllamaApplicationTests.java index 30a0806..bee559e 100644 --- a/chatbot-ollama-springai/src/test/java/com/example/chatbot/ChatbotOllamaApplicationTests.java +++ b/chatbot-ollama-springai/src/test/java/com/example/chatbot/ChatbotOllamaApplicationTests.java @@ -4,12 +4,19 @@ import static org.hamcrest.Matchers.containsString; import com.example.chatbot.model.request.AIChatRequest; +import com.example.chatbot.model.response.AIChatResponse; +import com.fasterxml.jackson.core.exc.StreamReadException; +import com.fasterxml.jackson.databind.DatabindException; +import com.fasterxml.jackson.databind.ObjectMapper; import io.restassured.RestAssured; import io.restassured.http.ContentType; +import io.restassured.response.Response; +import java.io.IOException; import org.apache.http.HttpStatus; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.web.server.LocalServerPort; @@ -19,6 +26,9 @@ @TestInstance(TestInstance.Lifecycle.PER_CLASS) class ChatbotOllamaApplicationTests { + @Autowired + private ObjectMapper objectMapper; + @LocalServerPort private int localServerPort; @@ -28,35 +38,37 @@ void setUp() { } @Test - void contextLoads() { - given().contentType(ContentType.JSON) - .body(new AIChatRequest("Hello?")) - .when() - .post("/api/ai/chat") - .then() - .statusCode(HttpStatus.SC_OK) - .contentType(ContentType.JSON) - .body("answer", containsString("help")); - } + void chat() throws StreamReadException, DatabindException, IOException { - @Test - void chat() { - given().contentType(ContentType.JSON) - .body(new AIChatRequest("How many cricket centuries did Sachin Tendulkar scored ?")) + Response response = given().contentType(ContentType.JSON) + .body(new AIChatRequest( + "As a cricketer, how many centuries did Sachin Tendulkar scored adding up both One Day International (ODI) and Test centuries ?", + null)) .when() .post("/api/ai/chat") .then() .statusCode(HttpStatus.SC_OK) .contentType(ContentType.JSON) - .body("answer", containsString("100")); + .body("answer", containsString("100")) + .log() + .all(true) + .extract() + .response(); + + AIChatResponse aiChatResponse = objectMapper.readValue(response.asByteArray(), AIChatResponse.class); + System.out.println("conversationalId :: " + aiChatResponse.conversationId()); given().contentType(ContentType.JSON) - .body(new AIChatRequest("What is his age ?")) + .body(new AIChatRequest( + "How many One Day International (ODI) centuries did he scored ?", + aiChatResponse.conversationId())) .when() .post("/api/ai/chat") .then() .statusCode(HttpStatus.SC_OK) .contentType(ContentType.JSON) - .body("answer", containsString("50")); + .body("answer", containsString("49")) + .log() + .all(true); } } diff --git a/chatbot-ollama-springai/src/test/java/com/example/chatbot/TestChatbotOllamaApplication.java b/chatbot-ollama-springai/src/test/java/com/example/chatbot/TestChatbotOllamaApplication.java index 21ea7b2..043780d 100644 --- a/chatbot-ollama-springai/src/test/java/com/example/chatbot/TestChatbotOllamaApplication.java +++ b/chatbot-ollama-springai/src/test/java/com/example/chatbot/TestChatbotOllamaApplication.java @@ -4,6 +4,7 @@ import org.springframework.boot.test.context.TestConfiguration; import org.springframework.boot.testcontainers.service.connection.ServiceConnection; import org.springframework.context.annotation.Bean; +import org.testcontainers.chromadb.ChromaDBContainer; import org.testcontainers.ollama.OllamaContainer; import org.testcontainers.utility.DockerImageName; @@ -17,6 +18,12 @@ OllamaContainer ollama() { DockerImageName.parse("langchain4j/ollama-llama3:latest").asCompatibleSubstituteFor("ollama/ollama")); } + @Bean + @ServiceConnection + ChromaDBContainer chromadb() { + return new ChromaDBContainer(DockerImageName.parse("chromadb/chroma").withTag("0.5.0")); + } + public static void main(String[] args) { SpringApplication.from(ChatbotOllamaApplication::main) .with(TestChatbotOllamaApplication.class)