diff --git a/src/Momento.Sdk/IPreviewVectorIndexClient.cs b/src/Momento.Sdk/IPreviewVectorIndexClient.cs index a28bdaba..3ae6f414 100644 --- a/src/Momento.Sdk/IPreviewVectorIndexClient.cs +++ b/src/Momento.Sdk/IPreviewVectorIndexClient.cs @@ -188,4 +188,37 @@ public Task UpsertItemBatchAsync(string indexName, /// public Task SearchAsync(string indexName, IEnumerable queryVector, int topK = 10, MetadataFields? metadataFields = null, float? scoreThreshold = null); + + /// + /// Searches for the most similar vectors to the query vector in the index. + /// Ranks the vectors according to the similarity metric specified when the + /// index was created. Also returns the vectors associated with each result. + /// + /// The name of the vector index to search in. + /// The vector to search for. + /// The number of results to return. Defaults to 10. + /// A list of metadata fields to return with each result. + /// A score threshold to filter results by. For cosine + /// similarity and inner product, scores lower than the threshold are excluded. For + /// euclidean similarity, scores higher than the threshold are excluded. The threshold + /// is exclusive. Defaults to None, ie no threshold. + /// + /// Task representing the result of the upsert operation. The + /// response object is resolved to a type-safe object of one of + /// the following subtypes: + /// + /// SearchResponse.Success + /// SearchResponse.Error + /// + /// Pattern matching can be used to operate on the appropriate subtype. + /// For example: + /// + /// if (response is SearchResponse.Error errorResponse) + /// { + /// // handle error as appropriate + /// } + /// + /// + public Task SearchAndFetchVectorsAsync(string indexName, IEnumerable queryVector, int topK = 10, + MetadataFields? metadataFields = null, float? scoreThreshold = null); } \ No newline at end of file diff --git a/src/Momento.Sdk/Internal/VectorIndexDataClient.cs b/src/Momento.Sdk/Internal/VectorIndexDataClient.cs index a04c50fd..b62f90f7 100644 --- a/src/Momento.Sdk/Internal/VectorIndexDataClient.cs +++ b/src/Momento.Sdk/Internal/VectorIndexDataClient.cs @@ -102,7 +102,7 @@ public async Task SearchAsync(string indexName, IEnumerable SearchAsync(string indexName, IEnumerable SearchAndFetchVectorsAsync(string indexName, + IEnumerable queryVector, int topK, MetadataFields? metadataFields, float? scoreThreshold) + { + try + { + _logger.LogTraceVectorIndexRequest("searchAndFetchVectors", indexName); + CheckValidIndexName(indexName); + var validatedTopK = ValidateTopK(topK); + metadataFields ??= new List(); + var metadataRequest = metadataFields switch + { + MetadataFields.AllFields => new _MetadataRequest { All = new _MetadataRequest.Types.All() }, + MetadataFields.List list => new _MetadataRequest + { + Some = new _MetadataRequest.Types.Some { Fields = { list.Fields } } + }, + _ => throw new InvalidArgumentException($"Unknown metadata fields type {metadataFields.GetType()}") + }; + + var request = new _SearchAndFetchVectorsRequest() + { + IndexName = indexName, + QueryVector = new _Vector { Elements = { queryVector } }, + TopK = validatedTopK, + MetadataFields = metadataRequest, + }; + + if (scoreThreshold != null) + { + request.ScoreThreshold = scoreThreshold.Value; + } + else + { + request.NoScoreThreshold = new _NoScoreThreshold(); + } + + var response = + await grpcManager.Client.SearchAndFetchVectorsAsync(request, + new CallOptions(deadline: CalculateDeadline())); + var searchHits = response.Hits.Select(h => + new SearchAndFetchVectorsHit(h.Id, h.Score, h.Vector.Elements.ToList(), Convert(h.Metadata))).ToList(); + return _logger.LogTraceVectorIndexRequestSuccess("searchAndFetchVectors", indexName, + new SearchAndFetchVectorsResponse.Success(searchHits)); + } + catch (Exception e) + { + return _logger.LogTraceVectorIndexRequestError("searchAndFetchVectors", indexName, + new SearchAndFetchVectorsResponse.Error(_exceptionMapper.Convert(e))); + } + } + private static _Item Convert(Item item) { return new _Item @@ -179,6 +230,11 @@ private static SearchHit Convert(_SearchHit hit) return new SearchHit(hit.Id, hit.Score, Convert(hit.Metadata)); } + private static SearchAndFetchVectorsHit Convert(_SearchAndFetchVectorsHit hit) + { + return new SearchAndFetchVectorsHit(hit.Id, hit.Score, hit.Vector.Elements.ToList(), Convert(hit.Metadata)); + } + private static void CheckValidIndexName(string indexName) { if (string.IsNullOrWhiteSpace(indexName)) diff --git a/src/Momento.Sdk/Internal/VectorIndexDataGrpcManager.cs b/src/Momento.Sdk/Internal/VectorIndexDataGrpcManager.cs index f30a865e..f255fd7b 100644 --- a/src/Momento.Sdk/Internal/VectorIndexDataGrpcManager.cs +++ b/src/Momento.Sdk/Internal/VectorIndexDataGrpcManager.cs @@ -21,6 +21,7 @@ public interface IVectorIndexDataClient { public Task<_UpsertItemBatchResponse> UpsertItemBatchAsync(_UpsertItemBatchRequest request, CallOptions callOptions); public Task<_SearchResponse> SearchAsync(_SearchRequest request, CallOptions callOptions); + public Task<_SearchAndFetchVectorsResponse> SearchAndFetchVectorsAsync(_SearchAndFetchVectorsRequest request, CallOptions callOptions); public Task<_DeleteItemBatchResponse> DeleteItemBatchAsync(_DeleteItemBatchRequest request, CallOptions callOptions); } @@ -55,6 +56,12 @@ public async Task<_SearchResponse> SearchAsync(_SearchRequest request, CallOptio var wrapped = await _middlewares.WrapRequest(request, callOptions, (r, o) => _generatedClient.SearchAsync(r, o)); return await wrapped.ResponseAsync; } + + public async Task<_SearchAndFetchVectorsResponse> SearchAndFetchVectorsAsync(_SearchAndFetchVectorsRequest request, CallOptions callOptions) + { + var wrapped = await _middlewares.WrapRequest(request, callOptions, (r, o) => _generatedClient.SearchAndFetchVectorsAsync(r, o)); + return await wrapped.ResponseAsync; + } public async Task<_DeleteItemBatchResponse> DeleteItemBatchAsync(_DeleteItemBatchRequest request, CallOptions callOptions) { diff --git a/src/Momento.Sdk/PreviewVectorIndexClient.cs b/src/Momento.Sdk/PreviewVectorIndexClient.cs index 59633bfe..e085d66c 100644 --- a/src/Momento.Sdk/PreviewVectorIndexClient.cs +++ b/src/Momento.Sdk/PreviewVectorIndexClient.cs @@ -14,7 +14,7 @@ namespace Momento.Sdk; /// /// Includes control operations and data operations. /// -public class PreviewVectorIndexClient: IPreviewVectorIndexClient +public class PreviewVectorIndexClient : IPreviewVectorIndexClient { private readonly VectorIndexControlClient controlClient; private readonly VectorIndexDataClient dataClient; @@ -74,6 +74,15 @@ public async Task SearchAsync(string indexName, IEnumerable + public async Task SearchAndFetchVectorsAsync(string indexName, + IEnumerable queryVector, int topK = 10, MetadataFields? metadataFields = null, + float? scoreThreshold = null) + { + return await dataClient.SearchAndFetchVectorsAsync(indexName, queryVector, topK, metadataFields, + scoreThreshold); + } + /// public void Dispose() { diff --git a/src/Momento.Sdk/Responses/Vector/SearchAndFetchVectorsResponse.cs b/src/Momento.Sdk/Responses/Vector/SearchAndFetchVectorsResponse.cs new file mode 100644 index 00000000..8798341d --- /dev/null +++ b/src/Momento.Sdk/Responses/Vector/SearchAndFetchVectorsResponse.cs @@ -0,0 +1,83 @@ +using System.Collections.Generic; +using System.Linq; +using Momento.Sdk.Exceptions; + +namespace Momento.Sdk.Responses.Vector; + +/// +/// Parent response type for a list vector indexes request. The +/// response object is resolved to a type-safe object of one of +/// the following subtypes: +/// +/// SearchAndFetchVectorsResponse.Success +/// SearchAndFetchVectorsResponse.Error +/// +/// Pattern matching can be used to operate on the appropriate subtype. +/// For example: +/// +/// if (response is SearchAndFetchVectorsResponse.Success successResponse) +/// { +/// return successResponse.Hits; +/// } +/// else if (response is SearchAndFetchVectorsResponse.Error errorResponse) +/// { +/// // handle error as appropriate +/// } +/// else +/// { +/// // handle unexpected response +/// } +/// +/// +public abstract class SearchAndFetchVectorsResponse +{ + /// + public class Success : SearchAndFetchVectorsResponse + { + /// + /// The list of hits returned by the search. + /// + public List Hits { get; } + + /// + /// the search results + public Success(List hits) + { + Hits = hits; + } + + /// + public override string ToString() + { + var displayedHits = Hits.Take(5).Select(hit => $"{hit.Id} ({hit.Score})"); + return $"{base.ToString()}: {string.Join(", ", displayedHits)}..."; + } + + } + + /// + public class Error : SearchAndFetchVectorsResponse, IError + { + /// + public Error(SdkException error) + { + InnerException = error; + } + + /// + public SdkException InnerException { get; } + + /// + public MomentoErrorCode ErrorCode => InnerException.ErrorCode; + + /// + public string Message => $"{InnerException.MessageWrapper}: {InnerException.Message}"; + + /// + public override string ToString() + { + return $"{base.ToString()}: {Message}"; + } + + } +} diff --git a/src/Momento.Sdk/Responses/Vector/SearchHit.cs b/src/Momento.Sdk/Responses/Vector/SearchHit.cs index eeb746b7..7419ad32 100644 --- a/src/Momento.Sdk/Responses/Vector/SearchHit.cs +++ b/src/Momento.Sdk/Responses/Vector/SearchHit.cs @@ -1,3 +1,4 @@ +using System.Linq; using Momento.Sdk.Messages.Vector; namespace Momento.Sdk.Responses.Vector; @@ -14,17 +15,17 @@ public class SearchHit /// The ID of the hit. /// public string Id { get; } - + /// /// The similarity to the query vector. /// public double Score { get; } - + /// /// Requested metadata associated with the hit. /// public Dictionary Metadata { get; } - + /// /// Constructs a SearchHit with no metadata. /// @@ -36,7 +37,7 @@ public SearchHit(string id, double score) Score = score; Metadata = new Dictionary(); } - + /// /// Constructs a SearchHit. /// @@ -61,7 +62,7 @@ public override bool Equals(object obj) // ReSharper disable once CompareOfFloatsByEqualityOperator if (Id != other.Id || Score != other.Score) return false; - + // Compare Metadata dictionaries if (Metadata.Count != other.Metadata.Count) return false; @@ -95,3 +96,67 @@ public override int GetHashCode() } } +/// +/// A hit from a vector search and fetch vectors. Contains the ID of the vector, the search score, +/// the vector, and any requested metadata. +/// +public class SearchAndFetchVectorsHit : SearchHit +{ + /// + /// The similarity to the query vector. + /// + public List Vector { get; } + + /// + /// Constructs a SearchAndFetchVectorsHit with no metadata. + /// + /// The ID of the hit. + /// The similarity to the query vector. + /// The vector of the hit. + public SearchAndFetchVectorsHit(string id, double score, List vector) : base(id, score) + { + Vector = vector; + } + + /// + /// Constructs a SearchAndFetchVectorsHit. + /// + /// The ID of the hit. + /// The similarity to the query vector. + /// The vector of the hit. + /// Requested metadata associated with the hit + public SearchAndFetchVectorsHit(string id, double score, List vector, + Dictionary metadata) : base(id, score, metadata) + { + Vector = vector; + } + + /// + /// Constructs a SearchAndFetchVectorsHit from a SearchHit. + /// + /// A SearchHit containing an ID, score, and metadata + /// The vector of the hit. + public SearchAndFetchVectorsHit(SearchHit searchHit, List vector) : base(searchHit.Id, searchHit.Score, + searchHit.Metadata) + { + Vector = vector; + } + + /// + public override bool Equals(object obj) + { + if (ReferenceEquals(this, obj)) return true; + // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract + if (obj is null || GetType() != obj.GetType()) return false; + + var other = (SearchAndFetchVectorsHit)obj; + + return base.Equals(other) && Vector.SequenceEqual(other.Vector); + } + + /// + public override int GetHashCode() + { + return base.GetHashCode() ^ Vector.GetHashCode(); + } +} \ No newline at end of file diff --git a/tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs b/tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs index a892bd55..e9a2b9a3 100644 --- a/tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs +++ b/tests/Integration/Momento.Sdk.Tests/VectorIndexDataTest.cs @@ -1,4 +1,5 @@ using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; using Momento.Sdk.Messages.Vector; using Momento.Sdk.Requests.Vector; @@ -57,8 +58,56 @@ public async Task SearchAsync_InvalidIndexName() Assert.Equal(MomentoErrorCode.INVALID_ARGUMENT_ERROR, error.InnerException.ErrorCode); } - [Fact] - public async Task UpsertAndSearch_InnerProduct() + public delegate Task SearchDelegate(IPreviewVectorIndexClient client, string indexName, + IEnumerable queryVector, int topK = 10, + MetadataFields? metadataFields = null, float? scoreThreshold = null); + + public delegate void AssertOnSearchResponse(T response, List expectedHits, + List> expectedVectors); + + public static IEnumerable UpsertAndSearchTestData + { + get + { + return new List + { + new object[] + { + new SearchDelegate( + (client, indexName, queryVector, topK, metadata, scoreThreshold) => + client.SearchAsync(indexName, queryVector, topK, metadata, scoreThreshold)), + new AssertOnSearchResponse((response, expectedHits, _) => + { + Assert.True(response is SearchResponse.Success, $"Unexpected response: {response}"); + var successResponse = (SearchResponse.Success)response; + Assert.Equal(expectedHits, successResponse.Hits); + }) + }, + new object[] + { + new SearchDelegate( + (client, indexName, queryVector, topK, metadata, scoreThreshold) => + client.SearchAndFetchVectorsAsync(indexName, queryVector, topK, metadata, + scoreThreshold)), + new AssertOnSearchResponse( + (response, expectedHits, expectedVectors) => + { + Assert.True(response is SearchAndFetchVectorsResponse.Success, + $"Unexpected response: {response}"); + var successResponse = (SearchAndFetchVectorsResponse.Success)response; + var expectedHitsAndVectors = expectedHits.Zip(expectedVectors, + (h, v) => new SearchAndFetchVectorsHit(h, v)); + Assert.Equal(expectedHitsAndVectors, successResponse.Hits); + }) + } + }; + } + } + + [Theory] + [MemberData(nameof(UpsertAndSearchTestData))] + public async Task UpsertAndSearch_InnerProduct(SearchDelegate searchDelegate, + AssertOnSearchResponse assertOnSearchResponse) { var indexName = $"index-{Utils.NewGuidString()}"; @@ -67,22 +116,23 @@ public async Task UpsertAndSearch_InnerProduct() try { - var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, new List + var items = new List { new("test_item", new List { 1.0f, 2.0f }) - }); + }; + + var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, items); Assert.True(upsertResponse is UpsertItemBatchResponse.Success, $"Unexpected response: {upsertResponse}"); await Task.Delay(2_000); - var searchResponse = await vectorIndexClient.SearchAsync(indexName, new List { 1.0f, 2.0f }); - Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); - var successResponse = (SearchResponse.Success)searchResponse; - Assert.Equal(new List + var searchResponse = + await searchDelegate.Invoke(vectorIndexClient, indexName, new List { 1.0f, 2.0f }); + assertOnSearchResponse.Invoke(searchResponse, new List { new("test_item", 5.0f) - }, successResponse.Hits); + }, items.Select(i => i.Vector).ToList()); } finally { @@ -90,8 +140,10 @@ public async Task UpsertAndSearch_InnerProduct() } } - [Fact] - public async Task UpsertAndSearch_CosineSimilarity() + [Theory] + [MemberData(nameof(UpsertAndSearchTestData))] + public async Task UpsertAndSearch_CosineSimilarity(SearchDelegate searchDelegate, + AssertOnSearchResponse assertOnSearchResponse) { var indexName = $"index-{Utils.NewGuidString()}"; @@ -100,26 +152,26 @@ public async Task UpsertAndSearch_CosineSimilarity() try { - var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, new List + var items = new List { new("test_item_1", new List { 1.0f, 1.0f }), new("test_item_2", new List { -1.0f, 1.0f }), new("test_item_3", new List { -1.0f, -1.0f }) - }); + }; + var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, items); Assert.True(upsertResponse is UpsertItemBatchResponse.Success, $"Unexpected response: {upsertResponse}"); await Task.Delay(2_000); - var searchResponse = await vectorIndexClient.SearchAsync(indexName, new List { 2.0f, 2.0f }); - Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); - var successResponse = (SearchResponse.Success)searchResponse; - Assert.Equal(new List + var searchResponse = + await searchDelegate.Invoke(vectorIndexClient, indexName, new List { 2.0f, 2.0f }); + assertOnSearchResponse.Invoke(searchResponse, new List { new("test_item_1", 1.0f), new("test_item_2", 0.0f), new("test_item_3", -1.0f) - }, successResponse.Hits); + }, items.Select(i => i.Vector).ToList()); } finally { @@ -127,8 +179,10 @@ public async Task UpsertAndSearch_CosineSimilarity() } } - [Fact] - public async Task UpsertAndSearch_EuclideanSimilarity() + [Theory] + [MemberData(nameof(UpsertAndSearchTestData))] + public async Task UpsertAndSearch_EuclideanSimilarity(SearchDelegate searchDelegate, + AssertOnSearchResponse assertOnSearchResponse) { var indexName = $"index-{Utils.NewGuidString()}"; @@ -138,26 +192,26 @@ public async Task UpsertAndSearch_EuclideanSimilarity() try { - var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, new List + var items = new List { new("test_item_1", new List { 1.0f, 1.0f }), new("test_item_2", new List { -1.0f, 1.0f }), new("test_item_3", new List { -1.0f, -1.0f }) - }); + }; + var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, items); Assert.True(upsertResponse is UpsertItemBatchResponse.Success, $"Unexpected response: {upsertResponse}"); await Task.Delay(2_000); - var searchResponse = await vectorIndexClient.SearchAsync(indexName, new List { 1.0f, 1.0f }); - Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); - var successResponse = (SearchResponse.Success)searchResponse; - Assert.Equal(new List + var searchResponse = + await searchDelegate.Invoke(vectorIndexClient, indexName, new List { 1.0f, 1.0f }); + assertOnSearchResponse.Invoke(searchResponse, new List { new("test_item_1", 0.0f), new("test_item_2", 4.0f), new("test_item_3", 8.0f) - }, successResponse.Hits); + }, items.Select(i => i.Vector).ToList()); } finally { @@ -165,8 +219,10 @@ public async Task UpsertAndSearch_EuclideanSimilarity() } } - [Fact] - public async Task UpsertAndSearch_TopKLimit() + [Theory] + [MemberData(nameof(UpsertAndSearchTestData))] + public async Task UpsertAndSearch_TopKLimit(SearchDelegate searchDelegate, + AssertOnSearchResponse assertOnSearchResponse) { var indexName = $"index-{Utils.NewGuidString()}"; @@ -175,25 +231,29 @@ public async Task UpsertAndSearch_TopKLimit() try { - var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, new List + var items = new List { new("test_item_1", new List { 1.0f, 2.0f }), new("test_item_2", new List { 3.0f, 4.0f }), new("test_item_3", new List { 5.0f, 6.0f }) - }); + }; + var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, items); Assert.True(upsertResponse is UpsertItemBatchResponse.Success, $"Unexpected response: {upsertResponse}"); await Task.Delay(2_000); - var searchResponse = await vectorIndexClient.SearchAsync(indexName, new List { 1.0f, 2.0f }, 2); - Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); - var successResponse = (SearchResponse.Success)searchResponse; - Assert.Equal(new List + var searchResponse = + await searchDelegate.Invoke(vectorIndexClient, indexName, new List { 1.0f, 2.0f }, topK: 2); + assertOnSearchResponse.Invoke(searchResponse, new List { new("test_item_3", 17.0f), new("test_item_2", 11.0f) - }, successResponse.Hits); + }, new List> + { + new() { 5.0f, 6.0f }, + new() { 3.0f, 4.0f } + }); } finally { @@ -201,8 +261,10 @@ public async Task UpsertAndSearch_TopKLimit() } } - [Fact] - public async Task UpsertAndSearch_WithMetadata() + [Theory] + [MemberData(nameof(UpsertAndSearchTestData))] + public async Task UpsertAndSearch_WithMetadata(SearchDelegate searchDelegate, + AssertOnSearchResponse assertOnSearchResponse) { var indexName = $"index-{Utils.NewGuidString()}"; @@ -211,7 +273,7 @@ public async Task UpsertAndSearch_WithMetadata() try { - var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, new List + var items = new List { new("test_item_1", new List { 1.0f, 2.0f }, new Dictionary { { "key1", "value1" } }), @@ -220,30 +282,35 @@ public async Task UpsertAndSearch_WithMetadata() new("test_item_3", new List { 5.0f, 6.0f }, new Dictionary { { "key1", "value3" }, { "key3", "value3" } }) - }); + }; + var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, items); Assert.True(upsertResponse is UpsertItemBatchResponse.Success, $"Unexpected response: {upsertResponse}"); await Task.Delay(2_000); - var searchResponse = await vectorIndexClient.SearchAsync(indexName, new List { 1.0f, 2.0f }, 3, - new List { "key1" }); - Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); - var successResponse = (SearchResponse.Success)searchResponse; - Assert.Equal(new List + var expectedVectors = new List> + { + new() { 5.0f, 6.0f }, + new() { 3.0f, 4.0f }, + new() { 1.0f, 2.0f } + }; + var searchResponse = + await searchDelegate.Invoke(vectorIndexClient, indexName, new List { 1.0f, 2.0f }, 3, + new List { "key1" }); + assertOnSearchResponse.Invoke(searchResponse, new List { new("test_item_3", 17.0f, new Dictionary { { "key1", "value3" } }), new("test_item_2", 11.0f, new Dictionary()), new("test_item_1", 5.0f, new Dictionary { { "key1", "value1" } }) - }, successResponse.Hits); + }, expectedVectors); - searchResponse = await vectorIndexClient.SearchAsync(indexName, new List { 1.0f, 2.0f }, 3, - new List { "key1", "key2", "key3", "key4" }); - Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); - successResponse = (SearchResponse.Success)searchResponse; - Assert.Equal(new List + searchResponse = + await searchDelegate.Invoke(vectorIndexClient, indexName, new List { 1.0f, 2.0f }, 3, + new List { "key1", "key2", "key3", "key4" }); + assertOnSearchResponse.Invoke(searchResponse, new List { new("test_item_3", 17.0f, new Dictionary @@ -252,13 +319,12 @@ public async Task UpsertAndSearch_WithMetadata() new Dictionary { { "key2", "value2" } }), new("test_item_1", 5.0f, new Dictionary { { "key1", "value1" } }) - }, successResponse.Hits); + }, expectedVectors); searchResponse = - await vectorIndexClient.SearchAsync(indexName, new List { 1.0f, 2.0f }, 3, MetadataFields.All); - Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); - successResponse = (SearchResponse.Success)searchResponse; - Assert.Equal(new List + await searchDelegate.Invoke(vectorIndexClient, indexName, new List { 1.0f, 2.0f }, 3, + MetadataFields.All); + assertOnSearchResponse.Invoke(searchResponse, new List { new("test_item_3", 17.0f, new Dictionary @@ -267,7 +333,7 @@ public async Task UpsertAndSearch_WithMetadata() new Dictionary { { "key2", "value2" } }), new("test_item_1", 5.0f, new Dictionary { { "key1", "value1" } }) - }, successResponse.Hits); + }, expectedVectors); } finally { @@ -275,8 +341,10 @@ public async Task UpsertAndSearch_WithMetadata() } } - [Fact] - public async Task UpsertAndSearch_WithDiverseMetadata() + [Theory] + [MemberData(nameof(UpsertAndSearchTestData))] + public async Task UpsertAndSearch_WithDiverseMetadata(SearchDelegate searchDelegate, + AssertOnSearchResponse assertOnSearchResponse) { var indexName = $"index-{Utils.NewGuidString()}"; @@ -294,23 +362,24 @@ public async Task UpsertAndSearch_WithDiverseMetadata() { "list_key", new List { "a", "b", "c" } }, { "empty_list_key", new List() } }; - var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, new List + var items = new List { new("test_item_1", new List { 1.0f, 2.0f }, metadata) - }); + }; + + var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, items); Assert.True(upsertResponse is UpsertItemBatchResponse.Success, $"Unexpected response: {upsertResponse}"); await Task.Delay(2_000); var searchResponse = - await vectorIndexClient.SearchAsync(indexName, new List { 1.0f, 2.0f }, 1, MetadataFields.All); - Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); - var successResponse = (SearchResponse.Success)searchResponse; - Assert.Equal(new List + await searchDelegate.Invoke(vectorIndexClient, indexName, new List { 1.0f, 2.0f }, 1, + MetadataFields.All); + assertOnSearchResponse.Invoke(searchResponse, new List { new("test_item_1", 5.0f, metadata) - }, successResponse.Hits); + }, items.Select(i => i.Vector).ToList()); } finally { @@ -341,11 +410,17 @@ public async Task UpsertAndSearch_WithDiverseMetadata() new List { 3.0f, 20.0f, -0.01f } } }; + + // Combine the search threshold parameters and the search/search with vectors parameters + public static IEnumerable UpsertAndSearchThresholdTestCases => + SearchThresholdTestCases.SelectMany( + _ => UpsertAndSearchTestData, + (firstArray, secondArray) => firstArray.Concat(secondArray).ToArray()); [Theory] - [MemberData(nameof(SearchThresholdTestCases))] - public async Task Search_PruneBasedOnThreshold(SimilarityMetric similarityMetric, List scores, - List thresholds) + [MemberData(nameof(UpsertAndSearchThresholdTestCases))] + public async Task Search_PruneBasedOnThreshold(SimilarityMetric similarityMetric, List scores, + List thresholds, SearchDelegate searchDelegate, AssertOnSearchResponse assertOnSearchResponse) { var indexName = $"index-{Utils.NewGuidString()}"; @@ -354,12 +429,13 @@ public async Task Search_PruneBasedOnThreshold(SimilarityMetric similarityMetric try { - var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, new List + var items = new List { new("test_item_1", new List { 1.0f, 1.0f }), new("test_item_2", new List { -1.0f, 1.0f }), new("test_item_3", new List { -1.0f, -1.0f }) - }); + }; + var upsertResponse = await vectorIndexClient.UpsertItemBatchAsync(indexName, items); Assert.True(upsertResponse is UpsertItemBatchResponse.Success, $"Unexpected response: {upsertResponse}"); @@ -375,27 +451,21 @@ public async Task Search_PruneBasedOnThreshold(SimilarityMetric similarityMetric // Test threshold to get only the top result var searchResponse = - await vectorIndexClient.SearchAsync(indexName, queryVector, 3, scoreThreshold: thresholds[0]); - Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); - var successResponse = (SearchResponse.Success)searchResponse; - Assert.Equal(new List + await searchDelegate.Invoke(vectorIndexClient, indexName, queryVector, 3, scoreThreshold: thresholds[0]); + assertOnSearchResponse.Invoke(searchResponse, new List { searchHits[0] - }, successResponse.Hits); + }, items.FindAll(i => i.Id == "test_item_1").Select(i => i.Vector).ToList()); // Test threshold to get all results searchResponse = - await vectorIndexClient.SearchAsync(indexName, queryVector, 3, scoreThreshold: thresholds[1]); - Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); - successResponse = (SearchResponse.Success)searchResponse; - Assert.Equal(searchHits, successResponse.Hits); + await searchDelegate.Invoke(vectorIndexClient, indexName, queryVector, 3, scoreThreshold: thresholds[1]); + assertOnSearchResponse.Invoke(searchResponse, searchHits, items.Select(i => i.Vector).ToList()); // Test threshold to get no results searchResponse = - await vectorIndexClient.SearchAsync(indexName, queryVector, 3, scoreThreshold: thresholds[2]); - Assert.True(searchResponse is SearchResponse.Success, $"Unexpected response: {searchResponse}"); - successResponse = (SearchResponse.Success)searchResponse; - Assert.Empty(successResponse.Hits); + await searchDelegate.Invoke(vectorIndexClient, indexName, queryVector, 3, scoreThreshold: thresholds[2]); + assertOnSearchResponse.Invoke(searchResponse, new List(), new List>()); } finally {