diff --git a/api/birdxplorer_api/routers/data.py b/api/birdxplorer_api/routers/data.py index 2398df8..7bdd935 100644 --- a/api/birdxplorer_api/routers/data.py +++ b/api/birdxplorer_api/routers/data.py @@ -92,11 +92,14 @@ def get_notes( @router.get("/posts", response_model=PostListResponse) def get_posts( post_id: Union[List[PostId], None] = Query(default=None), + note_id: Union[List[NoteId], None] = Query(default=None), created_at_start: Union[None, TwitterTimestamp, str] = Query(default=None), created_at_end: Union[None, TwitterTimestamp, str] = Query(default=None), ) -> PostListResponse: if post_id is not None: return PostListResponse(data=list(storage.get_posts_by_ids(post_ids=post_id))) + if note_id is not None: + return PostListResponse(data=list(storage.get_posts_by_note_ids(note_ids=note_id))) if created_at_start is not None: if created_at_end is not None: return PostListResponse( diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 4fcc8e0..7e8d128 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -262,6 +262,15 @@ def _get_posts_by_ids(post_ids: List[PostId]) -> Generator[Post, None, None]: mock.get_posts_by_ids.side_effect = _get_posts_by_ids + def _get_posts_by_note_ids(note_ids: List[NoteId]) -> Generator[Post, None, None]: + for post in post_samples: + for note in note_samples: + if note.note_id in note_ids and post.post_id == note.post_id: + yield post + break + + mock.get_posts_by_note_ids.side_effect = _get_posts_by_note_ids + def _get_posts_by_created_at_range(start: TwitterTimestamp, end: TwitterTimestamp) -> Generator[Post, None, None]: for post in post_samples: if start <= post.created_at < end: diff --git a/api/tests/routers/test_data.py b/api/tests/routers/test_data.py index 8bf271c..59fef49 100644 --- a/api/tests/routers/test_data.py +++ b/api/tests/routers/test_data.py @@ -39,6 +39,13 @@ def test_posts_get_has_post_id_filter(client: TestClient, post_samples: List[Pos } +def test_posts_get_has_note_id_filter(client: TestClient, post_samples: List[Post], note_samples: List[Note]) -> None: + response = client.get(f"/api/v1/data/posts/?noteId={','.join([n.note_id for n in note_samples])}") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(post_samples[0].model_dump_json())]} + + def test_posts_get_has_created_at_filter_start_and_end(client: TestClient, post_samples: List[Post]) -> None: response = client.get("/api/v1/data/posts/?createdAtStart=2006-7-25 00:00:00&createdAtEnd=2006-7-30 23:59:59") assert response.status_code == 200 diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index cd07b7a..ad77eca 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -296,6 +296,16 @@ def get_posts_by_created_at_end(self, end: TwitterTimestamp) -> Generator[PostMo for post_record in sess.query(PostRecord).filter(PostRecord.created_at < end).all(): yield self._post_record_to_model(post_record) + def get_posts_by_note_ids(self, note_ids: List[NoteId]) -> Generator[PostModel, None, None]: + query = ( + select(PostRecord) + .join(NoteRecord, NoteRecord.post_id == PostRecord.post_id) + .where(NoteRecord.note_id.in_(note_ids)) + ) + with Session(self.engine) as sess: + for post_record in sess.execute(query).scalars().all(): + yield self._post_record_to_model(post_record) + def gen_storage(settings: GlobalSettings) -> Storage: engine = create_engine(settings.storage_settings.sqlalchemy_database_url)