Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Document structure with source to VectorStore #761

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,6 @@ constructor(
val conversationId: ConversationId? = ConversationId(UUID.generateUUID().toString())
) {

@AiDsl
@JvmSynthetic
suspend fun addContext(vararg docs: String) {
store.addTexts(docs.toList())
}

@AiDsl
@JvmSynthetic
suspend fun addContext(docs: Iterable<String>): Unit {
store.addTexts(docs.toList())
}

companion object {

@JvmSynthetic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,18 @@ class CombinedVectorStore(private val top: VectorStore, private val bottom: Vect
.reversed()
}

override suspend fun similaritySearch(query: String, limit: Int): List<String> {
override suspend fun similaritySearch(query: String, limit: Int): List<VectorStore.Document> {
val topResults = top.similaritySearch(query, limit)
return when {
topResults.size >= limit -> topResults
else -> topResults + bottom.similaritySearch(query, limit - topResults.size)
}
}

override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<String> {
override suspend fun similaritySearchByVector(
embedding: Embedding,
limit: Int
): List<VectorStore.Document> {
val topResults = top.similaritySearchByVector(embedding, limit)
return when {
topResults.size >= limit -> topResults
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import kotlin.math.sqrt

private data class State(
val orderedMemories: Map<ConversationId, List<Memory>>,
val documents: List<String>,
val documents: List<VectorStore.Document>,
val precomputedEmbeddings: Map<String, Embedding>
) {
companion object {
Expand Down Expand Up @@ -75,26 +75,30 @@ private constructor(
.reversed()
}

override suspend fun addTexts(texts: List<String>) {
override suspend fun addDocuments(texts: List<VectorStore.Document>) {
val docsContent = texts.map { it.content }
val embeddingsList =
embeddings.embedDocuments(texts, embeddingRequestModel = embeddingRequestModel)
embeddings.embedDocuments(docsContent, embeddingRequestModel = embeddingRequestModel)
state.getAndUpdate { prevState ->
val newEmbeddings = prevState.precomputedEmbeddings + texts.zip(embeddingsList)
val newEmbeddings = prevState.precomputedEmbeddings + docsContent.zip(embeddingsList)
State(prevState.orderedMemories, prevState.documents + texts, newEmbeddings)
}
}

override suspend fun similaritySearch(query: String, limit: Int): List<String> {
override suspend fun similaritySearch(query: String, limit: Int): List<VectorStore.Document> {
val queryEmbedding =
embeddings.embedQuery(query, embeddingRequestModel = embeddingRequestModel).firstOrNull()
return queryEmbedding?.let { similaritySearchByVector(it, limit) }.orEmpty()
}

override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<String> {
override suspend fun similaritySearchByVector(
embedding: Embedding,
limit: Int
): List<VectorStore.Document> {
val state0 = state.get()
return state0.documents
.asSequence()
.mapNotNull { doc -> state0.precomputedEmbeddings[doc]?.let { doc to it } }
.mapNotNull { doc -> state0.precomputedEmbeddings[doc.content]?.let { doc to it } }
.map { (doc, e) -> doc to embedding.cosineSimilarity(e) }
.sortedByDescending { (_, similarity) -> similarity }
.take(limit)
Expand All @@ -103,9 +107,9 @@ private constructor(
}

private fun Embedding.cosineSimilarity(other: Embedding): Double {
val dotProduct = this.embedding.zip(other.embedding).sumOf { (a, b) -> (a * b).toDouble() }
val magnitudeA = sqrt(this.embedding.sumOf { (it * it).toDouble() })
val magnitudeB = sqrt(other.embedding.sumOf { (it * it).toDouble() })
val dotProduct = this.embedding.zip(other.embedding).sumOf { (a, b) -> (a * b) }
val magnitudeA = sqrt(this.embedding.sumOf { (it * it) })
val magnitudeB = sqrt(other.embedding.sumOf { (it * it) })
return dotProduct / (magnitudeA * magnitudeB)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,22 @@ package com.xebia.functional.xef.store
import arrow.atomic.AtomicInt
import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestModel
import com.xebia.functional.openai.generated.model.Embedding
import com.xebia.functional.xef.Config
import kotlin.jvm.JvmStatic
import kotlinx.serialization.Serializable

interface VectorStore {

@Serializable
data class Document(val content: String, val source: String) {
fun toJson(): String = Config.DEFAULT.json.encodeToString(serializer(), this)

companion object {
fun fromJson(json: String): Document =
Config.DEFAULT.json.decodeFromString(serializer(), json)
}
}

val indexValue: AtomicInt

fun incrementIndexAndGet(): Int = indexValue.addAndGet(1)
Expand All @@ -27,9 +39,9 @@ interface VectorStore {
* @param texts list of text to add to the vector store
* @return a list of IDs from adding the texts to the vector store
*/
suspend fun addTexts(texts: List<String>)
suspend fun addDocuments(texts: List<Document>)

suspend fun addText(texts: String) = addTexts(listOf(texts))
suspend fun addDocument(texts: Document) = addDocuments(listOf(texts))

/**
* Return the docs most similar to the query
Expand All @@ -38,7 +50,7 @@ interface VectorStore {
* @param limit number of documents to return
* @return a list of Documents most similar to query
*/
suspend fun similaritySearch(query: String, limit: Int): List<String>
suspend fun similaritySearch(query: String, limit: Int): List<Document>

/**
* Return the docs most similar to the embedding
Expand All @@ -47,7 +59,7 @@ interface VectorStore {
* @param limit number of documents to return
* @return list of Documents most similar to the embedding
*/
suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<String>
suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<Document>

companion object {
@JvmStatic
Expand All @@ -65,14 +77,15 @@ interface VectorStore {
limitTokens: Int
): List<Memory> = emptyList()

override suspend fun addTexts(texts: List<String>) {}
override suspend fun addDocuments(texts: List<Document>) {}

override suspend fun similaritySearch(query: String, limit: Int): List<String> = emptyList()
override suspend fun similaritySearch(query: String, limit: Int): List<Document> =
emptyList()

override suspend fun similaritySearchByVector(
embedding: Embedding,
limit: Int
): List<String> = emptyList()
): List<Document> = emptyList()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.xebia.functional.xef.vectorstore

import com.xebia.functional.xef.OpenAI
import com.xebia.functional.xef.store.LocalVectorStore
import com.xebia.functional.xef.store.VectorStore.Document

suspend fun main() {
val embeddings = OpenAI().embeddings
val vectorStore = LocalVectorStore(embeddings)
val helloDoc = Document("Hello, how are you?", "source1")
val unrelatedDoc = Document("Unrelated text", "source2")
vectorStore.addDocuments(listOf(helloDoc, unrelatedDoc))
val maybeHelloDoc = vectorStore.similaritySearch("Hello", 1).first()
assert(maybeHelloDoc == helloDoc) { "Expected $helloDoc but got $maybeHelloDoc" }
val maybeUnrelatedDoc = vectorStore.similaritySearch("Unrelated", 1).first()
assert(maybeUnrelatedDoc == unrelatedDoc) { "Expected $unrelatedDoc but got $maybeUnrelatedDoc" }
println("All expected documents found!")
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,22 +84,23 @@ class PGVectorStore(
}
}

override suspend fun addTexts(texts: List<String>): Unit =
override suspend fun addDocuments(texts: List<VectorStore.Document>): Unit =
dataSource.connection {
val embeddings = embeddings.embedDocuments(texts, chunkSize, embeddingRequestModel)
val docsContent = texts.map { it.content }
val embeddings = embeddings.embedDocuments(docsContent, chunkSize, embeddingRequestModel)
val collection = getCollection(collectionName)
texts.zip(embeddings) { text, embedding ->
val uuid = UUID.generateUUID()
update(addNewText) {
bind(uuid.toString())
bind(collection.uuid.toString())
bind(embedding.embedding.toString())
bind(text)
bind(text.toJson())
}
}
}

override suspend fun similaritySearch(query: String, limit: Int): List<String> =
override suspend fun similaritySearch(query: String, limit: Int): List<VectorStore.Document> =
dataSource.connection {
val collection = getCollection(collectionName)

Expand All @@ -123,10 +124,12 @@ class PGVectorStore(
}
) {
string()
}.map { json ->
VectorStore.Document.fromJson(json)
}
}

override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<String> =
override suspend fun similaritySearchByVector(embedding: Embedding, limit: Int): List<VectorStore.Document> =
dataSource.connection {
val collection = getCollection(collectionName)
queryAsList(
Expand All @@ -138,6 +141,8 @@ class PGVectorStore(
}
) {
string()
}.map { json ->
VectorStore.Document.fromJson(json)
}
}

Expand Down
14 changes: 8 additions & 6 deletions integrations/postgresql/src/test/kotlin/xef/PGVectorStoreSpec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.xebia.functional.openai.generated.model.CreateChatCompletionRequestMo
import com.xebia.functional.openai.generated.model.CreateEmbeddingRequestModel
import com.xebia.functional.openai.generated.model.Embedding
import com.xebia.functional.xef.store.PGVectorStore
import com.xebia.functional.xef.store.VectorStore
import com.xebia.functional.xef.store.migrations.runDatabaseMigrations
import com.xebia.functional.xef.store.postgresql.PGDistanceStrategy
import com.zaxxer.hikari.HikariConfig
Expand All @@ -17,7 +18,6 @@ import io.kotest.matchers.shouldBe
import org.junit.jupiter.api.assertThrows
import org.testcontainers.containers.PostgreSQLContainer
import org.testcontainers.utility.DockerImageName
import kotlin.coroutines.coroutineContext

val postgres: PostgreSQLContainer<Nothing> =
PostgreSQLContainer(
Expand Down Expand Up @@ -67,10 +67,12 @@ class PGVectorStoreSpec :
postgresVector.createCollection()
}

val docs = listOf(VectorStore.Document(content = "foo", source = "tests"), VectorStore.Document(content = "bar", source = "tests"))

"initialDbSetup should configure the DB properly" { pg().initialDbSetup() }

"addTexts should fail with a CollectionNotFoundError if collection isn't present in the DB" {
assertThrows<IllegalStateException> { pg().addTexts(listOf("foo", "bar")) }.message shouldBe
assertThrows<IllegalStateException> { pg().addDocuments(docs) }.message shouldBe
"Collection 'test_collection' not found"
}

Expand All @@ -82,13 +84,13 @@ class PGVectorStoreSpec :
"createCollection should create collection" { pg().createCollection() }

"addTexts should not fail now that we created the collection" {
pg().addTexts(listOf("foo", "bar"))
pg().addDocuments(docs)
}

"similaritySearchByVector should return both documents" {
pg().addTexts(listOf("bar", "foo"))
pg().addDocuments(docs.reversed())
pg().similaritySearchByVector(Embedding(0, listOf(4.0, 5.0, 6.0), Embedding.Object.embedding), 2) shouldBe
listOf("bar", "foo")
docs.reversed()
}

"similaritySearch should return 2 documents" {
Expand All @@ -104,7 +106,7 @@ class PGVectorStoreSpec :
pg().similaritySearchByVector(
Embedding(0, listOf(1.0, 2.0, 3.0), Embedding.Object.embedding),
1
) shouldBe listOf("foo")
) shouldBe listOf(docs[0])
}

"the added memories sorted by index should be obtained in the same order" {
Expand Down
Loading