From 0a9f6987977a5aba35ab91d8a6049a411c2efefc Mon Sep 17 00:00:00 2001 From: osoken Date: Tue, 17 Sep 2024 13:45:25 +0900 Subject: [PATCH 1/8] fix: fix type error --- common/birdxplorer_common/models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/common/birdxplorer_common/models.py b/common/birdxplorer_common/models.py index 72052dd..0e89a52 100644 --- a/common/birdxplorer_common/models.py +++ b/common/birdxplorer_common/models.py @@ -6,9 +6,9 @@ from pydantic import BaseModel as PydanticBaseModel from pydantic import ConfigDict, GetCoreSchemaHandler, HttpUrl, TypeAdapter from pydantic.alias_generators import to_camel +from pydantic.main import IncEx from pydantic_core import core_schema -IncEx: TypeAlias = "set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None" StrT = TypeVar("StrT", bound="BaseString") IntT = TypeVar("IntT", bound="BaseInt") FloatT = TypeVar("FloatT", bound="BaseFloat") @@ -467,8 +467,8 @@ def model_dump_json( self, *, indent: int | None = None, - include: IncEx = None, - exclude: IncEx = None, + include: IncEx | None = None, + exclude: IncEx | None = None, context: Dict[str, Any] | None = None, by_alias: bool = True, exclude_unset: bool = False, From 39288b3e727e5451d5340228c7dfd9c2b0a12ad3 Mon Sep 17 00:00:00 2001 From: osoken Date: Tue, 17 Sep 2024 22:30:45 +0900 Subject: [PATCH 2/8] chore: add ulid --- common/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/common/pyproject.toml b/common/pyproject.toml index f29efbc..69b2d92 100644 --- a/common/pyproject.toml +++ b/common/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "sqlalchemy", "pydantic_settings", "JSON-log-formatter", + "ulid-py", ] [project.urls] From 2c8db37fc0ddfc8bd464c3e71807bbe1ba72532a Mon Sep 17 00:00:00 2001 From: osoken Date: Fri, 27 Sep 2024 22:50:05 +0900 Subject: [PATCH 3/8] feat: add Link model --- common/birdxplorer_common/models.py | 21 ++++++++++++++++ common/tests/conftest.py | 39 +++++++++++++++++++++++++++-- common/tests/test_storage.py | 8 +++--- 3 files changed, 62 insertions(+), 6 deletions(-) diff --git a/common/birdxplorer_common/models.py b/common/birdxplorer_common/models.py index 0e89a52..bbe0566 100644 --- a/common/birdxplorer_common/models.py +++ b/common/birdxplorer_common/models.py @@ -677,6 +677,26 @@ class XUser(BaseModel): MediaDetails: TypeAlias = List[HttpUrl] | None +class LinkId(NonNegativeInt): + """ + >>> LinkId.from_int(1) + LinkId(1) + """ + + pass + + +class Link(BaseModel): + """ + >>> Link.model_validate_json('{"linkId": 1, "canonicalUrl": "https://example.com", "shortUrl": "https://example.com/short"}') + Link(link_id=LinkId(1), canonical_url=Url('https://example.com/'), short_url=Url('https://example.com/short')) + """ # noqa: E501 + + link_id: LinkId + canonical_url: HttpUrl + short_url: HttpUrl + + class Post(BaseModel): post_id: PostId link: Optional[HttpUrl] = None @@ -688,6 +708,7 @@ class Post(BaseModel): like_count: NonNegativeInt repost_count: NonNegativeInt impression_count: NonNegativeInt + links: List[Link] = [] class PaginationMeta(BaseModel): diff --git a/common/tests/conftest.py b/common/tests/conftest.py index a8c048b..6fe2285 100644 --- a/common/tests/conftest.py +++ b/common/tests/conftest.py @@ -15,6 +15,7 @@ from sqlalchemy.sql import text from birdxplorer_common.models import ( + Link, Note, Post, Topic, @@ -99,6 +100,11 @@ class PostFactory(ModelFactory[Post]): __model__ = Post +@register_fixture(name="link_factory") +class LinkFactory(ModelFactory[Link]): + __model__ = Link + + @fixture def user_enrollment_samples( user_enrollment_factory: UserEnrollmentFactory, @@ -117,6 +123,17 @@ def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, N yield topics +@fixture +def link_samples(link_factory: LinkFactory) -> Generator[List[Link], None, None]: + links = [ + link_factory.build(link_id=0, canonical_url="https://t.co/xxxxxxxxxxx/", short_url="https://example.com/sh0"), + link_factory.build(link_id=1, canonical_url="https://t.co/yyyyyyyyyyy/", short_url="https://example.com/sh1"), + link_factory.build(link_id=2, canonical_url="https://t.co/zzzzzzzzzzz/", short_url="https://example.com/sh2"), + link_factory.build(link_id=3, canonical_url="https://t.co/wwwwwwwwwww/", short_url="https://example.com/sh3"), + ] + yield links + + @fixture def note_samples(note_factory: NoteFactory, topic_samples: List[Topic]) -> Generator[List[Note], None, None]: notes = [ @@ -201,7 +218,9 @@ def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None, @fixture -def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Generator[List[Post], None, None]: +def post_samples( + post_factory: PostFactory, x_user_samples: List[XUser], link_samples: List[Link] +) -> Generator[List[Post], None, None]: posts = [ post_factory.build( post_id="2234567890123456781", @@ -217,6 +236,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene like_count=10, repost_count=20, impression_count=30, + links=[link_samples[0]], ), post_factory.build( post_id="2234567890123456791", @@ -232,6 +252,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene like_count=10, repost_count=20, impression_count=30, + links=[link_samples[1]], ), post_factory.build( post_id="2234567890123456801", @@ -239,12 +260,26 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene x_user_id="1234567890123456782", x_user=x_user_samples[1], text="""\ -次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ #旅行 #バケーション""", +次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ https://t.co/wwwwwwwwwww/ #旅行 #バケーション""", media_details=None, created_at=1154921800000, like_count=10, repost_count=20, impression_count=30, + links=[link_samples[0], link_samples[3]], + ), + post_factory.build( + post_id="2234567890123456811", + link=None, + x_user_id="1234567890123456782", + x_user=x_user_samples[1], + text="https://t.co/zzzzzzzzzzz/ https://t.co/wwwwwwwwwww/", + media_details=None, + created_at=1154922900000, + like_count=10, + repost_count=20, + impression_count=30, + links=[link_samples[2], link_samples[3]], ), ] yield posts diff --git a/common/tests/test_storage.py b/common/tests/test_storage.py index cc3638d..7943b9a 100644 --- a/common/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -31,8 +31,8 @@ def test_get_topic_list( @pytest.mark.parametrize( ["filter_args", "expected_indices"], [ - [dict(), [0, 1, 2]], - [dict(offset=1), [1, 2]], + [dict(), [0, 1, 2, 3]], + [dict(offset=1), [1, 2, 3]], [dict(limit=1), [0]], [dict(offset=1, limit=1), [1]], [dict(post_ids=[PostId.from_str("2234567890123456781"), PostId.from_str("2234567890123456801")]), [0, 2]], @@ -63,11 +63,11 @@ def test_get_post( @pytest.mark.parametrize( ["filter_args", "expected_indices"], [ - [dict(), [0, 1, 2]], + [dict(), [0, 1, 2, 3]], [dict(post_ids=[PostId.from_str("2234567890123456781"), PostId.from_str("2234567890123456801")]), [0, 2]], [dict(post_ids=[]), []], [dict(start=TwitterTimestamp.from_int(1153921700000), end=TwitterTimestamp.from_int(1153921800000)), [1]], - [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2]], + [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3]], [dict(end=TwitterTimestamp.from_int(1153921700000)), [0]], [dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], From 960f7b555ac7b761e1c9a86a94b058a089a68ea6 Mon Sep 17 00:00:00 2001 From: osoken Date: Sat, 28 Sep 2024 23:55:03 +0900 Subject: [PATCH 4/8] test: add more test samples --- common/tests/conftest.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/common/tests/conftest.py b/common/tests/conftest.py index 6fe2285..1acc7aa 100644 --- a/common/tests/conftest.py +++ b/common/tests/conftest.py @@ -281,6 +281,19 @@ def post_samples( impression_count=30, links=[link_samples[2], link_samples[3]], ), + post_factory.build( + post_id="2234567890123456821", + link=None, + x_user_id="1234567890123456783", + x_user=x_user_samples[2], + text="empty", + media_details=None, + created_at=1154923900000, + like_count=10, + repost_count=20, + impression_count=30, + links=[], + ), ] yield posts From 6876d74a5ce0d327f0a3b2b4cad014fdb7cdca6b Mon Sep 17 00:00:00 2001 From: osoken Date: Sat, 28 Sep 2024 23:59:58 +0900 Subject: [PATCH 5/8] feat: add some tables --- common/birdxplorer_common/storage.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index e318459..c347976 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String -from .models import BinaryBool, LanguageIdentifier, MediaDetails, NonNegativeInt +from .models import BinaryBool, LanguageIdentifier, LinkId, MediaDetails, NonNegativeInt from .models import Note as NoteModel from .models import NoteId, NotesClassification, NotesHarmful, ParticipantId from .models import Post as PostModel @@ -34,6 +34,7 @@ def adapt_pydantic_http_url(url: AnyUrl) -> AsIs: class Base(DeclarativeBase): type_annotation_map = { + LinkId: Integer, TopicId: Integer, TopicLabel: JSON, NoteId: String, @@ -88,6 +89,14 @@ class XUserRecord(Base): following_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) +class LinkRecord(Base): + __tablename__ = "links" + + link_id: Mapped[LinkId] = mapped_column(primary_key=True) + canonical_url: Mapped[HttpUrl] = mapped_column(nullable=False, index=True) + short_url: Mapped[HttpUrl] = mapped_column(nullable=False, index=True) + + class PostRecord(Base): __tablename__ = "posts" @@ -102,6 +111,14 @@ class PostRecord(Base): impression_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) +class PostLinkAssociation(Base): + __tablename__ = "post_link" + + post_id: Mapped[PostId] = mapped_column(ForeignKey("posts.post_id"), primary_key=True) + link_id: Mapped[LinkId] = mapped_column(ForeignKey("links.link_id"), primary_key=True) + link: Mapped["LinkRecord"] = relationship() + + class RowNoteRecord(Base): __tablename__ = "row_notes" From 3fff12c9b2a6d662ca654a676d6847ed870f0c56 Mon Sep 17 00:00:00 2001 From: osoken Date: Sun, 29 Sep 2024 00:22:39 +0900 Subject: [PATCH 6/8] feat: add Link table model --- common/birdxplorer_common/storage.py | 25 +++++++++------ common/tests/conftest.py | 47 +++++++++++++++++++--------- common/tests/test_storage.py | 10 +++--- 3 files changed, 54 insertions(+), 28 deletions(-) diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index c347976..1f6a7ac 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -7,7 +7,9 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String -from .models import BinaryBool, LanguageIdentifier, LinkId, MediaDetails, NonNegativeInt +from .models import BinaryBool, LanguageIdentifier +from .models import Link as LinkModel +from .models import LinkId, MediaDetails, NonNegativeInt from .models import Note as NoteModel from .models import NoteId, NotesClassification, NotesHarmful, ParticipantId from .models import Post as PostModel @@ -97,6 +99,14 @@ class LinkRecord(Base): short_url: Mapped[HttpUrl] = mapped_column(nullable=False, index=True) +class PostLinkAssociation(Base): + __tablename__ = "post_link" + + post_id: Mapped[PostId] = mapped_column(ForeignKey("posts.post_id"), primary_key=True) + link_id: Mapped[LinkId] = mapped_column(ForeignKey("links.link_id"), primary_key=True) + link: Mapped[LinkRecord] = relationship() + + class PostRecord(Base): __tablename__ = "posts" @@ -109,14 +119,7 @@ class PostRecord(Base): like_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) repost_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) impression_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) - - -class PostLinkAssociation(Base): - __tablename__ = "post_link" - - post_id: Mapped[PostId] = mapped_column(ForeignKey("posts.post_id"), primary_key=True) - link_id: Mapped[LinkId] = mapped_column(ForeignKey("links.link_id"), primary_key=True) - link: Mapped["LinkRecord"] = relationship() + links: Mapped[List[PostLinkAssociation]] = relationship() class RowNoteRecord(Base): @@ -213,6 +216,10 @@ def _post_record_to_model(cls, post_record: PostRecord) -> PostModel: like_count=post_record.like_count, repost_count=post_record.repost_count, impression_count=post_record.impression_count, + links=[ + LinkModel(link_id=link.link_id, canonical_url=link.link.canonical_url, short_url=link.link.short_url) + for link in post_record.links + ], ) def get_user_enrollment_by_participant_id(self, participant_id: ParticipantId) -> UserEnrollment: diff --git a/common/tests/conftest.py b/common/tests/conftest.py index 1acc7aa..8d9ffd4 100644 --- a/common/tests/conftest.py +++ b/common/tests/conftest.py @@ -26,8 +26,10 @@ from birdxplorer_common.settings import GlobalSettings, PostgresStorageSettings from birdxplorer_common.storage import ( Base, + LinkRecord, NoteRecord, NoteTopicAssociation, + PostLinkAssociation, PostRecord, TopicRecord, XUserRecord, @@ -409,26 +411,43 @@ def x_user_records_sample( yield res +@fixture +def link_records_sample( + link_samples: List[Link], + engine_for_test: Engine, +) -> Generator[List[LinkRecord], None, None]: + res = [LinkRecord(link_id=d.link_id, canonical_url=d.canonical_url, short_url=d.short_url) for d in link_samples] + with Session(engine_for_test) as sess: + sess.add_all(res) + sess.commit() + yield res + + @fixture def post_records_sample( x_user_records_sample: List[XUserRecord], + link_records_sample: List[LinkRecord], post_samples: List[Post], engine_for_test: Engine, ) -> Generator[List[PostRecord], None, None]: - res = [ - PostRecord( - post_id=d.post_id, - user_id=d.x_user_id, - text=d.text, - media_details=d.media_details, - created_at=d.created_at, - like_count=d.like_count, - repost_count=d.repost_count, - impression_count=d.impression_count, - ) - for d in post_samples - ] + res: List[PostRecord] = [] with Session(engine_for_test) as sess: - sess.add_all(res) + for post in post_samples: + inst = PostRecord( + post_id=post.post_id, + user_id=post.x_user_id, + text=post.text, + media_details=post.media_details, + created_at=post.created_at, + like_count=post.like_count, + repost_count=post.repost_count, + impression_count=post.impression_count, + ) + sess.add(inst) + for link in post.links: + assoc = PostLinkAssociation(link_id=link.link_id, post_id=inst.post_id) + sess.add(assoc) + inst.links.append(assoc) + res.append(inst) sess.commit() yield res diff --git a/common/tests/test_storage.py b/common/tests/test_storage.py index 7943b9a..64a2141 100644 --- a/common/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -31,14 +31,14 @@ def test_get_topic_list( @pytest.mark.parametrize( ["filter_args", "expected_indices"], [ - [dict(), [0, 1, 2, 3]], - [dict(offset=1), [1, 2, 3]], + [dict(), [0, 1, 2, 3, 4]], + [dict(offset=1), [1, 2, 3, 4]], [dict(limit=1), [0]], [dict(offset=1, limit=1), [1]], [dict(post_ids=[PostId.from_str("2234567890123456781"), PostId.from_str("2234567890123456801")]), [0, 2]], [dict(post_ids=[]), []], [dict(start=TwitterTimestamp.from_int(1153921700000), end=TwitterTimestamp.from_int(1153921800000)), [1]], - [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2]], + [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3, 4]], [dict(end=TwitterTimestamp.from_int(1153921700000)), [0]], [dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], @@ -63,11 +63,11 @@ def test_get_post( @pytest.mark.parametrize( ["filter_args", "expected_indices"], [ - [dict(), [0, 1, 2, 3]], + [dict(), [0, 1, 2, 3, 4]], [dict(post_ids=[PostId.from_str("2234567890123456781"), PostId.from_str("2234567890123456801")]), [0, 2]], [dict(post_ids=[]), []], [dict(start=TwitterTimestamp.from_int(1153921700000), end=TwitterTimestamp.from_int(1153921800000)), [1]], - [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3]], + [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3, 4]], [dict(end=TwitterTimestamp.from_int(1153921700000)), [0]], [dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], From 115fb1e883167d43f91c03f036c8cd92d08c7cd5 Mon Sep 17 00:00:00 2001 From: osoken Date: Sun, 6 Oct 2024 14:52:45 +0900 Subject: [PATCH 7/8] feat: Change LinkId field from int to UUID and generate LinkId deterministically from URL --- common/birdxplorer_common/models.py | 68 ++++++++++++++++++++++++---- common/birdxplorer_common/storage.py | 12 ++--- common/tests/conftest.py | 10 ++-- 3 files changed, 68 insertions(+), 22 deletions(-) diff --git a/common/birdxplorer_common/models.py b/common/birdxplorer_common/models.py index bbe0566..d0b57b1 100644 --- a/common/birdxplorer_common/models.py +++ b/common/birdxplorer_common/models.py @@ -1,10 +1,18 @@ from abc import ABC, abstractmethod from datetime import datetime, timezone from enum import Enum +from random import Random from typing import Any, Dict, List, Literal, Optional, Type, TypeAlias, TypeVar, Union +from uuid import UUID from pydantic import BaseModel as PydanticBaseModel -from pydantic import ConfigDict, GetCoreSchemaHandler, HttpUrl, TypeAdapter +from pydantic import ( + ConfigDict, + GetCoreSchemaHandler, + HttpUrl, + TypeAdapter, + model_validator, +) from pydantic.alias_generators import to_camel from pydantic.main import IncEx from pydantic_core import core_schema @@ -677,24 +685,66 @@ class XUser(BaseModel): MediaDetails: TypeAlias = List[HttpUrl] | None -class LinkId(NonNegativeInt): +class LinkId(UUID): """ - >>> LinkId.from_int(1) - LinkId(1) + >>> LinkId("53dc4ed6-fc9b-54ef-1afa-90f1125098c5") + LinkId('53dc4ed6-fc9b-54ef-1afa-90f1125098c5') + >>> LinkId(UUID("53dc4ed6-fc9b-54ef-1afa-90f1125098c5")) + LinkId('53dc4ed6-fc9b-54ef-1afa-90f1125098c5') """ - pass + def __init__( + self, + hex: str | None = None, + int: int | None = None, + ) -> None: + if isinstance(hex, UUID): + hex = str(hex) + super().__init__(hex, int=int) + + @classmethod + def from_url(cls, url: HttpUrl) -> "LinkId": + """ + >>> LinkId.from_url("https://example.com/") + LinkId('d5d15194-6574-0c01-8f6f-15abd72b2cf6') + """ + random_number_generator = Random() + random_number_generator.seed(str(url).encode("utf-8")) + return LinkId(int=random_number_generator.getrandbits(128)) + + @classmethod + def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: + return core_schema.no_info_plain_validator_function( + cls.validate, + serialization=core_schema.plain_serializer_function_ser_schema(cls.serialize, when_used="json"), + ) + + @classmethod + def validate(cls, v: Any) -> "LinkId": + return cls(v) + + def serialize(self) -> str: + return str(self) class Link(BaseModel): """ - >>> Link.model_validate_json('{"linkId": 1, "canonicalUrl": "https://example.com", "shortUrl": "https://example.com/short"}') - Link(link_id=LinkId(1), canonical_url=Url('https://example.com/'), short_url=Url('https://example.com/short')) + >>> Link.model_validate_json('{"linkId": "d5d15194-6574-0c01-8f6f-15abd72b2cf6", "url": "https://example.com"}') + Link(link_id=LinkId('d5d15194-6574-0c01-8f6f-15abd72b2cf6'), url=Url('https://example.com/')) + >>> Link(url="https://example.com/") + Link(link_id=LinkId('d5d15194-6574-0c01-8f6f-15abd72b2cf6'), url=Url('https://example.com/')) + >>> Link(link_id=UUID("d5d15194-6574-0c01-8f6f-15abd72b2cf6"), url="https://example.com/") + Link(link_id=LinkId('d5d15194-6574-0c01-8f6f-15abd72b2cf6'), url=Url('https://example.com/')) """ # noqa: E501 link_id: LinkId - canonical_url: HttpUrl - short_url: HttpUrl + url: HttpUrl + + @model_validator(mode="before") + def validate_link_id(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "link_id" not in values: + values["link_id"] = LinkId.from_url(values["url"]) + return values class Post(BaseModel): diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index 1f6a7ac..2acab19 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -5,7 +5,7 @@ from sqlalchemy import ForeignKey, create_engine, func, select from sqlalchemy.engine import Engine from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship -from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String +from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String, Uuid from .models import BinaryBool, LanguageIdentifier from .models import Link as LinkModel @@ -36,7 +36,7 @@ def adapt_pydantic_http_url(url: AnyUrl) -> AsIs: class Base(DeclarativeBase): type_annotation_map = { - LinkId: Integer, + LinkId: Uuid, TopicId: Integer, TopicLabel: JSON, NoteId: String, @@ -95,8 +95,7 @@ class LinkRecord(Base): __tablename__ = "links" link_id: Mapped[LinkId] = mapped_column(primary_key=True) - canonical_url: Mapped[HttpUrl] = mapped_column(nullable=False, index=True) - short_url: Mapped[HttpUrl] = mapped_column(nullable=False, index=True) + url: Mapped[HttpUrl] = mapped_column(nullable=False, index=True) class PostLinkAssociation(Base): @@ -216,10 +215,7 @@ def _post_record_to_model(cls, post_record: PostRecord) -> PostModel: like_count=post_record.like_count, repost_count=post_record.repost_count, impression_count=post_record.impression_count, - links=[ - LinkModel(link_id=link.link_id, canonical_url=link.link.canonical_url, short_url=link.link.short_url) - for link in post_record.links - ], + links=[LinkModel(link_id=link.link_id, url=link.link.url) for link in post_record.links], ) def get_user_enrollment_by_participant_id(self, participant_id: ParticipantId) -> UserEnrollment: diff --git a/common/tests/conftest.py b/common/tests/conftest.py index 8d9ffd4..13fd8a2 100644 --- a/common/tests/conftest.py +++ b/common/tests/conftest.py @@ -128,10 +128,10 @@ def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, N @fixture def link_samples(link_factory: LinkFactory) -> Generator[List[Link], None, None]: links = [ - link_factory.build(link_id=0, canonical_url="https://t.co/xxxxxxxxxxx/", short_url="https://example.com/sh0"), - link_factory.build(link_id=1, canonical_url="https://t.co/yyyyyyyyyyy/", short_url="https://example.com/sh1"), - link_factory.build(link_id=2, canonical_url="https://t.co/zzzzzzzzzzz/", short_url="https://example.com/sh2"), - link_factory.build(link_id=3, canonical_url="https://t.co/wwwwwwwwwww/", short_url="https://example.com/sh3"), + link_factory.build(link_id="9f56ee4a-6b36-b79c-d6ca-67865e54bbd5", url="https://example.com/sh0"), + link_factory.build(link_id="f5b0ac79-20fe-9718-4a40-6030bb62d156", url="https://example.com/sh1"), + link_factory.build(link_id="76a0ac4a-a20c-b1f4-1906-d00e2e8f8bf8", url="https://example.com/sh2"), + link_factory.build(link_id="6c352be8-eca3-0d96-55bf-a9bbef1c0fc2", url="https://example.com/sh3"), ] yield links @@ -416,7 +416,7 @@ def link_records_sample( link_samples: List[Link], engine_for_test: Engine, ) -> Generator[List[LinkRecord], None, None]: - res = [LinkRecord(link_id=d.link_id, canonical_url=d.canonical_url, short_url=d.short_url) for d in link_samples] + res = [LinkRecord(link_id=d.link_id, url=d.url) for d in link_samples] with Session(engine_for_test) as sess: sess.add_all(res) sess.commit() From ddb0846e0a9b28fc25c1075436efd8831466920d Mon Sep 17 00:00:00 2001 From: osoken Date: Sun, 6 Oct 2024 15:12:44 +0900 Subject: [PATCH 8/8] fix: Add test data and update tests accordingly --- api/tests/conftest.py | 55 ++++++++++++++++++++++++++++++++-- api/tests/routers/test_data.py | 9 ++++-- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 97a4deb..3b8e677 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -14,6 +14,7 @@ from birdxplorer_common.exceptions import UserEnrollmentNotFoundError from birdxplorer_common.models import ( LanguageIdentifier, + Link, Note, NoteId, ParticipantId, @@ -66,6 +67,11 @@ class PostFactory(ModelFactory[Post]): __model__ = Post +@register_fixture(name="link_factory") +class LinkFactory(ModelFactory[Link]): + __model__ = Link + + @fixture def user_enrollment_samples( user_enrollment_factory: UserEnrollmentFactory, @@ -84,6 +90,17 @@ def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, N yield topics +@fixture +def link_samples(link_factory: LinkFactory) -> Generator[List[Link], None, None]: + links = [ + link_factory.build(link_id="9f56ee4a-6b36-b79c-d6ca-67865e54bbd5", url="https://example.com/sh0"), + link_factory.build(link_id="f5b0ac79-20fe-9718-4a40-6030bb62d156", url="https://example.com/sh1"), + link_factory.build(link_id="76a0ac4a-a20c-b1f4-1906-d00e2e8f8bf8", url="https://example.com/sh2"), + link_factory.build(link_id="6c352be8-eca3-0d96-55bf-a9bbef1c0fc2", url="https://example.com/sh3"), + ] + yield links + + @fixture def note_samples(note_factory: NoteFactory, topic_samples: List[Topic]) -> Generator[List[Note], None, None]: notes = [ @@ -160,10 +177,13 @@ def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None, @fixture -def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Generator[List[Post], None, None]: +def post_samples( + post_factory: PostFactory, x_user_samples: List[XUser], link_samples: List[Link] +) -> Generator[List[Post], None, None]: posts = [ post_factory.build( post_id="2234567890123456781", + link=None, x_user_id="1234567890123456781", x_user=x_user_samples[0], text="""\ @@ -175,9 +195,11 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene like_count=10, repost_count=20, impression_count=30, + links=[link_samples[0]], ), post_factory.build( post_id="2234567890123456791", + link=None, x_user_id="1234567890123456781", x_user=x_user_samples[0], text="""\ @@ -189,18 +211,47 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene like_count=10, repost_count=20, impression_count=30, + links=[link_samples[1]], ), post_factory.build( post_id="2234567890123456801", + link=None, x_user_id="1234567890123456782", x_user=x_user_samples[1], text="""\ -次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ #旅行 #バケーション""", +次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ https://t.co/wwwwwwwwwww/ #旅行 #バケーション""", media_details=None, created_at=1154921800000, like_count=10, repost_count=20, impression_count=30, + links=[link_samples[0], link_samples[3]], + ), + post_factory.build( + post_id="2234567890123456811", + link=None, + x_user_id="1234567890123456782", + x_user=x_user_samples[1], + text="https://t.co/zzzzzzzzzzz/ https://t.co/wwwwwwwwwww/", + media_details=None, + created_at=1154922900000, + like_count=10, + repost_count=20, + impression_count=30, + links=[link_samples[2], link_samples[3]], + ), + post_factory.build( + post_id="2234567890123456821", + link=None, + x_user_id="1234567890123456783", + x_user=x_user_samples[2], + text="empty", + media_details=None, + created_at=1154923900000, + like_count=10, + repost_count=20, + impression_count=30, + links=[], ), ] yield posts diff --git a/api/tests/routers/test_data.py b/api/tests/routers/test_data.py index 67160f9..6ba8f7a 100644 --- a/api/tests/routers/test_data.py +++ b/api/tests/routers/test_data.py @@ -36,7 +36,10 @@ def test_posts_get_limit_and_offset(client: TestClient, post_samples: List[Post] res_json = response.json() assert res_json == { "data": [json.loads(d.model_dump_json()) for d in post_samples[1:3]], - "meta": {"next": None, "prev": "http://testserver/api/v1/data/posts?offset=0&limit=2"}, + "meta": { + "next": "http://testserver/api/v1/data/posts?offset=3&limit=2", + "prev": "http://testserver/api/v1/data/posts?offset=0&limit=2", + }, } @@ -72,7 +75,7 @@ def test_posts_get_has_created_at_filter_start(client: TestClient, post_samples: assert response.status_code == 200 res_json = response.json() assert res_json == { - "data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2)], + "data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2, 3, 4)], "meta": {"next": None, "prev": None}, } @@ -99,7 +102,7 @@ def test_posts_get_created_at_start_filter_accepts_integer(client: TestClient, p assert response.status_code == 200 res_json = response.json() assert res_json == { - "data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2)], + "data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2, 3, 4)], "meta": {"next": None, "prev": None}, }