Skip to content

Commit

Permalink
既存のテストが壊れる部分を修正
Browse files Browse the repository at this point in the history
  • Loading branch information
sushichan044 committed Oct 8, 2024
1 parent 54a284e commit 9d5555d
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 59 deletions.
19 changes: 10 additions & 9 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
from typing import List, Type, Union
from unittest.mock import MagicMock, patch

from dotenv import load_dotenv
from fastapi.testclient import TestClient
from polyfactory import Use
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.pytest_plugin import register_fixture
from pydantic import HttpUrl
from pytest import fixture

from birdxplorer_common.exceptions import UserEnrollmentNotFoundError
from birdxplorer_common.models import (
LanguageIdentifier,
Expand All @@ -27,13 +35,6 @@
PostgresStorageSettings,
)
from birdxplorer_common.storage import Storage
from dotenv import load_dotenv
from fastapi.testclient import TestClient
from polyfactory import Use
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.pytest_plugin import register_fixture
from pydantic import HttpUrl
from pytest import fixture


def gen_random_twitter_timestamp() -> int:
Expand Down Expand Up @@ -272,7 +273,7 @@ def post_samples(
x_user_id="1234567890123456782",
x_user=x_user_samples[1],
text="https://t.co/zzzzzzzzzzz/ https://t.co/wwwwwwwwwww/",
media_details=None,
media_details=[],
created_at=1154922900000,
like_count=10,
repost_count=20,
Expand All @@ -285,7 +286,7 @@ def post_samples(
x_user_id="1234567890123456783",
x_user=x_user_samples[2],
text="empty",
media_details=None,
media_details=[],
created_at=1154923900000,
like_count=10,
repost_count=20,
Expand Down
9 changes: 2 additions & 7 deletions common/birdxplorer_common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,9 @@
from uuid import UUID

from pydantic import BaseModel as PydanticBaseModel
from pydantic import (
ConfigDict,
GetCoreSchemaHandler,
HttpUrl,
TypeAdapter,
model_validator,
)
from pydantic import ConfigDict
from pydantic import Field as PydanticField
from pydantic import GetCoreSchemaHandler, HttpUrl, TypeAdapter, model_validator
from pydantic.alias_generators import to_camel
from pydantic.main import IncEx
from pydantic_core import core_schema
Expand Down
25 changes: 8 additions & 17 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,22 @@
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column, relationship
from sqlalchemy.types import CHAR, DECIMAL, JSON, Integer, String, Uuid

from .models import BinaryBool, LanguageIdentifier
from .models import Link as LinkModel
from .models import LinkId, Media, MediaDetails, MediaType, NonNegativeInt
from .models import Note as NoteModel
from .models import NoteId, NotesClassification, NotesHarmful, ParticipantId
from .models import Post as PostModel
from .models import PostId, SummaryString
from .models import Topic as TopicModel
from .models import (
BinaryBool,
LanguageIdentifier,
LinkId,
Media,
MediaDetails,
MediaType,
NonNegativeInt,
NoteId,
NotesClassification,
NotesHarmful,
ParticipantId,
PostId,
SummaryString,
TopicId,
TopicLabel,
TwitterTimestamp,
UserEnrollment,
UserId,
UserName,
)
from .models import Link as LinkModel
from .models import Note as NoteModel
from .models import Post as PostModel
from .models import Topic as TopicModel
from .models import XUser as XUserModel
from .settings import GlobalSettings

Expand Down
61 changes: 36 additions & 25 deletions common/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@
from collections.abc import Generator
from typing import List, Type

from dotenv import load_dotenv
from polyfactory import Use
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.pytest_plugin import register_fixture
from pytest import fixture
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from sqlalchemy.sql import text

from birdxplorer_common.models import (
Link,
Media,
Expand All @@ -26,16 +37,6 @@
TopicRecord,
XUserRecord,
)
from dotenv import load_dotenv
from polyfactory import Use
from polyfactory.factories.pydantic_factory import ModelFactory
from polyfactory.pytest_plugin import register_fixture
from pytest import fixture
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from sqlalchemy.sql import text


def gen_random_twitter_timestamp() -> int:
Expand Down Expand Up @@ -316,7 +317,7 @@ def post_samples(
x_user_id="1234567890123456782",
x_user=x_user_samples[1],
text="https://t.co/zzzzzzzzzzz/ https://t.co/wwwwwwwwwww/",
media_details=None,
media_details=[],
created_at=1154922900000,
like_count=10,
repost_count=20,
Expand All @@ -329,7 +330,7 @@ def post_samples(
x_user_id="1234567890123456783",
x_user=x_user_samples[2],
text="empty",
media_details=None,
media_details=[],
created_at=1154923900000,
like_count=10,
repost_count=20,
Expand Down Expand Up @@ -451,6 +452,27 @@ def x_user_records_sample(
yield res


@fixture
def media_records_sample(
media_samples: List[Media],
engine_for_test: Engine,
) -> Generator[List[MediaRecord], None, None]:
res = [
MediaRecord(
media_key=d.media_key,
url=d.url,
type=d.type,
width=d.width,
height=d.height,
)
for d in media_samples
]
with Session(engine_for_test) as sess:
sess.add_all(res)
sess.commit()
yield res


@fixture
def link_records_sample(
link_samples: List[Link],
Expand All @@ -467,23 +489,12 @@ def link_records_sample(
def post_records_sample(
x_user_records_sample: List[XUserRecord],
media_records_sample: List[MediaRecord],
link_records_sample: List[LinkRecord],
link_samples: List[Link],
post_samples: List[Post],
engine_for_test: Engine,
) -> Generator[List[PostRecord], None, None]:
res = [
PostRecord(
post_id=d.post_id,
user_id=d.x_user_id,
text=d.text,
media_details=d.media_details,
created_at=d.created_at,
like_count=d.like_count,
repost_count=d.repost_count,
impression_count=d.impression_count,
)
for d in post_samples
]
res = []
with Session(engine_for_test) as sess:
for post in post_samples:
inst = PostRecord(
Expand Down
2 changes: 1 addition & 1 deletion common/tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_get_topic_list(
[dict(search_url=HttpUrl("https://example.com/sh3")), [2, 3]],
[dict(note_ids=[NoteId.from_str("1234567890123456781")]), [0]],
[dict(offset=1, limit=1, search_text="https://t.co/xxxxxxxxxxx/"), [2]],
[dict(with_media=True), [0, 1, 2]],
[dict(with_media=True), [0, 1, 2, 3, 4]],
[dict(post_ids=[PostId.from_str("2234567890123456781")], with_media=False), [0]],
],
)
Expand Down

0 comments on commit 9d5555d

Please sign in to comment.