Skip to content

Commit

Permalink
feat(topics): add topics list route (sample data)
Browse files Browse the repository at this point in the history
  • Loading branch information
osoken committed Feb 24, 2024
1 parent 73262ef commit 70f3947
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 5 deletions.
48 changes: 47 additions & 1 deletion birdxplorer/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Literal, Type, TypeAlias, TypeVar, Union
from typing import Any, Dict, Literal, Type, TypeAlias, TypeVar, Union

from pydantic import BaseModel as PydanticBaseModel
from pydantic import ConfigDict, Field, GetCoreSchemaHandler, TypeAdapter
Expand Down Expand Up @@ -128,6 +128,31 @@ def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9A-F]{64}$")


class NonEmptyStringMixin(BaseString):
@classmethod
def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), min_length=1)


class TrimmedStringMixin(BaseString):
@classmethod
def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), strip_whitespace=True)


class NonEmptyTrimmedString(TrimmedStringMixin, NonEmptyStringMixin):
"""
>>> NonEmptyTrimmedString.from_str("test")
NonEmptyTrimmedString('test')
>>> NonEmptyTrimmedString.from_str("")
Traceback (most recent call last):
...
pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str]
String should have at least 1 character [type=string_too_short, input_value='', input_type=str]
...
"""


class BaseInt(int):
"""
>>> BaseInt(1)
Expand Down Expand Up @@ -543,3 +568,24 @@ class Note(BaseModel):
harmful: NotesHarmful
validation_difficulty: NotesValidationDifficulty
summary: str


class TopicId(NonNegativeInt): ...


class LanguageIdentifier(str, Enum):
EN = "en"
ES = "es"
JA = "ja"
PT = "pt"
DE = "de"
FR = "fr"


class TopicLabelString(NonEmptyTrimmedString): ...


class Topic(BaseModel):
topic_id: TopicId
label: Dict[LanguageIdentifier, TopicLabelString]
reference_count: NonNegativeInt
18 changes: 17 additions & 1 deletion birdxplorer/routers/data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import List

from fastapi import APIRouter

from ..models import ParticipantId, UserEnrollment
from ..models import BaseModel, ParticipantId, Topic, UserEnrollment
from ..storage import Storage


class TopicListResponse(BaseModel):
data: List[Topic]


def gen_router(storage: Storage) -> APIRouter:
router = APIRouter()

Expand All @@ -14,4 +20,14 @@ def get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> User
raise ValueError(f"participant_id={participant_id} not found")
return res

@router.get("/topics", response_model=TopicListResponse)
def get_topics() -> TopicListResponse:
return TopicListResponse(
data=[
Topic(topic_id=1, label={"en": "topic1", "ja": "トピック1"}, reference_count=12341),
Topic(topic_id=2, label={"en": "topic2", "ja": "トピック2"}, reference_count=1232312342),
Topic(topic_id=3, label={"en": "topic3", "ja": "トピック3"}, reference_count=3),
]
)

return router
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ dependencies = [
"pydantic_settings",
"fastapi",
"JSON-log-formatter",
"openai"
"openai",
"typing-extensions",
]

[project.urls]
Expand Down
17 changes: 16 additions & 1 deletion tests/routers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pytest import fixture

from birdxplorer.exceptions import UserEnrollmentNotFoundError
from birdxplorer.models import ParticipantId, TwitterTimestamp, UserEnrollment
from birdxplorer.models import ParticipantId, Topic, TwitterTimestamp, UserEnrollment
from birdxplorer.settings import GlobalSettings
from birdxplorer.storage import Storage

Expand All @@ -35,6 +35,11 @@ class UserEnrollmentFactory(ModelFactory[UserEnrollment]):
timestamp_of_last_earn_out = Use(gen_random_twitter_timestamp)


@register_fixture(name="topic_factory")
class TopicFactory(ModelFactory[Topic]):
__model__ = Topic


@fixture
def user_enrollment_samples(
user_enrollment_factory: UserEnrollmentFactory,
Expand Down Expand Up @@ -63,3 +68,13 @@ def client(settings_for_test: GlobalSettings, mock_storage: MagicMock) -> Genera
with patch("birdxplorer.app.gen_storage", return_value=mock_storage):
app = gen_app(settings=settings_for_test)
yield TestClient(app)


@fixture
def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, None]:
topics = [
topic_factory.build(topic_id=1, label={"en": "topic1", "ja": "トピック1"}, reference_count=12341),
topic_factory.build(topic_id=2, label={"en": "topic2", "ja": "トピック2"}, reference_count=1232312342),
topic_factory.build(topic_id=3, label={"en": "topic3", "ja": "トピック3"}, reference_count=3),
]
yield topics
9 changes: 8 additions & 1 deletion tests/routers/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,18 @@

from fastapi.testclient import TestClient

from birdxplorer.models import UserEnrollment
from birdxplorer.models import Topic, UserEnrollment


def test_user_enrollments_get(client: TestClient, user_enrollment_samples: List[UserEnrollment]) -> None:
response = client.get(f"/api/v1/data/user-enrollments/{user_enrollment_samples[0].participant_id}")
assert response.status_code == 200
res_json = response.json()
assert res_json["participantId"] == user_enrollment_samples[0].participant_id


def test_topics_get(client: TestClient, topic_samples: List[Topic]) -> None:
response = client.get("/api/v1/data/topics")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [d.model_dump(by_alias=True) for d in topic_samples]}

0 comments on commit 70f3947

Please sign in to comment.