Skip to content

Commit

Permalink
fix: Trying BatchStatement instead of execute_concurrent_with_args
Browse files Browse the repository at this point in the history
  • Loading branch information
Bhargav Dodla committed Jan 6, 2025
1 parent 7ad16b7 commit d8053b4
Showing 1 changed file with 40 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@
from cassandra.cluster import (
EXEC_PROFILE_DEFAULT,
Cluster,
ConsistencyLevel,
ExecutionProfile,
ResultSet,
Session,
)
from cassandra.concurrent import execute_concurrent_with_args
from cassandra.policies import DCAwareRoundRobinPolicy, TokenAwarePolicy
from cassandra.query import PreparedStatement
from cassandra.query import BatchStatement, PreparedStatement
from pydantic import StrictFloat, StrictInt, StrictStr

from feast import Entity, FeatureView, RepoConfig
Expand Down Expand Up @@ -352,12 +353,17 @@ def online_write_batch(
"""
project = config.project

keyspace: str = self._keyspace
fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table)
insert_cql = self._get_cql_statement(config, "insert4", fqtable=fqtable)

def unroll_insertion_tuples() -> Iterable[Tuple[str, bytes, str, datetime]]:
"""
We craft an iterable over all rows to be inserted (entities->features),
but this way we can call `progress` after each entity is done.
"""
for entity_key, values, timestamp, created_ts in data:
batch = BatchStatement(consistency_level=ConsistencyLevel.QUORUM)
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
Expand All @@ -369,15 +375,14 @@ def unroll_insertion_tuples() -> Iterable[Tuple[str, bytes, str, datetime]]:
entity_key_bin,
timestamp,
)
yield params
batch.add(insert_cql, params)
yield batch
# this happens N-1 times, will be corrected outside:
if progress:
progress(1)

self._write_rows_concurrently(
config,
project,
table,
unroll_insertion_tuples(),
)
# correction for the last missing call to `progress`:
Expand Down Expand Up @@ -493,21 +498,39 @@ def _fq_table_name(keyspace: str, project: str, table: FeatureView) -> str:
def _write_rows_concurrently(
self,
config: RepoConfig,
project: str,
table: FeatureView,
rows: Iterable[Tuple[str, bytes, str, datetime]],
batches: Iterable[BatchStatement],
):
session: Session = self._get_session(config)
keyspace: str = self._keyspace
fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table)
insert_cql = self._get_cql_statement(config, "insert4", fqtable=fqtable)
#
execute_concurrent_with_args(
session,
insert_cql,
rows,
concurrency=config.online_store.write_concurrency,
)
futures = []
for batch in batches:
futures.append(session.execute_async(batch))
if len(futures) >= config.online_store.write_concurrency:
# Raises exception if at least one of the batch fails
try:
for future in futures:
future.result()
futures = []
except Exception as exc:
logger.error(f"Error writing a batch: {exc}")
print(f"Error writing a batch: {exc}")
raise Exception("Error writing a batch") from exc

if len(futures) > 0:
try:
for future in futures:
future.result()
futures = []
except Exception as exc:
logger.error(f"Error writing a batch: {exc}")
print(f"Error writing a batch: {exc}")
raise Exception("Error writing a batch") from exc

# execute_concurrent_with_args(
# session,
# insert_cql,
# rows,
# concurrency=config.online_store.write_concurrency,
# )

def _read_rows_by_entity_keys(
self,
Expand Down

0 comments on commit d8053b4

Please sign in to comment.