diff --git a/api/birdxplorer_api/routers/data.py b/api/birdxplorer_api/routers/data.py index 75ed648..ad22b81 100644 --- a/api/birdxplorer_api/routers/data.py +++ b/api/birdxplorer_api/routers/data.py @@ -101,6 +101,7 @@ def get_posts( offset: int = Query(default=0, ge=0), limit: int = Query(default=100, gt=0, le=1000), search_text: Union[None, str] = Query(default=None), + search_url: Union[None, HttpUrl] = Query(default=None), ) -> PostListResponse: if created_at_from is not None and isinstance(created_at_from, str): created_at_from = ensure_twitter_timestamp(created_at_from) @@ -113,6 +114,7 @@ def get_posts( start=created_at_from, end=created_at_to, search_text=search_text, + search_url=search_url, offset=offset, limit=limit, ) @@ -123,6 +125,7 @@ def get_posts( start=created_at_from, end=created_at_to, search_text=search_text, + search_url=search_url, ) for post in posts: diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 3b8e677..208acdd 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -9,12 +9,14 @@ from polyfactory import Use from polyfactory.factories.pydantic_factory import ModelFactory from polyfactory.pytest_plugin import register_fixture +from pydantic import HttpUrl from pytest import fixture from birdxplorer_common.exceptions import UserEnrollmentNotFoundError from birdxplorer_common.models import ( LanguageIdentifier, Link, + LinkId, Note, NoteId, ParticipantId, @@ -263,6 +265,7 @@ def mock_storage( topic_samples: List[Topic], post_samples: List[Post], note_samples: List[Note], + link_samples: List[Link], ) -> Generator[MagicMock, None, None]: mock = MagicMock(spec=Storage) @@ -311,11 +314,17 @@ def _get_posts( start: Union[TwitterTimestamp, None] = None, end: Union[TwitterTimestamp, None] = None, search_text: Union[str, None] = None, + search_url: Union[HttpUrl, None] = None, offset: Union[int, None] = None, limit: Union[int, None] = None, ) -> Generator[Post, None, None]: gen_count = 0 actual_gen_count = 0 + url_id: LinkId | None = None + if search_url is not None: + url_candidates = [link.link_id for link in link_samples if link.url == search_url] + if len(url_candidates) > 0: + url_id = url_candidates[0] for idx, post in enumerate(post_samples): if limit is not None and actual_gen_count >= limit: break @@ -331,6 +340,8 @@ def _get_posts( continue if search_text is not None and search_text not in post.text: continue + if search_url is not None and url_id not in [link.link_id for link in post.links]: + continue gen_count += 1 if offset is not None and gen_count <= offset: continue @@ -345,8 +356,9 @@ def _get_number_of_posts( start: Union[TwitterTimestamp, None] = None, end: Union[TwitterTimestamp, None] = None, search_text: Union[str, None] = None, + search_url: Union[str, None] = None, ) -> int: - return len(list(_get_posts(post_ids, note_ids, start, end, search_text))) + return len(list(_get_posts(post_ids, note_ids, start, end, search_text, search_url))) mock.get_number_of_posts.side_effect = _get_number_of_posts