Skip to content

Commit

Permalink
fix: Enhance cluster connection management and session handling in Ca…
Browse files Browse the repository at this point in the history
…ssandraOnlineStore
  • Loading branch information
Bhargav Dodla committed Jan 12, 2025
1 parent c810ee8 commit b4cd604
Showing 1 changed file with 142 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,111 @@ class CassandraOnlineStore(OnlineStore):
_keyspace: str = "feast_keyspace"
_prepared_statements: Dict[str, PreparedStatement] = {}

def _get_cluster(self, config: RepoConfig):
"""
Establish the database connection, if not yet created,
and return it.
Also perform basic config validation checks.
"""

online_store_config = config.online_store
if not isinstance(online_store_config, CassandraOnlineStoreConfig):
raise CassandraInvalidConfig(E_CASSANDRA_UNEXPECTED_CONFIGURATION_CLASS)

if self._cluster:
if not self._cluster.is_shutdown:
print("Reusing existing cluster..")
return self._cluster
else:
self._cluster = None
print("Creating a new cluster..")
if not self._cluster:
# configuration consistency checks
hosts = online_store_config.hosts
secure_bundle_path = online_store_config.secure_bundle_path
port = online_store_config.port or 9042
keyspace = online_store_config.keyspace
username = online_store_config.username
password = online_store_config.password
protocol_version = online_store_config.protocol_version

db_directions = hosts or secure_bundle_path
if not db_directions or not keyspace:
raise CassandraInvalidConfig(E_CASSANDRA_NOT_CONFIGURED)
if hosts and secure_bundle_path:
raise CassandraInvalidConfig(E_CASSANDRA_MISCONFIGURED)
if (username is None) ^ (password is None):
raise CassandraInvalidConfig(E_CASSANDRA_INCONSISTENT_AUTH)

if username is not None:
auth_provider = PlainTextAuthProvider(
username=username,
password=password,
)
else:
auth_provider = None

# handling of load-balancing policy (optional)
if online_store_config.load_balancing:
# construct a proper execution profile embedding
# the configured LB policy
_lbp_name = online_store_config.load_balancing.load_balancing_policy
if _lbp_name == "DCAwareRoundRobinPolicy":
lb_policy = DCAwareRoundRobinPolicy(
local_dc=online_store_config.load_balancing.local_dc,
)
elif _lbp_name == "TokenAwarePolicy(DCAwareRoundRobinPolicy)":
lb_policy = TokenAwarePolicy(
DCAwareRoundRobinPolicy(
local_dc=online_store_config.load_balancing.local_dc,
)
)
else:
raise CassandraInvalidConfig(E_CASSANDRA_UNKNOWN_LB_POLICY)

# wrap it up in a map of ex.profiles with a default
exe_profile = ExecutionProfile(
request_timeout=online_store_config.request_timeout,
load_balancing_policy=lb_policy,
)
execution_profiles = {EXEC_PROFILE_DEFAULT: exe_profile}
else:
execution_profiles = None

# additional optional keyword args to Cluster
cluster_kwargs = {
k: v
for k, v in {
"protocol_version": protocol_version,
"execution_profiles": execution_profiles,
"idle_heartbeat_interval": None,
}.items()
if v is not None
}

# creation of Cluster (Cassandra vs. Astra)
if hosts:
self._cluster = Cluster(
hosts,
port=port,
auth_provider=auth_provider,
**cluster_kwargs,
)
else:
# we use 'secure_bundle_path'
self._cluster = Cluster(
cloud={"secure_connect_bundle": secure_bundle_path},
auth_provider=auth_provider,
**cluster_kwargs,
)

# creation of Session
self._keyspace = keyspace
# self._session = self._cluster.connect(self._keyspace)

return self._cluster

def _get_session(self, config: RepoConfig):
"""
Establish the database connection, if not yet created,
Expand Down Expand Up @@ -350,35 +455,47 @@ def online_write_batch(
display progress.
"""
project = config.project
session: Session = self._get_session(config)
cluster: Cluster = self._get_cluster(config)
# session: Session = self._get_session(config)
keyspace: str = self._keyspace
fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table)
insert_cql = self._get_cql_statement(
config, "insert4", fqtable=fqtable, session=session
)

futures = []
for entity_key, values, timestamp, created_ts in data:
batch = BatchStatement(batch_type=BatchType.UNLOGGED)
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
for feature_name, val in values.items():
params: Tuple[str, bytes, str, datetime] = (
feature_name,
val.SerializeToString(),
entity_key_bin,
timestamp,
)
batch.add(insert_cql, params)
# this happens N-1 times, will be corrected outside:
if progress:
progress(1)

futures.append(session.execute_async(batch))
if len(futures) >= config.online_store.write_concurrency:
# Raises exception if at least one of the batch fails
with cluster.connect(keyspace) as session:
insert_cql = self._get_cql_statement(
config, "insert4", fqtable=fqtable, session=session
)
for entity_key, values, timestamp, created_ts in data:
batch = BatchStatement(batch_type=BatchType.UNLOGGED)
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()
for feature_name, val in values.items():
params: Tuple[str, bytes, str, datetime] = (
feature_name,
val.SerializeToString(),
entity_key_bin,
timestamp,
)
batch.add(insert_cql, params)
# this happens N-1 times, will be corrected outside:
if progress:
progress(1)

futures.append(session.execute_async(batch))
if len(futures) >= config.online_store.write_concurrency:
# Raises exception if at least one of the batch fails
try:
for future in futures:
future.result()
futures = []
except Exception as exc:
logger.error(f"Error writing a batch: {exc}")
print(f"Error writing a batch: {exc}")
raise Exception("Error writing a batch") from exc

if len(futures) > 0:
try:
for future in futures:
future.result()
Expand All @@ -387,16 +504,6 @@ def online_write_batch(
logger.error(f"Error writing a batch: {exc}")
print(f"Error writing a batch: {exc}")
raise Exception("Error writing a batch") from exc

if len(futures) > 0:
try:
for future in futures:
future.result()
futures = []
except Exception as exc:
logger.error(f"Error writing a batch: {exc}")
print(f"Error writing a batch: {exc}")
raise Exception("Error writing a batch") from exc
# correction for the last missing call to `progress`:
if progress:
progress(1)
Expand Down

0 comments on commit b4cd604

Please sign in to comment.