diff --git a/sdk/python/feast/expediagroup/vectordb/milvus_online_store.py b/sdk/python/feast/expediagroup/vectordb/milvus_online_store.py index 00e7ca6f30..7293c63ca3 100644 --- a/sdk/python/feast/expediagroup/vectordb/milvus_online_store.py +++ b/sdk/python/feast/expediagroup/vectordb/milvus_online_store.py @@ -411,11 +411,12 @@ def _create_index_params(self, tags: Dict[str, str], data_type: DataType): metric_type = "L2" if "metric_type" in tags: - metric_type = tags["metric_type"] + metric_type = tags["metric_type"].upper() return { "metric_type": metric_type, - "index_type": index_type_name, + # Note: Milvus aliases variations of IVF_FLAT to the IVFLAT enum, but requires "IVF_FLAT" for index creation + "index_type": index_type_name.replace("IVFLAT", "IVF_FLAT"), "params": params, } diff --git a/sdk/python/tests/expediagroup/test_milvus_online_store.py b/sdk/python/tests/expediagroup/test_milvus_online_store.py index 55ee1fa966..1362c58f27 100644 --- a/sdk/python/tests/expediagroup/test_milvus_online_store.py +++ b/sdk/python/tests/expediagroup/test_milvus_online_store.py @@ -1,3 +1,4 @@ +import json import logging import random from datetime import datetime @@ -26,7 +27,7 @@ from feast.protos.feast.types.Value_pb2 import FloatList from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RepoConfig -from feast.types import Array, Float32, Int64, String +from feast.types import Array, Bytes, Float32, Int64, String from tests.expediagroup.milvus_online_store_creator import MilvusOnlineStoreCreator logging.basicConfig(level=logging.INFO) @@ -77,7 +78,6 @@ def embedded_milvus(): class TestMilvusConnectionManager: def test_connection_manager(self, repo_config, caplog, mocker): - mocker.patch("pymilvus.connections.connect") with MilvusConnectionManager(repo_config.online_store): assert ( @@ -116,7 +116,6 @@ def test_context_manager_exit(self, repo_config, caplog, mocker): class TestMilvusOnlineStore: - collection_to_write = "Collection2" collection_to_delete = "Collection1" unavailable_collection = "abc" @@ -180,7 +179,64 @@ def _create_n_customer_test_samples_milvus(self, n=10): for i in range(n) ] - def test_milvus_update_add_collection(self, repo_config, caplog): + index_param_list = [ + { + "metric_type": "L2", + "index_type": "FLAT", + "params": {}, + }, + { + "metric_type": "L2", + "index_type": "IVF_FLAT", + "params": {"nlist": 64}, + }, + { + "metric_type": "L2", + "index_type": "IVF_SQ8", + "params": {"nlist": 64}, + }, + { + "metric_type": "IP", + "index_type": "IVF_PQ", + "params": {"nlist": 64, "m": 2, "nbits": 8}, + }, + { + "metric_type": "L2", + "index_type": "HNSW", + "params": {"M": 32, "efConstruction": 256}, + }, + { + "metric_type": "HAMMING", + "index_type": "BIN_FLAT", + "params": {}, + }, + { + "metric_type": "JACCARD", + "index_type": "BIN_IVF_FLAT", + "params": {"nlist": 64}, + }, + ] + + @pytest.mark.parametrize("index_params", index_param_list) + def test_milvus_update_add_collection(self, repo_config, caplog, index_params): + dimensions = 16 + vector_type = Float32 + if "BIN" in index_params["index_type"]: + vector_type = Bytes + + vector_tags = { + "is_primary": "False", + "description": vector_type.name, + "dimensions": dimensions, + "index_type": index_params["index_type"], + } + + if "metric_type" in index_params and index_params["metric_type"]: + vector_tags["metric_type"] = index_params["metric_type"] + + if "params" in index_params and index_params["params"]: + vector_tags["index_params"] = json.dumps(index_params["params"]) + entity = Entity(name="feature2") feast_schema = [ Field( @@ -190,14 +246,8 @@ def test_milvus_update_add_collection(self, repo_config, caplog): ), Field( name="feature1", - dtype=Array(Float32), - tags={ - "is_primary": "False", - "description": "float32", - "dimensions": 10, - "index_type": "HNSW", - "index_params": '{ "M": 32, "efConstruction": 256}', - }, + dtype=Array(vector_type), + tags=vector_tags, ), ] @@ -226,10 +276,12 @@ def test_milvus_update_add_collection(self, repo_config, caplog): ), FieldSchema( "feature1", - DataType.FLOAT_VECTOR, - description="float32", + DataType.FLOAT_VECTOR + if vector_type == Float32 + else DataType.BINARY_VECTOR, + description=vector_type.name, is_primary=False, - dim=10, + dim=dimensions, ), ], ) @@ -239,10 +291,12 @@ def test_milvus_update_add_collection(self, repo_config, caplog): fields=[ FieldSchema( "feature1", - DataType.FLOAT_VECTOR, - description="float32", + DataType.FLOAT_VECTOR + if vector_type == Float32 + else DataType.BINARY_VECTOR, + description=vector_type.name, is_primary=False, - dim=10, + dim=dimensions, ), FieldSchema( "feature2", DataType.INT64, description="int64", is_primary=True @@ -250,11 +304,6 @@ def test_milvus_update_add_collection(self, repo_config, caplog): ], ) - index_params = { - "metric_type": "L2", - "index_type": "HNSW", - "params": {"M": 32, "efConstruction": 256}, - } # Here we want to open and check whether the collection was added and then close the connection. with MilvusConnectionManager(repo_config.online_store): assert utility.has_collection(self.collection_to_write)