From 3ca5c3615880fd557d50c63d2feb10decda77f51 Mon Sep 17 00:00:00 2001 From: yu23ki14 Date: Fri, 1 Nov 2024 11:10:20 +0900 Subject: [PATCH 1/2] add pagenation to notes endpoint --- api/birdxplorer_api/routers/data.py | 52 ++++++++++++++++++++++------ api/tests/conftest.py | 20 ++++++++++- api/tests/routers/test_data.py | 26 ++++++++++---- common/birdxplorer_common/storage.py | 50 +++++++++++++++++++++++--- 4 files changed, 126 insertions(+), 22 deletions(-) diff --git a/api/birdxplorer_api/routers/data.py b/api/birdxplorer_api/routers/data.py index e51e13d..2955df7 100644 --- a/api/birdxplorer_api/routers/data.py +++ b/api/birdxplorer_api/routers/data.py @@ -28,6 +28,7 @@ class TopicListResponse(BaseModel): class NoteListResponse(BaseModel): data: List[Note] + meta: PaginationMeta class PostListResponse(BaseModel): @@ -71,27 +72,56 @@ def get_topics() -> TopicListResponse: @router.get("/notes", response_model=NoteListResponse) def get_notes( + request: Request, note_ids: Union[List[NoteId], None] = Query(default=None), created_at_from: Union[None, TwitterTimestamp] = Query(default=None), created_at_to: Union[None, TwitterTimestamp] = Query(default=None), + offset: int = Query(default=0, ge=0), + limit: int = Query(default=100, gt=0, le=1000), topic_ids: Union[List[TopicId], None] = Query(default=None), post_ids: Union[List[PostId], None] = Query(default=None), current_status: Union[None, List[str]] = Query(default=None), language: Union[LanguageIdentifier, None] = Query(default=None), ) -> NoteListResponse: - return NoteListResponse( - data=list( - storage.get_notes( - note_ids=note_ids, - created_at_from=created_at_from, - created_at_to=created_at_to, - topic_ids=topic_ids, - post_ids=post_ids, - current_status=current_status, - language=language, - ) + if created_at_from is not None and isinstance(created_at_from, str): + created_at_from = ensure_twitter_timestamp(created_at_from) + if created_at_to is not None and isinstance(created_at_to, str): + created_at_to = ensure_twitter_timestamp(created_at_to) + + notes = list( + storage.get_notes( + note_ids=note_ids, + created_at_from=created_at_from, + created_at_to=created_at_to, + topic_ids=topic_ids, + post_ids=post_ids, + current_status=current_status, + language=language, + offset=offset, + limit=limit, ) ) + total_count = storage.get_number_of_notes( + note_ids=note_ids, + created_at_from=created_at_from, + created_at_to=created_at_to, + topic_ids=topic_ids, + post_ids=post_ids, + current_status=current_status, + language=language, + ) + + baseurl = str(request.url).split("?")[0] + next_offset = offset + limit + prev_offset = max(offset - limit, 0) + next_url = None + if next_offset < total_count: + next_url = f"{baseurl}?offset={next_offset}&limit={limit}" + prev_url = None + if offset > 0: + prev_url = f"{baseurl}?offset={prev_offset}&limit={limit}" + + return NoteListResponse(data=notes, meta=PaginationMeta(next=next_url, prev=prev_url)) @router.get("/posts", response_model=PostListResponse) def get_posts( diff --git a/api/tests/conftest.py b/api/tests/conftest.py index affa005..cdb8301 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -321,6 +321,8 @@ def _get_user_enrollment_by_participant_id( def _get_topics() -> Generator[Topic, None, None]: yield from topic_samples + mock.get_topics.side_effect = _get_topics + def _get_notes( note_ids: Union[List[NoteId], None] = None, created_at_from: Union[None, TwitterTimestamp] = None, @@ -329,6 +331,8 @@ def _get_notes( post_ids: Union[List[PostId], None] = None, current_status: Union[None, List[str]] = None, language: Union[LanguageIdentifier, None] = None, + offset: Union[int, None] = None, + limit: Union[int, None] = None, ) -> Generator[Note, None, None]: for note in note_samples: if note_ids is not None and note.note_id not in note_ids: @@ -347,9 +351,23 @@ def _get_notes( continue yield note - mock.get_topics.side_effect = _get_topics mock.get_notes.side_effect = _get_notes + def _get_number_of_notes( + note_ids: Union[List[NoteId], None] = None, + created_at_from: Union[None, TwitterTimestamp] = None, + created_at_to: Union[None, TwitterTimestamp] = None, + topic_ids: Union[List[TopicId], None] = None, + post_ids: Union[List[PostId], None] = None, + current_status: Union[None, List[str]] = None, + language: Union[LanguageIdentifier, None] = None, + ) -> int: + return len( + list(_get_notes(note_ids, created_at_from, created_at_to, topic_ids, post_ids, current_status, language)) + ) + + mock.get_number_of_notes.side_effect = _get_number_of_notes + def _get_posts( post_ids: Union[List[PostId], None] = None, note_ids: Union[List[NoteId], None] = None, diff --git a/api/tests/routers/test_data.py b/api/tests/routers/test_data.py index e971098..c9dfbff 100644 --- a/api/tests/routers/test_data.py +++ b/api/tests/routers/test_data.py @@ -180,7 +180,10 @@ def test_notes_get(client: TestClient, note_samples: List[Note]) -> None: response = client.get("/api/v1/data/notes") assert response.status_code == 200 res_json = response.json() - assert res_json == {"data": [json.loads(d.model_dump_json()) for d in note_samples]} + assert res_json == { + "data": [json.loads(d.model_dump_json()) for d in note_samples], + "meta": {"next": None, "prev": None}, + } def test_notes_get_has_note_id_filter(client: TestClient, note_samples: List[Note]) -> None: @@ -191,7 +194,8 @@ def test_notes_get_has_note_id_filter(client: TestClient, note_samples: List[Not "data": [ json.loads(note_samples[0].model_dump_json()), json.loads(note_samples[2].model_dump_json()), - ] + ], + "meta": {"next": None, "prev": None}, } @@ -199,21 +203,30 @@ def test_notes_get_has_created_at_filter_from_and_to(client: TestClient, note_sa response = client.get("/api/v1/data/notes/?createdAtFrom=1152921601000&createdAtTo=1152921603000") assert response.status_code == 200 res_json = response.json() - assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3)]} + assert res_json == { + "data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3)], + "meta": {"next": None, "prev": None}, + } def test_notes_get_has_created_at_filter_from(client: TestClient, note_samples: List[Note]) -> None: response = client.get("/api/v1/data/notes/?createdAtFrom=1152921601000") assert response.status_code == 200 res_json = response.json() - assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3, 4)]} + assert res_json == { + "data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3, 4)], + "meta": {"next": None, "prev": None}, + } def test_notes_get_has_created_at_filter_to(client: TestClient, note_samples: List[Note]) -> None: response = client.get("/api/v1/data/notes/?createdAtTo=1152921603000") assert response.status_code == 200 res_json = response.json() - assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1, 2, 3)]} + assert res_json == { + "data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1, 2, 3)], + "meta": {"next": None, "prev": None}, + } def test_notes_get_has_topic_id_filter(client: TestClient, note_samples: List[Note]) -> None: @@ -222,5 +235,6 @@ def test_notes_get_has_topic_id_filter(client: TestClient, note_samples: List[No assert response.status_code == 200 res_json = response.json() assert res_json == { - "data": [json.loads(correct_notes[i].model_dump_json()) for i in range(correct_notes.__len__())] + "data": [json.loads(correct_notes[i].model_dump_json()) for i in range(correct_notes.__len__())], + "meta": {"next": None, "prev": None}, } diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index df77fc4..7aaf51e 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -338,6 +338,8 @@ def get_notes( post_ids: Union[List[PostId], None] = None, current_status: Union[None, List[str]] = None, language: Union[LanguageIdentifier, None] = None, + offset: Union[int, None] = None, + limit: int = 100, ) -> Generator[NoteModel, None, None]: with Session(self.engine) as sess: query = sess.query(NoteRecord) @@ -363,6 +365,9 @@ def get_notes( query = query.filter(NoteRecord.language == language) if current_status is not None: query = query.filter(NoteRecord.current_status.in_(current_status)) + if offset is not None: + query = query.offset(offset) + query = query.limit(limit) for note_record in query.all(): yield NoteModel( note_id=note_record.note_id, @@ -371,10 +376,11 @@ def get_notes( TopicModel( topic_id=topic.topic_id, label=topic.topic.label, - reference_count=sess.query(func.count(NoteTopicAssociation.note_id)) - .filter(NoteTopicAssociation.topic_id == topic.topic_id) - .scalar() - or 0, + reference_count=0, + # reference_count=sess.query(func.count(NoteTopicAssociation.note_id)) + # .filter(NoteTopicAssociation.topic_id == topic.topic_id) + # .scalar() + # or 0, ) for topic in note_record.topics ], @@ -384,6 +390,42 @@ def get_notes( created_at=note_record.created_at, ) + def get_number_of_notes( + self, + note_ids: Union[List[NoteId], None] = None, + created_at_from: Union[None, TwitterTimestamp] = None, + created_at_to: Union[None, TwitterTimestamp] = None, + topic_ids: Union[List[TopicId], None] = None, + post_ids: Union[List[PostId], None] = None, + current_status: Union[None, List[str]] = None, + language: Union[LanguageIdentifier, None] = None, + ) -> int: + with Session(self.engine) as sess: + query = sess.query(NoteRecord) + if note_ids is not None: + query = query.filter(NoteRecord.note_id.in_(note_ids)) + if created_at_from is not None: + query = query.filter(NoteRecord.created_at >= created_at_from) + if created_at_to is not None: + query = query.filter(NoteRecord.created_at <= created_at_to) + if topic_ids is not None: + # 同じトピックIDを持つノートを取得するためのサブクエリ + # とりあえずANDを実装 + subq = ( + select(NoteTopicAssociation.note_id) + .group_by(NoteTopicAssociation.note_id) + .having(func.bool_or(NoteTopicAssociation.topic_id.in_(topic_ids))) + .subquery() + ) + query = query.join(subq, NoteRecord.note_id == subq.c.note_id) + if post_ids is not None: + query = query.filter(NoteRecord.post_id.in_(post_ids)) + if language is not None: + query = query.filter(NoteRecord.language == language) + if current_status is not None: + query = query.filter(NoteRecord.current_status.in_(current_status)) + return query.count() + def get_posts( self, post_ids: Union[List[PostId], None] = None, From 8d3f32f118c411ee568f539bb40f0df7249a02a3 Mon Sep 17 00:00:00 2001 From: yu23ki14 Date: Fri, 1 Nov 2024 11:19:17 +0900 Subject: [PATCH 2/2] fix test --- common/tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/common/tests/conftest.py b/common/tests/conftest.py index c972acd..bc28cfb 100644 --- a/common/tests/conftest.py +++ b/common/tests/conftest.py @@ -146,6 +146,7 @@ def link_samples(link_factory: LinkFactory) -> Generator[List[Link], None, None] @fixture def note_samples(note_factory: NoteFactory, topic_samples: List[Topic]) -> Generator[List[Note], None, None]: + topic_samples = [t.model_copy(update={"reference_count": 0}) for t in topic_samples] notes = [ note_factory.build( note_id="1234567890123456781",