Skip to content

Commit

Permalink
passed all tests written
Browse files Browse the repository at this point in the history
  • Loading branch information
kota-yata authored and osoken committed Apr 21, 2024
1 parent 83e9b8c commit 16b2c68
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 42 deletions.
38 changes: 18 additions & 20 deletions birdxplorer/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

from ..models import (
BaseModel,
LanguageIdentifier,
Note,
NoteId,
ParticipantId,
Post,
PostId,
Topic,
TopicId,
TweetId,
TwitterTimestamp,
UserEnrollment,
Expand Down Expand Up @@ -66,28 +68,24 @@ def get_topics() -> TopicListResponse:
@router.get("/notes", response_model=NoteListResponse)
def get_notes(
note_ids: Union[List[NoteId], None] = Query(default=None, alias="noteIds"),
created_at_from: Union[None, int] = Query(default=None, alias="createdAtFrom"),
created_at_to: Union[None, int] = Query(default=None, alias="createdAtTo"),
topic_ids: Union[List[str], None] = Query(default=None, alias="topicIds"),
created_at_from: Union[None, TwitterTimestamp] = Query(default=None, alias="createdAtFrom"),
created_at_to: Union[None, TwitterTimestamp] = Query(default=None, alias="createdAtTo"),
topic_ids: Union[List[TopicId], None] = Query(default=None, alias="topicIds"),
post_ids: Union[List[TweetId], None] = Query(default=None, alias="postIds"),
language: Union[str, None] = Query(default=None),
language: Union[LanguageIdentifier, None] = Query(default=None),
) -> NoteListResponse:
filters = {}

if note_ids is not None:
filters["note_ids"] = note_ids
if created_at_from is not None:
filters["created_at_from"] = TwitterTimestamp.from_int(created_at_from)
if created_at_to is not None:
filters["created_at_to"] = TwitterTimestamp.from_int(created_at_to)
if topic_ids is not None:
filters["topic_ids"] = topic_ids
if post_ids is not None:
filters["post_ids"] = post_ids
if language is not None:
filters["language"] = language

return NoteListResponse(data=list(storage.get_notes(**filters)))
return NoteListResponse(
data=list(
storage.get_notes(
note_ids=note_ids,
created_at_from=created_at_from,
created_at_to=created_at_to,
topic_ids=topic_ids,
post_ids=post_ids,
language=language,
)
)
)

@router.get("/posts", response_model=PostListResponse)
def get_posts(
Expand Down
10 changes: 8 additions & 2 deletions birdxplorer/storage.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generator, List
from typing import Generator, List, Union

from psycopg2.extensions import AsIs, register_adapter
from pydantic import AnyUrl, HttpUrl
Expand Down Expand Up @@ -149,7 +149,13 @@ def get_topics(self) -> Generator[TopicModel, None, None]:
)

def get_notes(
self, note_ids=None, created_at_from=None, created_at_to=None, topic_ids=None, post_ids=None, language=None
self,
note_ids: Union[List[NoteId], None] = None,
created_at_from: Union[None, TwitterTimestamp] = None,
created_at_to: Union[None, TwitterTimestamp] = None,
topic_ids: Union[List[TopicId], None] = None,
post_ids: Union[List[TweetId], None] = None,
language: Union[LanguageIdentifier, None] = None,
) -> Generator[NoteModel, None, None]:
with Session(self.engine) as sess:
query = sess.query(NoteRecord)
Expand Down
13 changes: 11 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import random
from collections.abc import Generator
from typing import List, Type
from typing import List, Type, Union
from unittest.mock import MagicMock, patch

from dotenv import load_dotenv
Expand All @@ -18,11 +18,15 @@

from birdxplorer.exceptions import UserEnrollmentNotFoundError
from birdxplorer.models import (
LanguageIdentifier,
Note,
NoteId,
ParticipantId,
Post,
PostId,
Topic,
TopicId,
TweetId,
TwitterTimestamp,
UserEnrollment,
XUser,
Expand Down Expand Up @@ -133,7 +137,12 @@ def _get_topics() -> Generator[Topic, None, None]:
yield from topic_samples

def _get_notes(
note_ids=None, created_at_from=None, created_at_to=None, topic_ids=None, post_ids=None, language=None
note_ids: Union[List[NoteId], None] = None,
created_at_from: Union[None, TwitterTimestamp] = None,
created_at_to: Union[None, TwitterTimestamp] = None,
topic_ids: Union[List[TopicId], None] = None,
post_ids: Union[List[TweetId], None] = None,
language: Union[LanguageIdentifier, None] = None,
) -> Generator[Note, None, None]:
for note in note_samples:
if note_ids is not None and note.note_id not in note_ids:
Expand Down
34 changes: 16 additions & 18 deletions tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@

from sqlalchemy.engine import Engine

from birdxplorer.models import Note, Post, PostId, Topic, TweetId, TwitterTimestamp
from birdxplorer.models import (
LanguageIdentifier,
Note,
NoteId,
Post,
PostId,
Topic,
TopicId,
TweetId,
TwitterTimestamp,
)
from birdxplorer.storage import NoteRecord, PostRecord, Storage, TopicRecord


Expand Down Expand Up @@ -120,7 +130,7 @@ def test_get_notes_by_ids_empty(
note_records_sample: List[NoteRecord],
) -> None:
storage = Storage(engine=engine_for_test)
note_ids: List[int] = []
note_ids: List[NoteId] = []
expected: List[Note] = []
actual = list(storage.get_notes(note_ids=note_ids))
assert expected == actual
Expand Down Expand Up @@ -170,7 +180,7 @@ def test_get_notes_by_topic_ids(
) -> None:
storage = Storage(engine=engine_for_test)
topics = note_samples[0].topics
topic_ids = [0]
topic_ids: List[TopicId] = [TopicId.from_int(0)]
expected = sorted([note for note in note_samples if note.topics == topics], key=lambda note: note.note_id)
actual = sorted(list(storage.get_notes(topic_ids=topic_ids)), key=lambda note: note.note_id)
assert expected == actual
Expand All @@ -182,7 +192,7 @@ def test_get_notes_by_topic_ids_empty(
note_records_sample: List[NoteRecord],
) -> None:
storage = Storage(engine=engine_for_test)
topic_ids: List[int] = []
topic_ids: List[TopicId] = []
expected: List[Note] = []
actual = list(storage.get_notes(topic_ids=topic_ids))
assert expected == actual
Expand All @@ -206,7 +216,7 @@ def test_get_notes_by_post_ids_empty(
note_records_sample: List[NoteRecord],
) -> None:
storage = Storage(engine=engine_for_test)
post_ids: List[int] = []
post_ids: List[TweetId] = []
expected: List[Note] = []
actual = list(storage.get_notes(post_ids=post_ids))
assert expected == actual
Expand All @@ -218,19 +228,7 @@ def test_get_notes_by_language(
note_records_sample: List[NoteRecord],
) -> None:
storage = Storage(engine=engine_for_test)
language = "en"
language = LanguageIdentifier("en")
expected = [note for note in note_samples if note.language == language]
actual = list(storage.get_notes(language=language))
assert expected == actual


def test_get_notes_by_language_empty(
engine_for_test: Engine,
note_samples: List[Note],
note_records_sample: List[NoteRecord],
) -> None:
storage = Storage(engine=engine_for_test)
language = ""
expected: List[Note] = []
actual = list(storage.get_notes(language=language))
assert expected == actual

0 comments on commit 16b2c68

Please sign in to comment.