Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat : adds metadata filtering #47

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package com.learning.ai.config;

import io.swagger.v3.oas.annotations.OpenAPIDefinition;
import io.swagger.v3.oas.annotations.info.Info;
import io.swagger.v3.oas.annotations.servers.Server;
import org.springframework.context.annotation.Configuration;

@Configuration(proxyBeanMethods = false)
@OpenAPIDefinition(
info = @Info(title = "pgvector-springai", version = "v1.0.0"),
servers = @Server(url = "/"))
public class SwaggerConfig {}
package com.learning.ai.config;

import io.swagger.v3.oas.annotations.OpenAPIDefinition;
import io.swagger.v3.oas.annotations.info.Info;
import io.swagger.v3.oas.annotations.servers.Server;
import org.springframework.context.annotation.Configuration;

@Configuration(proxyBeanMethods = false)
@OpenAPIDefinition(info = @Info(title = "pgvector-springai", version = "v1.0.0"), servers = @Server(url = "/"))
public class SwaggerConfig {}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public QueryController(PgVectorStoreService pgVectorStoreService) {
}

@GetMapping("/query")
AIChatResponse queryEmbeddedStore(@RequestParam String question) {
return pgVectorStoreService.queryEmbeddingStore(question);
AIChatResponse queryEmbeddedStore(@RequestParam String question, @RequestParam(required = false) Integer userId) {
return pgVectorStoreService.queryEmbeddingStore(question, userId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.learning.ai.model.response.AIChatResponse;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -23,14 +24,19 @@ public PgVectorStoreService(VectorStore vectorStore) {

public void storeEmbeddings() {
// Store embeddings
List<Document> documents =
List.of(new Document("I like football."), new Document("The weather is good today."));
List<Document> documents = List.of(
new Document("I like football.", Map.of("userId", 1)),
new Document("I like cricket.", Map.of("userId", 2)),
new Document("The weather is good today."));
vectorStore.add(documents);
}

public AIChatResponse queryEmbeddingStore(String question) {
public AIChatResponse queryEmbeddingStore(String question, Integer userId) {
// Retrieve embeddings
SearchRequest query = SearchRequest.query(question).withTopK(1);
if (userId != null) {
query.withFilterExpression("userId == " + userId);
}
List<Document> similarDocuments = vectorStore.similaritySearch(query);
String relevantData =
similarDocuments.stream().map(Document::getContent).collect(Collectors.joining(System.lineSeparator()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,32 @@ void setUp() {
@Test
void queryEmbeddedStore() {
given().param("question", "What is your favourite sport")
.param("userId", 1)
.when()
.get("/api/ai/query")
.then()
.statusCode(200)
.body("answer", equalTo("I like football."));
}

@Test
void queryEmbeddedStoreWithMetadata() {
given().param("question", "What is your favourite sport")
.param("userId", 2)
.when()
.get("/api/ai/query")
.then()
.statusCode(200)
.body("answer", equalTo("I like cricket."));
}

@Test
void queryEmbeddedStoreWithOutMetadata() {
given().param("question", "What is weather today")
.when()
.get("/api/ai/query")
.then()
.statusCode(200)
.body("answer", equalTo("The weather is good today."));
}
}