Skip to content

Commit

Permalink
Merge pull request #102 from codeforjapan/feat/issue-99-improve-posts
Browse files Browse the repository at this point in the history
Feat/issue 99 improve posts
  • Loading branch information
yu23ki14 authored Sep 2, 2024
2 parents 6889587 + 9421fca commit 260f73b
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 190 deletions.
52 changes: 25 additions & 27 deletions api/birdxplorer_api/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,32 +102,30 @@ def get_posts(
limit: int = Query(default=100, gt=0, le=1000),
search_text: Union[None, str] = Query(default=None),
) -> PostListResponse:
posts = None

if post_id is not None:
posts = list(storage.get_posts_by_ids(post_ids=post_id))
elif note_id is not None:
posts = list(storage.get_posts_by_note_ids(note_ids=note_id))
elif created_at_from is not None:
if created_at_to is not None:
posts = list(
storage.get_posts_by_created_at_range(
start=ensure_twitter_timestamp(created_at_from),
end=ensure_twitter_timestamp(created_at_to),
)
)
else:
posts = list(storage.get_posts_by_created_at_start(start=ensure_twitter_timestamp(created_at_from)))
elif created_at_to is not None:
posts = list(storage.get_posts_by_created_at_end(end=ensure_twitter_timestamp(created_at_to)))
elif search_text is not None and len(search_text) > 0:
posts = list(storage.search_posts_by_text(search_text))
else:
posts = list(storage.get_posts())

total_count = len(posts)
paginated_posts = posts[offset : offset + limit]
for post in paginated_posts:
if created_at_from is not None and isinstance(created_at_from, str):
created_at_from = ensure_twitter_timestamp(created_at_from)
if created_at_to is not None and isinstance(created_at_to, str):
created_at_to = ensure_twitter_timestamp(created_at_to)
posts = list(
storage.get_posts(
post_ids=post_id,
note_ids=note_id,
start=created_at_from,
end=created_at_to,
search_text=search_text,
offset=offset,
limit=limit,
)
)
total_count = storage.get_number_of_posts(
post_ids=post_id,
note_ids=note_id,
start=created_at_from,
end=created_at_to,
search_text=search_text,
)

for post in posts:
post.link = HttpUrl(f"https://x.com/{post.x_user.name}/status/{post.post_id}")

base_url = str(request.url).split("?")[0]
Expand All @@ -140,6 +138,6 @@ def get_posts(
if offset > 0:
prev_url = f"{base_url}?offset={prev_offset}&limit={limit}"

return PostListResponse(data=paginated_posts, meta=PaginationMeta(next=next_url, prev=prev_url))
return PostListResponse(data=posts, meta=PaginationMeta(next=next_url, prev=prev_url))

return router
90 changes: 40 additions & 50 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,60 +254,50 @@ def _get_notes(
mock.get_topics.side_effect = _get_topics
mock.get_notes.side_effect = _get_notes

def _get_posts() -> Generator[Post, None, None]:
yield from post_samples

mock.get_posts.side_effect = _get_posts

def _get_posts_by_ids(post_ids: List[PostId]) -> Generator[Post, None, None]:
for i in post_ids:
for post in post_samples:
if post.post_id == i:
yield post
break

mock.get_posts_by_ids.side_effect = _get_posts_by_ids

def _get_posts_by_note_ids(note_ids: List[NoteId]) -> Generator[Post, None, None]:
for post in post_samples:
for note in note_samples:
if note.note_id in note_ids and post.post_id == note.post_id:
yield post
break

mock.get_posts_by_note_ids.side_effect = _get_posts_by_note_ids

def _get_posts_by_created_at_range(start: TwitterTimestamp, end: TwitterTimestamp) -> Generator[Post, None, None]:
for post in post_samples:
if start <= post.created_at < end:
yield post

mock.get_posts_by_created_at_range.side_effect = _get_posts_by_created_at_range

def _get_posts_by_created_at_start(
start: TwitterTimestamp,
) -> Generator[Post, None, None]:
for post in post_samples:
if start <= post.created_at:
yield post

mock.get_posts_by_created_at_start.side_effect = _get_posts_by_created_at_start

def _get_posts_by_created_at_end(
end: TwitterTimestamp,
def _get_posts(
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
offset: Union[int, None] = None,
limit: Union[int, None] = None,
) -> Generator[Post, None, None]:
for post in post_samples:
if post.created_at < end:
yield post
gen_count = 0
actual_gen_count = 0
for idx, post in enumerate(post_samples):
if limit is not None and actual_gen_count >= limit:
break
if post_ids is not None and post.post_id not in post_ids:
continue
if note_ids is not None and not any(
note.note_id in note_ids and note.post_id == post.post_id for note in note_samples
):
continue
if start is not None and post.created_at < start:
continue
if end is not None and post.created_at >= end:
continue
if search_text is not None and search_text not in post.text:
continue
gen_count += 1
if offset is not None and gen_count <= offset:
continue
actual_gen_count += 1
yield post

mock.get_posts_by_created_at_end.side_effect = _get_posts_by_created_at_end
mock.get_posts.side_effect = _get_posts

def _search_posts_by_text(search_text: str) -> Generator[Post, None, None]:
for post in post_samples:
if search_text in post.text:
yield post
def _get_number_of_posts(
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
) -> int:
return len(list(_get_posts(post_ids, note_ids, start, end, search_text)))

mock.search_posts_by_text.side_effect = _search_posts_by_text
mock.get_number_of_posts.side_effect = _get_number_of_posts

yield mock

Expand Down
84 changes: 48 additions & 36 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,47 +268,59 @@ def get_notes(
created_at=note_record.created_at,
)

def get_posts(self) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).all():
yield self._post_record_to_model(post_record)

def get_posts_by_ids(self, post_ids: List[PostId]) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.post_id.in_(post_ids)).all():
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_range(
self, start: TwitterTimestamp, end: TwitterTimestamp
def get_posts(
self,
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
offset: Union[int, None] = None,
limit: int = 100,
) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.created_at.between(start, end)).all():
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_start(self, start: TwitterTimestamp) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.created_at >= start).all():
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_end(self, end: TwitterTimestamp) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.created_at < end).all():
yield self._post_record_to_model(post_record)

def get_posts_by_note_ids(self, note_ids: List[NoteId]) -> Generator[PostModel, None, None]:
query = (
select(PostRecord)
.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id)
.where(NoteRecord.note_id.in_(note_ids))
)
with Session(self.engine) as sess:
for post_record in sess.execute(query).scalars().all():
query = sess.query(PostRecord)
if post_ids is not None:
query = query.filter(PostRecord.post_id.in_(post_ids))
if note_ids is not None:
query = query.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id).filter(
NoteRecord.note_id.in_(note_ids)
)
if start is not None:
query = query.filter(PostRecord.created_at >= start)
if end is not None:
query = query.filter(PostRecord.created_at < end)
if search_text is not None:
query = query.filter(PostRecord.text.like(f"%{search_text}%"))
if offset is not None:
query = query.offset(offset)
query = query.limit(limit)
for post_record in query.all():
yield self._post_record_to_model(post_record)

def search_posts_by_text(self, search_word: str) -> Generator[PostModel, None, None]:
def get_number_of_posts(
self,
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
) -> int:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.text.like(f"%{search_word}%")):
yield self._post_record_to_model(post_record)
query = sess.query(PostRecord)
if post_ids is not None:
query = query.filter(PostRecord.post_id.in_(post_ids))
if note_ids is not None:
query = query.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id).filter(
NoteRecord.note_id.in_(note_ids)
)
if start is not None:
query = query.filter(PostRecord.created_at >= start)
if end is not None:
query = query.filter(PostRecord.created_at < end)
if search_text is not None:
query = query.filter(PostRecord.text.like(f"%{search_text}%"))
return query.count()


def gen_storage(settings: GlobalSettings) -> Storage:
Expand Down
Loading

0 comments on commit 260f73b

Please sign in to comment.