From 03abbc739c3edbc1a1b90ecdb9783a3e19eafa6c Mon Sep 17 00:00:00 2001 From: Raja Kolli Date: Fri, 10 May 2024 15:39:37 +0530 Subject: [PATCH] feat : adds metadata filtering (#47) --- .../com/learning/ai/config/SwaggerConfig.java | 22 +++++++++---------- .../ai/controller/QueryController.java | 4 ++-- .../ai/service/PgVectorStoreService.java | 12 +++++++--- .../ai/controller/QueryControllerTest.java | 22 +++++++++++++++++++ 4 files changed, 43 insertions(+), 17 deletions(-) diff --git a/embeddingstores/pgvector-springai/src/main/java/com/learning/ai/config/SwaggerConfig.java b/embeddingstores/pgvector-springai/src/main/java/com/learning/ai/config/SwaggerConfig.java index 8bba440..c543547 100644 --- a/embeddingstores/pgvector-springai/src/main/java/com/learning/ai/config/SwaggerConfig.java +++ b/embeddingstores/pgvector-springai/src/main/java/com/learning/ai/config/SwaggerConfig.java @@ -1,12 +1,10 @@ -package com.learning.ai.config; - -import io.swagger.v3.oas.annotations.OpenAPIDefinition; -import io.swagger.v3.oas.annotations.info.Info; -import io.swagger.v3.oas.annotations.servers.Server; -import org.springframework.context.annotation.Configuration; - -@Configuration(proxyBeanMethods = false) -@OpenAPIDefinition( - info = @Info(title = "pgvector-springai", version = "v1.0.0"), - servers = @Server(url = "/")) -public class SwaggerConfig {} +package com.learning.ai.config; + +import io.swagger.v3.oas.annotations.OpenAPIDefinition; +import io.swagger.v3.oas.annotations.info.Info; +import io.swagger.v3.oas.annotations.servers.Server; +import org.springframework.context.annotation.Configuration; + +@Configuration(proxyBeanMethods = false) +@OpenAPIDefinition(info = @Info(title = "pgvector-springai", version = "v1.0.0"), servers = @Server(url = "/")) +public class SwaggerConfig {} diff --git a/embeddingstores/pgvector-springai/src/main/java/com/learning/ai/controller/QueryController.java b/embeddingstores/pgvector-springai/src/main/java/com/learning/ai/controller/QueryController.java index 874ab72..c36d6a5 100644 --- a/embeddingstores/pgvector-springai/src/main/java/com/learning/ai/controller/QueryController.java +++ b/embeddingstores/pgvector-springai/src/main/java/com/learning/ai/controller/QueryController.java @@ -18,7 +18,7 @@ public QueryController(PgVectorStoreService pgVectorStoreService) { } @GetMapping("/query") - AIChatResponse queryEmbeddedStore(@RequestParam String question) { - return pgVectorStoreService.queryEmbeddingStore(question); + AIChatResponse queryEmbeddedStore(@RequestParam String question, @RequestParam(required = false) Integer userId) { + return pgVectorStoreService.queryEmbeddingStore(question, userId); } } diff --git a/embeddingstores/pgvector-springai/src/main/java/com/learning/ai/service/PgVectorStoreService.java b/embeddingstores/pgvector-springai/src/main/java/com/learning/ai/service/PgVectorStoreService.java index 141d39b..444b403 100644 --- a/embeddingstores/pgvector-springai/src/main/java/com/learning/ai/service/PgVectorStoreService.java +++ b/embeddingstores/pgvector-springai/src/main/java/com/learning/ai/service/PgVectorStoreService.java @@ -2,6 +2,7 @@ import com.learning.ai.model.response.AIChatResponse; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -23,14 +24,19 @@ public PgVectorStoreService(VectorStore vectorStore) { public void storeEmbeddings() { // Store embeddings - List documents = - List.of(new Document("I like football."), new Document("The weather is good today.")); + List documents = List.of( + new Document("I like football.", Map.of("userId", 1)), + new Document("I like cricket.", Map.of("userId", 2)), + new Document("The weather is good today.")); vectorStore.add(documents); } - public AIChatResponse queryEmbeddingStore(String question) { + public AIChatResponse queryEmbeddingStore(String question, Integer userId) { // Retrieve embeddings SearchRequest query = SearchRequest.query(question).withTopK(1); + if (userId != null) { + query.withFilterExpression("userId == " + userId); + } List similarDocuments = vectorStore.similaritySearch(query); String relevantData = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining(System.lineSeparator())); diff --git a/embeddingstores/pgvector-springai/src/test/java/com/learning/ai/controller/QueryControllerTest.java b/embeddingstores/pgvector-springai/src/test/java/com/learning/ai/controller/QueryControllerTest.java index 949cd47..7f5e214 100644 --- a/embeddingstores/pgvector-springai/src/test/java/com/learning/ai/controller/QueryControllerTest.java +++ b/embeddingstores/pgvector-springai/src/test/java/com/learning/ai/controller/QueryControllerTest.java @@ -28,10 +28,32 @@ void setUp() { @Test void queryEmbeddedStore() { given().param("question", "What is your favourite sport") + .param("userId", 1) .when() .get("/api/ai/query") .then() .statusCode(200) .body("answer", equalTo("I like football.")); } + + @Test + void queryEmbeddedStoreWithMetadata() { + given().param("question", "What is your favourite sport") + .param("userId", 2) + .when() + .get("/api/ai/query") + .then() + .statusCode(200) + .body("answer", equalTo("I like cricket.")); + } + + @Test + void queryEmbeddedStoreWithOutMetadata() { + given().param("question", "What is weather today") + .when() + .get("/api/ai/query") + .then() + .statusCode(200) + .body("answer", equalTo("The weather is good today.")); + } }