diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index e318459..c6b366e 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String -from .models import BinaryBool, LanguageIdentifier, MediaDetails, NonNegativeInt +from .models import BinaryBool, LanguageIdentifier, MediaDetails, NonNegativeInt, XMedia, XMediaType from .models import Note as NoteModel from .models import NoteId, NotesClassification, NotesHarmful, ParticipantId from .models import Post as PostModel @@ -88,6 +88,29 @@ class XUserRecord(Base): following_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) +class PostMediaAssociation(Base): + __tablename__ = "post_media" + + post_id: Mapped[PostId] = mapped_column(ForeignKey("posts.post_id"), primary_key=True) + media_key: Mapped[str] = mapped_column(ForeignKey("x_medias.media_key"), primary_key=True) + + # このテーブルにアクセスした時点でほぼ間違いなく MediaRecord も必要なので一気に引っ張る + media: Mapped["MediaRecord"] = relationship(back_populates="post_media_association", lazy="joined") + + +class MediaRecord(Base): + __tablename__ = "x_medias" + + media_key: Mapped[str] = mapped_column(primary_key=True) + + type: Mapped[XMediaType] = mapped_column(nullable=False) + url: Mapped[HttpUrl] = mapped_column(nullable=False) + width: Mapped[NonNegativeInt] = mapped_column(nullable=False) + height: Mapped[NonNegativeInt] = mapped_column(nullable=False) + + post_media_association: Mapped["PostMediaAssociation"] = relationship(back_populates="media") + + class PostRecord(Base): __tablename__ = "posts" @@ -95,7 +118,7 @@ class PostRecord(Base): user_id: Mapped[UserId] = mapped_column(ForeignKey("x_users.user_id"), nullable=False) user: Mapped[XUserRecord] = relationship() text: Mapped[SummaryString] = mapped_column(nullable=False) - media_details: Mapped[MediaDetails] = mapped_column() + media_details: Mapped[List[PostMediaAssociation]] = relationship() created_at: Mapped[TwitterTimestamp] = mapped_column(nullable=False) like_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) repost_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) @@ -179,7 +202,29 @@ def engine(self) -> Engine: return self._engine @classmethod - def _post_record_to_model(cls, post_record: PostRecord) -> PostModel: + def _media_record_to_model(cls, media_record: MediaRecord) -> XMedia: + return XMedia( + media_key=media_record.media_key, + type=media_record.type, + url=media_record.url, + width=media_record.width, + height=media_record.height, + ) + + @classmethod + def _post_record_media_details_to_model(cls, post_record: PostRecord) -> MediaDetails: + if post_record.media_details is None: + return None + if post_record.media_details == []: + return [] + return [cls._media_record_to_model(post_media.media) for post_media in post_record.media_details] + + @classmethod + def _post_record_to_model(cls, post_record: PostRecord, *, with_media: bool) -> PostModel: + # post_record.media_detailsにアクセスしたタイミングでメディア情報を一気に引っ張るクエリが発行される + # media情報がいらない場合はクエリを発行したくないので先にwith_mediaをチェック + media_details = cls._post_record_media_details_to_model(post_record) if with_media else None + return PostModel( post_id=post_record.post_id, x_user_id=post_record.user_id, @@ -191,7 +236,7 @@ def _post_record_to_model(cls, post_record: PostRecord) -> PostModel: following_count=post_record.user.following_count, ), text=post_record.text, - media_details=post_record.media_details, + media_details=media_details, created_at=post_record.created_at, like_count=post_record.like_count, repost_count=post_record.repost_count, @@ -277,6 +322,7 @@ def get_posts( search_text: Union[str, None] = None, offset: Union[int, None] = None, limit: int = 100, + with_media: bool = True, ) -> Generator[PostModel, None, None]: with Session(self.engine) as sess: query = sess.query(PostRecord) @@ -296,7 +342,7 @@ def get_posts( query = query.offset(offset) query = query.limit(limit) for post_record in query.all(): - yield self._post_record_to_model(post_record) + yield self._post_record_to_model(post_record, with_media=with_media) def get_number_of_posts( self,