-
Notifications
You must be signed in to change notification settings - Fork 824
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added VectorStore, Retriever with implementations
* Removed some classes out of loader to their own packages
- Loading branch information
1 parent
b83b190
commit 049ac8a
Showing
17 changed files
with
648 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
...k/ai/core/loader/DocumentTransformer.java → ...ai/core/document/DocumentTransformer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
...ramework/ai/core/loader/MetadataMode.java → ...mework/ai/core/document/MetadataMode.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 changes: 4 additions & 0 deletions
4
spring-ai-core/src/main/java/org/springframework/ai/core/embedding/EmbeddingClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 2 additions & 1 deletion
3
spring-ai-core/src/main/java/org/springframework/ai/core/loader/Loader.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
16 changes: 16 additions & 0 deletions
16
spring-ai-core/src/main/java/org/springframework/ai/core/retriever/Retriever.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
package org.springframework.ai.core.retriever; | ||
|
||
import org.springframework.ai.core.document.Document; | ||
|
||
import java.util.List; | ||
|
||
public interface Retriever { | ||
|
||
/** | ||
* Retrieves relevant documents however the implementation sees fit. | ||
* @param query query string | ||
* @return relevant documents | ||
*/ | ||
List<Document> retrieve(String query); | ||
|
||
} |
58 changes: 58 additions & 0 deletions
58
...i-core/src/main/java/org/springframework/ai/core/retriever/impl/VectorStoreRetriever.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
package org.springframework.ai.core.retriever.impl; | ||
|
||
import org.springframework.ai.core.document.Document; | ||
import org.springframework.ai.core.retriever.Retriever; | ||
import org.springframework.ai.core.vectorstore.VectorStore; | ||
|
||
import java.util.List; | ||
import java.util.Objects; | ||
import java.util.Optional; | ||
|
||
public class VectorStoreRetriever implements Retriever { | ||
|
||
private VectorStore vectorStore; | ||
|
||
int k; | ||
|
||
Optional<Double> threshold = Optional.empty(); | ||
|
||
public VectorStoreRetriever(VectorStore vectorStore) { | ||
this(vectorStore, 4); | ||
} | ||
|
||
public VectorStoreRetriever(VectorStore vectorStore, int k) { | ||
Objects.requireNonNull(vectorStore, "VectorStore must not be null"); | ||
this.vectorStore = vectorStore; | ||
this.k = k; | ||
} | ||
|
||
public VectorStoreRetriever(VectorStore vectorStore, int k, double threshold) { | ||
Objects.requireNonNull(vectorStore, "VectorStore must not be null"); | ||
this.vectorStore = vectorStore; | ||
this.k = k; | ||
this.threshold = Optional.of(threshold); | ||
} | ||
|
||
public VectorStore getVectorStore() { | ||
return vectorStore; | ||
} | ||
|
||
public int getK() { | ||
return k; | ||
} | ||
|
||
public Optional<Double> getThreshold() { | ||
return threshold; | ||
} | ||
|
||
@Override | ||
public List<Document> retrieve(String query) { | ||
if (threshold.isPresent()) { | ||
return this.vectorStore.similaritySearch(query, this.k, this.threshold.get()); | ||
} | ||
else { | ||
return this.vectorStore.similaritySearch(query, this.k); | ||
} | ||
} | ||
|
||
} |
6 changes: 3 additions & 3 deletions
6
...ai/core/loader/splitter/TextSplitter.java → ...mework/ai/core/splitter/TextSplitter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2 changes: 1 addition & 1 deletion
2
...re/loader/splitter/TokenTextSplitter.java → ...k/ai/core/splitter/TokenTextSplitter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 33 additions & 0 deletions
33
spring-ai-core/src/main/java/org/springframework/ai/core/vectorstore/VectorStore.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package org.springframework.ai.core.vectorstore; | ||
|
||
import org.springframework.ai.core.document.Document; | ||
import org.springframework.ai.core.embedding.EmbeddingClient; | ||
|
||
import java.util.List; | ||
import java.util.Optional; | ||
|
||
public interface VectorStore { | ||
|
||
/** | ||
* Adds Documents to the vector store. | ||
* @param documents the list of documents to store Will throw an exception if the | ||
* underlying provider checks for duplicate IDs on add | ||
*/ | ||
void add(List<Document> documents); | ||
|
||
Optional<Boolean> delete(List<String> idList); | ||
|
||
List<Document> similaritySearch(String query); | ||
|
||
List<Document> similaritySearch(String query, int k); | ||
|
||
/** | ||
* @param query The query to send, it will be converted to an embeddeing based on the | ||
* configuration of the vector store. | ||
* @param k the top 'k' similar results | ||
* @param threshold the lower bound of the similarity score | ||
* @return similar documents | ||
*/ | ||
List<Document> similaritySearch(String query, int k, double threshold); | ||
|
||
} |
132 changes: 132 additions & 0 deletions
132
...-core/src/main/java/org/springframework/ai/core/vectorstore/impl/InMemoryVectorStore.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
package org.springframework.ai.core.vectorstore.impl; | ||
|
||
import org.springframework.ai.core.document.Document; | ||
import org.springframework.ai.core.embedding.EmbeddingClient; | ||
import org.springframework.ai.core.vectorstore.VectorStore; | ||
|
||
import java.util.*; | ||
import java.util.concurrent.ConcurrentHashMap; | ||
|
||
/*** | ||
* @author Raphael Yu | ||
* @author Dingmeng Xue | ||
* @author Mark Pollack | ||
*/ | ||
public class InMemoryVectorStore implements VectorStore { | ||
|
||
private Map<String, Document> store = new ConcurrentHashMap<>(); | ||
|
||
private EmbeddingClient embeddingClient; | ||
|
||
public InMemoryVectorStore(EmbeddingClient embeddingClient) { | ||
Objects.requireNonNull(embeddingClient, "EmbeddingClient must not be null"); | ||
this.embeddingClient = embeddingClient; | ||
} | ||
|
||
@Override | ||
public void add(List<Document> documents) { | ||
for (Document document : documents) { | ||
List<Double> embedding = this.embeddingClient.createEmbedding(document); | ||
document.setEmbedding(embedding); | ||
this.store.put(document.getId(), document); | ||
} | ||
} | ||
|
||
@Override | ||
public Optional<Boolean> delete(List<String> idList) { | ||
for (String id : idList) { | ||
this.store.remove(id); | ||
} | ||
return Optional.of(true); | ||
} | ||
|
||
@Override | ||
public List<Document> similaritySearch(String query) { | ||
return similaritySearch(query, 4); | ||
} | ||
|
||
@Override | ||
public List<Document> similaritySearch(String query, int k) { | ||
List<Double> userQueryEmbedding = getUserQueryEmbedding(query); | ||
var similarities = this.store.values() | ||
.stream() | ||
.map(entry -> new Similarity(entry.getId(), | ||
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding()))) | ||
.sorted(Comparator.<Similarity>comparingDouble(s -> s.similarity).reversed()) | ||
.limit(k) | ||
.map(s -> store.get(s.key)) | ||
.toList(); | ||
return similarities; | ||
} | ||
|
||
@Override | ||
public List<Document> similaritySearch(String query, int k, double threshold) { | ||
List<Double> userQueryEmbedding = getUserQueryEmbedding(query); | ||
var similarities = this.store.values() | ||
.stream() | ||
.map(entry -> new Similarity(entry.getId(), | ||
EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding()))) | ||
.filter(s -> s.similarity >= threshold) | ||
.sorted(Comparator.<Similarity>comparingDouble(s -> s.similarity).reversed()) | ||
.limit(k) | ||
.map(s -> store.get(s.key)) | ||
.toList(); | ||
return similarities; | ||
} | ||
|
||
private List<Double> getUserQueryEmbedding(String query) { | ||
List<Double> userQueryEmbedding = this.embeddingClient.createEmbedding(query); | ||
return userQueryEmbedding; | ||
} | ||
|
||
public static class Similarity { | ||
|
||
private String key; | ||
|
||
private double similarity; | ||
|
||
public Similarity(String key, double similarity) { | ||
this.key = key; | ||
this.similarity = similarity; | ||
} | ||
|
||
} | ||
|
||
public class EmbeddingMath { | ||
|
||
public static double cosineSimilarity(List<Double> vectorX, List<Double> vectorY) { | ||
if (vectorX.size() != vectorY.size()) { | ||
throw new IllegalArgumentException("Vectors lengths must be equal"); | ||
} | ||
|
||
double dotProduct = dotProduct(vectorX, vectorY); | ||
double normX = norm(vectorX); | ||
double normY = norm(vectorY); | ||
|
||
if (normX == 0 || normY == 0) { | ||
throw new IllegalArgumentException("Vectors cannot have zero norm"); | ||
} | ||
|
||
return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY)); | ||
} | ||
|
||
public static double dotProduct(List<Double> vectorX, List<Double> vectorY) { | ||
if (vectorX.size() != vectorY.size()) { | ||
throw new IllegalArgumentException("Vectors lengths must be equal"); | ||
} | ||
|
||
double result = 0; | ||
for (int i = 0; i < vectorX.size(); ++i) { | ||
result += vectorX.get(i) * vectorY.get(i); | ||
} | ||
|
||
return result; | ||
} | ||
|
||
public static double norm(List<Double> vector) { | ||
return dotProduct(vector, vector); | ||
} | ||
|
||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.