From df585ffd9ec1a5d50080e5ff278ed7ca45535a49 Mon Sep 17 00:00:00 2001 From: osoken Date: Sun, 24 Mar 2024 21:34:20 +0900 Subject: [PATCH] feat(storage): implement get_posts_by_created_at_range --- birdxplorer/storage.py | 60 ++++++++++++++++++------------------------ tests/test_storage.py | 17 +++++++++++- 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/birdxplorer/storage.py b/birdxplorer/storage.py index a3fb917..f7c2c36 100644 --- a/birdxplorer/storage.py +++ b/birdxplorer/storage.py @@ -113,6 +113,26 @@ def __init__(self, engine: Engine) -> None: def engine(self) -> Engine: return self._engine + @classmethod + def _post_record_to_model(cls, post_record: PostRecord) -> PostModel: + return PostModel( + post_id=post_record.post_id, + x_user_id=post_record.user_id, + x_user=XUserModel( + user_id=post_record.user.user_id, + name=post_record.user.name, + profile_image=post_record.user.profile_image, + followers_count=post_record.user.followers_count, + following_count=post_record.user.following_count, + ), + text=post_record.text, + media_details=post_record.media_details, + created_at=post_record.created_at, + like_count=post_record.like_count, + repost_count=post_record.repost_count, + impression_count=post_record.impression_count, + ) + def get_user_enrollment_by_participant_id(self, participant_id: ParticipantId) -> UserEnrollment: raise NotImplementedError @@ -135,49 +155,19 @@ def get_topics(self) -> Generator[TopicModel, None, None]: def get_posts(self) -> Generator[PostModel, None, None]: with Session(self.engine) as sess: for post_record in sess.query(PostRecord).all(): - yield PostModel( - post_id=post_record.post_id, - x_user_id=post_record.user_id, - x_user=XUserModel( - user_id=post_record.user.user_id, - name=post_record.user.name, - profile_image=post_record.user.profile_image, - followers_count=post_record.user.followers_count, - following_count=post_record.user.following_count, - ), - text=post_record.text, - media_details=post_record.media_details, - created_at=post_record.created_at, - like_count=post_record.like_count, - repost_count=post_record.repost_count, - impression_count=post_record.impression_count, - ) + yield self._post_record_to_model(post_record) def get_posts_by_ids(self, post_ids: List[PostId]) -> Generator[PostModel, None, None]: with Session(self.engine) as sess: for post_record in sess.query(PostRecord).filter(PostRecord.post_id.in_(post_ids)).all(): - yield PostModel( - post_id=post_record.post_id, - x_user_id=post_record.user_id, - x_user=XUserModel( - user_id=post_record.user.user_id, - name=post_record.user.name, - profile_image=post_record.user.profile_image, - followers_count=post_record.user.followers_count, - following_count=post_record.user.following_count, - ), - text=post_record.text, - media_details=post_record.media_details, - created_at=post_record.created_at, - like_count=post_record.like_count, - repost_count=post_record.repost_count, - impression_count=post_record.impression_count, - ) + yield self._post_record_to_model(post_record) def get_posts_by_created_at_range( self, start: TwitterTimestamp, end: TwitterTimestamp ) -> Generator[PostModel, None, None]: - raise NotImplementedError + with Session(self.engine) as sess: + for post_record in sess.query(PostRecord).filter(PostRecord.created_at.between(start, end)).all(): + yield self._post_record_to_model(post_record) def get_posts_by_created_at_start(self, start: TwitterTimestamp) -> Generator[PostModel, None, None]: raise NotImplementedError diff --git a/tests/test_storage.py b/tests/test_storage.py index 383f6d0..53f4c61 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -2,7 +2,7 @@ from sqlalchemy.engine import Engine -from birdxplorer.models import Post, PostId, Topic +from birdxplorer.models import Post, PostId, Topic, TwitterTimestamp from birdxplorer.storage import NoteRecord, PostRecord, Storage, TopicRecord @@ -57,3 +57,18 @@ def test_get_posts_by_ids_empty( expected: List[Post] = [] actual = list(storage.get_posts_by_ids(post_ids)) assert expected == actual + + +def test_get_posts_by_created_at_range( + engine_for_test: Engine, + post_samples: List[Post], + post_records_sample: List[PostRecord], + topic_records_sample: List[TopicRecord], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + start = TwitterTimestamp.from_int(1153921700000) + end = TwitterTimestamp.from_int(1153921800000) + expected = [post_samples[i] for i in (1,)] + actual = list(storage.get_posts_by_created_at_range(start, end)) + assert expected == actual