diff --git a/api/birdxplorer_api/routers/data.py b/api/birdxplorer_api/routers/data.py index 36be229..09c5f2e 100644 --- a/api/birdxplorer_api/routers/data.py +++ b/api/birdxplorer_api/routers/data.py @@ -103,6 +103,7 @@ def get_posts( offset: int = Query(default=0, ge=0), limit: int = Query(default=100, gt=0, le=1000), search_text: Union[None, str] = Query(default=None), + search_url: Union[None, HttpUrl] = Query(default=None), ) -> PostListResponse: if created_at_from is not None and isinstance(created_at_from, str): created_at_from = ensure_twitter_timestamp(created_at_from) @@ -115,6 +116,7 @@ def get_posts( start=created_at_from, end=created_at_to, search_text=search_text, + search_url=search_url, offset=offset, limit=limit, ) @@ -125,6 +127,7 @@ def get_posts( start=created_at_from, end=created_at_to, search_text=search_text, + search_url=search_url, ) for post in posts: diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 8d1e50a..141c9f2 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -9,12 +9,14 @@ from polyfactory import Use from polyfactory.factories.pydantic_factory import ModelFactory from polyfactory.pytest_plugin import register_fixture +from pydantic import HttpUrl from pytest import fixture from birdxplorer_common.exceptions import UserEnrollmentNotFoundError from birdxplorer_common.models import ( LanguageIdentifier, Link, + LinkId, Note, NoteId, ParticipantId, @@ -268,6 +270,7 @@ def mock_storage( topic_samples: List[Topic], post_samples: List[Post], note_samples: List[Note], + link_samples: List[Link], ) -> Generator[MagicMock, None, None]: mock = MagicMock(spec=Storage) @@ -319,11 +322,17 @@ def _get_posts( start: Union[TwitterTimestamp, None] = None, end: Union[TwitterTimestamp, None] = None, search_text: Union[str, None] = None, + search_url: Union[HttpUrl, None] = None, offset: Union[int, None] = None, limit: Union[int, None] = None, ) -> Generator[Post, None, None]: gen_count = 0 actual_gen_count = 0 + url_id: LinkId | None = None + if search_url is not None: + url_candidates = [link.link_id for link in link_samples if link.url == search_url] + if len(url_candidates) > 0: + url_id = url_candidates[0] for idx, post in enumerate(post_samples): if limit is not None and actual_gen_count >= limit: break @@ -339,6 +348,8 @@ def _get_posts( continue if search_text is not None and search_text not in post.text: continue + if search_url is not None and url_id not in [link.link_id for link in post.links]: + continue gen_count += 1 if offset is not None and gen_count <= offset: continue @@ -353,8 +364,9 @@ def _get_number_of_posts( start: Union[TwitterTimestamp, None] = None, end: Union[TwitterTimestamp, None] = None, search_text: Union[str, None] = None, + search_url: Union[HttpUrl, None] = None, ) -> int: - return len(list(_get_posts(post_ids, note_ids, start, end, search_text))) + return len(list(_get_posts(post_ids, note_ids, start, end, search_text, search_url))) mock.get_number_of_posts.side_effect = _get_number_of_posts diff --git a/api/tests/routers/test_data.py b/api/tests/routers/test_data.py index 6ba8f7a..e35f6e4 100644 --- a/api/tests/routers/test_data.py +++ b/api/tests/routers/test_data.py @@ -132,6 +132,16 @@ def test_posts_search_by_text(client: TestClient, post_samples: List[Post]) -> N } +def test_posts_search_by_url(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?searchUrl=https%3A%2F%2Fexample.com%2Fsh3") + assert response.status_code == 200 + res_json = response.json() + assert res_json == { + "data": [json.loads(post_samples[i].model_dump_json()) for i in (2, 3)], + "meta": {"next": None, "prev": None}, + } + + def test_notes_get(client: TestClient, note_samples: List[Note]) -> None: response = client.get("/api/v1/data/notes") assert response.status_code == 200 diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index 3561358..2b7868b 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -337,6 +337,7 @@ def get_posts( start: Union[TwitterTimestamp, None] = None, end: Union[TwitterTimestamp, None] = None, search_text: Union[str, None] = None, + search_url: Union[HttpUrl, None] = None, offset: Union[int, None] = None, limit: int = 100, ) -> Generator[PostModel, None, None]: @@ -354,6 +355,12 @@ def get_posts( query = query.filter(PostRecord.created_at < end) if search_text is not None: query = query.filter(PostRecord.text.like(f"%{search_text}%")) + if search_url is not None: + query = ( + query.join(PostLinkAssociation, PostLinkAssociation.post_id == PostRecord.post_id) + .join(LinkRecord, LinkRecord.link_id == PostLinkAssociation.link_id) + .filter(LinkRecord.url == search_url) + ) if offset is not None: query = query.offset(offset) query = query.limit(limit) @@ -367,6 +374,7 @@ def get_number_of_posts( start: Union[TwitterTimestamp, None] = None, end: Union[TwitterTimestamp, None] = None, search_text: Union[str, None] = None, + search_url: Union[HttpUrl, None] = None, ) -> int: with Session(self.engine) as sess: query = sess.query(PostRecord) @@ -382,6 +390,12 @@ def get_number_of_posts( query = query.filter(PostRecord.created_at < end) if search_text is not None: query = query.filter(PostRecord.text.like(f"%{search_text}%")) + if search_url is not None: + query = ( + query.join(PostLinkAssociation, PostLinkAssociation.post_id == PostRecord.post_id) + .join(LinkRecord, LinkRecord.link_id == PostLinkAssociation.link_id) + .filter(LinkRecord.url == search_url) + ) return query.count() diff --git a/common/tests/test_storage.py b/common/tests/test_storage.py index 46fb93e..079056a 100644 --- a/common/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List import pytest +from pydantic import HttpUrl from sqlalchemy.engine import Engine from birdxplorer_common.models import ( @@ -41,6 +42,7 @@ def test_get_topic_list( [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3, 4]], [dict(end=TwitterTimestamp.from_int(1153921700000)), [0]], [dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]], + [dict(search_url=HttpUrl("https://example.com/sh3")), [2, 3]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], [dict(offset=1, limit=1, search_text="https://t.co/xxxxxxxxxxx/"), [2]], ], @@ -70,6 +72,7 @@ def test_get_post( [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3, 4]], [dict(end=TwitterTimestamp.from_int(1153921700000)), [0]], [dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]], + [dict(search_url=HttpUrl("https://example.com/sh3")), [2, 3]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], ], )