Skip to content

Commit

Permalink
Use list with mock event for test_repository
Browse files Browse the repository at this point in the history
  • Loading branch information
TheByronHimes committed Mar 20, 2024
1 parent 6ba34d3 commit 3241996
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 37 deletions.
7 changes: 0 additions & 7 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,6 @@ def fixture_auth_headers_steward() -> dict[str, str]:
return headers_for_token(token)


class AccessRequestDetails(NamedTuple):
"""Hashable version of the AccessRequestDetails event schema"""

user_id: str
dataset_id: str


class JointFixture(NamedTuple):
"""Joint fixture object."""

Expand Down
61 changes: 31 additions & 30 deletions tests/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@

"""Test the access request repository"""

from collections import defaultdict
from collections.abc import AsyncIterator, Mapping
from datetime import timedelta
from operator import attrgetter
from typing import Any, Optional
from typing import Any, NamedTuple, Optional

import pytest
from ghga_service_commons.auth.ghga import AcademicTitle, AuthContext, UserStatus
Expand All @@ -36,7 +35,6 @@
from ars.ports.outbound.access_grants import AccessGrantsPort
from ars.ports.outbound.dao import AccessRequestDaoPort, ResourceNotFoundError
from ars.ports.outbound.event_pub import EventPublisherPort
from tests.fixtures import AccessRequestDetails

pytestmark = pytest.mark.asyncio(scope="session")

Expand Down Expand Up @@ -191,40 +189,43 @@ async def update(self, dto: AccessRequest) -> None:
self.last_upsert = dto


class MockAccessRequestEvent(NamedTuple):
"""Mock of AccessRequestDetails plus status field to represent event type"""

user_id: str
dataset_id: str
status: str


class EventPublisherDummy(EventPublisherPort):
"""Dummy event publisher for testing."""

events: dict[AccessRequestDetails, list[str]]
events: list[MockAccessRequestEvent]

def reset(self) -> None:
"""Reset the recorded events."""
self.events = defaultdict(list)
self.events = []

@property
def num_events(self):
"""Get total number of recorded events."""
return len(self.events)

def status_for(self, request: AccessRequest) -> list[str]:
"""Get the statuses used in the events published for a given request."""
details = AccessRequestDetails(
user_id=request.user_id, dataset_id=request.dataset_id
)
try:
status = self.events[details]
except KeyError as err:
raise RuntimeError(
f"No events recorded for request with user id '{details.user_id}'"
+ f" and dataset id '{details.dataset_id}'"
) from err
return status
def events_for(self, request: AccessRequest) -> list[MockAccessRequestEvent]:
"""Get the events published for a given request."""
return [
event
for event in self.events
if event.user_id == request.user_id
and event.dataset_id == request.dataset_id
]

def _record_request(self, *, request: AccessRequest, request_state: str):
"""Record a request as either created, allowed, or denied for a user and dataset."""
details = AccessRequestDetails(
user_id=request.user_id, dataset_id=request.dataset_id
mock_event = MockAccessRequestEvent(
request.user_id, request.dataset_id, request_state
)
self.events[details].append(request_state)
self.events.append(mock_event)

async def publish_request_allowed(self, *, request: AccessRequest) -> None:
"""Mark an access request as allowed via event publish."""
Expand Down Expand Up @@ -318,12 +319,12 @@ async def test_can_create_request():

assert dao.last_upsert == request

# there will be exactly 1 'event' published (a call to the dummy publisher)
assert event_publisher.num_events == 1
# the 'publish_request_created' method should have been called, get events for request
events = event_publisher.events_for(request=request)

# the 'publish_request_created' method should have been called
request_states = event_publisher.status_for(request=request)
assert request_states[0] == "created"
# there will be exactly 1 'event' published (a call to the dummy publisher)
assert len(events) == 1
assert events[0].status == "created"

assert access_grants.last_grant == "nothing granted so far"

Expand Down Expand Up @@ -368,7 +369,7 @@ async def test_silently_correct_request_that_is_too_early():
assert dao.last_upsert == request

# There should be one event published which communicates the state of the request
assert event_publisher.num_events == 1
assert len(event_publisher.events) == 1


async def test_cannot_create_request_too_much_in_advance():
Expand Down Expand Up @@ -553,9 +554,9 @@ async def test_set_status_to_allowed():
assert changed_dict.pop("changed_by") == "[email protected]"
assert changed_dict == original_dict

assert event_publisher.num_events == 1
request_states = event_publisher.status_for(request=changed_request)
assert request_states[0] == "allowed"
events = event_publisher.events_for(request=changed_request)
assert len(events) == 1
assert events[0].status == "allowed"

assert (
access_grants.last_grant == "to [email protected] for new-dataset"
Expand Down

0 comments on commit 3241996

Please sign in to comment.