diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index 2acab19..9e23943 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -295,6 +295,7 @@ def get_posts( start: Union[TwitterTimestamp, None] = None, end: Union[TwitterTimestamp, None] = None, search_text: Union[str, None] = None, + search_url: Union[HttpUrl, None] = None, offset: Union[int, None] = None, limit: int = 100, ) -> Generator[PostModel, None, None]: @@ -312,6 +313,12 @@ def get_posts( query = query.filter(PostRecord.created_at < end) if search_text is not None: query = query.filter(PostRecord.text.like(f"%{search_text}%")) + if search_url is not None: + query = ( + query.join(PostLinkAssociation, PostLinkAssociation.post_id == PostRecord.post_id) + .join(LinkRecord, LinkRecord.link_id == PostLinkAssociation.link_id) + .filter(LinkRecord.url == search_url) + ) if offset is not None: query = query.offset(offset) query = query.limit(limit) @@ -325,6 +332,7 @@ def get_number_of_posts( start: Union[TwitterTimestamp, None] = None, end: Union[TwitterTimestamp, None] = None, search_text: Union[str, None] = None, + search_url: Union[HttpUrl, None] = None, ) -> int: with Session(self.engine) as sess: query = sess.query(PostRecord) @@ -340,6 +348,12 @@ def get_number_of_posts( query = query.filter(PostRecord.created_at < end) if search_text is not None: query = query.filter(PostRecord.text.like(f"%{search_text}%")) + if search_url is not None: + query = ( + query.join(PostLinkAssociation, PostLinkAssociation.post_id == PostRecord.post_id) + .join(LinkRecord, LinkRecord.link_id == PostLinkAssociation.link_id) + .filter(LinkRecord.url == search_url) + ) return query.count() diff --git a/common/tests/test_storage.py b/common/tests/test_storage.py index 64a2141..c678a52 100644 --- a/common/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -1,6 +1,7 @@ from typing import Any, Dict, List import pytest +from pydantic import HttpUrl from sqlalchemy.engine import Engine from birdxplorer_common.models import ( @@ -41,6 +42,7 @@ def test_get_topic_list( [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3, 4]], [dict(end=TwitterTimestamp.from_int(1153921700000)), [0]], [dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]], + [dict(search_url=HttpUrl("https://example.com/sh3")), [2, 3]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], [dict(offset=1, limit=1, search_text="https://t.co/xxxxxxxxxxx/"), [2]], ], @@ -70,6 +72,7 @@ def test_get_post( [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3, 4]], [dict(end=TwitterTimestamp.from_int(1153921700000)), [0]], [dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]], + [dict(search_url=HttpUrl("https://example.com/sh3")), [2, 3]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], ], )