diff --git a/fixbackend/app.py b/fixbackend/app.py index 02632793..f073363d 100644 --- a/fixbackend/app.py +++ b/fixbackend/app.py @@ -17,7 +17,17 @@ from dataclasses import replace from datetime import timedelta from ssl import Purpose, create_default_context -from typing import Any, AsyncIterator, Awaitable, Callable, ClassVar, Optional, Set, Tuple, cast +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + ClassVar, + Optional, + Set, + Tuple, + cast, +) import boto3 import httpx @@ -46,7 +56,10 @@ from fixbackend.certificates.cert_store import CertificateStore from fixbackend.cloud_accounts.account_setup import AwsAccountSetupHelper from fixbackend.cloud_accounts.repository import CloudAccountRepositoryImpl -from fixbackend.cloud_accounts.router import cloud_accounts_callback_router, cloud_accounts_router +from fixbackend.cloud_accounts.router import ( + cloud_accounts_callback_router, + cloud_accounts_router, +) from fixbackend.cloud_accounts.service_impl import CloudAccountServiceImpl from fixbackend.collect.collect_queue import RedisCollectQueue from fixbackend.config import Config @@ -63,7 +76,11 @@ from fixbackend.inventory.inventory_client import InventoryClient from fixbackend.inventory.inventory_service import InventoryService from fixbackend.inventory.router import inventory_router -from fixbackend.logging_context import get_logging_context, set_fix_cloud_account_id, set_workspace_id +from fixbackend.logging_context import ( + get_logging_context, + set_fix_cloud_account_id, + set_workspace_id, +) from fixbackend.metering.metering_repository import MeteringRepository from fixbackend.middleware.x_real_ip import RealIpMiddleware from fixbackend.subscription.aws_marketplace import AwsMarketplaceHandler @@ -88,6 +105,7 @@ def fast_api_app(cfg: Config) -> FastAPI: client_context = create_default_context(purpose=Purpose.SERVER_AUTH) if ca_cert_path: client_context.load_verify_locations(ca_cert_path) + http_client = deps.add(SN.http_client, AsyncClient(verify=ca_cert_path or True)) def create_redis(url: str) -> Redis: kwargs = dict(ssl_ca_certs=ca_cert_path) if url.startswith("rediss://") else {} @@ -97,7 +115,6 @@ def create_redis(url: str) -> Redis: @asynccontextmanager async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]: - http_client = deps.add(SN.http_client, AsyncClient(verify=ca_cert_path or True)) arq_redis = deps.add( SN.arq_redis, await create_pool( @@ -111,7 +128,8 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]: deps.add(SN.readonly_redis, create_redis(cfg.redis_readonly_url)) readwrite_redis = deps.add(SN.readwrite_redis, create_redis(cfg.redis_readwrite_url)) domain_event_subscriber = deps.add( - SN.domain_event_subscriber, DomainEventSubscriber(readwrite_redis, cfg, "fixbackend") + SN.domain_event_subscriber, + DomainEventSubscriber(readwrite_redis, cfg, "fixbackend"), ) engine = deps.add( SN.async_engine, @@ -130,7 +148,10 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]: deps.add(SN.collect_queue, RedisCollectQueue(arq_redis)) graph_db_access = deps.add(SN.graph_db_access, GraphDatabaseAccessManager(cfg, session_maker)) inventory_client = deps.add(SN.inventory_client, InventoryClient(cfg.inventory_url, http_client)) - deps.add(SN.inventory, InventoryService(inventory_client, graph_db_access, domain_event_subscriber)) + deps.add( + SN.inventory, + InventoryService(inventory_client, graph_db_access, domain_event_subscriber), + ) fixbackend_events = deps.add( SN.domain_event_redis_stream_publisher, RedisStreamPublisher( @@ -142,7 +163,8 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]: ) domain_event_publisher = deps.add(SN.domain_event_sender, DomainEventPublisherImpl(fixbackend_events)) workspace_repo = deps.add( - SN.workspace_repo, WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher) + SN.workspace_repo, + WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher), ) subscription_repo = deps.add(SN.subscription_repo, SubscriptionRepository(session_maker)) deps.add( @@ -160,7 +182,9 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]: CustomerIoEventConsumer(http_client, cfg, domain_event_subscriber), ) cloud_accounts_redis_publisher = RedisPubSubPublisher( - redis=readwrite_redis, channel="cloud_accounts", publisher_name="cloud_account_service" + redis=readwrite_redis, + channel="cloud_accounts", + publisher_name="cloud_account_service", ) deps.add( SN.cloud_account_service, @@ -173,6 +197,9 @@ async def setup_teardown_application(_: FastAPI) -> AsyncIterator[None]: cfg, AwsAccountSetupHelper(boto_session), dispatching=False, + http_client=http_client, + boto_session=boto_session, + cf_stack_queue_url=cfg.aws_cf_stack_notification_sqs_url, ), ) @@ -199,7 +226,8 @@ async def setup_teardown_dispatcher(_: FastAPI) -> AsyncIterator[None]: ) rw_redis = deps.add(SN.readwrite_redis, create_redis(cfg.redis_readwrite_url)) domain_event_subscriber = deps.add( - SN.domain_event_subscriber, DomainEventSubscriber(rw_redis, cfg, "dispatching") + SN.domain_event_subscriber, + DomainEventSubscriber(rw_redis, cfg, "dispatching"), ) temp_store_redis = deps.add(SN.temp_store_redis, create_redis(cfg.redis_temp_store_url)) engine = deps.add( @@ -231,10 +259,13 @@ async def setup_teardown_dispatcher(_: FastAPI) -> AsyncIterator[None]: domain_event_publisher = deps.add(SN.domain_event_sender, DomainEventPublisherImpl(fixbackend_events)) workspace_repo = deps.add( - SN.workspace_repo, WorkspaceRepositoryImpl(session_maker, db_access, domain_event_publisher) + SN.workspace_repo, + WorkspaceRepositoryImpl(session_maker, db_access, domain_event_publisher), ) cloud_accounts_redis_publisher = RedisPubSubPublisher( - redis=rw_redis, channel="cloud_accounts", publisher_name="cloud_account_service" + redis=rw_redis, + channel="cloud_accounts", + publisher_name="cloud_account_service", ) deps.add( SN.cloud_account_service, @@ -247,6 +278,9 @@ async def setup_teardown_dispatcher(_: FastAPI) -> AsyncIterator[None]: cfg, AwsAccountSetupHelper(boto_session), dispatching=True, + http_client=http_client, + boto_session=boto_session, + cf_stack_queue_url=cfg.aws_cf_stack_notification_sqs_url, ), ) deps.add( @@ -297,7 +331,8 @@ async def setup_teardown_billing(_: FastAPI) -> AsyncIterator[None]: domain_event_publisher = deps.add(SN.domain_event_sender, DomainEventPublisherImpl(fixbackend_events)) metering_repo = deps.add(SN.metering_repo, MeteringRepository(session_maker)) workspace_repo = deps.add( - SN.workspace_repo, WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher) + SN.workspace_repo, + WorkspaceRepositoryImpl(session_maker, graph_db_access, domain_event_publisher), ) subscription_repo = deps.add(SN.subscription_repo, SubscriptionRepository(session_maker)) aws_marketplace = deps.add( @@ -434,7 +469,11 @@ async def refresh_session(request: Request, call_next: Callable[[Request], Await return response if cfg.static_assets: - app.mount("/", StaticFiles(directory=cfg.static_assets, html=True), name="static_assets") + app.mount( + "/", + StaticFiles(directory=cfg.static_assets, html=True), + name="static_assets", + ) @app.get("/") async def root(request: Request) -> Response: @@ -466,7 +505,11 @@ def setup_process() -> FastAPI: """ current_config = config.get_config() level = logging.DEBUG if current_config.args.debug else logging.INFO - setup_logger(f"fixbackend_{current_config.args.mode}", level=level, get_logging_context=get_logging_context) + setup_logger( + f"fixbackend_{current_config.args.mode}", + level=level, + get_logging_context=get_logging_context, + ) # Replace all special uvicorn handlers for logger in ["uvicorn", "uvicorn.error", "uvicorn.access"]: diff --git a/fixbackend/cloud_accounts/models/__init__.py b/fixbackend/cloud_accounts/models/__init__.py index f7c52a7a..06c89797 100644 --- a/fixbackend/cloud_accounts/models/__init__.py +++ b/fixbackend/cloud_accounts/models/__init__.py @@ -53,6 +53,9 @@ class GcpCloudAccess(CloudAccess): class CloudAccountState(ABC): state_name: ClassVar[str] + def cloud_access(self) -> Optional[CloudAccess]: + return None + class CloudAccountStates: """ @@ -85,6 +88,9 @@ class Discovered(CloudAccountState): state_name: ClassVar[str] = "discovered" access: CloudAccess + def cloud_access(self) -> Optional[CloudAccess]: + return self.access + @frozen class Configured(CloudAccountState): """ @@ -95,6 +101,9 @@ class Configured(CloudAccountState): access: CloudAccess enabled: bool # is enabled for collection + def cloud_access(self) -> Optional[CloudAccess]: + return self.access + @frozen class Degraded(CloudAccountState): """ @@ -105,6 +114,9 @@ class Degraded(CloudAccountState): access: CloudAccess error: str + def cloud_access(self) -> Optional[CloudAccess]: + return self.access + @frozen(kw_only=True) class CloudAccount: diff --git a/fixbackend/cloud_accounts/service_impl.py b/fixbackend/cloud_accounts/service_impl.py index c1035973..6c24be11 100644 --- a/fixbackend/cloud_accounts/service_impl.py +++ b/fixbackend/cloud_accounts/service_impl.py @@ -11,8 +11,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - - +import json import uuid from collections import defaultdict from datetime import timedelta @@ -20,12 +19,14 @@ from logging import getLogger from typing import Any, Dict, List, Optional +import boto3 from attrs import evolve from fixcloudutils.asyncio.periodic import Periodic from fixcloudutils.redis.event_stream import Backoff, DefaultBackoff, Json, MessageContext, RedisStreamListener from fixcloudutils.redis.pub_sub import RedisPubSubPublisher from fixcloudutils.service import Service from fixcloudutils.util import utc +from httpx import AsyncClient from redis.asyncio import Redis from fixbackend.cloud_accounts.account_setup import AssumeRoleResults, AwsAccountSetupHelper @@ -54,6 +55,8 @@ WorkspaceId, ) from fixbackend.logging_context import set_cloud_account_id, set_fix_cloud_account_id, set_workspace_id +from fixbackend.sqs import SQSRawListener +from fixbackend.utils import uid from fixbackend.workspaces.repository import WorkspaceRepository log = getLogger(__name__) @@ -70,12 +73,14 @@ def __init__( config: Config, account_setup_helper: AwsAccountSetupHelper, dispatching: bool, + http_client: AsyncClient, + boto_session: boto3.Session, + cf_stack_queue_url: Optional[str] = None, ) -> None: self.workspace_repository = workspace_repository self.cloud_account_repository = cloud_account_repository self.pubsub_publisher = pubsub_publisher self.domain_events = domain_event_publisher - backoff_config: Dict[str, Backoff] = defaultdict(lambda: DefaultBackoff) backoff_config[AwsAccountDiscovered.kind] = Backoff( base_delay=5, @@ -105,17 +110,139 @@ def __init__( self.dispatching = dispatching self.fast_lane_timeout = timedelta(minutes=1) self.become_degraded_timeout = timedelta(minutes=15) + self.cf_listener = ( + SQSRawListener( + session=boto_session, + queue_url=cf_stack_queue_url, + message_processor=self.process_cf_stack_event, + consider_failed_after=timedelta(minutes=5), + max_nr_of_messages_in_one_batch=1, + wait_for_new_messages_to_arrive=timedelta(seconds=10), + ) + if cf_stack_queue_url + else None + ) + self.http_client = http_client async def start(self) -> Any: await self.domain_event_listener.start() if self.periodic: await self.periodic.start() + if self.cf_listener: + await self.cf_listener.start() async def stop(self) -> Any: + if self.cf_listener: + await self.cf_listener.stop() await self.domain_event_listener.stop() if self.periodic: await self.periodic.stop() + async def process_cf_stack_event(self, message: Json) -> Optional[CloudAccount]: + log.info(f"Received CF stack event: {message}") + + async def send_response( + msg: Json, physical_resource_id: Optional[str] = None, error_message: Optional[str] = None + ) -> None: + try: + physical_resource_id = physical_resource_id or msg["PhysicalResourceId"] + request_id = msg["RequestId"] + logical_resource_id = msg["LogicalResourceId"] + response_url = msg["ResponseURL"] + resource_properties = msg["ResourceProperties"] + role_name = AwsRoleName(resource_properties["RoleName"]) + stack_id = resource_properties["StackId"] + except Exception as e: + log.warning(f"Not enough data to inform CF: {msg}. Error: {e}") + return None + + # Signal CF that we're done + response = await self.http_client.put( + response_url, + json={ + "Status": "FAILURE" if error_message else "SUCCESS", + "Reason": error_message or "OK", + "LogicalResourceId": logical_resource_id, + "PhysicalResourceId": physical_resource_id, + "StackId": stack_id, + "RequestId": request_id, + "Data": {"RoleName": role_name}, + }, + ) + if response.is_error: + raise RuntimeError(f"Failed to signal CF that we're done: {response}") + + async def handle_stack_created(msg: Json) -> Optional[CloudAccount]: + try: + resource_properties = msg["ResourceProperties"] + workspace_id = WorkspaceId(uuid.UUID(resource_properties["WorkspaceId"])) + external_id = ExternalId(uuid.UUID(resource_properties["ExternalId"])) + role_name = AwsRoleName(resource_properties["RoleName"]) + stack_id = resource_properties["StackId"] + assert stack_id.startswith("arn:aws:cloudformation:") + assert stack_id.count(":") == 5 + account_id = CloudAccountId(stack_id.split(":")[4]) + except Exception as e: + log.warning(f"Received invalid CF stack create event: {msg}. Error: {e}") + await send_response(msg, str(uid()), "Invalid format for CF stack create/update event") + return None + # Create/Update the account on our side + set_workspace_id(str(workspace_id)) + set_cloud_account_id(account_id) + account = await self.create_aws_account( + workspace_id=workspace_id, + account_id=account_id, + role_name=role_name, + external_id=external_id, + account_name=None, + ) + # Signal to CF that we're done + await send_response(msg, str(account.id)) + return account + + async def handle_stack_deleted(msg: Json) -> Optional[CloudAccount]: + try: + resource_properties = msg["ResourceProperties"] + role_name = AwsRoleName(resource_properties["RoleName"]) + external_id = ExternalId(uuid.UUID(resource_properties["ExternalId"])) + cloud_account_id = FixCloudAccountId(uuid.UUID(msg["PhysicalResourceId"])) + except Exception as e: + log.warning(f"Received invalid CF stack delete event: {msg}. Error: {e}") + await send_response(msg, str(uid()), "Invalid format for CF stack delete event") + return None + if ( + (account := await self.cloud_account_repository.get(cloud_account_id)) + and isinstance(access := account.state.cloud_access(), AwsCloudAccess) + # also make sure the stack refers to the same role and external id + and access.role_name == role_name + and access.external_id == external_id + ): + account = await self.__degrade_account( + FixCloudAccountId(cloud_account_id), "CloudformationStack deleted" + ) + await send_response(msg, str(cloud_account_id)) + return account + + try: + body = json.loads(message["Body"]) + assert body["Type"] == "Notification" + content = json.loads(body["Message"]) + kind = content["RequestType"] + match kind: + case "Create": + return await handle_stack_created(content) + case "Delete": + return await handle_stack_deleted(content) + case "Update": + return await handle_stack_created(content) + case _: + log.info(f"Received a CF stack event that is currently not handled. Ignore. {kind}") + await send_response(message) # still try to acknowledge the message + return None + except Exception as e: + log.warning(f"Received invalid CF stack event: {message}. Error: {e}") + return None + async def process_domain_event(self, message: Json, context: MessageContext) -> None: match context.kind: case TenantAccountsCollected.kind: @@ -205,24 +332,7 @@ def fast_lane_should_end() -> bool: if should_move_to_degraded(): log.info("failed to assume role, but timeout is reached, moving account to degraded state") error = "Cannot assume role" - - def update_fn(cloud_account: CloudAccount) -> CloudAccount: - if isinstance(cloud_account.state, CloudAccountStates.Discovered): - return evolve( - cloud_account, state=CloudAccountStates.Degraded(access, error), state_updated_at=utc() - ) - else: - return cloud_account - - await self.cloud_account_repository.update(account.id, update_fn) - await self.domain_events.publish( - AwsAccountDegraded( - cloud_account_id=account.id, - tenant_id=account.workspace_id, - aws_account_id=account.account_id, - error=error, - ) - ) + await self.__degrade_account(account.id, error) return None elif fast_lane_should_end(): log.info("Can't assume role, leaving account in discovered state") @@ -453,3 +563,25 @@ def update_state(cloud_account: CloudAccount) -> CloudAccount: raise ValueError(f"Account {cloud_account_id} is not configured, cannot enable account") return await self.cloud_account_repository.update(cloud_account_id, update_state) + + async def __degrade_account( + self, + account_id: FixCloudAccountId, + error: str, + ) -> CloudAccount: + def set_degraded(cloud_account: CloudAccount) -> CloudAccount: + if access := cloud_account.state.cloud_access(): + return evolve(cloud_account, state=CloudAccountStates.Degraded(access, error), state_updated_at=utc()) + else: + return cloud_account + + account = await self.cloud_account_repository.update(account_id, set_degraded) + await self.domain_events.publish( + AwsAccountDegraded( + cloud_account_id=account.id, + tenant_id=account.workspace_id, + aws_account_id=account.account_id, + error=error, + ) + ) + return account diff --git a/fixbackend/config.py b/fixbackend/config.py index b61d7895..fa35b6f3 100644 --- a/fixbackend/config.py +++ b/fixbackend/config.py @@ -61,6 +61,7 @@ class Config(BaseSettings): customerio_site_id: Optional[str] customerio_api_key: Optional[str] cloud_account_service_event_parallelism: int + aws_cf_stack_notification_sqs_url: Optional[str] def frontend_cdn_origin(self) -> str: return f"{self.cdn_endpoint}/{self.cdn_bucket}/{self.fixui_sha}" @@ -123,6 +124,9 @@ def parse_args(argv: Optional[Sequence[str]] = None) -> Namespace: parser.add_argument( "--aws-marketplace-metering-sqs-url", default=os.environ.get("AWS_MARKETPLACE_METERING_SQS_URL") ) + parser.add_argument( + "--aws-cf-stack-notification-sqs-url", default=os.environ.get("AWS_CF_STACK_NOTIFICATION_SQS_URL") + ) parser.add_argument("--ca-cert", type=Path, default=os.environ.get("CA_CERT")) parser.add_argument("--host-cert", type=Path, default=os.environ.get("HOST_CERT")) parser.add_argument("--host-key", type=Path, default=os.environ.get("HOST_KEY")) diff --git a/fixbackend/subscription/aws_marketplace.py b/fixbackend/subscription/aws_marketplace.py index a4d6c79d..7dc8c863 100644 --- a/fixbackend/subscription/aws_marketplace.py +++ b/fixbackend/subscription/aws_marketplace.py @@ -65,7 +65,7 @@ def __init__( self.handle_message, consider_failed_after=timedelta(minutes=5), max_nr_of_messages_in_one_batch=1, - wait_for_new_messages_to_arrive=timedelta(seconds=5), + wait_for_new_messages_to_arrive=timedelta(seconds=10), ) if sqs_queue_url is not None else None diff --git a/tests/fixbackend/cloud_accounts/service_test.py b/tests/fixbackend/cloud_accounts/service_test.py index 3bd15369..d6c0ee09 100644 --- a/tests/fixbackend/cloud_accounts/service_test.py +++ b/tests/fixbackend/cloud_accounts/service_test.py @@ -11,17 +11,18 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . - - +import json import uuid from datetime import datetime, timedelta from typing import Callable, Dict, List, Optional, Tuple +import boto3 import pytest from attrs import evolve from fixcloudutils.redis.event_stream import MessageContext, RedisStreamPublisher from fixcloudutils.redis.pub_sub import RedisPubSubPublisher from fixcloudutils.types import Json +from httpx import AsyncClient, Request, Response from redis.asyncio import Redis from fixbackend.cloud_accounts.account_setup import AssumeRoleResult, AssumeRoleResults, AwsAccountSetupHelper @@ -54,6 +55,8 @@ from fixbackend.workspaces.repository import WorkspaceRepositoryImpl from fixcloudutils.util import utc +from tests.fixbackend.conftest import RequestHandlerMock + class CloudAccountRepositoryMock(CloudAccountRepository): def __init__(self) -> None: @@ -112,6 +115,7 @@ async def list_all_discovered_accounts(self) -> List[CloudAccount]: class OrganizationServiceMock(WorkspaceRepositoryImpl): + # noinspection PyMissingConstructor def __init__(self) -> None: pass @@ -122,6 +126,7 @@ async def get_workspace(self, workspace_id: WorkspaceId, with_users: bool = Fals class RedisStreamPublisherMock(RedisStreamPublisher): + # noinspection PyMissingConstructor def __init__(self) -> None: self.last_message: Optional[Tuple[str, Json]] = None @@ -130,6 +135,7 @@ async def publish(self, kind: str, message: Json) -> None: class RedisPubSubPublisherMock(RedisPubSubPublisher): + # noinspection PyMissingConstructor def __init__(self) -> None: self.last_message: Optional[Tuple[str, Json, Optional[str]]] = None @@ -146,6 +152,7 @@ async def publish(self, event: Event) -> None: class AwsAccountSetupHelperMock(AwsAccountSetupHelper): + # noinspection PyMissingConstructor def __init__(self) -> None: self.can_assume = True self.org_accounts: Dict[CloudAccountId, CloudAccountName] = {} @@ -170,17 +177,44 @@ async def list_account_aliases(self, assume_role_result: AssumeRoleResults.Succe now = datetime.utcnow() -@pytest.mark.asyncio -async def test_create_aws_account( +@pytest.fixture +def repository() -> CloudAccountRepositoryMock: + return CloudAccountRepositoryMock() + + +@pytest.fixture +def organization_repository() -> OrganizationServiceMock: + return OrganizationServiceMock() + + +@pytest.fixture +def pubsub_publisher() -> RedisPubSubPublisherMock: + return RedisPubSubPublisherMock() + + +@pytest.fixture +def domain_sender() -> DomainEventSenderMock: + return DomainEventSenderMock() + + +@pytest.fixture +def account_setup_helper() -> AwsAccountSetupHelperMock: + return AwsAccountSetupHelperMock() + + +@pytest.fixture +def service( + organization_repository: OrganizationServiceMock, + repository: CloudAccountRepositoryMock, + pubsub_publisher: RedisPubSubPublisherMock, + domain_sender: DomainEventSenderMock, + account_setup_helper: AwsAccountSetupHelperMock, arq_redis: Redis, default_config: Config, -) -> None: - repository = CloudAccountRepositoryMock() - organization_repository = OrganizationServiceMock() - pubsub_publisher = RedisPubSubPublisherMock() - domain_sender = DomainEventSenderMock() - account_setup_helper = AwsAccountSetupHelperMock() - service = CloudAccountServiceImpl( + boto_session: boto3.Session, + http_client: AsyncClient, +) -> CloudAccountServiceImpl: + return CloudAccountServiceImpl( organization_repository, repository, pubsub_publisher, @@ -189,8 +223,19 @@ async def test_create_aws_account( default_config, account_setup_helper, dispatching=False, + boto_session=boto_session, + http_client=http_client, + cf_stack_queue_url=None, ) + +@pytest.mark.asyncio +async def test_create_aws_account( + repository: CloudAccountRepositoryMock, + pubsub_publisher: RedisPubSubPublisherMock, + domain_sender: DomainEventSenderMock, + service: CloudAccountServiceImpl, +) -> None: # happy case acc = await service.create_aws_account( workspace_id=test_workspace_id, @@ -297,23 +342,10 @@ async def test_create_aws_account( @pytest.mark.asyncio -async def test_delete_aws_account(arq_redis: Redis, default_config: Config) -> None: - repository = CloudAccountRepositoryMock() - organization_repository = OrganizationServiceMock() - pubsub_publisher = RedisPubSubPublisherMock() - domain_sender = DomainEventSenderMock() - account_setup_helper = AwsAccountSetupHelperMock() - service = CloudAccountServiceImpl( - organization_repository, - repository, - pubsub_publisher, - domain_sender, - arq_redis, - default_config, - account_setup_helper, - dispatching=False, - ) - +async def test_delete_aws_account( + repository: CloudAccountRepositoryMock, + service: CloudAccountServiceImpl, +) -> None: account = await service.create_aws_account( workspace_id=test_workspace_id, account_id=account_id, @@ -334,23 +366,7 @@ async def test_delete_aws_account(arq_redis: Redis, default_config: Config) -> N @pytest.mark.asyncio -async def test_store_last_run_info(arq_redis: Redis, default_config: Config) -> None: - repository = CloudAccountRepositoryMock() - organization_repository = OrganizationServiceMock() - pubsub_publisher = RedisPubSubPublisherMock() - domain_sender = DomainEventSenderMock() - account_setup_helper = AwsAccountSetupHelperMock() - service = CloudAccountServiceImpl( - organization_repository, - repository, - pubsub_publisher, - domain_sender, - arq_redis, - default_config, - account_setup_helper, - dispatching=False, - ) - +async def test_store_last_run_info(service: CloudAccountServiceImpl) -> None: account = await service.create_aws_account( workspace_id=test_workspace_id, account_id=account_id, @@ -379,23 +395,7 @@ async def test_store_last_run_info(arq_redis: Redis, default_config: Config) -> @pytest.mark.asyncio -async def test_get_cloud_account(arq_redis: Redis, default_config: Config) -> None: - repository = CloudAccountRepositoryMock() - organization_repository = OrganizationServiceMock() - pubsub_publisher = RedisPubSubPublisherMock() - domain_sender = DomainEventSenderMock() - account_setup_helper = AwsAccountSetupHelperMock() - service = CloudAccountServiceImpl( - organization_repository, - repository, - pubsub_publisher, - domain_sender, - arq_redis, - default_config, - account_setup_helper, - dispatching=False, - ) - +async def test_get_cloud_account(repository: CloudAccountRepositoryMock, service: CloudAccountServiceImpl) -> None: account = await service.create_aws_account( workspace_id=test_workspace_id, account_id=account_id, @@ -428,24 +428,7 @@ async def test_get_cloud_account(arq_redis: Redis, default_config: Config) -> No @pytest.mark.asyncio -async def test_list_cloud_accounts(arq_redis: Redis, default_config: Config) -> None: - repository = CloudAccountRepositoryMock() - organization_repository = OrganizationServiceMock() - pubsub_publisher = RedisPubSubPublisherMock() - domain_sender = DomainEventSenderMock() - - account_setup_helper = AwsAccountSetupHelperMock() - service = CloudAccountServiceImpl( - organization_repository, - repository, - pubsub_publisher, - domain_sender, - arq_redis, - default_config, - account_setup_helper, - dispatching=False, - ) - +async def test_list_cloud_accounts(repository: CloudAccountRepositoryMock, service: CloudAccountServiceImpl) -> None: account = await service.create_aws_account( workspace_id=test_workspace_id, account_id=account_id, @@ -469,24 +452,9 @@ async def test_list_cloud_accounts(arq_redis: Redis, default_config: Config) -> @pytest.mark.asyncio -async def test_update_cloud_account_name(arq_redis: Redis, default_config: Config) -> None: - repository = CloudAccountRepositoryMock() - organization_repository = OrganizationServiceMock() - pubsub_publisher = RedisPubSubPublisherMock() - domain_sender = DomainEventSenderMock() - - account_setup_helper = AwsAccountSetupHelperMock() - service = CloudAccountServiceImpl( - organization_repository, - repository, - pubsub_publisher, - domain_sender, - arq_redis, - default_config, - account_setup_helper, - dispatching=False, - ) - +async def test_update_cloud_account_name( + repository: CloudAccountRepositoryMock, service: CloudAccountServiceImpl +) -> None: account = await service.create_aws_account( workspace_id=test_workspace_id, account_id=account_id, @@ -523,24 +491,9 @@ async def test_update_cloud_account_name(arq_redis: Redis, default_config: Confi @pytest.mark.asyncio -async def test_handle_account_discovered_success(arq_redis: Redis, default_config: Config) -> None: - repository = CloudAccountRepositoryMock() - organization_repository = OrganizationServiceMock() - pubsub_publisher = RedisPubSubPublisherMock() - domain_sender = DomainEventSenderMock() - - account_setup_helper = AwsAccountSetupHelperMock() - service = CloudAccountServiceImpl( - organization_repository, - repository, - pubsub_publisher, - domain_sender, - arq_redis, - default_config, - account_setup_helper, - dispatching=False, - ) - +async def test_handle_account_discovered_success( + repository: CloudAccountRepositoryMock, domain_sender: DomainEventSenderMock, service: CloudAccountServiceImpl +) -> None: account = await service.create_aws_account( workspace_id=test_workspace_id, account_id=account_id, @@ -568,24 +521,12 @@ async def test_handle_account_discovered_success(arq_redis: Redis, default_confi @pytest.mark.asyncio -async def test_handle_account_discovered_assume_role_failure(arq_redis: Redis, default_config: Config) -> None: - repository = CloudAccountRepositoryMock() - organization_repository = OrganizationServiceMock() - pubsub_publisher = RedisPubSubPublisherMock() - domain_sender = DomainEventSenderMock() - - account_setup_helper = AwsAccountSetupHelperMock() - service = CloudAccountServiceImpl( - organization_repository, - repository, - pubsub_publisher, - domain_sender, - arq_redis, - default_config, - account_setup_helper, - dispatching=False, - ) - +async def test_handle_account_discovered_assume_role_failure( + repository: CloudAccountRepositoryMock, + domain_sender: DomainEventSenderMock, + service: CloudAccountServiceImpl, + account_setup_helper: AwsAccountSetupHelperMock, +) -> None: # boto3 cannot assume right away account_id = CloudAccountId("foobar") role_name = AwsRoleName("FooBarRole") @@ -633,24 +574,12 @@ async def test_handle_account_discovered_assume_role_failure(arq_redis: Redis, d @pytest.mark.asyncio -async def test_handle_account_discovered_list_accounts_success(arq_redis: Redis, default_config: Config) -> None: - repository = CloudAccountRepositoryMock() - organization_repository = OrganizationServiceMock() - pubsub_publisher = RedisPubSubPublisherMock() - domain_sender = DomainEventSenderMock() - - account_setup_helper = AwsAccountSetupHelperMock() - service = CloudAccountServiceImpl( - organization_repository, - repository, - pubsub_publisher, - domain_sender, - arq_redis, - default_config, - account_setup_helper, - dispatching=False, - ) - +async def test_handle_account_discovered_list_accounts_success( + repository: CloudAccountRepositoryMock, + domain_sender: DomainEventSenderMock, + service: CloudAccountServiceImpl, + account_setup_helper: AwsAccountSetupHelperMock, +) -> None: account_id = CloudAccountId("foobar") role_name = AwsRoleName("FooBarRole") account = await service.create_aws_account( @@ -687,24 +616,12 @@ async def test_handle_account_discovered_list_accounts_success(arq_redis: Redis, @pytest.mark.asyncio -async def test_handle_account_discovered_list_aliases_success(arq_redis: Redis, default_config: Config) -> None: - repository = CloudAccountRepositoryMock() - organization_repository = OrganizationServiceMock() - pubsub_publisher = RedisPubSubPublisherMock() - domain_sender = DomainEventSenderMock() - - account_setup_helper = AwsAccountSetupHelperMock() - service = CloudAccountServiceImpl( - organization_repository, - repository, - pubsub_publisher, - domain_sender, - arq_redis, - default_config, - account_setup_helper, - dispatching=False, - ) - +async def test_handle_account_discovered_list_aliases_success( + repository: CloudAccountRepositoryMock, + domain_sender: DomainEventSenderMock, + service: CloudAccountServiceImpl, + account_setup_helper: AwsAccountSetupHelperMock, +) -> None: # boto3 cannot assume right away account_id = CloudAccountId("foobar") role_name = AwsRoleName("FooBarRole") @@ -743,24 +660,9 @@ async def test_handle_account_discovered_list_aliases_success(arq_redis: Redis, @pytest.mark.asyncio -async def test_enable_disable_cloud_account(arq_redis: Redis, default_config: Config) -> None: - repository = CloudAccountRepositoryMock() - organization_repository = OrganizationServiceMock() - pubsub_publisher = RedisPubSubPublisherMock() - domain_sender = DomainEventSenderMock() - - account_setup_helper = AwsAccountSetupHelperMock() - service = CloudAccountServiceImpl( - organization_repository, - repository, - pubsub_publisher, - domain_sender, - arq_redis, - default_config, - account_setup_helper, - dispatching=False, - ) - +async def test_enable_disable_cloud_account( + repository: CloudAccountRepositoryMock, service: CloudAccountServiceImpl +) -> None: account = await service.create_aws_account( workspace_id=test_workspace_id, account_id=account_id, @@ -808,7 +710,7 @@ async def test_enable_disable_cloud_account(arq_redis: Redis, default_config: Co ) with pytest.raises(Exception): - updated = await service.enable_cloud_account( + await service.enable_cloud_account( test_workspace_id, account.id, ) @@ -823,23 +725,12 @@ async def test_enable_disable_cloud_account(arq_redis: Redis, default_config: Co @pytest.mark.asyncio -async def test_configure_account(arq_redis: Redis, default_config: Config) -> None: - repository = CloudAccountRepositoryMock() - organization_repository = OrganizationServiceMock() - pubsub_publisher = RedisPubSubPublisherMock() - domain_sender = DomainEventSenderMock() - - account_setup_helper = AwsAccountSetupHelperMock() - service = CloudAccountServiceImpl( - organization_repository, - repository, - pubsub_publisher, - domain_sender, - arq_redis, - default_config, - account_setup_helper, - dispatching=False, - ) +async def test_configure_account( + repository: CloudAccountRepositoryMock, + domain_sender: DomainEventSenderMock, + service: CloudAccountServiceImpl, + account_setup_helper: AwsAccountSetupHelperMock, +) -> None: account_setup_helper.can_assume = False def get_account(state_updated_at: datetime) -> CloudAccount: @@ -899,3 +790,60 @@ def get_account(state_updated_at: datetime) -> CloudAccount: assert event.cloud_account_id == account.id assert event.aws_account_id == account_id assert event.tenant_id == account.workspace_id + + +@pytest.mark.asyncio +async def test_handle_cf_sqs_message( + repository: CloudAccountRepositoryMock, service: CloudAccountServiceImpl, request_handler_mock: RequestHandlerMock +) -> None: + async def handle_request(_: Request) -> Response: + return Response(200, content=b"ok") + + def notification(kind: str, physical_resource_id: Optional[str] = None) -> Json: + base = { + "RequestType": kind, + "ServiceToken": "arn:aws:sns:us-east-1:12345:SomeCallbacks", + "ResponseURL": "https://cloudformation-custom.test.com/", + "StackId": "arn:aws:cloudformation:us-east-1:12345:stack/name/some-id", + "RequestId": "855e25d5-3b80-4aed-b9f4-af8682deaf79", + "LogicalResourceId": "FixAccessFunction", + "ResourceType": "Custom::Function", + "ResourceProperties": { + "ServiceToken": "arn:aws:sns:us-east-1:12345:SomeCallbacks", + "RoleName": role_name, + "ExternalId": str(external_id), + "WorkspaceId": str(test_workspace_id), + "StackId": "arn:aws:cloudformation:us-east-1:12345:stack/name/some-id", + }, + } + if physical_resource_id: + base["PhysicalResourceId"] = physical_resource_id + cf_message = { + "Type": "Notification", + "MessageId": "38ccbec9-9999-5871-9383-e31e58450b68", + "TopicArn": "arn:aws:sns:us-east-1:12345:SomeCallbacks", + "Subject": "AWS CloudFormation custom resource request", + "Message": json.dumps(base), + "Timestamp": "2023-11-22T08:45:16.159Z", + "SignatureVersion": "1", + "Signature": "sig", + "SigningCertURL": "https://cert/pem", + "UnsubscribeURL": "https://unsubscribe", + } + return {"Body": json.dumps(cf_message)} + + # Handle Create Message + request_handler_mock.append(handle_request) + assert len(repository.accounts) == 0 + account = await service.process_cf_stack_event(notification("Create")) + assert account is not None + assert len(repository.accounts) == 1 + assert repository.accounts[account.id] == account + + # Handle Delete Message + repository.accounts[account.id] = evolve( + account, state=CloudAccountStates.Configured(AwsCloudAccess(external_id, role_name), enabled=True) + ) + account = await service.process_cf_stack_event(notification("Delete", str(account.id))) + assert account is not None + assert isinstance(account.state, CloudAccountStates.Degraded) diff --git a/tests/fixbackend/conftest.py b/tests/fixbackend/conftest.py index 5fb0b2ac..3e51f820 100644 --- a/tests/fixbackend/conftest.py +++ b/tests/fixbackend/conftest.py @@ -64,7 +64,7 @@ DATABASE_URL = "mysql+aiomysql://root@127.0.0.1:3306/fixbackend-testdb" # only used to create/drop the database SYNC_DATABASE_URL = "mysql+pymysql://root@127.0.0.1:3306/fixbackend-testdb" -InventoryMock = List[Callable[[Request], Awaitable[Response]]] +RequestHandlerMock = List[Callable[[Request], Awaitable[Response]]] os.environ["LOCAL_DEV_ENV"] = "true" @@ -117,6 +117,7 @@ def default_config() -> Config: customerio_site_id=None, customerio_api_key=None, cloud_account_service_event_parallelism=1000, + aws_cf_stack_notification_sqs_url=None, ) @@ -341,22 +342,28 @@ def nd_json_response(content: Sequence[JsonElement]) -> Response: @pytest.fixture -async def inventory_mock() -> InventoryMock: +async def request_handler_mock() -> RequestHandlerMock: return [] @pytest.fixture -async def inventory_client(inventory_mock: InventoryMock) -> AsyncIterator[InventoryClient]: +async def http_client(request_handler_mock: RequestHandlerMock) -> AsyncClient: async def app(request: Request) -> Response: - for mock in inventory_mock: + for mock in request_handler_mock: try: return await mock(request) except AttributeError: pass raise AttributeError(f'Unexpected request: {request.url.path} with content {request.content.decode("utf-8")}') - async_client = AsyncClient(transport=MockTransport(app)) - async with InventoryClient("http://localhost:8980", client=async_client) as client: + return AsyncClient(transport=MockTransport(app)) + + +@pytest.fixture +async def inventory_client( + http_client: AsyncClient, request_handler_mock: RequestHandlerMock +) -> AsyncIterator[InventoryClient]: + async with InventoryClient("http://localhost:8980", client=http_client) as client: yield client diff --git a/tests/fixbackend/inventory/inventory_client_test.py b/tests/fixbackend/inventory/inventory_client_test.py index 9d7b4ce2..0cf5236a 100644 --- a/tests/fixbackend/inventory/inventory_client_test.py +++ b/tests/fixbackend/inventory/inventory_client_test.py @@ -30,7 +30,7 @@ from fixbackend.ids import WorkspaceId, CloudAccountId, NodeId from fixbackend.inventory.inventory_client import InventoryClient from fixbackend.inventory.schemas import CompletePathRequest -from tests.fixbackend.conftest import InventoryMock, nd_json_response, json_response +from tests.fixbackend.conftest import RequestHandlerMock, nd_json_response, json_response db_access = GraphDatabaseAccess(WorkspaceId(uuid.uuid1()), "server", "database", "username", "password") @@ -38,7 +38,7 @@ @pytest.fixture def mocked_inventory_client( inventory_client: InventoryClient, - inventory_mock: InventoryMock, + request_handler_mock: RequestHandlerMock, azure_virtual_machine_resource_json: Json, aws_ec2_model_json: Json, ) -> InventoryClient: @@ -82,7 +82,7 @@ async def mock(request: Request) -> Response: else: raise AttributeError(f"Unexpected request: {request.method} {request.url.path} with content {content}") - inventory_mock.append(mock) + request_handler_mock.append(mock) return inventory_client diff --git a/tests/fixbackend/inventory/inventory_service_test.py b/tests/fixbackend/inventory/inventory_service_test.py index ff61b78d..6b6c652b 100644 --- a/tests/fixbackend/inventory/inventory_service_test.py +++ b/tests/fixbackend/inventory/inventory_service_test.py @@ -44,7 +44,7 @@ HistorySearch, HistoryChange, ) -from tests.fixbackend.conftest import InventoryMock, json_response, nd_json_response +from tests.fixbackend.conftest import RequestHandlerMock, json_response, nd_json_response from fixbackend.domain_events.subscriber import DomainEventSubscriber db = GraphDatabaseAccess(WorkspaceId(uuid.uuid1()), "server", "database", "username", "password") @@ -60,10 +60,10 @@ @pytest.fixture def mocked_answers( - inventory_mock: InventoryMock, + request_handler_mock: RequestHandlerMock, benchmark_json: List[Json], azure_virtual_machine_resource_json: Json, -) -> InventoryMock: +) -> RequestHandlerMock: async def mock(request: Request) -> Response: content = request.content.decode("utf-8") if request.url.path == "/cli/execute" and content.endswith("jq --no-rewrite .group"): @@ -109,18 +109,18 @@ async def mock(request: Request) -> Response: else: raise AttributeError(f"Unexpected request: {request.url.path} with content {content}") - inventory_mock.append(mock) - return inventory_mock + request_handler_mock.append(mock) + return request_handler_mock async def test_benchmark_command( - inventory_service: InventoryService, benchmark_json: List[Json], mocked_answers: InventoryMock + inventory_service: InventoryService, benchmark_json: List[Json], mocked_answers: RequestHandlerMock ) -> None: response = [a async for a in await inventory_service.benchmark(db, "benchmark_name")] assert response == benchmark_json -async def test_summary(inventory_service: InventoryService, mocked_answers: InventoryMock) -> None: +async def test_summary(inventory_service: InventoryService, mocked_answers: RequestHandlerMock) -> None: summary = await inventory_service.summary(db) assert len(summary.benchmarks) == 2 assert summary.overall_score == 42 @@ -188,7 +188,7 @@ async def test_dict_values_by() -> None: assert [a for a in dict_values_by(inv, lambda x: -x)] == [1, 2, 3, 11, 12, 13, 21, 22, 23] -async def test_search_list(inventory_service: InventoryService, mocked_answers: InventoryMock) -> None: +async def test_search_list(inventory_service: InventoryService, mocked_answers: RequestHandlerMock) -> None: expected = [ { "columns": [ @@ -208,7 +208,7 @@ async def test_search_list(inventory_service: InventoryService, mocked_answers: assert result == expected -async def test_search_start_data(inventory_service: InventoryService, mocked_answers: InventoryMock) -> None: +async def test_search_start_data(inventory_service: InventoryService, mocked_answers: RequestHandlerMock) -> None: result = [ SearchCloudResource(id="234", name="bla", cloud="gcp"), SearchCloudResource(id="123", name="foo", cloud="aws"), @@ -221,7 +221,7 @@ async def test_search_start_data(inventory_service: InventoryService, mocked_ans async def test_resource( - inventory_service: InventoryService, mocked_answers: InventoryMock, azure_virtual_machine_resource_json: Json + inventory_service: InventoryService, mocked_answers: RequestHandlerMock, azure_virtual_machine_resource_json: Json ) -> None: res = await inventory_service.resource(db, NodeId("some_node_id")) assert res["neighborhood"] == neighborhood