diff --git a/common/tests/conftest.py b/common/tests/conftest.py index 4705d75..13990e5 100644 --- a/common/tests/conftest.py +++ b/common/tests/conftest.py @@ -15,6 +15,7 @@ from sqlalchemy.sql import text from birdxplorer_common.models import ( + Media, Note, Post, Topic, @@ -25,8 +26,10 @@ from birdxplorer_common.settings import GlobalSettings, PostgresStorageSettings from birdxplorer_common.storage import ( Base, + MediaRecord, NoteRecord, NoteTopicAssociation, + PostMediaAssociation, PostRecord, TopicRecord, XUserRecord, @@ -94,6 +97,11 @@ class XUserFactory(ModelFactory[XUser]): __model__ = XUser +@register_fixture(name="media_factory") +class MediaFactory(ModelFactory[Media]): + __model__ = Media + + @register_fixture(name="post_factory") class PostFactory(ModelFactory[Post]): __model__ = Post @@ -201,7 +209,36 @@ def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None, @fixture -def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Generator[List[Post], None, None]: +def media_samples(media_factory: MediaFactory) -> Generator[List[Media], None, None]: + yield [ + media_factory.build( + media_key="1234567890123456781", + url="https://pbs.twimg.com/media/xxxxxxxxxxxxxxx.jpg", + type="photo", + width=100, + height=100, + ), + media_factory.build( + media_key="1234567890123456782", + url="https://pbs.twimg.com/media/yyyyyyyyyyyyyyy.mp4", + type="video", + width=200, + height=200, + ), + media_factory.build( + media_key="1234567890123456783", + url="https://pbs.twimg.com/media/zzzzzzzzzzzzzzz.gif", + type="animated_gif", + width=300, + height=300, + ), + ] + + +@fixture +def post_samples( + post_factory: PostFactory, x_user_samples: List[XUser], media_samples: List[Media] +) -> Generator[List[Post], None, None]: posts = [ post_factory.build( post_id="2234567890123456781", @@ -227,7 +264,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene このブログ記事、めちゃくちゃ参考になった!🔥 チェックしてみて! https://t.co/yyyyyyyyyyy/ #学び #自己啓発""", - media_details=[], + media_details=[media_samples[0]], created_at=1153921700000, like_count=10, repost_count=20, @@ -240,7 +277,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene x_user=x_user_samples[1], text="""\ 次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ #旅行 #バケーション""", - media_details=[], + media_details=[media_samples[1], media_samples[2]], created_at=1154921800000, like_count=10, repost_count=20, @@ -362,25 +399,53 @@ def x_user_records_sample( @fixture -def post_records_sample( - x_user_records_sample: List[XUserRecord], - post_samples: List[Post], +def media_records_sample( engine_for_test: Engine, -) -> Generator[List[PostRecord], None, None]: + media_samples: List[Media], +) -> Generator[List[MediaRecord], None, None]: res = [ - PostRecord( - post_id=d.post_id, - user_id=d.x_user_id, - text=d.text, - media_details=d.media_details, - created_at=d.created_at, - like_count=d.like_count, - repost_count=d.repost_count, - impression_count=d.impression_count, + MediaRecord( + media_key=d.media_key, + type=d.type, + url=d.url, + width=d.width, + height=d.height, ) - for d in post_samples + for d in media_samples ] with Session(engine_for_test) as sess: sess.add_all(res) sess.commit() yield res + + +@fixture +def post_records_sample( + x_user_records_sample: List[XUserRecord], + post_samples: List[Post], + media_records_sample: List[MediaRecord], + engine_for_test: Engine, +) -> Generator[List[PostRecord], None, None]: + res: List[PostRecord] = [] + with Session(engine_for_test) as sess: + for post in post_samples: + inst = PostRecord( + post_id=post.post_id, + user_id=post.x_user_id, + text=post.text, + created_at=post.created_at, + like_count=post.like_count, + repost_count=post.repost_count, + impression_count=post.impression_count, + ) + sess.add(inst) + for media in post.media_details: + assoc = PostMediaAssociation( + post_id=inst.post_id, + media_key=media.media_key, + ) + sess.add(assoc) + inst.media_details.append(assoc) + res.append(inst) + sess.commit() + yield res diff --git a/common/tests/test_storage.py b/common/tests/test_storage.py index cc3638d..f6ba041 100644 --- a/common/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -43,6 +43,8 @@ def test_get_topic_list( [dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], [dict(offset=1, limit=1, search_text="https://t.co/xxxxxxxxxxx/"), [2]], + [dict(with_media=True), [0, 1, 2]], + [dict(post_ids=[PostId.from_str("2234567890123456781")], with_media=False), [0]], ], ) def test_get_post(