Skip to content

Commit

Permalink
feat(app): add topics route
Browse files Browse the repository at this point in the history
  • Loading branch information
osoken committed Feb 25, 2024
1 parent a2d1134 commit b416793
Show file tree
Hide file tree
Showing 9 changed files with 409 additions and 114 deletions.
51 changes: 46 additions & 5 deletions birdxplorer/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, Literal, Type, TypeAlias, TypeVar, Union
from typing import Any, Dict, List, Literal, Type, TypeAlias, TypeVar, Union

from pydantic import BaseModel as PydanticBaseModel
from pydantic import ConfigDict, Field, GetCoreSchemaHandler, TypeAdapter
from pydantic import ConfigDict, GetCoreSchemaHandler, TypeAdapter
from pydantic.alias_generators import to_camel
from pydantic_core import core_schema

Expand Down Expand Up @@ -128,6 +128,23 @@ def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9A-F]{64}$")


class NineToNineteenDigitsDecimalString(BaseString):
"""
>>> NineToNineteenDigitsDecimalString.from_str("test")
Traceback (most recent call last):
...
pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str]
String should match pattern '^[0-9]{9,19}$' [type=string_pattern_mismatch, input_value='test', input_type=str]
...
>>> NineToNineteenDigitsDecimalString.from_str("1234567890123456789")
NineToNineteenDigitsDecimalString('1234567890123456789')
"""

@classmethod
def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9]{9,19}$")


class NonEmptyStringMixin(BaseString):
@classmethod
def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
Expand Down Expand Up @@ -544,11 +561,18 @@ class NotesValidationDifficulty(str, Enum):
empty = ""


class Note(BaseModel):
class TweetId(NineToNineteenDigitsDecimalString): ...


class NoteData(BaseModel):
"""
This is for validating the original data from notes.csv.
"""

note_id: NoteId
note_author_participant_id: ParticipantId
created_at_millis: TwitterTimestamp
tweet_id: str = Field(pattern=r"^[0-9]{9,19}$")
tweet_id: TweetId
believable: NotesBelievable
misleading_other: BinaryBool
misleading_factual_error: BinaryBool
Expand Down Expand Up @@ -585,7 +609,24 @@ class LanguageIdentifier(str, Enum):
class TopicLabelString(NonEmptyTrimmedString): ...


TopicLabel: TypeAlias = Dict[LanguageIdentifier, TopicLabelString]


class Topic(BaseModel):
model_config = ConfigDict(from_attributes=True)

topic_id: TopicId
label: Dict[LanguageIdentifier, TopicLabelString]
label: TopicLabel
reference_count: NonNegativeInt


class SummaryString(NonEmptyTrimmedString): ...


class Note(BaseModel):
note_id: NoteId
post_id: TweetId
language: LanguageIdentifier
topics: List[Topic]
summary: SummaryString
created_at: TwitterTimestamp
4 changes: 2 additions & 2 deletions birdxplorer/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class PostgresStorageSettings(BaseSettings):

@computed_field # type: ignore[misc]
@property
def sqlalchemy_database_url(self) -> PostgresDsn:
def sqlalchemy_database_url(self) -> str:
return PostgresDsn(
url=f"postgresql://{self.username}:"
f"{self.password.replace('@', '%40')}@{self.host}:{self.port}/{self.database}"
)
).unicode_string()


class GlobalSettings(BaseSettings):
Expand Down
77 changes: 72 additions & 5 deletions birdxplorer/storage.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,83 @@
from typing import Generator
from typing import Generator, List

from .models import ParticipantId, Topic, UserEnrollment
from sqlalchemy import ForeignKey, create_engine, func, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
from sqlalchemy.types import DECIMAL, JSON, Integer, String

from .models import LanguageIdentifier, NoteId, ParticipantId, SummaryString
from .models import Topic as TopicModel
from .models import TopicId, TopicLabel, TweetId, TwitterTimestamp, UserEnrollment
from .settings import GlobalSettings


class Base(DeclarativeBase):
type_annotation_map = {
TopicId: Integer,
TopicLabel: JSON,
NoteId: String,
ParticipantId: String,
TweetId: String,
LanguageIdentifier: String,
TwitterTimestamp: DECIMAL,
SummaryString: String,
}


class NoteTopicAssociation(Base):
__tablename__ = "note_topic"

note_id: Mapped[NoteId] = mapped_column(ForeignKey("notes.note_id"), primary_key=True)
topic_id: Mapped[TopicId] = mapped_column(ForeignKey("topics.topic_id"), primary_key=True)
topic: Mapped["TopicRecord"] = relationship()


class NoteRecord(Base):
__tablename__ = "notes"

note_id: Mapped[NoteId] = mapped_column(primary_key=True)
post_id: Mapped[TweetId] = mapped_column(nullable=False)
topics: Mapped[List[NoteTopicAssociation]] = relationship()
language: Mapped[LanguageIdentifier] = mapped_column(nullable=False)
summary: Mapped[SummaryString] = mapped_column(nullable=False)
created_at: Mapped[TwitterTimestamp] = mapped_column(nullable=False)


class TopicRecord(Base):
__tablename__ = "topics"

topic_id: Mapped[TopicId] = mapped_column(primary_key=True)
label: Mapped[TopicLabel] = mapped_column(nullable=False)


class Storage:
def __init__(self, engine: Engine) -> None:
self._engine = engine

@property
def engine(self) -> Engine:
return self._engine

def get_user_enrollment_by_participant_id(self, participant_id: ParticipantId) -> UserEnrollment:
raise NotImplementedError

def get_topics(self) -> Generator[Topic, None, None]:
raise NotImplementedError
def get_topics(self) -> Generator[TopicModel, None, None]:
with Session(self.engine) as sess:
subq = (
select(NoteTopicAssociation.topic_id, func.count().label("reference_count"))
.group_by(NoteTopicAssociation.topic_id)
.subquery()
)
for topic_record, reference_count in (
sess.query(TopicRecord, subq.c.reference_count)
.outerjoin(subq, TopicRecord.topic_id == subq.c.topic_id)
.all()
):
yield TopicModel(
topic_id=topic_record.topic_id, label=topic_record.label, reference_count=reference_count or 0
)


def gen_storage(settings: GlobalSettings) -> Storage:
return Storage()
engine = create_engine(settings.storage_settings.sqlalchemy_database_url)
return Storage(engine=engine)
2 changes: 1 addition & 1 deletion compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ services:
container_name: postgres_container
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: birdxplorer
POSTGRES_PASSWORD: ${BX_STORAGE_SETTINGS__PASSWORD}
POSTGRES_DB: postgres
ports:
- '5432:5432'
Expand Down
Loading

0 comments on commit b416793

Please sign in to comment.