diff --git a/chatbot-ollama-springai/ReadMe.md b/chatbot-ollama-springai/ReadMe.md
index c351fb0..3af27ff 100644
--- a/chatbot-ollama-springai/ReadMe.md
+++ b/chatbot-ollama-springai/ReadMe.md
@@ -3,6 +3,8 @@
## Sequence Diagram
+Before Vector Store
+
```mermaid
sequenceDiagram
participant User
@@ -19,4 +21,28 @@ participant ChatMemory
ChatbotService-->>ChatbotController: response
ChatbotController-->>User: response
+```
+
+After Vector Store
+
+```mermaid
+sequenceDiagram
+ participant User
+ participant ChatbotController
+ participant ChatbotService
+ participant ChatService
+ participant ChatMemory
+ participant VectorStore
+
+ User->>ChatbotController: POST /api/ai/chat
+ ChatbotController->>ChatbotService: chat(message)
+ ChatbotService->>ChatService: Process Chat Request
+ ChatService->>ChatMemory: Retrieve Memory
+ ChatService->>VectorStore: Retrieve Vectors
+ ChatMemory-->>ChatService: Return Memory Data
+ VectorStore-->>ChatService: Return Vector Data
+ ChatService-->>ChatbotService: Processed Response
+ ChatbotService-->>ChatbotController: response
+ ChatbotController-->>User: response
+
```
\ No newline at end of file
diff --git a/chatbot-ollama-springai/pom.xml b/chatbot-ollama-springai/pom.xml
index d0da21e..279e116 100644
--- a/chatbot-ollama-springai/pom.xml
+++ b/chatbot-ollama-springai/pom.xml
@@ -10,7 +10,7 @@
com.example.chatbot
- chatbot-ollama-lamma3
+ chatbot-ollama-springai
0.0.1-SNAPSHOT
chatbot-ollama
Demo project for Chatbot using Ollama
@@ -34,6 +34,10 @@
org.springframework.ai
spring-ai-ollama-spring-boot-starter
+
+ org.springframework.ai
+ spring-ai-chroma-store-spring-boot-starter
+
org.springframework.boot
@@ -45,11 +49,6 @@
rest-assured
test
-
- org.springframework.boot
- spring-boot-testcontainers
- test
-
org.springframework.ai
spring-ai-spring-boot-testcontainers
@@ -65,6 +64,11 @@
ollama
test
+
+ org.testcontainers
+ chromadb
+ test
+
@@ -92,7 +96,7 @@
- 2.41.0
+ 2.47.0
diff --git a/chatbot-ollama-springai/src/main/java/com/example/chatbot/config/ChatConfig.java b/chatbot-ollama-springai/src/main/java/com/example/chatbot/config/ChatConfig.java
index 2a7ceea..74525be 100644
--- a/chatbot-ollama-springai/src/main/java/com/example/chatbot/config/ChatConfig.java
+++ b/chatbot-ollama-springai/src/main/java/com/example/chatbot/config/ChatConfig.java
@@ -1,9 +1,26 @@
package com.example.chatbot.config;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
import org.springframework.ai.chat.memory.ChatMemory;
+import org.springframework.ai.chat.memory.ChatMemoryChatServiceListener;
+import org.springframework.ai.chat.memory.ChatMemoryRetriever;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
+import org.springframework.ai.chat.memory.LastMaxTokenSizeContentTransformer;
+import org.springframework.ai.chat.memory.SystemPromptChatMemoryAugmentor;
+import org.springframework.ai.chat.memory.VectorStoreChatMemoryChatServiceListener;
+import org.springframework.ai.chat.memory.VectorStoreChatMemoryRetriever;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.chat.prompt.transformer.QuestionContextAugmentor;
+import org.springframework.ai.chat.prompt.transformer.TransformerContentType;
+import org.springframework.ai.chat.prompt.transformer.VectorStoreRetriever;
+import org.springframework.ai.chat.service.ChatService;
+import org.springframework.ai.chat.service.PromptTransformingChatService;
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
import org.springframework.ai.tokenizer.TokenCountEstimator;
+import org.springframework.ai.vectorstore.SearchRequest;
+import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@@ -19,4 +36,44 @@ ChatMemory chatHistory() {
TokenCountEstimator tokenCountEstimator() {
return new JTokkitTokenCountEstimator();
}
+
+ @Bean
+ ChatService chatService(
+ ChatModel chatModel,
+ ChatMemory chatMemory,
+ TokenCountEstimator tokenCountEstimator,
+ VectorStore vectorStore) {
+ return PromptTransformingChatService.builder(chatModel)
+ .withRetrievers(List.of(
+ new VectorStoreRetriever(vectorStore, SearchRequest.defaults()),
+ ChatMemoryRetriever.builder()
+ .withChatHistory(chatMemory)
+ .withMetadata(Map.of(TransformerContentType.SHORT_TERM_MEMORY, ""))
+ .build(),
+ new VectorStoreChatMemoryRetriever(
+ vectorStore, 10, Map.of(TransformerContentType.LONG_TERM_MEMORY, ""))))
+ .withContentPostProcessors(List.of(
+ new LastMaxTokenSizeContentTransformer(
+ tokenCountEstimator, 1000, Set.of(TransformerContentType.SHORT_TERM_MEMORY)),
+ new LastMaxTokenSizeContentTransformer(
+ tokenCountEstimator, 1000, Set.of(TransformerContentType.LONG_TERM_MEMORY)),
+ new LastMaxTokenSizeContentTransformer(
+ tokenCountEstimator, 2000, Set.of(TransformerContentType.EXTERNAL_KNOWLEDGE))))
+ .withAugmentors(List.of(
+ new QuestionContextAugmentor(),
+ new SystemPromptChatMemoryAugmentor(
+ """
+ Use the long term conversation history from the LONG TERM HISTORY section to provide accurate answers.
+
+ LONG TERM HISTORY:
+ {history}
+ """,
+ Set.of(TransformerContentType.LONG_TERM_MEMORY)),
+ new SystemPromptChatMemoryAugmentor(Set.of(TransformerContentType.SHORT_TERM_MEMORY))))
+ .withChatServiceListeners(List.of(
+ new ChatMemoryChatServiceListener(chatMemory),
+ new VectorStoreChatMemoryChatServiceListener(
+ vectorStore, Map.of(TransformerContentType.LONG_TERM_MEMORY, ""))))
+ .build();
+ }
}
diff --git a/chatbot-ollama-springai/src/main/java/com/example/chatbot/controller/ChatbotController.java b/chatbot-ollama-springai/src/main/java/com/example/chatbot/controller/ChatbotController.java
index a77aed9..7231a0b 100644
--- a/chatbot-ollama-springai/src/main/java/com/example/chatbot/controller/ChatbotController.java
+++ b/chatbot-ollama-springai/src/main/java/com/example/chatbot/controller/ChatbotController.java
@@ -20,6 +20,6 @@ class ChatbotController {
@PostMapping("/chat")
AIChatResponse chat(@RequestBody AIChatRequest request) {
- return chatbotService.chat(request.query());
+ return chatbotService.chat(request);
}
}
diff --git a/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/request/AIChatRequest.java b/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/request/AIChatRequest.java
index 34712bc..1726c2a 100644
--- a/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/request/AIChatRequest.java
+++ b/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/request/AIChatRequest.java
@@ -1,3 +1,3 @@
package com.example.chatbot.model.request;
-public record AIChatRequest(String query) {}
+public record AIChatRequest(String query, String conversationId) {}
diff --git a/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/response/AIChatResponse.java b/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/response/AIChatResponse.java
index e9196bd..06133c8 100644
--- a/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/response/AIChatResponse.java
+++ b/chatbot-ollama-springai/src/main/java/com/example/chatbot/model/response/AIChatResponse.java
@@ -1,3 +1,3 @@
package com.example.chatbot.model.response;
-public record AIChatResponse(String answer) {}
+public record AIChatResponse(String answer, String conversationId) {}
diff --git a/chatbot-ollama-springai/src/main/java/com/example/chatbot/service/ChatbotService.java b/chatbot-ollama-springai/src/main/java/com/example/chatbot/service/ChatbotService.java
index c53dc62..7c828db 100644
--- a/chatbot-ollama-springai/src/main/java/com/example/chatbot/service/ChatbotService.java
+++ b/chatbot-ollama-springai/src/main/java/com/example/chatbot/service/ChatbotService.java
@@ -1,36 +1,34 @@
package com.example.chatbot.service;
+import com.example.chatbot.model.request.AIChatRequest;
import com.example.chatbot.model.response.AIChatResponse;
-import java.util.List;
-import org.springframework.ai.chat.memory.*;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.UserMessage;
-import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.transformer.ChatServiceContext;
import org.springframework.ai.chat.service.ChatService;
import org.springframework.ai.chat.service.ChatServiceResponse;
-import org.springframework.ai.chat.service.PromptTransformingChatService;
-import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.stereotype.Service;
@Service
public class ChatbotService {
+ private static final Logger LOGGER = LoggerFactory.getLogger(ChatbotService.class);
+
private final ChatService chatService;
- ChatbotService(ChatModel chatModel, ChatMemory chatMemory, TokenCountEstimator tokenCountEstimator) {
- this.chatService = PromptTransformingChatService.builder(chatModel)
- .withRetrievers(List.of(new ChatMemoryRetriever(chatMemory)))
- .withContentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
- .withAugmentors(List.of(new SystemPromptChatMemoryAugmentor()))
- .withChatServiceListeners(List.of(new ChatMemoryChatServiceListener(chatMemory)))
- .build();
+ public ChatbotService(ChatService chatService) {
+ this.chatService = chatService;
}
- public AIChatResponse chat(String message) {
- Prompt prompt = new Prompt(new UserMessage(message));
- ChatServiceResponse chatServiceResponse = this.chatService.call(new ChatServiceContext(prompt));
+ public AIChatResponse chat(AIChatRequest request) {
+ Prompt prompt = new Prompt(new UserMessage(request.query()));
+ String conversationId = request.conversationId() == null ? "default" : request.conversationId();
+ ChatServiceResponse chatServiceResponse = this.chatService.call(new ChatServiceContext(prompt, conversationId));
+ LOGGER.info("Response :{}", chatServiceResponse.getChatResponse().getResult());
return new AIChatResponse(
- chatServiceResponse.getChatResponse().getResult().getOutput().getContent());
+ chatServiceResponse.getChatResponse().getResult().getOutput().getContent(),
+ chatServiceResponse.getPromptContext().getConversationId());
}
}
diff --git a/chatbot-ollama-springai/src/main/resources/application.properties b/chatbot-ollama-springai/src/main/resources/application.properties
index edce8bf..8823695 100644
--- a/chatbot-ollama-springai/src/main/resources/application.properties
+++ b/chatbot-ollama-springai/src/main/resources/application.properties
@@ -1,5 +1,9 @@
spring.application.name=chatbot-ollama
+spring.threads.virtual.enabled=true
+spring.mvc.problemdetails.enabled=true
+
spring.ai.ollama.chat.options.model=llama3
+spring.ai.ollama.embedding.options.model=llama3
-spring.threads.virtual.enabled=true
+spring.testcontainers.beans.startup=parallel
diff --git a/chatbot-ollama-springai/src/test/java/com/example/chatbot/ChatbotOllamaApplicationTests.java b/chatbot-ollama-springai/src/test/java/com/example/chatbot/ChatbotOllamaApplicationTests.java
index 30a0806..bee559e 100644
--- a/chatbot-ollama-springai/src/test/java/com/example/chatbot/ChatbotOllamaApplicationTests.java
+++ b/chatbot-ollama-springai/src/test/java/com/example/chatbot/ChatbotOllamaApplicationTests.java
@@ -4,12 +4,19 @@
import static org.hamcrest.Matchers.containsString;
import com.example.chatbot.model.request.AIChatRequest;
+import com.example.chatbot.model.response.AIChatResponse;
+import com.fasterxml.jackson.core.exc.StreamReadException;
+import com.fasterxml.jackson.databind.DatabindException;
+import com.fasterxml.jackson.databind.ObjectMapper;
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
+import io.restassured.response.Response;
+import java.io.IOException;
import org.apache.http.HttpStatus;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
+import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.server.LocalServerPort;
@@ -19,6 +26,9 @@
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class ChatbotOllamaApplicationTests {
+ @Autowired
+ private ObjectMapper objectMapper;
+
@LocalServerPort
private int localServerPort;
@@ -28,35 +38,37 @@ void setUp() {
}
@Test
- void contextLoads() {
- given().contentType(ContentType.JSON)
- .body(new AIChatRequest("Hello?"))
- .when()
- .post("/api/ai/chat")
- .then()
- .statusCode(HttpStatus.SC_OK)
- .contentType(ContentType.JSON)
- .body("answer", containsString("help"));
- }
+ void chat() throws StreamReadException, DatabindException, IOException {
- @Test
- void chat() {
- given().contentType(ContentType.JSON)
- .body(new AIChatRequest("How many cricket centuries did Sachin Tendulkar scored ?"))
+ Response response = given().contentType(ContentType.JSON)
+ .body(new AIChatRequest(
+ "As a cricketer, how many centuries did Sachin Tendulkar scored adding up both One Day International (ODI) and Test centuries ?",
+ null))
.when()
.post("/api/ai/chat")
.then()
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
- .body("answer", containsString("100"));
+ .body("answer", containsString("100"))
+ .log()
+ .all(true)
+ .extract()
+ .response();
+
+ AIChatResponse aiChatResponse = objectMapper.readValue(response.asByteArray(), AIChatResponse.class);
+ System.out.println("conversationalId :: " + aiChatResponse.conversationId());
given().contentType(ContentType.JSON)
- .body(new AIChatRequest("What is his age ?"))
+ .body(new AIChatRequest(
+ "How many One Day International (ODI) centuries did he scored ?",
+ aiChatResponse.conversationId()))
.when()
.post("/api/ai/chat")
.then()
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
- .body("answer", containsString("50"));
+ .body("answer", containsString("49"))
+ .log()
+ .all(true);
}
}
diff --git a/chatbot-ollama-springai/src/test/java/com/example/chatbot/TestChatbotOllamaApplication.java b/chatbot-ollama-springai/src/test/java/com/example/chatbot/TestChatbotOllamaApplication.java
index 21ea7b2..043780d 100644
--- a/chatbot-ollama-springai/src/test/java/com/example/chatbot/TestChatbotOllamaApplication.java
+++ b/chatbot-ollama-springai/src/test/java/com/example/chatbot/TestChatbotOllamaApplication.java
@@ -4,6 +4,7 @@
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.boot.testcontainers.service.connection.ServiceConnection;
import org.springframework.context.annotation.Bean;
+import org.testcontainers.chromadb.ChromaDBContainer;
import org.testcontainers.ollama.OllamaContainer;
import org.testcontainers.utility.DockerImageName;
@@ -17,6 +18,12 @@ OllamaContainer ollama() {
DockerImageName.parse("langchain4j/ollama-llama3:latest").asCompatibleSubstituteFor("ollama/ollama"));
}
+ @Bean
+ @ServiceConnection
+ ChromaDBContainer chromadb() {
+ return new ChromaDBContainer(DockerImageName.parse("chromadb/chroma").withTag("0.5.0"));
+ }
+
public static void main(String[] args) {
SpringApplication.from(ChatbotOllamaApplication::main)
.with(TestChatbotOllamaApplication.class)