From 2ab6317388f48dc694015ecc568b7bdb23cf77d3 Mon Sep 17 00:00:00 2001 From: kota-yata Date: Sun, 10 Mar 2024 23:50:07 +0900 Subject: [PATCH 01/19] WIP: Add notes endpoint Done - add /notes router - add get_notes func in Storage class Undone - allow multiple topic_id and post_id filters in get_notes() - write tests for the endpoint --- birdxplorer/routers/data.py | 17 +++++++++++++++++ birdxplorer/storage.py | 15 +++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/birdxplorer/routers/data.py b/birdxplorer/routers/data.py index 20327a2..eb96aa2 100644 --- a/birdxplorer/routers/data.py +++ b/birdxplorer/routers/data.py @@ -6,6 +6,7 @@ from ..models import ( BaseModel, + Note, ParticipantId, Post, PostId, @@ -20,6 +21,10 @@ class TopicListResponse(BaseModel): data: List[Topic] +class NoteListResponse(BaseModel): + data: List[Note] + + class PostListResponse(BaseModel): data: List[Post] @@ -56,6 +61,18 @@ def get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> User def get_topics() -> TopicListResponse: return TopicListResponse(data=list(storage.get_topics())) + @router.get("/notes", response_model=NoteListResponse) + def get_notes( + created_at_from: Union[int, None] = None, + created_at_to: Union[int, None] = None, + topic_id: Union[str, None] = None, + post_id: Union[str, None] = None, + language: Union[str, None] = None, + ) -> NoteListResponse: + return NoteListResponse( + data=list(storage.get_notes(created_at_from, created_at_to, topic_id, post_id, language)) + ) + @router.get("/posts", response_model=PostListResponse) def get_posts( post_id: Union[List[PostId], None] = Query(default=None), diff --git a/birdxplorer/storage.py b/birdxplorer/storage.py index caeea09..296b9f1 100644 --- a/birdxplorer/storage.py +++ b/birdxplorer/storage.py @@ -151,6 +151,21 @@ def get_topics(self) -> Generator[TopicModel, None, None]: yield TopicModel( topic_id=topic_record.topic_id, label=topic_record.label, reference_count=reference_count or 0 ) + def get_notes(self, created_at_from, created_at_to, topic_id, post_id, language) -> Generator[NoteRecord, None, None]: + with Session(self.engine) as sess: + query = sess.query(NoteRecord) + if created_at_from: + query = query.filter(NoteRecord.created_at >= created_at_from) + if created_at_to: + query = query.filter(NoteRecord.created_at <= created_at_to) + if topic_id: + query = query.join(NoteTopicAssociation).filter(NoteTopicAssociation.topic_id == topic_id) + if post_id: + query = query.filter(NoteRecord.post_id == post_id) + if language: + query = query.filter(NoteRecord.language == language) + for note_record in query.all(): + yield note_record def get_posts(self) -> Generator[PostModel, None, None]: with Session(self.engine) as sess: From 593cee1f494dcc8e26a5655bd51552007fecfda3 Mon Sep 17 00:00:00 2001 From: kota-yata Date: Mon, 25 Mar 2024 15:55:47 +0900 Subject: [PATCH 02/19] make note_id, topic_id and post_id plural --- birdxplorer/routers/data.py | 32 ++++++++++++++++++++++++-------- birdxplorer/storage.py | 15 ++++++++++----- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/birdxplorer/routers/data.py b/birdxplorer/routers/data.py index eb96aa2..69e29a6 100644 --- a/birdxplorer/routers/data.py +++ b/birdxplorer/routers/data.py @@ -7,10 +7,12 @@ from ..models import ( BaseModel, Note, + NoteId, ParticipantId, Post, PostId, Topic, + TweetId, TwitterTimestamp, UserEnrollment, ) @@ -63,15 +65,29 @@ def get_topics() -> TopicListResponse: @router.get("/notes", response_model=NoteListResponse) def get_notes( - created_at_from: Union[int, None] = None, - created_at_to: Union[int, None] = None, - topic_id: Union[str, None] = None, - post_id: Union[str, None] = None, - language: Union[str, None] = None, + note_id: Union[List[NoteId], None] = Query(default=None), + created_at_from: Union[None, int, str] = Query(default=None), + created_at_to: Union[None, int, str] = Query(default=None), + topic_id: Union[List[str], None] = Query(default=None), + post_id: Union[List[TweetId], None] = Query(default=None), + language: Union[str, None] = Query(default=None), ) -> NoteListResponse: - return NoteListResponse( - data=list(storage.get_notes(created_at_from, created_at_to, topic_id, post_id, language)) - ) + filters = {} + + if note_id is not None: + filters["note_ids"] = note_id + if created_at_from is not None: + filters["created_at_from"] = created_at_from + if created_at_to is not None: + filters["created_at_to"] = created_at_to + if topic_id is not None: + filters["topic_ids"] = topic_id + if post_id is not None: + filters["post_ids"] = post_id + if language is not None: + filters["language"] = language + + return NoteListResponse(data=list(storage.get_notes(**filters))) @router.get("/posts", response_model=PostListResponse) def get_posts( diff --git a/birdxplorer/storage.py b/birdxplorer/storage.py index 296b9f1..dc48cff 100644 --- a/birdxplorer/storage.py +++ b/birdxplorer/storage.py @@ -151,17 +151,22 @@ def get_topics(self) -> Generator[TopicModel, None, None]: yield TopicModel( topic_id=topic_record.topic_id, label=topic_record.label, reference_count=reference_count or 0 ) - def get_notes(self, created_at_from, created_at_to, topic_id, post_id, language) -> Generator[NoteRecord, None, None]: + + def get_notes( + self, note_id, created_at_from, created_at_to, topic_ids, post_ids, language + ) -> Generator[NoteRecord, None, None]: with Session(self.engine) as sess: query = sess.query(NoteRecord) + if note_id: + query = query.filter(NoteRecord.note_id.in_(note_id)) if created_at_from: query = query.filter(NoteRecord.created_at >= created_at_from) if created_at_to: query = query.filter(NoteRecord.created_at <= created_at_to) - if topic_id: - query = query.join(NoteTopicAssociation).filter(NoteTopicAssociation.topic_id == topic_id) - if post_id: - query = query.filter(NoteRecord.post_id == post_id) + if topic_ids: + query = query.join(NoteTopicAssociation).filter(NoteTopicAssociation.topic_id.in_(topic_ids)) + if post_ids: + query = query.filter(NoteRecord.post_id.in_(post_ids)) if language: query = query.filter(NoteRecord.language == language) for note_record in query.all(): From 85723cc9f8a05b7923cc5148c9dc6657e0f308fd Mon Sep 17 00:00:00 2001 From: kota-yata Date: Thu, 4 Apr 2024 12:50:56 +0900 Subject: [PATCH 03/19] add tests for get_notes --- birdxplorer/routers/data.py | 18 ++--- birdxplorer/storage.py | 14 ++-- tests/test_storage.py | 136 +++++++++++++++++++++++++++++++++++- 3 files changed, 151 insertions(+), 17 deletions(-) diff --git a/birdxplorer/routers/data.py b/birdxplorer/routers/data.py index 69e29a6..5e2ad06 100644 --- a/birdxplorer/routers/data.py +++ b/birdxplorer/routers/data.py @@ -65,25 +65,25 @@ def get_topics() -> TopicListResponse: @router.get("/notes", response_model=NoteListResponse) def get_notes( - note_id: Union[List[NoteId], None] = Query(default=None), + note_ids: Union[List[NoteId], None] = Query(default=None), created_at_from: Union[None, int, str] = Query(default=None), created_at_to: Union[None, int, str] = Query(default=None), - topic_id: Union[List[str], None] = Query(default=None), - post_id: Union[List[TweetId], None] = Query(default=None), + topic_ids: Union[List[str], None] = Query(default=None), + post_ids: Union[List[TweetId], None] = Query(default=None), language: Union[str, None] = Query(default=None), ) -> NoteListResponse: filters = {} - if note_id is not None: - filters["note_ids"] = note_id + if note_ids is not None: + filters["note_ids"] = note_ids if created_at_from is not None: filters["created_at_from"] = created_at_from if created_at_to is not None: filters["created_at_to"] = created_at_to - if topic_id is not None: - filters["topic_ids"] = topic_id - if post_id is not None: - filters["post_ids"] = post_id + if topic_ids is not None: + filters["topic_ids"] = topic_ids + if post_ids is not None: + filters["post_ids"] = post_ids if language is not None: filters["language"] = language diff --git a/birdxplorer/storage.py b/birdxplorer/storage.py index dc48cff..20d487b 100644 --- a/birdxplorer/storage.py +++ b/birdxplorer/storage.py @@ -153,21 +153,21 @@ def get_topics(self) -> Generator[TopicModel, None, None]: ) def get_notes( - self, note_id, created_at_from, created_at_to, topic_ids, post_ids, language + self, note_id=None, created_at_from=None, created_at_to=None, topic_ids=None, post_ids=None, language=None ) -> Generator[NoteRecord, None, None]: with Session(self.engine) as sess: query = sess.query(NoteRecord) - if note_id: + if note_id is not None: query = query.filter(NoteRecord.note_id.in_(note_id)) - if created_at_from: + if created_at_from is not None: query = query.filter(NoteRecord.created_at >= created_at_from) - if created_at_to: + if created_at_to is not None: query = query.filter(NoteRecord.created_at <= created_at_to) - if topic_ids: + if topic_ids is not None: query = query.join(NoteTopicAssociation).filter(NoteTopicAssociation.topic_id.in_(topic_ids)) - if post_ids: + if post_ids is not None: query = query.filter(NoteRecord.post_id.in_(post_ids)) - if language: + if language is not None: query = query.filter(NoteRecord.language == language) for note_record in query.all(): yield note_record diff --git a/tests/test_storage.py b/tests/test_storage.py index 2951208..37ca90c 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -1,8 +1,9 @@ +from datetime import datetime, timedelta from typing import List from sqlalchemy.engine import Engine -from birdxplorer.models import Post, PostId, Topic, TwitterTimestamp +from birdxplorer.models import Note, Post, PostId, Topic, TwitterTimestamp from birdxplorer.storage import NoteRecord, PostRecord, Storage, TopicRecord @@ -100,3 +101,136 @@ def test_get_posts_by_created_at_end( expected = [post_samples[i] for i in (0,)] actual = list(storage.get_posts_by_created_at_end(end)) assert expected == actual + + +def test_get_notes_by_ids( + engine_for_test: Engine, + note_samples: List[Note], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + note_ids = [note_samples[i].note_id for i in (0, 2)] + expected = [note_samples[i] for i in (0, 2)] + actual = list(storage.get_notes(note_id=note_ids)) + assert expected == actual + + +def test_get_notes_by_ids_empty( + engine_for_test: Engine, + note_samples: List[Note], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + note_ids: List[int] = [] + expected: List[Note] = [] + actual = list(storage.get_notes(note_id=note_ids)) + assert expected == actual + + +def test_get_notes_by_created_at_range( + engine_for_test: Engine, + note_samples: List[Note], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + start = datetime.now() - timedelta(days=2) + end = datetime.now() - timedelta(days=1) + expected = [note for note in note_samples if start <= note.created_at <= end] + actual = list(storage.get_notes(created_at_from=start, created_at_to=end)) + assert expected == actual + + +def test_get_notes_by_created_at_start( + engine_for_test: Engine, + note_samples: List[Note], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + start = datetime.now() - timedelta(days=1) + expected = [note for note in note_samples if note.created_at >= start] + actual = list(storage.get_notes(created_at_from=start)) + assert expected == actual + + +def test_get_notes_by_created_at_end( + engine_for_test: Engine, + note_samples: List[Note], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + end = datetime.now() - timedelta(days=1) + expected = [note for note in note_samples if note.created_at <= end] + actual = list(storage.get_notes(created_at_to=end)) + assert expected == actual + + +def test_get_notes_by_topic_ids( + engine_for_test: Engine, + note_samples: List[Note], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + topic_ids = [1, 2] + expected = [note for note in note_samples if note.topic_id in topic_ids] + actual = list(storage.get_notes(topic_ids=topic_ids)) + assert expected == actual + + +def test_get_notes_by_topic_ids_empty( + engine_for_test: Engine, + note_samples: List[Note], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + topic_ids: List[int] = [] + expected: List[Note] = [] + actual = list(storage.get_notes(topic_ids=topic_ids)) + assert expected == actual + + +def test_get_notes_by_post_ids( + engine_for_test: Engine, + note_samples: List[Note], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + post_ids = [1, 2] + expected = [note for note in note_samples if note.post_id in post_ids] + actual = list(storage.get_notes(post_ids=post_ids)) + assert expected == actual + + +def test_get_notes_by_post_ids_empty( + engine_for_test: Engine, + note_samples: List[Note], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + post_ids: List[int] = [] + expected: List[Note] = [] + actual = list(storage.get_notes(post_ids=post_ids)) + assert expected == actual + + +def test_get_notes_by_language( + engine_for_test: Engine, + note_samples: List[Note], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + language = "en" + expected = [note for note in note_samples if note.language == language] + actual = list(storage.get_notes(language=language)) + assert expected == actual + + +def test_get_notes_by_language_empty( + engine_for_test: Engine, + note_samples: List[Note], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + language = "" + expected: List[Note] = [] + actual = list(storage.get_notes(language=language)) + assert expected == actual From da5c01ce3edac02c4bb8668e1ebdcfac02be9357 Mon Sep 17 00:00:00 2001 From: kota-yata Date: Thu, 4 Apr 2024 13:00:55 +0900 Subject: [PATCH 04/19] test rename & modify yielding type in storage.py --- birdxplorer/storage.py | 19 +++++++++++-------- tests/test_storage.py | 4 ++-- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/birdxplorer/storage.py b/birdxplorer/storage.py index 20d487b..37ebbab 100644 --- a/birdxplorer/storage.py +++ b/birdxplorer/storage.py @@ -7,13 +7,9 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship from sqlalchemy.types import DECIMAL, JSON, Integer, String -from .models import ( - LanguageIdentifier, - MediaDetails, - NonNegativeInt, - NoteId, - ParticipantId, -) +from .models import LanguageIdentifier, MediaDetails, NonNegativeInt +from .models import Note as NoteModel +from .models import NoteId, ParticipantId from .models import Post as PostModel from .models import PostId, SummaryString from .models import Topic as TopicModel @@ -170,7 +166,14 @@ def get_notes( if language is not None: query = query.filter(NoteRecord.language == language) for note_record in query.all(): - yield note_record + yield NoteModel( + note_id=note_record.note_id, + post_id=note_record.post_id, + topics=[topic.topic for topic in note_record.topics], + language=note_record.language, + summary=note_record.summary, + created_at=note_record.created_at, + ) def get_posts(self) -> Generator[PostModel, None, None]: with Session(self.engine) as sess: diff --git a/tests/test_storage.py b/tests/test_storage.py index 37ca90c..a729f8a 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -140,7 +140,7 @@ def test_get_notes_by_created_at_range( assert expected == actual -def test_get_notes_by_created_at_start( +def test_get_notes_by_created_at_from( engine_for_test: Engine, note_samples: List[Note], note_records_sample: List[NoteRecord], @@ -152,7 +152,7 @@ def test_get_notes_by_created_at_start( assert expected == actual -def test_get_notes_by_created_at_end( +def test_get_notes_by_created_at_to( engine_for_test: Engine, note_samples: List[Note], note_records_sample: List[NoteRecord], From 9fd724ef9be440004ef2da7bc591c221a2234e01 Mon Sep 17 00:00:00 2001 From: kota-yata Date: Thu, 4 Apr 2024 22:45:05 +0900 Subject: [PATCH 05/19] add date test --- birdxplorer/storage.py | 20 +++++++++--- tests/routers/test_data.py | 65 +++++++++++++++++++++++++++++++++++++- tests/test_storage.py | 30 +++++++++--------- 3 files changed, 94 insertions(+), 21 deletions(-) diff --git a/birdxplorer/storage.py b/birdxplorer/storage.py index 37ebbab..4ebdf01 100644 --- a/birdxplorer/storage.py +++ b/birdxplorer/storage.py @@ -149,12 +149,12 @@ def get_topics(self) -> Generator[TopicModel, None, None]: ) def get_notes( - self, note_id=None, created_at_from=None, created_at_to=None, topic_ids=None, post_ids=None, language=None - ) -> Generator[NoteRecord, None, None]: + self, note_ids=None, created_at_from=None, created_at_to=None, topic_ids=None, post_ids=None, language=None + ) -> Generator[NoteModel, None, None]: with Session(self.engine) as sess: query = sess.query(NoteRecord) - if note_id is not None: - query = query.filter(NoteRecord.note_id.in_(note_id)) + if note_ids is not None: + query = query.filter(NoteRecord.note_id.in_(note_ids)) if created_at_from is not None: query = query.filter(NoteRecord.created_at >= created_at_from) if created_at_to is not None: @@ -169,7 +169,17 @@ def get_notes( yield NoteModel( note_id=note_record.note_id, post_id=note_record.post_id, - topics=[topic.topic for topic in note_record.topics], + topics=[ + TopicModel( + topic_id=topic.topic_id, + label=topic.topic.label, + reference_count=sess.query(func.count(NoteTopicAssociation.note_id)) + .filter(NoteTopicAssociation.topic_id == topic.topic_id) + .scalar() + or 0, + ) + for topic in note_record.topics + ], language=note_record.language, summary=note_record.summary, created_at=note_record.created_at, diff --git a/tests/routers/test_data.py b/tests/routers/test_data.py index 21ef02e..96f42c7 100644 --- a/tests/routers/test_data.py +++ b/tests/routers/test_data.py @@ -3,7 +3,7 @@ from fastapi.testclient import TestClient -from birdxplorer.models import Post, Topic, UserEnrollment +from birdxplorer.models import Note, Post, Topic, UserEnrollment def test_user_enrollments_get(client: TestClient, user_enrollment_samples: List[UserEnrollment]) -> None: @@ -81,3 +81,66 @@ def test_posts_get_created_at_end_filter_accepts_integer(client: TestClient, pos def test_posts_get_timestamp_out_of_range(client: TestClient, post_samples: List[Post]) -> None: response = client.get("/api/v1/data/posts/?createdAtStart=1153921700&createdAtEnd=1153921700") assert response.status_code == 422 + + +def test_notes_get(client: TestClient, note_samples: List[Note]) -> None: + response = client.get("/api/v1/data/notes") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(d.model_dump_json()) for d in note_samples]} + + +def test_notes_get_has_note_id_filter(client: TestClient, note_samples: List[Note]) -> None: + response = client.get(f"/api/v1/data/notes/?noteId={note_samples[0].note_id},{note_samples[2].note_id}") + assert response.status_code == 200 + res_json = response.json() + assert res_json == { + "data": [json.loads(note_samples[0].model_dump_json()), json.loads(note_samples[2].model_dump_json())] + } + + +def test_notes_get_has_created_at_filter_from_and_to(client: TestClient, note_samples: List[Note]) -> None: + response = client.get("/api/v1/data/notes/?createdAtFrom=2006-7-25 00:00:00&createdAtTo=2006-7-30 23:59:59") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(note_samples[1].model_dump_json())]} + + +def test_notes_get_has_created_at_filter_from(client: TestClient, note_samples: List[Note]) -> None: + response = client.get("/api/v1/data/notes/?createdAtFrom=2006-7-25 00:00:00") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2)]} + + +def test_notes_get_has_created_at_filter_to(client: TestClient, note_samples: List[Note]) -> None: + response = client.get("/api/v1/data/notes/?createdAtTo=2006-7-30 00:00:00") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1)]} + + +def test_notes_get_created_at_range_filter_accepts_integer(client: TestClient, note_samples: List[Note]) -> None: + response = client.get("/api/v1/data/notes/?createdAtFrom=1153921700000&createdAtTo=1154921800000") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(note_samples[1].model_dump_json())]} + + +def test_notes_get_created_at_from_filter_accepts_integer(client: TestClient, note_samples: List[Note]) -> None: + response = client.get("/api/v1/data/notes/?createdAtFrom=1153921700000") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2)]} + + +def test_notes_get_created_at_to_filter_accepts_integer(client: TestClient, note_samples: List[Note]) -> None: + response = client.get("/api/v1/data/notes/?createdAtTo=1154921800000") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1)]} + + +def test_notes_get_timestamp_out_of_range(client: TestClient, note_samples: List[Note]) -> None: + response = client.get("/api/v1/data/notes/?createdAtFrom=1153921700&createdAtTo=1153921700") + assert response.status_code == 422 diff --git a/tests/test_storage.py b/tests/test_storage.py index a729f8a..2da873b 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -3,7 +3,7 @@ from sqlalchemy.engine import Engine -from birdxplorer.models import Note, Post, PostId, Topic, TwitterTimestamp +from birdxplorer.models import Note, Post, PostId, Topic, TweetId, TwitterTimestamp from birdxplorer.storage import NoteRecord, PostRecord, Storage, TopicRecord @@ -111,7 +111,7 @@ def test_get_notes_by_ids( storage = Storage(engine=engine_for_test) note_ids = [note_samples[i].note_id for i in (0, 2)] expected = [note_samples[i] for i in (0, 2)] - actual = list(storage.get_notes(note_id=note_ids)) + actual = list(storage.get_notes(note_ids=note_ids)) assert expected == actual @@ -123,7 +123,7 @@ def test_get_notes_by_ids_empty( storage = Storage(engine=engine_for_test) note_ids: List[int] = [] expected: List[Note] = [] - actual = list(storage.get_notes(note_id=note_ids)) + actual = list(storage.get_notes(note_ids=note_ids)) assert expected == actual @@ -133,10 +133,10 @@ def test_get_notes_by_created_at_range( note_records_sample: List[NoteRecord], ) -> None: storage = Storage(engine=engine_for_test) - start = datetime.now() - timedelta(days=2) - end = datetime.now() - timedelta(days=1) - expected = [note for note in note_samples if start <= note.created_at <= end] - actual = list(storage.get_notes(created_at_from=start, created_at_to=end)) + created_at_from = TwitterTimestamp.from_int(1152921602000) + created_at_to = TwitterTimestamp.from_int(1152921603000) + expected = [note for note in note_samples if created_at_from <= note.created_at <= created_at_to] + actual = list(storage.get_notes(created_at_from=created_at_from, created_at_to=created_at_to)) assert expected == actual @@ -146,9 +146,9 @@ def test_get_notes_by_created_at_from( note_records_sample: List[NoteRecord], ) -> None: storage = Storage(engine=engine_for_test) - start = datetime.now() - timedelta(days=1) - expected = [note for note in note_samples if note.created_at >= start] - actual = list(storage.get_notes(created_at_from=start)) + created_at_from = TwitterTimestamp.from_int(1152921602000) + expected = [note for note in note_samples if note.created_at >= created_at_from] + actual = list(storage.get_notes(created_at_from=created_at_from)) assert expected == actual @@ -158,9 +158,9 @@ def test_get_notes_by_created_at_to( note_records_sample: List[NoteRecord], ) -> None: storage = Storage(engine=engine_for_test) - end = datetime.now() - timedelta(days=1) - expected = [note for note in note_samples if note.created_at <= end] - actual = list(storage.get_notes(created_at_to=end)) + created_at_to = TwitterTimestamp.from_int(1152921603000) + expected = [note for note in note_samples if note.created_at <= created_at_to] + actual = list(storage.get_notes(created_at_to=created_at_to)) assert expected == actual @@ -171,7 +171,7 @@ def test_get_notes_by_topic_ids( ) -> None: storage = Storage(engine=engine_for_test) topic_ids = [1, 2] - expected = [note for note in note_samples if note.topic_id in topic_ids] + expected = [note for note in note_samples if note.topics == topic_ids] actual = list(storage.get_notes(topic_ids=topic_ids)) assert expected == actual @@ -194,7 +194,7 @@ def test_get_notes_by_post_ids( note_records_sample: List[NoteRecord], ) -> None: storage = Storage(engine=engine_for_test) - post_ids = [1, 2] + post_ids = [TweetId.from_str("1"), TweetId.from_str("2")] expected = [note for note in note_samples if note.post_id in post_ids] actual = list(storage.get_notes(post_ids=post_ids)) assert expected == actual From e2481fa17a3d2570be5941d31038a407afe27245 Mon Sep 17 00:00:00 2001 From: kota-yata Date: Thu, 11 Apr 2024 20:42:18 +0900 Subject: [PATCH 06/19] test_get_notes_by_topic_ids passed --- birdxplorer/storage.py | 10 +++++++++- tests/test_storage.py | 7 ++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/birdxplorer/storage.py b/birdxplorer/storage.py index 4ebdf01..e3c13fe 100644 --- a/birdxplorer/storage.py +++ b/birdxplorer/storage.py @@ -160,7 +160,15 @@ def get_notes( if created_at_to is not None: query = query.filter(NoteRecord.created_at <= created_at_to) if topic_ids is not None: - query = query.join(NoteTopicAssociation).filter(NoteTopicAssociation.topic_id.in_(topic_ids)) + # 同じトピックIDを持つノートを取得するためのサブクエリ + # とりあえずANDを実装 + subq = ( + select(NoteTopicAssociation.note_id) + .group_by(NoteTopicAssociation.note_id) + .having(func.array_agg(NoteTopicAssociation.topic_id) == topic_ids) + .subquery() + ) + query = query.join(subq, NoteRecord.note_id == subq.c.note_id) if post_ids is not None: query = query.filter(NoteRecord.post_id.in_(post_ids)) if language is not None: diff --git a/tests/test_storage.py b/tests/test_storage.py index 2da873b..e76caf9 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -170,9 +170,10 @@ def test_get_notes_by_topic_ids( note_records_sample: List[NoteRecord], ) -> None: storage = Storage(engine=engine_for_test) - topic_ids = [1, 2] - expected = [note for note in note_samples if note.topics == topic_ids] - actual = list(storage.get_notes(topic_ids=topic_ids)) + topics = note_samples[0].topics + topic_ids = [0] + expected = sorted([note for note in note_samples if note.topics == topics], key=lambda note: note.note_id) + actual = sorted(list(storage.get_notes(topic_ids=topic_ids)), key=lambda note: note.note_id) assert expected == actual From c60d827ec4c5322ea96f90d1249d32a3ba1da8ca Mon Sep 17 00:00:00 2001 From: kota-yata Date: Thu, 11 Apr 2024 20:51:18 +0900 Subject: [PATCH 07/19] test_get_notes_by_post_ids passed --- tests/test_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_storage.py b/tests/test_storage.py index e76caf9..0d9cc1a 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -195,7 +195,7 @@ def test_get_notes_by_post_ids( note_records_sample: List[NoteRecord], ) -> None: storage = Storage(engine=engine_for_test) - post_ids = [TweetId.from_str("1"), TweetId.from_str("2")] + post_ids = [TweetId.from_str("2234567890123456781"), TweetId.from_str("2234567890123456782")] expected = [note for note in note_samples if note.post_id in post_ids] actual = list(storage.get_notes(post_ids=post_ids)) assert expected == actual From cf0d95bcf7e123d4f29c9d2cc86d485f06bde496 Mon Sep 17 00:00:00 2001 From: kota-yata Date: Fri, 12 Apr 2024 14:51:43 +0900 Subject: [PATCH 08/19] test_notes_get passed --- birdxplorer/routers/data.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/birdxplorer/routers/data.py b/birdxplorer/routers/data.py index 5e2ad06..db9644d 100644 --- a/birdxplorer/routers/data.py +++ b/birdxplorer/routers/data.py @@ -65,11 +65,11 @@ def get_topics() -> TopicListResponse: @router.get("/notes", response_model=NoteListResponse) def get_notes( - note_ids: Union[List[NoteId], None] = Query(default=None), - created_at_from: Union[None, int, str] = Query(default=None), - created_at_to: Union[None, int, str] = Query(default=None), - topic_ids: Union[List[str], None] = Query(default=None), - post_ids: Union[List[TweetId], None] = Query(default=None), + note_ids: Union[List[NoteId], None] = Query(default=None, alias="noteIds"), + created_at_from: Union[None, int] = Query(default=None, alias="createdAtFrom"), + created_at_to: Union[None, int] = Query(default=None, alias="createdAtTo"), + topic_ids: Union[List[str], None] = Query(default=None, alias="topicIds"), + post_ids: Union[List[TweetId], None] = Query(default=None, alias="postIds"), language: Union[str, None] = Query(default=None), ) -> NoteListResponse: filters = {} @@ -77,9 +77,9 @@ def get_notes( if note_ids is not None: filters["note_ids"] = note_ids if created_at_from is not None: - filters["created_at_from"] = created_at_from + filters["created_at_from"] = TwitterTimestamp.from_int(created_at_from) if created_at_to is not None: - filters["created_at_to"] = created_at_to + filters["created_at_to"] = TwitterTimestamp.from_int(created_at_to) if topic_ids is not None: filters["topic_ids"] = topic_ids if post_ids is not None: From ee3420ddd9aeb81500055edc8f0508b76dd2fc73 Mon Sep 17 00:00:00 2001 From: kota-yata Date: Fri, 12 Apr 2024 14:52:16 +0900 Subject: [PATCH 09/19] test_notes_get passed --- tests/conftest.py | 24 ++++++++++++++++++++++- tests/routers/test_data.py | 39 +++++++------------------------------- 2 files changed, 30 insertions(+), 33 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index fe3ece0..2474fb1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -114,7 +114,10 @@ def user_enrollment_samples( @fixture def mock_storage( - user_enrollment_samples: List[UserEnrollment], topic_samples: List[Topic], post_samples: List[Post] + user_enrollment_samples: List[UserEnrollment], + topic_samples: List[Topic], + post_samples: List[Post], + note_samples: List[Note], ) -> Generator[MagicMock, None, None]: mock = MagicMock(spec=Storage) @@ -129,7 +132,26 @@ def _get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> Use def _get_topics() -> Generator[Topic, None, None]: yield from topic_samples + def _get_notes( + note_ids=None, created_at_from=None, created_at_to=None, topic_ids=None, post_ids=None, language=None + ) -> Generator[Note, None, None]: + for note in note_samples: + if note_ids is not None and note.note_id not in note_ids: + continue + if created_at_from is not None and note.created_at < created_at_from: + continue + if created_at_to is not None and note.created_at > created_at_to: + continue + if topic_ids is not None and not set(topic_ids).issubset({topic.topic_id for topic in note.topics}): + continue + if post_ids is not None and note.post_id not in post_ids: + continue + if language is not None and note.language != language: + continue + yield note + mock.get_topics.side_effect = _get_topics + mock.get_notes.side_effect = _get_notes def _get_posts() -> Generator[Post, None, None]: yield from post_samples diff --git a/tests/routers/test_data.py b/tests/routers/test_data.py index 96f42c7..47ce0e5 100644 --- a/tests/routers/test_data.py +++ b/tests/routers/test_data.py @@ -91,56 +91,31 @@ def test_notes_get(client: TestClient, note_samples: List[Note]) -> None: def test_notes_get_has_note_id_filter(client: TestClient, note_samples: List[Note]) -> None: - response = client.get(f"/api/v1/data/notes/?noteId={note_samples[0].note_id},{note_samples[2].note_id}") + response = client.get(f"/api/v1/data/notes/?noteIds={note_samples[0].note_id}¬eIds={note_samples[2].note_id}") assert response.status_code == 200 res_json = response.json() assert res_json == { "data": [json.loads(note_samples[0].model_dump_json()), json.loads(note_samples[2].model_dump_json())] } - def test_notes_get_has_created_at_filter_from_and_to(client: TestClient, note_samples: List[Note]) -> None: - response = client.get("/api/v1/data/notes/?createdAtFrom=2006-7-25 00:00:00&createdAtTo=2006-7-30 23:59:59") + response = client.get("/api/v1/data/notes/?createdAtFrom=1152921601000&createdAtTo=1152921603000") assert response.status_code == 200 res_json = response.json() - assert res_json == {"data": [json.loads(note_samples[1].model_dump_json())]} + assert res_json == {"data": [json.loads(note_samples[1].model_dump_json()) for i in (1, 2, 3)]} def test_notes_get_has_created_at_filter_from(client: TestClient, note_samples: List[Note]) -> None: - response = client.get("/api/v1/data/notes/?createdAtFrom=2006-7-25 00:00:00") + response = client.get("/api/v1/data/notes/?createdAtFrom=1152921601000") assert response.status_code == 200 res_json = response.json() - assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2)]} + assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3, 4)]} def test_notes_get_has_created_at_filter_to(client: TestClient, note_samples: List[Note]) -> None: - response = client.get("/api/v1/data/notes/?createdAtTo=2006-7-30 00:00:00") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1)]} - - -def test_notes_get_created_at_range_filter_accepts_integer(client: TestClient, note_samples: List[Note]) -> None: - response = client.get("/api/v1/data/notes/?createdAtFrom=1153921700000&createdAtTo=1154921800000") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(note_samples[1].model_dump_json())]} - - -def test_notes_get_created_at_from_filter_accepts_integer(client: TestClient, note_samples: List[Note]) -> None: - response = client.get("/api/v1/data/notes/?createdAtFrom=1153921700000") + response = client.get("/api/v1/data/notes/?createdAtTo=1152921603000") assert response.status_code == 200 res_json = response.json() - assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2)]} + assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1, 2, 3)]} -def test_notes_get_created_at_to_filter_accepts_integer(client: TestClient, note_samples: List[Note]) -> None: - response = client.get("/api/v1/data/notes/?createdAtTo=1154921800000") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1)]} - - -def test_notes_get_timestamp_out_of_range(client: TestClient, note_samples: List[Note]) -> None: - response = client.get("/api/v1/data/notes/?createdAtFrom=1153921700&createdAtTo=1153921700") - assert response.status_code == 422 From 83e9b8c701f92e8651a9ebefa91495617c138577 Mon Sep 17 00:00:00 2001 From: kota-yata Date: Fri, 12 Apr 2024 15:04:49 +0900 Subject: [PATCH 10/19] test_notes_get_has_created_at_filter_from_and_to passed --- tests/routers/test_data.py | 5 ++--- tests/test_storage.py | 1 - 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/routers/test_data.py b/tests/routers/test_data.py index 47ce0e5..8f719f7 100644 --- a/tests/routers/test_data.py +++ b/tests/routers/test_data.py @@ -98,11 +98,12 @@ def test_notes_get_has_note_id_filter(client: TestClient, note_samples: List[Not "data": [json.loads(note_samples[0].model_dump_json()), json.loads(note_samples[2].model_dump_json())] } + def test_notes_get_has_created_at_filter_from_and_to(client: TestClient, note_samples: List[Note]) -> None: response = client.get("/api/v1/data/notes/?createdAtFrom=1152921601000&createdAtTo=1152921603000") assert response.status_code == 200 res_json = response.json() - assert res_json == {"data": [json.loads(note_samples[1].model_dump_json()) for i in (1, 2, 3)]} + assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3)]} def test_notes_get_has_created_at_filter_from(client: TestClient, note_samples: List[Note]) -> None: @@ -117,5 +118,3 @@ def test_notes_get_has_created_at_filter_to(client: TestClient, note_samples: Li assert response.status_code == 200 res_json = response.json() assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1, 2, 3)]} - - diff --git a/tests/test_storage.py b/tests/test_storage.py index 0d9cc1a..4865e54 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -1,4 +1,3 @@ -from datetime import datetime, timedelta from typing import List from sqlalchemy.engine import Engine From 16b2c68249291827772ff5826113daa029c38011 Mon Sep 17 00:00:00 2001 From: kota-yata Date: Mon, 15 Apr 2024 17:07:43 +0900 Subject: [PATCH 11/19] passed all tests written --- birdxplorer/routers/data.py | 38 ++++++++++++++++++------------------- birdxplorer/storage.py | 10 ++++++++-- tests/conftest.py | 13 +++++++++++-- tests/test_storage.py | 34 ++++++++++++++++----------------- 4 files changed, 53 insertions(+), 42 deletions(-) diff --git a/birdxplorer/routers/data.py b/birdxplorer/routers/data.py index db9644d..7200623 100644 --- a/birdxplorer/routers/data.py +++ b/birdxplorer/routers/data.py @@ -6,12 +6,14 @@ from ..models import ( BaseModel, + LanguageIdentifier, Note, NoteId, ParticipantId, Post, PostId, Topic, + TopicId, TweetId, TwitterTimestamp, UserEnrollment, @@ -66,28 +68,24 @@ def get_topics() -> TopicListResponse: @router.get("/notes", response_model=NoteListResponse) def get_notes( note_ids: Union[List[NoteId], None] = Query(default=None, alias="noteIds"), - created_at_from: Union[None, int] = Query(default=None, alias="createdAtFrom"), - created_at_to: Union[None, int] = Query(default=None, alias="createdAtTo"), - topic_ids: Union[List[str], None] = Query(default=None, alias="topicIds"), + created_at_from: Union[None, TwitterTimestamp] = Query(default=None, alias="createdAtFrom"), + created_at_to: Union[None, TwitterTimestamp] = Query(default=None, alias="createdAtTo"), + topic_ids: Union[List[TopicId], None] = Query(default=None, alias="topicIds"), post_ids: Union[List[TweetId], None] = Query(default=None, alias="postIds"), - language: Union[str, None] = Query(default=None), + language: Union[LanguageIdentifier, None] = Query(default=None), ) -> NoteListResponse: - filters = {} - - if note_ids is not None: - filters["note_ids"] = note_ids - if created_at_from is not None: - filters["created_at_from"] = TwitterTimestamp.from_int(created_at_from) - if created_at_to is not None: - filters["created_at_to"] = TwitterTimestamp.from_int(created_at_to) - if topic_ids is not None: - filters["topic_ids"] = topic_ids - if post_ids is not None: - filters["post_ids"] = post_ids - if language is not None: - filters["language"] = language - - return NoteListResponse(data=list(storage.get_notes(**filters))) + return NoteListResponse( + data=list( + storage.get_notes( + note_ids=note_ids, + created_at_from=created_at_from, + created_at_to=created_at_to, + topic_ids=topic_ids, + post_ids=post_ids, + language=language, + ) + ) + ) @router.get("/posts", response_model=PostListResponse) def get_posts( diff --git a/birdxplorer/storage.py b/birdxplorer/storage.py index e3c13fe..2c9d2fe 100644 --- a/birdxplorer/storage.py +++ b/birdxplorer/storage.py @@ -1,4 +1,4 @@ -from typing import Generator, List +from typing import Generator, List, Union from psycopg2.extensions import AsIs, register_adapter from pydantic import AnyUrl, HttpUrl @@ -149,7 +149,13 @@ def get_topics(self) -> Generator[TopicModel, None, None]: ) def get_notes( - self, note_ids=None, created_at_from=None, created_at_to=None, topic_ids=None, post_ids=None, language=None + self, + note_ids: Union[List[NoteId], None] = None, + created_at_from: Union[None, TwitterTimestamp] = None, + created_at_to: Union[None, TwitterTimestamp] = None, + topic_ids: Union[List[TopicId], None] = None, + post_ids: Union[List[TweetId], None] = None, + language: Union[LanguageIdentifier, None] = None, ) -> Generator[NoteModel, None, None]: with Session(self.engine) as sess: query = sess.query(NoteRecord) diff --git a/tests/conftest.py b/tests/conftest.py index 2474fb1..29f0f09 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import os import random from collections.abc import Generator -from typing import List, Type +from typing import List, Type, Union from unittest.mock import MagicMock, patch from dotenv import load_dotenv @@ -18,11 +18,15 @@ from birdxplorer.exceptions import UserEnrollmentNotFoundError from birdxplorer.models import ( + LanguageIdentifier, Note, + NoteId, ParticipantId, Post, PostId, Topic, + TopicId, + TweetId, TwitterTimestamp, UserEnrollment, XUser, @@ -133,7 +137,12 @@ def _get_topics() -> Generator[Topic, None, None]: yield from topic_samples def _get_notes( - note_ids=None, created_at_from=None, created_at_to=None, topic_ids=None, post_ids=None, language=None + note_ids: Union[List[NoteId], None] = None, + created_at_from: Union[None, TwitterTimestamp] = None, + created_at_to: Union[None, TwitterTimestamp] = None, + topic_ids: Union[List[TopicId], None] = None, + post_ids: Union[List[TweetId], None] = None, + language: Union[LanguageIdentifier, None] = None, ) -> Generator[Note, None, None]: for note in note_samples: if note_ids is not None and note.note_id not in note_ids: diff --git a/tests/test_storage.py b/tests/test_storage.py index 4865e54..5cbcf20 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -2,7 +2,17 @@ from sqlalchemy.engine import Engine -from birdxplorer.models import Note, Post, PostId, Topic, TweetId, TwitterTimestamp +from birdxplorer.models import ( + LanguageIdentifier, + Note, + NoteId, + Post, + PostId, + Topic, + TopicId, + TweetId, + TwitterTimestamp, +) from birdxplorer.storage import NoteRecord, PostRecord, Storage, TopicRecord @@ -120,7 +130,7 @@ def test_get_notes_by_ids_empty( note_records_sample: List[NoteRecord], ) -> None: storage = Storage(engine=engine_for_test) - note_ids: List[int] = [] + note_ids: List[NoteId] = [] expected: List[Note] = [] actual = list(storage.get_notes(note_ids=note_ids)) assert expected == actual @@ -170,7 +180,7 @@ def test_get_notes_by_topic_ids( ) -> None: storage = Storage(engine=engine_for_test) topics = note_samples[0].topics - topic_ids = [0] + topic_ids: List[TopicId] = [TopicId.from_int(0)] expected = sorted([note for note in note_samples if note.topics == topics], key=lambda note: note.note_id) actual = sorted(list(storage.get_notes(topic_ids=topic_ids)), key=lambda note: note.note_id) assert expected == actual @@ -182,7 +192,7 @@ def test_get_notes_by_topic_ids_empty( note_records_sample: List[NoteRecord], ) -> None: storage = Storage(engine=engine_for_test) - topic_ids: List[int] = [] + topic_ids: List[TopicId] = [] expected: List[Note] = [] actual = list(storage.get_notes(topic_ids=topic_ids)) assert expected == actual @@ -206,7 +216,7 @@ def test_get_notes_by_post_ids_empty( note_records_sample: List[NoteRecord], ) -> None: storage = Storage(engine=engine_for_test) - post_ids: List[int] = [] + post_ids: List[TweetId] = [] expected: List[Note] = [] actual = list(storage.get_notes(post_ids=post_ids)) assert expected == actual @@ -218,19 +228,7 @@ def test_get_notes_by_language( note_records_sample: List[NoteRecord], ) -> None: storage = Storage(engine=engine_for_test) - language = "en" + language = LanguageIdentifier("en") expected = [note for note in note_samples if note.language == language] actual = list(storage.get_notes(language=language)) assert expected == actual - - -def test_get_notes_by_language_empty( - engine_for_test: Engine, - note_samples: List[Note], - note_records_sample: List[NoteRecord], -) -> None: - storage = Storage(engine=engine_for_test) - language = "" - expected: List[Note] = [] - actual = list(storage.get_notes(language=language)) - assert expected == actual From 279ccb3a088e0450e0929950be7459a44d6cca3f Mon Sep 17 00:00:00 2001 From: osoken Date: Mon, 22 Apr 2024 00:46:47 +0900 Subject: [PATCH 12/19] fix(router): remove redundant aliases --- birdxplorer/routers/data.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/birdxplorer/routers/data.py b/birdxplorer/routers/data.py index 7200623..df95afd 100644 --- a/birdxplorer/routers/data.py +++ b/birdxplorer/routers/data.py @@ -67,11 +67,11 @@ def get_topics() -> TopicListResponse: @router.get("/notes", response_model=NoteListResponse) def get_notes( - note_ids: Union[List[NoteId], None] = Query(default=None, alias="noteIds"), - created_at_from: Union[None, TwitterTimestamp] = Query(default=None, alias="createdAtFrom"), - created_at_to: Union[None, TwitterTimestamp] = Query(default=None, alias="createdAtTo"), - topic_ids: Union[List[TopicId], None] = Query(default=None, alias="topicIds"), - post_ids: Union[List[TweetId], None] = Query(default=None, alias="postIds"), + note_ids: Union[List[NoteId], None] = Query(default=None), + created_at_from: Union[None, TwitterTimestamp] = Query(default=None), + created_at_to: Union[None, TwitterTimestamp] = Query(default=None), + topic_ids: Union[List[TopicId], None] = Query(default=None), + post_ids: Union[List[TweetId], None] = Query(default=None), language: Union[LanguageIdentifier, None] = Query(default=None), ) -> NoteListResponse: return NoteListResponse( From 79dbdb599cb68021977a924ab907dc55913d25db Mon Sep 17 00:00:00 2001 From: osoken Date: Fri, 3 May 2024 22:59:15 +0900 Subject: [PATCH 13/19] feat(common): add common package --- .vscode/settings.json.example | 20 ++- birdxplorer/app.py | 46 ------- birdxplorer/main.py | 11 -- birdxplorer/routers/__init__.py | 0 birdxplorer/routers/data.py | 117 ----------------- birdxplorer/routers/system.py | 13 -- .../birdxplorer_common}/__init__.py | 0 .../birdxplorer_common}/exceptions.py | 0 .../birdxplorer_common}/logger.py | 0 .../birdxplorer_common}/models.py | 0 .../birdxplorer_common}/py.typed | 0 .../birdxplorer_common}/settings.py | 0 .../birdxplorer_common}/storage.py | 0 pyproject.toml => common/pyproject.toml | 27 ++-- .../stubs}/json_log_formatter.pyi | 0 {tests => common/tests}/conftest.py | 111 +--------------- {tests => common/tests}/test_data_model.py | 2 +- {tests => common/tests}/test_logger.py | 4 +- common/tests/test_settings.py | 25 ++++ {tests => common/tests}/test_storage.py | 14 +- tests/routers/__init__.py | 0 tests/routers/test_data.py | 120 ------------------ tests/routers/test_system.py | 7 - tests/test_app.py | 15 --- tests/test_birdxplorer.py | 5 - tests/test_main.py | 13 -- tests/test_settings.py | 16 --- 27 files changed, 69 insertions(+), 497 deletions(-) delete mode 100644 birdxplorer/app.py delete mode 100644 birdxplorer/main.py delete mode 100644 birdxplorer/routers/__init__.py delete mode 100644 birdxplorer/routers/data.py delete mode 100644 birdxplorer/routers/system.py rename {birdxplorer => common/birdxplorer_common}/__init__.py (100%) rename {birdxplorer => common/birdxplorer_common}/exceptions.py (100%) rename {birdxplorer => common/birdxplorer_common}/logger.py (100%) rename {birdxplorer => common/birdxplorer_common}/models.py (100%) rename {birdxplorer => common/birdxplorer_common}/py.typed (100%) rename {birdxplorer => common/birdxplorer_common}/settings.py (100%) rename {birdxplorer => common/birdxplorer_common}/storage.py (100%) rename pyproject.toml => common/pyproject.toml (71%) rename {stubs => common/stubs}/json_log_formatter.pyi (100%) rename {tests => common/tests}/conftest.py (73%) rename {tests => common/tests}/test_data_model.py (96%) rename {tests => common/tests}/test_logger.py (71%) create mode 100644 common/tests/test_settings.py rename {tests => common/tests}/test_storage.py (95%) delete mode 100644 tests/routers/__init__.py delete mode 100644 tests/routers/test_data.py delete mode 100644 tests/routers/test_system.py delete mode 100644 tests/test_app.py delete mode 100644 tests/test_birdxplorer.py delete mode 100644 tests/test_main.py delete mode 100644 tests/test_settings.py diff --git a/.vscode/settings.json.example b/.vscode/settings.json.example index b018960..0aaaa5b 100644 --- a/.vscode/settings.json.example +++ b/.vscode/settings.json.example @@ -14,7 +14,7 @@ "editor.formatOnSave": true, "[python]": { "editor.codeActionsOnSave": { - "source.organizeImports": true + "source.organizeImports": "explicit" }, "editor.defaultFormatter": "ms-python.black-formatter", "editor.formatOnSave": true @@ -23,9 +23,21 @@ ".venv/lib/python3.10/site-packages" ], "python.defaultInterpreterPath": ".venv/bin/python", - "python.analysis.stubPath": "stubs", + "python.analysis.stubPath": "common/stubs", + "black-formatter.args": [ + "--config", + "common/pyproject.toml" + ], "flake8.path": [ ".venv/bin/pflake8" ], - "python.formatting.provider": "none", -} + "flake8.args": [ + "--config", + "common/pyproject.toml" + ], + "isort.args": [ + "--settings-path", + "common/pyproject.toml" + ], + "python.formatting.provider": "none" +} \ No newline at end of file diff --git a/birdxplorer/app.py b/birdxplorer/app.py deleted file mode 100644 index 8c93ba7..0000000 --- a/birdxplorer/app.py +++ /dev/null @@ -1,46 +0,0 @@ -import csv -import io -from urllib.parse import parse_qs as parse_query_string -from urllib.parse import urlencode as encode_query_string - -from fastapi import FastAPI -from pydantic.alias_generators import to_snake -from starlette.types import ASGIApp, Receive, Scope, Send - -from .logger import get_logger -from .routers.data import gen_router as gen_data_router -from .routers.system import gen_router as gen_system_router -from .settings import GlobalSettings -from .storage import gen_storage - - -class QueryStringFlatteningMiddleware: - def __init__(self, app: ASGIApp) -> None: - self._app = app - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - query_string = scope.get("query_string") - if not isinstance(query_string, bytes): - query_string = b"" - query_string = query_string.decode("utf-8") - if scope["type"] == "http" and query_string: - parsed = parse_query_string(query_string) - flattened = {} - for name, values in parsed.items(): - flattened[to_snake(name)] = [c for value in values for r in csv.reader(io.StringIO(value)) for c in r] - - scope["query_string"] = encode_query_string(flattened, doseq=True).encode("utf-8") - - await self._app(scope, receive, send) - else: - await self._app(scope, receive, send) - - -def gen_app(settings: GlobalSettings) -> FastAPI: - _ = get_logger(level=settings.logger_settings.level) - storage = gen_storage(settings=settings) - app = FastAPI() - app.add_middleware(QueryStringFlatteningMiddleware) - app.include_router(gen_system_router(), prefix="/api/v1/system") - app.include_router(gen_data_router(storage=storage), prefix="/api/v1/data") - return app diff --git a/birdxplorer/main.py b/birdxplorer/main.py deleted file mode 100644 index baa7c77..0000000 --- a/birdxplorer/main.py +++ /dev/null @@ -1,11 +0,0 @@ -from fastapi import FastAPI - -from .app import gen_app -from .settings import GlobalSettings - - -def main() -> FastAPI: - return gen_app(settings=GlobalSettings()) - - -app = main() diff --git a/birdxplorer/routers/__init__.py b/birdxplorer/routers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/birdxplorer/routers/data.py b/birdxplorer/routers/data.py deleted file mode 100644 index df95afd..0000000 --- a/birdxplorer/routers/data.py +++ /dev/null @@ -1,117 +0,0 @@ -from datetime import timezone -from typing import List, Union - -from dateutil.parser import parse as dateutil_parse -from fastapi import APIRouter, HTTPException, Query - -from ..models import ( - BaseModel, - LanguageIdentifier, - Note, - NoteId, - ParticipantId, - Post, - PostId, - Topic, - TopicId, - TweetId, - TwitterTimestamp, - UserEnrollment, -) -from ..storage import Storage - - -class TopicListResponse(BaseModel): - data: List[Topic] - - -class NoteListResponse(BaseModel): - data: List[Note] - - -class PostListResponse(BaseModel): - data: List[Post] - - -def str_to_twitter_timestamp(s: str) -> TwitterTimestamp: - try: - return TwitterTimestamp.from_int(int(s)) - except ValueError: - pass - try: - tmp = dateutil_parse(s) - if tmp.tzinfo is None: - tmp = tmp.replace(tzinfo=timezone.utc) - return TwitterTimestamp.from_int(int(tmp.timestamp() * 1000)) - except ValueError: - raise HTTPException(status_code=422, detail=f"Invalid TwitterTimestamp string: {s}") - - -def ensure_twitter_timestamp(t: Union[str, TwitterTimestamp]) -> TwitterTimestamp: - return str_to_twitter_timestamp(t) if isinstance(t, str) else t - - -def gen_router(storage: Storage) -> APIRouter: - router = APIRouter() - - @router.get("/user-enrollments/{participant_id}", response_model=UserEnrollment) - def get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> UserEnrollment: - res = storage.get_user_enrollment_by_participant_id(participant_id=participant_id) - if res is None: - raise ValueError(f"participant_id={participant_id} not found") - return res - - @router.get("/topics", response_model=TopicListResponse) - def get_topics() -> TopicListResponse: - return TopicListResponse(data=list(storage.get_topics())) - - @router.get("/notes", response_model=NoteListResponse) - def get_notes( - note_ids: Union[List[NoteId], None] = Query(default=None), - created_at_from: Union[None, TwitterTimestamp] = Query(default=None), - created_at_to: Union[None, TwitterTimestamp] = Query(default=None), - topic_ids: Union[List[TopicId], None] = Query(default=None), - post_ids: Union[List[TweetId], None] = Query(default=None), - language: Union[LanguageIdentifier, None] = Query(default=None), - ) -> NoteListResponse: - return NoteListResponse( - data=list( - storage.get_notes( - note_ids=note_ids, - created_at_from=created_at_from, - created_at_to=created_at_to, - topic_ids=topic_ids, - post_ids=post_ids, - language=language, - ) - ) - ) - - @router.get("/posts", response_model=PostListResponse) - def get_posts( - post_id: Union[List[PostId], None] = Query(default=None), - created_at_start: Union[None, TwitterTimestamp, str] = Query(default=None), - created_at_end: Union[None, TwitterTimestamp, str] = Query(default=None), - ) -> PostListResponse: - if post_id is not None: - return PostListResponse(data=list(storage.get_posts_by_ids(post_ids=post_id))) - if created_at_start is not None: - if created_at_end is not None: - return PostListResponse( - data=list( - storage.get_posts_by_created_at_range( - start=ensure_twitter_timestamp(created_at_start), - end=ensure_twitter_timestamp(created_at_end), - ) - ) - ) - return PostListResponse( - data=list(storage.get_posts_by_created_at_start(start=ensure_twitter_timestamp(created_at_start))) - ) - if created_at_end is not None: - return PostListResponse( - data=list(storage.get_posts_by_created_at_end(end=ensure_twitter_timestamp(created_at_end))) - ) - return PostListResponse(data=list(storage.get_posts())) - - return router diff --git a/birdxplorer/routers/system.py b/birdxplorer/routers/system.py deleted file mode 100644 index 3108d4a..0000000 --- a/birdxplorer/routers/system.py +++ /dev/null @@ -1,13 +0,0 @@ -from fastapi import APIRouter - -from ..models import Message - - -def gen_router() -> APIRouter: - router = APIRouter() - - @router.get("/ping", response_model=Message) - async def ping() -> Message: - return Message(message="pong") - - return router diff --git a/birdxplorer/__init__.py b/common/birdxplorer_common/__init__.py similarity index 100% rename from birdxplorer/__init__.py rename to common/birdxplorer_common/__init__.py diff --git a/birdxplorer/exceptions.py b/common/birdxplorer_common/exceptions.py similarity index 100% rename from birdxplorer/exceptions.py rename to common/birdxplorer_common/exceptions.py diff --git a/birdxplorer/logger.py b/common/birdxplorer_common/logger.py similarity index 100% rename from birdxplorer/logger.py rename to common/birdxplorer_common/logger.py diff --git a/birdxplorer/models.py b/common/birdxplorer_common/models.py similarity index 100% rename from birdxplorer/models.py rename to common/birdxplorer_common/models.py diff --git a/birdxplorer/py.typed b/common/birdxplorer_common/py.typed similarity index 100% rename from birdxplorer/py.typed rename to common/birdxplorer_common/py.typed diff --git a/birdxplorer/settings.py b/common/birdxplorer_common/settings.py similarity index 100% rename from birdxplorer/settings.py rename to common/birdxplorer_common/settings.py diff --git a/birdxplorer/storage.py b/common/birdxplorer_common/storage.py similarity index 100% rename from birdxplorer/storage.py rename to common/birdxplorer_common/storage.py diff --git a/pyproject.toml b/common/pyproject.toml similarity index 71% rename from pyproject.toml rename to common/pyproject.toml index abcf924..2cc9ae9 100644 --- a/pyproject.toml +++ b/common/pyproject.toml @@ -4,16 +4,16 @@ requires = ["flit_core >=3.8.0,<4"] [project] -name = "birdxplorer" -description = "birdxplorer is a tool to help you read and get insights from your documents." +name = "birdxplorer_common" +description = "Common library for BirdXplorer" authors = [ {name = "osoken"}, ] dynamic = [ "version", ] -readme = "README.md" -license = {file = "LICENSE"} +readme = "../README.md" +license = {file = "../LICENSE"} requires-python = ">=3.10" classifiers = [ @@ -29,16 +29,14 @@ dependencies = [ "python-dateutil", "sqlalchemy", "pydantic_settings", - "fastapi", "JSON-log-formatter", - "openai", ] [project.urls] Source = "https://github.com/codeforjapan/BirdXplorer" [tool.setuptools] -packages=["birdxplorer"] +packages=["birdxplorer_common"] [tool.setuptools.package-data] birdxplorer = ["py.typed"] @@ -58,9 +56,7 @@ dev=[ "types-python-dateutil", "psycopg2-binary", "factory_boy", - "uvicorn", "polyfactory", - "httpx", "types-psycopg2", ] prod=[ @@ -69,11 +65,10 @@ prod=[ [tool.pytest.ini_options] -addopts = ["-sv", "--doctest-modules", "--cov=birdxplorer", "--cov-report=xml", "--cov-report=term-missing"] -testpaths = ["tests", "birdxplorer"] +addopts = ["-sv", "--doctest-modules", "--cov=birdxplorer_common", "--cov-report=xml", "--cov-report=term-missing"] +testpaths = ["tests", "birdxplorer_common"] filterwarnings = [ "error", - "ignore:The \\'app\\' shortcut is now deprecated. Use the explicit style \\'transport=WSGITransport\\(app=\\.\\.\\.\\)\\' instead\\.", ] [tool.black] @@ -110,10 +105,10 @@ legacy_tox_ini = """ deps = -e .[dev] commands = - black birdxplorer tests - isort birdxplorer tests + black birdxplorer_common tests + isort birdxplorer_common tests pytest - pflake8 birdxplorer/ tests/ stubs/ - mypy birdxplorer --strict + pflake8 birdxplorer_common/ tests/ stubs/ + mypy birdxplorer_common --strict mypy tests --strict """ diff --git a/stubs/json_log_formatter.pyi b/common/stubs/json_log_formatter.pyi similarity index 100% rename from stubs/json_log_formatter.pyi rename to common/stubs/json_log_formatter.pyi diff --git a/tests/conftest.py b/common/tests/conftest.py similarity index 73% rename from tests/conftest.py rename to common/tests/conftest.py index 29f0f09..a2b1094 100644 --- a/tests/conftest.py +++ b/common/tests/conftest.py @@ -1,11 +1,9 @@ import os import random from collections.abc import Generator -from typing import List, Type, Union -from unittest.mock import MagicMock, patch +from typing import List, Type from dotenv import load_dotenv -from fastapi.testclient import TestClient from polyfactory import Use from polyfactory.factories.pydantic_factory import ModelFactory from polyfactory.pytest_plugin import register_fixture @@ -16,28 +14,20 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import text -from birdxplorer.exceptions import UserEnrollmentNotFoundError -from birdxplorer.models import ( - LanguageIdentifier, +from birdxplorer_common.models import ( Note, - NoteId, - ParticipantId, Post, - PostId, Topic, - TopicId, - TweetId, TwitterTimestamp, UserEnrollment, XUser, ) -from birdxplorer.settings import GlobalSettings, PostgresStorageSettings -from birdxplorer.storage import ( +from birdxplorer_common.settings import GlobalSettings, PostgresStorageSettings +from birdxplorer_common.storage import ( Base, NoteRecord, NoteTopicAssociation, PostRecord, - Storage, TopicRecord, XUserRecord, ) @@ -116,99 +106,6 @@ def user_enrollment_samples( yield [user_enrollment_factory.build() for _ in range(3)] -@fixture -def mock_storage( - user_enrollment_samples: List[UserEnrollment], - topic_samples: List[Topic], - post_samples: List[Post], - note_samples: List[Note], -) -> Generator[MagicMock, None, None]: - mock = MagicMock(spec=Storage) - - def _get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> UserEnrollment: - x = {d.participant_id: d for d in user_enrollment_samples}.get(participant_id) - if x is None: - raise UserEnrollmentNotFoundError(participant_id=participant_id) - return x - - mock.get_user_enrollment_by_participant_id.side_effect = _get_user_enrollment_by_participant_id - - def _get_topics() -> Generator[Topic, None, None]: - yield from topic_samples - - def _get_notes( - note_ids: Union[List[NoteId], None] = None, - created_at_from: Union[None, TwitterTimestamp] = None, - created_at_to: Union[None, TwitterTimestamp] = None, - topic_ids: Union[List[TopicId], None] = None, - post_ids: Union[List[TweetId], None] = None, - language: Union[LanguageIdentifier, None] = None, - ) -> Generator[Note, None, None]: - for note in note_samples: - if note_ids is not None and note.note_id not in note_ids: - continue - if created_at_from is not None and note.created_at < created_at_from: - continue - if created_at_to is not None and note.created_at > created_at_to: - continue - if topic_ids is not None and not set(topic_ids).issubset({topic.topic_id for topic in note.topics}): - continue - if post_ids is not None and note.post_id not in post_ids: - continue - if language is not None and note.language != language: - continue - yield note - - mock.get_topics.side_effect = _get_topics - mock.get_notes.side_effect = _get_notes - - def _get_posts() -> Generator[Post, None, None]: - yield from post_samples - - mock.get_posts.side_effect = _get_posts - - def _get_posts_by_ids(post_ids: List[PostId]) -> Generator[Post, None, None]: - for i in post_ids: - for post in post_samples: - if post.post_id == i: - yield post - break - - mock.get_posts_by_ids.side_effect = _get_posts_by_ids - - def _get_posts_by_created_at_range(start: TwitterTimestamp, end: TwitterTimestamp) -> Generator[Post, None, None]: - for post in post_samples: - if start <= post.created_at < end: - yield post - - mock.get_posts_by_created_at_range.side_effect = _get_posts_by_created_at_range - - def _get_posts_by_created_at_start(start: TwitterTimestamp) -> Generator[Post, None, None]: - for post in post_samples: - if start <= post.created_at: - yield post - - mock.get_posts_by_created_at_start.side_effect = _get_posts_by_created_at_start - - def _get_posts_by_created_at_end(end: TwitterTimestamp) -> Generator[Post, None, None]: - for post in post_samples: - if post.created_at < end: - yield post - - mock.get_posts_by_created_at_end.side_effect = _get_posts_by_created_at_end - - yield mock - - -@fixture -def client(settings_for_test: GlobalSettings, mock_storage: MagicMock) -> Generator[TestClient, None, None]: - from birdxplorer.app import gen_app - - with patch("birdxplorer.app.gen_storage", return_value=mock_storage): - app = gen_app(settings=settings_for_test) - yield TestClient(app) - - @fixture def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, None]: topics = [ diff --git a/tests/test_data_model.py b/common/tests/test_data_model.py similarity index 96% rename from tests/test_data_model.py rename to common/tests/test_data_model.py index a57e765..d3f77f0 100644 --- a/tests/test_data_model.py +++ b/common/tests/test_data_model.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping -from birdxplorer.models import NoteData, UserEnrollment +from birdxplorer_common.models import NoteData, UserEnrollment class BaseDataModelTester(ABC): diff --git a/tests/test_logger.py b/common/tests/test_logger.py similarity index 71% rename from tests/test_logger.py rename to common/tests/test_logger.py index 126931c..6922e75 100644 --- a/tests/test_logger.py +++ b/common/tests/test_logger.py @@ -1,6 +1,6 @@ from pytest import LogCaptureFixture -from birdxplorer.logger import get_logger +from birdxplorer_common.logger import get_logger def test_logger_is_a_child_of_root_logger(caplog: LogCaptureFixture) -> None: @@ -8,5 +8,5 @@ def test_logger_is_a_child_of_root_logger(caplog: LogCaptureFixture) -> None: with caplog.at_level("INFO"): logger.info("test") assert len(caplog.records) == 1 - assert caplog.records[0].name == "birdxplorer.logger" + assert caplog.records[0].name == "birdxplorer_common.logger" assert caplog.records[0].message == "test" diff --git a/common/tests/test_settings.py b/common/tests/test_settings.py new file mode 100644 index 0000000..41424f4 --- /dev/null +++ b/common/tests/test_settings.py @@ -0,0 +1,25 @@ +import os + +from pytest_mock import MockerFixture + +from birdxplorer_common.settings import GlobalSettings + + +def test_settings_read_from_env_var(mocker: MockerFixture) -> None: + mocker.patch.dict( + os.environ, + {"BX_LOGGER_SETTINGS__LEVEL": "99", "BX_STORAGE_SETTINGS__PASSWORD": "s0m6S+ron9P@55w0rd"}, + clear=True, + ) + settings = GlobalSettings() + assert settings.logger_settings.level == 99 + + +def test_settings_default(mocker: MockerFixture) -> None: + mocker.patch.dict( + os.environ, + {"BX_STORAGE_SETTINGS__PASSWORD": "s0m6S+ron9P@55w0rd"}, + ) + + settings = GlobalSettings() + assert settings.logger_settings.level == 20 diff --git a/tests/test_storage.py b/common/tests/test_storage.py similarity index 95% rename from tests/test_storage.py rename to common/tests/test_storage.py index 5cbcf20..3b74975 100644 --- a/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -2,7 +2,7 @@ from sqlalchemy.engine import Engine -from birdxplorer.models import ( +from birdxplorer_common.models import ( LanguageIdentifier, Note, NoteId, @@ -13,7 +13,7 @@ TweetId, TwitterTimestamp, ) -from birdxplorer.storage import NoteRecord, PostRecord, Storage, TopicRecord +from birdxplorer_common.storage import NoteRecord, PostRecord, Storage, TopicRecord def test_get_topic_list( @@ -181,7 +181,10 @@ def test_get_notes_by_topic_ids( storage = Storage(engine=engine_for_test) topics = note_samples[0].topics topic_ids: List[TopicId] = [TopicId.from_int(0)] - expected = sorted([note for note in note_samples if note.topics == topics], key=lambda note: note.note_id) + expected = sorted( + [note for note in note_samples if note.topics == topics], + key=lambda note: note.note_id, + ) actual = sorted(list(storage.get_notes(topic_ids=topic_ids)), key=lambda note: note.note_id) assert expected == actual @@ -204,7 +207,10 @@ def test_get_notes_by_post_ids( note_records_sample: List[NoteRecord], ) -> None: storage = Storage(engine=engine_for_test) - post_ids = [TweetId.from_str("2234567890123456781"), TweetId.from_str("2234567890123456782")] + post_ids = [ + TweetId.from_str("2234567890123456781"), + TweetId.from_str("2234567890123456782"), + ] expected = [note for note in note_samples if note.post_id in post_ids] actual = list(storage.get_notes(post_ids=post_ids)) assert expected == actual diff --git a/tests/routers/__init__.py b/tests/routers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/routers/test_data.py b/tests/routers/test_data.py deleted file mode 100644 index 8f719f7..0000000 --- a/tests/routers/test_data.py +++ /dev/null @@ -1,120 +0,0 @@ -import json -from typing import List - -from fastapi.testclient import TestClient - -from birdxplorer.models import Note, Post, Topic, UserEnrollment - - -def test_user_enrollments_get(client: TestClient, user_enrollment_samples: List[UserEnrollment]) -> None: - response = client.get(f"/api/v1/data/user-enrollments/{user_enrollment_samples[0].participant_id}") - assert response.status_code == 200 - res_json = response.json() - assert res_json["participantId"] == user_enrollment_samples[0].participant_id - - -def test_topics_get(client: TestClient, topic_samples: List[Topic]) -> None: - response = client.get("/api/v1/data/topics") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [d.model_dump(by_alias=True) for d in topic_samples]} - - -def test_posts_get(client: TestClient, post_samples: List[Post]) -> None: - response = client.get("/api/v1/data/posts") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(d.model_dump_json()) for d in post_samples]} - - -def test_posts_get_has_post_id_filter(client: TestClient, post_samples: List[Post]) -> None: - response = client.get(f"/api/v1/data/posts/?postId={post_samples[0].post_id},{post_samples[2].post_id}") - assert response.status_code == 200 - res_json = response.json() - assert res_json == { - "data": [json.loads(post_samples[0].model_dump_json()), json.loads(post_samples[2].model_dump_json())] - } - - -def test_posts_get_has_created_at_filter_start_and_end(client: TestClient, post_samples: List[Post]) -> None: - response = client.get("/api/v1/data/posts/?createdAtStart=2006-7-25 00:00:00&createdAtEnd=2006-7-30 23:59:59") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(post_samples[1].model_dump_json())]} - - -def test_posts_get_has_created_at_filter_start(client: TestClient, post_samples: List[Post]) -> None: - response = client.get("/api/v1/data/posts/?createdAtStart=2006-7-25 00:00:00") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2)]} - - -def test_posts_get_has_created_at_filter_end(client: TestClient, post_samples: List[Post]) -> None: - response = client.get("/api/v1/data/posts/?createdAtEnd=2006-7-30 00:00:00") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(post_samples[i].model_dump_json()) for i in (0, 1)]} - - -def test_posts_get_created_at_range_filter_accepts_integer(client: TestClient, post_samples: List[Post]) -> None: - response = client.get("/api/v1/data/posts/?createdAtStart=1153921700000&createdAtEnd=1154921800000") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(post_samples[1].model_dump_json())]} - - -def test_posts_get_created_at_start_filter_accepts_integer(client: TestClient, post_samples: List[Post]) -> None: - response = client.get("/api/v1/data/posts/?createdAtStart=1153921700000") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2)]} - - -def test_posts_get_created_at_end_filter_accepts_integer(client: TestClient, post_samples: List[Post]) -> None: - response = client.get("/api/v1/data/posts/?createdAtEnd=1154921800000") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(post_samples[i].model_dump_json()) for i in (0, 1)]} - - -def test_posts_get_timestamp_out_of_range(client: TestClient, post_samples: List[Post]) -> None: - response = client.get("/api/v1/data/posts/?createdAtStart=1153921700&createdAtEnd=1153921700") - assert response.status_code == 422 - - -def test_notes_get(client: TestClient, note_samples: List[Note]) -> None: - response = client.get("/api/v1/data/notes") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(d.model_dump_json()) for d in note_samples]} - - -def test_notes_get_has_note_id_filter(client: TestClient, note_samples: List[Note]) -> None: - response = client.get(f"/api/v1/data/notes/?noteIds={note_samples[0].note_id}¬eIds={note_samples[2].note_id}") - assert response.status_code == 200 - res_json = response.json() - assert res_json == { - "data": [json.loads(note_samples[0].model_dump_json()), json.loads(note_samples[2].model_dump_json())] - } - - -def test_notes_get_has_created_at_filter_from_and_to(client: TestClient, note_samples: List[Note]) -> None: - response = client.get("/api/v1/data/notes/?createdAtFrom=1152921601000&createdAtTo=1152921603000") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3)]} - - -def test_notes_get_has_created_at_filter_from(client: TestClient, note_samples: List[Note]) -> None: - response = client.get("/api/v1/data/notes/?createdAtFrom=1152921601000") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3, 4)]} - - -def test_notes_get_has_created_at_filter_to(client: TestClient, note_samples: List[Note]) -> None: - response = client.get("/api/v1/data/notes/?createdAtTo=1152921603000") - assert response.status_code == 200 - res_json = response.json() - assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1, 2, 3)]} diff --git a/tests/routers/test_system.py b/tests/routers/test_system.py deleted file mode 100644 index 36751a7..0000000 --- a/tests/routers/test_system.py +++ /dev/null @@ -1,7 +0,0 @@ -from fastapi.testclient import TestClient - - -def test_ping(client: TestClient) -> None: - response = client.get("/api/v1/system/ping") - assert response.status_code == 200 - assert response.json() == {"message": "pong"} diff --git a/tests/test_app.py b/tests/test_app.py deleted file mode 100644 index e87e423..0000000 --- a/tests/test_app.py +++ /dev/null @@ -1,15 +0,0 @@ -from pytest_mock import MockerFixture - -from birdxplorer.app import gen_app -from birdxplorer.settings import GlobalSettings - - -def test_gen_app(mocker: MockerFixture, default_settings: GlobalSettings) -> None: - FastAPI = mocker.patch("birdxplorer.app.FastAPI") - get_logger = mocker.patch("birdxplorer.app.get_logger") - expected = FastAPI.return_value - - actual = gen_app(settings=default_settings) - - assert actual == expected - get_logger.assert_called_once_with(level=default_settings.logger_settings.level) diff --git a/tests/test_birdxplorer.py b/tests/test_birdxplorer.py deleted file mode 100644 index 8e6dd87..0000000 --- a/tests/test_birdxplorer.py +++ /dev/null @@ -1,5 +0,0 @@ -import birdxplorer - - -def test_birdxplorer_has_version() -> None: - assert birdxplorer.__version__ is not None diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 87dc313..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,13 +0,0 @@ -from pytest_mock import MockerFixture - - -def test_main_returns_app(mocker: MockerFixture) -> None: - gen_app = mocker.patch("birdxplorer.main.gen_app") - GlobalSettings = mocker.patch("birdxplorer.main.GlobalSettings") - from birdxplorer.main import main - - expected = gen_app.return_value - actual = main() - GlobalSettings.assert_called_once_with() - gen_app.assert_called_once_with(settings=GlobalSettings.return_value) - assert actual == expected diff --git a/tests/test_settings.py b/tests/test_settings.py deleted file mode 100644 index da77138..0000000 --- a/tests/test_settings.py +++ /dev/null @@ -1,16 +0,0 @@ -import os - -from pytest_mock import MockerFixture - -from birdxplorer.settings import GlobalSettings - - -def test_settings_read_from_env_var(mocker: MockerFixture) -> None: - mocker.patch.dict(os.environ, {"BX_LOGGER_SETTINGS__LEVEL": "99"}, clear=True) - settings = GlobalSettings() - assert settings.logger_settings.level == 99 - - -def test_settings_default() -> None: - settings = GlobalSettings() - assert settings.logger_settings.level == 20 From bd8dd4bcdfa0fe5fdd08218a926ba391a989b211 Mon Sep 17 00:00:00 2001 From: osoken Date: Sat, 4 May 2024 18:37:19 +0900 Subject: [PATCH 14/19] feat(api): add api package --- api/birdxplorer_api/__init__.py | 1 + api/birdxplorer_api/app.py | 46 ++++ api/birdxplorer_api/main.py | 11 + api/birdxplorer_api/py.typed | 0 api/birdxplorer_api/routers/__init__.py | 0 api/birdxplorer_api/routers/data.py | 118 ++++++++ api/birdxplorer_api/routers/system.py | 12 + api/pyproject.toml | 113 ++++++++ api/tests/conftest.py | 347 ++++++++++++++++++++++++ api/tests/routers/__init__.py | 0 api/tests/routers/test_data.py | 125 +++++++++ api/tests/routers/test_system.py | 7 + api/tests/test_app.py | 15 + api/tests/test_main.py | 13 + api/tests/test_package.py | 9 + 15 files changed, 817 insertions(+) create mode 100644 api/birdxplorer_api/__init__.py create mode 100644 api/birdxplorer_api/app.py create mode 100644 api/birdxplorer_api/main.py create mode 100644 api/birdxplorer_api/py.typed create mode 100644 api/birdxplorer_api/routers/__init__.py create mode 100644 api/birdxplorer_api/routers/data.py create mode 100644 api/birdxplorer_api/routers/system.py create mode 100644 api/pyproject.toml create mode 100644 api/tests/conftest.py create mode 100644 api/tests/routers/__init__.py create mode 100644 api/tests/routers/test_data.py create mode 100644 api/tests/routers/test_system.py create mode 100644 api/tests/test_app.py create mode 100644 api/tests/test_main.py create mode 100644 api/tests/test_package.py diff --git a/api/birdxplorer_api/__init__.py b/api/birdxplorer_api/__init__.py new file mode 100644 index 0000000..f102a9c --- /dev/null +++ b/api/birdxplorer_api/__init__.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/api/birdxplorer_api/app.py b/api/birdxplorer_api/app.py new file mode 100644 index 0000000..cf72e28 --- /dev/null +++ b/api/birdxplorer_api/app.py @@ -0,0 +1,46 @@ +import csv +import io +from urllib.parse import parse_qs as parse_query_string +from urllib.parse import urlencode as encode_query_string + +from birdxplorer_common.logger import get_logger +from birdxplorer_common.settings import GlobalSettings +from birdxplorer_common.storage import gen_storage +from fastapi import FastAPI +from pydantic.alias_generators import to_snake +from starlette.types import ASGIApp, Receive, Scope, Send + +from .routers.data import gen_router as gen_data_router +from .routers.system import gen_router as gen_system_router + + +class QueryStringFlatteningMiddleware: + def __init__(self, app: ASGIApp) -> None: + self._app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + query_string = scope.get("query_string") + if not isinstance(query_string, bytes): + query_string = b"" + query_string = query_string.decode("utf-8") + if scope["type"] == "http" and query_string: + parsed = parse_query_string(query_string) + flattened = {} + for name, values in parsed.items(): + flattened[to_snake(name)] = [c for value in values for r in csv.reader(io.StringIO(value)) for c in r] + + scope["query_string"] = encode_query_string(flattened, doseq=True).encode("utf-8") + + await self._app(scope, receive, send) + else: + await self._app(scope, receive, send) + + +def gen_app(settings: GlobalSettings) -> FastAPI: + _ = get_logger(level=settings.logger_settings.level) + storage = gen_storage(settings=settings) + app = FastAPI() + app.add_middleware(QueryStringFlatteningMiddleware) + app.include_router(gen_system_router(), prefix="/api/v1/system") + app.include_router(gen_data_router(storage=storage), prefix="/api/v1/data") + return app diff --git a/api/birdxplorer_api/main.py b/api/birdxplorer_api/main.py new file mode 100644 index 0000000..88793c6 --- /dev/null +++ b/api/birdxplorer_api/main.py @@ -0,0 +1,11 @@ +from birdxplorer_common.settings import GlobalSettings +from fastapi import FastAPI + +from .app import gen_app + + +def main() -> FastAPI: + return gen_app(settings=GlobalSettings()) + + +app = main() diff --git a/api/birdxplorer_api/py.typed b/api/birdxplorer_api/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/api/birdxplorer_api/routers/__init__.py b/api/birdxplorer_api/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/birdxplorer_api/routers/data.py b/api/birdxplorer_api/routers/data.py new file mode 100644 index 0000000..b4ceeb2 --- /dev/null +++ b/api/birdxplorer_api/routers/data.py @@ -0,0 +1,118 @@ +from datetime import timezone +from typing import List, Union + +from birdxplorer_common.models import ( + BaseModel, + LanguageIdentifier, + Note, + NoteId, + ParticipantId, + Post, + PostId, + Topic, + TopicId, + TweetId, + TwitterTimestamp, + UserEnrollment, +) +from birdxplorer_common.storage import Storage +from dateutil.parser import parse as dateutil_parse +from fastapi import APIRouter, HTTPException, Query + + +class TopicListResponse(BaseModel): + data: List[Topic] + + +class NoteListResponse(BaseModel): + data: List[Note] + + +class PostListResponse(BaseModel): + data: List[Post] + + +def str_to_twitter_timestamp(s: str) -> TwitterTimestamp: + try: + return TwitterTimestamp.from_int(int(s)) + except ValueError: + pass + try: + tmp = dateutil_parse(s) + if tmp.tzinfo is None: + tmp = tmp.replace(tzinfo=timezone.utc) + return TwitterTimestamp.from_int(int(tmp.timestamp() * 1000)) + except ValueError: + raise HTTPException(status_code=422, detail=f"Invalid TwitterTimestamp string: {s}") + + +def ensure_twitter_timestamp(t: Union[str, TwitterTimestamp]) -> TwitterTimestamp: + return str_to_twitter_timestamp(t) if isinstance(t, str) else t + + +def gen_router(storage: Storage) -> APIRouter: + router = APIRouter() + + @router.get("/user-enrollments/{participant_id}", response_model=UserEnrollment) + def get_user_enrollment_by_participant_id( + participant_id: ParticipantId, + ) -> UserEnrollment: + res = storage.get_user_enrollment_by_participant_id(participant_id=participant_id) + if res is None: + raise ValueError(f"participant_id={participant_id} not found") + return res + + @router.get("/topics", response_model=TopicListResponse) + def get_topics() -> TopicListResponse: + return TopicListResponse(data=list(storage.get_topics())) + + @router.get("/notes", response_model=NoteListResponse) + def get_notes( + note_ids: Union[List[NoteId], None] = Query(default=None), + created_at_from: Union[None, TwitterTimestamp] = Query(default=None), + created_at_to: Union[None, TwitterTimestamp] = Query(default=None), + topic_ids: Union[List[TopicId], None] = Query(default=None), + post_ids: Union[List[TweetId], None] = Query(default=None), + language: Union[LanguageIdentifier, None] = Query(default=None), + ) -> NoteListResponse: + return NoteListResponse( + data=list( + storage.get_notes( + note_ids=note_ids, + created_at_from=created_at_from, + created_at_to=created_at_to, + topic_ids=topic_ids, + post_ids=post_ids, + language=language, + ) + ) + ) + + @router.get("/posts", response_model=PostListResponse) + def get_posts( + post_id: Union[List[PostId], None] = Query(default=None), + created_at_start: Union[None, TwitterTimestamp, str] = Query(default=None), + created_at_end: Union[None, TwitterTimestamp, str] = Query(default=None), + ) -> PostListResponse: + if post_id is not None: + return PostListResponse(data=list(storage.get_posts_by_ids(post_ids=post_id))) + if created_at_start is not None: + if created_at_end is not None: + return PostListResponse( + data=list( + storage.get_posts_by_created_at_range( + start=ensure_twitter_timestamp(created_at_start), + end=ensure_twitter_timestamp(created_at_end), + ) + ) + ) + return PostListResponse( + data=list(storage.get_posts_by_created_at_start(start=ensure_twitter_timestamp(created_at_start))) + ) + if created_at_end is not None: + return PostListResponse( + data=list(storage.get_posts_by_created_at_end(end=ensure_twitter_timestamp(created_at_end))) + ) + return PostListResponse(data=list(storage.get_posts())) + + return router diff --git a/api/birdxplorer_api/routers/system.py b/api/birdxplorer_api/routers/system.py new file mode 100644 index 0000000..87595a8 --- /dev/null +++ b/api/birdxplorer_api/routers/system.py @@ -0,0 +1,12 @@ +from birdxplorer_common.models import Message +from fastapi import APIRouter + + +def gen_router() -> APIRouter: + router = APIRouter() + + @router.get("/ping", response_model=Message) + async def ping() -> Message: + return Message(message="pong") + + return router diff --git a/api/pyproject.toml b/api/pyproject.toml new file mode 100644 index 0000000..fff11ec --- /dev/null +++ b/api/pyproject.toml @@ -0,0 +1,113 @@ +[build-system] +build-backend = "flit_core.buildapi" +requires = ["flit_core >=3.8.0,<4"] + + +[project] +name = "birdxplorer_api" +description = "The Web API for BirdXplorer project." +authors = [ + {name = "osoken"}, +] +dynamic = [ + "version", +] +readme = "../README.md" +license = {file = "../LICENSE"} +requires-python = ">=3.10" + +classifiers = [ + "Development Status :: 3 - Alpha", + "Natural Language :: Japanese", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3.10", +] + +dependencies = [ + "birdxplorer_common @ git+https://github.com/codeforjapan/BirdXplorer.git@feature/issue-53-divide-python-packages#subdirectory=common", + "fastapi", + "python-dateutil", + "pydantic", + "starlette", + "python-dotenv", +] + +[project.urls] +Source = "https://github.com/codeforjapan/BirdXplorer" + +[tool.setuptools] +packages=["birdxplorer"] + +[tool.setuptools.package-data] +birdxplorer = ["py.typed"] + +[project.optional-dependencies] +dev=[ + "black", + "flake8", + "pyproject-flake8", + "pytest", + "mypy", + "tox", + "isort", + "pytest-mock", + "pytest-cov", + "freezegun", + "types-python-dateutil", + "psycopg2-binary", + "factory_boy", + "uvicorn", + "polyfactory", + "httpx", +] +prod=[ +] + +[tool.pytest.ini_options] +addopts = ["-sv", "--doctest-modules", "--ignore-glob=birdxplorer_api/main.py", "--cov=birdxplorer_api", "--cov-report=xml", "--cov-report=term-missing"] +testpaths = ["tests", "birdxplorer_api"] +filterwarnings = [ + "error", + "ignore:The \\'app\\' shortcut is now deprecated. Use the explicit style \\'transport=WSGITransport\\(app=\\.\\.\\.\\)\\' instead\\.", +] + +[tool.black] +line-length = 120 +target-version = ['py310'] + +[tool.flake8] +max-line-length = 120 +extend-ignore = "E203,E701" + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +plugins = ["pydantic.mypy"] + +[tool.pydantic.mypy] +init_typed = true + +[tool.isort] +profile = "black" + +[tool.tox] +legacy_tox_ini = """ + [tox] + skipsdist = true + envlist = py310 + + [testenv] + setenv = + VIRTUALENV_PIP = 24.0 + deps = + -e .[dev] + commands = + black birdxplorer_api tests + isort birdxplorer_api tests + pytest + pflake8 birdxplorer_api/ tests/ + mypy birdxplorer_api --strict + mypy tests --strict +""" diff --git a/api/tests/conftest.py b/api/tests/conftest.py new file mode 100644 index 0000000..18063aa --- /dev/null +++ b/api/tests/conftest.py @@ -0,0 +1,347 @@ +import os +import random +from collections.abc import Generator +from typing import List, Type, Union +from unittest.mock import MagicMock, patch + +from birdxplorer_common.exceptions import UserEnrollmentNotFoundError +from birdxplorer_common.models import ( + LanguageIdentifier, + Note, + NoteId, + ParticipantId, + Post, + PostId, + Topic, + TopicId, + TweetId, + TwitterTimestamp, + UserEnrollment, + XUser, +) +from birdxplorer_common.settings import GlobalSettings, PostgresStorageSettings +from birdxplorer_common.storage import Storage +from dotenv import load_dotenv +from fastapi.testclient import TestClient +from polyfactory import Use +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.pytest_plugin import register_fixture +from pytest import fixture + + +def gen_random_twitter_timestamp() -> int: + return random.randint(TwitterTimestamp.min_value(), TwitterTimestamp.max_value()) + + +@register_fixture(name="user_enrollment_factory") +class UserEnrollmentFactory(ModelFactory[UserEnrollment]): + __model__ = UserEnrollment + + participant_id = Use(lambda: "".join(random.choices("0123456789ABCDEF", k=64))) + timestamp_of_last_state_change = Use(gen_random_twitter_timestamp) + timestamp_of_last_earn_out = Use(gen_random_twitter_timestamp) + + +@register_fixture(name="note_factory") +class NoteFactory(ModelFactory[Note]): + __model__ = Note + + +@register_fixture(name="topic_factory") +class TopicFactory(ModelFactory[Topic]): + __model__ = Topic + + +@register_fixture(name="x_user_factory") +class XUserFactory(ModelFactory[XUser]): + __model__ = XUser + + +@register_fixture(name="post_factory") +class PostFactory(ModelFactory[Post]): + __model__ = Post + + +@fixture +def user_enrollment_samples( + user_enrollment_factory: UserEnrollmentFactory, +) -> Generator[List[UserEnrollment], None, None]: + yield [user_enrollment_factory.build() for _ in range(3)] + + +@fixture +def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, None]: + topics = [ + topic_factory.build(topic_id=0, label={"en": "topic0", "ja": "トピック0"}, reference_count=3), + topic_factory.build(topic_id=1, label={"en": "topic1", "ja": "トピック1"}, reference_count=2), + topic_factory.build(topic_id=2, label={"en": "topic2", "ja": "トピック2"}, reference_count=1), + topic_factory.build(topic_id=3, label={"en": "topic3", "ja": "トピック3"}, reference_count=0), + ] + yield topics + + +@fixture +def note_samples(note_factory: NoteFactory, topic_samples: List[Topic]) -> Generator[List[Note], None, None]: + notes = [ + note_factory.build( + note_id="1234567890123456781", + post_id="2234567890123456781", + topics=[topic_samples[0]], + language="ja", + summary="要約文1", + created_at=1152921600000, + ), + note_factory.build( + note_id="1234567890123456782", + post_id="2234567890123456782", + topics=[], + language="en", + summary="summary2", + created_at=1152921601000, + ), + note_factory.build( + note_id="1234567890123456783", + post_id="2234567890123456783", + topics=[topic_samples[1]], + language="en", + summary="summary3", + created_at=1152921602000, + ), + note_factory.build( + note_id="1234567890123456784", + post_id="2234567890123456784", + topics=[topic_samples[0], topic_samples[1], topic_samples[2]], + language="en", + summary="summary4", + created_at=1152921603000, + ), + note_factory.build( + note_id="1234567890123456785", + post_id="2234567890123456785", + topics=[topic_samples[0]], + language="en", + summary="summary5", + created_at=1152921604000, + ), + ] + yield notes + + +@fixture +def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None, None]: + x_users = [ + x_user_factory.build( + user_id="1234567890123456781", + name="User1", + profile_image_url="https://pbs.twimg.com/profile_images/1468001914302390XXX/xxxxXXXX_normal.jpg", + followers_count=100, + following_count=200, + ), + x_user_factory.build( + user_id="1234567890123456782", + name="User2", + profile_image_url="https://pbs.twimg.com/profile_images/1468001914302390YYY/yyyyYYYY_normal.jpg", + followers_count=300, + following_count=400, + ), + x_user_factory.build( + user_id="1234567890123456783", + name="User3", + profile_image_url="https://pbs.twimg.com/profile_images/1468001914302390ZZZ/zzzzZZZZ_normal.jpg", + followers_count=300, + following_count=400, + ), + ] + yield x_users + + +@fixture +def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Generator[List[Post], None, None]: + posts = [ + post_factory.build( + post_id="2234567890123456781", + x_user_id="1234567890123456781", + x_user=x_user_samples[0], + text="text11", + media_details=None, + created_at=1152921600000, + like_count=10, + repost_count=20, + impression_count=30, + ), + post_factory.build( + post_id="2234567890123456791", + x_user_id="1234567890123456781", + x_user=x_user_samples[0], + text="text12", + media_details=None, + created_at=1153921700000, + like_count=10, + repost_count=20, + impression_count=30, + ), + post_factory.build( + post_id="2234567890123456801", + x_user_id="1234567890123456782", + x_user=x_user_samples[1], + text="text21", + media_details=None, + created_at=1154921800000, + like_count=10, + repost_count=20, + impression_count=30, + ), + ] + yield posts + + +@fixture +def mock_storage( + user_enrollment_samples: List[UserEnrollment], + topic_samples: List[Topic], + post_samples: List[Post], + note_samples: List[Note], +) -> Generator[MagicMock, None, None]: + mock = MagicMock(spec=Storage) + + def _get_user_enrollment_by_participant_id( + participant_id: ParticipantId, + ) -> UserEnrollment: + x = {d.participant_id: d for d in user_enrollment_samples}.get(participant_id) + if x is None: + raise UserEnrollmentNotFoundError(participant_id=participant_id) + return x + + mock.get_user_enrollment_by_participant_id.side_effect = _get_user_enrollment_by_participant_id + + def _get_topics() -> Generator[Topic, None, None]: + yield from topic_samples + + def _get_notes( + note_ids: Union[List[NoteId], None] = None, + created_at_from: Union[None, TwitterTimestamp] = None, + created_at_to: Union[None, TwitterTimestamp] = None, + topic_ids: Union[List[TopicId], None] = None, + post_ids: Union[List[TweetId], None] = None, + language: Union[LanguageIdentifier, None] = None, + ) -> Generator[Note, None, None]: + for note in note_samples: + if note_ids is not None and note.note_id not in note_ids: + continue + if created_at_from is not None and note.created_at < created_at_from: + continue + if created_at_to is not None and note.created_at > created_at_to: + continue + if topic_ids is not None and not set(topic_ids).issubset({topic.topic_id for topic in note.topics}): + continue + if post_ids is not None and note.post_id not in post_ids: + continue + if language is not None and note.language != language: + continue + yield note + + mock.get_topics.side_effect = _get_topics + mock.get_notes.side_effect = _get_notes + + def _get_posts() -> Generator[Post, None, None]: + yield from post_samples + + mock.get_posts.side_effect = _get_posts + + def _get_posts_by_ids(post_ids: List[PostId]) -> Generator[Post, None, None]: + for i in post_ids: + for post in post_samples: + if post.post_id == i: + yield post + break + + mock.get_posts_by_ids.side_effect = _get_posts_by_ids + + def _get_posts_by_created_at_range(start: TwitterTimestamp, end: TwitterTimestamp) -> Generator[Post, None, None]: + for post in post_samples: + if start <= post.created_at < end: + yield post + + mock.get_posts_by_created_at_range.side_effect = _get_posts_by_created_at_range + + def _get_posts_by_created_at_start( + start: TwitterTimestamp, + ) -> Generator[Post, None, None]: + for post in post_samples: + if start <= post.created_at: + yield post + + mock.get_posts_by_created_at_start.side_effect = _get_posts_by_created_at_start + + def _get_posts_by_created_at_end( + end: TwitterTimestamp, + ) -> Generator[Post, None, None]: + for post in post_samples: + if post.created_at < end: + yield post + + mock.get_posts_by_created_at_end.side_effect = _get_posts_by_created_at_end + + yield mock + + +TEST_DATABASE_NAME = "bx_test" + + +@fixture +def load_dotenv_fixture() -> None: + load_dotenv() + + +@fixture +def postgres_storage_settings_factory( + load_dotenv_fixture: None, +) -> Type[ModelFactory[PostgresStorageSettings]]: + class PostgresStorageSettingsFactory(ModelFactory[PostgresStorageSettings]): + __model__ = PostgresStorageSettings + + host = "localhost" + username = "postgres" + port = 5432 + database = "postgres" + password = os.environ["BX_STORAGE_SETTINGS__PASSWORD"] + + return PostgresStorageSettingsFactory + + +@fixture +def global_settings_factory( + postgres_storage_settings_factory: Type[ModelFactory[PostgresStorageSettings]], +) -> Type[ModelFactory[GlobalSettings]]: + class GlobalSettingsFactory(ModelFactory[GlobalSettings]): + __model__ = GlobalSettings + + storage_settings = postgres_storage_settings_factory.build() + + return GlobalSettingsFactory + + +@fixture +def settings_for_test( + global_settings_factory: Type[ModelFactory[GlobalSettings]], + postgres_storage_settings_factory: Type[ModelFactory[PostgresStorageSettings]], +) -> Generator[GlobalSettings, None, None]: + yield global_settings_factory.build( + storage_settings=postgres_storage_settings_factory.build(database=TEST_DATABASE_NAME) + ) + + +@fixture +def client(settings_for_test: GlobalSettings, mock_storage: MagicMock) -> Generator[TestClient, None, None]: + from birdxplorer_api.app import gen_app + + with patch("birdxplorer_api.app.gen_storage", return_value=mock_storage): + app = gen_app(settings=settings_for_test) + yield TestClient(app) + + +@fixture +def default_settings( + global_settings_factory: Type[ModelFactory[GlobalSettings]], +) -> Generator[GlobalSettings, None, None]: + yield global_settings_factory.build() diff --git a/api/tests/routers/__init__.py b/api/tests/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/tests/routers/test_data.py b/api/tests/routers/test_data.py new file mode 100644 index 0000000..c36cd8e --- /dev/null +++ b/api/tests/routers/test_data.py @@ -0,0 +1,125 @@ +import json +from typing import List + +from birdxplorer_common.models import Note, Post, Topic, UserEnrollment +from fastapi.testclient import TestClient + + +def test_user_enrollments_get(client: TestClient, user_enrollment_samples: List[UserEnrollment]) -> None: + response = client.get(f"/api/v1/data/user-enrollments/{user_enrollment_samples[0].participant_id}") + assert response.status_code == 200 + res_json = response.json() + assert res_json["participantId"] == user_enrollment_samples[0].participant_id + + +def test_topics_get(client: TestClient, topic_samples: List[Topic]) -> None: + response = client.get("/api/v1/data/topics") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [d.model_dump(by_alias=True) for d in topic_samples]} + + +def test_posts_get(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(d.model_dump_json()) for d in post_samples]} + + +def test_posts_get_has_post_id_filter(client: TestClient, post_samples: List[Post]) -> None: + response = client.get(f"/api/v1/data/posts/?postId={post_samples[0].post_id},{post_samples[2].post_id}") + assert response.status_code == 200 + res_json = response.json() + assert res_json == { + "data": [ + json.loads(post_samples[0].model_dump_json()), + json.loads(post_samples[2].model_dump_json()), + ] + } + + +def test_posts_get_has_created_at_filter_start_and_end(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?createdAtStart=2006-7-25 00:00:00&createdAtEnd=2006-7-30 23:59:59") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(post_samples[1].model_dump_json())]} + + +def test_posts_get_has_created_at_filter_start(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?createdAtStart=2006-7-25 00:00:00") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2)]} + + +def test_posts_get_has_created_at_filter_end(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?createdAtEnd=2006-7-30 00:00:00") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(post_samples[i].model_dump_json()) for i in (0, 1)]} + + +def test_posts_get_created_at_range_filter_accepts_integer(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?createdAtStart=1153921700000&createdAtEnd=1154921800000") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(post_samples[1].model_dump_json())]} + + +def test_posts_get_created_at_start_filter_accepts_integer(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?createdAtStart=1153921700000") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(post_samples[i].model_dump_json()) for i in (1, 2)]} + + +def test_posts_get_created_at_end_filter_accepts_integer(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?createdAtEnd=1154921800000") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(post_samples[i].model_dump_json()) for i in (0, 1)]} + + +def test_posts_get_timestamp_out_of_range(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?createdAtStart=1153921700&createdAtEnd=1153921700") + assert response.status_code == 422 + + +def test_notes_get(client: TestClient, note_samples: List[Note]) -> None: + response = client.get("/api/v1/data/notes") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(d.model_dump_json()) for d in note_samples]} + + +def test_notes_get_has_note_id_filter(client: TestClient, note_samples: List[Note]) -> None: + response = client.get(f"/api/v1/data/notes/?noteIds={note_samples[0].note_id}¬eIds={note_samples[2].note_id}") + assert response.status_code == 200 + res_json = response.json() + assert res_json == { + "data": [ + json.loads(note_samples[0].model_dump_json()), + json.loads(note_samples[2].model_dump_json()), + ] + } + + +def test_notes_get_has_created_at_filter_from_and_to(client: TestClient, note_samples: List[Note]) -> None: + response = client.get("/api/v1/data/notes/?createdAtFrom=1152921601000&createdAtTo=1152921603000") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3)]} + + +def test_notes_get_has_created_at_filter_from(client: TestClient, note_samples: List[Note]) -> None: + response = client.get("/api/v1/data/notes/?createdAtFrom=1152921601000") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (1, 2, 3, 4)]} + + +def test_notes_get_has_created_at_filter_to(client: TestClient, note_samples: List[Note]) -> None: + response = client.get("/api/v1/data/notes/?createdAtTo=1152921603000") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1, 2, 3)]} diff --git a/api/tests/routers/test_system.py b/api/tests/routers/test_system.py new file mode 100644 index 0000000..36751a7 --- /dev/null +++ b/api/tests/routers/test_system.py @@ -0,0 +1,7 @@ +from fastapi.testclient import TestClient + + +def test_ping(client: TestClient) -> None: + response = client.get("/api/v1/system/ping") + assert response.status_code == 200 + assert response.json() == {"message": "pong"} diff --git a/api/tests/test_app.py b/api/tests/test_app.py new file mode 100644 index 0000000..7fd0b52 --- /dev/null +++ b/api/tests/test_app.py @@ -0,0 +1,15 @@ +from birdxplorer_common.settings import GlobalSettings +from pytest_mock import MockerFixture + +from birdxplorer_api.app import gen_app + + +def test_gen_app(mocker: MockerFixture, default_settings: GlobalSettings) -> None: + FastAPI = mocker.patch("birdxplorer_api.app.FastAPI") + get_logger = mocker.patch("birdxplorer_api.app.get_logger") + expected = FastAPI.return_value + + actual = gen_app(settings=default_settings) + + assert actual == expected + get_logger.assert_called_once_with(level=default_settings.logger_settings.level) diff --git a/api/tests/test_main.py b/api/tests/test_main.py new file mode 100644 index 0000000..0443346 --- /dev/null +++ b/api/tests/test_main.py @@ -0,0 +1,13 @@ +from pytest_mock import MockerFixture + + +def test_main_returns_app(mocker: MockerFixture) -> None: + gen_app = mocker.patch("birdxplorer_api.main.gen_app") + GlobalSettings = mocker.patch("birdxplorer_api.main.GlobalSettings") + from birdxplorer_api.main import main + + expected = gen_app.return_value + actual = main() + GlobalSettings.assert_called_once_with() + gen_app.assert_called_once_with(settings=GlobalSettings.return_value) + assert actual == expected diff --git a/api/tests/test_package.py b/api/tests/test_package.py new file mode 100644 index 0000000..43ef207 --- /dev/null +++ b/api/tests/test_package.py @@ -0,0 +1,9 @@ +import re + +import birdxplorer_api + + +def test_package_has_version() -> None: + assert hasattr(birdxplorer_api, "__version__") + assert isinstance(birdxplorer_api.__version__, str) + assert re.match(r"^\d+\.\d+\.\d+$", birdxplorer_api.__version__) From b2fa857ef45866338151ec46c49f8061ad8ddec4 Mon Sep 17 00:00:00 2001 From: osoken Date: Sat, 4 May 2024 22:07:42 +0900 Subject: [PATCH 15/19] fix(isort): fix isort settings --- api/birdxplorer_api/app.py | 7 ++++--- api/birdxplorer_api/main.py | 3 ++- api/birdxplorer_api/routers/data.py | 5 +++-- api/birdxplorer_api/routers/system.py | 3 ++- api/pyproject.toml | 1 + api/tests/conftest.py | 13 +++++++------ api/tests/routers/test_data.py | 3 ++- api/tests/test_app.py | 2 +- common/pyproject.toml | 1 + scripts/migrations/migrate_all.py | 14 +++++++++----- 10 files changed, 32 insertions(+), 20 deletions(-) diff --git a/api/birdxplorer_api/app.py b/api/birdxplorer_api/app.py index cf72e28..38a6e86 100644 --- a/api/birdxplorer_api/app.py +++ b/api/birdxplorer_api/app.py @@ -3,13 +3,14 @@ from urllib.parse import parse_qs as parse_query_string from urllib.parse import urlencode as encode_query_string -from birdxplorer_common.logger import get_logger -from birdxplorer_common.settings import GlobalSettings -from birdxplorer_common.storage import gen_storage from fastapi import FastAPI from pydantic.alias_generators import to_snake from starlette.types import ASGIApp, Receive, Scope, Send +from birdxplorer_common.logger import get_logger +from birdxplorer_common.settings import GlobalSettings +from birdxplorer_common.storage import gen_storage + from .routers.data import gen_router as gen_data_router from .routers.system import gen_router as gen_system_router diff --git a/api/birdxplorer_api/main.py b/api/birdxplorer_api/main.py index 88793c6..0cd201d 100644 --- a/api/birdxplorer_api/main.py +++ b/api/birdxplorer_api/main.py @@ -1,6 +1,7 @@ -from birdxplorer_common.settings import GlobalSettings from fastapi import FastAPI +from birdxplorer_common.settings import GlobalSettings + from .app import gen_app diff --git a/api/birdxplorer_api/routers/data.py b/api/birdxplorer_api/routers/data.py index b4ceeb2..2398df8 100644 --- a/api/birdxplorer_api/routers/data.py +++ b/api/birdxplorer_api/routers/data.py @@ -1,6 +1,9 @@ from datetime import timezone from typing import List, Union +from dateutil.parser import parse as dateutil_parse +from fastapi import APIRouter, HTTPException, Query + from birdxplorer_common.models import ( BaseModel, LanguageIdentifier, @@ -16,8 +19,6 @@ UserEnrollment, ) from birdxplorer_common.storage import Storage -from dateutil.parser import parse as dateutil_parse -from fastapi import APIRouter, HTTPException, Query class TopicListResponse(BaseModel): diff --git a/api/birdxplorer_api/routers/system.py b/api/birdxplorer_api/routers/system.py index 87595a8..00d8394 100644 --- a/api/birdxplorer_api/routers/system.py +++ b/api/birdxplorer_api/routers/system.py @@ -1,6 +1,7 @@ -from birdxplorer_common.models import Message from fastapi import APIRouter +from birdxplorer_common.models import Message + def gen_router() -> APIRouter: router = APIRouter() diff --git a/api/pyproject.toml b/api/pyproject.toml index fff11ec..9af7082 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -91,6 +91,7 @@ init_typed = true [tool.isort] profile = "black" +known_first_party = "birdxplorer_api,birdxplorer_common" [tool.tox] legacy_tox_ini = """ diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 18063aa..093d718 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -4,6 +4,13 @@ from typing import List, Type, Union from unittest.mock import MagicMock, patch +from dotenv import load_dotenv +from fastapi.testclient import TestClient +from polyfactory import Use +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.pytest_plugin import register_fixture +from pytest import fixture + from birdxplorer_common.exceptions import UserEnrollmentNotFoundError from birdxplorer_common.models import ( LanguageIdentifier, @@ -21,12 +28,6 @@ ) from birdxplorer_common.settings import GlobalSettings, PostgresStorageSettings from birdxplorer_common.storage import Storage -from dotenv import load_dotenv -from fastapi.testclient import TestClient -from polyfactory import Use -from polyfactory.factories.pydantic_factory import ModelFactory -from polyfactory.pytest_plugin import register_fixture -from pytest import fixture def gen_random_twitter_timestamp() -> int: diff --git a/api/tests/routers/test_data.py b/api/tests/routers/test_data.py index c36cd8e..8bf271c 100644 --- a/api/tests/routers/test_data.py +++ b/api/tests/routers/test_data.py @@ -1,9 +1,10 @@ import json from typing import List -from birdxplorer_common.models import Note, Post, Topic, UserEnrollment from fastapi.testclient import TestClient +from birdxplorer_common.models import Note, Post, Topic, UserEnrollment + def test_user_enrollments_get(client: TestClient, user_enrollment_samples: List[UserEnrollment]) -> None: response = client.get(f"/api/v1/data/user-enrollments/{user_enrollment_samples[0].participant_id}") diff --git a/api/tests/test_app.py b/api/tests/test_app.py index 7fd0b52..c54e8c1 100644 --- a/api/tests/test_app.py +++ b/api/tests/test_app.py @@ -1,7 +1,7 @@ -from birdxplorer_common.settings import GlobalSettings from pytest_mock import MockerFixture from birdxplorer_api.app import gen_app +from birdxplorer_common.settings import GlobalSettings def test_gen_app(mocker: MockerFixture, default_settings: GlobalSettings) -> None: diff --git a/common/pyproject.toml b/common/pyproject.toml index 2cc9ae9..7bd2dc2 100644 --- a/common/pyproject.toml +++ b/common/pyproject.toml @@ -91,6 +91,7 @@ init_typed = true [tool.isort] profile = "black" +known_first_party = "birdxplorer_api,birdxplorer_common" [tool.tox] legacy_tox_ini = """ diff --git a/scripts/migrations/migrate_all.py b/scripts/migrations/migrate_all.py index 140182d..8d6b879 100644 --- a/scripts/migrations/migrate_all.py +++ b/scripts/migrations/migrate_all.py @@ -6,9 +6,9 @@ from dotenv import load_dotenv from sqlalchemy.orm import Session -from birdxplorer.logger import get_logger -from birdxplorer.settings import GlobalSettings -from birdxplorer.storage import ( +from birdxplorer_common.logger import get_logger +from birdxplorer_common.settings import GlobalSettings +from birdxplorer_common.storage import ( Base, NoteRecord, NoteTopicAssociation, @@ -57,7 +57,11 @@ ) ) sess.commit() - with open(os.path.join(args.data_dir, args.notes_topics_association_file_name), "r", encoding="utf-8") as fin: + with open( + os.path.join(args.data_dir, args.notes_topics_association_file_name), + "r", + encoding="utf-8", + ) as fin: for d in csv.DictReader(fin): if ( sess.query(NoteTopicAssociation) @@ -109,7 +113,7 @@ post_id=d["post_id"], user_id=d["user_id"], text=d["text"], - media_details=json.loads(d["media_details"]) if len(d["media_details"]) > 0 else None, + media_details=(json.loads(d["media_details"]) if len(d["media_details"]) > 0 else None), created_at=d["created_at"], like_count=d["like_count"], repost_count=d["repost_count"], From ad1dcb45582803872b4c1da1920ed9f59b0eae60 Mon Sep 17 00:00:00 2001 From: osoken Date: Sat, 4 May 2024 22:13:29 +0900 Subject: [PATCH 16/19] feat(etl): add ETL --- api/pyproject.toml | 2 +- common/pyproject.toml | 2 +- etl/birdxplorer_etl/__init__.py | 1 + etl/birdxplorer_etl/py.typed | 0 etl/pyproject.toml | 111 ++++++++++++++++++++++++++++++++ etl/tests/test_package.py | 9 +++ 6 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 etl/birdxplorer_etl/__init__.py create mode 100644 etl/birdxplorer_etl/py.typed create mode 100644 etl/pyproject.toml create mode 100644 etl/tests/test_package.py diff --git a/api/pyproject.toml b/api/pyproject.toml index 9af7082..7503043 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -91,7 +91,7 @@ init_typed = true [tool.isort] profile = "black" -known_first_party = "birdxplorer_api,birdxplorer_common" +known_first_party = "birdxplorer_api,birdxplorer_common,birdxplorer_etl" [tool.tox] legacy_tox_ini = """ diff --git a/common/pyproject.toml b/common/pyproject.toml index 7bd2dc2..f29efbc 100644 --- a/common/pyproject.toml +++ b/common/pyproject.toml @@ -91,7 +91,7 @@ init_typed = true [tool.isort] profile = "black" -known_first_party = "birdxplorer_api,birdxplorer_common" +known_first_party = "birdxplorer_api,birdxplorer_common,birdxplorer_etl" [tool.tox] legacy_tox_ini = """ diff --git a/etl/birdxplorer_etl/__init__.py b/etl/birdxplorer_etl/__init__.py new file mode 100644 index 0000000..f102a9c --- /dev/null +++ b/etl/birdxplorer_etl/__init__.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/etl/birdxplorer_etl/py.typed b/etl/birdxplorer_etl/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/etl/pyproject.toml b/etl/pyproject.toml new file mode 100644 index 0000000..602c45b --- /dev/null +++ b/etl/pyproject.toml @@ -0,0 +1,111 @@ +[build-system] +build-backend = "flit_core.buildapi" +requires = ["flit_core >=3.8.0,<4"] + + +[project] +name = "birdxplorer_etl" +description = "ETL module for BirdXplorer" +authors = [ + {name = "osoken"}, +] +dynamic = [ + "version", +] +readme = "../README.md" +license = {file = "../LICENSE"} +requires-python = ">=3.10" + +classifiers = [ + "Development Status :: 3 - Alpha", + "Natural Language :: Japanese", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3.10", +] + +dependencies = [ + "birdxplorer_common @ git+https://github.com/codeforjapan/BirdXplorer.git@feature/issue-53-divide-python-packages#subdirectory=common", +] + +[project.urls] +Source = "https://github.com/codeforjapan/BirdXplorer" + +[tool.setuptools] +packages=["birdxplorer_etl"] + +[tool.setuptools.package-data] +birdxplorer = ["py.typed"] + +[project.optional-dependencies] +dev=[ + "black", + "flake8", + "pyproject-flake8", + "pytest", + "mypy", + "tox", + "isort", + "pytest-mock", + "pytest-cov", + "freezegun", + "types-python-dateutil", + "psycopg2-binary", + "factory_boy", + "polyfactory", + "types-psycopg2", +] +prod=[ + "psycopg2" +] + + +[tool.pytest.ini_options] +addopts = ["-sv", "--doctest-modules", "--cov=birdxplorer_etl", "--cov-report=xml", "--cov-report=term-missing"] +testpaths = ["tests", "birdxplorer_etl"] +filterwarnings = [ + "error", +] + +[tool.black] +line-length = 120 +target-version = ['py310'] + +[tool.flake8] +max-line-length = 120 +extend-ignore = "E203,E701" + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +plugins = ["pydantic.mypy"] +mypy_path = "stubs/" + +[tool.pydantic.mypy] +init_typed = true + +[tool.isort] +profile = "black" +known_first_party = "birdxplorer_api,birdxplorer_common,birdxplorer_etl" + +[tool.tox] +legacy_tox_ini = """ + [tox] + skipsdist = true + envlist = py310 + + [testenv] + setenv = + VIRTUALENV_PIP = 24.0 + DATA_DIR = {env:BX_DATA_DIR} + deps = + -e .[dev] + commands = + black birdxplorer_etl tests + isort birdxplorer_etl tests + pytest + pflake8 birdxplorer_etl/ tests/ + mypy birdxplorer_etl --strict + mypy tests --strict +""" diff --git a/etl/tests/test_package.py b/etl/tests/test_package.py new file mode 100644 index 0000000..2caeadc --- /dev/null +++ b/etl/tests/test_package.py @@ -0,0 +1,9 @@ +import re + +import birdxplorer_etl + + +def test_birdxplorer_etl_has_version() -> None: + assert hasattr(birdxplorer_etl, "__version__") + assert isinstance(birdxplorer_etl.__version__, str) + assert re.match(r"^\d+\.\d+\.\d+$", birdxplorer_etl.__version__) From f26894ac931cb6f5cd65fc4acd3120c2a13202e3 Mon Sep 17 00:00:00 2001 From: osoken Date: Sat, 4 May 2024 22:21:46 +0900 Subject: [PATCH 17/19] ci: fix common test --- .github/workflows/test.yml | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index faeb770..8173dba 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,7 +6,7 @@ permissions: packages: read jobs: - test-check: + common-test-check: runs-on: ubuntu-latest services: postgres: @@ -22,15 +22,16 @@ jobs: steps: - name: Checkout uses: actions/checkout@v4 - - name: Setup python 3.12 + - name: Setup python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.12 + python-version: 3.10 cache: pip - cache-dependency-path: pyproject.toml + cache-dependency-path: common/pyproject.toml - name: dependency install - run: pip install -e ".[dev]" + run: pip install -e "./common[dev]" - name: copy env - run: cp .env.example .env + run: cp .env.example common/.env - name: test + working-directory: common run: tox From 5bc1ec4610c2cf2f3fc5be7edc8eebe87c971f83 Mon Sep 17 00:00:00 2001 From: osoken Date: Sat, 4 May 2024 22:23:22 +0900 Subject: [PATCH 18/19] ci: fix python version string --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8173dba..3c78355 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -25,7 +25,7 @@ jobs: - name: Setup python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.10 + python-version: "3.10" cache: pip cache-dependency-path: common/pyproject.toml - name: dependency install From 3615ed9287e8d59ffb4dc8b8d6888b15a4bcf8a4 Mon Sep 17 00:00:00 2001 From: osoken Date: Sat, 4 May 2024 22:26:08 +0900 Subject: [PATCH 19/19] ci: add test-check for api and etl --- .github/workflows/test.yml | 40 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3c78355..8b31f5a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,3 +35,43 @@ jobs: - name: test working-directory: common run: tox + + api-test-check: + runs-on: ubuntu-latest + timeout-minutes: 5 + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup python 3.10 + uses: actions/setup-python@v4 + with: + python-version: "3.10" + cache: pip + cache-dependency-path: api/pyproject.toml + - name: dependency install + run: pip install -e "./api[dev]" + - name: copy env + run: cp .env.example api/.env + - name: test + working-directory: api + run: tox + + etl-test-check: + runs-on: ubuntu-latest + timeout-minutes: 5 + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup python 3.10 + uses: actions/setup-python@v4 + with: + python-version: "3.10" + cache: pip + cache-dependency-path: etl/pyproject.toml + - name: dependency install + run: pip install -e "./etl[dev]" + - name: copy env + run: cp .env.example etl/.env + - name: test + working-directory: etl + run: tox