Skip to content
This repository has been archived by the owner on Mar 4, 2024. It is now read-only.

Commit

Permalink
Add disable/enable endpoints for rulebooks/rulesets
Browse files Browse the repository at this point in the history
  • Loading branch information
hsong-rh committed Dec 16, 2022
1 parent 752604d commit 9d2c52d
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 0 deletions.
138 changes: 138 additions & 0 deletions src/eda_server/api/rulebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

import yaml
from fastapi import APIRouter, Depends, HTTPException, status
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession

from eda_server import schema
from eda_server.auth import requires_permission
from eda_server.db import models
from eda_server.db.dependency import get_db_session

# Rule, Ruleset, Rulebook query builder, enums, etc
Expand Down Expand Up @@ -255,6 +257,74 @@ async def read_ruleset(
return response


@router.patch(
"/api/rulesets/{ruleset_id}/enable",
response_model=schema.RulesetDetail,
operation_id="enable_ruleset",
dependencies=[
Depends(requires_permission(ResourceType.RULEBOOK, Action.UPDATE)),
],
)
async def enable_ruleset(
ruleset_id: int, db: AsyncSession = Depends(get_db_session)
):
ruleset = await rsql.get_ruleset(db, ruleset_id)
if not ruleset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Ruleset not found.",
)

await db.execute(
sa.update(models.rulesets)
.where(models.rulesets.c.id == ruleset_id)
.values(enabled=True)
)
await db.commit()

updated_ruleset = await rsql.get_ruleset(db, ruleset_id)
ruleset_counts = await rsql.get_ruleset_counts(db, ruleset_id)
response = updated_ruleset._asdict()
response["fired_stats"] = await build_detail_object_totals(
ruleset_counts, updated_ruleset.id
)
return response


@router.patch(
"/api/rulesets/{ruleset_id}/disable",
response_model=schema.RulesetDetail,
operation_id="disable_ruleset",
dependencies=[
Depends(requires_permission(ResourceType.RULEBOOK, Action.UPDATE)),
],
)
async def disable_ruleset(
ruleset_id: int, db: AsyncSession = Depends(get_db_session)
):
ruleset = await rsql.get_ruleset(db, ruleset_id)
if not ruleset:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Ruleset not found.",
)

await db.execute(
sa.update(models.rulesets)
.where(models.rulesets.c.id == ruleset_id)
.values(enabled=False)
)
await db.commit()

updated_ruleset = await rsql.get_ruleset(db, ruleset_id)
ruleset_counts = await rsql.get_ruleset_counts(db, ruleset_id)
response = updated_ruleset._asdict()
response["fired_stats"] = await build_detail_object_totals(
ruleset_counts, updated_ruleset.id
)
return response


@router.get(
"/api/rulesets/{ruleset_id}/rules",
response_model=List[schema.RuleList],
Expand Down Expand Up @@ -352,6 +422,74 @@ async def read_rulebook(
return result


@router.patch(
"/api/rulebooks/{rulebook_id}/enable",
operation_id="enable_rulebook",
response_model=schema.RulebookRead,
dependencies=[
Depends(requires_permission(ResourceType.RULEBOOK, Action.UPDATE)),
],
)
async def enable_rulebook(
rulebook_id: int, db: AsyncSession = Depends(get_db_session)
):
result = await rsql.get_rulebook(db, rulebook_id)
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Rulebook Not Found.",
)

await db.execute(
sa.update(models.rulebooks)
.where(models.rulebooks.c.id == rulebook_id)
.values(enabled=True)
)
await db.execute(
sa.update(models.rulesets)
.where(models.rulesets.c.rulebook_id == rulebook_id)
.values(enabled=True)
)
await db.commit()

updated_rulebook = await rsql.get_rulebook(db, rulebook_id)
return updated_rulebook


@router.patch(
"/api/rulebooks/{rulebook_id}/disable",
operation_id="disable_rulebook",
response_model=schema.RulebookRead,
dependencies=[
Depends(requires_permission(ResourceType.RULEBOOK, Action.UPDATE)),
],
)
async def disable_rulebook(
rulebook_id: int, db: AsyncSession = Depends(get_db_session)
):
result = await rsql.get_rulebook(db, rulebook_id)
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Rulebook Not Found.",
)

await db.execute(
sa.update(models.rulebooks)
.where(models.rulebooks.c.id == rulebook_id)
.values(enabled=False)
)
await db.execute(
sa.update(models.rulesets)
.where(models.rulesets.c.rulebook_id == rulebook_id)
.values(enabled=False)
)
await db.commit()

updated_rulebook = await rsql.get_rulebook(db, rulebook_id)
return updated_rulebook


@router.get(
"/api/rulebook_json/{rulebook_id}",
operation_id="read_rulebook_json",
Expand Down
14 changes: 14 additions & 0 deletions src/eda_server/db/models/rulebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@
server_default=func.now(),
onupdate=func.now(),
),
sa.Column(
"enabled",
sa.Boolean,
nullable=False,
default=True,
server_default=sa.true(),
),
)


Expand Down Expand Up @@ -91,6 +98,13 @@
server_default=func.now(),
onupdate=func.now(),
),
sa.Column(
"enabled",
sa.Boolean,
nullable=False,
default=True,
server_default=sa.true(),
),
)

rules = sa.Table(
Expand Down
3 changes: 3 additions & 0 deletions src/eda_server/schema/rulebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class RulebookRead(BaseModel):
ruleset_count: int
created_at: datetime
modified_at: datetime
enabled: bool = True


class RulebookRulesetList(BaseModel):
Expand Down Expand Up @@ -84,6 +85,7 @@ class RulesetList(BaseModel):
source_types: Optional[List[str]]
created_at: datetime
modified_at: datetime
enabled: bool = True
fired_stats: Optional[List[FireCountsListRef]]


Expand All @@ -108,6 +110,7 @@ class RulesetDetail(BaseModel):
rule_count: int
created_at: datetime
modified_at: datetime
enabled: bool = True
sources: Optional[List[RulesetSourceRef]]
rulebook: Optional[RulebookRef]
project: Optional[RulesetProjectRef]
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/api/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ async def test_list_rulesets(
"pct_window_total": 100,
}
],
"enabled": True,
}
]

Expand All @@ -374,6 +375,7 @@ async def test_list_rulesets_no_stats(
"modified_at": ruleset.created_at.isoformat(),
"source_types": ["range"],
"fired_stats": [],
"enabled": True,
}
]

Expand Down
Loading

0 comments on commit 9d2c52d

Please sign in to comment.