diff --git a/api/birdxplorer_api/routers/data.py b/api/birdxplorer_api/routers/data.py index 7fcd49a..75ed648 100644 --- a/api/birdxplorer_api/routers/data.py +++ b/api/birdxplorer_api/routers/data.py @@ -102,32 +102,30 @@ def get_posts( limit: int = Query(default=100, gt=0, le=1000), search_text: Union[None, str] = Query(default=None), ) -> PostListResponse: - posts = None - - if post_id is not None: - posts = list(storage.get_posts_by_ids(post_ids=post_id)) - elif note_id is not None: - posts = list(storage.get_posts_by_note_ids(note_ids=note_id)) - elif created_at_from is not None: - if created_at_to is not None: - posts = list( - storage.get_posts_by_created_at_range( - start=ensure_twitter_timestamp(created_at_from), - end=ensure_twitter_timestamp(created_at_to), - ) - ) - else: - posts = list(storage.get_posts_by_created_at_start(start=ensure_twitter_timestamp(created_at_from))) - elif created_at_to is not None: - posts = list(storage.get_posts_by_created_at_end(end=ensure_twitter_timestamp(created_at_to))) - elif search_text is not None and len(search_text) > 0: - posts = list(storage.search_posts_by_text(search_text)) - else: - posts = list(storage.get_posts()) - - total_count = len(posts) - paginated_posts = posts[offset : offset + limit] - for post in paginated_posts: + if created_at_from is not None and isinstance(created_at_from, str): + created_at_from = ensure_twitter_timestamp(created_at_from) + if created_at_to is not None and isinstance(created_at_to, str): + created_at_to = ensure_twitter_timestamp(created_at_to) + posts = list( + storage.get_posts( + post_ids=post_id, + note_ids=note_id, + start=created_at_from, + end=created_at_to, + search_text=search_text, + offset=offset, + limit=limit, + ) + ) + total_count = storage.get_number_of_posts( + post_ids=post_id, + note_ids=note_id, + start=created_at_from, + end=created_at_to, + search_text=search_text, + ) + + for post in posts: post.link = HttpUrl(f"https://x.com/{post.x_user.name}/status/{post.post_id}") base_url = str(request.url).split("?")[0] @@ -140,6 +138,6 @@ def get_posts( if offset > 0: prev_url = f"{base_url}?offset={prev_offset}&limit={limit}" - return PostListResponse(data=paginated_posts, meta=PaginationMeta(next=next_url, prev=prev_url)) + return PostListResponse(data=posts, meta=PaginationMeta(next=next_url, prev=prev_url)) return router diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 6b9419c..97a4deb 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -254,60 +254,50 @@ def _get_notes( mock.get_topics.side_effect = _get_topics mock.get_notes.side_effect = _get_notes - def _get_posts() -> Generator[Post, None, None]: - yield from post_samples - - mock.get_posts.side_effect = _get_posts - - def _get_posts_by_ids(post_ids: List[PostId]) -> Generator[Post, None, None]: - for i in post_ids: - for post in post_samples: - if post.post_id == i: - yield post - break - - mock.get_posts_by_ids.side_effect = _get_posts_by_ids - - def _get_posts_by_note_ids(note_ids: List[NoteId]) -> Generator[Post, None, None]: - for post in post_samples: - for note in note_samples: - if note.note_id in note_ids and post.post_id == note.post_id: - yield post - break - - mock.get_posts_by_note_ids.side_effect = _get_posts_by_note_ids - - def _get_posts_by_created_at_range(start: TwitterTimestamp, end: TwitterTimestamp) -> Generator[Post, None, None]: - for post in post_samples: - if start <= post.created_at < end: - yield post - - mock.get_posts_by_created_at_range.side_effect = _get_posts_by_created_at_range - - def _get_posts_by_created_at_start( - start: TwitterTimestamp, - ) -> Generator[Post, None, None]: - for post in post_samples: - if start <= post.created_at: - yield post - - mock.get_posts_by_created_at_start.side_effect = _get_posts_by_created_at_start - - def _get_posts_by_created_at_end( - end: TwitterTimestamp, + def _get_posts( + post_ids: Union[List[PostId], None] = None, + note_ids: Union[List[NoteId], None] = None, + start: Union[TwitterTimestamp, None] = None, + end: Union[TwitterTimestamp, None] = None, + search_text: Union[str, None] = None, + offset: Union[int, None] = None, + limit: Union[int, None] = None, ) -> Generator[Post, None, None]: - for post in post_samples: - if post.created_at < end: - yield post + gen_count = 0 + actual_gen_count = 0 + for idx, post in enumerate(post_samples): + if limit is not None and actual_gen_count >= limit: + break + if post_ids is not None and post.post_id not in post_ids: + continue + if note_ids is not None and not any( + note.note_id in note_ids and note.post_id == post.post_id for note in note_samples + ): + continue + if start is not None and post.created_at < start: + continue + if end is not None and post.created_at >= end: + continue + if search_text is not None and search_text not in post.text: + continue + gen_count += 1 + if offset is not None and gen_count <= offset: + continue + actual_gen_count += 1 + yield post - mock.get_posts_by_created_at_end.side_effect = _get_posts_by_created_at_end + mock.get_posts.side_effect = _get_posts - def _search_posts_by_text(search_text: str) -> Generator[Post, None, None]: - for post in post_samples: - if search_text in post.text: - yield post + def _get_number_of_posts( + post_ids: Union[List[PostId], None] = None, + note_ids: Union[List[NoteId], None] = None, + start: Union[TwitterTimestamp, None] = None, + end: Union[TwitterTimestamp, None] = None, + search_text: Union[str, None] = None, + ) -> int: + return len(list(_get_posts(post_ids, note_ids, start, end, search_text))) - mock.search_posts_by_text.side_effect = _search_posts_by_text + mock.get_number_of_posts.side_effect = _get_number_of_posts yield mock diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index 4f940e9..e318459 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -268,47 +268,59 @@ def get_notes( created_at=note_record.created_at, ) - def get_posts(self) -> Generator[PostModel, None, None]: - with Session(self.engine) as sess: - for post_record in sess.query(PostRecord).all(): - yield self._post_record_to_model(post_record) - - def get_posts_by_ids(self, post_ids: List[PostId]) -> Generator[PostModel, None, None]: - with Session(self.engine) as sess: - for post_record in sess.query(PostRecord).filter(PostRecord.post_id.in_(post_ids)).all(): - yield self._post_record_to_model(post_record) - - def get_posts_by_created_at_range( - self, start: TwitterTimestamp, end: TwitterTimestamp + def get_posts( + self, + post_ids: Union[List[PostId], None] = None, + note_ids: Union[List[NoteId], None] = None, + start: Union[TwitterTimestamp, None] = None, + end: Union[TwitterTimestamp, None] = None, + search_text: Union[str, None] = None, + offset: Union[int, None] = None, + limit: int = 100, ) -> Generator[PostModel, None, None]: with Session(self.engine) as sess: - for post_record in sess.query(PostRecord).filter(PostRecord.created_at.between(start, end)).all(): - yield self._post_record_to_model(post_record) - - def get_posts_by_created_at_start(self, start: TwitterTimestamp) -> Generator[PostModel, None, None]: - with Session(self.engine) as sess: - for post_record in sess.query(PostRecord).filter(PostRecord.created_at >= start).all(): - yield self._post_record_to_model(post_record) - - def get_posts_by_created_at_end(self, end: TwitterTimestamp) -> Generator[PostModel, None, None]: - with Session(self.engine) as sess: - for post_record in sess.query(PostRecord).filter(PostRecord.created_at < end).all(): - yield self._post_record_to_model(post_record) - - def get_posts_by_note_ids(self, note_ids: List[NoteId]) -> Generator[PostModel, None, None]: - query = ( - select(PostRecord) - .join(NoteRecord, NoteRecord.post_id == PostRecord.post_id) - .where(NoteRecord.note_id.in_(note_ids)) - ) - with Session(self.engine) as sess: - for post_record in sess.execute(query).scalars().all(): + query = sess.query(PostRecord) + if post_ids is not None: + query = query.filter(PostRecord.post_id.in_(post_ids)) + if note_ids is not None: + query = query.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id).filter( + NoteRecord.note_id.in_(note_ids) + ) + if start is not None: + query = query.filter(PostRecord.created_at >= start) + if end is not None: + query = query.filter(PostRecord.created_at < end) + if search_text is not None: + query = query.filter(PostRecord.text.like(f"%{search_text}%")) + if offset is not None: + query = query.offset(offset) + query = query.limit(limit) + for post_record in query.all(): yield self._post_record_to_model(post_record) - def search_posts_by_text(self, search_word: str) -> Generator[PostModel, None, None]: + def get_number_of_posts( + self, + post_ids: Union[List[PostId], None] = None, + note_ids: Union[List[NoteId], None] = None, + start: Union[TwitterTimestamp, None] = None, + end: Union[TwitterTimestamp, None] = None, + search_text: Union[str, None] = None, + ) -> int: with Session(self.engine) as sess: - for post_record in sess.query(PostRecord).filter(PostRecord.text.like(f"%{search_word}%")): - yield self._post_record_to_model(post_record) + query = sess.query(PostRecord) + if post_ids is not None: + query = query.filter(PostRecord.post_id.in_(post_ids)) + if note_ids is not None: + query = query.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id).filter( + NoteRecord.note_id.in_(note_ids) + ) + if start is not None: + query = query.filter(PostRecord.created_at >= start) + if end is not None: + query = query.filter(PostRecord.created_at < end) + if search_text is not None: + query = query.filter(PostRecord.text.like(f"%{search_text}%")) + return query.count() def gen_storage(settings: GlobalSettings) -> Storage: diff --git a/common/tests/test_storage.py b/common/tests/test_storage.py index d6a66d5..cc3638d 100644 --- a/common/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -1,5 +1,6 @@ -from typing import List +from typing import Any, Dict, List +import pytest from sqlalchemy.engine import Engine from birdxplorer_common.models import ( @@ -27,102 +28,66 @@ def test_get_topic_list( assert expected == actual -def test_get_post_list( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], - topic_records_sample: List[TopicRecord], - note_records_sample: List[NoteRecord], -) -> None: - storage = Storage(engine=engine_for_test) - expected = sorted(post_samples, key=lambda x: x.post_id) - actual = sorted(storage.get_posts(), key=lambda x: x.post_id) - assert expected == actual - - -def test_get_posts_by_ids( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], - topic_records_sample: List[TopicRecord], - note_records_sample: List[NoteRecord], -) -> None: - storage = Storage(engine=engine_for_test) - post_ids = [post_samples[i].post_id for i in (0, 2)] - expected = [post_samples[i] for i in (0, 2)] - actual = list(storage.get_posts_by_ids(post_ids)) - assert expected == actual - - -def test_get_posts_by_ids_empty( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], - topic_records_sample: List[TopicRecord], - note_records_sample: List[NoteRecord], -) -> None: - storage = Storage(engine=engine_for_test) - post_ids: List[PostId] = [] - expected: List[Post] = [] - actual = list(storage.get_posts_by_ids(post_ids)) - assert expected == actual - - -def test_get_posts_by_created_at_range( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], - topic_records_sample: List[TopicRecord], - note_records_sample: List[NoteRecord], -) -> None: - storage = Storage(engine=engine_for_test) - start = TwitterTimestamp.from_int(1153921700000) - end = TwitterTimestamp.from_int(1153921800000) - expected = [post_samples[i] for i in (1,)] - actual = list(storage.get_posts_by_created_at_range(start, end)) - assert expected == actual - - -def test_get_posts_by_created_at_start( +@pytest.mark.parametrize( + ["filter_args", "expected_indices"], + [ + [dict(), [0, 1, 2]], + [dict(offset=1), [1, 2]], + [dict(limit=1), [0]], + [dict(offset=1, limit=1), [1]], + [dict(post_ids=[PostId.from_str("2234567890123456781"), PostId.from_str("2234567890123456801")]), [0, 2]], + [dict(post_ids=[]), []], + [dict(start=TwitterTimestamp.from_int(1153921700000), end=TwitterTimestamp.from_int(1153921800000)), [1]], + [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2]], + [dict(end=TwitterTimestamp.from_int(1153921700000)), [0]], + [dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]], + [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], + [dict(offset=1, limit=1, search_text="https://t.co/xxxxxxxxxxx/"), [2]], + ], +) +def test_get_post( engine_for_test: Engine, post_samples: List[Post], post_records_sample: List[PostRecord], topic_records_sample: List[TopicRecord], note_records_sample: List[NoteRecord], + filter_args: Dict[str, Any], + expected_indices: List[int], ) -> None: storage = Storage(engine=engine_for_test) - start = TwitterTimestamp.from_int(1153921700000) - expected = [post_samples[i] for i in (1, 2)] - actual = list(storage.get_posts_by_created_at_start(start)) + actual = list(storage.get_posts(**filter_args)) + expected = [post_samples[i] for i in expected_indices] assert expected == actual -def test_get_posts_by_created_at_end( +@pytest.mark.parametrize( + ["filter_args", "expected_indices"], + [ + [dict(), [0, 1, 2]], + [dict(post_ids=[PostId.from_str("2234567890123456781"), PostId.from_str("2234567890123456801")]), [0, 2]], + [dict(post_ids=[]), []], + [dict(start=TwitterTimestamp.from_int(1153921700000), end=TwitterTimestamp.from_int(1153921800000)), [1]], + [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2]], + [dict(end=TwitterTimestamp.from_int(1153921700000)), [0]], + [dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]], + [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], + ], +) +def test_get_number_of_posts( engine_for_test: Engine, post_samples: List[Post], post_records_sample: List[PostRecord], topic_records_sample: List[TopicRecord], note_records_sample: List[NoteRecord], + filter_args: Dict[str, Any], + expected_indices: List[int], ) -> None: storage = Storage(engine=engine_for_test) - end = TwitterTimestamp.from_int(1153921700000) - expected = [post_samples[i] for i in (0,)] - actual = list(storage.get_posts_by_created_at_end(end)) + actual = storage.get_number_of_posts(**filter_args) + expected = len(expected_indices) assert expected == actual -def test_search_posts_by_text( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], -) -> None: - storage = Storage(engine=engine_for_test) - search_word = "https://t.co/xxxxxxxxxxx/" - expected = [post_samples[i] for i in (0, 2)] - actual = list(storage.search_posts_by_text(search_word)) - assert actual == expected - - def test_get_notes_by_ids( engine_for_test: Engine, note_samples: List[Note],