From f466efde40f8083cc8ca26071c1d38fdcd925fae Mon Sep 17 00:00:00 2001 From: sushichan044 Date: Tue, 8 Oct 2024 17:35:57 +0900 Subject: [PATCH] fix: error after rebase --- api/birdxplorer_api/routers/data.py | 7 ++++--- api/tests/conftest.py | 2 +- common/birdxplorer_common/models.py | 11 ++++------- common/birdxplorer_common/storage.py | 25 ++++++++----------------- common/tests/conftest.py | 23 ++++++++++++----------- common/tests/test_storage.py | 2 +- 6 files changed, 30 insertions(+), 40 deletions(-) diff --git a/api/birdxplorer_api/routers/data.py b/api/birdxplorer_api/routers/data.py index 820105b..e0ef426 100644 --- a/api/birdxplorer_api/routers/data.py +++ b/api/birdxplorer_api/routers/data.py @@ -1,6 +1,10 @@ from datetime import timezone from typing import List, Union +from dateutil.parser import parse as dateutil_parse +from fastapi import APIRouter, HTTPException, Query, Request +from pydantic import HttpUrl + from birdxplorer_common.models import ( BaseModel, LanguageIdentifier, @@ -16,9 +20,6 @@ UserEnrollment, ) from birdxplorer_common.storage import Storage -from dateutil.parser import parse as dateutil_parse -from fastapi import APIRouter, HTTPException, Query, Request -from pydantic import HttpUrl class TopicListResponse(BaseModel): diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 35572e1..affa005 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -218,7 +218,7 @@ def media_samples(media_factory: MediaFactory) -> Generator[List[Media], None, N @fixture def post_samples( - post_factory: PostFactory, x_user_samples: List[XUser], media_samples: List[Media],link_samples: List[Link] + post_factory: PostFactory, x_user_samples: List[XUser], media_samples: List[Media], link_samples: List[Link] ) -> Generator[List[Post], None, None]: posts = [ post_factory.build( diff --git a/common/birdxplorer_common/models.py b/common/birdxplorer_common/models.py index 3e54f8e..908d52a 100644 --- a/common/birdxplorer_common/models.py +++ b/common/birdxplorer_common/models.py @@ -17,17 +17,14 @@ from uuid import UUID from pydantic import BaseModel as PydanticBaseModel -from pydantic import ( - ConfigDict, - GetCoreSchemaHandler, - HttpUrl, - TypeAdapter, - model_validator, -) +from pydantic import ConfigDict from pydantic import Field as PydanticField +from pydantic import GetCoreSchemaHandler, HttpUrl, TypeAdapter, model_validator from pydantic.alias_generators import to_camel +from pydantic.main import IncEx from pydantic_core import core_schema +StrT = TypeVar("StrT", bound="BaseString") IntT = TypeVar("IntT", bound="BaseInt") FloatT = TypeVar("FloatT", bound="BaseFloat") diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index 4b07f26..6eee4af 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -7,20 +7,15 @@ from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String, Uuid +from .models import BinaryBool, LanguageIdentifier +from .models import Link as LinkModel +from .models import LinkId, Media, MediaDetails, MediaType, NonNegativeInt +from .models import Note as NoteModel +from .models import NoteId, NotesClassification, NotesHarmful, ParticipantId +from .models import Post as PostModel +from .models import PostId, SummaryString +from .models import Topic as TopicModel from .models import ( - BinaryBool, - LanguageIdentifier, - LinkId, - Media, - MediaDetails, - MediaType, - NonNegativeInt, - NoteId, - NotesClassification, - NotesHarmful, - ParticipantId, - PostId, - SummaryString, TopicId, TopicLabel, TwitterTimestamp, @@ -28,10 +23,6 @@ UserId, UserName, ) -from .models import Link as LinkModel -from .models import Note as NoteModel -from .models import Post as PostModel -from .models import Topic as TopicModel from .models import XUser as XUserModel from .settings import GlobalSettings diff --git a/common/tests/conftest.py b/common/tests/conftest.py index e8430ee..78865d7 100644 --- a/common/tests/conftest.py +++ b/common/tests/conftest.py @@ -3,6 +3,17 @@ from collections.abc import Generator from typing import List, Type +from dotenv import load_dotenv +from polyfactory import Use +from polyfactory.factories.pydantic_factory import ModelFactory +from polyfactory.pytest_plugin import register_fixture +from pytest import fixture +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session +from sqlalchemy.sql import text + from birdxplorer_common.models import ( Link, Media, @@ -26,16 +37,6 @@ TopicRecord, XUserRecord, ) -from dotenv import load_dotenv -from polyfactory import Use -from polyfactory.factories.pydantic_factory import ModelFactory -from polyfactory.pytest_plugin import register_fixture -from pytest import fixture -from sqlalchemy import create_engine -from sqlalchemy.engine import Engine -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import Session -from sqlalchemy.sql import text def gen_random_twitter_timestamp() -> int: @@ -499,13 +500,13 @@ def post_records_sample( post_id=post.post_id, user_id=post.x_user_id, text=post.text, - media_details=post.media_details, created_at=post.created_at, like_count=post.like_count, repost_count=post.repost_count, impression_count=post.impression_count, ) sess.add(inst) + for link in post.links: post_link_assoc = PostLinkAssociation(link_id=link.link_id, post_id=inst.post_id) sess.add(post_link_assoc) diff --git a/common/tests/test_storage.py b/common/tests/test_storage.py index 07d6fb2..ce854b1 100644 --- a/common/tests/test_storage.py +++ b/common/tests/test_storage.py @@ -45,7 +45,7 @@ def test_get_topic_list( [dict(search_url=HttpUrl("https://example.com/sh3")), [2, 3]], [dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]], [dict(offset=1, limit=1, search_text="https://t.co/xxxxxxxxxxx/"), [2]], - [dict(with_media=True), [0, 1, 2]], + [dict(with_media=True), [0, 1, 2, 3, 4]], [dict(post_ids=[PostId.from_str("2234567890123456781")], with_media=False), [0]], ], )