From 23f7d78e3497525e5779fd106a43c3388df9729d Mon Sep 17 00:00:00 2001 From: Timon Pike Date: Mon, 30 Oct 2023 11:04:12 -0700 Subject: [PATCH] fix tests and expanded value proto parsing --- .../vectordb/elasticsearch_online_store.py | 46 ++++++++++++++++--- .../test_elasticsearch_online_store.py | 13 ++++-- 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py b/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py index 3d6a94eeb4..56b05d08ab 100644 --- a/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py +++ b/sdk/python/feast/expediagroup/vectordb/elasticsearch_online_store.py @@ -10,7 +10,15 @@ from feast import Entity, FeatureView, RepoConfig from feast.infra.online_stores.online_store import OnlineStore from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto -from feast.protos.feast.types.Value_pb2 import FloatList +from feast.protos.feast.types.Value_pb2 import ( + BoolList, + BytesList, + DoubleList, + FloatList, + Int32List, + Int64List, + StringList, +) from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel from feast.types import ( @@ -245,30 +253,54 @@ def _get_value_from_value_proto(self, proto: ValueProto): value (Any): the extracted value. """ val_type = proto.WhichOneof("val") + if not val_type: + return None + value = getattr(proto, val_type) # type: ignore if val_type == "bytes_val": value = base64.b64encode(value).decode() - if val_type == "float_list_val": + if val_type == "bytes_list_val": + value = [base64.b64encode(v).decode() for v in value.val] + elif "_list_val" in val_type: value = list(value.val) return value def _create_value_proto(self, feature_val, value_type) -> ValueProto: """ - Construct Value Proto so that Feast can interpret Milvus results + Construct Value Proto so that Feast can interpret Elasticsearch results Parameters: - val_proto (ValueProto): Initialised Value Proto - feature_val (Union[list, int, str, double, float, bool, bytes]): A row/ an item in the result that Milvus returns. + feature_val (Union[list, int, str, double, float, bool, bytes]): An item in the result that Elasticsearch returns. value_type (Str): Feast Value type; example: int64_val, float_val, etc. Returns: val_proto (ValueProto): Constructed result that Feast can understand. """ - if value_type == "float_list_val": - val_proto = ValueProto(float_list_val=FloatList(val=feature_val)) + if value_type == "bytes_list_val": + val_proto = ValueProto( + bytes_list_val=BytesList(val=[base64.b64decode(f) for f in feature_val]) + ) elif value_type == "bytes_val": val_proto = ValueProto(bytes_val=base64.b64decode(feature_val)) + elif value_type == "string_list_val": + val_proto = ValueProto(string_list_val=StringList(val=feature_val)) + elif value_type == "int32_list_val": + val_proto = ValueProto(int32_list_val=Int32List(val=feature_val)) + elif value_type == "int64_list_val": + val_proto = ValueProto(int64_list_val=Int64List(val=feature_val)) + elif value_type == "double_list_val": + val_proto = ValueProto(double_list_val=DoubleList(val=feature_val)) + elif value_type == "float_list_val": + val_proto = ValueProto(float_list_val=FloatList(val=feature_val)) + elif value_type == "bool_list_val": + val_proto = ValueProto(bool_list_val=BoolList(val=feature_val)) + elif value_type == "unix_timestamp_list_val": + nanos_list = [ + int(datetime.strptime(f, "%Y-%m-%dT%H:%M:%S.%fZ").timestamp() * 1000) + for f in feature_val + ] + val_proto = ValueProto(unix_timestamp_list_val=Int64List(val=nanos_list)) elif value_type == "unix_timestamp_val": nanos = ( datetime.strptime(feature_val, "%Y-%m-%dT%H:%M:%S.%fZ").timestamp() diff --git a/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py b/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py index 4a4cab34ea..ee5be3dfc6 100644 --- a/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py +++ b/sdk/python/tests/expediagroup/test_elasticsearch_online_store.py @@ -16,7 +16,7 @@ from feast.infra.offline_stores.file import FileOfflineStoreConfig from feast.infra.offline_stores.file_source import FileSource from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto -from feast.protos.feast.types.Value_pb2 import FloatList +from feast.protos.feast.types.Value_pb2 import BytesList, FloatList from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import RepoConfig from feast.types import ( @@ -171,7 +171,7 @@ def test_elasticsearch_update_add_index(self, repo_config, caplog, index_params) "type": index_params["index_type"].lower(), **index_params["index_params"], } - + with ElasticsearchConnectionManager(repo_config.online_store) as es: created_index = es.indices.get(index=self.index_to_write) assert created_index.body[self.index_to_write]["mappings"] == mapping @@ -314,7 +314,7 @@ def test_elasticsearch_online_write_batch(self, repo_config, caplog): es.indices.refresh(index=self.index_to_write) res = es.cat.count(index=self.index_to_write, params={"format": "json"}) assert res[0]["count"] == f"{total_rows_to_write}" - doc = es.get(index=self.index_to_write, id="0")["_source"]["doc"] + doc = es.get(index=self.index_to_write, id="0")["_source"] for feature in feature_view.schema: assert feature.name in doc @@ -466,6 +466,10 @@ def _create_n_customer_test_samples_elasticsearch_online_read(self, name, n=10): name="timestamp", dtype=UnixTimestamp, ), + Field( + name="byte_list", + dtype=Array(Bytes), + ), ], ) return fv, [ @@ -490,6 +494,9 @@ def _create_n_customer_test_samples_elasticsearch_online_read(self, name, n=10): "timestamp": ValueProto( unix_timestamp_val=int(datetime.utcnow().timestamp() * 1000) ), + "byte_list": ValueProto( + bytes_list_val=BytesList(val=[b"a", b"b", b"c"]) + ), }, datetime.utcnow(), None,