Skip to content

Commit

Permalink
feat: Add MVI SearchAndFetchVectors method (#517)
Browse files Browse the repository at this point in the history
Add a new version of the MVI search method that returns the vectors
with the hits.

Paramaterize the search integration tests so that they cover both search
methods.
  • Loading branch information
nand4011 authored Nov 14, 2023
1 parent fc41d11 commit f1c5f3f
Show file tree
Hide file tree
Showing 7 changed files with 414 additions and 91 deletions.
33 changes: 33 additions & 0 deletions src/Momento.Sdk/IPreviewVectorIndexClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,37 @@ public Task<UpsertItemBatchResponse> UpsertItemBatchAsync(string indexName,
/// </returns>
public Task<SearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector, int topK = 10,
MetadataFields? metadataFields = null, float? scoreThreshold = null);

/// <summary>
/// Searches for the most similar vectors to the query vector in the index.
/// Ranks the vectors according to the similarity metric specified when the
/// index was created. Also returns the vectors associated with each result.
/// </summary>
/// <param name="indexName">The name of the vector index to search in.</param>
/// <param name="queryVector">The vector to search for.</param>
/// <param name="topK">The number of results to return. Defaults to 10.</param>
/// <param name="metadataFields">A list of metadata fields to return with each result.</param>
/// <param name="scoreThreshold">A score threshold to filter results by. For cosine
/// similarity and inner product, scores lower than the threshold are excluded. For
/// euclidean similarity, scores higher than the threshold are excluded. The threshold
/// is exclusive. Defaults to None, ie no threshold.</param>
/// <returns>
/// Task representing the result of the upsert operation. The
/// response object is resolved to a type-safe object of one of
/// the following subtypes:
/// <list type="bullet">
/// <item><description>SearchResponse.Success</description></item>
/// <item><description>SearchResponse.Error</description></item>
/// </list>
/// Pattern matching can be used to operate on the appropriate subtype.
/// For example:
/// <code>
/// if (response is SearchResponse.Error errorResponse)
/// {
/// // handle error as appropriate
/// }
/// </code>
/// </returns>
public Task<SearchAndFetchVectorsResponse> SearchAndFetchVectorsAsync(string indexName, IEnumerable<float> queryVector, int topK = 10,
MetadataFields? metadataFields = null, float? scoreThreshold = null);
}
58 changes: 57 additions & 1 deletion src/Momento.Sdk/Internal/VectorIndexDataClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public async Task<SearchResponse> SearchAsync(string indexName, IEnumerable<floa
{
request.NoScoreThreshold = new _NoScoreThreshold();
}

var response =
await grpcManager.Client.SearchAsync(request, new CallOptions(deadline: CalculateDeadline()));
var searchHits = response.Hits.Select(Convert).ToList();
Expand All @@ -116,6 +116,57 @@ public async Task<SearchResponse> SearchAsync(string indexName, IEnumerable<floa
}
}

public async Task<SearchAndFetchVectorsResponse> SearchAndFetchVectorsAsync(string indexName,
IEnumerable<float> queryVector, int topK, MetadataFields? metadataFields, float? scoreThreshold)
{
try
{
_logger.LogTraceVectorIndexRequest("searchAndFetchVectors", indexName);
CheckValidIndexName(indexName);
var validatedTopK = ValidateTopK(topK);
metadataFields ??= new List<string>();
var metadataRequest = metadataFields switch
{
MetadataFields.AllFields => new _MetadataRequest { All = new _MetadataRequest.Types.All() },
MetadataFields.List list => new _MetadataRequest
{
Some = new _MetadataRequest.Types.Some { Fields = { list.Fields } }
},
_ => throw new InvalidArgumentException($"Unknown metadata fields type {metadataFields.GetType()}")
};

var request = new _SearchAndFetchVectorsRequest()
{
IndexName = indexName,
QueryVector = new _Vector { Elements = { queryVector } },
TopK = validatedTopK,
MetadataFields = metadataRequest,
};

if (scoreThreshold != null)
{
request.ScoreThreshold = scoreThreshold.Value;
}
else
{
request.NoScoreThreshold = new _NoScoreThreshold();
}

var response =
await grpcManager.Client.SearchAndFetchVectorsAsync(request,
new CallOptions(deadline: CalculateDeadline()));
var searchHits = response.Hits.Select(h =>
new SearchAndFetchVectorsHit(h.Id, h.Score, h.Vector.Elements.ToList(), Convert(h.Metadata))).ToList();
return _logger.LogTraceVectorIndexRequestSuccess("searchAndFetchVectors", indexName,
new SearchAndFetchVectorsResponse.Success(searchHits));
}
catch (Exception e)
{
return _logger.LogTraceVectorIndexRequestError("searchAndFetchVectors", indexName,
new SearchAndFetchVectorsResponse.Error(_exceptionMapper.Convert(e)));
}
}

private static _Item Convert(Item item)
{
return new _Item
Expand Down Expand Up @@ -179,6 +230,11 @@ private static SearchHit Convert(_SearchHit hit)
return new SearchHit(hit.Id, hit.Score, Convert(hit.Metadata));
}

private static SearchAndFetchVectorsHit Convert(_SearchAndFetchVectorsHit hit)
{
return new SearchAndFetchVectorsHit(hit.Id, hit.Score, hit.Vector.Elements.ToList(), Convert(hit.Metadata));
}

private static void CheckValidIndexName(string indexName)
{
if (string.IsNullOrWhiteSpace(indexName))
Expand Down
7 changes: 7 additions & 0 deletions src/Momento.Sdk/Internal/VectorIndexDataGrpcManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public interface IVectorIndexDataClient
{
public Task<_UpsertItemBatchResponse> UpsertItemBatchAsync(_UpsertItemBatchRequest request, CallOptions callOptions);
public Task<_SearchResponse> SearchAsync(_SearchRequest request, CallOptions callOptions);
public Task<_SearchAndFetchVectorsResponse> SearchAndFetchVectorsAsync(_SearchAndFetchVectorsRequest request, CallOptions callOptions);
public Task<_DeleteItemBatchResponse> DeleteItemBatchAsync(_DeleteItemBatchRequest request, CallOptions callOptions);
}

Expand Down Expand Up @@ -55,6 +56,12 @@ public async Task<_SearchResponse> SearchAsync(_SearchRequest request, CallOptio
var wrapped = await _middlewares.WrapRequest(request, callOptions, (r, o) => _generatedClient.SearchAsync(r, o));
return await wrapped.ResponseAsync;
}

public async Task<_SearchAndFetchVectorsResponse> SearchAndFetchVectorsAsync(_SearchAndFetchVectorsRequest request, CallOptions callOptions)
{
var wrapped = await _middlewares.WrapRequest(request, callOptions, (r, o) => _generatedClient.SearchAndFetchVectorsAsync(r, o));
return await wrapped.ResponseAsync;
}

public async Task<_DeleteItemBatchResponse> DeleteItemBatchAsync(_DeleteItemBatchRequest request, CallOptions callOptions)
{
Expand Down
11 changes: 10 additions & 1 deletion src/Momento.Sdk/PreviewVectorIndexClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Momento.Sdk;
///
/// Includes control operations and data operations.
/// </summary>
public class PreviewVectorIndexClient: IPreviewVectorIndexClient
public class PreviewVectorIndexClient : IPreviewVectorIndexClient
{
private readonly VectorIndexControlClient controlClient;
private readonly VectorIndexDataClient dataClient;
Expand Down Expand Up @@ -74,6 +74,15 @@ public async Task<SearchResponse> SearchAsync(string indexName, IEnumerable<floa
return await dataClient.SearchAsync(indexName, queryVector, topK, metadataFields, searchThreshold);
}

/// <inheritdoc />
public async Task<SearchAndFetchVectorsResponse> SearchAndFetchVectorsAsync(string indexName,
IEnumerable<float> queryVector, int topK = 10, MetadataFields? metadataFields = null,
float? scoreThreshold = null)
{
return await dataClient.SearchAndFetchVectorsAsync(indexName, queryVector, topK, metadataFields,
scoreThreshold);
}

/// <inheritdoc />
public void Dispose()
{
Expand Down
83 changes: 83 additions & 0 deletions src/Momento.Sdk/Responses/Vector/SearchAndFetchVectorsResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
using System.Collections.Generic;
using System.Linq;
using Momento.Sdk.Exceptions;

namespace Momento.Sdk.Responses.Vector;

/// <summary>
/// Parent response type for a list vector indexes request. The
/// response object is resolved to a type-safe object of one of
/// the following subtypes:
/// <list type="bullet">
/// <item><description>SearchAndFetchVectorsResponse.Success</description></item>
/// <item><description>SearchAndFetchVectorsResponse.Error</description></item>
/// </list>
/// Pattern matching can be used to operate on the appropriate subtype.
/// For example:
/// <code>
/// if (response is SearchAndFetchVectorsResponse.Success successResponse)
/// {
/// return successResponse.Hits;
/// }
/// else if (response is SearchAndFetchVectorsResponse.Error errorResponse)
/// {
/// // handle error as appropriate
/// }
/// else
/// {
/// // handle unexpected response
/// }
/// </code>
/// </summary>
public abstract class SearchAndFetchVectorsResponse
{
/// <include file="../../docs.xml" path='docs/class[@name="Success"]/description/*' />
public class Success : SearchAndFetchVectorsResponse
{
/// <summary>
/// The list of hits returned by the search.
/// </summary>
public List<SearchAndFetchVectorsHit> Hits { get; }

/// <include file="../../docs.xml" path='docs/class[@name="Success"]/description/*' />
/// <param name="hits">the search results</param>
public Success(List<SearchAndFetchVectorsHit> hits)
{
Hits = hits;
}

/// <inheritdoc />
public override string ToString()
{
var displayedHits = Hits.Take(5).Select(hit => $"{hit.Id} ({hit.Score})");
return $"{base.ToString()}: {string.Join(", ", displayedHits)}...";
}

}

/// <include file="../../docs.xml" path='docs/class[@name="Error"]/description/*' />
public class Error : SearchAndFetchVectorsResponse, IError
{
/// <include file="../../docs.xml" path='docs/class[@name="Error"]/constructor/*' />
public Error(SdkException error)
{
InnerException = error;
}

/// <inheritdoc />
public SdkException InnerException { get; }

/// <inheritdoc />
public MomentoErrorCode ErrorCode => InnerException.ErrorCode;

/// <inheritdoc />
public string Message => $"{InnerException.MessageWrapper}: {InnerException.Message}";

/// <inheritdoc />
public override string ToString()
{
return $"{base.ToString()}: {Message}";
}

}
}
75 changes: 70 additions & 5 deletions src/Momento.Sdk/Responses/Vector/SearchHit.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Linq;
using Momento.Sdk.Messages.Vector;

namespace Momento.Sdk.Responses.Vector;
Expand All @@ -14,17 +15,17 @@ public class SearchHit
/// The ID of the hit.
/// </summary>
public string Id { get; }

/// <summary>
/// The similarity to the query vector.
/// </summary>
public double Score { get; }

/// <summary>
/// Requested metadata associated with the hit.
/// </summary>
public Dictionary<string, MetadataValue> Metadata { get; }

/// <summary>
/// Constructs a SearchHit with no metadata.
/// </summary>
Expand All @@ -36,7 +37,7 @@ public SearchHit(string id, double score)
Score = score;
Metadata = new Dictionary<string, MetadataValue>();
}

/// <summary>
/// Constructs a SearchHit.
/// </summary>
Expand All @@ -61,7 +62,7 @@ public override bool Equals(object obj)

// ReSharper disable once CompareOfFloatsByEqualityOperator
if (Id != other.Id || Score != other.Score) return false;

// Compare Metadata dictionaries
if (Metadata.Count != other.Metadata.Count) return false;

Expand Down Expand Up @@ -95,3 +96,67 @@ public override int GetHashCode()
}
}

/// <summary>
/// A hit from a vector search and fetch vectors. Contains the ID of the vector, the search score,
/// the vector, and any requested metadata.
/// </summary>
public class SearchAndFetchVectorsHit : SearchHit
{
/// <summary>
/// The similarity to the query vector.
/// </summary>
public List<float> Vector { get; }

/// <summary>
/// Constructs a SearchAndFetchVectorsHit with no metadata.
/// </summary>
/// <param name="id">The ID of the hit.</param>
/// <param name="score">The similarity to the query vector.</param>
/// <param name="vector">The vector of the hit.</param>
public SearchAndFetchVectorsHit(string id, double score, List<float> vector) : base(id, score)
{
Vector = vector;
}

/// <summary>
/// Constructs a SearchAndFetchVectorsHit.
/// </summary>
/// <param name="id">The ID of the hit.</param>
/// <param name="score">The similarity to the query vector.</param>
/// <param name="vector">The vector of the hit.</param>
/// <param name="metadata">Requested metadata associated with the hit</param>
public SearchAndFetchVectorsHit(string id, double score, List<float> vector,
Dictionary<string, MetadataValue> metadata) : base(id, score, metadata)
{
Vector = vector;
}

/// <summary>
/// Constructs a SearchAndFetchVectorsHit from a SearchHit.
/// </summary>
/// <param name="searchHit">A SearchHit containing an ID, score, and metadata</param>
/// <param name="vector">The vector of the hit.</param>
public SearchAndFetchVectorsHit(SearchHit searchHit, List<float> vector) : base(searchHit.Id, searchHit.Score,
searchHit.Metadata)
{
Vector = vector;
}

/// <inheritdoc />
public override bool Equals(object obj)
{
if (ReferenceEquals(this, obj)) return true;
// ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
if (obj is null || GetType() != obj.GetType()) return false;

var other = (SearchAndFetchVectorsHit)obj;

return base.Equals(other) && Vector.SequenceEqual(other.Vector);
}

/// <inheritdoc />
public override int GetHashCode()
{
return base.GetHashCode() ^ Vector.GetHashCode();
}
}
Loading

0 comments on commit f1c5f3f

Please sign in to comment.