diff --git a/common/tests/conftest.py b/common/tests/conftest.py index 3e0795a..e8430ee 100644 --- a/common/tests/conftest.py +++ b/common/tests/conftest.py @@ -3,19 +3,9 @@ from collections.abc import Generator from typing import List, Type -from dotenv import load_dotenv -from polyfactory import Use -from polyfactory.factories.pydantic_factory import ModelFactory -from polyfactory.pytest_plugin import register_fixture -from pytest import fixture -from sqlalchemy import create_engine -from sqlalchemy.engine import Engine -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import Session -from sqlalchemy.sql import text - from birdxplorer_common.models import ( Link, + Media, Note, Post, Topic, @@ -27,13 +17,25 @@ from birdxplorer_common.storage import ( Base, LinkRecord, + MediaRecord, NoteRecord, NoteTopicAssociation, PostLinkAssociation, + PostMediaAssociation, PostRecord, TopicRecord, XUserRecord, ) +from dotenv import load_dotenv +from polyfactory import Use +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.pytest_plugin import register_fixture +from pytest import fixture +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session +from sqlalchemy.sql import text def gen_random_twitter_timestamp() -> int: @@ -97,6 +99,11 @@ class XUserFactory(ModelFactory[XUser]): __model__ = XUser +@register_fixture(name="media_factory") +class MediaFactory(ModelFactory[Media]): + __model__ = Media + + @register_fixture(name="post_factory") class PostFactory(ModelFactory[Post]): __model__ = Post @@ -225,9 +232,36 @@ def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None, yield x_users +@fixture +def media_samples(media_factory: MediaFactory) -> Generator[List[Media], None, None]: + yield [ + media_factory.build( + media_key="1234567890123456781", + url="https://pbs.twimg.com/media/xxxxxxxxxxxxxxx.jpg", + type="photo", + width=100, + height=100, + ), + media_factory.build( + media_key="1234567890123456782", + url="https://pbs.twimg.com/media/yyyyyyyyyyyyyyy.mp4", + type="video", + width=200, + height=200, + ), + media_factory.build( + media_key="1234567890123456783", + url="https://pbs.twimg.com/media/zzzzzzzzzzzzzzz.gif", + type="animated_gif", + width=300, + height=300, + ), + ] + + @fixture def post_samples( - post_factory: PostFactory, x_user_samples: List[XUser], link_samples: List[Link] + post_factory: PostFactory, x_user_samples: List[XUser], media_samples: List[Media], link_samples: List[Link] ) -> Generator[List[Post], None, None]: posts = [ post_factory.build( @@ -255,7 +289,7 @@ def post_samples( このブログ記事、めちゃくちゃ参考になった!🔥 チェックしてみて! https://t.co/yyyyyyyyyyy/ #学び #自己啓発""", - media_details=[], + media_details=[media_samples[0]], created_at=1153921700000, like_count=10, repost_count=20, @@ -268,8 +302,8 @@ def post_samples( x_user_id="1234567890123456782", x_user=x_user_samples[1], text="""\ -次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ https://t.co/wwwwwwwwwww/ #旅行 #バケーション""", - media_details=[], +次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ #旅行 #バケーション""", + media_details=[media_samples[1], media_samples[2]], created_at=1154921800000, like_count=10, repost_count=20, @@ -419,8 +453,8 @@ def x_user_records_sample( @fixture def link_records_sample( - link_samples: List[Link], engine_for_test: Engine, + link_samples: List[Link], ) -> Generator[List[LinkRecord], None, None]: res = [LinkRecord(link_id=d.link_id, url=d.url) for d in link_samples] with Session(engine_for_test) as sess: @@ -429,9 +463,31 @@ def link_records_sample( yield res +@fixture +def media_records_sample( + engine_for_test: Engine, + media_samples: List[Media], +) -> Generator[List[MediaRecord], None, None]: + res = [ + MediaRecord( + media_key=d.media_key, + type=d.type, + url=d.url, + width=d.width, + height=d.height, + ) + for d in media_samples + ] + with Session(engine_for_test) as sess: + sess.add_all(res) + sess.commit() + yield res + + @fixture def post_records_sample( x_user_records_sample: List[XUserRecord], + media_records_sample: List[MediaRecord], link_records_sample: List[LinkRecord], post_samples: List[Post], engine_for_test: Engine, @@ -451,9 +507,15 @@ def post_records_sample( ) sess.add(inst) for link in post.links: - assoc = PostLinkAssociation(link_id=link.link_id, post_id=inst.post_id) - sess.add(assoc) - inst.links.append(assoc) + post_link_assoc = PostLinkAssociation(link_id=link.link_id, post_id=inst.post_id) + sess.add(post_link_assoc) + inst.links.append(post_link_assoc) + + for media in post.media_details: + post_media_assoc = PostMediaAssociation(media_key=media.media_key, post_id=inst.post_id) + sess.add(post_media_assoc) + inst.media_details.append(post_media_assoc) + res.append(inst) sess.commit() yield res diff --git a/common/tests/test_storage.py b/common/tests/test_storage.py index 079056a..07d6fb2 100644 --- a/common/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -45,6 +45,8 @@ def test_get_topic_list( [dict(search_url=HttpUrl("https://example.com/sh3")), [2, 3]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], [dict(offset=1, limit=1, search_text="https://t.co/xxxxxxxxxxx/"), [2]], + [dict(with_media=True), [0, 1, 2]], + [dict(post_ids=[PostId.from_str("2234567890123456781")], with_media=False), [0]], ], ) def test_get_post(