diff --git a/birdxplorer/routers/data.py b/birdxplorer/routers/data.py index 784d58d..37201ab 100644 --- a/birdxplorer/routers/data.py +++ b/birdxplorer/routers/data.py @@ -1,8 +1,18 @@ +from datetime import timezone from typing import List, Union +from dateutil.parser import parse as dateutil_parse from fastapi import APIRouter, Query -from ..models import BaseModel, ParticipantId, Post, PostId, Topic, UserEnrollment +from ..models import ( + BaseModel, + ParticipantId, + Post, + PostId, + Topic, + TwitterTimestamp, + UserEnrollment, +) from ..storage import Storage @@ -14,6 +24,17 @@ class PostListResponse(BaseModel): data: List[Post] +def str_to_twitter_timestamp(s: str) -> TwitterTimestamp: + tmp = dateutil_parse(s) + if tmp.tzinfo is None: + tmp = tmp.replace(tzinfo=timezone.utc) + return TwitterTimestamp.from_int(int(tmp.timestamp() * 1000)) + + +def ensure_twitter_timestamp(t: Union[str, TwitterTimestamp]) -> TwitterTimestamp: + return str_to_twitter_timestamp(t) if isinstance(t, str) else t + + def gen_router(storage: Storage) -> APIRouter: router = APIRouter() @@ -29,9 +50,30 @@ def get_topics() -> TopicListResponse: return TopicListResponse(data=list(storage.get_topics())) @router.get("/posts", response_model=PostListResponse) - def get_posts(post_id: Union[List[PostId], None] = Query(default=None)) -> PostListResponse: + def get_posts( + post_id: Union[List[PostId], None] = Query(default=None), + created_at_start: Union[None, str, TwitterTimestamp] = Query(default=None), + created_at_end: Union[None, str, TwitterTimestamp] = Query(default=None), + ) -> PostListResponse: if post_id is not None: return PostListResponse(data=list(storage.get_posts_by_ids(post_ids=post_id))) + if created_at_start is not None: + if created_at_end is not None: + return PostListResponse( + data=list( + storage.get_posts_by_created_at_range( + start=ensure_twitter_timestamp(created_at_start), + end=ensure_twitter_timestamp(created_at_end), + ) + ) + ) + return PostListResponse( + data=list(storage.get_posts_by_created_at_start(start=ensure_twitter_timestamp(created_at_start))) + ) + if created_at_end is not None: + return PostListResponse( + data=list(storage.get_posts_by_created_at_end(end=ensure_twitter_timestamp(created_at_end))) + ) return PostListResponse(data=list(storage.get_posts())) return router diff --git a/birdxplorer/storage.py b/birdxplorer/storage.py index cedc055..0de45bb 100644 --- a/birdxplorer/storage.py +++ b/birdxplorer/storage.py @@ -156,6 +156,17 @@ def get_posts(self) -> Generator[PostModel, None, None]: def get_posts_by_ids(self, post_ids: List[PostId]) -> Generator[PostModel, None, None]: raise NotImplementedError + def get_posts_by_created_at_range( + self, start: TwitterTimestamp, end: TwitterTimestamp + ) -> Generator[PostModel, None, None]: + raise NotImplementedError + + def get_posts_by_created_at_start(self, start: TwitterTimestamp) -> Generator[PostModel, None, None]: + raise NotImplementedError + + def get_posts_by_created_at_end(self, end: TwitterTimestamp) -> Generator[PostModel, None, None]: + raise NotImplementedError + def gen_storage(settings: GlobalSettings) -> Storage: engine = create_engine(settings.storage_settings.sqlalchemy_database_url) diff --git a/tests/conftest.py b/tests/conftest.py index 11152e9..8890d9f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -145,6 +145,27 @@ def _get_posts_by_ids(post_ids: List[PostId]) -> Generator[Post, None, None]: mock.get_posts_by_ids.side_effect = _get_posts_by_ids + def _get_posts_by_created_at_range(start: TwitterTimestamp, end: TwitterTimestamp) -> Generator[Post, None, None]: + for post in post_samples: + if start <= post.created_at <= end: + yield post + + mock.get_posts_by_created_at_range.side_effect = _get_posts_by_created_at_range + + def _get_posts_by_created_at_start(start: TwitterTimestamp) -> Generator[Post, None, None]: + for post in post_samples: + if start <= post.created_at: + yield post + + mock.get_posts_by_created_at_start.side_effect = _get_posts_by_created_at_start + + def _get_posts_by_created_at_end(end: TwitterTimestamp) -> Generator[Post, None, None]: + for post in post_samples: + if post.created_at <= end: + yield post + + mock.get_posts_by_created_at_end.side_effect = _get_posts_by_created_at_end + yield mock @@ -263,7 +284,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene x_user=x_user_samples[0], text="text12", media_details=None, - created_at=1152921700000, + created_at=1153921700000, like_count=10, repost_count=20, impression_count=30, @@ -274,7 +295,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene x_user=x_user_samples[1], text="text21", media_details=None, - created_at=1152921800000, + created_at=1154921800000, like_count=10, repost_count=20, impression_count=30, diff --git a/tests/routers/test_data.py b/tests/routers/test_data.py index 98a1af4..f55093d 100644 --- a/tests/routers/test_data.py +++ b/tests/routers/test_data.py @@ -34,3 +34,24 @@ def test_posts_get_has_post_id_filter(client: TestClient, post_samples: List[Pos assert res_json == { "data": [json.loads(post_samples[0].model_dump_json()), json.loads(post_samples[2].model_dump_json())] } + + +def test_posts_get_has_created_at_filter_start_and_end(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?createdAtStart=2006-7-25 00:00:00&createdAtEnd=2006-7-30 23:59:59") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(post_samples[1].model_dump_json())]} + + +def test_posts_get_has_created_at_filter_start(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?createdAtStart=2006-7-25 00:00:00") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2)]} + + +def test_posts_get_has_created_at_filter_end(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?createdAtEnd=2006-7-30 00:00:00") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(post_samples[i].model_dump_json()) for i in (0, 1)]}