diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index c347976..1f6a7ac 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -7,7 +7,9 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String -from .models import BinaryBool, LanguageIdentifier, LinkId, MediaDetails, NonNegativeInt +from .models import BinaryBool, LanguageIdentifier +from .models import Link as LinkModel +from .models import LinkId, MediaDetails, NonNegativeInt from .models import Note as NoteModel from .models import NoteId, NotesClassification, NotesHarmful, ParticipantId from .models import Post as PostModel @@ -97,6 +99,14 @@ class LinkRecord(Base): short_url: Mapped[HttpUrl] = mapped_column(nullable=False, index=True) +class PostLinkAssociation(Base): + __tablename__ = "post_link" + + post_id: Mapped[PostId] = mapped_column(ForeignKey("posts.post_id"), primary_key=True) + link_id: Mapped[LinkId] = mapped_column(ForeignKey("links.link_id"), primary_key=True) + link: Mapped[LinkRecord] = relationship() + + class PostRecord(Base): __tablename__ = "posts" @@ -109,14 +119,7 @@ class PostRecord(Base): like_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) repost_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) impression_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) - - -class PostLinkAssociation(Base): - __tablename__ = "post_link" - - post_id: Mapped[PostId] = mapped_column(ForeignKey("posts.post_id"), primary_key=True) - link_id: Mapped[LinkId] = mapped_column(ForeignKey("links.link_id"), primary_key=True) - link: Mapped["LinkRecord"] = relationship() + links: Mapped[List[PostLinkAssociation]] = relationship() class RowNoteRecord(Base): @@ -213,6 +216,10 @@ def _post_record_to_model(cls, post_record: PostRecord) -> PostModel: like_count=post_record.like_count, repost_count=post_record.repost_count, impression_count=post_record.impression_count, + links=[ + LinkModel(link_id=link.link_id, canonical_url=link.link.canonical_url, short_url=link.link.short_url) + for link in post_record.links + ], ) def get_user_enrollment_by_participant_id(self, participant_id: ParticipantId) -> UserEnrollment: diff --git a/common/tests/conftest.py b/common/tests/conftest.py index 1acc7aa..8d9ffd4 100644 --- a/common/tests/conftest.py +++ b/common/tests/conftest.py @@ -26,8 +26,10 @@ from birdxplorer_common.settings import GlobalSettings, PostgresStorageSettings from birdxplorer_common.storage import ( Base, + LinkRecord, NoteRecord, NoteTopicAssociation, + PostLinkAssociation, PostRecord, TopicRecord, XUserRecord, @@ -409,26 +411,43 @@ def x_user_records_sample( yield res +@fixture +def link_records_sample( + link_samples: List[Link], + engine_for_test: Engine, +) -> Generator[List[LinkRecord], None, None]: + res = [LinkRecord(link_id=d.link_id, canonical_url=d.canonical_url, short_url=d.short_url) for d in link_samples] + with Session(engine_for_test) as sess: + sess.add_all(res) + sess.commit() + yield res + + @fixture def post_records_sample( x_user_records_sample: List[XUserRecord], + link_records_sample: List[LinkRecord], post_samples: List[Post], engine_for_test: Engine, ) -> Generator[List[PostRecord], None, None]: - res = [ - PostRecord( - post_id=d.post_id, - user_id=d.x_user_id, - text=d.text, - media_details=d.media_details, - created_at=d.created_at, - like_count=d.like_count, - repost_count=d.repost_count, - impression_count=d.impression_count, - ) - for d in post_samples - ] + res: List[PostRecord] = [] with Session(engine_for_test) as sess: - sess.add_all(res) + for post in post_samples: + inst = PostRecord( + post_id=post.post_id, + user_id=post.x_user_id, + text=post.text, + media_details=post.media_details, + created_at=post.created_at, + like_count=post.like_count, + repost_count=post.repost_count, + impression_count=post.impression_count, + ) + sess.add(inst) + for link in post.links: + assoc = PostLinkAssociation(link_id=link.link_id, post_id=inst.post_id) + sess.add(assoc) + inst.links.append(assoc) + res.append(inst) sess.commit() yield res diff --git a/common/tests/test_storage.py b/common/tests/test_storage.py index 7943b9a..64a2141 100644 --- a/common/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -31,14 +31,14 @@ def test_get_topic_list( @pytest.mark.parametrize( ["filter_args", "expected_indices"], [ - [dict(), [0, 1, 2, 3]], - [dict(offset=1), [1, 2, 3]], + [dict(), [0, 1, 2, 3, 4]], + [dict(offset=1), [1, 2, 3, 4]], [dict(limit=1), [0]], [dict(offset=1, limit=1), [1]], [dict(post_ids=[PostId.from_str("2234567890123456781"), PostId.from_str("2234567890123456801")]), [0, 2]], [dict(post_ids=[]), []], [dict(start=TwitterTimestamp.from_int(1153921700000), end=TwitterTimestamp.from_int(1153921800000)), [1]], - [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2]], + [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3, 4]], [dict(end=TwitterTimestamp.from_int(1153921700000)), [0]], [dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], @@ -63,11 +63,11 @@ def test_get_post( @pytest.mark.parametrize( ["filter_args", "expected_indices"], [ - [dict(), [0, 1, 2, 3]], + [dict(), [0, 1, 2, 3, 4]], [dict(post_ids=[PostId.from_str("2234567890123456781"), PostId.from_str("2234567890123456801")]), [0, 2]], [dict(post_ids=[]), []], [dict(start=TwitterTimestamp.from_int(1153921700000), end=TwitterTimestamp.from_int(1153921800000)), [1]], - [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3]], + [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3, 4]], [dict(end=TwitterTimestamp.from_int(1153921700000)), [0]], [dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]],