Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/issue 104 url data model #116

Merged
merged 8 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 53 additions & 2 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from birdxplorer_common.exceptions import UserEnrollmentNotFoundError
from birdxplorer_common.models import (
LanguageIdentifier,
Link,
Note,
NoteId,
ParticipantId,
Expand Down Expand Up @@ -66,6 +67,11 @@ class PostFactory(ModelFactory[Post]):
__model__ = Post


@register_fixture(name="link_factory")
class LinkFactory(ModelFactory[Link]):
__model__ = Link


@fixture
def user_enrollment_samples(
user_enrollment_factory: UserEnrollmentFactory,
Expand All @@ -84,6 +90,17 @@ def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, N
yield topics


@fixture
def link_samples(link_factory: LinkFactory) -> Generator[List[Link], None, None]:
links = [
link_factory.build(link_id="9f56ee4a-6b36-b79c-d6ca-67865e54bbd5", url="https://example.com/sh0"),
link_factory.build(link_id="f5b0ac79-20fe-9718-4a40-6030bb62d156", url="https://example.com/sh1"),
link_factory.build(link_id="76a0ac4a-a20c-b1f4-1906-d00e2e8f8bf8", url="https://example.com/sh2"),
link_factory.build(link_id="6c352be8-eca3-0d96-55bf-a9bbef1c0fc2", url="https://example.com/sh3"),
]
yield links


@fixture
def note_samples(note_factory: NoteFactory, topic_samples: List[Topic]) -> Generator[List[Note], None, None]:
notes = [
Expand Down Expand Up @@ -160,10 +177,13 @@ def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None,


@fixture
def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Generator[List[Post], None, None]:
def post_samples(
post_factory: PostFactory, x_user_samples: List[XUser], link_samples: List[Link]
) -> Generator[List[Post], None, None]:
posts = [
post_factory.build(
post_id="2234567890123456781",
link=None,
x_user_id="1234567890123456781",
x_user=x_user_samples[0],
text="""\
Expand All @@ -175,9 +195,11 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene
like_count=10,
repost_count=20,
impression_count=30,
links=[link_samples[0]],
),
post_factory.build(
post_id="2234567890123456791",
link=None,
x_user_id="1234567890123456781",
x_user=x_user_samples[0],
text="""\
Expand All @@ -189,18 +211,47 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene
like_count=10,
repost_count=20,
impression_count=30,
links=[link_samples[1]],
),
post_factory.build(
post_id="2234567890123456801",
link=None,
x_user_id="1234567890123456782",
x_user=x_user_samples[1],
text="""\
次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ #旅行 #バケーション""",
次の休暇はここに決めた!🌴🏖️ 見てみて~ https://t.co/xxxxxxxxxxx/ https://t.co/wwwwwwwwwww/ #旅行 #バケーション""",
media_details=None,
created_at=1154921800000,
like_count=10,
repost_count=20,
impression_count=30,
links=[link_samples[0], link_samples[3]],
),
post_factory.build(
post_id="2234567890123456811",
link=None,
x_user_id="1234567890123456782",
x_user=x_user_samples[1],
text="https://t.co/zzzzzzzzzzz/ https://t.co/wwwwwwwwwww/",
media_details=None,
created_at=1154922900000,
like_count=10,
repost_count=20,
impression_count=30,
links=[link_samples[2], link_samples[3]],
),
post_factory.build(
post_id="2234567890123456821",
link=None,
x_user_id="1234567890123456783",
x_user=x_user_samples[2],
text="empty",
media_details=None,
created_at=1154923900000,
like_count=10,
repost_count=20,
impression_count=30,
links=[],
),
]
yield posts
Expand Down
9 changes: 6 additions & 3 deletions api/tests/routers/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def test_posts_get_limit_and_offset(client: TestClient, post_samples: List[Post]
res_json = response.json()
assert res_json == {
"data": [json.loads(d.model_dump_json()) for d in post_samples[1:3]],
"meta": {"next": None, "prev": "http://testserver/api/v1/data/posts?offset=0&limit=2"},
"meta": {
"next": "http://testserver/api/v1/data/posts?offset=3&limit=2",
"prev": "http://testserver/api/v1/data/posts?offset=0&limit=2",
},
}


Expand Down Expand Up @@ -72,7 +75,7 @@ def test_posts_get_has_created_at_filter_start(client: TestClient, post_samples:
assert response.status_code == 200
res_json = response.json()
assert res_json == {
"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2)],
"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2, 3, 4)],
"meta": {"next": None, "prev": None},
}

Expand All @@ -99,7 +102,7 @@ def test_posts_get_created_at_start_filter_accepts_integer(client: TestClient, p
assert response.status_code == 200
res_json = response.json()
assert res_json == {
"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2)],
"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2, 3, 4)],
"meta": {"next": None, "prev": None},
}

Expand Down
79 changes: 75 additions & 4 deletions common/birdxplorer_common/models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from enum import Enum
from random import Random
from typing import Any, Dict, List, Literal, Optional, Type, TypeAlias, TypeVar, Union
from uuid import UUID

from pydantic import BaseModel as PydanticBaseModel
from pydantic import ConfigDict, GetCoreSchemaHandler, HttpUrl, TypeAdapter
from pydantic import (
ConfigDict,
GetCoreSchemaHandler,
HttpUrl,
TypeAdapter,
model_validator,
)
from pydantic.alias_generators import to_camel
from pydantic.main import IncEx
from pydantic_core import core_schema

IncEx: TypeAlias = "set[int] | set[str] | dict[int, IncEx] | dict[str, IncEx] | None"
StrT = TypeVar("StrT", bound="BaseString")
IntT = TypeVar("IntT", bound="BaseInt")
FloatT = TypeVar("FloatT", bound="BaseFloat")
Expand Down Expand Up @@ -467,8 +475,8 @@ def model_dump_json(
self,
*,
indent: int | None = None,
include: IncEx = None,
exclude: IncEx = None,
include: IncEx | None = None,
exclude: IncEx | None = None,
context: Dict[str, Any] | None = None,
by_alias: bool = True,
exclude_unset: bool = False,
Expand Down Expand Up @@ -677,6 +685,68 @@ class XUser(BaseModel):
MediaDetails: TypeAlias = List[HttpUrl] | None


class LinkId(UUID):
"""
>>> LinkId("53dc4ed6-fc9b-54ef-1afa-90f1125098c5")
LinkId('53dc4ed6-fc9b-54ef-1afa-90f1125098c5')
>>> LinkId(UUID("53dc4ed6-fc9b-54ef-1afa-90f1125098c5"))
LinkId('53dc4ed6-fc9b-54ef-1afa-90f1125098c5')
"""

def __init__(
self,
hex: str | None = None,
int: int | None = None,
) -> None:
if isinstance(hex, UUID):
hex = str(hex)
super().__init__(hex, int=int)

@classmethod
def from_url(cls, url: HttpUrl) -> "LinkId":
"""
>>> LinkId.from_url("https://example.com/")
LinkId('d5d15194-6574-0c01-8f6f-15abd72b2cf6')
"""
random_number_generator = Random()
random_number_generator.seed(str(url).encode("utf-8"))
return LinkId(int=random_number_generator.getrandbits(128))

@classmethod
def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.no_info_plain_validator_function(
cls.validate,
serialization=core_schema.plain_serializer_function_ser_schema(cls.serialize, when_used="json"),
)

@classmethod
def validate(cls, v: Any) -> "LinkId":
return cls(v)

def serialize(self) -> str:
return str(self)


class Link(BaseModel):
"""
>>> Link.model_validate_json('{"linkId": "d5d15194-6574-0c01-8f6f-15abd72b2cf6", "url": "https://example.com"}')
Link(link_id=LinkId('d5d15194-6574-0c01-8f6f-15abd72b2cf6'), url=Url('https://example.com/'))
>>> Link(url="https://example.com/")
Link(link_id=LinkId('d5d15194-6574-0c01-8f6f-15abd72b2cf6'), url=Url('https://example.com/'))
>>> Link(link_id=UUID("d5d15194-6574-0c01-8f6f-15abd72b2cf6"), url="https://example.com/")
Link(link_id=LinkId('d5d15194-6574-0c01-8f6f-15abd72b2cf6'), url=Url('https://example.com/'))
""" # noqa: E501

link_id: LinkId
url: HttpUrl

@model_validator(mode="before")
def validate_link_id(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "link_id" not in values:
values["link_id"] = LinkId.from_url(values["url"])
return values


class Post(BaseModel):
post_id: PostId
link: Optional[HttpUrl] = None
Expand All @@ -688,6 +758,7 @@ class Post(BaseModel):
like_count: NonNegativeInt
repost_count: NonNegativeInt
impression_count: NonNegativeInt
links: List[Link] = []


class PaginationMeta(BaseModel):
Expand Down
24 changes: 22 additions & 2 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from sqlalchemy import ForeignKey, create_engine, func, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String
from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String, Uuid

from .models import BinaryBool, LanguageIdentifier, MediaDetails, NonNegativeInt
from .models import BinaryBool, LanguageIdentifier
from .models import Link as LinkModel
from .models import LinkId, MediaDetails, NonNegativeInt
from .models import Note as NoteModel
from .models import NoteId, NotesClassification, NotesHarmful, ParticipantId
from .models import Post as PostModel
Expand All @@ -34,6 +36,7 @@ def adapt_pydantic_http_url(url: AnyUrl) -> AsIs:

class Base(DeclarativeBase):
type_annotation_map = {
LinkId: Uuid,
TopicId: Integer,
TopicLabel: JSON,
NoteId: String,
Expand Down Expand Up @@ -88,6 +91,21 @@ class XUserRecord(Base):
following_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)


class LinkRecord(Base):
__tablename__ = "links"

link_id: Mapped[LinkId] = mapped_column(primary_key=True)
url: Mapped[HttpUrl] = mapped_column(nullable=False, index=True)


class PostLinkAssociation(Base):
__tablename__ = "post_link"

post_id: Mapped[PostId] = mapped_column(ForeignKey("posts.post_id"), primary_key=True)
link_id: Mapped[LinkId] = mapped_column(ForeignKey("links.link_id"), primary_key=True)
link: Mapped[LinkRecord] = relationship()


class PostRecord(Base):
__tablename__ = "posts"

Expand All @@ -100,6 +118,7 @@ class PostRecord(Base):
like_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
repost_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
impression_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
links: Mapped[List[PostLinkAssociation]] = relationship()


class RowNoteRecord(Base):
Expand Down Expand Up @@ -196,6 +215,7 @@ def _post_record_to_model(cls, post_record: PostRecord) -> PostModel:
like_count=post_record.like_count,
repost_count=post_record.repost_count,
impression_count=post_record.impression_count,
links=[LinkModel(link_id=link.link_id, url=link.link.url) for link in post_record.links],
)

def get_user_enrollment_by_participant_id(self, participant_id: ParticipantId) -> UserEnrollment:
Expand Down
1 change: 1 addition & 0 deletions common/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"sqlalchemy",
"pydantic_settings",
"JSON-log-formatter",
"ulid-py",
]

[project.urls]
Expand Down
Loading
Loading