diff --git a/api/tests/conftest.py b/api/tests/conftest.py index baec95f..1812c3f 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -14,6 +14,7 @@ from birdxplorer_common.exceptions import UserEnrollmentNotFoundError from birdxplorer_common.models import ( LanguageIdentifier, + Media, Note, NoteId, ParticipantId, @@ -61,6 +62,11 @@ class XUserFactory(ModelFactory[XUser]): __model__ = XUser +@register_fixture(name="media_factory") +class MediaFactory(ModelFactory[Media]): + __model__ = Media + + @register_fixture(name="post_factory") class PostFactory(ModelFactory[Post]): __model__ = Post @@ -160,7 +166,36 @@ def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None, @fixture -def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Generator[List[Post], None, None]: +def media_samples(media_factory: MediaFactory) -> Generator[List[Media], None, None]: + yield [ + media_factory.build( + media_key="1234567890123456781", + url="https://pbs.twimg.com/media/xxxxxxxxxxxxxxx.jpg", + type="photo", + width=100, + height=100, + ), + media_factory.build( + media_key="1234567890123456782", + url="https://pbs.twimg.com/media/yyyyyyyyyyyyyyy.mp4", + type="video", + width=200, + height=200, + ), + media_factory.build( + media_key="1234567890123456783", + url="https://pbs.twimg.com/media/zzzzzzzzzzzzzzz.gif", + type="animated_gif", + width=300, + height=300, + ), + ] + + +@fixture +def post_samples( + post_factory: PostFactory, x_user_samples: List[XUser], media_samples: List[Media] +) -> Generator[List[Post], None, None]: posts = [ post_factory.build( post_id="2234567890123456781", @@ -184,7 +219,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene このブログ記事、めちゃくちゃ参考になった!🔥 チェックしてみて! https://t.co/yyyyyyyyyyy/ #学び #自己啓発""", - media_details=[], + media_details=[media_samples[0]], created_at=1153921700000, like_count=10, repost_count=20, @@ -196,7 +231,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene x_user=x_user_samples[1], text="""\ 次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ #旅行 #バケーション""", - media_details=[], + media_details=[media_samples[1], media_samples[2]], created_at=1154921800000, like_count=10, repost_count=20, @@ -210,6 +245,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene def mock_storage( user_enrollment_samples: List[UserEnrollment], topic_samples: List[Topic], + media_samples: List[Media], post_samples: List[Post], note_samples: List[Note], ) -> Generator[MagicMock, None, None]: @@ -285,9 +321,11 @@ def _get_posts( if offset is not None and gen_count <= offset: continue actual_gen_count += 1 - if not with_media: - post.media_details = [] - yield post + + if with_media is False: + yield post.model_copy(update={"media_details": []}, deep=True) + else: + yield post mock.get_posts.side_effect = _get_posts diff --git a/api/tests/routers/test_data.py b/api/tests/routers/test_data.py index 67160f9..e267d93 100644 --- a/api/tests/routers/test_data.py +++ b/api/tests/routers/test_data.py @@ -119,6 +119,40 @@ def test_posts_get_timestamp_out_of_range(client: TestClient, post_samples: List assert response.status_code == 422 +def test_posts_get_with_media_by_default(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?postId=2234567890123456791") + + assert response.status_code == 200 + res_json_default = response.json() + assert res_json_default == { + "data": [json.loads(post_samples[1].model_dump_json())], + "meta": {"next": None, "prev": None}, + } + + +def test_posts_get_with_media_true(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?postId=2234567890123456791&media=true") + + assert response.status_code == 200 + res_json_default = response.json() + assert res_json_default == { + "data": [json.loads(post_samples[1].model_dump_json())], + "meta": {"next": None, "prev": None}, + } + + +def test_posts_get_with_media_false(client: TestClient, post_samples: List[Post]) -> None: + expected_post = post_samples[1].model_copy(update={"media_details": []}) + response = client.get("/api/v1/data/posts/?postId=2234567890123456791&media=false") + + assert response.status_code == 200 + res_json_default = response.json() + assert res_json_default == { + "data": [json.loads(expected_post.model_dump_json())], + "meta": {"next": None, "prev": None}, + } + + def test_posts_search_by_text(client: TestClient, post_samples: List[Post]) -> None: response = client.get("/api/v1/data/posts/?searchText=https%3A%2F%2Ft.co%2Fxxxxxxxxxxx%2F") assert response.status_code == 200