diff --git a/api/birdxplorer_api/routers/data.py b/api/birdxplorer_api/routers/data.py index 09c5f2e..e51e13d 100644 --- a/api/birdxplorer_api/routers/data.py +++ b/api/birdxplorer_api/routers/data.py @@ -104,6 +104,7 @@ def get_posts( limit: int = Query(default=100, gt=0, le=1000), search_text: Union[None, str] = Query(default=None), search_url: Union[None, HttpUrl] = Query(default=None), + media: bool = Query(default=True), ) -> PostListResponse: if created_at_from is not None and isinstance(created_at_from, str): created_at_from = ensure_twitter_timestamp(created_at_from) @@ -119,6 +120,7 @@ def get_posts( search_url=search_url, offset=offset, limit=limit, + with_media=media, ) ) total_count = storage.get_number_of_posts( @@ -130,9 +132,6 @@ def get_posts( search_url=search_url, ) - for post in posts: - post.link = HttpUrl(f"https://x.com/{post.x_user.name}/status/{post.post_id}") - base_url = str(request.url).split("?")[0] next_offset = offset + limit prev_offset = max(offset - limit, 0) diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 141c9f2..affa005 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -17,6 +17,7 @@ LanguageIdentifier, Link, LinkId, + Media, Note, NoteId, ParticipantId, @@ -64,6 +65,11 @@ class XUserFactory(ModelFactory[XUser]): __model__ = XUser +@register_fixture(name="media_factory") +class MediaFactory(ModelFactory[Media]): + __model__ = Media + + @register_fixture(name="post_factory") class PostFactory(ModelFactory[Post]): __model__ = Post @@ -183,9 +189,36 @@ def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None, yield x_users +@fixture +def media_samples(media_factory: MediaFactory) -> Generator[List[Media], None, None]: + yield [ + media_factory.build( + media_key="1234567890123456781", + url="https://pbs.twimg.com/media/xxxxxxxxxxxxxxx.jpg", + type="photo", + width=100, + height=100, + ), + media_factory.build( + media_key="1234567890123456782", + url="https://pbs.twimg.com/media/yyyyyyyyyyyyyyy.mp4", + type="video", + width=200, + height=200, + ), + media_factory.build( + media_key="1234567890123456783", + url="https://pbs.twimg.com/media/zzzzzzzzzzzzzzz.gif", + type="animated_gif", + width=300, + height=300, + ), + ] + + @fixture def post_samples( - post_factory: PostFactory, x_user_samples: List[XUser], link_samples: List[Link] + post_factory: PostFactory, x_user_samples: List[XUser], media_samples: List[Media], link_samples: List[Link] ) -> Generator[List[Post], None, None]: posts = [ post_factory.build( @@ -197,7 +230,7 @@ def post_samples( 新しいプロジェクトがついに公開されました!詳細はこちら👉 https://t.co/xxxxxxxxxxx/ #プロジェクト #新発売 #Tech""", - media_details=None, + media_details=[], created_at=1152921600000, like_count=10, repost_count=20, @@ -213,7 +246,7 @@ def post_samples( このブログ記事、めちゃくちゃ参考になった!🔥 チェックしてみて! https://t.co/yyyyyyyyyyy/ #学び #自己啓発""", - media_details=None, + media_details=[media_samples[0]], created_at=1153921700000, like_count=10, repost_count=20, @@ -227,7 +260,7 @@ def post_samples( x_user=x_user_samples[1], text="""\ 次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ https://t.co/wwwwwwwwwww/ #旅行 #バケーション""", - media_details=None, + media_details=[], created_at=1154921800000, like_count=10, repost_count=20, @@ -240,7 +273,7 @@ def post_samples( x_user_id="1234567890123456782", x_user=x_user_samples[1], text="https://t.co/zzzzzzzzzzz/ https://t.co/wwwwwwwwwww/", - media_details=None, + media_details=[], created_at=1154922900000, like_count=10, repost_count=20, @@ -253,7 +286,7 @@ def post_samples( x_user_id="1234567890123456783", x_user=x_user_samples[2], text="empty", - media_details=None, + media_details=[], created_at=1154923900000, like_count=10, repost_count=20, @@ -268,6 +301,7 @@ def post_samples( def mock_storage( user_enrollment_samples: List[UserEnrollment], topic_samples: List[Topic], + media_samples: List[Media], post_samples: List[Post], note_samples: List[Note], link_samples: List[Link], @@ -325,6 +359,7 @@ def _get_posts( search_url: Union[HttpUrl, None] = None, offset: Union[int, None] = None, limit: Union[int, None] = None, + with_media: bool = True, ) -> Generator[Post, None, None]: gen_count = 0 actual_gen_count = 0 @@ -354,7 +389,11 @@ def _get_posts( if offset is not None and gen_count <= offset: continue actual_gen_count += 1 - yield post + + if with_media is False: + yield post.model_copy(update={"media_details": []}, deep=True) + else: + yield post mock.get_posts.side_effect = _get_posts diff --git a/api/tests/routers/test_data.py b/api/tests/routers/test_data.py index e35f6e4..e971098 100644 --- a/api/tests/routers/test_data.py +++ b/api/tests/routers/test_data.py @@ -122,6 +122,40 @@ def test_posts_get_timestamp_out_of_range(client: TestClient, post_samples: List assert response.status_code == 422 +def test_posts_get_with_media_by_default(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?postId=2234567890123456791") + + assert response.status_code == 200 + res_json_default = response.json() + assert res_json_default == { + "data": [json.loads(post_samples[1].model_dump_json())], + "meta": {"next": None, "prev": None}, + } + + +def test_posts_get_with_media_true(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?postId=2234567890123456791&media=true") + + assert response.status_code == 200 + res_json_default = response.json() + assert res_json_default == { + "data": [json.loads(post_samples[1].model_dump_json())], + "meta": {"next": None, "prev": None}, + } + + +def test_posts_get_with_media_false(client: TestClient, post_samples: List[Post]) -> None: + expected_post = post_samples[1].model_copy(update={"media_details": []}) + response = client.get("/api/v1/data/posts/?postId=2234567890123456791&media=false") + + assert response.status_code == 200 + res_json_default = response.json() + assert res_json_default == { + "data": [json.loads(expected_post.model_dump_json())], + "meta": {"next": None, "prev": None}, + } + + def test_posts_search_by_text(client: TestClient, post_samples: List[Post]) -> None: response = client.get("/api/v1/data/posts/?searchText=https%3A%2F%2Ft.co%2Fxxxxxxxxxxx%2F") assert response.status_code == 200 diff --git a/common/birdxplorer_common/models.py b/common/birdxplorer_common/models.py index ec33593..25b1e4b 100644 --- a/common/birdxplorer_common/models.py +++ b/common/birdxplorer_common/models.py @@ -2,17 +2,24 @@ from datetime import datetime, timezone from enum import Enum from random import Random -from typing import Any, Dict, List, Literal, Optional, Type, TypeAlias, TypeVar, Union +from typing import ( + Annotated, + Any, + Dict, + List, + Literal, + Optional, + Type, + TypeAlias, + TypeVar, + Union, +) from uuid import UUID from pydantic import BaseModel as PydanticBaseModel -from pydantic import ( - ConfigDict, - GetCoreSchemaHandler, - HttpUrl, - TypeAdapter, - model_validator, -) +from pydantic import ConfigDict +from pydantic import Field as PydanticField +from pydantic import GetCoreSchemaHandler, HttpUrl, TypeAdapter, model_validator, computed_field from pydantic.alias_generators import to_camel from pydantic.main import IncEx from pydantic_core import core_schema @@ -683,7 +690,20 @@ class XUser(BaseModel): following_count: NonNegativeInt -MediaDetails: TypeAlias = List[HttpUrl] | None +# ref: https://developer.x.com/en/docs/x-api/data-dictionary/object-model/media +MediaType: TypeAlias = Literal["photo", "video", "animated_gif"] + + +class Media(BaseModel): + media_key: str + + type: MediaType + url: HttpUrl + width: NonNegativeInt + height: NonNegativeInt + + +MediaDetails: TypeAlias = List[Media] class LinkId(UUID): @@ -750,17 +770,39 @@ def validate_link_id(cls, values: Dict[str, Any]) -> Dict[str, Any]: class Post(BaseModel): post_id: PostId - link: Optional[HttpUrl] = None x_user_id: UserId x_user: XUser text: str - media_details: MediaDetails = None + media_details: Annotated[MediaDetails, PydanticField(default_factory=lambda: [])] created_at: TwitterTimestamp like_count: NonNegativeInt repost_count: NonNegativeInt impression_count: NonNegativeInt links: List[Link] = [] + @property + @computed_field + def link(self) -> HttpUrl: + """ + PostのX上でのURLを返す。 + + Examples + -------- + >>> post = Post(post_id="1234567890123456789", + x_user_id="1234567890123456789", + x_user=XUser(user_id="1234567890123456789", + name="test", + profile_image="https://x.com/test"), + text="test", + created_at=1288834974657, + like_count=1, + repost_count=1, + impression_count=1) + >>> post.link + HttpUrl('https://x.com/test/status/1234567890123456789') + """ + return HttpUrl(f"https://x.com/{self.x_user.name}/status/{self.post_id}") + class PaginationMeta(BaseModel): next: Optional[HttpUrl] = None diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index 2b7868b..6eee4af 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -9,7 +9,7 @@ from .models import BinaryBool, LanguageIdentifier from .models import Link as LinkModel -from .models import LinkId, MediaDetails, NonNegativeInt +from .models import LinkId, Media, MediaDetails, MediaType, NonNegativeInt from .models import Note as NoteModel from .models import NoteId, NotesClassification, NotesHarmful, ParticipantId from .models import Post as PostModel @@ -107,6 +107,29 @@ class PostLinkAssociation(Base): link: Mapped[LinkRecord] = relationship() +class PostMediaAssociation(Base): + __tablename__ = "post_media" + + post_id: Mapped[PostId] = mapped_column(ForeignKey("posts.post_id"), primary_key=True) + media_key: Mapped[str] = mapped_column(ForeignKey("media.media_key"), primary_key=True) + + # このテーブルにアクセスした時点でほぼ間違いなく MediaRecord も必要なので一気に引っ張る + media: Mapped["MediaRecord"] = relationship(back_populates="post_media_association", lazy="joined") + + +class MediaRecord(Base): + __tablename__ = "media" + + media_key: Mapped[str] = mapped_column(primary_key=True) + + type: Mapped[MediaType] = mapped_column(nullable=False) + url: Mapped[HttpUrl] = mapped_column(nullable=False) + width: Mapped[NonNegativeInt] = mapped_column(nullable=False) + height: Mapped[NonNegativeInt] = mapped_column(nullable=False) + + post_media_association: Mapped["PostMediaAssociation"] = relationship(back_populates="media") + + class PostRecord(Base): __tablename__ = "posts" @@ -114,7 +137,7 @@ class PostRecord(Base): user_id: Mapped[UserId] = mapped_column(ForeignKey("x_users.user_id"), nullable=False) user: Mapped[XUserRecord] = relationship() text: Mapped[SummaryString] = mapped_column(nullable=False) - media_details: Mapped[MediaDetails] = mapped_column() + media_details: Mapped[List[PostMediaAssociation]] = relationship() created_at: Mapped[TwitterTimestamp] = mapped_column(nullable=False) like_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) repost_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) @@ -236,7 +259,27 @@ def engine(self) -> Engine: return self._engine @classmethod - def _post_record_to_model(cls, post_record: PostRecord) -> PostModel: + def _media_record_to_model(cls, media_record: MediaRecord) -> Media: + return Media( + media_key=media_record.media_key, + type=media_record.type, + url=media_record.url, + width=media_record.width, + height=media_record.height, + ) + + @classmethod + def _post_record_media_details_to_model(cls, post_record: PostRecord) -> MediaDetails: + if post_record.media_details == []: + return [] + return [cls._media_record_to_model(post_media.media) for post_media in post_record.media_details] + + @classmethod + def _post_record_to_model(cls, post_record: PostRecord, *, with_media: bool) -> PostModel: + # post_record.media_detailsにアクセスしたタイミングでメディア情報を一気に引っ張るクエリが発行される + # media情報がいらない場合はクエリを発行したくないので先にwith_mediaをチェック + media_details = cls._post_record_media_details_to_model(post_record) if with_media else [] + return PostModel( post_id=post_record.post_id, x_user_id=post_record.user_id, @@ -248,7 +291,7 @@ def _post_record_to_model(cls, post_record: PostRecord) -> PostModel: following_count=post_record.user.following_count, ), text=post_record.text, - media_details=post_record.media_details, + media_details=media_details, created_at=post_record.created_at, like_count=post_record.like_count, repost_count=post_record.repost_count, @@ -340,6 +383,7 @@ def get_posts( search_url: Union[HttpUrl, None] = None, offset: Union[int, None] = None, limit: int = 100, + with_media: bool = True, ) -> Generator[PostModel, None, None]: with Session(self.engine) as sess: query = sess.query(PostRecord) @@ -365,7 +409,7 @@ def get_posts( query = query.offset(offset) query = query.limit(limit) for post_record in query.all(): - yield self._post_record_to_model(post_record) + yield self._post_record_to_model(post_record, with_media=with_media) def get_number_of_posts( self, diff --git a/common/tests/conftest.py b/common/tests/conftest.py index ab3707b..c972acd 100644 --- a/common/tests/conftest.py +++ b/common/tests/conftest.py @@ -16,6 +16,7 @@ from birdxplorer_common.models import ( Link, + Media, Note, Post, Topic, @@ -27,9 +28,11 @@ from birdxplorer_common.storage import ( Base, LinkRecord, + MediaRecord, NoteRecord, NoteTopicAssociation, PostLinkAssociation, + PostMediaAssociation, PostRecord, TopicRecord, XUserRecord, @@ -97,6 +100,11 @@ class XUserFactory(ModelFactory[XUser]): __model__ = XUser +@register_fixture(name="media_factory") +class MediaFactory(ModelFactory[Media]): + __model__ = Media + + @register_fixture(name="post_factory") class PostFactory(ModelFactory[Post]): __model__ = Post @@ -225,21 +233,47 @@ def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None, yield x_users +@fixture +def media_samples(media_factory: MediaFactory) -> Generator[List[Media], None, None]: + yield [ + media_factory.build( + media_key="1234567890123456781", + url="https://pbs.twimg.com/media/xxxxxxxxxxxxxxx.jpg", + type="photo", + width=100, + height=100, + ), + media_factory.build( + media_key="1234567890123456782", + url="https://pbs.twimg.com/media/yyyyyyyyyyyyyyy.mp4", + type="video", + width=200, + height=200, + ), + media_factory.build( + media_key="1234567890123456783", + url="https://pbs.twimg.com/media/zzzzzzzzzzzzzzz.gif", + type="animated_gif", + width=300, + height=300, + ), + ] + + @fixture def post_samples( - post_factory: PostFactory, x_user_samples: List[XUser], link_samples: List[Link] + post_factory: PostFactory, x_user_samples: List[XUser], media_samples: List[Media], link_samples: List[Link] ) -> Generator[List[Post], None, None]: posts = [ post_factory.build( post_id="2234567890123456781", - link=None, x_user_id="1234567890123456781", x_user=x_user_samples[0], text="""\ 新しいプロジェクトがついに公開されました!詳細はこちら👉 https://t.co/xxxxxxxxxxx/ #プロジェクト #新発売 #Tech""", - media_details=None, + media_details=[], created_at=1152921600000, like_count=10, repost_count=20, @@ -248,14 +282,13 @@ def post_samples( ), post_factory.build( post_id="2234567890123456791", - link=None, x_user_id="1234567890123456781", x_user=x_user_samples[0], text="""\ このブログ記事、めちゃくちゃ参考になった!🔥 チェックしてみて! https://t.co/yyyyyyyyyyy/ #学び #自己啓発""", - media_details=None, + media_details=[media_samples[0]], created_at=1153921700000, like_count=10, repost_count=20, @@ -264,12 +297,11 @@ def post_samples( ), post_factory.build( post_id="2234567890123456801", - link=None, x_user_id="1234567890123456782", x_user=x_user_samples[1], text="""\ -次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ https://t.co/wwwwwwwwwww/ #旅行 #バケーション""", - media_details=None, +次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ #旅行 #バケーション""", + media_details=[media_samples[1], media_samples[2]], created_at=1154921800000, like_count=10, repost_count=20, @@ -282,7 +314,7 @@ def post_samples( x_user_id="1234567890123456782", x_user=x_user_samples[1], text="https://t.co/zzzzzzzzzzz/ https://t.co/wwwwwwwwwww/", - media_details=None, + media_details=[], created_at=1154922900000, like_count=10, repost_count=20, @@ -295,7 +327,7 @@ def post_samples( x_user_id="1234567890123456783", x_user=x_user_samples[2], text="empty", - media_details=None, + media_details=[], created_at=1154923900000, like_count=10, repost_count=20, @@ -419,8 +451,8 @@ def x_user_records_sample( @fixture def link_records_sample( - link_samples: List[Link], engine_for_test: Engine, + link_samples: List[Link], ) -> Generator[List[LinkRecord], None, None]: res = [LinkRecord(link_id=d.link_id, url=d.url) for d in link_samples] with Session(engine_for_test) as sess: @@ -429,9 +461,31 @@ def link_records_sample( yield res +@fixture +def media_records_sample( + engine_for_test: Engine, + media_samples: List[Media], +) -> Generator[List[MediaRecord], None, None]: + res = [ + MediaRecord( + media_key=d.media_key, + type=d.type, + url=d.url, + width=d.width, + height=d.height, + ) + for d in media_samples + ] + with Session(engine_for_test) as sess: + sess.add_all(res) + sess.commit() + yield res + + @fixture def post_records_sample( x_user_records_sample: List[XUserRecord], + media_records_sample: List[MediaRecord], link_records_sample: List[LinkRecord], post_samples: List[Post], engine_for_test: Engine, @@ -443,17 +497,23 @@ def post_records_sample( post_id=post.post_id, user_id=post.x_user_id, text=post.text, - media_details=post.media_details, created_at=post.created_at, like_count=post.like_count, repost_count=post.repost_count, impression_count=post.impression_count, ) sess.add(inst) + for link in post.links: - assoc = PostLinkAssociation(link_id=link.link_id, post_id=inst.post_id) - sess.add(assoc) - inst.links.append(assoc) + post_link_assoc = PostLinkAssociation(link_id=link.link_id, post_id=inst.post_id) + sess.add(post_link_assoc) + inst.links.append(post_link_assoc) + + for media in post.media_details: + post_media_assoc = PostMediaAssociation(media_key=media.media_key, post_id=inst.post_id) + sess.add(post_media_assoc) + inst.media_details.append(post_media_assoc) + res.append(inst) sess.commit() yield res diff --git a/common/tests/test_storage.py b/common/tests/test_storage.py index 079056a..ce854b1 100644 --- a/common/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -45,6 +45,8 @@ def test_get_topic_list( [dict(search_url=HttpUrl("https://example.com/sh3")), [2, 3]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], [dict(offset=1, limit=1, search_text="https://t.co/xxxxxxxxxxx/"), [2]], + [dict(with_media=True), [0, 1, 2, 3, 4]], + [dict(post_ids=[PostId.from_str("2234567890123456781")], with_media=False), [0]], ], ) def test_get_post(