Skip to content

Commit

Permalink
feat(app): add created_at range query for GET posts
Browse files Browse the repository at this point in the history
  • Loading branch information
osoken committed Mar 24, 2024
1 parent 088454f commit 3bb1e6c
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 4 deletions.
46 changes: 44 additions & 2 deletions birdxplorer/routers/data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from datetime import timezone
from typing import List, Union

from dateutil.parser import parse as dateutil_parse
from fastapi import APIRouter, Query

from ..models import BaseModel, ParticipantId, Post, PostId, Topic, UserEnrollment
from ..models import (
BaseModel,
ParticipantId,
Post,
PostId,
Topic,
TwitterTimestamp,
UserEnrollment,
)
from ..storage import Storage


Expand All @@ -14,6 +24,17 @@ class PostListResponse(BaseModel):
data: List[Post]


def str_to_twitter_timestamp(s: str) -> TwitterTimestamp:
tmp = dateutil_parse(s)
if tmp.tzinfo is None:
tmp = tmp.replace(tzinfo=timezone.utc)
return TwitterTimestamp.from_int(int(tmp.timestamp() * 1000))


def ensure_twitter_timestamp(t: Union[str, TwitterTimestamp]) -> TwitterTimestamp:
return str_to_twitter_timestamp(t) if isinstance(t, str) else t


def gen_router(storage: Storage) -> APIRouter:
router = APIRouter()

Expand All @@ -29,9 +50,30 @@ def get_topics() -> TopicListResponse:
return TopicListResponse(data=list(storage.get_topics()))

@router.get("/posts", response_model=PostListResponse)
def get_posts(post_id: Union[List[PostId], None] = Query(default=None)) -> PostListResponse:
def get_posts(
post_id: Union[List[PostId], None] = Query(default=None),
created_at_start: Union[None, str, TwitterTimestamp] = Query(default=None),
created_at_end: Union[None, str, TwitterTimestamp] = Query(default=None),
) -> PostListResponse:
if post_id is not None:
return PostListResponse(data=list(storage.get_posts_by_ids(post_ids=post_id)))
if created_at_start is not None:
if created_at_end is not None:
return PostListResponse(
data=list(
storage.get_posts_by_created_at_range(
start=ensure_twitter_timestamp(created_at_start),
end=ensure_twitter_timestamp(created_at_end),
)
)
)
return PostListResponse(
data=list(storage.get_posts_by_created_at_start(start=ensure_twitter_timestamp(created_at_start)))
)
if created_at_end is not None:
return PostListResponse(
data=list(storage.get_posts_by_created_at_end(end=ensure_twitter_timestamp(created_at_end)))
)
return PostListResponse(data=list(storage.get_posts()))

return router
11 changes: 11 additions & 0 deletions birdxplorer/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@ def get_posts(self) -> Generator[PostModel, None, None]:
def get_posts_by_ids(self, post_ids: List[PostId]) -> Generator[PostModel, None, None]:
raise NotImplementedError

def get_posts_by_created_at_range(
self, start: TwitterTimestamp, end: TwitterTimestamp
) -> Generator[PostModel, None, None]:
raise NotImplementedError

def get_posts_by_created_at_start(self, start: TwitterTimestamp) -> Generator[PostModel, None, None]:
raise NotImplementedError

def get_posts_by_created_at_end(self, end: TwitterTimestamp) -> Generator[PostModel, None, None]:
raise NotImplementedError


def gen_storage(settings: GlobalSettings) -> Storage:
engine = create_engine(settings.storage_settings.sqlalchemy_database_url)
Expand Down
25 changes: 23 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,27 @@ def _get_posts_by_ids(post_ids: List[PostId]) -> Generator[Post, None, None]:

mock.get_posts_by_ids.side_effect = _get_posts_by_ids

def _get_posts_by_created_at_range(start: TwitterTimestamp, end: TwitterTimestamp) -> Generator[Post, None, None]:
for post in post_samples:
if start <= post.created_at <= end:
yield post

mock.get_posts_by_created_at_range.side_effect = _get_posts_by_created_at_range

def _get_posts_by_created_at_start(start: TwitterTimestamp) -> Generator[Post, None, None]:
for post in post_samples:
if start <= post.created_at:
yield post

mock.get_posts_by_created_at_start.side_effect = _get_posts_by_created_at_start

def _get_posts_by_created_at_end(end: TwitterTimestamp) -> Generator[Post, None, None]:
for post in post_samples:
if post.created_at <= end:
yield post

mock.get_posts_by_created_at_end.side_effect = _get_posts_by_created_at_end

yield mock


Expand Down Expand Up @@ -263,7 +284,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene
x_user=x_user_samples[0],
text="text12",
media_details=None,
created_at=1152921700000,
created_at=1153921700000,
like_count=10,
repost_count=20,
impression_count=30,
Expand All @@ -274,7 +295,7 @@ def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Gene
x_user=x_user_samples[1],
text="text21",
media_details=None,
created_at=1152921800000,
created_at=1154921800000,
like_count=10,
repost_count=20,
impression_count=30,
Expand Down
21 changes: 21 additions & 0 deletions tests/routers/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,24 @@ def test_posts_get_has_post_id_filter(client: TestClient, post_samples: List[Pos
assert res_json == {
"data": [json.loads(post_samples[0].model_dump_json()), json.loads(post_samples[2].model_dump_json())]
}


def test_posts_get_has_created_at_filter_start_and_end(client: TestClient, post_samples: List[Post]) -> None:
response = client.get("/api/v1/data/posts/?createdAtStart=2006-7-25 00:00:00&createdAtEnd=2006-7-30 23:59:59")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(post_samples[1].model_dump_json())]}


def test_posts_get_has_created_at_filter_start(client: TestClient, post_samples: List[Post]) -> None:
response = client.get("/api/v1/data/posts/?createdAtStart=2006-7-25 00:00:00")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2)]}


def test_posts_get_has_created_at_filter_end(client: TestClient, post_samples: List[Post]) -> None:
response = client.get("/api/v1/data/posts/?createdAtEnd=2006-7-30 00:00:00")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(post_samples[i].model_dump_json()) for i in (0, 1)]}

0 comments on commit 3bb1e6c

Please sign in to comment.