Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/issue 40 add topics endpoint #46

Merged
merged 13 commits into from
Mar 11, 2024
10 changes: 10 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@ permissions:
jobs:
test-check:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:15.4
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: birdxplorer
POSTGRES_DB: postgres
ports:
- 5432:5432

timeout-minutes: 5
steps:
- name: Checkout
Expand Down
95 changes: 91 additions & 4 deletions birdxplorer/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Literal, Type, TypeAlias, TypeVar, Union
from typing import Any, Dict, List, Literal, Type, TypeAlias, TypeVar, Union

from pydantic import BaseModel as PydanticBaseModel
from pydantic import ConfigDict, Field, GetCoreSchemaHandler, TypeAdapter
from pydantic import ConfigDict, GetCoreSchemaHandler, TypeAdapter
from pydantic.alias_generators import to_camel
from pydantic_core import core_schema

Expand Down Expand Up @@ -128,6 +128,48 @@ def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9A-F]{64}$")


class NineToNineteenDigitsDecimalString(BaseString):
"""
>>> NineToNineteenDigitsDecimalString.from_str("test")
Traceback (most recent call last):
...
pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str]
String should match pattern '^[0-9]{9,19}$' [type=string_pattern_mismatch, input_value='test', input_type=str]
...
>>> NineToNineteenDigitsDecimalString.from_str("1234567890123456789")
NineToNineteenDigitsDecimalString('1234567890123456789')
"""

@classmethod
def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9]{9,19}$")


class NonEmptyStringMixin(BaseString):
@classmethod
def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), min_length=1)


class TrimmedStringMixin(BaseString):
@classmethod
def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), strip_whitespace=True)


class NonEmptyTrimmedString(TrimmedStringMixin, NonEmptyStringMixin):
"""
>>> NonEmptyTrimmedString.from_str("test")
NonEmptyTrimmedString('test')
>>> NonEmptyTrimmedString.from_str("")
Traceback (most recent call last):
...
pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str]
String should have at least 1 character [type=string_too_short, input_value='', input_type=str]
...
"""


class BaseInt(int):
"""
>>> BaseInt(1)
Expand Down Expand Up @@ -519,11 +561,18 @@ class NotesValidationDifficulty(str, Enum):
empty = ""


class Note(BaseModel):
class TweetId(NineToNineteenDigitsDecimalString): ...


class NoteData(BaseModel):
"""
This is for validating the original data from notes.csv.
"""

note_id: NoteId
note_author_participant_id: ParticipantId
created_at_millis: TwitterTimestamp
tweet_id: str = Field(pattern=r"^[0-9]{9,19}$")
tweet_id: TweetId
believable: NotesBelievable
misleading_other: BinaryBool
misleading_factual_error: BinaryBool
Expand All @@ -543,3 +592,41 @@ class Note(BaseModel):
harmful: NotesHarmful
validation_difficulty: NotesValidationDifficulty
summary: str


class TopicId(NonNegativeInt): ...


class LanguageIdentifier(str, Enum):
EN = "en"
ES = "es"
JA = "ja"
PT = "pt"
DE = "de"
FR = "fr"


class TopicLabelString(NonEmptyTrimmedString): ...


TopicLabel: TypeAlias = Dict[LanguageIdentifier, TopicLabelString]


class Topic(BaseModel):
model_config = ConfigDict(from_attributes=True)

topic_id: TopicId
label: TopicLabel
reference_count: NonNegativeInt


class SummaryString(NonEmptyTrimmedString): ...


class Note(BaseModel):
note_id: NoteId
post_id: TweetId
language: LanguageIdentifier
topics: List[Topic]
summary: SummaryString
created_at: TwitterTimestamp
12 changes: 11 additions & 1 deletion birdxplorer/routers/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import List

from fastapi import APIRouter

from ..models import ParticipantId, UserEnrollment
from ..models import BaseModel, ParticipantId, Topic, UserEnrollment
from ..storage import Storage


class TopicListResponse(BaseModel):
data: List[Topic]


def gen_router(storage: Storage) -> APIRouter:
router = APIRouter()

Expand All @@ -14,4 +20,8 @@ def get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> User
raise ValueError(f"participant_id={participant_id} not found")
return res

@router.get("/topics", response_model=TopicListResponse)
def get_topics() -> TopicListResponse:
return TopicListResponse(data=list(storage.get_topics()))

return router
4 changes: 2 additions & 2 deletions birdxplorer/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ class PostgresStorageSettings(BaseSettings):

@computed_field # type: ignore[misc]
@property
def sqlalchemy_database_url(self) -> PostgresDsn:
def sqlalchemy_database_url(self) -> str:
return PostgresDsn(
url=f"postgresql://{self.username}:"
f"{self.password.replace('@', '%40')}@{self.host}:{self.port}/{self.database}"
)
).unicode_string()


class GlobalSettings(BaseSettings):
Expand Down
76 changes: 74 additions & 2 deletions birdxplorer/storage.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,83 @@
from .models import ParticipantId, UserEnrollment
from typing import Generator, List

from sqlalchemy import ForeignKey, create_engine, func, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
from sqlalchemy.types import DECIMAL, JSON, Integer, String

from .models import LanguageIdentifier, NoteId, ParticipantId, SummaryString
from .models import Topic as TopicModel
from .models import TopicId, TopicLabel, TweetId, TwitterTimestamp, UserEnrollment
from .settings import GlobalSettings


class Base(DeclarativeBase):
type_annotation_map = {
TopicId: Integer,
TopicLabel: JSON,
NoteId: String,
ParticipantId: String,
TweetId: String,
LanguageIdentifier: String,
TwitterTimestamp: DECIMAL,
SummaryString: String,
}


class NoteTopicAssociation(Base):
__tablename__ = "note_topic"

note_id: Mapped[NoteId] = mapped_column(ForeignKey("notes.note_id"), primary_key=True)
topic_id: Mapped[TopicId] = mapped_column(ForeignKey("topics.topic_id"), primary_key=True)
topic: Mapped["TopicRecord"] = relationship()


class NoteRecord(Base):
__tablename__ = "notes"

note_id: Mapped[NoteId] = mapped_column(primary_key=True)
post_id: Mapped[TweetId] = mapped_column(nullable=False)
topics: Mapped[List[NoteTopicAssociation]] = relationship()
language: Mapped[LanguageIdentifier] = mapped_column(nullable=False)
summary: Mapped[SummaryString] = mapped_column(nullable=False)
created_at: Mapped[TwitterTimestamp] = mapped_column(nullable=False)


class TopicRecord(Base):
__tablename__ = "topics"

topic_id: Mapped[TopicId] = mapped_column(primary_key=True)
label: Mapped[TopicLabel] = mapped_column(nullable=False)


class Storage:
def __init__(self, engine: Engine) -> None:
self._engine = engine

@property
def engine(self) -> Engine:
return self._engine

def get_user_enrollment_by_participant_id(self, participant_id: ParticipantId) -> UserEnrollment:
raise NotImplementedError

def get_topics(self) -> Generator[TopicModel, None, None]:
with Session(self.engine) as sess:
subq = (
select(NoteTopicAssociation.topic_id, func.count().label("reference_count"))
.group_by(NoteTopicAssociation.topic_id)
.subquery()
)
for topic_record, reference_count in (
sess.query(TopicRecord, subq.c.reference_count)
.outerjoin(subq, TopicRecord.topic_id == subq.c.topic_id)
.all()
):
yield TopicModel(
topic_id=topic_record.topic_id, label=topic_record.label, reference_count=reference_count or 0
)


def gen_storage(settings: GlobalSettings) -> Storage:
return Storage()
engine = create_engine(settings.storage_settings.sqlalchemy_database_url)
return Storage(engine=engine)
2 changes: 1 addition & 1 deletion compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ services:
container_name: postgres_container
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: birdxplorer
POSTGRES_PASSWORD: ${BX_STORAGE_SETTINGS__PASSWORD}
POSTGRES_DB: postgres
ports:
- '5432:5432'
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"pydantic_settings",
"fastapi",
"JSON-log-formatter",
"openai"
"openai",
]

[project.urls]
Expand Down
54 changes: 54 additions & 0 deletions scripts/migrations/convert_data_from_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import csv
import json
import os
from argparse import ArgumentParser
from collections import Counter

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("notes_file")
parser.add_argument("output_dir")
parser.add_argument("--notes-file-name", default="notes.csv")
parser.add_argument("--topics-file-name", default="topics.csv")
parser.add_argument("--notes-topics-association-file-name", default="note_topic.csv")
parser.add_argument("--topic-threshold", type=int, default=5)

args = parser.parse_args()

with open(args.notes_file, "r", encoding="utf-8") as fin:
notes = list(csv.DictReader(fin))
for d in notes:
d["topic"] = [t.strip() for t in d["topic"].split(",")]
topics_with_count = Counter(t for d in notes for t in d["topic"])
topic_name_to_id_map = {t: i for i, (t, c) in enumerate(topics_with_count.items()) if c > args.topic_threshold}

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

with open(os.path.join(args.output_dir, args.topics_file_name), "w", encoding="utf-8") as fout:
writer = csv.DictWriter(fout, fieldnames=["topic_id", "label"])
writer.writeheader()
for topic, topic_id in topic_name_to_id_map.items():
writer.writerow({"topic_id": topic_id, "label": json.dumps({"ja": topic})})

with open(os.path.join(args.output_dir, args.notes_file_name), "w", encoding="utf-8") as fout:
writer = csv.DictWriter(fout, fieldnames=["note_id", "post_id", "language", "summary", "created_at"])
writer.writeheader()
for d in notes:
writer.writerow(
{
"note_id": d["note_id"],
"post_id": d["post_id"],
"language": d["language"],
"summary": d["summary"],
"created_at": d["created_at"],
}
)

with open(os.path.join(args.output_dir, args.notes_topics_association_file_name), "w", encoding="utf-8") as fout:
writer = csv.DictWriter(fout, fieldnames=["note_id", "topic_id"])
writer.writeheader()
for d in notes:
for t in d["topic"]:
if t in topic_name_to_id_map:
writer.writerow({"note_id": d["note_id"], "topic_id": topic_name_to_id_map[t]})
Loading
Loading