From b416793f936c1a3fbfbbaf161d080f6717b8b71a Mon Sep 17 00:00:00 2001 From: osoken Date: Mon, 26 Feb 2024 03:09:23 +0900 Subject: [PATCH] feat(app): add topics route --- birdxplorer/models.py | 51 +++++++- birdxplorer/settings.py | 4 +- birdxplorer/storage.py | 77 ++++++++++- compose.yml | 2 +- tests/conftest.py | 269 +++++++++++++++++++++++++++++++++++++- tests/routers/conftest.py | 88 ------------- tests/test_app.py | 10 +- tests/test_data_model.py | 4 +- tests/test_storage.py | 18 +++ 9 files changed, 409 insertions(+), 114 deletions(-) delete mode 100644 tests/routers/conftest.py create mode 100644 tests/test_storage.py diff --git a/birdxplorer/models.py b/birdxplorer/models.py index 319bc93..2844b01 100644 --- a/birdxplorer/models.py +++ b/birdxplorer/models.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod from datetime import datetime, timezone from enum import Enum -from typing import Any, Dict, Literal, Type, TypeAlias, TypeVar, Union +from typing import Any, Dict, List, Literal, Type, TypeAlias, TypeVar, Union from pydantic import BaseModel as PydanticBaseModel -from pydantic import ConfigDict, Field, GetCoreSchemaHandler, TypeAdapter +from pydantic import ConfigDict, GetCoreSchemaHandler, TypeAdapter from pydantic.alias_generators import to_camel from pydantic_core import core_schema @@ -128,6 +128,23 @@ def __get_extra_constraint_dict__(cls) -> dict[str, Any]: return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9A-F]{64}$") +class NineToNineteenDigitsDecimalString(BaseString): + """ + >>> NineToNineteenDigitsDecimalString.from_str("test") + Traceback (most recent call last): + ... + pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str] + String should match pattern '^[0-9]{9,19}$' [type=string_pattern_mismatch, input_value='test', input_type=str] + ... + >>> NineToNineteenDigitsDecimalString.from_str("1234567890123456789") + NineToNineteenDigitsDecimalString('1234567890123456789') + """ + + @classmethod + def __get_extra_constraint_dict__(cls) -> dict[str, Any]: + return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9]{9,19}$") + + class NonEmptyStringMixin(BaseString): @classmethod def __get_extra_constraint_dict__(cls) -> dict[str, Any]: @@ -544,11 +561,18 @@ class NotesValidationDifficulty(str, Enum): empty = "" -class Note(BaseModel): +class TweetId(NineToNineteenDigitsDecimalString): ... + + +class NoteData(BaseModel): + """ + This is for validating the original data from notes.csv. + """ + note_id: NoteId note_author_participant_id: ParticipantId created_at_millis: TwitterTimestamp - tweet_id: str = Field(pattern=r"^[0-9]{9,19}$") + tweet_id: TweetId believable: NotesBelievable misleading_other: BinaryBool misleading_factual_error: BinaryBool @@ -585,7 +609,24 @@ class LanguageIdentifier(str, Enum): class TopicLabelString(NonEmptyTrimmedString): ... +TopicLabel: TypeAlias = Dict[LanguageIdentifier, TopicLabelString] + + class Topic(BaseModel): + model_config = ConfigDict(from_attributes=True) + topic_id: TopicId - label: Dict[LanguageIdentifier, TopicLabelString] + label: TopicLabel reference_count: NonNegativeInt + + +class SummaryString(NonEmptyTrimmedString): ... + + +class Note(BaseModel): + note_id: NoteId + post_id: TweetId + language: LanguageIdentifier + topics: List[Topic] + summary: SummaryString + created_at: TwitterTimestamp diff --git a/birdxplorer/settings.py b/birdxplorer/settings.py index ad71cd0..701a890 100644 --- a/birdxplorer/settings.py +++ b/birdxplorer/settings.py @@ -20,11 +20,11 @@ class PostgresStorageSettings(BaseSettings): @computed_field # type: ignore[misc] @property - def sqlalchemy_database_url(self) -> PostgresDsn: + def sqlalchemy_database_url(self) -> str: return PostgresDsn( url=f"postgresql://{self.username}:" f"{self.password.replace('@', '%40')}@{self.host}:{self.port}/{self.database}" - ) + ).unicode_string() class GlobalSettings(BaseSettings): diff --git a/birdxplorer/storage.py b/birdxplorer/storage.py index e7fa7c9..42cd536 100644 --- a/birdxplorer/storage.py +++ b/birdxplorer/storage.py @@ -1,16 +1,83 @@ -from typing import Generator +from typing import Generator, List -from .models import ParticipantId, Topic, UserEnrollment +from sqlalchemy import ForeignKey, create_engine, func, select +from sqlalchemy.engine import Engine +from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship +from sqlalchemy.types import DECIMAL, JSON, Integer, String + +from .models import LanguageIdentifier, NoteId, ParticipantId, SummaryString +from .models import Topic as TopicModel +from .models import TopicId, TopicLabel, TweetId, TwitterTimestamp, UserEnrollment from .settings import GlobalSettings +class Base(DeclarativeBase): + type_annotation_map = { + TopicId: Integer, + TopicLabel: JSON, + NoteId: String, + ParticipantId: String, + TweetId: String, + LanguageIdentifier: String, + TwitterTimestamp: DECIMAL, + SummaryString: String, + } + + +class NoteTopicAssociation(Base): + __tablename__ = "note_topic" + + note_id: Mapped[NoteId] = mapped_column(ForeignKey("notes.note_id"), primary_key=True) + topic_id: Mapped[TopicId] = mapped_column(ForeignKey("topics.topic_id"), primary_key=True) + topic: Mapped["TopicRecord"] = relationship() + + +class NoteRecord(Base): + __tablename__ = "notes" + + note_id: Mapped[NoteId] = mapped_column(primary_key=True) + post_id: Mapped[TweetId] = mapped_column(nullable=False) + topics: Mapped[List[NoteTopicAssociation]] = relationship() + language: Mapped[LanguageIdentifier] = mapped_column(nullable=False) + summary: Mapped[SummaryString] = mapped_column(nullable=False) + created_at: Mapped[TwitterTimestamp] = mapped_column(nullable=False) + + +class TopicRecord(Base): + __tablename__ = "topics" + + topic_id: Mapped[TopicId] = mapped_column(primary_key=True) + label: Mapped[TopicLabel] = mapped_column(nullable=False) + + class Storage: + def __init__(self, engine: Engine) -> None: + self._engine = engine + + @property + def engine(self) -> Engine: + return self._engine + def get_user_enrollment_by_participant_id(self, participant_id: ParticipantId) -> UserEnrollment: raise NotImplementedError - def get_topics(self) -> Generator[Topic, None, None]: - raise NotImplementedError + def get_topics(self) -> Generator[TopicModel, None, None]: + with Session(self.engine) as sess: + subq = ( + select(NoteTopicAssociation.topic_id, func.count().label("reference_count")) + .group_by(NoteTopicAssociation.topic_id) + .subquery() + ) + for topic_record, reference_count in ( + sess.query(TopicRecord, subq.c.reference_count) + .outerjoin(subq, TopicRecord.topic_id == subq.c.topic_id) + .all() + ): + yield TopicModel( + topic_id=topic_record.topic_id, label=topic_record.label, reference_count=reference_count or 0 + ) def gen_storage(settings: GlobalSettings) -> Storage: - return Storage() + engine = create_engine(settings.storage_settings.sqlalchemy_database_url) + return Storage(engine=engine) diff --git a/compose.yml b/compose.yml index 77e110d..5d3bd88 100644 --- a/compose.yml +++ b/compose.yml @@ -6,7 +6,7 @@ services: container_name: postgres_container environment: POSTGRES_USER: postgres - POSTGRES_PASSWORD: birdxplorer + POSTGRES_PASSWORD: ${BX_STORAGE_SETTINGS__PASSWORD} POSTGRES_DB: postgres ports: - '5432:5432' diff --git a/tests/conftest.py b/tests/conftest.py index a1bb82f..f13eeeb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,276 @@ +import os +import random from collections.abc import Generator -from typing import Type +from typing import List, Type +from unittest.mock import MagicMock, patch +from dotenv import load_dotenv +from fastapi.testclient import TestClient +from polyfactory import Use from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.pytest_plugin import register_fixture from pytest import fixture +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session +from sqlalchemy.sql import text -from birdxplorer.settings import GlobalSettings +from birdxplorer.exceptions import UserEnrollmentNotFoundError +from birdxplorer.models import ( + Note, + ParticipantId, + Topic, + TwitterTimestamp, + UserEnrollment, +) +from birdxplorer.settings import GlobalSettings, PostgresStorageSettings +from birdxplorer.storage import ( + Base, + NoteRecord, + NoteTopicAssociation, + Storage, + TopicRecord, +) + + +def gen_random_twitter_timestamp() -> int: + return random.randint(TwitterTimestamp.min_value(), TwitterTimestamp.max_value()) + + +@fixture +def load_dotenv_fixture() -> None: + load_dotenv() + + +@fixture +def postgres_storage_settings_factory( + load_dotenv_fixture: None, +) -> Type[ModelFactory[PostgresStorageSettings]]: + class PostgresStorageSettingsFactory(ModelFactory[PostgresStorageSettings]): + __model__ = PostgresStorageSettings + + host = "localhost" + username = "postgres" + port = 5432 + database = "postgres" + password = os.environ["BX_STORAGE_SETTINGS__PASSWORD"] + + return PostgresStorageSettingsFactory @fixture -def global_settings_factory_class() -> Generator[Type[ModelFactory[GlobalSettings]], None, None]: +def global_settings_factory( + postgres_storage_settings_factory: Type[ModelFactory[PostgresStorageSettings]], +) -> Type[ModelFactory[GlobalSettings]]: class GlobalSettingsFactory(ModelFactory[GlobalSettings]): __model__ = GlobalSettings - yield GlobalSettingsFactory + storage_settings = postgres_storage_settings_factory.build() + + return GlobalSettingsFactory + + +@register_fixture(name="user_enrollment_factory") +class UserEnrollmentFactory(ModelFactory[UserEnrollment]): + __model__ = UserEnrollment + + participant_id = Use(lambda: "".join(random.choices("0123456789ABCDEF", k=64))) + timestamp_of_last_state_change = Use(gen_random_twitter_timestamp) + timestamp_of_last_earn_out = Use(gen_random_twitter_timestamp) + + +@register_fixture(name="note_factory") +class NoteFactory(ModelFactory[Note]): + __model__ = Note + + +@register_fixture(name="topic_factory") +class TopicFactory(ModelFactory[Topic]): + __model__ = Topic + + +@fixture +def user_enrollment_samples( + user_enrollment_factory: UserEnrollmentFactory, +) -> Generator[List[UserEnrollment], None, None]: + yield [user_enrollment_factory.build() for _ in range(3)] + + +@fixture +def mock_storage( + user_enrollment_samples: List[UserEnrollment], topic_samples: List[Topic] +) -> Generator[MagicMock, None, None]: + mock = MagicMock(spec=Storage) + + def _get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> UserEnrollment: + x = {d.participant_id: d for d in user_enrollment_samples}.get(participant_id) + if x is None: + raise UserEnrollmentNotFoundError(participant_id=participant_id) + return x + + mock.get_user_enrollment_by_participant_id.side_effect = _get_user_enrollment_by_participant_id + + def _get_topics() -> Generator[Topic, None, None]: + yield from topic_samples + + mock.get_topics.side_effect = _get_topics + + yield mock + + +@fixture +def client(settings_for_test: GlobalSettings, mock_storage: MagicMock) -> Generator[TestClient, None, None]: + from birdxplorer.app import gen_app + + with patch("birdxplorer.app.gen_storage", return_value=mock_storage): + app = gen_app(settings=settings_for_test) + yield TestClient(app) + + +@fixture +def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, None]: + topics = [ + topic_factory.build(topic_id=0, label={"en": "topic0", "ja": "トピック0"}, reference_count=3), + topic_factory.build(topic_id=1, label={"en": "topic1", "ja": "トピック1"}, reference_count=2), + topic_factory.build(topic_id=2, label={"en": "topic2", "ja": "トピック2"}, reference_count=1), + topic_factory.build(topic_id=3, label={"en": "topic3", "ja": "トピック3"}, reference_count=0), + ] + yield topics + + +@fixture +def note_samples(note_factory: NoteFactory, topic_samples: List[Topic]) -> Generator[List[Note], None, None]: + notes = [ + note_factory.build( + note_id="1234567890123456781", + post_id="2234567890123456781", + topics=[topic_samples[0]], + language="ja", + summary="要約文1", + created_at=1152921600000, + ), + note_factory.build( + note_id="1234567890123456782", + post_id="2234567890123456782", + topics=[], + language="en", + summary="summary2", + created_at=1152921601000, + ), + note_factory.build( + note_id="1234567890123456783", + post_id="2234567890123456783", + topics=[topic_samples[1]], + language="en", + summary="summary3", + created_at=1152921602000, + ), + note_factory.build( + note_id="1234567890123456784", + post_id="2234567890123456784", + topics=[topic_samples[0], topic_samples[1], topic_samples[2]], + language="en", + summary="summary4", + created_at=1152921603000, + ), + note_factory.build( + note_id="1234567890123456785", + post_id="2234567890123456785", + topics=[topic_samples[0]], + language="en", + summary="summary5", + created_at=1152921604000, + ), + ] + yield notes + + +TEST_DATABASE_NAME = "bx_test" + + +@fixture +def default_settings( + global_settings_factory: Type[ModelFactory[GlobalSettings]], +) -> Generator[GlobalSettings, None, None]: + yield global_settings_factory.build() + + +@fixture +def settings_for_test( + global_settings_factory: Type[ModelFactory[GlobalSettings]], + postgres_storage_settings_factory: Type[ModelFactory[PostgresStorageSettings]], +) -> Generator[GlobalSettings, None, None]: + yield global_settings_factory.build( + storage_settings=postgres_storage_settings_factory.build(database=TEST_DATABASE_NAME) + ) + + +@fixture +def engine_for_test( + default_settings: GlobalSettings, settings_for_test: GlobalSettings +) -> Generator[Engine, None, None]: + default_engine = create_engine(default_settings.storage_settings.sqlalchemy_database_url) + with default_engine.connect() as conn: + conn.execute(text("COMMIT")) + try: + conn.execute(text(f"DROP DATABASE {TEST_DATABASE_NAME}")) + except SQLAlchemyError: + pass + + with default_engine.connect() as conn: + conn.execute(text("COMMIT")) + conn.execute(text(f"CREATE DATABASE {TEST_DATABASE_NAME}")) + + engine = create_engine(settings_for_test.storage_settings.sqlalchemy_database_url) + + Base.metadata.create_all(engine) + + yield engine + + engine.dispose() + del engine + + with default_engine.connect() as conn: + conn.execute(text("COMMIT")) + conn.execute(text(f"DROP DATABASE {TEST_DATABASE_NAME}")) + + default_engine.dispose() + + +@fixture +def topic_records_sample( + engine_for_test: Engine, + topic_samples: List[TopicRecord], +) -> Generator[List[TopicRecord], None, None]: + res = [TopicRecord(topic_id=d.topic_id, label=d.label) for d in topic_samples] + with Session(engine_for_test) as sess: + sess.add_all(res) + sess.commit() + yield res + + +@fixture +def note_records_sample( + note_samples: List[NoteRecord], + topic_records_sample: List[TopicRecord], + engine_for_test: Engine, +) -> Generator[List[NoteRecord], None, None]: + res: List[NoteRecord] = [] + with Session(engine_for_test) as sess: + for note in note_samples: + inst = NoteRecord( + note_id=note.note_id, + post_id=note.post_id, + language=note.language, + summary=note.summary, + created_at=note.created_at, + ) + sess.add(inst) + for topic in note.topics: + assoc = NoteTopicAssociation(topic_id=topic.topic_id, note_id=inst.note_id) + sess.add(assoc) + inst.topics.append(assoc) + res.append(inst) + sess.commit() + yield res diff --git a/tests/routers/conftest.py b/tests/routers/conftest.py deleted file mode 100644 index 1aff30a..0000000 --- a/tests/routers/conftest.py +++ /dev/null @@ -1,88 +0,0 @@ -import random -from collections.abc import Generator -from typing import List, Type -from unittest.mock import MagicMock, patch - -from fastapi.testclient import TestClient -from polyfactory import Use -from polyfactory.factories.pydantic_factory import ModelFactory -from polyfactory.pytest_plugin import register_fixture -from pytest import fixture - -from birdxplorer.exceptions import UserEnrollmentNotFoundError -from birdxplorer.models import ParticipantId, Topic, TwitterTimestamp, UserEnrollment -from birdxplorer.settings import GlobalSettings -from birdxplorer.storage import Storage - - -def gen_random_twitter_timestamp() -> int: - return random.randint(TwitterTimestamp.min_value(), TwitterTimestamp.max_value()) - - -@fixture -def settings_for_test( - global_settings_factory_class: Type[ModelFactory[GlobalSettings]], -) -> Generator[GlobalSettings, None, None]: - yield global_settings_factory_class.build() - - -@register_fixture(name="user_enrollment_factory") -class UserEnrollmentFactory(ModelFactory[UserEnrollment]): - __model__ = UserEnrollment - - participant_id = Use(lambda: "".join(random.choices("0123456789ABCDEF", k=64))) - timestamp_of_last_state_change = Use(gen_random_twitter_timestamp) - timestamp_of_last_earn_out = Use(gen_random_twitter_timestamp) - - -@register_fixture(name="topic_factory") -class TopicFactory(ModelFactory[Topic]): - __model__ = Topic - - -@fixture -def user_enrollment_samples( - user_enrollment_factory: UserEnrollmentFactory, -) -> Generator[List[UserEnrollment], None, None]: - yield [user_enrollment_factory.build() for _ in range(3)] - - -@fixture -def mock_storage( - user_enrollment_samples: List[UserEnrollment], topic_samples: List[Topic] -) -> Generator[MagicMock, None, None]: - mock = MagicMock(spec=Storage) - - def _get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> UserEnrollment: - x = {d.participant_id: d for d in user_enrollment_samples}.get(participant_id) - if x is None: - raise UserEnrollmentNotFoundError(participant_id=participant_id) - return x - - mock.get_user_enrollment_by_participant_id.side_effect = _get_user_enrollment_by_participant_id - - def _get_topics() -> Generator[Topic, None, None]: - yield from topic_samples - - mock.get_topics.side_effect = _get_topics - - yield mock - - -@fixture -def client(settings_for_test: GlobalSettings, mock_storage: MagicMock) -> Generator[TestClient, None, None]: - from birdxplorer.app import gen_app - - with patch("birdxplorer.app.gen_storage", return_value=mock_storage): - app = gen_app(settings=settings_for_test) - yield TestClient(app) - - -@fixture -def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, None]: - topics = [ - topic_factory.build(topic_id=1, label={"en": "topic1", "ja": "トピック1"}, reference_count=12341), - topic_factory.build(topic_id=2, label={"en": "topic2", "ja": "トピック2"}, reference_count=1232312342), - topic_factory.build(topic_id=3, label={"en": "topic3", "ja": "トピック3"}, reference_count=3), - ] - yield topics diff --git a/tests/test_app.py b/tests/test_app.py index 47044de..e87e423 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,19 +1,15 @@ -from typing import Type - -from polyfactory.factories.pydantic_factory import ModelFactory from pytest_mock import MockerFixture from birdxplorer.app import gen_app from birdxplorer.settings import GlobalSettings -def test_gen_app(mocker: MockerFixture, global_settings_factory_class: Type[ModelFactory[GlobalSettings]]) -> None: +def test_gen_app(mocker: MockerFixture, default_settings: GlobalSettings) -> None: FastAPI = mocker.patch("birdxplorer.app.FastAPI") - settings = global_settings_factory_class.build() get_logger = mocker.patch("birdxplorer.app.get_logger") expected = FastAPI.return_value - actual = gen_app(settings=settings) + actual = gen_app(settings=default_settings) assert actual == expected - get_logger.assert_called_once_with(level=settings.logger_settings.level) + get_logger.assert_called_once_with(level=default_settings.logger_settings.level) diff --git a/tests/test_data_model.py b/tests/test_data_model.py index afa57a0..a57e765 100644 --- a/tests/test_data_model.py +++ b/tests/test_data_model.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping -from birdxplorer.models import Note, UserEnrollment +from birdxplorer.models import NoteData, UserEnrollment class BaseDataModelTester(ABC): @@ -52,7 +52,7 @@ def __init__(self) -> None: super(NoteTester, self).__init__(re.compile(r"notes-[0-9]{5}.tsv")) def validate(self, row: Mapping[str, str]) -> None: - _ = Note.model_validate(row) + _ = NoteData.model_validate(row) def test_notes() -> None: diff --git a/tests/test_storage.py b/tests/test_storage.py new file mode 100644 index 0000000..f3b380d --- /dev/null +++ b/tests/test_storage.py @@ -0,0 +1,18 @@ +from typing import List + +from sqlalchemy.engine import Engine + +from birdxplorer.models import Topic +from birdxplorer.storage import NoteRecord, Storage, TopicRecord + + +def test_get_topic_list( + engine_for_test: Engine, + topic_samples: List[Topic], + topic_records_sample: List[TopicRecord], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + expected = sorted(topic_samples, key=lambda x: x.topic_id) + actual = sorted(storage.get_topics(), key=lambda x: x.topic_id) + assert expected == actual