Skip to content

Commit

Permalink
feat(app): add posts route
Browse files Browse the repository at this point in the history
  • Loading branch information
osoken committed Mar 23, 2024
1 parent 8be38c7 commit 911efa1
Show file tree
Hide file tree
Showing 9 changed files with 394 additions and 15 deletions.
48 changes: 40 additions & 8 deletions birdxplorer/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Dict, List, Literal, Type, TypeAlias, TypeVar, Union

from pydantic import BaseModel as PydanticBaseModel
from pydantic import ConfigDict, GetCoreSchemaHandler, TypeAdapter
from pydantic import ConfigDict, GetCoreSchemaHandler, HttpUrl, TypeAdapter
from pydantic.alias_generators import to_camel
from pydantic_core import core_schema

Expand Down Expand Up @@ -128,21 +128,21 @@ def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9A-F]{64}$")


class NineToNineteenDigitsDecimalString(BaseString):
class UpToNineteenDigitsDecimalString(BaseString):
"""
>>> NineToNineteenDigitsDecimalString.from_str("test")
>>> UpToNineteenDigitsDecimalString.from_str("test")
Traceback (most recent call last):
...
pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str]
String should match pattern '^[0-9]{9,19}$' [type=string_pattern_mismatch, input_value='test', input_type=str]
String should match pattern '^[0-9]{1,19}$' [type=string_pattern_mismatch, input_value='test', input_type=str]
...
>>> NineToNineteenDigitsDecimalString.from_str("1234567890123456789")
NineToNineteenDigitsDecimalString('1234567890123456789')
>>> UpToNineteenDigitsDecimalString.from_str("1234567890123456789")
UpToNineteenDigitsDecimalString('1234567890123456789')
"""

@classmethod
def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9]{9,19}$")
return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9]{1,19}$")


class NonEmptyStringMixin(BaseString):
Expand Down Expand Up @@ -561,7 +561,7 @@ class NotesValidationDifficulty(str, Enum):
empty = ""


class TweetId(NineToNineteenDigitsDecimalString): ...
class TweetId(UpToNineteenDigitsDecimalString): ...


class NoteData(BaseModel):
Expand Down Expand Up @@ -630,3 +630,35 @@ class Note(BaseModel):
topics: List[Topic]
summary: SummaryString
created_at: TwitterTimestamp


class UserId(UpToNineteenDigitsDecimalString): ...


class UserName(NonEmptyTrimmedString): ...


class XUser(BaseModel):
user_id: UserId
name: UserName
profile_image: HttpUrl
followers_count: NonNegativeInt
following_count: NonNegativeInt


class PostId(UpToNineteenDigitsDecimalString): ...


MediaDetails: TypeAlias = List[HttpUrl] | None


class Post(BaseModel):
post_id: PostId
x_user_id: UserId
x_user: XUser
text: str
media_details: MediaDetails = None
created_at: TwitterTimestamp
like_count: NonNegativeInt
repost_count: NonNegativeInt
impression_count: NonNegativeInt
10 changes: 9 additions & 1 deletion birdxplorer/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

from fastapi import APIRouter

from ..models import BaseModel, ParticipantId, Topic, UserEnrollment
from ..models import BaseModel, ParticipantId, Post, Topic, UserEnrollment
from ..storage import Storage


class TopicListResponse(BaseModel):
data: List[Topic]


class PostListResponse(BaseModel):
data: List[Post]


def gen_router(storage: Storage) -> APIRouter:
router = APIRouter()

Expand All @@ -24,4 +28,8 @@ def get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> User
def get_topics() -> TopicListResponse:
return TopicListResponse(data=list(storage.get_topics()))

@router.get("/posts", response_model=PostListResponse)
def get_posts() -> PostListResponse:
return PostListResponse(data=list(storage.get_posts()))

return router
80 changes: 78 additions & 2 deletions birdxplorer/storage.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,42 @@
from typing import Generator, List

from psycopg2.extensions import AsIs, register_adapter
from pydantic import AnyUrl, HttpUrl
from sqlalchemy import ForeignKey, create_engine, func, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
from sqlalchemy.types import DECIMAL, JSON, Integer, String

from .models import LanguageIdentifier, NoteId, ParticipantId, SummaryString
from .models import (
LanguageIdentifier,
MediaDetails,
NonNegativeInt,
NoteId,
ParticipantId,
)
from .models import Post as PostModel
from .models import SummaryString
from .models import Topic as TopicModel
from .models import TopicId, TopicLabel, TweetId, TwitterTimestamp, UserEnrollment
from .models import (
TopicId,
TopicLabel,
TweetId,
TwitterTimestamp,
UserEnrollment,
UserId,
UserName,
)
from .models import XUser as XUserModel
from .settings import GlobalSettings


def adapt_pydantic_http_url(url: AnyUrl) -> AsIs:
return AsIs(repr(str(url)))


register_adapter(AnyUrl, adapt_pydantic_http_url)


class Base(DeclarativeBase):
type_annotation_map = {
TopicId: Integer,
Expand All @@ -21,6 +47,11 @@ class Base(DeclarativeBase):
LanguageIdentifier: String,
TwitterTimestamp: DECIMAL,
SummaryString: String,
UserId: String,
UserName: String,
HttpUrl: String,
NonNegativeInt: DECIMAL,
MediaDetails: JSON,
}


Expand Down Expand Up @@ -50,6 +81,30 @@ class TopicRecord(Base):
label: Mapped[TopicLabel] = mapped_column(nullable=False)


class XUserRecord(Base):
__tablename__ = "x_users"

user_id: Mapped[UserId] = mapped_column(primary_key=True)
name: Mapped[UserName] = mapped_column(nullable=False)
profile_image: Mapped[HttpUrl] = mapped_column(nullable=False)
followers_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
following_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)


class PostRecord(Base):
__tablename__ = "posts"

post_id: Mapped[TweetId] = mapped_column(primary_key=True)
user_id: Mapped[UserId] = mapped_column(ForeignKey("x_users.user_id"), nullable=False)
user: Mapped[XUserRecord] = relationship()
text: Mapped[SummaryString] = mapped_column(nullable=False)
media_details: Mapped[MediaDetails] = mapped_column()
created_at: Mapped[TwitterTimestamp] = mapped_column(nullable=False)
like_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
repost_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
impression_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)


class Storage:
def __init__(self, engine: Engine) -> None:
self._engine = engine
Expand Down Expand Up @@ -77,6 +132,27 @@ def get_topics(self) -> Generator[TopicModel, None, None]:
topic_id=topic_record.topic_id, label=topic_record.label, reference_count=reference_count or 0
)

def get_posts(self) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).all():
yield PostModel(
post_id=post_record.post_id,
x_user_id=post_record.user_id,
x_user=XUserModel(
user_id=post_record.user.user_id,
name=post_record.user.name,
profile_image=post_record.user.profile_image,
followers_count=post_record.user.followers_count,
following_count=post_record.user.following_count,
),
text=post_record.text,
media_details=post_record.media_details,
created_at=post_record.created_at,
like_count=post_record.like_count,
repost_count=post_record.repost_count,
impression_count=post_record.impression_count,
)


def gen_storage(settings: GlobalSettings) -> Storage:
engine = create_engine(settings.storage_settings.sqlalchemy_database_url)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ dev=[
"uvicorn",
"polyfactory",
"httpx",
"types-psycopg2",
]
prod=[
"psycopg2"
Expand Down
62 changes: 62 additions & 0 deletions scripts/migrations/convert_data_from_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,37 @@
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("notes_file")
parser.add_argument("posts_file")
parser.add_argument("x_users_file")
parser.add_argument("output_dir")
parser.add_argument("--notes-file-name", default="notes.csv")
parser.add_argument("--topics-file-name", default="topics.csv")
parser.add_argument("--notes-topics-association-file-name", default="note_topic.csv")
parser.add_argument("--topic-threshold", type=int, default=5)
parser.add_argument("--posts-file-name", default="posts.csv")
parser.add_argument("--x-users-file-name", default="x_users.csv")

args = parser.parse_args()

with open(args.notes_file, "r", encoding="utf-8") as fin:
notes = list(csv.DictReader(fin))
for d in notes:
d["topic"] = [t.strip() for t in d["topic"].split(",")]

topics_with_count = Counter(t for d in notes for t in d["topic"])
topic_name_to_id_map = {t: i for i, (t, c) in enumerate(topics_with_count.items()) if c > args.topic_threshold}

with open(args.posts_file, "r", encoding="utf-8") as fin:
posts = list(csv.DictReader(fin))
for d in posts:
d["media_details"] = None if len(d["media_details"]) == 0 else json.loads(d["media_details"])

with open(args.x_users_file, "r", encoding="utf-8") as fin:
x_users = list(csv.DictReader(fin))
for d in x_users:
d["followers_count"] = int(d["followers_count"])
d["following_count"] = int(d["following_count"])

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

Expand Down Expand Up @@ -52,3 +68,49 @@
for t in d["topic"]:
if t in topic_name_to_id_map:
writer.writerow({"note_id": d["note_id"], "topic_id": topic_name_to_id_map[t]})

with open(os.path.join(args.output_dir, args.posts_file_name), "w", encoding="utf-8") as fout:
writer = csv.DictWriter(
fout,
fieldnames=[
"post_id",
"user_id",
"text",
"media_details",
"created_at",
"like_count",
"repost_count",
"impression_count",
],
)
writer.writeheader()
for d in posts:
writer.writerow(
{
"post_id": d["post_id"],
"user_id": d["user_id"],
"text": d["text"],
"media_details": json.dumps(d["media_details"]),
"created_at": 1288834974657,
"like_count": 0,
"repost_count": 0,
"impression_count": 0,
}
)

with open(os.path.join(args.output_dir, args.x_users_file_name), "w", encoding="utf-8") as fout:
writer = csv.DictWriter(
fout,
fieldnames=["user_id", "name", "profile_image", "followers_count", "following_count"],
)
writer.writeheader()
for d in x_users:
writer.writerow(
{
"user_id": d["user_id"],
"name": d["name"],
"profile_image": d["profile_image"],
"followers_count": d["followers_count"],
"following_count": d["following_count"],
}
)
46 changes: 46 additions & 0 deletions scripts/migrations/migrate_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
Base,
NoteRecord,
NoteTopicAssociation,
PostRecord,
TopicRecord,
XUserRecord,
gen_storage,
)

Expand All @@ -22,6 +24,9 @@
parser.add_argument("--notes-file-name", default="notes.csv")
parser.add_argument("--topics-file-name", default="topics.csv")
parser.add_argument("--notes-topics-association-file-name", default="note_topic.csv")
parser.add_argument("--posts-file-name", default="posts.csv")
parser.add_argument("--x-users-file-name", default="x_users.csv")
parser.add_argument("--limit-number-of-post-rows", type=int, default=None)
load_dotenv()
args = parser.parse_args()
settings = GlobalSettings()
Expand Down Expand Up @@ -71,4 +76,45 @@
)
)
sess.commit()
with open(os.path.join(args.data_dir, args.x_users_file_name), "r", encoding="utf-8") as fin:
for d in csv.DictReader(fin):
d["followers_count"] = int(d["followers_count"])
d["following_count"] = int(d["following_count"])
if sess.query(XUserRecord).filter(XUserRecord.user_id == d["user_id"]).count() > 0:
continue
sess.add(
XUserRecord(
user_id=d["user_id"],
name=d["name"],
profile_image=d["profile_image"],
followers_count=d["followers_count"],
following_count=d["following_count"],
)
)
sess.commit()
with open(os.path.join(args.data_dir, args.posts_file_name), "r", encoding="utf-8") as fin:
for d in csv.DictReader(fin):
if (
args.limit_number_of_post_rows is not None
and sess.query(PostRecord).count() >= args.limit_number_of_post_rows
):
break
d["like_count"] = int(d["like_count"])
d["repost_count"] = int(d["repost_count"])
d["impression_count"] = int(d["impression_count"])
if sess.query(PostRecord).filter(PostRecord.post_id == d["post_id"]).count() > 0:
continue
sess.add(
PostRecord(
post_id=d["post_id"],
user_id=d["user_id"],
text=d["text"],
media_details=json.loads(d["media_details"]) if len(d["media_details"]) > 0 else None,
created_at=d["created_at"],
like_count=d["like_count"],
repost_count=d["repost_count"],
impression_count=d["impression_count"],
)
)
sess.commit()
logger.info("Migration is done")
Loading

0 comments on commit 911efa1

Please sign in to comment.