From 70f394778ebb0b9914d6fa084b294f5af000fa7f Mon Sep 17 00:00:00 2001 From: osoken Date: Sat, 24 Feb 2024 14:18:54 +0900 Subject: [PATCH] feat(topics): add topics list route (sample data) --- birdxplorer/models.py | 48 ++++++++++++++++++++++++++++++++++++- birdxplorer/routers/data.py | 18 +++++++++++++- pyproject.toml | 3 ++- tests/routers/conftest.py | 17 ++++++++++++- tests/routers/test_data.py | 9 ++++++- 5 files changed, 90 insertions(+), 5 deletions(-) diff --git a/birdxplorer/models.py b/birdxplorer/models.py index 3dbc6b9..319bc93 100644 --- a/birdxplorer/models.py +++ b/birdxplorer/models.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from datetime import datetime, timezone from enum import Enum -from typing import Any, Literal, Type, TypeAlias, TypeVar, Union +from typing import Any, Dict, Literal, Type, TypeAlias, TypeVar, Union from pydantic import BaseModel as PydanticBaseModel from pydantic import ConfigDict, Field, GetCoreSchemaHandler, TypeAdapter @@ -128,6 +128,31 @@ def __get_extra_constraint_dict__(cls) -> dict[str, Any]: return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9A-F]{64}$") +class NonEmptyStringMixin(BaseString): + @classmethod + def __get_extra_constraint_dict__(cls) -> dict[str, Any]: + return dict(super().__get_extra_constraint_dict__(), min_length=1) + + +class TrimmedStringMixin(BaseString): + @classmethod + def __get_extra_constraint_dict__(cls) -> dict[str, Any]: + return dict(super().__get_extra_constraint_dict__(), strip_whitespace=True) + + +class NonEmptyTrimmedString(TrimmedStringMixin, NonEmptyStringMixin): + """ + >>> NonEmptyTrimmedString.from_str("test") + NonEmptyTrimmedString('test') + >>> NonEmptyTrimmedString.from_str("") + Traceback (most recent call last): + ... + pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str] + String should have at least 1 character [type=string_too_short, input_value='', input_type=str] + ... + """ + + class BaseInt(int): """ >>> BaseInt(1) @@ -543,3 +568,24 @@ class Note(BaseModel): harmful: NotesHarmful validation_difficulty: NotesValidationDifficulty summary: str + + +class TopicId(NonNegativeInt): ... + + +class LanguageIdentifier(str, Enum): + EN = "en" + ES = "es" + JA = "ja" + PT = "pt" + DE = "de" + FR = "fr" + + +class TopicLabelString(NonEmptyTrimmedString): ... + + +class Topic(BaseModel): + topic_id: TopicId + label: Dict[LanguageIdentifier, TopicLabelString] + reference_count: NonNegativeInt diff --git a/birdxplorer/routers/data.py b/birdxplorer/routers/data.py index f61a009..6cd76f5 100644 --- a/birdxplorer/routers/data.py +++ b/birdxplorer/routers/data.py @@ -1,9 +1,15 @@ +from typing import List + from fastapi import APIRouter -from ..models import ParticipantId, UserEnrollment +from ..models import BaseModel, ParticipantId, Topic, UserEnrollment from ..storage import Storage +class TopicListResponse(BaseModel): + data: List[Topic] + + def gen_router(storage: Storage) -> APIRouter: router = APIRouter() @@ -14,4 +20,14 @@ def get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> User raise ValueError(f"participant_id={participant_id} not found") return res + @router.get("/topics", response_model=TopicListResponse) + def get_topics() -> TopicListResponse: + return TopicListResponse( + data=[ + Topic(topic_id=1, label={"en": "topic1", "ja": "トピック1"}, reference_count=12341), + Topic(topic_id=2, label={"en": "topic2", "ja": "トピック2"}, reference_count=1232312342), + Topic(topic_id=3, label={"en": "topic3", "ja": "トピック3"}, reference_count=3), + ] + ) + return router diff --git a/pyproject.toml b/pyproject.toml index af0a400..1a50959 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,8 @@ dependencies = [ "pydantic_settings", "fastapi", "JSON-log-formatter", - "openai" + "openai", + "typing-extensions", ] [project.urls] diff --git a/tests/routers/conftest.py b/tests/routers/conftest.py index 1e1874f..d5e15d3 100644 --- a/tests/routers/conftest.py +++ b/tests/routers/conftest.py @@ -10,7 +10,7 @@ from pytest import fixture from birdxplorer.exceptions import UserEnrollmentNotFoundError -from birdxplorer.models import ParticipantId, TwitterTimestamp, UserEnrollment +from birdxplorer.models import ParticipantId, Topic, TwitterTimestamp, UserEnrollment from birdxplorer.settings import GlobalSettings from birdxplorer.storage import Storage @@ -35,6 +35,11 @@ class UserEnrollmentFactory(ModelFactory[UserEnrollment]): timestamp_of_last_earn_out = Use(gen_random_twitter_timestamp) +@register_fixture(name="topic_factory") +class TopicFactory(ModelFactory[Topic]): + __model__ = Topic + + @fixture def user_enrollment_samples( user_enrollment_factory: UserEnrollmentFactory, @@ -63,3 +68,13 @@ def client(settings_for_test: GlobalSettings, mock_storage: MagicMock) -> Genera with patch("birdxplorer.app.gen_storage", return_value=mock_storage): app = gen_app(settings=settings_for_test) yield TestClient(app) + + +@fixture +def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, None]: + topics = [ + topic_factory.build(topic_id=1, label={"en": "topic1", "ja": "トピック1"}, reference_count=12341), + topic_factory.build(topic_id=2, label={"en": "topic2", "ja": "トピック2"}, reference_count=1232312342), + topic_factory.build(topic_id=3, label={"en": "topic3", "ja": "トピック3"}, reference_count=3), + ] + yield topics diff --git a/tests/routers/test_data.py b/tests/routers/test_data.py index 53f34ad..0997cd1 100644 --- a/tests/routers/test_data.py +++ b/tests/routers/test_data.py @@ -2,7 +2,7 @@ from fastapi.testclient import TestClient -from birdxplorer.models import UserEnrollment +from birdxplorer.models import Topic, UserEnrollment def test_user_enrollments_get(client: TestClient, user_enrollment_samples: List[UserEnrollment]) -> None: @@ -10,3 +10,10 @@ def test_user_enrollments_get(client: TestClient, user_enrollment_samples: List[ assert response.status_code == 200 res_json = response.json() assert res_json["participantId"] == user_enrollment_samples[0].participant_id + + +def test_topics_get(client: TestClient, topic_samples: List[Topic]) -> None: + response = client.get("/api/v1/data/topics") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [d.model_dump(by_alias=True) for d in topic_samples]}