Skip to content

Commit

Permalink
Merge pull request #82 from codeforjapan/issue-51
Browse files Browse the repository at this point in the history
/v1/data/posts に NoteId による絞り込みを追加する
  • Loading branch information
yu23ki14 authored Aug 10, 2024
2 parents 5ac6bba + 776d0d1 commit cb337dc
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 0 deletions.
3 changes: 3 additions & 0 deletions api/birdxplorer_api/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,14 @@ def get_notes(
@router.get("/posts", response_model=PostListResponse)
def get_posts(
post_id: Union[List[PostId], None] = Query(default=None),
note_id: Union[List[NoteId], None] = Query(default=None),
created_at_start: Union[None, TwitterTimestamp, str] = Query(default=None),
created_at_end: Union[None, TwitterTimestamp, str] = Query(default=None),
) -> PostListResponse:
if post_id is not None:
return PostListResponse(data=list(storage.get_posts_by_ids(post_ids=post_id)))
if note_id is not None:
return PostListResponse(data=list(storage.get_posts_by_note_ids(note_ids=note_id)))
if created_at_start is not None:
if created_at_end is not None:
return PostListResponse(
Expand Down
9 changes: 9 additions & 0 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,15 @@ def _get_posts_by_ids(post_ids: List[PostId]) -> Generator[Post, None, None]:

mock.get_posts_by_ids.side_effect = _get_posts_by_ids

def _get_posts_by_note_ids(note_ids: List[NoteId]) -> Generator[Post, None, None]:
for post in post_samples:
for note in note_samples:
if note.note_id in note_ids and post.post_id == note.post_id:
yield post
break

mock.get_posts_by_note_ids.side_effect = _get_posts_by_note_ids

def _get_posts_by_created_at_range(start: TwitterTimestamp, end: TwitterTimestamp) -> Generator[Post, None, None]:
for post in post_samples:
if start <= post.created_at < end:
Expand Down
7 changes: 7 additions & 0 deletions api/tests/routers/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ def test_posts_get_has_post_id_filter(client: TestClient, post_samples: List[Pos
}


def test_posts_get_has_note_id_filter(client: TestClient, post_samples: List[Post], note_samples: List[Note]) -> None:
response = client.get(f"/api/v1/data/posts/?noteId={','.join([n.note_id for n in note_samples])}")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(post_samples[0].model_dump_json())]}


def test_posts_get_has_created_at_filter_start_and_end(client: TestClient, post_samples: List[Post]) -> None:
response = client.get("/api/v1/data/posts/?createdAtStart=2006-7-25 00:00:00&createdAtEnd=2006-7-30 23:59:59")
assert response.status_code == 200
Expand Down
10 changes: 10 additions & 0 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,16 @@ def get_posts_by_created_at_end(self, end: TwitterTimestamp) -> Generator[PostMo
for post_record in sess.query(PostRecord).filter(PostRecord.created_at < end).all():
yield self._post_record_to_model(post_record)

def get_posts_by_note_ids(self, note_ids: List[NoteId]) -> Generator[PostModel, None, None]:
query = (
select(PostRecord)
.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id)
.where(NoteRecord.note_id.in_(note_ids))
)
with Session(self.engine) as sess:
for post_record in sess.execute(query).scalars().all():
yield self._post_record_to_model(post_record)


def gen_storage(settings: GlobalSettings) -> Storage:
engine = create_engine(settings.storage_settings.sqlalchemy_database_url)
Expand Down

0 comments on commit cb337dc

Please sign in to comment.