generated from ghga-de/microservice-repository-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use list with mock event for test_repository
- Loading branch information
1 parent
6ba34d3
commit 3241996
Showing
2 changed files
with
31 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,11 +16,10 @@ | |
|
||
"""Test the access request repository""" | ||
|
||
from collections import defaultdict | ||
from collections.abc import AsyncIterator, Mapping | ||
from datetime import timedelta | ||
from operator import attrgetter | ||
from typing import Any, Optional | ||
from typing import Any, NamedTuple, Optional | ||
|
||
import pytest | ||
from ghga_service_commons.auth.ghga import AcademicTitle, AuthContext, UserStatus | ||
|
@@ -36,7 +35,6 @@ | |
from ars.ports.outbound.access_grants import AccessGrantsPort | ||
from ars.ports.outbound.dao import AccessRequestDaoPort, ResourceNotFoundError | ||
from ars.ports.outbound.event_pub import EventPublisherPort | ||
from tests.fixtures import AccessRequestDetails | ||
|
||
pytestmark = pytest.mark.asyncio(scope="session") | ||
|
||
|
@@ -191,40 +189,43 @@ async def update(self, dto: AccessRequest) -> None: | |
self.last_upsert = dto | ||
|
||
|
||
class MockAccessRequestEvent(NamedTuple): | ||
"""Mock of AccessRequestDetails plus status field to represent event type""" | ||
|
||
user_id: str | ||
dataset_id: str | ||
status: str | ||
|
||
|
||
class EventPublisherDummy(EventPublisherPort): | ||
"""Dummy event publisher for testing.""" | ||
|
||
events: dict[AccessRequestDetails, list[str]] | ||
events: list[MockAccessRequestEvent] | ||
|
||
def reset(self) -> None: | ||
"""Reset the recorded events.""" | ||
self.events = defaultdict(list) | ||
self.events = [] | ||
|
||
@property | ||
def num_events(self): | ||
"""Get total number of recorded events.""" | ||
return len(self.events) | ||
|
||
def status_for(self, request: AccessRequest) -> list[str]: | ||
"""Get the statuses used in the events published for a given request.""" | ||
details = AccessRequestDetails( | ||
user_id=request.user_id, dataset_id=request.dataset_id | ||
) | ||
try: | ||
status = self.events[details] | ||
except KeyError as err: | ||
raise RuntimeError( | ||
f"No events recorded for request with user id '{details.user_id}'" | ||
+ f" and dataset id '{details.dataset_id}'" | ||
) from err | ||
return status | ||
def events_for(self, request: AccessRequest) -> list[MockAccessRequestEvent]: | ||
"""Get the events published for a given request.""" | ||
return [ | ||
event | ||
for event in self.events | ||
if event.user_id == request.user_id | ||
and event.dataset_id == request.dataset_id | ||
] | ||
|
||
def _record_request(self, *, request: AccessRequest, request_state: str): | ||
"""Record a request as either created, allowed, or denied for a user and dataset.""" | ||
details = AccessRequestDetails( | ||
user_id=request.user_id, dataset_id=request.dataset_id | ||
mock_event = MockAccessRequestEvent( | ||
request.user_id, request.dataset_id, request_state | ||
) | ||
self.events[details].append(request_state) | ||
self.events.append(mock_event) | ||
|
||
async def publish_request_allowed(self, *, request: AccessRequest) -> None: | ||
"""Mark an access request as allowed via event publish.""" | ||
|
@@ -318,12 +319,12 @@ async def test_can_create_request(): | |
|
||
assert dao.last_upsert == request | ||
|
||
# there will be exactly 1 'event' published (a call to the dummy publisher) | ||
assert event_publisher.num_events == 1 | ||
# the 'publish_request_created' method should have been called, get events for request | ||
events = event_publisher.events_for(request=request) | ||
|
||
# the 'publish_request_created' method should have been called | ||
request_states = event_publisher.status_for(request=request) | ||
assert request_states[0] == "created" | ||
# there will be exactly 1 'event' published (a call to the dummy publisher) | ||
assert len(events) == 1 | ||
assert events[0].status == "created" | ||
|
||
assert access_grants.last_grant == "nothing granted so far" | ||
|
||
|
@@ -368,7 +369,7 @@ async def test_silently_correct_request_that_is_too_early(): | |
assert dao.last_upsert == request | ||
|
||
# There should be one event published which communicates the state of the request | ||
assert event_publisher.num_events == 1 | ||
assert len(event_publisher.events) == 1 | ||
|
||
|
||
async def test_cannot_create_request_too_much_in_advance(): | ||
|
@@ -553,9 +554,9 @@ async def test_set_status_to_allowed(): | |
assert changed_dict.pop("changed_by") == "[email protected]" | ||
assert changed_dict == original_dict | ||
|
||
assert event_publisher.num_events == 1 | ||
request_states = event_publisher.status_for(request=changed_request) | ||
assert request_states[0] == "allowed" | ||
events = event_publisher.events_for(request=changed_request) | ||
assert len(events) == 1 | ||
assert events[0].status == "allowed" | ||
|
||
assert ( | ||
access_grants.last_grant == "to [email protected] for new-dataset" | ||
|