Skip to content

Commit

Permalink
feat: Add MVI data methods
Browse files Browse the repository at this point in the history
Add the upsert, delete, and search MVI methods to the vector client.

Add an internal MVI data client and grpc manager.

Update the protos to get the latest MVI changes.
  • Loading branch information
nand4011 committed Oct 26, 2023
1 parent f86ce06 commit 5aa0bad
Show file tree
Hide file tree
Showing 14 changed files with 1,447 additions and 10 deletions.
88 changes: 85 additions & 3 deletions src/Momento.Sdk/IPreviewVectorIndexClient.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Momento.Sdk.Requests.Vector;
using Momento.Sdk.Responses.Vector;
Expand All @@ -11,7 +12,7 @@ namespace Momento.Sdk;
///
/// Includes control operations and data operations.
/// </summary>
public interface IPreviewVectorIndexClient: IDisposable
public interface IPreviewVectorIndexClient : IDisposable
{
/// <summary>
/// Creates a vector index if it does not exist.
Expand Down Expand Up @@ -57,8 +58,9 @@ public interface IPreviewVectorIndexClient: IDisposable
/// </list>
/// </remarks>
/// </returns>
public Task<CreateVectorIndexResponse> CreateIndexAsync(string indexName, ulong numDimensions, SimilarityMetric similarityMetric = SimilarityMetric.CosineSimilarity);

public Task<CreateVectorIndexResponse> CreateIndexAsync(string indexName, ulong numDimensions,
SimilarityMetric similarityMetric = SimilarityMetric.CosineSimilarity);

/// <summary>
/// Lists all vector indexes.
/// </summary>
Expand Down Expand Up @@ -103,4 +105,84 @@ public interface IPreviewVectorIndexClient: IDisposable
/// </code>
///</returns>
public Task<DeleteVectorIndexResponse> DeleteIndexesAsync(string indexName);

/// <summary>
/// Upserts a batch of items into a vector index.
/// If an item with the same ID already exists in the index, it will be replaced.
/// Otherwise, it will be added to the index.
/// </summary>
/// <param name="indexName">The name of the vector index to delete.</param>
/// <param name="items">The items to upsert into the index.</param>
/// <returns>
/// Task representing the result of the upsert operation. The
/// response object is resolved to a type-safe object of one of
/// the following subtypes:
/// <list type="bullet">
/// <item><description>VectorUpsertItemBatchResponse.Success</description></item>
/// <item><description>VectorUpsertItemBatchResponse.Error</description></item>
/// </list>
/// Pattern matching can be used to operate on the appropriate subtype.
/// For example:
/// <code>
/// if (response is VectorUpsertItemBatchResponse.Error errorResponse)
/// {
/// // handle error as appropriate
/// }
/// </code>
///</returns>
public Task<VectorUpsertItemBatchResponse> UpsertItemBatchAsync(string indexName,
IEnumerable<VectorIndexItem> items);

/// <summary>
/// Deletes all items with the given IDs from the index.
/// </summary>
/// <param name="indexName">The name of the vector index to delete.</param>
/// <param name="ids">The IDs of the items to delete from the index.</param>
/// <returns>
/// Task representing the result of the upsert operation. The
/// response object is resolved to a type-safe object of one of
/// the following subtypes:
/// <list type="bullet">
/// <item><description>VectorDeleteItemBatchResponse.Success</description></item>
/// <item><description>VectorDeleteItemBatchResponse.Error</description></item>
/// </list>
/// Pattern matching can be used to operate on the appropriate subtype.
/// For example:
/// <code>
/// if (response is VectorDeleteItemBatchResponse.Error errorResponse)
/// {
/// // handle error as appropriate
/// }
/// </code>
///</returns>
public Task<VectorDeleteItemBatchResponse> DeleteItemBatchAsync(string indexName, IEnumerable<string> ids);

/// <summary>
/// Searches for the most similar vectors to the query vector in the index.
/// Ranks the vectors according to the similarity metric specified when the
/// index was created.
/// </summary>
/// <param name="indexName">The name of the vector index to delete.</param>
/// <param name="queryVector">The vector to search for.</param>
/// <param name="topK">The number of results to return. Defaults to 10.</param>
/// <param name="metadataFields">A list of metadata fields to return with each result.</param>
/// <returns>
/// Task representing the result of the upsert operation. The
/// response object is resolved to a type-safe object of one of
/// the following subtypes:
/// <list type="bullet">
/// <item><description>VectorDeleteItemBatchResponse.Success</description></item>
/// <item><description>VectorDeleteItemBatchResponse.Error</description></item>
/// </list>
/// Pattern matching can be used to operate on the appropriate subtype.
/// For example:
/// <code>
/// if (response is VectorDeleteItemBatchResponse.Error errorResponse)
/// {
/// // handle error as appropriate
/// }
/// </code>
///</returns>
public Task<VectorSearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector, uint topK = 10,
MetadataFields? metadataFields = null);
}
198 changes: 198 additions & 0 deletions src/Momento.Sdk/Internal/VectorIndexDataClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Grpc.Core;
using Microsoft.Extensions.Logging;
using Momento.Sdk.Config;
using Momento.Sdk.Exceptions;
using Momento.Sdk.Messages.Vector;
using Momento.Sdk.Requests.Vector;
using Momento.Sdk.Responses.Vector;
using Vectorindex;

namespace Momento.Sdk.Internal;

internal sealed class VectorIndexDataClient : IDisposable
{
private readonly VectorIndexDataGrpcManager grpcManager;
private readonly TimeSpan deadline = TimeSpan.FromSeconds(60);

private readonly ILogger _logger;
private readonly CacheExceptionMapper _exceptionMapper;

public VectorIndexDataClient(IVectorIndexConfiguration config, string authToken, string endpoint)
{
grpcManager = new VectorIndexDataGrpcManager(config, authToken, endpoint);
_logger = config.LoggerFactory.CreateLogger<VectorIndexDataClient>();
_exceptionMapper = new CacheExceptionMapper(config.LoggerFactory);
}

public async Task<VectorUpsertItemBatchResponse> UpsertItemBatchAsync(string indexName,
IEnumerable<VectorIndexItem> items)
{
try
{
_logger.LogTraceVectorIndexRequest("upsertItemBatch", indexName);
CheckValidIndexName(indexName);
var request = new _UpsertItemBatchRequest() { IndexName = indexName, Items = { items.Select(Convert) } };

await grpcManager.Client.UpsertItemBatchAsync(request, new CallOptions(deadline: CalculateDeadline()));
return _logger.LogTraceVectorIndexRequestSuccess("upsertItemBatch", indexName,
new VectorUpsertItemBatchResponse.Success());
}
catch (Exception e)
{
return _logger.LogTraceVectorIndexRequestError("upsertItemBatch", indexName,
new VectorUpsertItemBatchResponse.Error(_exceptionMapper.Convert(e)));
}
}

public async Task<VectorDeleteItemBatchResponse> DeleteItemBatchAsync(string indexName, IEnumerable<string> ids)
{
try
{
_logger.LogTraceVectorIndexRequest("deleteItemBatch", indexName);
CheckValidIndexName(indexName);
var request = new _DeleteItemBatchRequest() { IndexName = indexName, Ids = { ids } };

await grpcManager.Client.DeleteItemBatchAsync(request, new CallOptions(deadline: CalculateDeadline()));
return _logger.LogTraceVectorIndexRequestSuccess("deleteItemBatch", indexName,
new VectorDeleteItemBatchResponse.Success());
}
catch (Exception e)
{
return _logger.LogTraceVectorIndexRequestError("deleteItemBatch", indexName,
new VectorDeleteItemBatchResponse.Error(_exceptionMapper.Convert(e)));
}
}

public async Task<VectorSearchResponse> SearchAsync(string indexName, IEnumerable<float> queryVector, uint topK,
MetadataFields? metadataFields)
{
try
{
_logger.LogTraceVectorIndexRequest("search", indexName);
CheckValidIndexName(indexName);
metadataFields ??= new List<string>();
var metadataRequest = metadataFields switch
{
MetadataFields.AllFields => new _MetadataRequest { All = new _MetadataRequest.Types.All() },
MetadataFields.List list => new _MetadataRequest
{
Some = new _MetadataRequest.Types.Some { Fields = { list.Fields } }
},
_ => throw new InvalidArgumentException($"Unknown metadata fields type {metadataFields.GetType()}")
};

var request = new _SearchRequest
{
IndexName = indexName,
QueryVector = new _Vector { Elements = { queryVector } },
TopK = topK,
MetadataFields = metadataRequest
};

var response =
await grpcManager.Client.SearchAsync(request, new CallOptions(deadline: CalculateDeadline()));
var searchHits = response.Hits.Select(Convert).ToList();
return _logger.LogTraceVectorIndexRequestSuccess("search", indexName,
new VectorSearchResponse.Success(searchHits));
}
catch (Exception e)
{
return _logger.LogTraceVectorIndexRequestError("search", indexName,
new VectorSearchResponse.Error(_exceptionMapper.Convert(e)));
}
}

private static _Item Convert(VectorIndexItem item)
{
return new _Item
{
Id = item.Id, Vector = new _Vector { Elements = { item.Vector } }, Metadata = { Convert(item.Metadata) }
};
}

private static IEnumerable<_Metadata> Convert(Dictionary<string, MetadataValue> metadata)
{
var convertedMetadataList = new List<_Metadata>();
foreach (var metadataPair in metadata)
{
_Metadata convertedMetadata;
switch (metadataPair.Value)
{
case StringValue stringValue:
convertedMetadata = new _Metadata { Field = metadataPair.Key, StringValue = stringValue.Value };
break;
case LongValue longValue:
convertedMetadata = new _Metadata { Field = metadataPair.Key, IntegerValue = longValue.Value };
break;
case DoubleValue doubleValue:
convertedMetadata = new _Metadata { Field = metadataPair.Key, DoubleValue = doubleValue.Value };
break;
case BoolValue boolValue:
convertedMetadata = new _Metadata { Field = metadataPair.Key, BooleanValue = boolValue.Value };
break;
case StringListValue stringListValue:
var listOfStrings = new _Metadata.Types._ListOfStrings { Values = { stringListValue.Value } };
convertedMetadata = new _Metadata { Field = metadataPair.Key, ListOfStringsValue = listOfStrings };
break;
default:
throw new InvalidArgumentException($"Unknown metadata type {metadataPair.Value.GetType()}");
}

convertedMetadataList.Add(convertedMetadata);
}

return convertedMetadataList;
}

private static Dictionary<string, MetadataValue> Convert(IEnumerable<_Metadata> metadata)
{
return metadata.ToDictionary(m => m.Field, Convert);
}

private static MetadataValue Convert(_Metadata metadata)
{
switch (metadata.ValueCase)
{
case _Metadata.ValueOneofCase.StringValue:
return new StringValue(metadata.StringValue);
case _Metadata.ValueOneofCase.IntegerValue:
return new LongValue(metadata.IntegerValue);
case _Metadata.ValueOneofCase.DoubleValue:
return new DoubleValue(metadata.DoubleValue);
case _Metadata.ValueOneofCase.BooleanValue:
return new BoolValue(metadata.BooleanValue);
case _Metadata.ValueOneofCase.ListOfStringsValue:
return new StringListValue(metadata.ListOfStringsValue.Values.ToList());
case _Metadata.ValueOneofCase.None:
default:
throw new UnknownException($"Unknown metadata type {metadata.ValueCase}");
}
}

private static SearchHit Convert(_SearchHit hit)
{
return new SearchHit(hit.Id, hit.Distance, Convert(hit.Metadata));
}

private static void CheckValidIndexName(string indexName)
{
if (string.IsNullOrWhiteSpace(indexName))
{
throw new InvalidArgumentException("Index name must be nonempty");
}
}

private DateTime CalculateDeadline()
{
return DateTime.UtcNow.Add(deadline);
}

public void Dispose()
{
grpcManager.Dispose();
}
}
Loading

0 comments on commit 5aa0bad

Please sign in to comment.