Skip to content

Commit

Permalink
fix: Share a single database sessions starting from __handle_formatte…
Browse files Browse the repository at this point in the history
…d_events to RuleEngine (#2053)
  • Loading branch information
VladimirFilonov authored Oct 1, 2024
1 parent c635ab0 commit b3669f1
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 26 deletions.
48 changes: 25 additions & 23 deletions keep/api/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import random
import uuid
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Any, Dict, List, Tuple, Union
Expand Down Expand Up @@ -79,6 +80,15 @@ def get_order_by(self):
return col(getattr(Incident, self.value))


@contextmanager
def existed_or_new_session(session: Optional[Session] = None) -> Session:
if session:
yield session
else:
with Session(engine) as session:
yield session


def get_session() -> Session:
"""
Creates a database session.
Expand Down Expand Up @@ -1561,11 +1571,11 @@ def delete_rule(tenant_id, rule_id):


def get_incident_for_grouping_rule(
tenant_id, rule, timeframe, rule_fingerprint
tenant_id, rule, timeframe, rule_fingerprint, session: Optional[Session] = None
) -> Incident:
# checks if incident with the incident criteria exists, if not it creates it
# and then assign the alert to the incident
with Session(engine) as session:
with existed_or_new_session(session) as session:
incident = session.exec(
select(Incident)
.options(joinedload(Incident.alerts))
Expand Down Expand Up @@ -1595,13 +1605,7 @@ def get_incident_for_grouping_rule(
)
session.add(incident)
session.commit()

# Re-query the incident with joinedload to set up future automatic loading of alerts
incident = session.exec(
select(Incident)
.options(joinedload(Incident.alerts))
.where(Incident.id == incident.id)
).first()
session.refresh(incident)

return incident

Expand Down Expand Up @@ -2330,8 +2334,8 @@ def update_preset_options(tenant_id: str, preset_id: str, options: dict) -> Pres
return preset


def assign_alert_to_incident(alert_id: UUID | str, incident_id: UUID, tenant_id: str):
return add_alerts_to_incident_by_incident_id(tenant_id, incident_id, [alert_id])
def assign_alert_to_incident(alert_id: UUID | str, incident_id: UUID, tenant_id: str, session: Optional[Session]=None):
return add_alerts_to_incident_by_incident_id(tenant_id, incident_id, [alert_id], session=session)


def is_alert_assigned_to_incident(
Expand Down Expand Up @@ -2670,15 +2674,15 @@ def get_alerts_data_for_incident(
Returns: dict {sources: list[str], services: list[str], count: int}
"""

def inner(db_session: Session):
with existed_or_new_session(session) as session:

fields = (
get_json_extract_field(session, Alert.event, "service"),
Alert.provider_type,
get_json_extract_field(session, Alert.event, "severity"),
)

alerts_data = db_session.exec(
alerts_data = session.exec(
select(*fields).where(
col(Alert.id).in_(alert_ids),
)
Expand Down Expand Up @@ -2706,22 +2710,19 @@ def inner(db_session: Session):
"count": len(alerts_data),
}

# Ensure that we have a session to execute the query. If not - make new one
if not session:
with Session(engine) as session:
return inner(session)
return inner(session)


def add_alerts_to_incident_by_incident_id(
tenant_id: str, incident_id: str | UUID, alert_ids: List[UUID]
tenant_id: str,
incident_id: str | UUID,
alert_ids: List[UUID],
session: Optional[Session] = None,
) -> Optional[Incident]:
logger.info(
f"Adding alerts to incident {incident_id} in database, total {len(alert_ids)} alerts",
extra={"tags": {"tenant_id": tenant_id, "incident_id": incident_id}},
)

with Session(engine) as session:
with existed_or_new_session(session) as session:
query = select(Incident).where(
Incident.tenant_id == tenant_id,
Incident.id == incident_id,
Expand Down Expand Up @@ -3136,9 +3137,10 @@ def get_provider_by_type_and_id(


def bulk_upsert_alert_fields(
tenant_id: str, fields: List[str], provider_id: str, provider_type: str
tenant_id: str, fields: List[str], provider_id: str, provider_type: str,
session: Optional[Session] = None,
):
with Session(engine) as session:
with existed_or_new_session(session) as session:
try:
# Prepare the data for bulk insert
data = [
Expand Down
3 changes: 2 additions & 1 deletion keep/api/tasks/process_event_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def __handle_formatted_events(
fields=fields,
provider_id=enriched_formatted_event.providerId,
provider_type=enriched_formatted_event.providerType,
session=session
)

logger.debug(
Expand Down Expand Up @@ -384,7 +385,7 @@ def __handle_formatted_events(
# Now we need to run the rules engine
try:
rules_engine = RulesEngine(tenant_id=tenant_id)
incidents: List[IncidentDto] = rules_engine.run_rules(enriched_formatted_events)
incidents: List[IncidentDto] = rules_engine.run_rules(enriched_formatted_events, session=session)

# TODO: Replace with incidents workflow triggers. Ticket: https://github.com/keephq/keep/issues/1527
# if new grouped incidents were created, we need to push them to the client
Expand Down
8 changes: 6 additions & 2 deletions keep/rulesengine/rulesengine.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
import logging
from typing import Optional

import celpy
import celpy.c7nlib
import celpy.celparser
import celpy.celtypes
import celpy.evaluation
from sqlmodel import Session

from keep.api.consts import STATIC_PRESETS
from keep.api.core.db import assign_alert_to_incident, get_incident_for_grouping_rule
Expand Down Expand Up @@ -38,7 +40,7 @@ def __init__(self, tenant_id=None):
self.logger = logging.getLogger(__name__)
self.env = celpy.Environment()

def run_rules(self, events: list[AlertDto]) -> list[IncidentDto]:
def run_rules(self, events: list[AlertDto], session: Optional[Session] = None) -> list[IncidentDto]:
self.logger.info("Running rules")
rules = get_rules_db(tenant_id=self.tenant_id)

Expand All @@ -64,13 +66,15 @@ def run_rules(self, events: list[AlertDto]) -> list[IncidentDto]:
rule_fingerprint = self._calc_rule_fingerprint(event, rule)

incident = get_incident_for_grouping_rule(
self.tenant_id, rule, rule.timeframe, rule_fingerprint
self.tenant_id, rule, rule.timeframe, rule_fingerprint,
session=session
)

incident = assign_alert_to_incident(
alert_id=event.event_id,
incident_id=incident.id,
tenant_id=self.tenant_id,
session=session
)

incidents_dto[incident.id] = IncidentDto.from_db_incident(incident)
Expand Down

0 comments on commit b3669f1

Please sign in to comment.