Skip to content

Commit

Permalink
feat: add Link table model
Browse files Browse the repository at this point in the history
  • Loading branch information
osoken committed Sep 28, 2024
1 parent 6876d74 commit 3fff12c
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 28 deletions.
25 changes: 16 additions & 9 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String

from .models import BinaryBool, LanguageIdentifier, LinkId, MediaDetails, NonNegativeInt
from .models import BinaryBool, LanguageIdentifier
from .models import Link as LinkModel
from .models import LinkId, MediaDetails, NonNegativeInt
from .models import Note as NoteModel
from .models import NoteId, NotesClassification, NotesHarmful, ParticipantId
from .models import Post as PostModel
Expand Down Expand Up @@ -97,6 +99,14 @@ class LinkRecord(Base):
short_url: Mapped[HttpUrl] = mapped_column(nullable=False, index=True)


class PostLinkAssociation(Base):
__tablename__ = "post_link"

post_id: Mapped[PostId] = mapped_column(ForeignKey("posts.post_id"), primary_key=True)
link_id: Mapped[LinkId] = mapped_column(ForeignKey("links.link_id"), primary_key=True)
link: Mapped[LinkRecord] = relationship()


class PostRecord(Base):
__tablename__ = "posts"

Expand All @@ -109,14 +119,7 @@ class PostRecord(Base):
like_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
repost_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
impression_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)


class PostLinkAssociation(Base):
__tablename__ = "post_link"

post_id: Mapped[PostId] = mapped_column(ForeignKey("posts.post_id"), primary_key=True)
link_id: Mapped[LinkId] = mapped_column(ForeignKey("links.link_id"), primary_key=True)
link: Mapped["LinkRecord"] = relationship()
links: Mapped[List[PostLinkAssociation]] = relationship()


class RowNoteRecord(Base):
Expand Down Expand Up @@ -213,6 +216,10 @@ def _post_record_to_model(cls, post_record: PostRecord) -> PostModel:
like_count=post_record.like_count,
repost_count=post_record.repost_count,
impression_count=post_record.impression_count,
links=[
LinkModel(link_id=link.link_id, canonical_url=link.link.canonical_url, short_url=link.link.short_url)
for link in post_record.links
],
)

def get_user_enrollment_by_participant_id(self, participant_id: ParticipantId) -> UserEnrollment:
Expand Down
47 changes: 33 additions & 14 deletions common/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from birdxplorer_common.settings import GlobalSettings, PostgresStorageSettings
from birdxplorer_common.storage import (
Base,
LinkRecord,
NoteRecord,
NoteTopicAssociation,
PostLinkAssociation,
PostRecord,
TopicRecord,
XUserRecord,
Expand Down Expand Up @@ -409,26 +411,43 @@ def x_user_records_sample(
yield res


@fixture
def link_records_sample(
link_samples: List[Link],
engine_for_test: Engine,
) -> Generator[List[LinkRecord], None, None]:
res = [LinkRecord(link_id=d.link_id, canonical_url=d.canonical_url, short_url=d.short_url) for d in link_samples]
with Session(engine_for_test) as sess:
sess.add_all(res)
sess.commit()
yield res


@fixture
def post_records_sample(
x_user_records_sample: List[XUserRecord],
link_records_sample: List[LinkRecord],
post_samples: List[Post],
engine_for_test: Engine,
) -> Generator[List[PostRecord], None, None]:
res = [
PostRecord(
post_id=d.post_id,
user_id=d.x_user_id,
text=d.text,
media_details=d.media_details,
created_at=d.created_at,
like_count=d.like_count,
repost_count=d.repost_count,
impression_count=d.impression_count,
)
for d in post_samples
]
res: List[PostRecord] = []
with Session(engine_for_test) as sess:
sess.add_all(res)
for post in post_samples:
inst = PostRecord(
post_id=post.post_id,
user_id=post.x_user_id,
text=post.text,
media_details=post.media_details,
created_at=post.created_at,
like_count=post.like_count,
repost_count=post.repost_count,
impression_count=post.impression_count,
)
sess.add(inst)
for link in post.links:
assoc = PostLinkAssociation(link_id=link.link_id, post_id=inst.post_id)
sess.add(assoc)
inst.links.append(assoc)
res.append(inst)
sess.commit()
yield res
10 changes: 5 additions & 5 deletions common/tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def test_get_topic_list(
@pytest.mark.parametrize(
["filter_args", "expected_indices"],
[
[dict(), [0, 1, 2, 3]],
[dict(offset=1), [1, 2, 3]],
[dict(), [0, 1, 2, 3, 4]],
[dict(offset=1), [1, 2, 3, 4]],
[dict(limit=1), [0]],
[dict(offset=1, limit=1), [1]],
[dict(post_ids=[PostId.from_str("2234567890123456781"), PostId.from_str("2234567890123456801")]), [0, 2]],
[dict(post_ids=[]), []],
[dict(start=TwitterTimestamp.from_int(1153921700000), end=TwitterTimestamp.from_int(1153921800000)), [1]],
[dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2]],
[dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3, 4]],
[dict(end=TwitterTimestamp.from_int(1153921700000)), [0]],
[dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]],
[dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]],
Expand All @@ -63,11 +63,11 @@ def test_get_post(
@pytest.mark.parametrize(
["filter_args", "expected_indices"],
[
[dict(), [0, 1, 2, 3]],
[dict(), [0, 1, 2, 3, 4]],
[dict(post_ids=[PostId.from_str("2234567890123456781"), PostId.from_str("2234567890123456801")]), [0, 2]],
[dict(post_ids=[]), []],
[dict(start=TwitterTimestamp.from_int(1153921700000), end=TwitterTimestamp.from_int(1153921800000)), [1]],
[dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3]],
[dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3, 4]],
[dict(end=TwitterTimestamp.from_int(1153921700000)), [0]],
[dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]],
[dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]],
Expand Down

0 comments on commit 3fff12c

Please sign in to comment.