Skip to content

Commit

Permalink
add pagenation to notes endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
yu23ki14 authored and sushichan044 committed Nov 1, 2024
1 parent 79dbfb0 commit 6d0db1f
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 22 deletions.
52 changes: 41 additions & 11 deletions api/birdxplorer_api/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class TopicListResponse(BaseModel):

class NoteListResponse(BaseModel):
data: NoteListWithExamples
meta: PaginationMeta


class PostListResponse(BaseModel):
Expand Down Expand Up @@ -141,29 +142,58 @@ def get_topics() -> TopicListResponse:

@router.get("/notes", description=V1DataNotesDocs.description, response_model=NoteListResponse)
def get_notes(
request: Request,
note_ids: Union[List[NoteId], None] = Query(default=None, **V1DataNotesDocs.params["note_ids"]),
created_at_from: Union[None, TwitterTimestamp] = Query(
default=None, **V1DataNotesDocs.params["created_at_from"]
),
created_at_to: Union[None, TwitterTimestamp] = Query(default=None, **V1DataNotesDocs.params["created_at_to"]),
offset: int = Query(default=0, ge=0),
limit: int = Query(default=100, gt=0, le=1000),
topic_ids: Union[List[TopicId], None] = Query(default=None, **V1DataNotesDocs.params["topic_ids"]),
post_ids: Union[List[PostId], None] = Query(default=None, **V1DataNotesDocs.params["post_ids"]),
current_status: Union[None, List[str]] = Query(default=None, **V1DataNotesDocs.params["current_status"]),
language: Union[LanguageIdentifier, None] = Query(default=None, **V1DataNotesDocs.params["language"]),
) -> NoteListResponse:
return NoteListResponse(
data=list(
storage.get_notes(
note_ids=note_ids,
created_at_from=created_at_from,
created_at_to=created_at_to,
topic_ids=topic_ids,
post_ids=post_ids,
current_status=current_status,
language=language,
)
if created_at_from is not None and isinstance(created_at_from, str):
created_at_from = ensure_twitter_timestamp(created_at_from)
if created_at_to is not None and isinstance(created_at_to, str):
created_at_to = ensure_twitter_timestamp(created_at_to)

notes = list(
storage.get_notes(
note_ids=note_ids,
created_at_from=created_at_from,
created_at_to=created_at_to,
topic_ids=topic_ids,
post_ids=post_ids,
current_status=current_status,
language=language,
offset=offset,
limit=limit,
)
)
total_count = storage.get_number_of_notes(
note_ids=note_ids,
created_at_from=created_at_from,
created_at_to=created_at_to,
topic_ids=topic_ids,
post_ids=post_ids,
current_status=current_status,
language=language,
)

baseurl = str(request.url).split("?")[0]
next_offset = offset + limit
prev_offset = max(offset - limit, 0)
next_url = None
if next_offset < total_count:
next_url = f"{baseurl}?offset={next_offset}&limit={limit}"
prev_url = None
if offset > 0:
prev_url = f"{baseurl}?offset={prev_offset}&limit={limit}"

return NoteListResponse(data=notes, meta=PaginationMeta(next=next_url, prev=prev_url))

@router.get("/posts", description=V1DataPostsDocs.description, response_model=PostListResponse)
def get_posts(
Expand Down
20 changes: 19 additions & 1 deletion api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ def _get_user_enrollment_by_participant_id(
def _get_topics() -> Generator[Topic, None, None]:
yield from topic_samples

mock.get_topics.side_effect = _get_topics

def _get_notes(
note_ids: Union[List[NoteId], None] = None,
created_at_from: Union[None, TwitterTimestamp] = None,
Expand All @@ -329,6 +331,8 @@ def _get_notes(
post_ids: Union[List[PostId], None] = None,
current_status: Union[None, List[str]] = None,
language: Union[LanguageIdentifier, None] = None,
offset: Union[int, None] = None,
limit: Union[int, None] = None,
) -> Generator[Note, None, None]:
for note in note_samples:
if note_ids is not None and note.note_id not in note_ids:
Expand All @@ -347,9 +351,23 @@ def _get_notes(
continue
yield note

mock.get_topics.side_effect = _get_topics
mock.get_notes.side_effect = _get_notes

def _get_number_of_notes(
note_ids: Union[List[NoteId], None] = None,
created_at_from: Union[None, TwitterTimestamp] = None,
created_at_to: Union[None, TwitterTimestamp] = None,
topic_ids: Union[List[TopicId], None] = None,
post_ids: Union[List[PostId], None] = None,
current_status: Union[None, List[str]] = None,
language: Union[LanguageIdentifier, None] = None,
) -> int:
return len(
list(_get_notes(note_ids, created_at_from, created_at_to, topic_ids, post_ids, current_status, language))
)

mock.get_number_of_notes.side_effect = _get_number_of_notes

def _get_posts(
post_ids: Union[List[PostId], None] = None,
note_ids: Union[List[NoteId], None] = None,
Expand Down
26 changes: 20 additions & 6 deletions api/tests/routers/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ def test_notes_get(client: TestClient, note_samples: List[Note]) -> None:
response = client.get("/api/v1/data/notes")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(d.model_dump_json()) for d in note_samples]}
assert res_json == {
"data": [json.loads(d.model_dump_json()) for d in note_samples],
"meta": {"next": None, "prev": None},
}


def test_notes_get_has_note_id_filter(client: TestClient, note_samples: List[Note]) -> None:
Expand All @@ -191,29 +194,39 @@ def test_notes_get_has_note_id_filter(client: TestClient, note_samples: List[Not
"data": [
json.loads(note_samples[0].model_dump_json()),
json.loads(note_samples[2].model_dump_json()),
]
],
"meta": {"next": None, "prev": None},
}


def test_notes_get_has_created_at_filter_from_and_to(client: TestClient, note_samples: List[Note]) -> None:
response = client.get("/api/v1/data/notes/?createdAtFrom=1152921601000&createdAtTo=1152921603000")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3)]}
assert res_json == {
"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3)],
"meta": {"next": None, "prev": None},
}


def test_notes_get_has_created_at_filter_from(client: TestClient, note_samples: List[Note]) -> None:
response = client.get("/api/v1/data/notes/?createdAtFrom=1152921601000")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3, 4)]}
assert res_json == {
"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3, 4)],
"meta": {"next": None, "prev": None},
}


def test_notes_get_has_created_at_filter_to(client: TestClient, note_samples: List[Note]) -> None:
response = client.get("/api/v1/data/notes/?createdAtTo=1152921603000")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1, 2, 3)]}
assert res_json == {
"data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1, 2, 3)],
"meta": {"next": None, "prev": None},
}


def test_notes_get_has_topic_id_filter(client: TestClient, note_samples: List[Note]) -> None:
Expand All @@ -222,5 +235,6 @@ def test_notes_get_has_topic_id_filter(client: TestClient, note_samples: List[No
assert response.status_code == 200
res_json = response.json()
assert res_json == {
"data": [json.loads(correct_notes[i].model_dump_json()) for i in range(correct_notes.__len__())]
"data": [json.loads(correct_notes[i].model_dump_json()) for i in range(correct_notes.__len__())],
"meta": {"next": None, "prev": None},
}
50 changes: 46 additions & 4 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ def get_notes(
post_ids: Union[List[PostId], None] = None,
current_status: Union[None, List[str]] = None,
language: Union[LanguageIdentifier, None] = None,
offset: Union[int, None] = None,
limit: int = 100,
) -> Generator[NoteModel, None, None]:
with Session(self.engine) as sess:
query = sess.query(NoteRecord)
Expand All @@ -363,6 +365,9 @@ def get_notes(
query = query.filter(NoteRecord.language == language)
if current_status is not None:
query = query.filter(NoteRecord.current_status.in_(current_status))
if offset is not None:
query = query.offset(offset)
query = query.limit(limit)
for note_record in query.all():
yield NoteModel(
note_id=note_record.note_id,
Expand All @@ -371,10 +376,11 @@ def get_notes(
TopicModel(
topic_id=topic.topic_id,
label=topic.topic.label,
reference_count=sess.query(func.count(NoteTopicAssociation.note_id))
.filter(NoteTopicAssociation.topic_id == topic.topic_id)
.scalar()
or 0,
reference_count=0,
# reference_count=sess.query(func.count(NoteTopicAssociation.note_id))
# .filter(NoteTopicAssociation.topic_id == topic.topic_id)
# .scalar()
# or 0,
)
for topic in note_record.topics
],
Expand All @@ -384,6 +390,42 @@ def get_notes(
created_at=note_record.created_at,
)

def get_number_of_notes(
self,
note_ids: Union[List[NoteId], None] = None,
created_at_from: Union[None, TwitterTimestamp] = None,
created_at_to: Union[None, TwitterTimestamp] = None,
topic_ids: Union[List[TopicId], None] = None,
post_ids: Union[List[PostId], None] = None,
current_status: Union[None, List[str]] = None,
language: Union[LanguageIdentifier, None] = None,
) -> int:
with Session(self.engine) as sess:
query = sess.query(NoteRecord)
if note_ids is not None:
query = query.filter(NoteRecord.note_id.in_(note_ids))
if created_at_from is not None:
query = query.filter(NoteRecord.created_at >= created_at_from)
if created_at_to is not None:
query = query.filter(NoteRecord.created_at <= created_at_to)
if topic_ids is not None:
# 同じトピックIDを持つノートを取得するためのサブクエリ
# とりあえずANDを実装
subq = (
select(NoteTopicAssociation.note_id)
.group_by(NoteTopicAssociation.note_id)
.having(func.bool_or(NoteTopicAssociation.topic_id.in_(topic_ids)))
.subquery()
)
query = query.join(subq, NoteRecord.note_id == subq.c.note_id)
if post_ids is not None:
query = query.filter(NoteRecord.post_id.in_(post_ids))
if language is not None:
query = query.filter(NoteRecord.language == language)
if current_status is not None:
query = query.filter(NoteRecord.current_status.in_(current_status))
return query.count()

def get_posts(
self,
post_ids: Union[List[PostId], None] = None,
Expand Down

0 comments on commit 6d0db1f

Please sign in to comment.