diff --git a/docs/reference/online-stores/milvus.md b/docs/reference/online-stores/milvus.md index 81433c4c4a..8182f91c86 100644 --- a/docs/reference/online-stores/milvus.md +++ b/docs/reference/online-stores/milvus.md @@ -54,15 +54,11 @@ An example feature view: Field( name="book_id", dtype=Int64, - tags={ - "is_primary": "True", - }, ), Field( name="book_embedding", dtype=Array(Float32), tags={ - "is_primary": "False", "description": "book embedding of the content", "dimensions": "2200", "index_type": IndexType.ivf_flat.value, @@ -70,6 +66,7 @@ An example feature view: "nlist": 1024, }, "metric_type": "L2", + } ), ], source=SOURCE, @@ -110,7 +107,7 @@ Below is a matrix indicating which functionality is supported by the Milvus onli | readable by Python SDK | yes | | readable by Java | no | | readable by Go | no | -| support for entityless feature views | yes | +| support for entityless feature views | no | | support for concurrent writing to the same key | yes | | support for ttl (time to live) at retrieval | no | | support for deleting expired data | no | diff --git a/sdk/python/feast/expediagroup/vectordb/milvus_online_store.py b/sdk/python/feast/expediagroup/vectordb/milvus_online_store.py index be10aba781..abd8509799 100644 --- a/sdk/python/feast/expediagroup/vectordb/milvus_online_store.py +++ b/sdk/python/feast/expediagroup/vectordb/milvus_online_store.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import numpy as np +import pandas as pd from bidict import bidict from pydantic.typing import Literal from pymilvus import ( @@ -17,7 +18,6 @@ from pymilvus.client.types import IndexType from feast import Entity, FeatureView, RepoConfig -from feast.field import Field from feast.infra.online_stores.online_store import OnlineStore from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import FloatList @@ -102,14 +102,15 @@ def online_write_batch( progress: Optional[Callable[[int], Any]], ) -> None: with MilvusConnectionManager(config.online_store): - try: - rows = self._format_data_for_milvus(data) - collection_to_load_data = Collection(table.name) - collection_to_load_data.insert(rows) - # The flush call will seal any remaining segments and send them for indexing - collection_to_load_data.flush() - except Exception as e: - logger.error(f"Batch writing data failed due to {e}") + collection_to_load_data = Collection(table.name) + rows = self._format_data_for_milvus(data, collection_to_load_data) + collection_to_load_data.insert(rows) + # The flush call will seal any remaining segments and send them for indexing + collection_to_load_data.flush() + collection_to_load_data.load() + logger.info("loading data into memory") + utility.wait_for_loading_complete(table.name) + logger.info("loading data into memory complete") def online_read( self, @@ -130,7 +131,8 @@ def online_read( query_result, collection, requested_features ) - return results + # results do not have timestamps + return [(None, row) for row in results] @log_exceptions_and_usage(online_store="milvus") def update( @@ -145,28 +147,33 @@ def update( with MilvusConnectionManager(config.online_store): for table_to_keep in tables_to_keep: collection_available = utility.has_collection(table_to_keep.name) - try: - if collection_available: - logger.info(f"Collection {table_to_keep.name} already exists.") - else: - ( - schema, - indexes, - ) = self._convert_featureview_schema_to_milvus_readable( - table_to_keep.schema, + + if collection_available: + logger.info(f"Collection {table_to_keep.name} already exists.") + else: + if not table_to_keep.schema: + raise ValueError( + f"a schema must be provided for feature_view: {table_to_keep}" ) - collection = Collection(name=table_to_keep.name, schema=schema) + ( + schema, + indexes, + ) = self._convert_featureview_schema_to_milvus_readable( + table_to_keep, + ) + + logger.info( + f"creating collection {table_to_keep.name} with schema: {schema}" + ) + collection = Collection(name=table_to_keep.name, schema=schema) - for field_name, index_params in indexes.items(): - collection.create_index(field_name, index_params) + for field_name, index_params in indexes.items(): + collection.create_index(field_name, index_params) - logger.info(f"Collection name is {collection.name}") - logger.info( - f"Collection {table_to_keep.name} has been created successfully." - ) - except Exception as e: - logger.error(f"Collection update failed due to {e}") + logger.info( + f"Collection {table_to_keep.name} has been created successfully." + ) for table_to_delete in tables_to_delete: collection_available = utility.has_collection(table_to_delete.name) @@ -198,35 +205,33 @@ def teardown( utility.drop_collection(collection_name) def _convert_featureview_schema_to_milvus_readable( - self, feast_schema: List[Field] + self, feature_view: FeatureView ) -> Tuple[CollectionSchema, Dict]: """ Converting a schema understood by Feast to a schema that is readable by Milvus so that it can be used when a collection is created in Milvus. Parameters: - feast_schema (List[Field]): Schema stored in FeatureView. + feature_view (FeatureView): the FeatureView that contains the schema. Returns: (CollectionSchema): Schema readable by Milvus. (Dict): A dictionary of indexes to be created with the key as the vector field name and the value as the parameters """ - boolean_mapping_from_string = {"True": True, "False": False} field_list = [] indexes = {} - for field in feast_schema: + for field in feature_view.schema: field_name = field.name data_type = self._get_milvus_type(field.dtype) dimensions = 0 + description = "" + is_primary = True if field.name in feature_view.join_keys else False if field.tags: - description = field.tags.get("description", " ") - is_primary = boolean_mapping_from_string.get( - field.tags.get("is_primary", "False") - ) + description = field.tags.get("description", "") if self._data_type_is_supported_vector(data_type) and field.tags.get( "index_type" @@ -281,7 +286,7 @@ def _data_type_is_supported_vector(self, data_type: DataType) -> bool: return False - def _format_data_for_milvus(self, feast_data): + def _format_data_for_milvus(self, feast_data, collection: Collection): """ Format Feast input for Milvus: Data stored into Milvus takes the grouped representation approach where each feature value is grouped together: [[1,2], [1,3]], [John, Lucy], [3,4]] @@ -289,19 +294,49 @@ def _format_data_for_milvus(self, feast_data): Parameters: feast_data: List[ Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]: Data represented for batch write in Feast + collection: target collection Returns: - List[List]: transformed_data: Data that can be directly written into Milvus + pd.DataFrame: transformed_data: Data that can be directly written into Milvus """ + # get the order of columns so that return data frame has the correct order. Milvus does need the correct order + # and does not use the column names when a data frame is passed. + schema = collection.schema + field_names = [field.name for field in schema.fields] - milvus_data = [] + data = [] + feature_names = None for entity_key, values, timestamp, created_ts in feast_data: - feature = self._process_values_for_milvus(values) - milvus_data.append(feature) + feature_names = [entity_key.join_keys[0]] + feature = [self._get_value_from_value_proto(entity_key.entity_values[0])] + for feature_name, val in values.items(): + value = self._get_value_from_value_proto(val) + feature.append(value) + feature_names.append(feature_name) + data.append(feature) + + df = pd.DataFrame(data, columns=feature_names) + transformed_data = df.reindex(field_names, axis=1) - transformed_data = [list(item) for item in zip(*milvus_data)] return transformed_data + def _get_value_from_value_proto(self, proto: ValueProto): + """ + Get the raw value from a value proto. + + Parameters: + value (ValueProto): the value proto that contains the data. + + Returns: + value (Any): the extracted value. + """ + val_type = proto.WhichOneof("val") + value = getattr(proto, val_type) # type: ignore + if val_type == "float_list_val": + value = np.array(value.val) + + return value + def _create_index_params(self, tags: Dict[str, str], data_type: DataType): """ Parses the tags to generate the index_params needed to create the specified index @@ -351,7 +386,7 @@ def _create_index_params(self, tags: Dict[str, str], data_type: DataType): } def _convert_milvus_result_to_feast_type( - self, milvus_result, collection, features_to_request + self, query_result, collection, requested_features ): """ Convert Milvus result to Feast types. @@ -365,27 +400,22 @@ def _convert_milvus_result_to_feast_type( List[Dict[str, ValueProto]]: Processed data with Feast types. """ - # Here we are constructing the feature list to request from Milvus with their relevant types - + # constructing the feature list to request from Milvus with their respective types features_with_types = list(tuple()) for field in collection.schema.fields: - if field.name in features_to_request: + if field.name in requested_features: features_with_types.append( (field.name, self._get_feast_type(field.dtype)) ) - feast_type_result = [] + results = [] prefix = "valuetype." - - for row in milvus_result: + for row in query_result: result_row = {} for feature, feast_type in features_with_types: - value_proto = ValueProto() feature_value = row[feature] - if feature_value: - # Doing some pre-processing here to remove prefix value_type_method = f"{feast_type.to_value_type()}_val".lower() if value_type_method.startswith(prefix): value_type_method = value_type_method[len(prefix) :] @@ -394,8 +424,9 @@ def _convert_milvus_result_to_feast_type( ) result_row[feature] = value_proto # Append result after conversion to Feast Type - feast_type_result.append(result_row) - return feast_type_result + results.append(result_row) + + return results def _create_value_proto(self, val_proto, feature_val, value_type) -> ValueProto: """ @@ -409,11 +440,11 @@ def _create_value_proto(self, val_proto, feature_val, value_type) -> ValueProto: Returns: val_proto (ValueProto): Constructed result that Feast can understand. """ - if value_type == "float_list_val": val_proto = ValueProto(float_list_val=FloatList(val=feature_val)) else: setattr(val_proto, value_type, feature_val) + return val_proto def _construct_milvus_query(self, entities) -> str: @@ -435,7 +466,7 @@ def _construct_milvus_query(self, entities) -> str: for key in entity.join_keys: entity_join_key.append(key) for value in entity.entity_values: - value_to_search = self._get_value_to_search_in_milvus(value) + value_to_search = self._get_value_from_value_proto(value) values_to_search.append(value_to_search) # TODO: Enable multiple join key support. Currently only supporting a single primary key/ join key. This is a limitation in Feast. @@ -443,39 +474,6 @@ def _construct_milvus_query(self, entities) -> str: return milvus_query_expr - def _process_values_for_milvus(self, values) -> List: - """ - Process values to prepare them for using in Milvus. - - Parameters: - values: (Dict[str, ValueProto]): Dictionary of values from Feast data. - - Returns: - (List): Processed feature values ready for storing in Milvus. - """ - feature = [] - for feature_name, val in values.items(): - value = self._get_value_to_search_in_milvus(val) - feature.append(value) - return feature - - def _get_value_to_search_in_milvus(self, value) -> Any: - """ - Process a value to prepare it for searching in Milvus. - - Parameters: - value (ValueProto): A value from Feast data. - - Returns: - value (Any): Processed value ready for Milvus searching. - """ - val_type = value.WhichOneof("val") - if val_type == "float_list_val": - value = np.array(value.float_list_val.val) - else: - value = getattr(value, val_type) - return value - def _get_milvus_type(self, feast_type) -> DataType: """ Convert Feast type to Milvus type using the TYPE_MAPPING bidict. diff --git a/sdk/python/tests/expediagroup/milvus_online_store_creator.py b/sdk/python/tests/expediagroup/milvus_online_store_creator.py index bfc20a6bc9..4008ec4ab1 100644 --- a/sdk/python/tests/expediagroup/milvus_online_store_creator.py +++ b/sdk/python/tests/expediagroup/milvus_online_store_creator.py @@ -21,7 +21,7 @@ def create_online_store(self) -> Dict[str, str]: "Milvus Proxy successfully initialized and ready to serve!" ) wait_for_logs( - container=self.container, predicate=log_string_to_wait_for, timeout=30 + container=self.container, predicate=log_string_to_wait_for, timeout=60 ) exposed_port = self.container.get_exposed_port("19530") diff --git a/sdk/python/tests/expediagroup/test_milvus_online_store.py b/sdk/python/tests/expediagroup/test_milvus_online_store.py index d4f4c3fe26..cd0857292b 100644 --- a/sdk/python/tests/expediagroup/test_milvus_online_store.py +++ b/sdk/python/tests/expediagroup/test_milvus_online_store.py @@ -13,6 +13,7 @@ ) from feast import FeatureView +from feast.entity import Entity from feast.expediagroup.vectordb.milvus_online_store import ( MilvusConnectionManager, MilvusOnlineStore, @@ -180,11 +181,12 @@ def _create_n_customer_test_samples_milvus(self, n=10): ] def test_milvus_update_add_collection(self, repo_config, caplog): + entity = Entity(name="feature2") feast_schema = [ Field( name="feature2", dtype=Int64, - tags={"is_primary": "True", "description": "int64"}, + tags={"description": "int64"}, ), Field( name="feature1", @@ -205,6 +207,7 @@ def test_milvus_update_add_collection(self, repo_config, caplog): tables_to_keep=[ FeatureView( name=self.collection_to_write, + entities=[entity], schema=feast_schema, source=SOURCE, ) @@ -264,13 +267,13 @@ def test_milvus_update_add_collection(self, repo_config, caplog): assert indexes[0].params == index_params def test_milvus_update_add_existing_collection(self, repo_config, caplog): + entity = Entity(name="feature2") # Creating a common schema for collection feast_schema = [ Field( name="feature1", dtype=Array(Float32), tags={ - "is_primary": "False", "description": "float32", "dimensions": "128", "index_type": "HNSW", @@ -280,7 +283,7 @@ def test_milvus_update_add_existing_collection(self, repo_config, caplog): Field( name="feature2", dtype=Int64, - tags={"is_primary": "True", "description": "int64"}, + tags={"description": "int64"}, ), ] @@ -292,6 +295,7 @@ def test_milvus_update_add_existing_collection(self, repo_config, caplog): tables_to_keep=[ FeatureView( name=self.collection_to_write, + entities=[entity], schema=feast_schema, source=SOURCE, ) @@ -307,13 +311,13 @@ def test_milvus_update_add_existing_collection(self, repo_config, caplog): assert len(utility.list_collections()) == 1 def test_milvus_update_delete_collection(self, repo_config, caplog): + entity = Entity(name="feature2") # Creating a common schema for collection which is compatible with FEAST feast_schema = [ Field( name="feature1", dtype=Array(Float32), tags={ - "is_primary": "False", "description": "float32", "dimensions": "128", "index_type": "HNSW", @@ -323,7 +327,7 @@ def test_milvus_update_delete_collection(self, repo_config, caplog): Field( name="feature2", dtype=Int64, - tags={"is_primary": "True", "description": "int64"}, + tags={"description": "int64"}, ), ] @@ -349,6 +353,7 @@ def test_milvus_update_delete_collection(self, repo_config, caplog): tables_to_delete=[ FeatureView( name=self.collection_to_write, + entities=[entity], schema=feast_schema, source=SOURCE, ) @@ -364,12 +369,12 @@ def test_milvus_update_delete_collection(self, repo_config, caplog): assert utility.has_collection(self.collection_to_write) is False def test_milvus_update_delete_unavailable_collection(self, repo_config, caplog): + entity = Entity(name="feature2") feast_schema = [ Field( name="feature1", dtype=Array(Float32), tags={ - "is_primary": "False", "description": "float32", "dimensions": "128", "index_type": "HNSW", @@ -379,7 +384,7 @@ def test_milvus_update_delete_unavailable_collection(self, repo_config, caplog): Field( name="feature2", dtype=Int64, - tags={"is_primary": "True", "description": "int64"}, + tags={"description": "int64"}, ), ] @@ -388,6 +393,7 @@ def test_milvus_update_delete_unavailable_collection(self, repo_config, caplog): tables_to_delete=[ FeatureView( name=self.unavailable_collection, + entities=[entity], schema=feast_schema, source=SOURCE, ) @@ -502,10 +508,13 @@ def _create_collection_in_milvus(self, collection_name, repo_config): def _write_data_to_milvus(self, collection_name, data, repo_config): with MilvusConnectionManager(repo_config.online_store): - rows = MilvusOnlineStore()._format_data_for_milvus(data) collection_to_load_data = Collection(collection_name) + rows = MilvusOnlineStore()._format_data_for_milvus( + data, collection_to_load_data + ) collection_to_load_data.insert(rows) collection_to_load_data.flush() + collection_to_load_data.load() def test_milvus_online_read(self, repo_config, caplog): @@ -578,7 +587,7 @@ def test_milvus_online_read(self, repo_config, caplog): assert result is not None assert len(result) == 10 - assert result[0]["film_id"].int64_val == 0 - assert result[0]["film_date"].int64_val == 2000 - assert result[-1]["film_id"].int64_val == 9 - assert result[-1]["film_date"].int64_val == 2009 + assert result[0][1]["film_id"].int64_val == 0 + assert result[0][1]["film_date"].int64_val == 2000 + assert result[9][1]["film_id"].int64_val == 9 + assert result[9][1]["film_date"].int64_val == 2009 diff --git a/sdk/python/tests/integration/feature_repos/universal/feature_views.py b/sdk/python/tests/integration/feature_repos/universal/feature_views.py index f5f3523bc6..178ce82f5e 100644 --- a/sdk/python/tests/integration/feature_repos/universal/feature_views.py +++ b/sdk/python/tests/integration/feature_repos/universal/feature_views.py @@ -306,13 +306,13 @@ def create_vector_feature_view(source): name="driver_profile", entities=[driver_entity], schema=[ + Field(name=driver_entity.join_key, dtype=Int64), + Field(name="lifetime_trip_count", dtype=Int64), Field( name="profile_embedding", dtype=Array(base_type=Float32), tags=vector_tags, ), - Field(name="lifetime_trip_count", dtype=Int32), - Field(name=driver_entity.join_key, dtype=Int32), ], source=source, ) diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index d6be30aa96..7b255d44a4 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -8,6 +8,7 @@ import assertpy import numpy as np import pandas as pd +import pandas.api.types as ptypes import pytest import requests from botocore.exceptions import BotoCoreError @@ -580,8 +581,6 @@ def test_write_vectors_to_online_store(environment, universal_data_sources): "driver_id": [123], "profile_embedding": [np.random.default_rng().uniform(-100, 100, 50)], "lifetime_trip_count": [85], - "avg_passenger_count": [0.067], - "current_balance": [0.78325], "event_timestamp": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")], "created": [pd.Timestamp(datetime.datetime.utcnow()).round("ms")], } @@ -598,7 +597,8 @@ def test_write_vectors_to_online_store(environment, universal_data_sources): ], entity_rows=[{"driver_id": 123}], ).to_df() - assertpy.assert_that(df["profile_embedding"].iloc[0]).is_type_of(np.array) + + assert ptypes.is_array_like(df["profile_embedding"]) assertpy.assert_that(df["profile_embedding"].iloc[0]).is_length(50) assertpy.assert_that(df["lifetime_trip_count"].iloc[0]).is_equal_to(85)