diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 4fcc8e0..7e8d128 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -262,6 +262,15 @@ def _get_posts_by_ids(post_ids: List[PostId]) -> Generator[Post, None, None]: mock.get_posts_by_ids.side_effect = _get_posts_by_ids + def _get_posts_by_note_ids(note_ids: List[NoteId]) -> Generator[Post, None, None]: + for post in post_samples: + for note in note_samples: + if note.note_id in note_ids and post.post_id == note.post_id: + yield post + break + + mock.get_posts_by_note_ids.side_effect = _get_posts_by_note_ids + def _get_posts_by_created_at_range(start: TwitterTimestamp, end: TwitterTimestamp) -> Generator[Post, None, None]: for post in post_samples: if start <= post.created_at < end: diff --git a/api/tests/routers/test_data.py b/api/tests/routers/test_data.py index 8bf271c..59fef49 100644 --- a/api/tests/routers/test_data.py +++ b/api/tests/routers/test_data.py @@ -39,6 +39,13 @@ def test_posts_get_has_post_id_filter(client: TestClient, post_samples: List[Pos } +def test_posts_get_has_note_id_filter(client: TestClient, post_samples: List[Post], note_samples: List[Note]) -> None: + response = client.get(f"/api/v1/data/posts/?noteId={','.join([n.note_id for n in note_samples])}") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(post_samples[0].model_dump_json())]} + + def test_posts_get_has_created_at_filter_start_and_end(client: TestClient, post_samples: List[Post]) -> None: response = client.get("/api/v1/data/posts/?createdAtStart=2006-7-25 00:00:00&createdAtEnd=2006-7-30 23:59:59") assert response.status_code == 200