Skip to content

Commit

Permalink
Merge pull request #50 from codeforjapan/feature/issue-41-add-posts-e…
Browse files Browse the repository at this point in the history
…ndpoint

feat(app): add posts route
  • Loading branch information
yu23ki14 authored Apr 15, 2024
2 parents 886e0dc + 6a84835 commit bc3fd5c
Show file tree
Hide file tree
Showing 10 changed files with 666 additions and 18 deletions.
30 changes: 30 additions & 0 deletions birdxplorer/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import csv
import io
from urllib.parse import parse_qs as parse_query_string
from urllib.parse import urlencode as encode_query_string

from fastapi import FastAPI
from pydantic.alias_generators import to_snake
from starlette.types import ASGIApp, Receive, Scope, Send

from .logger import get_logger
from .routers.data import gen_router as gen_data_router
Expand All @@ -7,10 +14,33 @@
from .storage import gen_storage


class QueryStringFlatteningMiddleware:
def __init__(self, app: ASGIApp) -> None:
self._app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
query_string = scope.get("query_string")
if not isinstance(query_string, bytes):
query_string = b""
query_string = query_string.decode("utf-8")
if scope["type"] == "http" and query_string:
parsed = parse_query_string(query_string)
flattened = {}
for name, values in parsed.items():
flattened[to_snake(name)] = [c for value in values for r in csv.reader(io.StringIO(value)) for c in r]

scope["query_string"] = encode_query_string(flattened, doseq=True).encode("utf-8")

await self._app(scope, receive, send)
else:
await self._app(scope, receive, send)


def gen_app(settings: GlobalSettings) -> FastAPI:
_ = get_logger(level=settings.logger_settings.level)
storage = gen_storage(settings=settings)
app = FastAPI()
app.add_middleware(QueryStringFlatteningMiddleware)
app.include_router(gen_system_router(), prefix="/api/v1/system")
app.include_router(gen_data_router(storage=storage), prefix="/api/v1/data")
return app
54 changes: 45 additions & 9 deletions birdxplorer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict, List, Literal, Type, TypeAlias, TypeVar, Union

from pydantic import BaseModel as PydanticBaseModel
from pydantic import ConfigDict, GetCoreSchemaHandler, TypeAdapter
from pydantic import ConfigDict, GetCoreSchemaHandler, HttpUrl, TypeAdapter
from pydantic.alias_generators import to_camel
from pydantic_core import core_schema

Expand Down Expand Up @@ -128,21 +128,21 @@ def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9A-F]{64}$")


class NineToNineteenDigitsDecimalString(BaseString):
class UpToNineteenDigitsDecimalString(BaseString):
"""
>>> NineToNineteenDigitsDecimalString.from_str("test")
>>> UpToNineteenDigitsDecimalString.from_str("test")
Traceback (most recent call last):
...
pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str]
String should match pattern '^[0-9]{9,19}$' [type=string_pattern_mismatch, input_value='test', input_type=str]
String should match pattern '^[0-9]{1,19}$' [type=string_pattern_mismatch, input_value='test', input_type=str]
...
>>> NineToNineteenDigitsDecimalString.from_str("1234567890123456789")
NineToNineteenDigitsDecimalString('1234567890123456789')
>>> UpToNineteenDigitsDecimalString.from_str("1234567890123456789")
UpToNineteenDigitsDecimalString('1234567890123456789')
"""

@classmethod
def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9]{9,19}$")
return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9]{1,19}$")


class NonEmptyStringMixin(BaseString):
Expand Down Expand Up @@ -467,23 +467,27 @@ def model_dump_json(
indent: int | None = None,
include: IncEx = None,
exclude: IncEx = None,
context: Dict[str, Any] | None = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
round_trip: bool = False,
warnings: bool = True,
warnings: bool | Literal["none"] | Literal["warn"] | Literal["error"] = True,
serialize_as_any: bool = False,
) -> str:
return super(BaseModel, self).model_dump_json(
indent=indent,
include=include,
exclude=exclude,
context=context,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
round_trip=round_trip,
warnings=warnings,
serialize_as_any=serialize_as_any,
)


Expand Down Expand Up @@ -561,7 +565,7 @@ class NotesValidationDifficulty(str, Enum):
empty = ""


class TweetId(NineToNineteenDigitsDecimalString): ...
class TweetId(UpToNineteenDigitsDecimalString): ...


class NoteData(BaseModel):
Expand Down Expand Up @@ -630,3 +634,35 @@ class Note(BaseModel):
topics: List[Topic]
summary: SummaryString
created_at: TwitterTimestamp


class UserId(UpToNineteenDigitsDecimalString): ...


class UserName(NonEmptyTrimmedString): ...


class XUser(BaseModel):
user_id: UserId
name: UserName
profile_image: HttpUrl
followers_count: NonNegativeInt
following_count: NonNegativeInt


class PostId(UpToNineteenDigitsDecimalString): ...


MediaDetails: TypeAlias = List[HttpUrl] | None


class Post(BaseModel):
post_id: PostId
x_user_id: UserId
x_user: XUser
text: str
media_details: MediaDetails = None
created_at: TwitterTimestamp
like_count: NonNegativeInt
repost_count: NonNegativeInt
impression_count: NonNegativeInt
65 changes: 62 additions & 3 deletions birdxplorer/routers/data.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,47 @@
from typing import List
from datetime import timezone
from typing import List, Union

from fastapi import APIRouter
from dateutil.parser import parse as dateutil_parse
from fastapi import APIRouter, HTTPException, Query

from ..models import BaseModel, ParticipantId, Topic, UserEnrollment
from ..models import (
BaseModel,
ParticipantId,
Post,
PostId,
Topic,
TwitterTimestamp,
UserEnrollment,
)
from ..storage import Storage


class TopicListResponse(BaseModel):
data: List[Topic]


class PostListResponse(BaseModel):
data: List[Post]


def str_to_twitter_timestamp(s: str) -> TwitterTimestamp:
try:
return TwitterTimestamp.from_int(int(s))
except ValueError:
pass
try:
tmp = dateutil_parse(s)
if tmp.tzinfo is None:
tmp = tmp.replace(tzinfo=timezone.utc)
return TwitterTimestamp.from_int(int(tmp.timestamp() * 1000))
except ValueError:
raise HTTPException(status_code=422, detail=f"Invalid TwitterTimestamp string: {s}")


def ensure_twitter_timestamp(t: Union[str, TwitterTimestamp]) -> TwitterTimestamp:
return str_to_twitter_timestamp(t) if isinstance(t, str) else t


def gen_router(storage: Storage) -> APIRouter:
router = APIRouter()

Expand All @@ -24,4 +56,31 @@ def get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> User
def get_topics() -> TopicListResponse:
return TopicListResponse(data=list(storage.get_topics()))

@router.get("/posts", response_model=PostListResponse)
def get_posts(
post_id: Union[List[PostId], None] = Query(default=None),
created_at_start: Union[None, TwitterTimestamp, str] = Query(default=None),
created_at_end: Union[None, TwitterTimestamp, str] = Query(default=None),
) -> PostListResponse:
if post_id is not None:
return PostListResponse(data=list(storage.get_posts_by_ids(post_ids=post_id)))
if created_at_start is not None:
if created_at_end is not None:
return PostListResponse(
data=list(
storage.get_posts_by_created_at_range(
start=ensure_twitter_timestamp(created_at_start),
end=ensure_twitter_timestamp(created_at_end),
)
)
)
return PostListResponse(
data=list(storage.get_posts_by_created_at_start(start=ensure_twitter_timestamp(created_at_start)))
)
if created_at_end is not None:
return PostListResponse(
data=list(storage.get_posts_by_created_at_end(end=ensure_twitter_timestamp(created_at_end)))
)
return PostListResponse(data=list(storage.get_posts()))

return router
106 changes: 104 additions & 2 deletions birdxplorer/storage.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,42 @@
from typing import Generator, List

from psycopg2.extensions import AsIs, register_adapter
from pydantic import AnyUrl, HttpUrl
from sqlalchemy import ForeignKey, create_engine, func, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
from sqlalchemy.types import DECIMAL, JSON, Integer, String

from .models import LanguageIdentifier, NoteId, ParticipantId, SummaryString
from .models import (
LanguageIdentifier,
MediaDetails,
NonNegativeInt,
NoteId,
ParticipantId,
)
from .models import Post as PostModel
from .models import PostId, SummaryString
from .models import Topic as TopicModel
from .models import TopicId, TopicLabel, TweetId, TwitterTimestamp, UserEnrollment
from .models import (
TopicId,
TopicLabel,
TweetId,
TwitterTimestamp,
UserEnrollment,
UserId,
UserName,
)
from .models import XUser as XUserModel
from .settings import GlobalSettings


def adapt_pydantic_http_url(url: AnyUrl) -> AsIs:
return AsIs(repr(str(url)))


register_adapter(AnyUrl, adapt_pydantic_http_url)


class Base(DeclarativeBase):
type_annotation_map = {
TopicId: Integer,
Expand All @@ -21,6 +47,11 @@ class Base(DeclarativeBase):
LanguageIdentifier: String,
TwitterTimestamp: DECIMAL,
SummaryString: String,
UserId: String,
UserName: String,
HttpUrl: String,
NonNegativeInt: DECIMAL,
MediaDetails: JSON,
}


Expand Down Expand Up @@ -50,6 +81,30 @@ class TopicRecord(Base):
label: Mapped[TopicLabel] = mapped_column(nullable=False)


class XUserRecord(Base):
__tablename__ = "x_users"

user_id: Mapped[UserId] = mapped_column(primary_key=True)
name: Mapped[UserName] = mapped_column(nullable=False)
profile_image: Mapped[HttpUrl] = mapped_column(nullable=False)
followers_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
following_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)


class PostRecord(Base):
__tablename__ = "posts"

post_id: Mapped[TweetId] = mapped_column(primary_key=True)
user_id: Mapped[UserId] = mapped_column(ForeignKey("x_users.user_id"), nullable=False)
user: Mapped[XUserRecord] = relationship()
text: Mapped[SummaryString] = mapped_column(nullable=False)
media_details: Mapped[MediaDetails] = mapped_column()
created_at: Mapped[TwitterTimestamp] = mapped_column(nullable=False)
like_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
repost_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
impression_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)


class Storage:
def __init__(self, engine: Engine) -> None:
self._engine = engine
Expand All @@ -58,6 +113,26 @@ def __init__(self, engine: Engine) -> None:
def engine(self) -> Engine:
return self._engine

@classmethod
def _post_record_to_model(cls, post_record: PostRecord) -> PostModel:
return PostModel(
post_id=post_record.post_id,
x_user_id=post_record.user_id,
x_user=XUserModel(
user_id=post_record.user.user_id,
name=post_record.user.name,
profile_image=post_record.user.profile_image,
followers_count=post_record.user.followers_count,
following_count=post_record.user.following_count,
),
text=post_record.text,
media_details=post_record.media_details,
created_at=post_record.created_at,
like_count=post_record.like_count,
repost_count=post_record.repost_count,
impression_count=post_record.impression_count,
)

def get_user_enrollment_by_participant_id(self, participant_id: ParticipantId) -> UserEnrollment:
raise NotImplementedError

Expand All @@ -77,6 +152,33 @@ def get_topics(self) -> Generator[TopicModel, None, None]:
topic_id=topic_record.topic_id, label=topic_record.label, reference_count=reference_count or 0
)

def get_posts(self) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).all():
yield self._post_record_to_model(post_record)

def get_posts_by_ids(self, post_ids: List[PostId]) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.post_id.in_(post_ids)).all():
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_range(
self, start: TwitterTimestamp, end: TwitterTimestamp
) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.created_at.between(start, end)).all():
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_start(self, start: TwitterTimestamp) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.created_at >= start).all():
yield self._post_record_to_model(post_record)

def get_posts_by_created_at_end(self, end: TwitterTimestamp) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).filter(PostRecord.created_at < end).all():
yield self._post_record_to_model(post_record)


def gen_storage(settings: GlobalSettings) -> Storage:
engine = create_engine(settings.storage_settings.sqlalchemy_database_url)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ dev=[
"uvicorn",
"polyfactory",
"httpx",
"types-psycopg2",
]
prod=[
"psycopg2"
Expand Down
Loading

0 comments on commit bc3fd5c

Please sign in to comment.