Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/issue 99 improve posts #102

Merged
merged 19 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 25 additions & 27 deletions api/birdxplorer_api/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,32 +102,30 @@ def get_posts(
limit: int = Query(default=100, gt=0, le=1000),
search_text: Union[None, str] = Query(default=None),
) -> PostListResponse:
posts = None

if post_id is not None:
posts = list(storage.get_posts_by_ids(post_ids=post_id))
elif note_id is not None:
posts = list(storage.get_posts_by_note_ids(note_ids=note_id))
elif created_at_from is not None:
if created_at_to is not None:
posts = list(
storage.get_posts_by_created_at_range(
start=ensure_twitter_timestamp(created_at_from),
end=ensure_twitter_timestamp(created_at_to),
)
)
else:
posts = list(storage.get_posts_by_created_at_start(start=ensure_twitter_timestamp(created_at_from)))
elif created_at_to is not None:
posts = list(storage.get_posts_by_created_at_end(end=ensure_twitter_timestamp(created_at_to)))
elif search_text is not None and len(search_text) > 0:
posts = list(storage.search_posts_by_text(search_text))
else:
posts = list(storage.get_posts())

total_count = len(posts)
paginated_posts = posts[offset : offset + limit]
for post in paginated_posts:
if created_at_from is not None and isinstance(created_at_from, str):
created_at_from = ensure_twitter_timestamp(created_at_from)
if created_at_to is not None and isinstance(created_at_to, str):
created_at_to = ensure_twitter_timestamp(created_at_to)
posts = list(
storage.get_posts(
post_ids=post_id,
note_ids=note_id,
start=created_at_from,
end=created_at_to,
search_text=search_text,
offset=offset,
limit=limit,
)
)
total_count = storage.get_number_of_posts(
post_ids=post_id,
note_ids=note_id,
start=created_at_from,
end=created_at_to,
search_text=search_text,
)

for post in posts:
post.link = HttpUrl(f"https://x.com/{post.x_user.name}/status/{post.post_id}")

base_url = str(request.url).split("?")[0]
Expand All @@ -140,6 +138,6 @@ def get_posts(
if offset > 0:
prev_url = f"{base_url}?offset={prev_offset}&limit={limit}"

return PostListResponse(data=paginated_posts, meta=PaginationMeta(next=next_url, prev=prev_url))
return PostListResponse(data=posts, meta=PaginationMeta(next=next_url, prev=prev_url))

return router
90 changes: 40 additions & 50 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,60 +254,50 @@ def _get_notes(
mock.get_topics.side_effect = _get_topics
mock.get_notes.side_effect = _get_notes

def _get_posts() -> Generator[Post, None, None]:
yield from post_samples

mock.get_posts.side_effect = _get_posts

def _get_posts_by_ids(post_ids: List[PostId]) -> Generator[Post, None, None]:
for i in post_ids:
for post in post_samples:
if post.post_id == i:
yield post
break

mock.get_posts_by_ids.side_effect = _get_posts_by_ids

def _get_posts_by_note_ids(note_ids: List[NoteId]) -> Generator[Post, None, None]:
for post in post_samples:
for note in note_samples:
if note.note_id in note_ids and post.post_id == note.post_id:
yield post
break

mock.get_posts_by_note_ids.side_effect = _get_posts_by_note_ids

def _get_posts_by_created_at_range(start: TwitterTimestamp, end: TwitterTimestamp) -> Generator[Post, None, None]:
for post in post_samples:
if start <= post.created_at < end:
yield post

mock.get_posts_by_created_at_range.side_effect = _get_posts_by_created_at_range

def _get_posts_by_created_at_start(
start: TwitterTimestamp,
) -> Generator[Post, None, None]:
for post in post_samples:
if start <= post.created_at:
yield post

mock.get_posts_by_created_at_start.side_effect = _get_posts_by_created_at_start

def _get_posts_by_created_at_end(
end: TwitterTimestamp,
def _get_posts(
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
offset: Union[int, None] = None,
limit: Union[int, None] = None,
) -> Generator[Post, None, None]:
for post in post_samples:
if post.created_at < end:
yield post
gen_count = 0
actual_gen_count = 0
for idx, post in enumerate(post_samples):
if limit is not None and actual_gen_count >= limit:
break
if post_ids is not None and post.post_id not in post_ids:
continue
if note_ids is not None and not any(
note.note_id in note_ids and note.post_id == post.post_id for note in note_samples
):
continue
if start is not None and post.created_at < start:
continue
if end is not None and post.created_at >= end:
continue
if search_text is not None and search_text not in post.text:
continue
gen_count += 1
if offset is not None and gen_count <= offset:
continue
actual_gen_count += 1
yield post

mock.get_posts_by_created_at_end.side_effect = _get_posts_by_created_at_end
mock.get_posts.side_effect = _get_posts

def _search_posts_by_text(search_text: str) -> Generator[Post, None, None]:
for post in post_samples:
if search_text in post.text:
yield post
def _get_number_of_posts(
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
) -> int:
return len(list(_get_posts(post_ids, note_ids, start, end, search_text)))

mock.search_posts_by_text.side_effect = _search_posts_by_text
mock.get_number_of_posts.side_effect = _get_number_of_posts

yield mock

Expand Down
84 changes: 48 additions & 36 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,47 +268,59 @@ def get_notes(
created_at=note_record.created_at,
)

def get_posts(self) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).all():
yield self._post_record_to_model(post_record)

def get_posts_by_ids(self, post_ids: List[PostId]) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.post_id.in_(post_ids)).all():
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_range(
self, start: TwitterTimestamp, end: TwitterTimestamp
def get_posts(
self,
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
offset: Union[int, None] = None,
limit: int = 100,
) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.created_at.between(start, end)).all():
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_start(self, start: TwitterTimestamp) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.created_at >= start).all():
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_end(self, end: TwitterTimestamp) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.created_at < end).all():
yield self._post_record_to_model(post_record)

def get_posts_by_note_ids(self, note_ids: List[NoteId]) -> Generator[PostModel, None, None]:
query = (
select(PostRecord)
.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id)
.where(NoteRecord.note_id.in_(note_ids))
)
with Session(self.engine) as sess:
for post_record in sess.execute(query).scalars().all():
query = sess.query(PostRecord)
if post_ids is not None:
query = query.filter(PostRecord.post_id.in_(post_ids))
if note_ids is not None:
query = query.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id).filter(
NoteRecord.note_id.in_(note_ids)
)
if start is not None:
query = query.filter(PostRecord.created_at >= start)
if end is not None:
query = query.filter(PostRecord.created_at < end)
if search_text is not None:
query = query.filter(PostRecord.text.like(f"%{search_text}%"))
if offset is not None:
query = query.offset(offset)
query = query.limit(limit)
for post_record in query.all():
yield self._post_record_to_model(post_record)

def search_posts_by_text(self, search_word: str) -> Generator[PostModel, None, None]:
def get_number_of_posts(
self,
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
) -> int:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.text.like(f"%{search_word}%")):
yield self._post_record_to_model(post_record)
query = sess.query(PostRecord)
if post_ids is not None:
query = query.filter(PostRecord.post_id.in_(post_ids))
if note_ids is not None:
query = query.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id).filter(
NoteRecord.note_id.in_(note_ids)
)
if start is not None:
query = query.filter(PostRecord.created_at >= start)
if end is not None:
query = query.filter(PostRecord.created_at < end)
if search_text is not None:
query = query.filter(PostRecord.text.like(f"%{search_text}%"))
return query.count()


def gen_storage(settings: GlobalSettings) -> Storage:
Expand Down
Loading
Loading