diff --git a/api/birdxplorer_api/routers/data.py b/api/birdxplorer_api/routers/data.py index fbf32ee..6c48dac 100644 --- a/api/birdxplorer_api/routers/data.py +++ b/api/birdxplorer_api/routers/data.py @@ -26,6 +26,7 @@ TopicId, TwitterTimestamp, UserEnrollment, + UserId, ) from birdxplorer_common.storage import Storage @@ -255,6 +256,7 @@ def get_posts( request: Request, post_ids: Union[List[PostId], None] = Query(default=None), note_ids: Union[List[NoteId], None] = Query(default=None), + user_ids: Union[List[UserId], None] = Query(default=None), created_at_from: Union[None, TwitterTimestamp, str] = Query( default=None, **V1DataPostsDocs.params["created_at_from"] ), @@ -275,6 +277,7 @@ def get_posts( storage.get_posts( post_ids=post_ids, note_ids=note_ids, + user_ids=user_ids, start=created_at_from, end=created_at_to, search_text=search_text, @@ -288,6 +291,7 @@ def get_posts( total_count = storage.get_number_of_posts( post_ids=post_ids, note_ids=note_ids, + user_ids=user_ids, start=created_at_from, end=created_at_to, search_text=search_text, diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 36a2df7..7b4530a 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -379,6 +379,7 @@ def _get_number_of_notes( def _get_posts( post_ids: Union[List[PostId], None] = None, note_ids: Union[List[NoteId], None] = None, + user_ids: Union[List[str], None] = None, start: Union[TwitterTimestamp, None] = None, end: Union[TwitterTimestamp, None] = None, search_text: Union[str, None] = None, @@ -403,6 +404,8 @@ def _get_posts( note.note_id in note_ids and note.post_id == post.post_id for note in note_samples ): continue + if user_ids is not None and post.x_user_id not in user_ids: + continue if start is not None and post.created_at < start: continue if end is not None and post.created_at >= end: @@ -426,12 +429,13 @@ def _get_posts( def _get_number_of_posts( post_ids: Union[List[PostId], None] = None, note_ids: Union[List[NoteId], None] = None, + user_ids: Union[List[str], None] = None, start: Union[TwitterTimestamp, None] = None, end: Union[TwitterTimestamp, None] = None, search_text: Union[str, None] = None, search_url: Union[HttpUrl, None] = None, ) -> int: - return len(list(_get_posts(post_ids, note_ids, start, end, search_text, search_url))) + return len(list(_get_posts(post_ids, note_ids, user_ids, start, end, search_text, search_url))) mock.get_number_of_posts.side_effect = _get_number_of_posts diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index 8f575df..ab59332 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -436,6 +436,7 @@ def get_posts( self, post_ids: Union[List[PostId], None] = None, note_ids: Union[List[NoteId], None] = None, + user_ids: Union[List[UserId], None] = None, start: Union[TwitterTimestamp, None] = None, end: Union[TwitterTimestamp, None] = None, search_text: Union[str, None] = None, @@ -452,6 +453,8 @@ def get_posts( query = query.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id).filter( NoteRecord.note_id.in_(note_ids) ) + if user_ids is not None: + query = query.filter(PostRecord.user_id.in_(user_ids)) if start is not None: query = query.filter(PostRecord.created_at >= start) if end is not None: @@ -474,6 +477,7 @@ def get_number_of_posts( self, post_ids: Union[List[PostId], None] = None, note_ids: Union[List[NoteId], None] = None, + user_ids: Union[List[UserId], None] = None, start: Union[TwitterTimestamp, None] = None, end: Union[TwitterTimestamp, None] = None, search_text: Union[str, None] = None, @@ -487,6 +491,8 @@ def get_number_of_posts( query = query.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id).filter( NoteRecord.note_id.in_(note_ids) ) + if user_ids is not None: + query = query.filter(PostRecord.user_id.in_(user_ids)) if start is not None: query = query.filter(PostRecord.created_at >= start) if end is not None: