From d8053b47af54ec3b9c1330a3c0cef6adc537c800 Mon Sep 17 00:00:00 2001 From: Bhargav Dodla Date: Mon, 6 Jan 2025 19:07:35 +0530 Subject: [PATCH] fix: Trying BatchStatement instead of execute_concurrent_with_args --- .../cassandra_online_store.py | 57 +++++++++++++------ 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py index 1998de464a..8d20e312ef 100644 --- a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py +++ b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py @@ -36,13 +36,14 @@ from cassandra.cluster import ( EXEC_PROFILE_DEFAULT, Cluster, + ConsistencyLevel, ExecutionProfile, ResultSet, Session, ) from cassandra.concurrent import execute_concurrent_with_args from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy -from cassandra.query import PreparedStatement +from cassandra.query import BatchStatement, PreparedStatement from pydantic import StrictFloat, StrictInt, StrictStr from feast import Entity, FeatureView, RepoConfig @@ -352,12 +353,17 @@ def online_write_batch( """ project = config.project + keyspace: str = self._keyspace + fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table) + insert_cql = self._get_cql_statement(config, "insert4", fqtable=fqtable) + def unroll_insertion_tuples() -> Iterable[Tuple[str, bytes, str, datetime]]: """ We craft an iterable over all rows to be inserted (entities->features), but this way we can call `progress` after each entity is done. """ for entity_key, values, timestamp, created_ts in data: + batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM) entity_key_bin = serialize_entity_key( entity_key, entity_key_serialization_version=config.entity_key_serialization_version, @@ -369,15 +375,14 @@ def unroll_insertion_tuples() -> Iterable[Tuple[str, bytes, str, datetime]]: entity_key_bin, timestamp, ) - yield params + batch.add(insert_cql, params) + yield batch # this happens N-1 times, will be corrected outside: if progress: progress(1) self._write_rows_concurrently( config, - project, - table, unroll_insertion_tuples(), ) # correction for the last missing call to `progress`: @@ -493,21 +498,39 @@ def _fq_table_name(keyspace: str, project: str, table: FeatureView) -> str: def _write_rows_concurrently( self, config: RepoConfig, - project: str, - table: FeatureView, - rows: Iterable[Tuple[str, bytes, str, datetime]], + batches: Iterable[BatchStatement], ): session: Session = self._get_session(config) - keyspace: str = self._keyspace - fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table) - insert_cql = self._get_cql_statement(config, "insert4", fqtable=fqtable) - # - execute_concurrent_with_args( - session, - insert_cql, - rows, - concurrency=config.online_store.write_concurrency, - ) + futures = [] + for batch in batches: + futures.append(session.execute_async(batch)) + if len(futures) >= config.online_store.write_concurrency: + # Raises exception if at least one of the batch fails + try: + for future in futures: + future.result() + futures = [] + except Exception as exc: + logger.error(f"Error writing a batch: {exc}") + print(f"Error writing a batch: {exc}") + raise Exception("Error writing a batch") from exc + + if len(futures) > 0: + try: + for future in futures: + future.result() + futures = [] + except Exception as exc: + logger.error(f"Error writing a batch: {exc}") + print(f"Error writing a batch: {exc}") + raise Exception("Error writing a batch") from exc + + # execute_concurrent_with_args( + # session, + # insert_cql, + # rows, + # concurrency=config.online_store.write_concurrency, + # ) def _read_rows_by_entity_keys( self,