diff --git a/fixbackend/app.py b/fixbackend/app.py
index 02632793..f073363d 100644
--- a/fixbackend/app.py
+++ b/fixbackend/app.py
@@ -17,7 +17,17 @@
from dataclasses import replace
from datetime import timedelta
from ssl import Purpose, create_default_context
-from typing import Any, AsyncIterator, Awaitable, Callable, ClassVar, Optional, Set, Tuple, cast
+from typing import (
+ Any,
+ AsyncIterator,
+ Awaitable,
+ Callable,
+ ClassVar,
+ Optional,
+ Set,
+ Tuple,
+ cast,
+)
import boto3
import httpx
@@ -46,7 +56,10 @@
from fixbackend.certificates.cert_store import CertificateStore
from fixbackend.cloud_accounts.account_setup import AwsAccountSetupHelper
from fixbackend.cloud_accounts.repository import CloudAccountRepositoryImpl
-from fixbackend.cloud_accounts.router import cloud_accounts_callback_router, cloud_accounts_router
+from fixbackend.cloud_accounts.router import (
+ cloud_accounts_callback_router,
+ cloud_accounts_router,
+)
from fixbackend.cloud_accounts.service_impl import CloudAccountServiceImpl
from fixbackend.collect.collect_queue import RedisCollectQueue
from fixbackend.config import Config
@@ -63,7 +76,11 @@
from fixbackend.inventory.inventory_client import InventoryClient
from fixbackend.inventory.inventory_service import InventoryService
from fixbackend.inventory.router import inventory_router
-from fixbackend.logging_context import get_logging_context, set_fix_cloud_account_id, set_workspace_id
+from fixbackend.logging_context import (
+ get_logging_context,
+ set_fix_cloud_account_id,
+ set_workspace_id,
+)
from fixbackend.metering.metering_repository import MeteringRepository
from fixbackend.middleware.x_real_ip import RealIpMiddleware
from fixbackend.subscription.aws_marketplace import AwsMarketplaceHandler
@@ -88,6 +105,7 @@ def fast_api_app(cfg: Config) -> FastAPI:
client_context = create_default_context(purpose=Purpose.SERVER_AUTH)
if ca_cert_path:
client_context.load_verify_locations(ca_cert_path)
+ http_client = deps.add(SN.http_client, AsyncClient(verify=ca_cert_path or True))
def create_redis(url: str) -> Redis:
kwargs = dict(ssl_ca_certs=ca_cert_path) if url.startswith("rediss://") else {}
@@ -97,7 +115,6 @@ def create_redis(url: str) -> Redis:
@asynccontextmanager
async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]:
- http_client = deps.add(SN.http_client, AsyncClient(verify=ca_cert_path or True))
arq_redis = deps.add(
SN.arq_redis,
await create_pool(
@@ -111,7 +128,8 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]:
deps.add(SN.readonly_redis, create_redis(cfg.redis_readonly_url))
readwrite_redis = deps.add(SN.readwrite_redis, create_redis(cfg.redis_readwrite_url))
domain_event_subscriber = deps.add(
- SN.domain_event_subscriber, DomainEventSubscriber(readwrite_redis, cfg, "fixbackend")
+ SN.domain_event_subscriber,
+ DomainEventSubscriber(readwrite_redis, cfg, "fixbackend"),
)
engine = deps.add(
SN.async_engine,
@@ -130,7 +148,10 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]:
deps.add(SN.collect_queue, RedisCollectQueue(arq_redis))
graph_db_access = deps.add(SN.graph_db_access, GraphDatabaseAccessManager(cfg, session_maker))
inventory_client = deps.add(SN.inventory_client, InventoryClient(cfg.inventory_url, http_client))
- deps.add(SN.inventory, InventoryService(inventory_client, graph_db_access, domain_event_subscriber))
+ deps.add(
+ SN.inventory,
+ InventoryService(inventory_client, graph_db_access, domain_event_subscriber),
+ )
fixbackend_events = deps.add(
SN.domain_event_redis_stream_publisher,
RedisStreamPublisher(
@@ -142,7 +163,8 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]:
)
domain_event_publisher = deps.add(SN.domain_event_sender, DomainEventPublisherImpl(fixbackend_events))
workspace_repo = deps.add(
- SN.workspace_repo, WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher)
+ SN.workspace_repo,
+ WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher),
)
subscription_repo = deps.add(SN.subscription_repo, SubscriptionRepository(session_maker))
deps.add(
@@ -160,7 +182,9 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]:
CustomerIoEventConsumer(http_client, cfg, domain_event_subscriber),
)
cloud_accounts_redis_publisher = RedisPubSubPublisher(
- redis=readwrite_redis, channel="cloud_accounts", publisher_name="cloud_account_service"
+ redis=readwrite_redis,
+ channel="cloud_accounts",
+ publisher_name="cloud_account_service",
)
deps.add(
SN.cloud_account_service,
@@ -173,6 +197,9 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]:
cfg,
AwsAccountSetupHelper(boto_session),
dispatching=False,
+ http_client=http_client,
+ boto_session=boto_session,
+ cf_stack_queue_url=cfg.aws_cf_stack_notification_sqs_url,
),
)
@@ -199,7 +226,8 @@ async def setup_teardown_dispatcher(_: FastAPI) -> AsyncIterator[None]:
)
rw_redis = deps.add(SN.readwrite_redis, create_redis(cfg.redis_readwrite_url))
domain_event_subscriber = deps.add(
- SN.domain_event_subscriber, DomainEventSubscriber(rw_redis, cfg, "dispatching")
+ SN.domain_event_subscriber,
+ DomainEventSubscriber(rw_redis, cfg, "dispatching"),
)
temp_store_redis = deps.add(SN.temp_store_redis, create_redis(cfg.redis_temp_store_url))
engine = deps.add(
@@ -231,10 +259,13 @@ async def setup_teardown_dispatcher(_: FastAPI) -> AsyncIterator[None]:
domain_event_publisher = deps.add(SN.domain_event_sender, DomainEventPublisherImpl(fixbackend_events))
workspace_repo = deps.add(
- SN.workspace_repo, WorkspaceRepositoryImpl(session_maker, db_access, domain_event_publisher)
+ SN.workspace_repo,
+ WorkspaceRepositoryImpl(session_maker, db_access, domain_event_publisher),
)
cloud_accounts_redis_publisher = RedisPubSubPublisher(
- redis=rw_redis, channel="cloud_accounts", publisher_name="cloud_account_service"
+ redis=rw_redis,
+ channel="cloud_accounts",
+ publisher_name="cloud_account_service",
)
deps.add(
SN.cloud_account_service,
@@ -247,6 +278,9 @@ async def setup_teardown_dispatcher(_: FastAPI) -> AsyncIterator[None]:
cfg,
AwsAccountSetupHelper(boto_session),
dispatching=True,
+ http_client=http_client,
+ boto_session=boto_session,
+ cf_stack_queue_url=cfg.aws_cf_stack_notification_sqs_url,
),
)
deps.add(
@@ -297,7 +331,8 @@ async def setup_teardown_billing(_: FastAPI) -> AsyncIterator[None]:
domain_event_publisher = deps.add(SN.domain_event_sender, DomainEventPublisherImpl(fixbackend_events))
metering_repo = deps.add(SN.metering_repo, MeteringRepository(session_maker))
workspace_repo = deps.add(
- SN.workspace_repo, WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher)
+ SN.workspace_repo,
+ WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher),
)
subscription_repo = deps.add(SN.subscription_repo, SubscriptionRepository(session_maker))
aws_marketplace = deps.add(
@@ -434,7 +469,11 @@ async def refresh_session(request: Request, call_next: Callable[[Request], Await
return response
if cfg.static_assets:
- app.mount("/", StaticFiles(directory=cfg.static_assets, html=True), name="static_assets")
+ app.mount(
+ "/",
+ StaticFiles(directory=cfg.static_assets, html=True),
+ name="static_assets",
+ )
@app.get("/")
async def root(request: Request) -> Response:
@@ -466,7 +505,11 @@ def setup_process() -> FastAPI:
"""
current_config = config.get_config()
level = logging.DEBUG if current_config.args.debug else logging.INFO
- setup_logger(f"fixbackend_{current_config.args.mode}", level=level, get_logging_context=get_logging_context)
+ setup_logger(
+ f"fixbackend_{current_config.args.mode}",
+ level=level,
+ get_logging_context=get_logging_context,
+ )
# Replace all special uvicorn handlers
for logger in ["uvicorn", "uvicorn.error", "uvicorn.access"]:
diff --git a/fixbackend/cloud_accounts/models/__init__.py b/fixbackend/cloud_accounts/models/__init__.py
index f7c52a7a..06c89797 100644
--- a/fixbackend/cloud_accounts/models/__init__.py
+++ b/fixbackend/cloud_accounts/models/__init__.py
@@ -53,6 +53,9 @@ class GcpCloudAccess(CloudAccess):
class CloudAccountState(ABC):
state_name: ClassVar[str]
+ def cloud_access(self) -> Optional[CloudAccess]:
+ return None
+
class CloudAccountStates:
"""
@@ -85,6 +88,9 @@ class Discovered(CloudAccountState):
state_name: ClassVar[str] = "discovered"
access: CloudAccess
+ def cloud_access(self) -> Optional[CloudAccess]:
+ return self.access
+
@frozen
class Configured(CloudAccountState):
"""
@@ -95,6 +101,9 @@ class Configured(CloudAccountState):
access: CloudAccess
enabled: bool # is enabled for collection
+ def cloud_access(self) -> Optional[CloudAccess]:
+ return self.access
+
@frozen
class Degraded(CloudAccountState):
"""
@@ -105,6 +114,9 @@ class Degraded(CloudAccountState):
access: CloudAccess
error: str
+ def cloud_access(self) -> Optional[CloudAccess]:
+ return self.access
+
@frozen(kw_only=True)
class CloudAccount:
diff --git a/fixbackend/cloud_accounts/service_impl.py b/fixbackend/cloud_accounts/service_impl.py
index c1035973..6c24be11 100644
--- a/fixbackend/cloud_accounts/service_impl.py
+++ b/fixbackend/cloud_accounts/service_impl.py
@@ -11,8 +11,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-
-
+import json
import uuid
from collections import defaultdict
from datetime import timedelta
@@ -20,12 +19,14 @@
from logging import getLogger
from typing import Any, Dict, List, Optional
+import boto3
from attrs import evolve
from fixcloudutils.asyncio.periodic import Periodic
from fixcloudutils.redis.event_stream import Backoff, DefaultBackoff, Json, MessageContext, RedisStreamListener
from fixcloudutils.redis.pub_sub import RedisPubSubPublisher
from fixcloudutils.service import Service
from fixcloudutils.util import utc
+from httpx import AsyncClient
from redis.asyncio import Redis
from fixbackend.cloud_accounts.account_setup import AssumeRoleResults, AwsAccountSetupHelper
@@ -54,6 +55,8 @@
WorkspaceId,
)
from fixbackend.logging_context import set_cloud_account_id, set_fix_cloud_account_id, set_workspace_id
+from fixbackend.sqs import SQSRawListener
+from fixbackend.utils import uid
from fixbackend.workspaces.repository import WorkspaceRepository
log = getLogger(__name__)
@@ -70,12 +73,14 @@ def __init__(
config: Config,
account_setup_helper: AwsAccountSetupHelper,
dispatching: bool,
+ http_client: AsyncClient,
+ boto_session: boto3.Session,
+ cf_stack_queue_url: Optional[str] = None,
) -> None:
self.workspace_repository = workspace_repository
self.cloud_account_repository = cloud_account_repository
self.pubsub_publisher = pubsub_publisher
self.domain_events = domain_event_publisher
-
backoff_config: Dict[str, Backoff] = defaultdict(lambda: DefaultBackoff)
backoff_config[AwsAccountDiscovered.kind] = Backoff(
base_delay=5,
@@ -105,17 +110,139 @@ def __init__(
self.dispatching = dispatching
self.fast_lane_timeout = timedelta(minutes=1)
self.become_degraded_timeout = timedelta(minutes=15)
+ self.cf_listener = (
+ SQSRawListener(
+ session=boto_session,
+ queue_url=cf_stack_queue_url,
+ message_processor=self.process_cf_stack_event,
+ consider_failed_after=timedelta(minutes=5),
+ max_nr_of_messages_in_one_batch=1,
+ wait_for_new_messages_to_arrive=timedelta(seconds=10),
+ )
+ if cf_stack_queue_url
+ else None
+ )
+ self.http_client = http_client
async def start(self) -> Any:
await self.domain_event_listener.start()
if self.periodic:
await self.periodic.start()
+ if self.cf_listener:
+ await self.cf_listener.start()
async def stop(self) -> Any:
+ if self.cf_listener:
+ await self.cf_listener.stop()
await self.domain_event_listener.stop()
if self.periodic:
await self.periodic.stop()
+ async def process_cf_stack_event(self, message: Json) -> Optional[CloudAccount]:
+ log.info(f"Received CF stack event: {message}")
+
+ async def send_response(
+ msg: Json, physical_resource_id: Optional[str] = None, error_message: Optional[str] = None
+ ) -> None:
+ try:
+ physical_resource_id = physical_resource_id or msg["PhysicalResourceId"]
+ request_id = msg["RequestId"]
+ logical_resource_id = msg["LogicalResourceId"]
+ response_url = msg["ResponseURL"]
+ resource_properties = msg["ResourceProperties"]
+ role_name = AwsRoleName(resource_properties["RoleName"])
+ stack_id = resource_properties["StackId"]
+ except Exception as e:
+ log.warning(f"Not enough data to inform CF: {msg}. Error: {e}")
+ return None
+
+ # Signal CF that we're done
+ response = await self.http_client.put(
+ response_url,
+ json={
+ "Status": "FAILURE" if error_message else "SUCCESS",
+ "Reason": error_message or "OK",
+ "LogicalResourceId": logical_resource_id,
+ "PhysicalResourceId": physical_resource_id,
+ "StackId": stack_id,
+ "RequestId": request_id,
+ "Data": {"RoleName": role_name},
+ },
+ )
+ if response.is_error:
+ raise RuntimeError(f"Failed to signal CF that we're done: {response}")
+
+ async def handle_stack_created(msg: Json) -> Optional[CloudAccount]:
+ try:
+ resource_properties = msg["ResourceProperties"]
+ workspace_id = WorkspaceId(uuid.UUID(resource_properties["WorkspaceId"]))
+ external_id = ExternalId(uuid.UUID(resource_properties["ExternalId"]))
+ role_name = AwsRoleName(resource_properties["RoleName"])
+ stack_id = resource_properties["StackId"]
+ assert stack_id.startswith("arn:aws:cloudformation:")
+ assert stack_id.count(":") == 5
+ account_id = CloudAccountId(stack_id.split(":")[4])
+ except Exception as e:
+ log.warning(f"Received invalid CF stack create event: {msg}. Error: {e}")
+ await send_response(msg, str(uid()), "Invalid format for CF stack create/update event")
+ return None
+ # Create/Update the account on our side
+ set_workspace_id(str(workspace_id))
+ set_cloud_account_id(account_id)
+ account = await self.create_aws_account(
+ workspace_id=workspace_id,
+ account_id=account_id,
+ role_name=role_name,
+ external_id=external_id,
+ account_name=None,
+ )
+ # Signal to CF that we're done
+ await send_response(msg, str(account.id))
+ return account
+
+ async def handle_stack_deleted(msg: Json) -> Optional[CloudAccount]:
+ try:
+ resource_properties = msg["ResourceProperties"]
+ role_name = AwsRoleName(resource_properties["RoleName"])
+ external_id = ExternalId(uuid.UUID(resource_properties["ExternalId"]))
+ cloud_account_id = FixCloudAccountId(uuid.UUID(msg["PhysicalResourceId"]))
+ except Exception as e:
+ log.warning(f"Received invalid CF stack delete event: {msg}. Error: {e}")
+ await send_response(msg, str(uid()), "Invalid format for CF stack delete event")
+ return None
+ if (
+ (account := await self.cloud_account_repository.get(cloud_account_id))
+ and isinstance(access := account.state.cloud_access(), AwsCloudAccess)
+ # also make sure the stack refers to the same role and external id
+ and access.role_name == role_name
+ and access.external_id == external_id
+ ):
+ account = await self.__degrade_account(
+ FixCloudAccountId(cloud_account_id), "CloudformationStack deleted"
+ )
+ await send_response(msg, str(cloud_account_id))
+ return account
+
+ try:
+ body = json.loads(message["Body"])
+ assert body["Type"] == "Notification"
+ content = json.loads(body["Message"])
+ kind = content["RequestType"]
+ match kind:
+ case "Create":
+ return await handle_stack_created(content)
+ case "Delete":
+ return await handle_stack_deleted(content)
+ case "Update":
+ return await handle_stack_created(content)
+ case _:
+ log.info(f"Received a CF stack event that is currently not handled. Ignore. {kind}")
+ await send_response(message) # still try to acknowledge the message
+ return None
+ except Exception as e:
+ log.warning(f"Received invalid CF stack event: {message}. Error: {e}")
+ return None
+
async def process_domain_event(self, message: Json, context: MessageContext) -> None:
match context.kind:
case TenantAccountsCollected.kind:
@@ -205,24 +332,7 @@ def fast_lane_should_end() -> bool:
if should_move_to_degraded():
log.info("failed to assume role, but timeout is reached, moving account to degraded state")
error = "Cannot assume role"
-
- def update_fn(cloud_account: CloudAccount) -> CloudAccount:
- if isinstance(cloud_account.state, CloudAccountStates.Discovered):
- return evolve(
- cloud_account, state=CloudAccountStates.Degraded(access, error), state_updated_at=utc()
- )
- else:
- return cloud_account
-
- await self.cloud_account_repository.update(account.id, update_fn)
- await self.domain_events.publish(
- AwsAccountDegraded(
- cloud_account_id=account.id,
- tenant_id=account.workspace_id,
- aws_account_id=account.account_id,
- error=error,
- )
- )
+ await self.__degrade_account(account.id, error)
return None
elif fast_lane_should_end():
log.info("Can't assume role, leaving account in discovered state")
@@ -453,3 +563,25 @@ def update_state(cloud_account: CloudAccount) -> CloudAccount:
raise ValueError(f"Account {cloud_account_id} is not configured, cannot enable account")
return await self.cloud_account_repository.update(cloud_account_id, update_state)
+
+ async def __degrade_account(
+ self,
+ account_id: FixCloudAccountId,
+ error: str,
+ ) -> CloudAccount:
+ def set_degraded(cloud_account: CloudAccount) -> CloudAccount:
+ if access := cloud_account.state.cloud_access():
+ return evolve(cloud_account, state=CloudAccountStates.Degraded(access, error), state_updated_at=utc())
+ else:
+ return cloud_account
+
+ account = await self.cloud_account_repository.update(account_id, set_degraded)
+ await self.domain_events.publish(
+ AwsAccountDegraded(
+ cloud_account_id=account.id,
+ tenant_id=account.workspace_id,
+ aws_account_id=account.account_id,
+ error=error,
+ )
+ )
+ return account
diff --git a/fixbackend/config.py b/fixbackend/config.py
index b61d7895..fa35b6f3 100644
--- a/fixbackend/config.py
+++ b/fixbackend/config.py
@@ -61,6 +61,7 @@ class Config(BaseSettings):
customerio_site_id: Optional[str]
customerio_api_key: Optional[str]
cloud_account_service_event_parallelism: int
+ aws_cf_stack_notification_sqs_url: Optional[str]
def frontend_cdn_origin(self) -> str:
return f"{self.cdn_endpoint}/{self.cdn_bucket}/{self.fixui_sha}"
@@ -123,6 +124,9 @@ def parse_args(argv: Optional[Sequence[str]] = None) -> Namespace:
parser.add_argument(
"--aws-marketplace-metering-sqs-url", default=os.environ.get("AWS_MARKETPLACE_METERING_SQS_URL")
)
+ parser.add_argument(
+ "--aws-cf-stack-notification-sqs-url", default=os.environ.get("AWS_CF_STACK_NOTIFICATION_SQS_URL")
+ )
parser.add_argument("--ca-cert", type=Path, default=os.environ.get("CA_CERT"))
parser.add_argument("--host-cert", type=Path, default=os.environ.get("HOST_CERT"))
parser.add_argument("--host-key", type=Path, default=os.environ.get("HOST_KEY"))
diff --git a/fixbackend/subscription/aws_marketplace.py b/fixbackend/subscription/aws_marketplace.py
index a4d6c79d..7dc8c863 100644
--- a/fixbackend/subscription/aws_marketplace.py
+++ b/fixbackend/subscription/aws_marketplace.py
@@ -65,7 +65,7 @@ def __init__(
self.handle_message,
consider_failed_after=timedelta(minutes=5),
max_nr_of_messages_in_one_batch=1,
- wait_for_new_messages_to_arrive=timedelta(seconds=5),
+ wait_for_new_messages_to_arrive=timedelta(seconds=10),
)
if sqs_queue_url is not None
else None
diff --git a/tests/fixbackend/cloud_accounts/service_test.py b/tests/fixbackend/cloud_accounts/service_test.py
index 3bd15369..d6c0ee09 100644
--- a/tests/fixbackend/cloud_accounts/service_test.py
+++ b/tests/fixbackend/cloud_accounts/service_test.py
@@ -11,17 +11,18 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-
-
+import json
import uuid
from datetime import datetime, timedelta
from typing import Callable, Dict, List, Optional, Tuple
+import boto3
import pytest
from attrs import evolve
from fixcloudutils.redis.event_stream import MessageContext, RedisStreamPublisher
from fixcloudutils.redis.pub_sub import RedisPubSubPublisher
from fixcloudutils.types import Json
+from httpx import AsyncClient, Request, Response
from redis.asyncio import Redis
from fixbackend.cloud_accounts.account_setup import AssumeRoleResult, AssumeRoleResults, AwsAccountSetupHelper
@@ -54,6 +55,8 @@
from fixbackend.workspaces.repository import WorkspaceRepositoryImpl
from fixcloudutils.util import utc
+from tests.fixbackend.conftest import RequestHandlerMock
+
class CloudAccountRepositoryMock(CloudAccountRepository):
def __init__(self) -> None:
@@ -112,6 +115,7 @@ async def list_all_discovered_accounts(self) -> List[CloudAccount]:
class OrganizationServiceMock(WorkspaceRepositoryImpl):
+ # noinspection PyMissingConstructor
def __init__(self) -> None:
pass
@@ -122,6 +126,7 @@ async def get_workspace(self, workspace_id: WorkspaceId, with_users: bool = Fals
class RedisStreamPublisherMock(RedisStreamPublisher):
+ # noinspection PyMissingConstructor
def __init__(self) -> None:
self.last_message: Optional[Tuple[str, Json]] = None
@@ -130,6 +135,7 @@ async def publish(self, kind: str, message: Json) -> None:
class RedisPubSubPublisherMock(RedisPubSubPublisher):
+ # noinspection PyMissingConstructor
def __init__(self) -> None:
self.last_message: Optional[Tuple[str, Json, Optional[str]]] = None
@@ -146,6 +152,7 @@ async def publish(self, event: Event) -> None:
class AwsAccountSetupHelperMock(AwsAccountSetupHelper):
+ # noinspection PyMissingConstructor
def __init__(self) -> None:
self.can_assume = True
self.org_accounts: Dict[CloudAccountId, CloudAccountName] = {}
@@ -170,17 +177,44 @@ async def list_account_aliases(self, assume_role_result: AssumeRoleResults.Succe
now = datetime.utcnow()
-@pytest.mark.asyncio
-async def test_create_aws_account(
+@pytest.fixture
+def repository() -> CloudAccountRepositoryMock:
+ return CloudAccountRepositoryMock()
+
+
+@pytest.fixture
+def organization_repository() -> OrganizationServiceMock:
+ return OrganizationServiceMock()
+
+
+@pytest.fixture
+def pubsub_publisher() -> RedisPubSubPublisherMock:
+ return RedisPubSubPublisherMock()
+
+
+@pytest.fixture
+def domain_sender() -> DomainEventSenderMock:
+ return DomainEventSenderMock()
+
+
+@pytest.fixture
+def account_setup_helper() -> AwsAccountSetupHelperMock:
+ return AwsAccountSetupHelperMock()
+
+
+@pytest.fixture
+def service(
+ organization_repository: OrganizationServiceMock,
+ repository: CloudAccountRepositoryMock,
+ pubsub_publisher: RedisPubSubPublisherMock,
+ domain_sender: DomainEventSenderMock,
+ account_setup_helper: AwsAccountSetupHelperMock,
arq_redis: Redis,
default_config: Config,
-) -> None:
- repository = CloudAccountRepositoryMock()
- organization_repository = OrganizationServiceMock()
- pubsub_publisher = RedisPubSubPublisherMock()
- domain_sender = DomainEventSenderMock()
- account_setup_helper = AwsAccountSetupHelperMock()
- service = CloudAccountServiceImpl(
+ boto_session: boto3.Session,
+ http_client: AsyncClient,
+) -> CloudAccountServiceImpl:
+ return CloudAccountServiceImpl(
organization_repository,
repository,
pubsub_publisher,
@@ -189,8 +223,19 @@ async def test_create_aws_account(
default_config,
account_setup_helper,
dispatching=False,
+ boto_session=boto_session,
+ http_client=http_client,
+ cf_stack_queue_url=None,
)
+
+@pytest.mark.asyncio
+async def test_create_aws_account(
+ repository: CloudAccountRepositoryMock,
+ pubsub_publisher: RedisPubSubPublisherMock,
+ domain_sender: DomainEventSenderMock,
+ service: CloudAccountServiceImpl,
+) -> None:
# happy case
acc = await service.create_aws_account(
workspace_id=test_workspace_id,
@@ -297,23 +342,10 @@ async def test_create_aws_account(
@pytest.mark.asyncio
-async def test_delete_aws_account(arq_redis: Redis, default_config: Config) -> None:
- repository = CloudAccountRepositoryMock()
- organization_repository = OrganizationServiceMock()
- pubsub_publisher = RedisPubSubPublisherMock()
- domain_sender = DomainEventSenderMock()
- account_setup_helper = AwsAccountSetupHelperMock()
- service = CloudAccountServiceImpl(
- organization_repository,
- repository,
- pubsub_publisher,
- domain_sender,
- arq_redis,
- default_config,
- account_setup_helper,
- dispatching=False,
- )
-
+async def test_delete_aws_account(
+ repository: CloudAccountRepositoryMock,
+ service: CloudAccountServiceImpl,
+) -> None:
account = await service.create_aws_account(
workspace_id=test_workspace_id,
account_id=account_id,
@@ -334,23 +366,7 @@ async def test_delete_aws_account(arq_redis: Redis, default_config: Config) -> N
@pytest.mark.asyncio
-async def test_store_last_run_info(arq_redis: Redis, default_config: Config) -> None:
- repository = CloudAccountRepositoryMock()
- organization_repository = OrganizationServiceMock()
- pubsub_publisher = RedisPubSubPublisherMock()
- domain_sender = DomainEventSenderMock()
- account_setup_helper = AwsAccountSetupHelperMock()
- service = CloudAccountServiceImpl(
- organization_repository,
- repository,
- pubsub_publisher,
- domain_sender,
- arq_redis,
- default_config,
- account_setup_helper,
- dispatching=False,
- )
-
+async def test_store_last_run_info(service: CloudAccountServiceImpl) -> None:
account = await service.create_aws_account(
workspace_id=test_workspace_id,
account_id=account_id,
@@ -379,23 +395,7 @@ async def test_store_last_run_info(arq_redis: Redis, default_config: Config) ->
@pytest.mark.asyncio
-async def test_get_cloud_account(arq_redis: Redis, default_config: Config) -> None:
- repository = CloudAccountRepositoryMock()
- organization_repository = OrganizationServiceMock()
- pubsub_publisher = RedisPubSubPublisherMock()
- domain_sender = DomainEventSenderMock()
- account_setup_helper = AwsAccountSetupHelperMock()
- service = CloudAccountServiceImpl(
- organization_repository,
- repository,
- pubsub_publisher,
- domain_sender,
- arq_redis,
- default_config,
- account_setup_helper,
- dispatching=False,
- )
-
+async def test_get_cloud_account(repository: CloudAccountRepositoryMock, service: CloudAccountServiceImpl) -> None:
account = await service.create_aws_account(
workspace_id=test_workspace_id,
account_id=account_id,
@@ -428,24 +428,7 @@ async def test_get_cloud_account(arq_redis: Redis, default_config: Config) -> No
@pytest.mark.asyncio
-async def test_list_cloud_accounts(arq_redis: Redis, default_config: Config) -> None:
- repository = CloudAccountRepositoryMock()
- organization_repository = OrganizationServiceMock()
- pubsub_publisher = RedisPubSubPublisherMock()
- domain_sender = DomainEventSenderMock()
-
- account_setup_helper = AwsAccountSetupHelperMock()
- service = CloudAccountServiceImpl(
- organization_repository,
- repository,
- pubsub_publisher,
- domain_sender,
- arq_redis,
- default_config,
- account_setup_helper,
- dispatching=False,
- )
-
+async def test_list_cloud_accounts(repository: CloudAccountRepositoryMock, service: CloudAccountServiceImpl) -> None:
account = await service.create_aws_account(
workspace_id=test_workspace_id,
account_id=account_id,
@@ -469,24 +452,9 @@ async def test_list_cloud_accounts(arq_redis: Redis, default_config: Config) ->
@pytest.mark.asyncio
-async def test_update_cloud_account_name(arq_redis: Redis, default_config: Config) -> None:
- repository = CloudAccountRepositoryMock()
- organization_repository = OrganizationServiceMock()
- pubsub_publisher = RedisPubSubPublisherMock()
- domain_sender = DomainEventSenderMock()
-
- account_setup_helper = AwsAccountSetupHelperMock()
- service = CloudAccountServiceImpl(
- organization_repository,
- repository,
- pubsub_publisher,
- domain_sender,
- arq_redis,
- default_config,
- account_setup_helper,
- dispatching=False,
- )
-
+async def test_update_cloud_account_name(
+ repository: CloudAccountRepositoryMock, service: CloudAccountServiceImpl
+) -> None:
account = await service.create_aws_account(
workspace_id=test_workspace_id,
account_id=account_id,
@@ -523,24 +491,9 @@ async def test_update_cloud_account_name(arq_redis: Redis, default_config: Confi
@pytest.mark.asyncio
-async def test_handle_account_discovered_success(arq_redis: Redis, default_config: Config) -> None:
- repository = CloudAccountRepositoryMock()
- organization_repository = OrganizationServiceMock()
- pubsub_publisher = RedisPubSubPublisherMock()
- domain_sender = DomainEventSenderMock()
-
- account_setup_helper = AwsAccountSetupHelperMock()
- service = CloudAccountServiceImpl(
- organization_repository,
- repository,
- pubsub_publisher,
- domain_sender,
- arq_redis,
- default_config,
- account_setup_helper,
- dispatching=False,
- )
-
+async def test_handle_account_discovered_success(
+ repository: CloudAccountRepositoryMock, domain_sender: DomainEventSenderMock, service: CloudAccountServiceImpl
+) -> None:
account = await service.create_aws_account(
workspace_id=test_workspace_id,
account_id=account_id,
@@ -568,24 +521,12 @@ async def test_handle_account_discovered_success(arq_redis: Redis, default_confi
@pytest.mark.asyncio
-async def test_handle_account_discovered_assume_role_failure(arq_redis: Redis, default_config: Config) -> None:
- repository = CloudAccountRepositoryMock()
- organization_repository = OrganizationServiceMock()
- pubsub_publisher = RedisPubSubPublisherMock()
- domain_sender = DomainEventSenderMock()
-
- account_setup_helper = AwsAccountSetupHelperMock()
- service = CloudAccountServiceImpl(
- organization_repository,
- repository,
- pubsub_publisher,
- domain_sender,
- arq_redis,
- default_config,
- account_setup_helper,
- dispatching=False,
- )
-
+async def test_handle_account_discovered_assume_role_failure(
+ repository: CloudAccountRepositoryMock,
+ domain_sender: DomainEventSenderMock,
+ service: CloudAccountServiceImpl,
+ account_setup_helper: AwsAccountSetupHelperMock,
+) -> None:
# boto3 cannot assume right away
account_id = CloudAccountId("foobar")
role_name = AwsRoleName("FooBarRole")
@@ -633,24 +574,12 @@ async def test_handle_account_discovered_assume_role_failure(arq_redis: Redis, d
@pytest.mark.asyncio
-async def test_handle_account_discovered_list_accounts_success(arq_redis: Redis, default_config: Config) -> None:
- repository = CloudAccountRepositoryMock()
- organization_repository = OrganizationServiceMock()
- pubsub_publisher = RedisPubSubPublisherMock()
- domain_sender = DomainEventSenderMock()
-
- account_setup_helper = AwsAccountSetupHelperMock()
- service = CloudAccountServiceImpl(
- organization_repository,
- repository,
- pubsub_publisher,
- domain_sender,
- arq_redis,
- default_config,
- account_setup_helper,
- dispatching=False,
- )
-
+async def test_handle_account_discovered_list_accounts_success(
+ repository: CloudAccountRepositoryMock,
+ domain_sender: DomainEventSenderMock,
+ service: CloudAccountServiceImpl,
+ account_setup_helper: AwsAccountSetupHelperMock,
+) -> None:
account_id = CloudAccountId("foobar")
role_name = AwsRoleName("FooBarRole")
account = await service.create_aws_account(
@@ -687,24 +616,12 @@ async def test_handle_account_discovered_list_accounts_success(arq_redis: Redis,
@pytest.mark.asyncio
-async def test_handle_account_discovered_list_aliases_success(arq_redis: Redis, default_config: Config) -> None:
- repository = CloudAccountRepositoryMock()
- organization_repository = OrganizationServiceMock()
- pubsub_publisher = RedisPubSubPublisherMock()
- domain_sender = DomainEventSenderMock()
-
- account_setup_helper = AwsAccountSetupHelperMock()
- service = CloudAccountServiceImpl(
- organization_repository,
- repository,
- pubsub_publisher,
- domain_sender,
- arq_redis,
- default_config,
- account_setup_helper,
- dispatching=False,
- )
-
+async def test_handle_account_discovered_list_aliases_success(
+ repository: CloudAccountRepositoryMock,
+ domain_sender: DomainEventSenderMock,
+ service: CloudAccountServiceImpl,
+ account_setup_helper: AwsAccountSetupHelperMock,
+) -> None:
# boto3 cannot assume right away
account_id = CloudAccountId("foobar")
role_name = AwsRoleName("FooBarRole")
@@ -743,24 +660,9 @@ async def test_handle_account_discovered_list_aliases_success(arq_redis: Redis,
@pytest.mark.asyncio
-async def test_enable_disable_cloud_account(arq_redis: Redis, default_config: Config) -> None:
- repository = CloudAccountRepositoryMock()
- organization_repository = OrganizationServiceMock()
- pubsub_publisher = RedisPubSubPublisherMock()
- domain_sender = DomainEventSenderMock()
-
- account_setup_helper = AwsAccountSetupHelperMock()
- service = CloudAccountServiceImpl(
- organization_repository,
- repository,
- pubsub_publisher,
- domain_sender,
- arq_redis,
- default_config,
- account_setup_helper,
- dispatching=False,
- )
-
+async def test_enable_disable_cloud_account(
+ repository: CloudAccountRepositoryMock, service: CloudAccountServiceImpl
+) -> None:
account = await service.create_aws_account(
workspace_id=test_workspace_id,
account_id=account_id,
@@ -808,7 +710,7 @@ async def test_enable_disable_cloud_account(arq_redis: Redis, default_config: Co
)
with pytest.raises(Exception):
- updated = await service.enable_cloud_account(
+ await service.enable_cloud_account(
test_workspace_id,
account.id,
)
@@ -823,23 +725,12 @@ async def test_enable_disable_cloud_account(arq_redis: Redis, default_config: Co
@pytest.mark.asyncio
-async def test_configure_account(arq_redis: Redis, default_config: Config) -> None:
- repository = CloudAccountRepositoryMock()
- organization_repository = OrganizationServiceMock()
- pubsub_publisher = RedisPubSubPublisherMock()
- domain_sender = DomainEventSenderMock()
-
- account_setup_helper = AwsAccountSetupHelperMock()
- service = CloudAccountServiceImpl(
- organization_repository,
- repository,
- pubsub_publisher,
- domain_sender,
- arq_redis,
- default_config,
- account_setup_helper,
- dispatching=False,
- )
+async def test_configure_account(
+ repository: CloudAccountRepositoryMock,
+ domain_sender: DomainEventSenderMock,
+ service: CloudAccountServiceImpl,
+ account_setup_helper: AwsAccountSetupHelperMock,
+) -> None:
account_setup_helper.can_assume = False
def get_account(state_updated_at: datetime) -> CloudAccount:
@@ -899,3 +790,60 @@ def get_account(state_updated_at: datetime) -> CloudAccount:
assert event.cloud_account_id == account.id
assert event.aws_account_id == account_id
assert event.tenant_id == account.workspace_id
+
+
+@pytest.mark.asyncio
+async def test_handle_cf_sqs_message(
+ repository: CloudAccountRepositoryMock, service: CloudAccountServiceImpl, request_handler_mock: RequestHandlerMock
+) -> None:
+ async def handle_request(_: Request) -> Response:
+ return Response(200, content=b"ok")
+
+ def notification(kind: str, physical_resource_id: Optional[str] = None) -> Json:
+ base = {
+ "RequestType": kind,
+ "ServiceToken": "arn:aws:sns:us-east-1:12345:SomeCallbacks",
+ "ResponseURL": "https://cloudformation-custom.test.com/",
+ "StackId": "arn:aws:cloudformation:us-east-1:12345:stack/name/some-id",
+ "RequestId": "855e25d5-3b80-4aed-b9f4-af8682deaf79",
+ "LogicalResourceId": "FixAccessFunction",
+ "ResourceType": "Custom::Function",
+ "ResourceProperties": {
+ "ServiceToken": "arn:aws:sns:us-east-1:12345:SomeCallbacks",
+ "RoleName": role_name,
+ "ExternalId": str(external_id),
+ "WorkspaceId": str(test_workspace_id),
+ "StackId": "arn:aws:cloudformation:us-east-1:12345:stack/name/some-id",
+ },
+ }
+ if physical_resource_id:
+ base["PhysicalResourceId"] = physical_resource_id
+ cf_message = {
+ "Type": "Notification",
+ "MessageId": "38ccbec9-9999-5871-9383-e31e58450b68",
+ "TopicArn": "arn:aws:sns:us-east-1:12345:SomeCallbacks",
+ "Subject": "AWS CloudFormation custom resource request",
+ "Message": json.dumps(base),
+ "Timestamp": "2023-11-22T08:45:16.159Z",
+ "SignatureVersion": "1",
+ "Signature": "sig",
+ "SigningCertURL": "https://cert/pem",
+ "UnsubscribeURL": "https://unsubscribe",
+ }
+ return {"Body": json.dumps(cf_message)}
+
+ # Handle Create Message
+ request_handler_mock.append(handle_request)
+ assert len(repository.accounts) == 0
+ account = await service.process_cf_stack_event(notification("Create"))
+ assert account is not None
+ assert len(repository.accounts) == 1
+ assert repository.accounts[account.id] == account
+
+ # Handle Delete Message
+ repository.accounts[account.id] = evolve(
+ account, state=CloudAccountStates.Configured(AwsCloudAccess(external_id, role_name), enabled=True)
+ )
+ account = await service.process_cf_stack_event(notification("Delete", str(account.id)))
+ assert account is not None
+ assert isinstance(account.state, CloudAccountStates.Degraded)
diff --git a/tests/fixbackend/conftest.py b/tests/fixbackend/conftest.py
index 5fb0b2ac..3e51f820 100644
--- a/tests/fixbackend/conftest.py
+++ b/tests/fixbackend/conftest.py
@@ -64,7 +64,7 @@
DATABASE_URL = "mysql+aiomysql://root@127.0.0.1:3306/fixbackend-testdb"
# only used to create/drop the database
SYNC_DATABASE_URL = "mysql+pymysql://root@127.0.0.1:3306/fixbackend-testdb"
-InventoryMock = List[Callable[[Request], Awaitable[Response]]]
+RequestHandlerMock = List[Callable[[Request], Awaitable[Response]]]
os.environ["LOCAL_DEV_ENV"] = "true"
@@ -117,6 +117,7 @@ def default_config() -> Config:
customerio_site_id=None,
customerio_api_key=None,
cloud_account_service_event_parallelism=1000,
+ aws_cf_stack_notification_sqs_url=None,
)
@@ -341,22 +342,28 @@ def nd_json_response(content: Sequence[JsonElement]) -> Response:
@pytest.fixture
-async def inventory_mock() -> InventoryMock:
+async def request_handler_mock() -> RequestHandlerMock:
return []
@pytest.fixture
-async def inventory_client(inventory_mock: InventoryMock) -> AsyncIterator[InventoryClient]:
+async def http_client(request_handler_mock: RequestHandlerMock) -> AsyncClient:
async def app(request: Request) -> Response:
- for mock in inventory_mock:
+ for mock in request_handler_mock:
try:
return await mock(request)
except AttributeError:
pass
raise AttributeError(f'Unexpected request: {request.url.path} with content {request.content.decode("utf-8")}')
- async_client = AsyncClient(transport=MockTransport(app))
- async with InventoryClient("http://localhost:8980", client=async_client) as client:
+ return AsyncClient(transport=MockTransport(app))
+
+
+@pytest.fixture
+async def inventory_client(
+ http_client: AsyncClient, request_handler_mock: RequestHandlerMock
+) -> AsyncIterator[InventoryClient]:
+ async with InventoryClient("http://localhost:8980", client=http_client) as client:
yield client
diff --git a/tests/fixbackend/inventory/inventory_client_test.py b/tests/fixbackend/inventory/inventory_client_test.py
index 9d7b4ce2..0cf5236a 100644
--- a/tests/fixbackend/inventory/inventory_client_test.py
+++ b/tests/fixbackend/inventory/inventory_client_test.py
@@ -30,7 +30,7 @@
from fixbackend.ids import WorkspaceId, CloudAccountId, NodeId
from fixbackend.inventory.inventory_client import InventoryClient
from fixbackend.inventory.schemas import CompletePathRequest
-from tests.fixbackend.conftest import InventoryMock, nd_json_response, json_response
+from tests.fixbackend.conftest import RequestHandlerMock, nd_json_response, json_response
db_access = GraphDatabaseAccess(WorkspaceId(uuid.uuid1()), "server", "database", "username", "password")
@@ -38,7 +38,7 @@
@pytest.fixture
def mocked_inventory_client(
inventory_client: InventoryClient,
- inventory_mock: InventoryMock,
+ request_handler_mock: RequestHandlerMock,
azure_virtual_machine_resource_json: Json,
aws_ec2_model_json: Json,
) -> InventoryClient:
@@ -82,7 +82,7 @@ async def mock(request: Request) -> Response:
else:
raise AttributeError(f"Unexpected request: {request.method} {request.url.path} with content {content}")
- inventory_mock.append(mock)
+ request_handler_mock.append(mock)
return inventory_client
diff --git a/tests/fixbackend/inventory/inventory_service_test.py b/tests/fixbackend/inventory/inventory_service_test.py
index ff61b78d..6b6c652b 100644
--- a/tests/fixbackend/inventory/inventory_service_test.py
+++ b/tests/fixbackend/inventory/inventory_service_test.py
@@ -44,7 +44,7 @@
HistorySearch,
HistoryChange,
)
-from tests.fixbackend.conftest import InventoryMock, json_response, nd_json_response
+from tests.fixbackend.conftest import RequestHandlerMock, json_response, nd_json_response
from fixbackend.domain_events.subscriber import DomainEventSubscriber
db = GraphDatabaseAccess(WorkspaceId(uuid.uuid1()), "server", "database", "username", "password")
@@ -60,10 +60,10 @@
@pytest.fixture
def mocked_answers(
- inventory_mock: InventoryMock,
+ request_handler_mock: RequestHandlerMock,
benchmark_json: List[Json],
azure_virtual_machine_resource_json: Json,
-) -> InventoryMock:
+) -> RequestHandlerMock:
async def mock(request: Request) -> Response:
content = request.content.decode("utf-8")
if request.url.path == "/cli/execute" and content.endswith("jq --no-rewrite .group"):
@@ -109,18 +109,18 @@ async def mock(request: Request) -> Response:
else:
raise AttributeError(f"Unexpected request: {request.url.path} with content {content}")
- inventory_mock.append(mock)
- return inventory_mock
+ request_handler_mock.append(mock)
+ return request_handler_mock
async def test_benchmark_command(
- inventory_service: InventoryService, benchmark_json: List[Json], mocked_answers: InventoryMock
+ inventory_service: InventoryService, benchmark_json: List[Json], mocked_answers: RequestHandlerMock
) -> None:
response = [a async for a in await inventory_service.benchmark(db, "benchmark_name")]
assert response == benchmark_json
-async def test_summary(inventory_service: InventoryService, mocked_answers: InventoryMock) -> None:
+async def test_summary(inventory_service: InventoryService, mocked_answers: RequestHandlerMock) -> None:
summary = await inventory_service.summary(db)
assert len(summary.benchmarks) == 2
assert summary.overall_score == 42
@@ -188,7 +188,7 @@ async def test_dict_values_by() -> None:
assert [a for a in dict_values_by(inv, lambda x: -x)] == [1, 2, 3, 11, 12, 13, 21, 22, 23]
-async def test_search_list(inventory_service: InventoryService, mocked_answers: InventoryMock) -> None:
+async def test_search_list(inventory_service: InventoryService, mocked_answers: RequestHandlerMock) -> None:
expected = [
{
"columns": [
@@ -208,7 +208,7 @@ async def test_search_list(inventory_service: InventoryService, mocked_answers:
assert result == expected
-async def test_search_start_data(inventory_service: InventoryService, mocked_answers: InventoryMock) -> None:
+async def test_search_start_data(inventory_service: InventoryService, mocked_answers: RequestHandlerMock) -> None:
result = [
SearchCloudResource(id="234", name="bla", cloud="gcp"),
SearchCloudResource(id="123", name="foo", cloud="aws"),
@@ -221,7 +221,7 @@ async def test_search_start_data(inventory_service: InventoryService, mocked_ans
async def test_resource(
- inventory_service: InventoryService, mocked_answers: InventoryMock, azure_virtual_machine_resource_json: Json
+ inventory_service: InventoryService, mocked_answers: RequestHandlerMock, azure_virtual_machine_resource_json: Json
) -> None:
res = await inventory_service.resource(db, NodeId("some_node_id"))
assert res["neighborhood"] == neighborhood