diff --git a/common/tests/test_storage.py b/common/tests/test_storage.py index b29bd57..048033a 100644 --- a/common/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -1,5 +1,6 @@ -from typing import List +from typing import Any, Dict, List +import pytest from sqlalchemy.engine import Engine from birdxplorer_common.models import ( @@ -27,146 +28,35 @@ def test_get_topic_list( assert expected == actual -def test_get_post_list( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], - topic_records_sample: List[TopicRecord], - note_records_sample: List[NoteRecord], -) -> None: - storage = Storage(engine=engine_for_test) - expected = sorted(post_samples, key=lambda x: x.post_id) - actual = sorted(storage.get_posts(), key=lambda x: x.post_id) - assert expected == actual - - -def test_get_posts_by_ids( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], - topic_records_sample: List[TopicRecord], - note_records_sample: List[NoteRecord], -) -> None: - storage = Storage(engine=engine_for_test) - post_ids = [post_samples[i].post_id for i in (0, 2)] - expected = [post_samples[i] for i in (0, 2)] - actual = list(storage.get_posts(post_ids=post_ids)) - assert expected == actual - - -def test_get_posts_by_ids_empty( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], - topic_records_sample: List[TopicRecord], - note_records_sample: List[NoteRecord], -) -> None: - storage = Storage(engine=engine_for_test) - post_ids: List[PostId] = [] - expected: List[Post] = [] - actual = list(storage.get_posts(post_ids=post_ids)) - assert expected == actual - - -def test_get_posts_by_created_at_range( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], - topic_records_sample: List[TopicRecord], - note_records_sample: List[NoteRecord], -) -> None: - storage = Storage(engine=engine_for_test) - start = TwitterTimestamp.from_int(1153921700000) - end = TwitterTimestamp.from_int(1153921800000) - expected = [post_samples[i] for i in (1,)] - actual = list(storage.get_posts(start=start, end=end)) - assert expected == actual - - -def test_get_posts_by_created_at_start( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], - topic_records_sample: List[TopicRecord], - note_records_sample: List[NoteRecord], -) -> None: - storage = Storage(engine=engine_for_test) - start = TwitterTimestamp.from_int(1153921700000) - expected = [post_samples[i] for i in (1, 2)] - actual = list(storage.get_posts(start=start)) - assert expected == actual - - -def test_get_posts_by_created_at_end( +@pytest.mark.parametrize( + ["filter_args", "expected_indices"], + [ + [dict(), [0, 1, 2]], + [dict(offset=1), [1, 2]], + [dict(limit=1), [0]], + [dict(offset=1, limit=1), [1]], + [dict(post_ids=[PostId.from_str("2234567890123456781"), PostId.from_str("2234567890123456801")]), [0, 2]], + [dict(post_ids=[]), []], + [dict(start=TwitterTimestamp.from_int(1153921700000), end=TwitterTimestamp.from_int(1153921800000)), [1]], + [dict(start=TwitterTimestamp.from_int(1153921700000)), [1, 2]], + [dict(end=TwitterTimestamp.from_int(1153921700000)), [0]], + [dict(search_text="https://t.co/xxxxxxxxxxx/"), [0, 2]], + [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], + [dict(offset=1, limit=1, search_text="https://t.co/xxxxxxxxxxx/"), [2]], + ], +) +def test_get_post( engine_for_test: Engine, post_samples: List[Post], post_records_sample: List[PostRecord], topic_records_sample: List[TopicRecord], note_records_sample: List[NoteRecord], + filter_args: Dict[str, Any], + expected_indices: List[int], ) -> None: storage = Storage(engine=engine_for_test) - end = TwitterTimestamp.from_int(1153921700000) - expected = [post_samples[i] for i in (0,)] - actual = list(storage.get_posts(end=end)) - assert expected == actual - - -def test_search_posts_by_text( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], -) -> None: - storage = Storage(engine=engine_for_test) - search_word = "https://t.co/xxxxxxxxxxx/" - expected = [post_samples[i] for i in (0, 2)] - actual = list(storage.get_posts(search_text=search_word)) - assert actual == expected - - -def test_get_posts_by_note_ids( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], - note_records_sample: List[NoteRecord], -) -> None: - storage = Storage(engine=engine_for_test) - note_ids = [NoteId.from_str("1234567890123456781")] - expected = [post_samples[i] for i in (0,)] - actual = list(storage.get_posts(note_ids=note_ids)) - assert expected == actual - - -def test_get_posts_offset( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], -) -> None: - storage = Storage(engine=engine_for_test) - expected = [post_samples[i] for i in (1, 2)] - actual = list(storage.get_posts(offset=1)) - assert expected == actual - - -def test_get_posts_limit( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], -) -> None: - storage = Storage(engine=engine_for_test) - expected = [post_samples[i] for i in (0,)] - actual = list(storage.get_posts(limit=1)) - assert expected == actual - - -def test_get_posts_offset_and_limit_with_filter( - engine_for_test: Engine, - post_samples: List[Post], - post_records_sample: List[PostRecord], -) -> None: - storage = Storage(engine=engine_for_test) - search_word = "https://t.co/xxxxxxxxxxx/" - expected = [post_samples[i] for i in (2,)] - actual = list(storage.get_posts(search_text=search_word, offset=1, limit=1)) + actual = list(storage.get_posts(**filter_args)) + expected = [post_samples[i] for i in expected_indices] assert expected == actual