Skip to content

Commit

Permalink
feat : adds vector store to chatbot (#58)
Browse files Browse the repository at this point in the history
* feat : adds vector store to chatbot

* Update application.properties

* upgrade to latest version

* Update pom.xml

* fix : wrong pom update

* adds sequence diagram

* tweak prompt for more data accuracy

* feat : removes test case which doesnt adds context

* feat : return conversationalId

* reads conversational Id and passes to next question

* minor polish
  • Loading branch information
rajadilipkolli authored May 27, 2024
1 parent f9f3956 commit b2692dd
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 44 deletions.
26 changes: 26 additions & 0 deletions chatbot-ollama-springai/ReadMe.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

## Sequence Diagram

Before Vector Store

```mermaid
sequenceDiagram
participant User
Expand All @@ -19,4 +21,28 @@ participant ChatMemory
ChatbotService-->>ChatbotController: response
ChatbotController-->>User: response
```

After Vector Store

```mermaid
sequenceDiagram
participant User
participant ChatbotController
participant ChatbotService
participant ChatService
participant ChatMemory
participant VectorStore
User->>ChatbotController: POST /api/ai/chat
ChatbotController->>ChatbotService: chat(message)
ChatbotService->>ChatService: Process Chat Request
ChatService->>ChatMemory: Retrieve Memory
ChatService->>VectorStore: Retrieve Vectors
ChatMemory-->>ChatService: Return Memory Data
VectorStore-->>ChatService: Return Vector Data
ChatService-->>ChatbotService: Processed Response
ChatbotService-->>ChatbotController: response
ChatbotController-->>User: response
```
18 changes: 11 additions & 7 deletions chatbot-ollama-springai/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
<relativePath /> <!-- lookup parent from repository -->
</parent>
<groupId>com.example.chatbot</groupId>
<artifactId>chatbot-ollama-lamma3</artifactId>
<artifactId>chatbot-ollama-springai</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>chatbot-ollama</name>
<description>Demo project for Chatbot using Ollama</description>
Expand All @@ -34,6 +34,10 @@
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-chroma-store-spring-boot-starter</artifactId>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
Expand All @@ -45,11 +49,6 @@
<artifactId>rest-assured</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-testcontainers</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-spring-boot-testcontainers</artifactId>
Expand All @@ -65,6 +64,11 @@
<artifactId>ollama</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>chromadb</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<dependencyManagement>
Expand Down Expand Up @@ -92,7 +96,7 @@
<configuration>
<java>
<palantirJavaFormat>
<version>2.41.0</version>
<version>2.47.0</version>
</palantirJavaFormat>
<importOrder />
<removeUnusedImports />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@
package com.example.chatbot.config;

import java.util.List;
import java.util.Map;
import java.util.Set;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.ChatMemoryChatServiceListener;
import org.springframework.ai.chat.memory.ChatMemoryRetriever;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chat.memory.LastMaxTokenSizeContentTransformer;
import org.springframework.ai.chat.memory.SystemPromptChatMemoryAugmentor;
import org.springframework.ai.chat.memory.VectorStoreChatMemoryChatServiceListener;
import org.springframework.ai.chat.memory.VectorStoreChatMemoryRetriever;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.transformer.QuestionContextAugmentor;
import org.springframework.ai.chat.prompt.transformer.TransformerContentType;
import org.springframework.ai.chat.prompt.transformer.VectorStoreRetriever;
import org.springframework.ai.chat.service.ChatService;
import org.springframework.ai.chat.service.PromptTransformingChatService;
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

Expand All @@ -19,4 +36,44 @@ ChatMemory chatHistory() {
TokenCountEstimator tokenCountEstimator() {
return new JTokkitTokenCountEstimator();
}

@Bean
ChatService chatService(
ChatModel chatModel,
ChatMemory chatMemory,
TokenCountEstimator tokenCountEstimator,
VectorStore vectorStore) {
return PromptTransformingChatService.builder(chatModel)
.withRetrievers(List.of(
new VectorStoreRetriever(vectorStore, SearchRequest.defaults()),
ChatMemoryRetriever.builder()
.withChatHistory(chatMemory)
.withMetadata(Map.of(TransformerContentType.SHORT_TERM_MEMORY, ""))
.build(),
new VectorStoreChatMemoryRetriever(
vectorStore, 10, Map.of(TransformerContentType.LONG_TERM_MEMORY, ""))))
.withContentPostProcessors(List.of(
new LastMaxTokenSizeContentTransformer(
tokenCountEstimator, 1000, Set.of(TransformerContentType.SHORT_TERM_MEMORY)),
new LastMaxTokenSizeContentTransformer(
tokenCountEstimator, 1000, Set.of(TransformerContentType.LONG_TERM_MEMORY)),
new LastMaxTokenSizeContentTransformer(
tokenCountEstimator, 2000, Set.of(TransformerContentType.EXTERNAL_KNOWLEDGE))))
.withAugmentors(List.of(
new QuestionContextAugmentor(),
new SystemPromptChatMemoryAugmentor(
"""
Use the long term conversation history from the LONG TERM HISTORY section to provide accurate answers.
LONG TERM HISTORY:
{history}
""",
Set.of(TransformerContentType.LONG_TERM_MEMORY)),
new SystemPromptChatMemoryAugmentor(Set.of(TransformerContentType.SHORT_TERM_MEMORY))))
.withChatServiceListeners(List.of(
new ChatMemoryChatServiceListener(chatMemory),
new VectorStoreChatMemoryChatServiceListener(
vectorStore, Map.of(TransformerContentType.LONG_TERM_MEMORY, ""))))
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ class ChatbotController {

@PostMapping("/chat")
AIChatResponse chat(@RequestBody AIChatRequest request) {
return chatbotService.chat(request.query());
return chatbotService.chat(request);
}
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
package com.example.chatbot.model.request;

public record AIChatRequest(String query) {}
public record AIChatRequest(String query, String conversationId) {}
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
package com.example.chatbot.model.response;

public record AIChatResponse(String answer) {}
public record AIChatResponse(String answer, String conversationId) {}
Original file line number Diff line number Diff line change
@@ -1,36 +1,34 @@
package com.example.chatbot.service;

import com.example.chatbot.model.request.AIChatRequest;
import com.example.chatbot.model.response.AIChatResponse;
import java.util.List;
import org.springframework.ai.chat.memory.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.transformer.ChatServiceContext;
import org.springframework.ai.chat.service.ChatService;
import org.springframework.ai.chat.service.ChatServiceResponse;
import org.springframework.ai.chat.service.PromptTransformingChatService;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.stereotype.Service;

@Service
public class ChatbotService {

private static final Logger LOGGER = LoggerFactory.getLogger(ChatbotService.class);

private final ChatService chatService;

ChatbotService(ChatModel chatModel, ChatMemory chatMemory, TokenCountEstimator tokenCountEstimator) {
this.chatService = PromptTransformingChatService.builder(chatModel)
.withRetrievers(List.of(new ChatMemoryRetriever(chatMemory)))
.withContentPostProcessors(List.of(new LastMaxTokenSizeContentTransformer(tokenCountEstimator, 1000)))
.withAugmentors(List.of(new SystemPromptChatMemoryAugmentor()))
.withChatServiceListeners(List.of(new ChatMemoryChatServiceListener(chatMemory)))
.build();
public ChatbotService(ChatService chatService) {
this.chatService = chatService;
}

public AIChatResponse chat(String message) {
Prompt prompt = new Prompt(new UserMessage(message));
ChatServiceResponse chatServiceResponse = this.chatService.call(new ChatServiceContext(prompt));
public AIChatResponse chat(AIChatRequest request) {
Prompt prompt = new Prompt(new UserMessage(request.query()));
String conversationId = request.conversationId() == null ? "default" : request.conversationId();
ChatServiceResponse chatServiceResponse = this.chatService.call(new ChatServiceContext(prompt, conversationId));
LOGGER.info("Response :{}", chatServiceResponse.getChatResponse().getResult());
return new AIChatResponse(
chatServiceResponse.getChatResponse().getResult().getOutput().getContent());
chatServiceResponse.getChatResponse().getResult().getOutput().getContent(),
chatServiceResponse.getPromptContext().getConversationId());
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
spring.application.name=chatbot-ollama

spring.threads.virtual.enabled=true
spring.mvc.problemdetails.enabled=true

spring.ai.ollama.chat.options.model=llama3
spring.ai.ollama.embedding.options.model=llama3

spring.threads.virtual.enabled=true
spring.testcontainers.beans.startup=parallel
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@
import static org.hamcrest.Matchers.containsString;

import com.example.chatbot.model.request.AIChatRequest;
import com.example.chatbot.model.response.AIChatResponse;
import com.fasterxml.jackson.core.exc.StreamReadException;
import com.fasterxml.jackson.databind.DatabindException;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.restassured.RestAssured;
import io.restassured.http.ContentType;
import io.restassured.response.Response;
import java.io.IOException;
import org.apache.http.HttpStatus;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.server.LocalServerPort;

Expand All @@ -19,6 +26,9 @@
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class ChatbotOllamaApplicationTests {

@Autowired
private ObjectMapper objectMapper;

@LocalServerPort
private int localServerPort;

Expand All @@ -28,35 +38,37 @@ void setUp() {
}

@Test
void contextLoads() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("Hello?"))
.when()
.post("/api/ai/chat")
.then()
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
.body("answer", containsString("help"));
}
void chat() throws StreamReadException, DatabindException, IOException {

@Test
void chat() {
given().contentType(ContentType.JSON)
.body(new AIChatRequest("How many cricket centuries did Sachin Tendulkar scored ?"))
Response response = given().contentType(ContentType.JSON)
.body(new AIChatRequest(
"As a cricketer, how many centuries did Sachin Tendulkar scored adding up both One Day International (ODI) and Test centuries ?",
null))
.when()
.post("/api/ai/chat")
.then()
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
.body("answer", containsString("100"));
.body("answer", containsString("100"))
.log()
.all(true)
.extract()
.response();

AIChatResponse aiChatResponse = objectMapper.readValue(response.asByteArray(), AIChatResponse.class);
System.out.println("conversationalId :: " + aiChatResponse.conversationId());

given().contentType(ContentType.JSON)
.body(new AIChatRequest("What is his age ?"))
.body(new AIChatRequest(
"How many One Day International (ODI) centuries did he scored ?",
aiChatResponse.conversationId()))
.when()
.post("/api/ai/chat")
.then()
.statusCode(HttpStatus.SC_OK)
.contentType(ContentType.JSON)
.body("answer", containsString("50"));
.body("answer", containsString("49"))
.log()
.all(true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.boot.testcontainers.service.connection.ServiceConnection;
import org.springframework.context.annotation.Bean;
import org.testcontainers.chromadb.ChromaDBContainer;
import org.testcontainers.ollama.OllamaContainer;
import org.testcontainers.utility.DockerImageName;

Expand All @@ -17,6 +18,12 @@ OllamaContainer ollama() {
DockerImageName.parse("langchain4j/ollama-llama3:latest").asCompatibleSubstituteFor("ollama/ollama"));
}

@Bean
@ServiceConnection
ChromaDBContainer chromadb() {
return new ChromaDBContainer(DockerImageName.parse("chromadb/chroma").withTag("0.5.0"));
}

public static void main(String[] args) {
SpringApplication.from(ChatbotOllamaApplication::main)
.with(TestChatbotOllamaApplication.class)
Expand Down

0 comments on commit b2692dd

Please sign in to comment.