Skip to content

Commit

Permalink
Merge pull request #117 from codeforjapan/feat/issue-105-posts-url-se…
Browse files Browse the repository at this point in the history
…arch

Feat/issue 105 posts url search
  • Loading branch information
yu23ki14 authored Oct 6, 2024
2 parents 40e69b0 + 1e04737 commit 3758cb1
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 1 deletion.
3 changes: 3 additions & 0 deletions api/birdxplorer_api/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def get_posts(
offset: int = Query(default=0, ge=0),
limit: int = Query(default=100, gt=0, le=1000),
search_text: Union[None, str] = Query(default=None),
search_url: Union[None, HttpUrl] = Query(default=None),
) -> PostListResponse:
if created_at_from is not None and isinstance(created_at_from, str):
created_at_from = ensure_twitter_timestamp(created_at_from)
Expand All @@ -115,6 +116,7 @@ def get_posts(
start=created_at_from,
end=created_at_to,
search_text=search_text,
search_url=search_url,
offset=offset,
limit=limit,
)
Expand All @@ -125,6 +127,7 @@ def get_posts(
start=created_at_from,
end=created_at_to,
search_text=search_text,
search_url=search_url,
)

for post in posts:
Expand Down
14 changes: 13 additions & 1 deletion api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
from polyfactory import Use
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.pytest_plugin import register_fixture
from pydantic import HttpUrl
from pytest import fixture

from birdxplorer_common.exceptions import UserEnrollmentNotFoundError
from birdxplorer_common.models import (
LanguageIdentifier,
Link,
LinkId,
Note,
NoteId,
ParticipantId,
Expand Down Expand Up @@ -268,6 +270,7 @@ def mock_storage(
topic_samples: List[Topic],
post_samples: List[Post],
note_samples: List[Note],
link_samples: List[Link],
) -> Generator[MagicMock, None, None]:
mock = MagicMock(spec=Storage)

Expand Down Expand Up @@ -319,11 +322,17 @@ def _get_posts(
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
search_url: Union[HttpUrl, None] = None,
offset: Union[int, None] = None,
limit: Union[int, None] = None,
) -> Generator[Post, None, None]:
gen_count = 0
actual_gen_count = 0
url_id: LinkId | None = None
if search_url is not None:
url_candidates = [link.link_id for link in link_samples if link.url == search_url]
if len(url_candidates) > 0:
url_id = url_candidates[0]
for idx, post in enumerate(post_samples):
if limit is not None and actual_gen_count >= limit:
break
Expand All @@ -339,6 +348,8 @@ def _get_posts(
continue
if search_text is not None and search_text not in post.text:
continue
if search_url is not None and url_id not in [link.link_id for link in post.links]:
continue
gen_count += 1
if offset is not None and gen_count <= offset:
continue
Expand All @@ -353,8 +364,9 @@ def _get_number_of_posts(
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
search_url: Union[HttpUrl, None] = None,
) -> int:
return len(list(_get_posts(post_ids, note_ids, start, end, search_text)))
return len(list(_get_posts(post_ids, note_ids, start, end, search_text, search_url)))

mock.get_number_of_posts.side_effect = _get_number_of_posts

Expand Down
10 changes: 10 additions & 0 deletions api/tests/routers/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ def test_posts_search_by_text(client: TestClient, post_samples: List[Post]) -> N
}


def test_posts_search_by_url(client: TestClient, post_samples: List[Post]) -> None:
response = client.get("/api/v1/data/posts/?searchUrl=https%3A%2F%2Fexample.com%2Fsh3")
assert response.status_code == 200
res_json = response.json()
assert res_json == {
"data": [json.loads(post_samples[i].model_dump_json()) for i in (2, 3)],
"meta": {"next": None, "prev": None},
}


def test_notes_get(client: TestClient, note_samples: List[Note]) -> None:
response = client.get("/api/v1/data/notes")
assert response.status_code == 200
Expand Down
14 changes: 14 additions & 0 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def get_posts(
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
search_url: Union[HttpUrl, None] = None,
offset: Union[int, None] = None,
limit: int = 100,
) -> Generator[PostModel, None, None]:
Expand All @@ -354,6 +355,12 @@ def get_posts(
query = query.filter(PostRecord.created_at < end)
if search_text is not None:
query = query.filter(PostRecord.text.like(f"%{search_text}%"))
if search_url is not None:
query = (
query.join(PostLinkAssociation, PostLinkAssociation.post_id == PostRecord.post_id)
.join(LinkRecord, LinkRecord.link_id == PostLinkAssociation.link_id)
.filter(LinkRecord.url == search_url)
)
if offset is not None:
query = query.offset(offset)
query = query.limit(limit)
Expand All @@ -367,6 +374,7 @@ def get_number_of_posts(
start: Union[TwitterTimestamp, None] = None,
end: Union[TwitterTimestamp, None] = None,
search_text: Union[str, None] = None,
search_url: Union[HttpUrl, None] = None,
) -> int:
with Session(self.engine) as sess:
query = sess.query(PostRecord)
Expand All @@ -382,6 +390,12 @@ def get_number_of_posts(
query = query.filter(PostRecord.created_at < end)
if search_text is not None:
query = query.filter(PostRecord.text.like(f"%{search_text}%"))
if search_url is not None:
query = (
query.join(PostLinkAssociation, PostLinkAssociation.post_id == PostRecord.post_id)
.join(LinkRecord, LinkRecord.link_id == PostLinkAssociation.link_id)
.filter(LinkRecord.url == search_url)
)
return query.count()


Expand Down
3 changes: 3 additions & 0 deletions common/tests/test_storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, List

import pytest
from pydantic import HttpUrl
from sqlalchemy.engine import Engine

from birdxplorer_common.models import (
Expand Down Expand Up @@ -41,6 +42,7 @@ def test_get_topic_list(
[dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3, 4]],
[dict(end=TwitterTimestamp.from_int(1153921700000)), [0]],
[dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]],
[dict(search_url=HttpUrl("https://example.com/sh3")), [2, 3]],
[dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]],
[dict(offset=1, limit=1, search_text="https://t.co/xxxxxxxxxxx/"), [2]],
],
Expand Down Expand Up @@ -70,6 +72,7 @@ def test_get_post(
[dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2, 3, 4]],
[dict(end=TwitterTimestamp.from_int(1153921700000)), [0]],
[dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]],
[dict(search_url=HttpUrl("https://example.com/sh3")), [2, 3]],
[dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]],
],
)
Expand Down

0 comments on commit 3758cb1

Please sign in to comment.