Skip to content

Commit

Permalink
Merge pull request #35 from ExpediaGroup/tipike/bug_milvus_ivflat
Browse files Browse the repository at this point in the history
Bugfix: rename IVFLAT to the expected string IVF_FLAT
  • Loading branch information
piket authored Sep 20, 2023
2 parents 6e17017 + af02212 commit 84a5c77
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 25 deletions.
5 changes: 3 additions & 2 deletions sdk/python/feast/expediagroup/vectordb/milvus_online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,12 @@ def _create_index_params(self, tags: Dict[str, str], data_type: DataType):

metric_type = "L2"
if "metric_type" in tags:
metric_type = tags["metric_type"]
metric_type = tags["metric_type"].upper()

return {
"metric_type": metric_type,
"index_type": index_type_name,
# Note: Milvus aliases variations of IVF_FLAT to the IVFLAT enum, but requires "IVF_FLAT" for index creation
"index_type": index_type_name.replace("IVFLAT", "IVF_FLAT"),
"params": params,
}

Expand Down
95 changes: 72 additions & 23 deletions sdk/python/tests/expediagroup/test_milvus_online_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import random
from datetime import datetime
Expand Down Expand Up @@ -26,7 +27,7 @@
from feast.protos.feast.types.Value_pb2 import FloatList
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import RepoConfig
from feast.types import Array, Float32, Int64, String
from feast.types import Array, Bytes, Float32, Int64, String
from tests.expediagroup.milvus_online_store_creator import MilvusOnlineStoreCreator

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -77,7 +78,6 @@ def embedded_milvus():

class TestMilvusConnectionManager:
def test_connection_manager(self, repo_config, caplog, mocker):

mocker.patch("pymilvus.connections.connect")
with MilvusConnectionManager(repo_config.online_store):
assert (
Expand Down Expand Up @@ -116,7 +116,6 @@ def test_context_manager_exit(self, repo_config, caplog, mocker):


class TestMilvusOnlineStore:

collection_to_write = "Collection2"
collection_to_delete = "Collection1"
unavailable_collection = "abc"
Expand Down Expand Up @@ -180,7 +179,64 @@ def _create_n_customer_test_samples_milvus(self, n=10):
for i in range(n)
]

def test_milvus_update_add_collection(self, repo_config, caplog):
index_param_list = [
{
"metric_type": "L2",
"index_type": "FLAT",
"params": {},
},
{
"metric_type": "L2",
"index_type": "IVF_FLAT",
"params": {"nlist": 64},
},
{
"metric_type": "L2",
"index_type": "IVF_SQ8",
"params": {"nlist": 64},
},
{
"metric_type": "IP",
"index_type": "IVF_PQ",
"params": {"nlist": 64, "m": 2, "nbits": 8},
},
{
"metric_type": "L2",
"index_type": "HNSW",
"params": {"M": 32, "efConstruction": 256},
},
{
"metric_type": "HAMMING",
"index_type": "BIN_FLAT",
"params": {},
},
{
"metric_type": "JACCARD",
"index_type": "BIN_IVF_FLAT",
"params": {"nlist": 64},
},
]

@pytest.mark.parametrize("index_params", index_param_list)
def test_milvus_update_add_collection(self, repo_config, caplog, index_params):
dimensions = 16
vector_type = Float32
if "BIN" in index_params["index_type"]:
vector_type = Bytes

vector_tags = {
"is_primary": "False",
"description": vector_type.name,
"dimensions": dimensions,
"index_type": index_params["index_type"],
}

if "metric_type" in index_params and index_params["metric_type"]:
vector_tags["metric_type"] = index_params["metric_type"]

if "params" in index_params and index_params["params"]:
vector_tags["index_params"] = json.dumps(index_params["params"])

entity = Entity(name="feature2")
feast_schema = [
Field(
Expand All @@ -190,14 +246,8 @@ def test_milvus_update_add_collection(self, repo_config, caplog):
),
Field(
name="feature1",
dtype=Array(Float32),
tags={
"is_primary": "False",
"description": "float32",
"dimensions": 10,
"index_type": "HNSW",
"index_params": '{ "M": 32, "efConstruction": 256}',
},
dtype=Array(vector_type),
tags=vector_tags,
),
]

Expand Down Expand Up @@ -226,10 +276,12 @@ def test_milvus_update_add_collection(self, repo_config, caplog):
),
FieldSchema(
"feature1",
DataType.FLOAT_VECTOR,
description="float32",
DataType.FLOAT_VECTOR
if vector_type == Float32
else DataType.BINARY_VECTOR,
description=vector_type.name,
is_primary=False,
dim=10,
dim=dimensions,
),
],
)
Expand All @@ -239,22 +291,19 @@ def test_milvus_update_add_collection(self, repo_config, caplog):
fields=[
FieldSchema(
"feature1",
DataType.FLOAT_VECTOR,
description="float32",
DataType.FLOAT_VECTOR
if vector_type == Float32
else DataType.BINARY_VECTOR,
description=vector_type.name,
is_primary=False,
dim=10,
dim=dimensions,
),
FieldSchema(
"feature2", DataType.INT64, description="int64", is_primary=True
),
],
)

index_params = {
"metric_type": "L2",
"index_type": "HNSW",
"params": {"M": 32, "efConstruction": 256},
}
# Here we want to open and check whether the collection was added and then close the connection.
with MilvusConnectionManager(repo_config.online_store):
assert utility.has_collection(self.collection_to_write)
Expand Down

0 comments on commit 84a5c77

Please sign in to comment.