From fc41d11adaed7633e4b4227eb092e1c17a20ead7 Mon Sep 17 00:00:00 2001 From: Nate Anderson Date: Fri, 10 Nov 2023 17:30:44 -0800 Subject: [PATCH] feat: add score threshold to MVI search (#516) Add score threshold to the MVI search method. Update the protos and change 'distance' to 'score' in internal search response. --- src/Momento.Sdk/IPreviewVectorIndexClient.cs | 58 +++++++------ .../Internal/VectorIndexDataClient.cs | 15 +++- src/Momento.Sdk/Momento.Sdk.csproj | 2 +- src/Momento.Sdk/PreviewVectorIndexClient.cs | 4 +- .../Momento.Sdk.Tests/VectorIndexDataTest.cs | 85 +++++++++++++++++++ 5 files changed, 131 insertions(+), 33 deletions(-) diff --git a/src/Momento.Sdk/IPreviewVectorIndexClient.cs b/src/Momento.Sdk/IPreviewVectorIndexClient.cs index b97e91f0..a28bdaba 100644 --- a/src/Momento.Sdk/IPreviewVectorIndexClient.cs +++ b/src/Momento.Sdk/IPreviewVectorIndexClient.cs @@ -156,32 +156,36 @@ public Task UpsertItemBatchAsync(string indexName, /// public Task DeleteItemBatchAsync(string indexName, IEnumerable ids); - /// - /// Searches for the most similar vectors to the query vector in the index. - /// Ranks the vectors according to the similarity metric specified when the - /// index was created. - /// - /// The name of the vector index to search in. - /// The vector to search for. - /// The number of results to return. Defaults to 10. - /// A list of metadata fields to return with each result. - /// - /// Task representing the result of the upsert operation. The - /// response object is resolved to a type-safe object of one of - /// the following subtypes: - /// - /// SearchResponse.Success - /// SearchResponse.Error - /// - /// Pattern matching can be used to operate on the appropriate subtype. - /// For example: - /// - /// if (response is SearchResponse.Error errorResponse) - /// { - /// // handle error as appropriate - /// } - /// - /// + /// + /// Searches for the most similar vectors to the query vector in the index. + /// Ranks the vectors according to the similarity metric specified when the + /// index was created. + /// + /// The name of the vector index to search in. + /// The vector to search for. + /// The number of results to return. Defaults to 10. + /// A list of metadata fields to return with each result. + /// A score threshold to filter results by. For cosine + /// similarity and inner product, scores lower than the threshold are excluded. For + /// euclidean similarity, scores higher than the threshold are excluded. The threshold + /// is exclusive. Defaults to None, ie no threshold. + /// + /// Task representing the result of the upsert operation. The + /// response object is resolved to a type-safe object of one of + /// the following subtypes: + /// + /// SearchResponse.Success + /// SearchResponse.Error + /// + /// Pattern matching can be used to operate on the appropriate subtype. + /// For example: + /// + /// if (response is SearchResponse.Error errorResponse) + /// { + /// // handle error as appropriate + /// } + /// + /// public Task SearchAsync(string indexName, IEnumerable queryVector, int topK = 10, - MetadataFields? metadataFields = null); + MetadataFields? metadataFields = null, float? scoreThreshold = null); } \ No newline at end of file diff --git a/src/Momento.Sdk/Internal/VectorIndexDataClient.cs b/src/Momento.Sdk/Internal/VectorIndexDataClient.cs index 8999823a..a04c50fd 100644 --- a/src/Momento.Sdk/Internal/VectorIndexDataClient.cs +++ b/src/Momento.Sdk/Internal/VectorIndexDataClient.cs @@ -68,7 +68,7 @@ public async Task DeleteItemBatchAsync(string indexName } public async Task SearchAsync(string indexName, IEnumerable queryVector, int topK, - MetadataFields? metadataFields) + MetadataFields? metadataFields, float? scoreThreshold) { try { @@ -91,9 +91,18 @@ public async Task SearchAsync(string indexName, IEnumerable - + diff --git a/src/Momento.Sdk/PreviewVectorIndexClient.cs b/src/Momento.Sdk/PreviewVectorIndexClient.cs index 6fa3debf..59633bfe 100644 --- a/src/Momento.Sdk/PreviewVectorIndexClient.cs +++ b/src/Momento.Sdk/PreviewVectorIndexClient.cs @@ -69,9 +69,9 @@ public async Task DeleteItemBatchAsync(string indexName /// public async Task SearchAsync(string indexName, IEnumerable queryVector, - int topK = 10, MetadataFields? metadataFields = null) + int topK = 10, MetadataFields? metadataFields = null, float? searchThreshold = null) { - return await dataClient.SearchAsync(indexName, queryVector, topK, metadataFields); + return await dataClient.SearchAsync(indexName, queryVector, topK, metadataFields, searchThreshold); } /// diff --git a/tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs b/tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs index 391a05a7..a892bd55 100644 --- a/tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs +++ b/tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs @@ -317,4 +317,89 @@ public async Task UpsertAndSearch_WithDiverseMetadata() await vectorIndexClient.DeleteIndexAsync(indexName); } } + + public static IEnumerable SearchThresholdTestCases => + new List + { + // similarity metric, scores, thresholds + new object[] + { + SimilarityMetric.CosineSimilarity, + new List { 1.0f, 0.0f, -1.0f }, + new List { 0.5f, -1.01f, 1.0f } + }, + new object[] + { + SimilarityMetric.InnerProduct, + new List { 4.0f, 0.0f, -4.0f }, + new List { 0.0f, -4.01f, 4.0f } + }, + new object[] + { + SimilarityMetric.EuclideanSimilarity, + new List { 2.0f, 10.0f, 18.0f }, + new List { 3.0f, 20.0f, -0.01f } + } + }; + + [Theory] + [MemberData(nameof(SearchThresholdTestCases))] + public async Task Search_PruneBasedOnThreshold(SimilarityMetric similarityMetric, List scores, + List thresholds) + { + var indexName = $"index-{Utils.NewGuidString()}"; + + var createResponse = await vectorIndexClient.CreateIndexAsync(indexName, 2, similarityMetric); + Assert.True(createResponse is CreateIndexResponse.Success, $"Unexpected response: {createResponse}"); + + try + { + var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, new List + { + new("test_item_1", new List { 1.0f, 1.0f }), + new("test_item_2", new List { -1.0f, 1.0f }), + new("test_item_3", new List { -1.0f, -1.0f }) + }); + Assert.True(upsertResponse is UpsertItemBatchResponse.Success, + $"Unexpected response: {upsertResponse}"); + + await Task.Delay(2_000); + + var queryVector = new List { 2.0f, 2.0f }; + var searchHits = new List + { + new("test_item_1", scores[0]), + new("test_item_2", scores[1]), + new("test_item_3", scores[2]) + }; + + // Test threshold to get only the top result + var searchResponse = + await vectorIndexClient.SearchAsync(indexName, queryVector, 3, scoreThreshold: thresholds[0]); + Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); + var successResponse = (SearchResponse.Success)searchResponse; + Assert.Equal(new List + { + searchHits[0] + }, successResponse.Hits); + + // Test threshold to get all results + searchResponse = + await vectorIndexClient.SearchAsync(indexName, queryVector, 3, scoreThreshold: thresholds[1]); + Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); + successResponse = (SearchResponse.Success)searchResponse; + Assert.Equal(searchHits, successResponse.Hits); + + // Test threshold to get no results + searchResponse = + await vectorIndexClient.SearchAsync(indexName, queryVector, 3, scoreThreshold: thresholds[2]); + Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); + successResponse = (SearchResponse.Success)searchResponse; + Assert.Empty(successResponse.Hits); + } + finally + { + await vectorIndexClient.DeleteIndexAsync(indexName); + } + } } \ No newline at end of file