From d9a1d416ae7bba6edc9a8cc4cda649aba0ce0370 Mon Sep 17 00:00:00 2001 From: Raja Kolli Date: Fri, 10 May 2024 16:29:12 +0530 Subject: [PATCH] feat : adds metadata filtering (#48) --- .../com/learning/ai/config/Initializer.java | 7 +++++- .../com/learning/ai/config/SwaggerConfig.java | 22 ++++++++--------- .../ai/controller/QueryController.java | 4 ++-- .../ai/service/PgVectorStoreService.java | 24 +++++++++++++------ .../ai/controller/TestQueryController.java | 22 ++++++++++++++++- .../ai/service/CustomerSupportService.java | 11 ++++++--- .../config/FunctionConfiguration.java | 22 +++++++++++++++++ .../service/AIChatService.java | 5 +++- .../src/main/resources/application.properties | 2 +- 9 files changed, 91 insertions(+), 28 deletions(-) create mode 100644 rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/config/FunctionConfiguration.java diff --git a/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/config/Initializer.java b/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/config/Initializer.java index 5fc98f2..4465ee1 100644 --- a/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/config/Initializer.java +++ b/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/config/Initializer.java @@ -1,5 +1,6 @@ package com.learning.ai.config; +import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; @@ -20,10 +21,14 @@ public Initializer(EmbeddingModel embeddingModel, EmbeddingStore em @Override public void run(String... args) throws Exception { - TextSegment segment1 = TextSegment.from("I like football."); + TextSegment segment1 = TextSegment.from("I like football.", Metadata.metadata("userId", "1")); Embedding embedding1 = embeddingModel.embed(segment1).content(); embeddingStore.add(embedding1, segment1); + segment1 = TextSegment.from("I like cricket.", Metadata.metadata("userId", "2")); + embedding1 = embeddingModel.embed(segment1).content(); + embeddingStore.add(embedding1, segment1); + TextSegment segment2 = TextSegment.from("The weather is good today."); Embedding embedding2 = embeddingModel.embed(segment2).content(); embeddingStore.add(embedding2, segment2); diff --git a/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/config/SwaggerConfig.java b/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/config/SwaggerConfig.java index 8079283..adf18e8 100644 --- a/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/config/SwaggerConfig.java +++ b/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/config/SwaggerConfig.java @@ -1,12 +1,10 @@ -package com.learning.ai.config; - -import io.swagger.v3.oas.annotations.OpenAPIDefinition; -import io.swagger.v3.oas.annotations.info.Info; -import io.swagger.v3.oas.annotations.servers.Server; -import org.springframework.context.annotation.Configuration; - -@Configuration(proxyBeanMethods = false) -@OpenAPIDefinition( - info = @Info(title = "pgvector-langchain4j", version = "v1.0.0"), - servers = @Server(url = "/")) -public class SwaggerConfig {} +package com.learning.ai.config; + +import io.swagger.v3.oas.annotations.OpenAPIDefinition; +import io.swagger.v3.oas.annotations.info.Info; +import io.swagger.v3.oas.annotations.servers.Server; +import org.springframework.context.annotation.Configuration; + +@Configuration(proxyBeanMethods = false) +@OpenAPIDefinition(info = @Info(title = "pgvector-langchain4j", version = "v1.0.0"), servers = @Server(url = "/")) +public class SwaggerConfig {} diff --git a/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/controller/QueryController.java b/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/controller/QueryController.java index db99711..1921b89 100644 --- a/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/controller/QueryController.java +++ b/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/controller/QueryController.java @@ -18,7 +18,7 @@ public QueryController(PgVectorStoreService pgVectorStoreService) { } @GetMapping("/query") - AIChatResponse queryEmbeddedStore(@RequestParam String question) { - return pgVectorStoreService.queryEmbeddingStore(question); + AIChatResponse queryEmbeddedStore(@RequestParam String question, @RequestParam(required = false) Integer userId) { + return pgVectorStoreService.queryEmbeddingStore(question, userId); } } diff --git a/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/service/PgVectorStoreService.java b/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/service/PgVectorStoreService.java index e52ef61..d4d75b4 100644 --- a/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/service/PgVectorStoreService.java +++ b/embeddingstores/pgvector-langchain4j/src/main/java/com/learning/ai/service/PgVectorStoreService.java @@ -5,8 +5,11 @@ import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingStore; -import java.util.List; +import dev.langchain4j.store.embedding.filter.Filter; +import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Service; @@ -24,14 +27,21 @@ public PgVectorStoreService(EmbeddingModel embeddingModel, EmbeddingStore> relevant = embeddingStore.findRelevant(queryEmbedding, 1); - EmbeddingMatch embeddingMatch = relevant.get(0); + EmbeddingSearchRequest.EmbeddingSearchRequestBuilder embeddingSearchRequestBuilder = + EmbeddingSearchRequest.builder().queryEmbedding(queryEmbedding).maxResults(1); + if (userId != null) { + Filter equalTo = MetadataFilterBuilder.metadataKey("userId").isEqualTo(userId); + embeddingSearchRequestBuilder.filter(equalTo); + } + EmbeddingSearchRequest embeddingSearchRequest = embeddingSearchRequestBuilder.build(); + EmbeddingSearchResult relevant = embeddingStore.search(embeddingSearchRequest); + EmbeddingMatch embeddingMatch = relevant.matches().get(0); LOGGER.info("Score : {}", embeddingMatch.score()); // 0.8144288608390052 - LOGGER.info("Embedded Segment : {}", embeddingMatch.embedded()); - // I like football. - return new AIChatResponse(embeddingMatch.embedded().text()); + String answer = embeddingMatch.embedded().text(); + LOGGER.info("Embedded Segment : {}", answer); // I like football. + return new AIChatResponse(answer); } } diff --git a/embeddingstores/pgvector-langchain4j/src/test/java/com/learning/ai/controller/TestQueryController.java b/embeddingstores/pgvector-langchain4j/src/test/java/com/learning/ai/controller/TestQueryController.java index f0e4d33..be5eba2 100644 --- a/embeddingstores/pgvector-langchain4j/src/test/java/com/learning/ai/controller/TestQueryController.java +++ b/embeddingstores/pgvector-langchain4j/src/test/java/com/learning/ai/controller/TestQueryController.java @@ -6,14 +6,34 @@ import com.learning.ai.config.AbstractIntegrationTest; import org.hamcrest.Matchers; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; class TestQueryController extends AbstractIntegrationTest { @Test void queryEmbeddedStore() throws Exception { - mockMvc.perform(get("/api/ai/query").param("question", "What is your favourite sport")) + mockMvc.perform(get("/api/ai/query") + .param("question", "What is your favourite sport") + .param("userId", "1")) .andExpect(status().isOk()) .andExpect(jsonPath("$.answer", Matchers.is("I like football."))); } + + @Test + @Disabled("Fixed in later version of langchain4j > 0.30.0") + void queryEmbeddedStoreWithMetadata() throws Exception { + mockMvc.perform(get("/api/ai/query") + .param("question", "What is your favourite sport") + .param("userId", "2")) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.answer", Matchers.is("I like cricket."))); + } + + @Test + void queryEmbeddedStoreWithOutMetadata() throws Exception { + mockMvc.perform(get("/api/ai/query").param("question", "How is weather today")) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.answer", Matchers.is("The weather is good today."))); + } } diff --git a/rag/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/service/CustomerSupportService.java b/rag/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/service/CustomerSupportService.java index 7b1e4d1..2f0cca0 100644 --- a/rag/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/service/CustomerSupportService.java +++ b/rag/rag-langchain4j-AllMiniLmL6V2-llm/src/main/java/com/learning/ai/service/CustomerSupportService.java @@ -5,8 +5,9 @@ import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingStore; -import java.util.List; import org.springframework.stereotype.Service; @Service @@ -28,8 +29,12 @@ public CustomerSupportService( public AICustomerSupportResponse chat(String question) { Embedding queryEmbedding = embeddingModel.embed(question).content(); - List> relevant = embeddingStore.findRelevant(queryEmbedding, 1); - EmbeddingMatch embeddingMatch = relevant.get(0); + EmbeddingSearchRequest embeddingSearchRequest = EmbeddingSearchRequest.builder() + .queryEmbedding(queryEmbedding) + .maxResults(1) + .build(); + EmbeddingSearchResult relevant = embeddingStore.search(embeddingSearchRequest); + EmbeddingMatch embeddingMatch = relevant.matches().get(0); String embeddedText = embeddingMatch.embedded().text(); return aiCustomerSupportAgent.chat(question, embeddedText); diff --git a/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/config/FunctionConfiguration.java b/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/config/FunctionConfiguration.java new file mode 100644 index 0000000..9aa1a4e --- /dev/null +++ b/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/config/FunctionConfiguration.java @@ -0,0 +1,22 @@ +package com.learning.ai.llmragwithspringai.config; + +import java.time.LocalDate; +import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Description; + +@Configuration(proxyBeanMethods = false) +public class FunctionConfiguration { + + private static final Logger log = LoggerFactory.getLogger(FunctionConfiguration.class); + + @Bean + @Description("Get the current date or as of today.") + Function currentDateFunction() { + log.info("fetching from function"); + return unused -> LocalDate.now(); + } +} diff --git a/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java b/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java index 7d48aa7..403e5cb 100644 --- a/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java +++ b/rag/rag-springai-openai-llm/src/main/java/com/learning/ai/llmragwithspringai/service/AIChatService.java @@ -11,6 +11,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.document.Document; +import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.stereotype.Service; @@ -54,7 +55,9 @@ public String chat(String searchQuery) { // to answer the question. Message systemMessage = new SystemPromptTemplate(template).createMessage(Map.of("documents", documents)); UserMessage userMessage = new UserMessage(searchQuery); - Prompt prompt = new Prompt(List.of(systemMessage, userMessage)); + OpenAiChatOptions chatOptions = + OpenAiChatOptions.builder().withFunction("currentDateFunction").build(); + Prompt prompt = new Prompt(List.of(systemMessage, userMessage), chatOptions); ChatResponse aiResponse = aiClient.call(prompt); Generation generation = aiResponse.getResult(); return (generation != null) ? generation.getOutput().getContent() : ""; diff --git a/rag/rag-springai-openai-llm/src/main/resources/application.properties b/rag/rag-springai-openai-llm/src/main/resources/application.properties index b375ce1..fe4dfac 100644 --- a/rag/rag-springai-openai-llm/src/main/resources/application.properties +++ b/rag/rag-springai-openai-llm/src/main/resources/application.properties @@ -6,7 +6,7 @@ spring.mvc.problemdetails.enabled=true spring.ai.openai.api-key=demo spring.ai.openai.base-url=http://langchain4j.dev/demo/openai spring.ai.openai.chat.options.model=gpt-3.5-turbo -spring.ai.openai.chat.options.temperature=0.7 +spring.ai.openai.chat.options.temperature=0.2 spring.ai.openai.chat.options.responseFormat=json_object #spring.ai.openai.image.model=dall-e-3