Skip to content

Commit

Permalink
refactor: parametrize the filter parameters into unified test
Browse files Browse the repository at this point in the history
  • Loading branch information
osoken committed Aug 18, 2024
1 parent 8fefe9a commit 48c2967
Showing 1 changed file with 24 additions and 134 deletions.
158 changes: 24 additions & 134 deletions common/tests/test_storage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List
from typing import Any, Dict, List

import pytest
from sqlalchemy.engine import Engine

from birdxplorer_common.models import (
Expand Down Expand Up @@ -27,146 +28,35 @@ def test_get_topic_list(
assert expected == actual


def test_get_post_list(
engine_for_test: Engine,
post_samples: List[Post],
post_records_sample: List[PostRecord],
topic_records_sample: List[TopicRecord],
note_records_sample: List[NoteRecord],
) -> None:
storage = Storage(engine=engine_for_test)
expected = sorted(post_samples, key=lambda x: x.post_id)
actual = sorted(storage.get_posts(), key=lambda x: x.post_id)
assert expected == actual


def test_get_posts_by_ids(
engine_for_test: Engine,
post_samples: List[Post],
post_records_sample: List[PostRecord],
topic_records_sample: List[TopicRecord],
note_records_sample: List[NoteRecord],
) -> None:
storage = Storage(engine=engine_for_test)
post_ids = [post_samples[i].post_id for i in (0, 2)]
expected = [post_samples[i] for i in (0, 2)]
actual = list(storage.get_posts(post_ids=post_ids))
assert expected == actual


def test_get_posts_by_ids_empty(
engine_for_test: Engine,
post_samples: List[Post],
post_records_sample: List[PostRecord],
topic_records_sample: List[TopicRecord],
note_records_sample: List[NoteRecord],
) -> None:
storage = Storage(engine=engine_for_test)
post_ids: List[PostId] = []
expected: List[Post] = []
actual = list(storage.get_posts(post_ids=post_ids))
assert expected == actual


def test_get_posts_by_created_at_range(
engine_for_test: Engine,
post_samples: List[Post],
post_records_sample: List[PostRecord],
topic_records_sample: List[TopicRecord],
note_records_sample: List[NoteRecord],
) -> None:
storage = Storage(engine=engine_for_test)
start = TwitterTimestamp.from_int(1153921700000)
end = TwitterTimestamp.from_int(1153921800000)
expected = [post_samples[i] for i in (1,)]
actual = list(storage.get_posts(start=start, end=end))
assert expected == actual


def test_get_posts_by_created_at_start(
engine_for_test: Engine,
post_samples: List[Post],
post_records_sample: List[PostRecord],
topic_records_sample: List[TopicRecord],
note_records_sample: List[NoteRecord],
) -> None:
storage = Storage(engine=engine_for_test)
start = TwitterTimestamp.from_int(1153921700000)
expected = [post_samples[i] for i in (1, 2)]
actual = list(storage.get_posts(start=start))
assert expected == actual


def test_get_posts_by_created_at_end(
@pytest.mark.parametrize(
["filter_args", "expected_indices"],
[
[dict(), [0, 1, 2]],
[dict(offset=1), [1, 2]],
[dict(limit=1), [0]],
[dict(offset=1, limit=1), [1]],
[dict(post_ids=[PostId.from_str("2234567890123456781"), PostId.from_str("2234567890123456801")]), [0, 2]],
[dict(post_ids=[]), []],
[dict(start=TwitterTimestamp.from_int(1153921700000), end=TwitterTimestamp.from_int(1153921800000)), [1]],
[dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2]],
[dict(end=TwitterTimestamp.from_int(1153921700000)), [0]],
[dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]],
[dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]],
[dict(offset=1, limit=1, search_text="https://t.co/xxxxxxxxxxx/"), [2]],
],
)
def test_get_post(
engine_for_test: Engine,
post_samples: List[Post],
post_records_sample: List[PostRecord],
topic_records_sample: List[TopicRecord],
note_records_sample: List[NoteRecord],
filter_args: Dict[str, Any],
expected_indices: List[int],
) -> None:
storage = Storage(engine=engine_for_test)
end = TwitterTimestamp.from_int(1153921700000)
expected = [post_samples[i] for i in (0,)]
actual = list(storage.get_posts(end=end))
assert expected == actual


def test_search_posts_by_text(
engine_for_test: Engine,
post_samples: List[Post],
post_records_sample: List[PostRecord],
) -> None:
storage = Storage(engine=engine_for_test)
search_word = "https://t.co/xxxxxxxxxxx/"
expected = [post_samples[i] for i in (0, 2)]
actual = list(storage.get_posts(search_text=search_word))
assert actual == expected


def test_get_posts_by_note_ids(
engine_for_test: Engine,
post_samples: List[Post],
post_records_sample: List[PostRecord],
note_records_sample: List[NoteRecord],
) -> None:
storage = Storage(engine=engine_for_test)
note_ids = [NoteId.from_str("1234567890123456781")]
expected = [post_samples[i] for i in (0,)]
actual = list(storage.get_posts(note_ids=note_ids))
assert expected == actual


def test_get_posts_offset(
engine_for_test: Engine,
post_samples: List[Post],
post_records_sample: List[PostRecord],
) -> None:
storage = Storage(engine=engine_for_test)
expected = [post_samples[i] for i in (1, 2)]
actual = list(storage.get_posts(offset=1))
assert expected == actual


def test_get_posts_limit(
engine_for_test: Engine,
post_samples: List[Post],
post_records_sample: List[PostRecord],
) -> None:
storage = Storage(engine=engine_for_test)
expected = [post_samples[i] for i in (0,)]
actual = list(storage.get_posts(limit=1))
assert expected == actual


def test_get_posts_offset_and_limit_with_filter(
engine_for_test: Engine,
post_samples: List[Post],
post_records_sample: List[PostRecord],
) -> None:
storage = Storage(engine=engine_for_test)
search_word = "https://t.co/xxxxxxxxxxx/"
expected = [post_samples[i] for i in (2,)]
actual = list(storage.get_posts(search_text=search_word, offset=1, limit=1))
actual = list(storage.get_posts(**filter_args))
expected = [post_samples[i] for i in expected_indices]
assert expected == actual


Expand Down

0 comments on commit 48c2967

Please sign in to comment.