Skip to content

Commit

Permalink
Added VectorStore, Retriever with implementations
Browse files Browse the repository at this point in the history
* Removed some classes out of loader to their own packages
  • Loading branch information
markpollack committed Aug 14, 2023
1 parent b83b190 commit 049ac8a
Show file tree
Hide file tree
Showing 17 changed files with 648 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.springframework.ai.core.loader;
package org.springframework.ai.core.document;

import org.springframework.util.StringUtils;

Expand All @@ -11,11 +11,11 @@ public class Document {
private static String DEFAULT_METADATA_TEMPLATE = "{key}: {value}";

/**
* Unique ID, creates UUID by default
* Unique ID
*/
private String id;
private String id = UUID.randomUUID().toString();

// Embedding List<Float>
private List<Double> embedding = new ArrayList<>();

/**
* Metadata for the document. It should not be nested and values should be restricted
Expand Down Expand Up @@ -51,10 +51,18 @@ public Document(String text, Map<String, Object> metadata) {
this.metadata = metadata;
}

public String getId() {
return id;
}

public String getText() {
return this.text;
}

public String getContent() {
return getContent(MetadataMode.ALL);
}

public String getContent(MetadataMode metadataMode) {
if (metadataMode == MetadataMode.NONE) {
return this.text;
Expand Down Expand Up @@ -111,6 +119,14 @@ public Map<String, Object> getMetadata() {
return metadata;
}

public List<Double> getEmbedding() {
return embedding;
}

public void setEmbedding(List<Double> embedding) {
this.embedding = embedding;
}

@Override
public String toString() {
return "Document{" + "id='" + id + '\'' + ", metadata=" + metadata + ", text='" + text + '\'' + '}';
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.springframework.ai.core.loader;
package org.springframework.ai.core.document;

import java.util.List;
import java.util.function.Function;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.springframework.ai.core.loader;
package org.springframework.ai.core.document;

public enum MetadataMode {

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package org.springframework.ai.core.embedding;

import org.springframework.ai.core.document.Document;

import java.util.List;

public interface EmbeddingClient {

List<Double> createEmbedding(String text);

List<Double> createEmbedding(Document document);

List<List<Double>> createEmbedding(List<String> texts);

EmbeddingResponse createEmbeddingResult(List<String> texts);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.springframework.ai.core.loader;

import org.springframework.ai.core.loader.splitter.TextSplitter;
import org.springframework.ai.core.document.Document;
import org.springframework.ai.core.splitter.TextSplitter;

import java.util.List;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.ai.core.loader.Document;
import org.springframework.ai.core.document.Document;
import org.springframework.ai.core.loader.Loader;
import org.springframework.ai.core.loader.splitter.TextSplitter;
import org.springframework.ai.core.loader.splitter.TokenTextSplitter;
import org.springframework.ai.core.splitter.TextSplitter;
import org.springframework.ai.core.splitter.TokenTextSplitter;
import org.springframework.core.io.Resource;

import java.io.IOException;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.springframework.ai.core.retriever;

import org.springframework.ai.core.document.Document;

import java.util.List;

public interface Retriever {

/**
* Retrieves relevant documents however the implementation sees fit.
* @param query query string
* @return relevant documents
*/
List<Document> retrieve(String query);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package org.springframework.ai.core.retriever.impl;

import org.springframework.ai.core.document.Document;
import org.springframework.ai.core.retriever.Retriever;
import org.springframework.ai.core.vectorstore.VectorStore;

import java.util.List;
import java.util.Objects;
import java.util.Optional;

public class VectorStoreRetriever implements Retriever {

private VectorStore vectorStore;

int k;

Optional<Double> threshold = Optional.empty();

public VectorStoreRetriever(VectorStore vectorStore) {
this(vectorStore, 4);
}

public VectorStoreRetriever(VectorStore vectorStore, int k) {
Objects.requireNonNull(vectorStore, "VectorStore must not be null");
this.vectorStore = vectorStore;
this.k = k;
}

public VectorStoreRetriever(VectorStore vectorStore, int k, double threshold) {
Objects.requireNonNull(vectorStore, "VectorStore must not be null");
this.vectorStore = vectorStore;
this.k = k;
this.threshold = Optional.of(threshold);
}

public VectorStore getVectorStore() {
return vectorStore;
}

public int getK() {
return k;
}

public Optional<Double> getThreshold() {
return threshold;
}

@Override
public List<Document> retrieve(String query) {
if (threshold.isPresent()) {
return this.vectorStore.similaritySearch(query, this.k, this.threshold.get());
}
else {
return this.vectorStore.similaritySearch(query, this.k);
}
}

}
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package org.springframework.ai.core.loader.splitter;
package org.springframework.ai.core.splitter;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.core.loader.Document;
import org.springframework.ai.core.loader.DocumentTransformer;
import org.springframework.ai.core.document.Document;
import org.springframework.ai.core.document.DocumentTransformer;

import java.util.ArrayList;
import java.util.HashMap;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.springframework.ai.core.loader.splitter;
package org.springframework.ai.core.splitter;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.springframework.ai.core.vectorstore;

import org.springframework.ai.core.document.Document;
import org.springframework.ai.core.embedding.EmbeddingClient;

import java.util.List;
import java.util.Optional;

public interface VectorStore {

/**
* Adds Documents to the vector store.
* @param documents the list of documents to store Will throw an exception if the
* underlying provider checks for duplicate IDs on add
*/
void add(List<Document> documents);

Optional<Boolean> delete(List<String> idList);

List<Document> similaritySearch(String query);

List<Document> similaritySearch(String query, int k);

/**
* @param query The query to send, it will be converted to an embeddeing based on the
* configuration of the vector store.
* @param k the top 'k' similar results
* @param threshold the lower bound of the similarity score
* @return similar documents
*/
List<Document> similaritySearch(String query, int k, double threshold);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package org.springframework.ai.core.vectorstore.impl;

import org.springframework.ai.core.document.Document;
import org.springframework.ai.core.embedding.EmbeddingClient;
import org.springframework.ai.core.vectorstore.VectorStore;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/***
* @author Raphael Yu
* @author Dingmeng Xue
* @author Mark Pollack
*/
public class InMemoryVectorStore implements VectorStore {

private Map<String, Document> store = new ConcurrentHashMap<>();

private EmbeddingClient embeddingClient;

public InMemoryVectorStore(EmbeddingClient embeddingClient) {
Objects.requireNonNull(embeddingClient, "EmbeddingClient must not be null");
this.embeddingClient = embeddingClient;
}

@Override
public void add(List<Document> documents) {
for (Document document : documents) {
List<Double> embedding = this.embeddingClient.createEmbedding(document);
document.setEmbedding(embedding);
this.store.put(document.getId(), document);
}
}

@Override
public Optional<Boolean> delete(List<String> idList) {
for (String id : idList) {
this.store.remove(id);
}
return Optional.of(true);
}

@Override
public List<Document> similaritySearch(String query) {
return similaritySearch(query, 4);
}

@Override
public List<Document> similaritySearch(String query, int k) {
List<Double> userQueryEmbedding = getUserQueryEmbedding(query);
var similarities = this.store.values()
.stream()
.map(entry -> new Similarity(entry.getId(),
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
.sorted(Comparator.<Similarity>comparingDouble(s -> s.similarity).reversed())
.limit(k)
.map(s -> store.get(s.key))
.toList();
return similarities;
}

@Override
public List<Document> similaritySearch(String query, int k, double threshold) {
List<Double> userQueryEmbedding = getUserQueryEmbedding(query);
var similarities = this.store.values()
.stream()
.map(entry -> new Similarity(entry.getId(),
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding())))
.filter(s -> s.similarity >= threshold)
.sorted(Comparator.<Similarity>comparingDouble(s -> s.similarity).reversed())
.limit(k)
.map(s -> store.get(s.key))
.toList();
return similarities;
}

private List<Double> getUserQueryEmbedding(String query) {
List<Double> userQueryEmbedding = this.embeddingClient.createEmbedding(query);
return userQueryEmbedding;
}

public static class Similarity {

private String key;

private double similarity;

public Similarity(String key, double similarity) {
this.key = key;
this.similarity = similarity;
}

}

public class EmbeddingMath {

public static double cosineSimilarity(List<Double> vectorX, List<Double> vectorY) {
if (vectorX.size() != vectorY.size()) {
throw new IllegalArgumentException("Vectors lengths must be equal");
}

double dotProduct = dotProduct(vectorX, vectorY);
double normX = norm(vectorX);
double normY = norm(vectorY);

if (normX == 0 || normY == 0) {
throw new IllegalArgumentException("Vectors cannot have zero norm");
}

return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY));
}

public static double dotProduct(List<Double> vectorX, List<Double> vectorY) {
if (vectorX.size() != vectorY.size()) {
throw new IllegalArgumentException("Vectors lengths must be equal");
}

double result = 0;
for (int i = 0; i < vectorX.size(); ++i) {
result += vectorX.get(i) * vectorY.get(i);
}

return result;
}

public static double norm(List<Double> vector) {
return dotProduct(vector, vector);
}

}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.springframework.ai.core.loader;

import org.junit.jupiter.api.Test;
import org.springframework.ai.core.document.Document;
import org.springframework.ai.core.loader.impl.JsonLoader;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
Expand All @@ -23,7 +24,7 @@ void loadJson() {
List<Document> documents = jsonLoader.load();
assertThat(documents).isNotEmpty();
for (Document document : documents) {
System.out.println(document);
assertThat(document.getText()).isNotEmpty();
}
}

Expand Down
Loading

0 comments on commit 049ac8a

Please sign in to comment.