Skip to content

Commit

Permalink
feat(storage): implement get_posts_by_created_at_range
Browse files Browse the repository at this point in the history
  • Loading branch information
osoken committed Mar 24, 2024
1 parent 243b871 commit df585ff
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 36 deletions.
60 changes: 25 additions & 35 deletions birdxplorer/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,26 @@ def __init__(self, engine: Engine) -> None:
def engine(self) -> Engine:
return self._engine

@classmethod
def _post_record_to_model(cls, post_record: PostRecord) -> PostModel:
return PostModel(
post_id=post_record.post_id,
x_user_id=post_record.user_id,
x_user=XUserModel(
user_id=post_record.user.user_id,
name=post_record.user.name,
profile_image=post_record.user.profile_image,
followers_count=post_record.user.followers_count,
following_count=post_record.user.following_count,
),
text=post_record.text,
media_details=post_record.media_details,
created_at=post_record.created_at,
like_count=post_record.like_count,
repost_count=post_record.repost_count,
impression_count=post_record.impression_count,
)

def get_user_enrollment_by_participant_id(self, participant_id: ParticipantId) -> UserEnrollment:
raise NotImplementedError

Expand All @@ -135,49 +155,19 @@ def get_topics(self) -> Generator[TopicModel, None, None]:
def get_posts(self) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).all():
yield PostModel(
post_id=post_record.post_id,
x_user_id=post_record.user_id,
x_user=XUserModel(
user_id=post_record.user.user_id,
name=post_record.user.name,
profile_image=post_record.user.profile_image,
followers_count=post_record.user.followers_count,
following_count=post_record.user.following_count,
),
text=post_record.text,
media_details=post_record.media_details,
created_at=post_record.created_at,
like_count=post_record.like_count,
repost_count=post_record.repost_count,
impression_count=post_record.impression_count,
)
yield self._post_record_to_model(post_record)

def get_posts_by_ids(self, post_ids: List[PostId]) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.post_id.in_(post_ids)).all():
yield PostModel(
post_id=post_record.post_id,
x_user_id=post_record.user_id,
x_user=XUserModel(
user_id=post_record.user.user_id,
name=post_record.user.name,
profile_image=post_record.user.profile_image,
followers_count=post_record.user.followers_count,
following_count=post_record.user.following_count,
),
text=post_record.text,
media_details=post_record.media_details,
created_at=post_record.created_at,
like_count=post_record.like_count,
repost_count=post_record.repost_count,
impression_count=post_record.impression_count,
)
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_range(
self, start: TwitterTimestamp, end: TwitterTimestamp
) -> Generator[PostModel, None, None]:
raise NotImplementedError
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.created_at.between(start, end)).all():
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_start(self, start: TwitterTimestamp) -> Generator[PostModel, None, None]:
raise NotImplementedError
Expand Down
17 changes: 16 additions & 1 deletion tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlalchemy.engine import Engine

from birdxplorer.models import Post, PostId, Topic
from birdxplorer.models import Post, PostId, Topic, TwitterTimestamp
from birdxplorer.storage import NoteRecord, PostRecord, Storage, TopicRecord


Expand Down Expand Up @@ -57,3 +57,18 @@ def test_get_posts_by_ids_empty(
expected: List[Post] = []
actual = list(storage.get_posts_by_ids(post_ids))
assert expected == actual


def test_get_posts_by_created_at_range(
engine_for_test: Engine,
post_samples: List[Post],
post_records_sample: List[PostRecord],
topic_records_sample: List[TopicRecord],
note_records_sample: List[NoteRecord],
) -> None:
storage = Storage(engine=engine_for_test)
start = TwitterTimestamp.from_int(1153921700000)
end = TwitterTimestamp.from_int(1153921800000)
expected = [post_samples[i] for i in (1,)]
actual = list(storage.get_posts_by_created_at_range(start, end))
assert expected == actual

0 comments on commit df585ff

Please sign in to comment.