Skip to content

Commit

Permalink
API側にテストを追加
Browse files Browse the repository at this point in the history
  • Loading branch information
sushichan044 committed Oct 2, 2024
1 parent bd88dc9 commit 8b72330
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 6 deletions.
50 changes: 44 additions & 6 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from birdxplorer_common.exceptions import UserEnrollmentNotFoundError
from birdxplorer_common.models import (
LanguageIdentifier,
Media,
Note,
NoteId,
ParticipantId,
Expand Down Expand Up @@ -61,6 +62,11 @@ class XUserFactory(ModelFactory[XUser]):
__model__ = XUser


@register_fixture(name="media_factory")
class MediaFactory(ModelFactory[Media]):
__model__ = Media


@register_fixture(name="post_factory")
class PostFactory(ModelFactory[Post]):
__model__ = Post
Expand Down Expand Up @@ -160,7 +166,36 @@ def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None,


@fixture
def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Generator[List[Post], None, None]:
def media_samples(media_factory: MediaFactory) -> Generator[List[Media], None, None]:
yield [
media_factory.build(
media_key="1234567890123456781",
url="https://pbs.twimg.com/media/xxxxxxxxxxxxxxx.jpg",
type="photo",
width=100,
height=100,
),
media_factory.build(
media_key="1234567890123456782",
url="https://pbs.twimg.com/media/yyyyyyyyyyyyyyy.mp4",
type="video",
width=200,
height=200,
),
media_factory.build(
media_key="1234567890123456783",
url="https://pbs.twimg.com/media/zzzzzzzzzzzzzzz.gif",
type="animated_gif",
width=300,
height=300,
),
]


@fixture
def post_samples(
post_factory: PostFactory, x_user_samples: List[XUser], media_samples: List[Media]
) -> Generator[List[Post], None, None]:
posts = [
post_factory.build(
post_id="2234567890123456781",
Expand All @@ -184,7 +219,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene
このブログ記事、めちゃくちゃ参考になった!🔥 チェックしてみて!
https://t.co/yyyyyyyyyyy/ #学び #自己啓発""",
media_details=[],
media_details=[media_samples[0]],
created_at=1153921700000,
like_count=10,
repost_count=20,
Expand All @@ -196,7 +231,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene
x_user=x_user_samples[1],
text="""\
次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ #旅行 #バケーション""",
media_details=[],
media_details=[media_samples[1], media_samples[2]],
created_at=1154921800000,
like_count=10,
repost_count=20,
Expand All @@ -210,6 +245,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene
def mock_storage(
user_enrollment_samples: List[UserEnrollment],
topic_samples: List[Topic],
media_samples: List[Media],
post_samples: List[Post],
note_samples: List[Note],
) -> Generator[MagicMock, None, None]:
Expand Down Expand Up @@ -285,9 +321,11 @@ def _get_posts(
if offset is not None and gen_count <= offset:
continue
actual_gen_count += 1
if not with_media:
post.media_details = []
yield post

if with_media is False:
yield post.model_copy(update={"media_details": []}, deep=True)
else:
yield post

mock.get_posts.side_effect = _get_posts

Expand Down
34 changes: 34 additions & 0 deletions api/tests/routers/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,40 @@ def test_posts_get_timestamp_out_of_range(client: TestClient, post_samples: List
assert response.status_code == 422


def test_posts_get_with_media_by_default(client: TestClient, post_samples: List[Post]) -> None:
response = client.get("/api/v1/data/posts/?postId=2234567890123456791")

assert response.status_code == 200
res_json_default = response.json()
assert res_json_default == {
"data": [json.loads(post_samples[1].model_dump_json())],
"meta": {"next": None, "prev": None},
}


def test_posts_get_with_media_true(client: TestClient, post_samples: List[Post]) -> None:
response = client.get("/api/v1/data/posts/?postId=2234567890123456791&media=true")

assert response.status_code == 200
res_json_default = response.json()
assert res_json_default == {
"data": [json.loads(post_samples[1].model_dump_json())],
"meta": {"next": None, "prev": None},
}


def test_posts_get_with_media_false(client: TestClient, post_samples: List[Post]) -> None:
expected_post = post_samples[1].model_copy(update={"media_details": []})
response = client.get("/api/v1/data/posts/?postId=2234567890123456791&media=false")

assert response.status_code == 200
res_json_default = response.json()
assert res_json_default == {
"data": [json.loads(expected_post.model_dump_json())],
"meta": {"next": None, "prev": None},
}


def test_posts_search_by_text(client: TestClient, post_samples: List[Post]) -> None:
response = client.get("/api/v1/data/posts/?searchText=https%3A%2F%2Ft.co%2Fxxxxxxxxxxx%2F")
assert response.status_code == 200
Expand Down

0 comments on commit 8b72330

Please sign in to comment.