Skip to content

Commit

Permalink
Merge pull request #47 from codeforjapan/feature/issue-39-add-notes-e…
Browse files Browse the repository at this point in the history
…ndpoint

Feature/issue 39 add notes endpoint
  • Loading branch information
osoken authored Apr 24, 2024
2 parents bc3fd5c + 279ccb3 commit 3db9d8e
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 12 deletions.
31 changes: 31 additions & 0 deletions birdxplorer/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@

from ..models import (
BaseModel,
LanguageIdentifier,
Note,
NoteId,
ParticipantId,
Post,
PostId,
Topic,
TopicId,
TweetId,
TwitterTimestamp,
UserEnrollment,
)
Expand All @@ -20,6 +25,10 @@ class TopicListResponse(BaseModel):
data: List[Topic]


class NoteListResponse(BaseModel):
data: List[Note]


class PostListResponse(BaseModel):
data: List[Post]

Expand Down Expand Up @@ -56,6 +65,28 @@ def get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> User
def get_topics() -> TopicListResponse:
return TopicListResponse(data=list(storage.get_topics()))

@router.get("/notes", response_model=NoteListResponse)
def get_notes(
note_ids: Union[List[NoteId], None] = Query(default=None),
created_at_from: Union[None, TwitterTimestamp] = Query(default=None),
created_at_to: Union[None, TwitterTimestamp] = Query(default=None),
topic_ids: Union[List[TopicId], None] = Query(default=None),
post_ids: Union[List[TweetId], None] = Query(default=None),
language: Union[LanguageIdentifier, None] = Query(default=None),
) -> NoteListResponse:
return NoteListResponse(
data=list(
storage.get_notes(
note_ids=note_ids,
created_at_from=created_at_from,
created_at_to=created_at_to,
topic_ids=topic_ids,
post_ids=post_ids,
language=language,
)
)
)

@router.get("/posts", response_model=PostListResponse)
def get_posts(
post_id: Union[List[PostId], None] = Query(default=None),
Expand Down
63 changes: 55 additions & 8 deletions birdxplorer/storage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generator, List
from typing import Generator, List, Union

from psycopg2.extensions import AsIs, register_adapter
from pydantic import AnyUrl, HttpUrl
Expand All @@ -7,13 +7,9 @@
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
from sqlalchemy.types import DECIMAL, JSON, Integer, String

from .models import (
LanguageIdentifier,
MediaDetails,
NonNegativeInt,
NoteId,
ParticipantId,
)
from .models import LanguageIdentifier, MediaDetails, NonNegativeInt
from .models import Note as NoteModel
from .models import NoteId, ParticipantId
from .models import Post as PostModel
from .models import PostId, SummaryString
from .models import Topic as TopicModel
Expand Down Expand Up @@ -152,6 +148,57 @@ def get_topics(self) -> Generator[TopicModel, None, None]:
topic_id=topic_record.topic_id, label=topic_record.label, reference_count=reference_count or 0
)

def get_notes(
self,
note_ids: Union[List[NoteId], None] = None,
created_at_from: Union[None, TwitterTimestamp] = None,
created_at_to: Union[None, TwitterTimestamp] = None,
topic_ids: Union[List[TopicId], None] = None,
post_ids: Union[List[TweetId], None] = None,
language: Union[LanguageIdentifier, None] = None,
) -> Generator[NoteModel, None, None]:
with Session(self.engine) as sess:
query = sess.query(NoteRecord)
if note_ids is not None:
query = query.filter(NoteRecord.note_id.in_(note_ids))
if created_at_from is not None:
query = query.filter(NoteRecord.created_at >= created_at_from)
if created_at_to is not None:
query = query.filter(NoteRecord.created_at <= created_at_to)
if topic_ids is not None:
# 同じトピックIDを持つノートを取得するためのサブクエリ
# とりあえずANDを実装
subq = (
select(NoteTopicAssociation.note_id)
.group_by(NoteTopicAssociation.note_id)
.having(func.array_agg(NoteTopicAssociation.topic_id) == topic_ids)
.subquery()
)
query = query.join(subq, NoteRecord.note_id == subq.c.note_id)
if post_ids is not None:
query = query.filter(NoteRecord.post_id.in_(post_ids))
if language is not None:
query = query.filter(NoteRecord.language == language)
for note_record in query.all():
yield NoteModel(
note_id=note_record.note_id,
post_id=note_record.post_id,
topics=[
TopicModel(
topic_id=topic.topic_id,
label=topic.topic.label,
reference_count=sess.query(func.count(NoteTopicAssociation.note_id))
.filter(NoteTopicAssociation.topic_id == topic.topic_id)
.scalar()
or 0,
)
for topic in note_record.topics
],
language=note_record.language,
summary=note_record.summary,
created_at=note_record.created_at,
)

def get_posts(self) -> Generator[PostModel, None, None]:
with Session(self.engine) as sess:
for post_record in sess.query(PostRecord).all():
Expand Down
35 changes: 33 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import random
from collections.abc import Generator
from typing import List, Type
from typing import List, Type, Union
from unittest.mock import MagicMock, patch

from dotenv import load_dotenv
Expand All @@ -18,11 +18,15 @@

from birdxplorer.exceptions import UserEnrollmentNotFoundError
from birdxplorer.models import (
LanguageIdentifier,
Note,
NoteId,
ParticipantId,
Post,
PostId,
Topic,
TopicId,
TweetId,
TwitterTimestamp,
UserEnrollment,
XUser,
Expand Down Expand Up @@ -114,7 +118,10 @@ def user_enrollment_samples(

@fixture
def mock_storage(
user_enrollment_samples: List[UserEnrollment], topic_samples: List[Topic], post_samples: List[Post]
user_enrollment_samples: List[UserEnrollment],
topic_samples: List[Topic],
post_samples: List[Post],
note_samples: List[Note],
) -> Generator[MagicMock, None, None]:
mock = MagicMock(spec=Storage)

Expand All @@ -129,7 +136,31 @@ def _get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> Use
def _get_topics() -> Generator[Topic, None, None]:
yield from topic_samples

def _get_notes(
note_ids: Union[List[NoteId], None] = None,
created_at_from: Union[None, TwitterTimestamp] = None,
created_at_to: Union[None, TwitterTimestamp] = None,
topic_ids: Union[List[TopicId], None] = None,
post_ids: Union[List[TweetId], None] = None,
language: Union[LanguageIdentifier, None] = None,
) -> Generator[Note, None, None]:
for note in note_samples:
if note_ids is not None and note.note_id not in note_ids:
continue
if created_at_from is not None and note.created_at < created_at_from:
continue
if created_at_to is not None and note.created_at > created_at_to:
continue
if topic_ids is not None and not set(topic_ids).issubset({topic.topic_id for topic in note.topics}):
continue
if post_ids is not None and note.post_id not in post_ids:
continue
if language is not None and note.language != language:
continue
yield note

mock.get_topics.side_effect = _get_topics
mock.get_notes.side_effect = _get_notes

def _get_posts() -> Generator[Post, None, None]:
yield from post_samples
Expand Down
39 changes: 38 additions & 1 deletion tests/routers/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from fastapi.testclient import TestClient

from birdxplorer.models import Post, Topic, UserEnrollment
from birdxplorer.models import Note, Post, Topic, UserEnrollment


def test_user_enrollments_get(client: TestClient, user_enrollment_samples: List[UserEnrollment]) -> None:
Expand Down Expand Up @@ -81,3 +81,40 @@ def test_posts_get_created_at_end_filter_accepts_integer(client: TestClient, pos
def test_posts_get_timestamp_out_of_range(client: TestClient, post_samples: List[Post]) -> None:
response = client.get("/api/v1/data/posts/?createdAtStart=1153921700&createdAtEnd=1153921700")
assert response.status_code == 422


def test_notes_get(client: TestClient, note_samples: List[Note]) -> None:
response = client.get("/api/v1/data/notes")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(d.model_dump_json()) for d in note_samples]}


def test_notes_get_has_note_id_filter(client: TestClient, note_samples: List[Note]) -> None:
response = client.get(f"/api/v1/data/notes/?noteIds={note_samples[0].note_id}&noteIds={note_samples[2].note_id}")
assert response.status_code == 200
res_json = response.json()
assert res_json == {
"data": [json.loads(note_samples[0].model_dump_json()), json.loads(note_samples[2].model_dump_json())]
}


def test_notes_get_has_created_at_filter_from_and_to(client: TestClient, note_samples: List[Note]) -> None:
response = client.get("/api/v1/data/notes/?createdAtFrom=1152921601000&createdAtTo=1152921603000")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3)]}


def test_notes_get_has_created_at_filter_from(client: TestClient, note_samples: List[Note]) -> None:
response = client.get("/api/v1/data/notes/?createdAtFrom=1152921601000")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3, 4)]}


def test_notes_get_has_created_at_filter_to(client: TestClient, note_samples: List[Note]) -> None:
response = client.get("/api/v1/data/notes/?createdAtTo=1152921603000")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1, 2, 3)]}
Loading

0 comments on commit 3db9d8e

Please sign in to comment.