diff --git a/birdxplorer/models.py b/birdxplorer/models.py index 2844b01..b10df9d 100644 --- a/birdxplorer/models.py +++ b/birdxplorer/models.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Literal, Type, TypeAlias, TypeVar, Union from pydantic import BaseModel as PydanticBaseModel -from pydantic import ConfigDict, GetCoreSchemaHandler, TypeAdapter +from pydantic import ConfigDict, GetCoreSchemaHandler, HttpUrl, TypeAdapter from pydantic.alias_generators import to_camel from pydantic_core import core_schema @@ -128,21 +128,21 @@ def __get_extra_constraint_dict__(cls) -> dict[str, Any]: return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9A-F]{64}$") -class NineToNineteenDigitsDecimalString(BaseString): +class UpToNineteenDigitsDecimalString(BaseString): """ - >>> NineToNineteenDigitsDecimalString.from_str("test") + >>> UpToNineteenDigitsDecimalString.from_str("test") Traceback (most recent call last): ... pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str] - String should match pattern '^[0-9]{9,19}$' [type=string_pattern_mismatch, input_value='test', input_type=str] + String should match pattern '^[0-9]{1,19}$' [type=string_pattern_mismatch, input_value='test', input_type=str] ... - >>> NineToNineteenDigitsDecimalString.from_str("1234567890123456789") - NineToNineteenDigitsDecimalString('1234567890123456789') + >>> UpToNineteenDigitsDecimalString.from_str("1234567890123456789") + UpToNineteenDigitsDecimalString('1234567890123456789') """ @classmethod def __get_extra_constraint_dict__(cls) -> dict[str, Any]: - return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9]{9,19}$") + return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9]{1,19}$") class NonEmptyStringMixin(BaseString): @@ -561,7 +561,7 @@ class NotesValidationDifficulty(str, Enum): empty = "" -class TweetId(NineToNineteenDigitsDecimalString): ... +class TweetId(UpToNineteenDigitsDecimalString): ... class NoteData(BaseModel): @@ -630,3 +630,35 @@ class Note(BaseModel): topics: List[Topic] summary: SummaryString created_at: TwitterTimestamp + + +class UserId(UpToNineteenDigitsDecimalString): ... + + +class UserName(NonEmptyTrimmedString): ... + + +class XUser(BaseModel): + user_id: UserId + name: UserName + profile_image: HttpUrl + followers_count: NonNegativeInt + following_count: NonNegativeInt + + +class PostId(UpToNineteenDigitsDecimalString): ... + + +MediaDetails: TypeAlias = List[HttpUrl] | None + + +class Post(BaseModel): + post_id: PostId + x_user_id: UserId + x_user: XUser + text: str + media_details: MediaDetails = None + created_at: TwitterTimestamp + like_count: NonNegativeInt + repost_count: NonNegativeInt + impression_count: NonNegativeInt diff --git a/birdxplorer/routers/data.py b/birdxplorer/routers/data.py index a3faad2..055a838 100644 --- a/birdxplorer/routers/data.py +++ b/birdxplorer/routers/data.py @@ -2,7 +2,7 @@ from fastapi import APIRouter -from ..models import BaseModel, ParticipantId, Topic, UserEnrollment +from ..models import BaseModel, ParticipantId, Post, Topic, UserEnrollment from ..storage import Storage @@ -10,6 +10,10 @@ class TopicListResponse(BaseModel): data: List[Topic] +class PostListResponse(BaseModel): + data: List[Post] + + def gen_router(storage: Storage) -> APIRouter: router = APIRouter() @@ -24,4 +28,8 @@ def get_user_enrollment_by_participant_id(participant_id: ParticipantId) -> User def get_topics() -> TopicListResponse: return TopicListResponse(data=list(storage.get_topics())) + @router.get("/posts", response_model=PostListResponse) + def get_posts() -> PostListResponse: + return PostListResponse(data=list(storage.get_posts())) + return router diff --git a/birdxplorer/storage.py b/birdxplorer/storage.py index 42cd536..c320d9e 100644 --- a/birdxplorer/storage.py +++ b/birdxplorer/storage.py @@ -1,16 +1,42 @@ from typing import Generator, List +from psycopg2.extensions import AsIs, register_adapter +from pydantic import AnyUrl, HttpUrl from sqlalchemy import ForeignKey, create_engine, func, select from sqlalchemy.engine import Engine from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship from sqlalchemy.types import DECIMAL, JSON, Integer, String -from .models import LanguageIdentifier, NoteId, ParticipantId, SummaryString +from .models import ( + LanguageIdentifier, + MediaDetails, + NonNegativeInt, + NoteId, + ParticipantId, +) +from .models import Post as PostModel +from .models import SummaryString from .models import Topic as TopicModel -from .models import TopicId, TopicLabel, TweetId, TwitterTimestamp, UserEnrollment +from .models import ( + TopicId, + TopicLabel, + TweetId, + TwitterTimestamp, + UserEnrollment, + UserId, + UserName, +) +from .models import XUser as XUserModel from .settings import GlobalSettings +def adapt_pydantic_http_url(url: AnyUrl) -> AsIs: + return AsIs(repr(str(url))) + + +register_adapter(AnyUrl, adapt_pydantic_http_url) + + class Base(DeclarativeBase): type_annotation_map = { TopicId: Integer, @@ -21,6 +47,11 @@ class Base(DeclarativeBase): LanguageIdentifier: String, TwitterTimestamp: DECIMAL, SummaryString: String, + UserId: String, + UserName: String, + HttpUrl: String, + NonNegativeInt: DECIMAL, + MediaDetails: JSON, } @@ -50,6 +81,30 @@ class TopicRecord(Base): label: Mapped[TopicLabel] = mapped_column(nullable=False) +class XUserRecord(Base): + __tablename__ = "x_users" + + user_id: Mapped[UserId] = mapped_column(primary_key=True) + name: Mapped[UserName] = mapped_column(nullable=False) + profile_image: Mapped[HttpUrl] = mapped_column(nullable=False) + followers_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + following_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + + +class PostRecord(Base): + __tablename__ = "posts" + + post_id: Mapped[TweetId] = mapped_column(primary_key=True) + user_id: Mapped[UserId] = mapped_column(ForeignKey("x_users.user_id"), nullable=False) + user: Mapped[XUserRecord] = relationship() + text: Mapped[SummaryString] = mapped_column(nullable=False) + media_details: Mapped[MediaDetails] = mapped_column() + created_at: Mapped[TwitterTimestamp] = mapped_column(nullable=False) + like_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + repost_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + impression_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) + + class Storage: def __init__(self, engine: Engine) -> None: self._engine = engine @@ -77,6 +132,27 @@ def get_topics(self) -> Generator[TopicModel, None, None]: topic_id=topic_record.topic_id, label=topic_record.label, reference_count=reference_count or 0 ) + def get_posts(self) -> Generator[PostModel, None, None]: + with Session(self.engine) as sess: + for post_record in sess.query(PostRecord).all(): + yield PostModel( + post_id=post_record.post_id, + x_user_id=post_record.user_id, + x_user=XUserModel( + user_id=post_record.user.user_id, + name=post_record.user.name, + profile_image=post_record.user.profile_image, + followers_count=post_record.user.followers_count, + following_count=post_record.user.following_count, + ), + text=post_record.text, + media_details=post_record.media_details, + created_at=post_record.created_at, + like_count=post_record.like_count, + repost_count=post_record.repost_count, + impression_count=post_record.impression_count, + ) + def gen_storage(settings: GlobalSettings) -> Storage: engine = create_engine(settings.storage_settings.sqlalchemy_database_url) diff --git a/pyproject.toml b/pyproject.toml index 0e42de8..abcf924 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ dev=[ "uvicorn", "polyfactory", "httpx", + "types-psycopg2", ] prod=[ "psycopg2" diff --git a/scripts/migrations/convert_data_from_v1.py b/scripts/migrations/convert_data_from_v1.py index 566ad43..4364961 100644 --- a/scripts/migrations/convert_data_from_v1.py +++ b/scripts/migrations/convert_data_from_v1.py @@ -7,11 +7,15 @@ if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("notes_file") + parser.add_argument("posts_file") + parser.add_argument("x_users_file") parser.add_argument("output_dir") parser.add_argument("--notes-file-name", default="notes.csv") parser.add_argument("--topics-file-name", default="topics.csv") parser.add_argument("--notes-topics-association-file-name", default="note_topic.csv") parser.add_argument("--topic-threshold", type=int, default=5) + parser.add_argument("--posts-file-name", default="posts.csv") + parser.add_argument("--x-users-file-name", default="x_users.csv") args = parser.parse_args() @@ -19,9 +23,21 @@ notes = list(csv.DictReader(fin)) for d in notes: d["topic"] = [t.strip() for t in d["topic"].split(",")] + topics_with_count = Counter(t for d in notes for t in d["topic"]) topic_name_to_id_map = {t: i for i, (t, c) in enumerate(topics_with_count.items()) if c > args.topic_threshold} + with open(args.posts_file, "r", encoding="utf-8") as fin: + posts = list(csv.DictReader(fin)) + for d in posts: + d["media_details"] = None if len(d["media_details"]) == 0 else json.loads(d["media_details"]) + + with open(args.x_users_file, "r", encoding="utf-8") as fin: + x_users = list(csv.DictReader(fin)) + for d in x_users: + d["followers_count"] = int(d["followers_count"]) + d["following_count"] = int(d["following_count"]) + if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) @@ -52,3 +68,49 @@ for t in d["topic"]: if t in topic_name_to_id_map: writer.writerow({"note_id": d["note_id"], "topic_id": topic_name_to_id_map[t]}) + + with open(os.path.join(args.output_dir, args.posts_file_name), "w", encoding="utf-8") as fout: + writer = csv.DictWriter( + fout, + fieldnames=[ + "post_id", + "user_id", + "text", + "media_details", + "created_at", + "like_count", + "repost_count", + "impression_count", + ], + ) + writer.writeheader() + for d in posts: + writer.writerow( + { + "post_id": d["post_id"], + "user_id": d["user_id"], + "text": d["text"], + "media_details": json.dumps(d["media_details"]), + "created_at": 1288834974657, + "like_count": 0, + "repost_count": 0, + "impression_count": 0, + } + ) + + with open(os.path.join(args.output_dir, args.x_users_file_name), "w", encoding="utf-8") as fout: + writer = csv.DictWriter( + fout, + fieldnames=["user_id", "name", "profile_image", "followers_count", "following_count"], + ) + writer.writeheader() + for d in x_users: + writer.writerow( + { + "user_id": d["user_id"], + "name": d["name"], + "profile_image": d["profile_image"], + "followers_count": d["followers_count"], + "following_count": d["following_count"], + } + ) diff --git a/scripts/migrations/migrate_all.py b/scripts/migrations/migrate_all.py index 0c25d72..140182d 100644 --- a/scripts/migrations/migrate_all.py +++ b/scripts/migrations/migrate_all.py @@ -12,7 +12,9 @@ Base, NoteRecord, NoteTopicAssociation, + PostRecord, TopicRecord, + XUserRecord, gen_storage, ) @@ -22,6 +24,9 @@ parser.add_argument("--notes-file-name", default="notes.csv") parser.add_argument("--topics-file-name", default="topics.csv") parser.add_argument("--notes-topics-association-file-name", default="note_topic.csv") + parser.add_argument("--posts-file-name", default="posts.csv") + parser.add_argument("--x-users-file-name", default="x_users.csv") + parser.add_argument("--limit-number-of-post-rows", type=int, default=None) load_dotenv() args = parser.parse_args() settings = GlobalSettings() @@ -71,4 +76,45 @@ ) ) sess.commit() + with open(os.path.join(args.data_dir, args.x_users_file_name), "r", encoding="utf-8") as fin: + for d in csv.DictReader(fin): + d["followers_count"] = int(d["followers_count"]) + d["following_count"] = int(d["following_count"]) + if sess.query(XUserRecord).filter(XUserRecord.user_id == d["user_id"]).count() > 0: + continue + sess.add( + XUserRecord( + user_id=d["user_id"], + name=d["name"], + profile_image=d["profile_image"], + followers_count=d["followers_count"], + following_count=d["following_count"], + ) + ) + sess.commit() + with open(os.path.join(args.data_dir, args.posts_file_name), "r", encoding="utf-8") as fin: + for d in csv.DictReader(fin): + if ( + args.limit_number_of_post_rows is not None + and sess.query(PostRecord).count() >= args.limit_number_of_post_rows + ): + break + d["like_count"] = int(d["like_count"]) + d["repost_count"] = int(d["repost_count"]) + d["impression_count"] = int(d["impression_count"]) + if sess.query(PostRecord).filter(PostRecord.post_id == d["post_id"]).count() > 0: + continue + sess.add( + PostRecord( + post_id=d["post_id"], + user_id=d["user_id"], + text=d["text"], + media_details=json.loads(d["media_details"]) if len(d["media_details"]) > 0 else None, + created_at=d["created_at"], + like_count=d["like_count"], + repost_count=d["repost_count"], + impression_count=d["impression_count"], + ) + ) + sess.commit() logger.info("Migration is done") diff --git a/tests/conftest.py b/tests/conftest.py index f13eeeb..adf6df0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,17 +20,21 @@ from birdxplorer.models import ( Note, ParticipantId, + Post, Topic, TwitterTimestamp, UserEnrollment, + XUser, ) from birdxplorer.settings import GlobalSettings, PostgresStorageSettings from birdxplorer.storage import ( Base, NoteRecord, NoteTopicAssociation, + PostRecord, Storage, TopicRecord, + XUserRecord, ) @@ -90,6 +94,16 @@ class TopicFactory(ModelFactory[Topic]): __model__ = Topic +@register_fixture(name="x_user_factory") +class XUserFactory(ModelFactory[XUser]): + __model__ = XUser + + +@register_fixture(name="post_factory") +class PostFactory(ModelFactory[Post]): + __model__ = Post + + @fixture def user_enrollment_samples( user_enrollment_factory: UserEnrollmentFactory, @@ -99,7 +113,7 @@ def user_enrollment_samples( @fixture def mock_storage( - user_enrollment_samples: List[UserEnrollment], topic_samples: List[Topic] + user_enrollment_samples: List[UserEnrollment], topic_samples: List[Topic], post_samples: List[Post] ) -> Generator[MagicMock, None, None]: mock = MagicMock(spec=Storage) @@ -116,6 +130,11 @@ def _get_topics() -> Generator[Topic, None, None]: mock.get_topics.side_effect = _get_topics + def _get_posts() -> Generator[Post, None, None]: + yield from post_samples + + mock.get_posts.side_effect = _get_posts + yield mock @@ -186,6 +205,74 @@ def note_samples(note_factory: NoteFactory, topic_samples: List[Topic]) -> Gener yield notes +@fixture +def x_user_samples(x_user_factory: XUserFactory) -> Generator[List[XUser], None, None]: + x_users = [ + x_user_factory.build( + user_id="1234567890123456781", + name="User1", + profile_image_url="https://pbs.twimg.com/profile_images/1468001914302390XXX/xxxxXXXX_normal.jpg", + followers_count=100, + following_count=200, + ), + x_user_factory.build( + user_id="1234567890123456782", + name="User2", + profile_image_url="https://pbs.twimg.com/profile_images/1468001914302390YYY/yyyyYYYY_normal.jpg", + followers_count=300, + following_count=400, + ), + x_user_factory.build( + user_id="1234567890123456783", + name="User3", + profile_image_url="https://pbs.twimg.com/profile_images/1468001914302390ZZZ/zzzzZZZZ_normal.jpg", + followers_count=300, + following_count=400, + ), + ] + yield x_users + + +@fixture +def post_samples(post_factory: PostFactory, x_user_samples: List[XUser]) -> Generator[List[Post], None, None]: + posts = [ + post_factory.build( + post_id="2234567890123456781", + x_user_id="1234567890123456781", + x_user=x_user_samples[0], + text="text11", + media_details=None, + created_at=1152921600000, + like_count=10, + repost_count=20, + impression_count=30, + ), + post_factory.build( + post_id="2234567890123456791", + x_user_id="1234567890123456781", + x_user=x_user_samples[0], + text="text12", + media_details=None, + created_at=1152921700000, + like_count=10, + repost_count=20, + impression_count=30, + ), + post_factory.build( + post_id="2234567890123456801", + x_user_id="1234567890123456782", + x_user=x_user_samples[1], + text="text21", + media_details=None, + created_at=1152921800000, + like_count=10, + repost_count=20, + impression_count=30, + ), + ] + yield posts + + TEST_DATABASE_NAME = "bx_test" @@ -274,3 +361,49 @@ def note_records_sample( res.append(inst) sess.commit() yield res + + +@fixture +def x_user_records_sample( + x_user_samples: List[XUser], + engine_for_test: Engine, +) -> Generator[List[XUserRecord], None, None]: + res = [ + XUserRecord( + user_id=d.user_id, + name=d.name, + profile_image=d.profile_image, + followers_count=d.followers_count, + following_count=d.following_count, + ) + for d in x_user_samples + ] + with Session(engine_for_test) as sess: + sess.add_all(res) + sess.commit() + yield res + + +@fixture +def post_records_sample( + x_user_records_sample: List[XUserRecord], + post_samples: List[Post], + engine_for_test: Engine, +) -> Generator[List[PostRecord], None, None]: + res = [ + PostRecord( + post_id=d.post_id, + user_id=d.x_user_id, + text=d.text, + media_details=d.media_details, + created_at=d.created_at, + like_count=d.like_count, + repost_count=d.repost_count, + impression_count=d.impression_count, + ) + for d in post_samples + ] + with Session(engine_for_test) as sess: + sess.add_all(res) + sess.commit() + yield res diff --git a/tests/routers/test_data.py b/tests/routers/test_data.py index 0997cd1..9895fcf 100644 --- a/tests/routers/test_data.py +++ b/tests/routers/test_data.py @@ -1,8 +1,9 @@ +import json from typing import List from fastapi.testclient import TestClient -from birdxplorer.models import Topic, UserEnrollment +from birdxplorer.models import Post, Topic, UserEnrollment def test_user_enrollments_get(client: TestClient, user_enrollment_samples: List[UserEnrollment]) -> None: @@ -17,3 +18,10 @@ def test_topics_get(client: TestClient, topic_samples: List[Topic]) -> None: assert response.status_code == 200 res_json = response.json() assert res_json == {"data": [d.model_dump(by_alias=True) for d in topic_samples]} + + +def test_posts_get(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(d.model_dump_json()) for d in post_samples]} diff --git a/tests/test_storage.py b/tests/test_storage.py index f3b380d..c8ab652 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -2,8 +2,8 @@ from sqlalchemy.engine import Engine -from birdxplorer.models import Topic -from birdxplorer.storage import NoteRecord, Storage, TopicRecord +from birdxplorer.models import Post, Topic +from birdxplorer.storage import NoteRecord, PostRecord, Storage, TopicRecord def test_get_topic_list( @@ -16,3 +16,16 @@ def test_get_topic_list( expected = sorted(topic_samples, key=lambda x: x.topic_id) actual = sorted(storage.get_topics(), key=lambda x: x.topic_id) assert expected == actual + + +def test_get_post_list( + engine_for_test: Engine, + post_samples: List[Post], + post_records_sample: List[PostRecord], + topic_records_sample: List[TopicRecord], + note_records_sample: List[NoteRecord], +) -> None: + storage = Storage(engine=engine_for_test) + expected = sorted(post_samples, key=lambda x: x.post_id) + actual = sorted(storage.get_posts(), key=lambda x: x.post_id) + assert expected == actual