Skip to content

Commit

Permalink
Merge pull request #110 from codeforjapan/issue-77-post-media
Browse files Browse the repository at this point in the history
Postにメディア情報を紐づける
  • Loading branch information
yu23ki14 authored Oct 8, 2024
2 parents 3758cb1 + faa7668 commit 22da0ca
Show file tree
Hide file tree
Showing 7 changed files with 261 additions and 41 deletions.
5 changes: 2 additions & 3 deletions api/birdxplorer_api/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def get_posts(
limit: int = Query(default=100, gt=0, le=1000),
search_text: Union[None, str] = Query(default=None),
search_url: Union[None, HttpUrl] = Query(default=None),
media: bool = Query(default=True),
) -> PostListResponse:
if created_at_from is not None and isinstance(created_at_from, str):
created_at_from = ensure_twitter_timestamp(created_at_from)
Expand All @@ -119,6 +120,7 @@ def get_posts(
search_url=search_url,
offset=offset,
limit=limit,
with_media=media,
)
)
total_count = storage.get_number_of_posts(
Expand All @@ -130,9 +132,6 @@ def get_posts(
search_url=search_url,
)

for post in posts:
post.link = HttpUrl(f"https://x.com/{post.x_user.name}/status/{post.post_id}")

base_url = str(request.url).split("?")[0]
next_offset = offset + limit
prev_offset = max(offset - limit, 0)
Expand Down
53 changes: 46 additions & 7 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
LanguageIdentifier,
Link,
LinkId,
Media,
Note,
NoteId,
ParticipantId,
Expand Down Expand Up @@ -64,6 +65,11 @@ class XUserFactory(ModelFactory[XUser]):
__model__ = XUser


@register_fixture(name="media_factory")
class MediaFactory(ModelFactory[Media]):
__model__ = Media


@register_fixture(name="post_factory")
class PostFactory(ModelFactory[Post]):
__model__ = Post
Expand Down Expand Up @@ -183,9 +189,36 @@ def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None,
yield x_users


@fixture
def media_samples(media_factory: MediaFactory) -> Generator[List[Media], None, None]:
yield [
media_factory.build(
media_key="1234567890123456781",
url="https://pbs.twimg.com/media/xxxxxxxxxxxxxxx.jpg",
type="photo",
width=100,
height=100,
),
media_factory.build(
media_key="1234567890123456782",
url="https://pbs.twimg.com/media/yyyyyyyyyyyyyyy.mp4",
type="video",
width=200,
height=200,
),
media_factory.build(
media_key="1234567890123456783",
url="https://pbs.twimg.com/media/zzzzzzzzzzzzzzz.gif",
type="animated_gif",
width=300,
height=300,
),
]


@fixture
def post_samples(
post_factory: PostFactory, x_user_samples: List[XUser], link_samples: List[Link]
post_factory: PostFactory, x_user_samples: List[XUser], media_samples: List[Media], link_samples: List[Link]
) -> Generator[List[Post], None, None]:
posts = [
post_factory.build(
Expand All @@ -197,7 +230,7 @@ def post_samples(
新しいプロジェクトがついに公開されました!詳細はこちら👉
https://t.co/xxxxxxxxxxx/ #プロジェクト #新発売 #Tech""",
media_details=None,
media_details=[],
created_at=1152921600000,
like_count=10,
repost_count=20,
Expand All @@ -213,7 +246,7 @@ def post_samples(
このブログ記事、めちゃくちゃ参考になった!🔥 チェックしてみて!
https://t.co/yyyyyyyyyyy/ #学び #自己啓発""",
media_details=None,
media_details=[media_samples[0]],
created_at=1153921700000,
like_count=10,
repost_count=20,
Expand All @@ -227,7 +260,7 @@ def post_samples(
x_user=x_user_samples[1],
text="""\
次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ https://t.co/wwwwwwwwwww/ #旅行 #バケーション""",
media_details=None,
media_details=[],
created_at=1154921800000,
like_count=10,
repost_count=20,
Expand All @@ -240,7 +273,7 @@ def post_samples(
x_user_id="1234567890123456782",
x_user=x_user_samples[1],
text="https://t.co/zzzzzzzzzzz/ https://t.co/wwwwwwwwwww/",
media_details=None,
media_details=[],
created_at=1154922900000,
like_count=10,
repost_count=20,
Expand All @@ -253,7 +286,7 @@ def post_samples(
x_user_id="1234567890123456783",
x_user=x_user_samples[2],
text="empty",
media_details=None,
media_details=[],
created_at=1154923900000,
like_count=10,
repost_count=20,
Expand All @@ -268,6 +301,7 @@ def post_samples(
def mock_storage(
user_enrollment_samples: List[UserEnrollment],
topic_samples: List[Topic],
media_samples: List[Media],
post_samples: List[Post],
note_samples: List[Note],
link_samples: List[Link],
Expand Down Expand Up @@ -325,6 +359,7 @@ def _get_posts(
search_url: Union[HttpUrl, None] = None,
offset: Union[int, None] = None,
limit: Union[int, None] = None,
with_media: bool = True,
) -> Generator[Post, None, None]:
gen_count = 0
actual_gen_count = 0
Expand Down Expand Up @@ -354,7 +389,11 @@ def _get_posts(
if offset is not None and gen_count <= offset:
continue
actual_gen_count += 1
yield post

if with_media is False:
yield post.model_copy(update={"media_details": []}, deep=True)
else:
yield post

mock.get_posts.side_effect = _get_posts

Expand Down
34 changes: 34 additions & 0 deletions api/tests/routers/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,40 @@ def test_posts_get_timestamp_out_of_range(client: TestClient, post_samples: List
assert response.status_code == 422


def test_posts_get_with_media_by_default(client: TestClient, post_samples: List[Post]) -> None:
response = client.get("/api/v1/data/posts/?postId=2234567890123456791")

assert response.status_code == 200
res_json_default = response.json()
assert res_json_default == {
"data": [json.loads(post_samples[1].model_dump_json())],
"meta": {"next": None, "prev": None},
}


def test_posts_get_with_media_true(client: TestClient, post_samples: List[Post]) -> None:
response = client.get("/api/v1/data/posts/?postId=2234567890123456791&media=true")

assert response.status_code == 200
res_json_default = response.json()
assert res_json_default == {
"data": [json.loads(post_samples[1].model_dump_json())],
"meta": {"next": None, "prev": None},
}


def test_posts_get_with_media_false(client: TestClient, post_samples: List[Post]) -> None:
expected_post = post_samples[1].model_copy(update={"media_details": []})
response = client.get("/api/v1/data/posts/?postId=2234567890123456791&media=false")

assert response.status_code == 200
res_json_default = response.json()
assert res_json_default == {
"data": [json.loads(expected_post.model_dump_json())],
"meta": {"next": None, "prev": None},
}


def test_posts_search_by_text(client: TestClient, post_samples: List[Post]) -> None:
response = client.get("/api/v1/data/posts/?searchText=https%3A%2F%2Ft.co%2Fxxxxxxxxxxx%2F")
assert response.status_code == 200
Expand Down
64 changes: 53 additions & 11 deletions common/birdxplorer_common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,24 @@
from datetime import datetime, timezone
from enum import Enum
from random import Random
from typing import Any, Dict, List, Literal, Optional, Type, TypeAlias, TypeVar, Union
from typing import (
Annotated,
Any,
Dict,
List,
Literal,
Optional,
Type,
TypeAlias,
TypeVar,
Union,
)
from uuid import UUID

from pydantic import BaseModel as PydanticBaseModel
from pydantic import (
ConfigDict,
GetCoreSchemaHandler,
HttpUrl,
TypeAdapter,
model_validator,
)
from pydantic import ConfigDict
from pydantic import Field as PydanticField
from pydantic import GetCoreSchemaHandler, HttpUrl, TypeAdapter, model_validator, computed_field
from pydantic.alias_generators import to_camel
from pydantic.main import IncEx
from pydantic_core import core_schema
Expand Down Expand Up @@ -683,7 +690,20 @@ class XUser(BaseModel):
following_count: NonNegativeInt


MediaDetails: TypeAlias = List[HttpUrl] | None
# ref: https://developer.x.com/en/docs/x-api/data-dictionary/object-model/media
MediaType: TypeAlias = Literal["photo", "video", "animated_gif"]


class Media(BaseModel):
media_key: str

type: MediaType
url: HttpUrl
width: NonNegativeInt
height: NonNegativeInt


MediaDetails: TypeAlias = List[Media]


class LinkId(UUID):
Expand Down Expand Up @@ -750,17 +770,39 @@ def validate_link_id(cls, values: Dict[str, Any]) -> Dict[str, Any]:

class Post(BaseModel):
post_id: PostId
link: Optional[HttpUrl] = None
x_user_id: UserId
x_user: XUser
text: str
media_details: MediaDetails = None
media_details: Annotated[MediaDetails, PydanticField(default_factory=lambda: [])]
created_at: TwitterTimestamp
like_count: NonNegativeInt
repost_count: NonNegativeInt
impression_count: NonNegativeInt
links: List[Link] = []

@property
@computed_field
def link(self) -> HttpUrl:
"""
PostのX上でのURLを返す。
Examples
--------
>>> post = Post(post_id="1234567890123456789",
x_user_id="1234567890123456789",
x_user=XUser(user_id="1234567890123456789",
name="test",
profile_image="https://x.com/test"),
text="test",
created_at=1288834974657,
like_count=1,
repost_count=1,
impression_count=1)
>>> post.link
HttpUrl('https://x.com/test/status/1234567890123456789')
"""
return HttpUrl(f"https://x.com/{self.x_user.name}/status/{self.post_id}")


class PaginationMeta(BaseModel):
next: Optional[HttpUrl] = None
Expand Down
54 changes: 49 additions & 5 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .models import BinaryBool, LanguageIdentifier
from .models import Link as LinkModel
from .models import LinkId, MediaDetails, NonNegativeInt
from .models import LinkId, Media, MediaDetails, MediaType, NonNegativeInt
from .models import Note as NoteModel
from .models import NoteId, NotesClassification, NotesHarmful, ParticipantId
from .models import Post as PostModel
Expand Down Expand Up @@ -107,14 +107,37 @@ class PostLinkAssociation(Base):
link: Mapped[LinkRecord] = relationship()


class PostMediaAssociation(Base):
__tablename__ = "post_media"

post_id: Mapped[PostId] = mapped_column(ForeignKey("posts.post_id"), primary_key=True)
media_key: Mapped[str] = mapped_column(ForeignKey("media.media_key"), primary_key=True)

# このテーブルにアクセスした時点でほぼ間違いなく MediaRecord も必要なので一気に引っ張る
media: Mapped["MediaRecord"] = relationship(back_populates="post_media_association", lazy="joined")


class MediaRecord(Base):
__tablename__ = "media"

media_key: Mapped[str] = mapped_column(primary_key=True)

type: Mapped[MediaType] = mapped_column(nullable=False)
url: Mapped[HttpUrl] = mapped_column(nullable=False)
width: Mapped[NonNegativeInt] = mapped_column(nullable=False)
height: Mapped[NonNegativeInt] = mapped_column(nullable=False)

post_media_association: Mapped["PostMediaAssociation"] = relationship(back_populates="media")


class PostRecord(Base):
__tablename__ = "posts"

post_id: Mapped[PostId] = mapped_column(primary_key=True)
user_id: Mapped[UserId] = mapped_column(ForeignKey("x_users.user_id"), nullable=False)
user: Mapped[XUserRecord] = relationship()
text: Mapped[SummaryString] = mapped_column(nullable=False)
media_details: Mapped[MediaDetails] = mapped_column()
media_details: Mapped[List[PostMediaAssociation]] = relationship()
created_at: Mapped[TwitterTimestamp] = mapped_column(nullable=False)
like_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
repost_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
Expand Down Expand Up @@ -236,7 +259,27 @@ def engine(self) -> Engine:
return self._engine

@classmethod
def _post_record_to_model(cls, post_record: PostRecord) -> PostModel:
def _media_record_to_model(cls, media_record: MediaRecord) -> Media:
return Media(
media_key=media_record.media_key,
type=media_record.type,
url=media_record.url,
width=media_record.width,
height=media_record.height,
)

@classmethod
def _post_record_media_details_to_model(cls, post_record: PostRecord) -> MediaDetails:
if post_record.media_details == []:
return []
return [cls._media_record_to_model(post_media.media) for post_media in post_record.media_details]

@classmethod
def _post_record_to_model(cls, post_record: PostRecord, *, with_media: bool) -> PostModel:
# post_record.media_detailsにアクセスしたタイミングでメディア情報を一気に引っ張るクエリが発行される
# media情報がいらない場合はクエリを発行したくないので先にwith_mediaをチェック
media_details = cls._post_record_media_details_to_model(post_record) if with_media else []

return PostModel(
post_id=post_record.post_id,
x_user_id=post_record.user_id,
Expand All @@ -248,7 +291,7 @@ def _post_record_to_model(cls, post_record: PostRecord) -> PostModel:
following_count=post_record.user.following_count,
),
text=post_record.text,
media_details=post_record.media_details,
media_details=media_details,
created_at=post_record.created_at,
like_count=post_record.like_count,
repost_count=post_record.repost_count,
Expand Down Expand Up @@ -340,6 +383,7 @@ def get_posts(
search_url: Union[HttpUrl, None] = None,
offset: Union[int, None] = None,
limit: int = 100,
with_media: bool = True,
) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
query = sess.query(PostRecord)
Expand All @@ -365,7 +409,7 @@ def get_posts(
query = query.offset(offset)
query = query.limit(limit)
for post_record in query.all():
yield self._post_record_to_model(post_record)
yield self._post_record_to_model(post_record, with_media=with_media)

def get_number_of_posts(
self,
Expand Down
Loading

0 comments on commit 22da0ca

Please sign in to comment.