From 9d2c52dca7a513386f4940e6b08dcbcffd2d5723 Mon Sep 17 00:00:00 2001 From: hsong-rh Date: Thu, 15 Dec 2022 11:43:58 -0500 Subject: [PATCH] Add disable/enable endpoints for rulebooks/rulesets --- src/eda_server/api/rulebook.py | 138 ++++++++++++++++++ src/eda_server/db/models/rulebook.py | 14 ++ src/eda_server/schema/rulebook.py | 3 + tests/integration/api/test_rule.py | 2 + tests/integration/api/test_rulebook.py | 190 +++++++++++++++++++++++++ 5 files changed, 347 insertions(+) diff --git a/src/eda_server/api/rulebook.py b/src/eda_server/api/rulebook.py index cd4108a2..594d61e3 100644 --- a/src/eda_server/api/rulebook.py +++ b/src/eda_server/api/rulebook.py @@ -17,10 +17,12 @@ import yaml from fastapi import APIRouter, Depends, HTTPException, status +import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession from eda_server import schema from eda_server.auth import requires_permission +from eda_server.db import models from eda_server.db.dependency import get_db_session # Rule, Ruleset, Rulebook query builder, enums, etc @@ -255,6 +257,74 @@ async def read_ruleset( return response +@router.patch( + "/api/rulesets/{ruleset_id}/enable", + response_model=schema.RulesetDetail, + operation_id="enable_ruleset", + dependencies=[ + Depends(requires_permission(ResourceType.RULEBOOK, Action.UPDATE)), + ], +) +async def enable_ruleset( + ruleset_id: int, db: AsyncSession = Depends(get_db_session) +): + ruleset = await rsql.get_ruleset(db, ruleset_id) + if not ruleset: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Ruleset not found.", + ) + + await db.execute( + sa.update(models.rulesets) + .where(models.rulesets.c.id == ruleset_id) + .values(enabled=True) + ) + await db.commit() + + updated_ruleset = await rsql.get_ruleset(db, ruleset_id) + ruleset_counts = await rsql.get_ruleset_counts(db, ruleset_id) + response = updated_ruleset._asdict() + response["fired_stats"] = await build_detail_object_totals( + ruleset_counts, updated_ruleset.id + ) + return response + + +@router.patch( + "/api/rulesets/{ruleset_id}/disable", + response_model=schema.RulesetDetail, + operation_id="disable_ruleset", + dependencies=[ + Depends(requires_permission(ResourceType.RULEBOOK, Action.UPDATE)), + ], +) +async def disable_ruleset( + ruleset_id: int, db: AsyncSession = Depends(get_db_session) +): + ruleset = await rsql.get_ruleset(db, ruleset_id) + if not ruleset: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Ruleset not found.", + ) + + await db.execute( + sa.update(models.rulesets) + .where(models.rulesets.c.id == ruleset_id) + .values(enabled=False) + ) + await db.commit() + + updated_ruleset = await rsql.get_ruleset(db, ruleset_id) + ruleset_counts = await rsql.get_ruleset_counts(db, ruleset_id) + response = updated_ruleset._asdict() + response["fired_stats"] = await build_detail_object_totals( + ruleset_counts, updated_ruleset.id + ) + return response + + @router.get( "/api/rulesets/{ruleset_id}/rules", response_model=List[schema.RuleList], @@ -352,6 +422,74 @@ async def read_rulebook( return result +@router.patch( + "/api/rulebooks/{rulebook_id}/enable", + operation_id="enable_rulebook", + response_model=schema.RulebookRead, + dependencies=[ + Depends(requires_permission(ResourceType.RULEBOOK, Action.UPDATE)), + ], +) +async def enable_rulebook( + rulebook_id: int, db: AsyncSession = Depends(get_db_session) +): + result = await rsql.get_rulebook(db, rulebook_id) + if not result: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Rulebook Not Found.", + ) + + await db.execute( + sa.update(models.rulebooks) + .where(models.rulebooks.c.id == rulebook_id) + .values(enabled=True) + ) + await db.execute( + sa.update(models.rulesets) + .where(models.rulesets.c.rulebook_id == rulebook_id) + .values(enabled=True) + ) + await db.commit() + + updated_rulebook = await rsql.get_rulebook(db, rulebook_id) + return updated_rulebook + + +@router.patch( + "/api/rulebooks/{rulebook_id}/disable", + operation_id="disable_rulebook", + response_model=schema.RulebookRead, + dependencies=[ + Depends(requires_permission(ResourceType.RULEBOOK, Action.UPDATE)), + ], +) +async def disable_rulebook( + rulebook_id: int, db: AsyncSession = Depends(get_db_session) +): + result = await rsql.get_rulebook(db, rulebook_id) + if not result: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Rulebook Not Found.", + ) + + await db.execute( + sa.update(models.rulebooks) + .where(models.rulebooks.c.id == rulebook_id) + .values(enabled=False) + ) + await db.execute( + sa.update(models.rulesets) + .where(models.rulesets.c.rulebook_id == rulebook_id) + .values(enabled=False) + ) + await db.commit() + + updated_rulebook = await rsql.get_rulebook(db, rulebook_id) + return updated_rulebook + + @router.get( "/api/rulebook_json/{rulebook_id}", operation_id="read_rulebook_json", diff --git a/src/eda_server/db/models/rulebook.py b/src/eda_server/db/models/rulebook.py index c4abe414..45b01e23 100644 --- a/src/eda_server/db/models/rulebook.py +++ b/src/eda_server/db/models/rulebook.py @@ -55,6 +55,13 @@ server_default=func.now(), onupdate=func.now(), ), + sa.Column( + "enabled", + sa.Boolean, + nullable=False, + default=True, + server_default=sa.true(), + ), ) @@ -91,6 +98,13 @@ server_default=func.now(), onupdate=func.now(), ), + sa.Column( + "enabled", + sa.Boolean, + nullable=False, + default=True, + server_default=sa.true(), + ), ) rules = sa.Table( diff --git a/src/eda_server/schema/rulebook.py b/src/eda_server/schema/rulebook.py index 8ccb357d..d8a80b6c 100644 --- a/src/eda_server/schema/rulebook.py +++ b/src/eda_server/schema/rulebook.py @@ -32,6 +32,7 @@ class RulebookRead(BaseModel): ruleset_count: int created_at: datetime modified_at: datetime + enabled: bool = True class RulebookRulesetList(BaseModel): @@ -84,6 +85,7 @@ class RulesetList(BaseModel): source_types: Optional[List[str]] created_at: datetime modified_at: datetime + enabled: bool = True fired_stats: Optional[List[FireCountsListRef]] @@ -108,6 +110,7 @@ class RulesetDetail(BaseModel): rule_count: int created_at: datetime modified_at: datetime + enabled: bool = True sources: Optional[List[RulesetSourceRef]] rulebook: Optional[RulebookRef] project: Optional[RulesetProjectRef] diff --git a/tests/integration/api/test_rule.py b/tests/integration/api/test_rule.py index 57c87520..ef050f91 100644 --- a/tests/integration/api/test_rule.py +++ b/tests/integration/api/test_rule.py @@ -348,6 +348,7 @@ async def test_list_rulesets( "pct_window_total": 100, } ], + "enabled": True, } ] @@ -374,6 +375,7 @@ async def test_list_rulesets_no_stats( "modified_at": ruleset.created_at.isoformat(), "source_types": ["range"], "fired_stats": [], + "enabled": True, } ] diff --git a/tests/integration/api/test_rulebook.py b/tests/integration/api/test_rulebook.py index 8ea13b54..f253d9dd 100644 --- a/tests/integration/api/test_rulebook.py +++ b/tests/integration/api/test_rulebook.py @@ -151,3 +151,193 @@ async def test_list_rulebook_rulesets(client: AsyncClient, db: AsyncSession): assert rulebook_rulesets[ix]["id"] == rulesets[ix].id assert rulebook_rulesets[ix]["name"] == rulesets[ix].name assert rulebook_rulesets[ix]["rule_count"] == rulesets[ix].rule_count + + +async def _create_rulebook_dependent_objects(db: AsyncSession): + (project_id,) = ( + await db.execute( + sa.insert(models.projects).values( + name="test_project_name", url="http://example.com" + ) + ) + ).inserted_primary_key + + (rulebook_id,) = ( + await db.execute( + sa.insert(models.rulebooks).values( + name="test_rulebook_name", + rulesets=TEST_RULESETS_SIMPLE, + project_id=project_id, + ) + ) + ).inserted_primary_key + + (ruleset_id_1,) = ( + await db.execute( + sa.insert(models.rulesets).values( + name="test_ruleset_name_1", + rulebook_id=rulebook_id, + ) + ) + ).inserted_primary_key + + (ruleset_id_2,) = ( + await db.execute( + sa.insert(models.rulesets).values( + name="test_ruleset_name_2", + rulebook_id=rulebook_id, + ) + ) + ).inserted_primary_key + + foreign_keys = { + "rulebook_id": rulebook_id, + "ruleset_ids": [ruleset_id_1, ruleset_id_2], + } + + return foreign_keys + + +async def test_disable_rulebooks(client: AsyncClient, db: AsyncSession): + foreign_keys = await _create_rulebook_dependent_objects(db) + rulebook_id = foreign_keys["rulebook_id"] + + response = await client.patch( + f"/api/rulebooks/{rulebook_id}/disable", + ) + + assert response.status_code == status_codes.HTTP_200_OK + + data = response.json() + + assert data["id"] == rulebook_id + assert data["enabled"] is False + + rulesets = await db.execute( + sa.select(models.rulesets).where( + models.rulesets.c.rulebook_id == rulebook_id, + ) + ) + + for ruleset in rulesets.all(): + assert ruleset["enabled"] is False + + +async def test_enable_rulebook(client: AsyncClient, db: AsyncSession): + foreign_keys = await _create_rulebook_dependent_objects(db) + rulebook_id = foreign_keys["rulebook_id"] + + response = await client.patch( + f"/api/rulebooks/{rulebook_id}/disable", + ) + + assert response.status_code == status_codes.HTTP_200_OK + + data = response.json() + assert data["id"] == rulebook_id + assert data["enabled"] is False + + rulesets = await db.execute( + sa.select(models.rulesets).where( + models.rulesets.c.rulebook_id == rulebook_id, + ) + ) + + for ruleset in rulesets.all(): + assert ruleset["enabled"] is False + + response = await client.patch( + f"/api/rulebooks/{rulebook_id}/enable", + ) + + assert response.status_code == status_codes.HTTP_200_OK + + data = response.json() + assert data["id"] == rulebook_id + assert data["enabled"] is True + + rulesets = await db.execute( + sa.select(models.rulesets).where( + models.rulesets.c.rulebook_id == rulebook_id, + ) + ) + + for ruleset in rulesets.all(): + assert ruleset["enabled"] is True + + +async def test_disable_ruleset(client: AsyncClient, db: AsyncSession): + foreign_keys = await _create_rulebook_dependent_objects(db) + ruleset_id_1 = foreign_keys["ruleset_ids"][0] + + response = await client.patch( + f"/api/rulesets/{ruleset_id_1}/disable", + ) + + assert response.status_code == status_codes.HTTP_200_OK + + ruleset_1 = (await db.execute( + sa.select(models.rulesets).where( + models.rulesets.c.id == ruleset_id_1, + ) + )).first() + + ruleset_2 = (await db.execute( + sa.select(models.rulesets).where( + models.rulesets.c.id == foreign_keys["ruleset_ids"][1], + ) + )).first() + + assert ruleset_1["id"] == foreign_keys["ruleset_ids"][0] + assert ruleset_1["enabled"] is False + + assert ruleset_2["id"] == foreign_keys["ruleset_ids"][1] + assert ruleset_2["enabled"] is True + + +async def test_enable_ruleset(client: AsyncClient, db: AsyncSession): + foreign_keys = await _create_rulebook_dependent_objects(db) + ruleset_id_1 = foreign_keys["ruleset_ids"][0] + + response = await client.patch( + f"/api/rulesets/{ruleset_id_1}/disable", + ) + + assert response.status_code == status_codes.HTTP_200_OK + + ruleset_1 = (await db.execute( + sa.select(models.rulesets).where( + models.rulesets.c.id == ruleset_id_1, + ) + )).first() + + ruleset_2 = (await db.execute( + sa.select(models.rulesets).where( + models.rulesets.c.id == foreign_keys["ruleset_ids"][1], + ) + )).first() + + assert ruleset_1["enabled"] is False + assert ruleset_2["enabled"] is True + + response = await client.patch( + f"/api/rulesets/{ruleset_id_1}/enable", + ) + + assert response.status_code == status_codes.HTTP_200_OK + + ruleset_1 = (await db.execute( + sa.select(models.rulesets).where( + models.rulesets.c.id == ruleset_id_1, + ) + )).first() + + ruleset_2 = (await db.execute( + sa.select(models.rulesets).where( + models.rulesets.c.id == foreign_keys["ruleset_ids"][1], + ) + )).first() + + assert ruleset_1["enabled"] is True + assert ruleset_2["enabled"] is True +