From 98a8dc2d0bc748c34c0f9b1fc193133046bc100b Mon Sep 17 00:00:00 2001 From: sushi-chaaaan Date: Wed, 2 Oct 2024 15:19:45 +0900 Subject: [PATCH] =?UTF-8?q?API=E5=81=B4=E3=81=AB=E3=83=86=E3=82=B9?= =?UTF-8?q?=E3=83=88=E3=82=92=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/tests/conftest.py | 46 ++++++++++++++++++++++++++++++---- api/tests/routers/test_data.py | 34 +++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 2d8ea9c..35572e1 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -17,6 +17,7 @@ LanguageIdentifier, Link, LinkId, + Media, Note, NoteId, ParticipantId, @@ -64,6 +65,11 @@ class XUserFactory(ModelFactory[XUser]): __model__ = XUser +@register_fixture(name="media_factory") +class MediaFactory(ModelFactory[Media]): + __model__ = Media + + @register_fixture(name="post_factory") class PostFactory(ModelFactory[Post]): __model__ = Post @@ -183,9 +189,36 @@ def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None, yield x_users +@fixture +def media_samples(media_factory: MediaFactory) -> Generator[List[Media], None, None]: + yield [ + media_factory.build( + media_key="1234567890123456781", + url="https://pbs.twimg.com/media/xxxxxxxxxxxxxxx.jpg", + type="photo", + width=100, + height=100, + ), + media_factory.build( + media_key="1234567890123456782", + url="https://pbs.twimg.com/media/yyyyyyyyyyyyyyy.mp4", + type="video", + width=200, + height=200, + ), + media_factory.build( + media_key="1234567890123456783", + url="https://pbs.twimg.com/media/zzzzzzzzzzzzzzz.gif", + type="animated_gif", + width=300, + height=300, + ), + ] + + @fixture def post_samples( - post_factory: PostFactory, x_user_samples: List[XUser], link_samples: List[Link] + post_factory: PostFactory, x_user_samples: List[XUser], media_samples: List[Media],link_samples: List[Link] ) -> Generator[List[Post], None, None]: posts = [ post_factory.build( @@ -213,7 +246,7 @@ def post_samples( このブログ記事、めちゃくちゃ参考になった!🔥 チェックしてみて! https://t.co/yyyyyyyyyyy/ #学び #自己啓発""", - media_details=[], + media_details=[media_samples[0]], created_at=1153921700000, like_count=10, repost_count=20, @@ -268,6 +301,7 @@ def post_samples( def mock_storage( user_enrollment_samples: List[UserEnrollment], topic_samples: List[Topic], + media_samples: List[Media], post_samples: List[Post], note_samples: List[Note], link_samples: List[Link], @@ -355,9 +389,11 @@ def _get_posts( if offset is not None and gen_count <= offset: continue actual_gen_count += 1 - if not with_media: - post.media_details = [] - yield post + + if with_media is False: + yield post.model_copy(update={"media_details": []}, deep=True) + else: + yield post mock.get_posts.side_effect = _get_posts diff --git a/api/tests/routers/test_data.py b/api/tests/routers/test_data.py index e35f6e4..e971098 100644 --- a/api/tests/routers/test_data.py +++ b/api/tests/routers/test_data.py @@ -122,6 +122,40 @@ def test_posts_get_timestamp_out_of_range(client: TestClient, post_samples: List assert response.status_code == 422 +def test_posts_get_with_media_by_default(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?postId=2234567890123456791") + + assert response.status_code == 200 + res_json_default = response.json() + assert res_json_default == { + "data": [json.loads(post_samples[1].model_dump_json())], + "meta": {"next": None, "prev": None}, + } + + +def test_posts_get_with_media_true(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?postId=2234567890123456791&media=true") + + assert response.status_code == 200 + res_json_default = response.json() + assert res_json_default == { + "data": [json.loads(post_samples[1].model_dump_json())], + "meta": {"next": None, "prev": None}, + } + + +def test_posts_get_with_media_false(client: TestClient, post_samples: List[Post]) -> None: + expected_post = post_samples[1].model_copy(update={"media_details": []}) + response = client.get("/api/v1/data/posts/?postId=2234567890123456791&media=false") + + assert response.status_code == 200 + res_json_default = response.json() + assert res_json_default == { + "data": [json.loads(expected_post.model_dump_json())], + "meta": {"next": None, "prev": None}, + } + + def test_posts_search_by_text(client: TestClient, post_samples: List[Post]) -> None: response = client.get("/api/v1/data/posts/?searchText=https%3A%2F%2Ft.co%2Fxxxxxxxxxxx%2F") assert response.status_code == 200