Skip to content

Commit

Permalink
Merge branch 'main' of github.com:codeforjapan/BirdXplorer into feat/…
Browse files Browse the repository at this point in the history
…add-topic-seed

# Conflicts:
#	common/birdxplorer_common/models.py
  • Loading branch information
ayuki-joto committed Aug 14, 2024
2 parents 719343c + e3c7e37 commit 3135420
Show file tree
Hide file tree
Showing 15 changed files with 296 additions and 84 deletions.
52 changes: 52 additions & 0 deletions api/Dockerfile.dev
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
ARG PYTHON_VERSION_CODE=3.10
ARG ENVIRONMENT="dev"
# ENVIRONMENT: dev or prod, refer to project.optional-dependencies in pyproject.toml

FROM python:${PYTHON_VERSION_CODE}-bookworm as builder
ARG PYTHON_VERSION_CODE
ARG ENVIRONMENT

WORKDIR /app
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1

COPY api/pyproject.toml api/README.md ./
COPY api/birdxplorer_api/__init__.py ./birdxplorer_api/

RUN if [ "${ENVIRONMENT}" = "prod" ]; then \
apt-get update && apt-get install -y --no-install-recommends \
postgresql-client-15 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*; \
fi

RUN python -m pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir -e ".[${ENVIRONMENT}]"

COPY ../common ./common
RUN if [ "${ENVIRONMENT}" = "dev" ]; then \
pip install -e ./common; \
fi

FROM python:${PYTHON_VERSION_CODE}-slim-bookworm as runner
ARG PYTHON_VERSION_CODE
ARG ENVIRONMENT

WORKDIR /app

RUN if [ "${ENVIRONMENT}" = "prod" ]; then \
apt-get update && apt-get install -y --no-install-recommends \
libpq5 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*; \
fi

RUN groupadd -r app && useradd -r -g app app
RUN chown -R app:app /app
USER app

COPY --from=builder /usr/local/lib/python${PYTHON_VERSION_CODE}/site-packages /usr/local/lib/python${PYTHON_VERSION_CODE}/site-packages
COPY --chown=app:app api ./
COPY ../common ./common

ENTRYPOINT ["python", "-m", "uvicorn", "birdxplorer_api.main:app", "--host", "0.0.0.0"]
46 changes: 46 additions & 0 deletions api/Dockerfile.prd
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
ARG PYTHON_VERSION_CODE=3.10
ARG ENVIRONMENT="prod"
# ENVIRONMENT: dev or prod, refer to project.optional-dependencies in pyproject.toml

FROM python:${PYTHON_VERSION_CODE}-bookworm as builder
ARG PYTHON_VERSION_CODE
ARG ENVIRONMENT

WORKDIR /app
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1

COPY pyproject.toml README.md ./
COPY birdxplorer_api/__init__.py ./birdxplorer_api/

RUN if [ "${ENVIRONMENT}" = "prod" ]; then \
apt-get update && apt-get install -y --no-install-recommends \
postgresql-client-15 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*; \
fi

RUN python -m pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir -e ".[${ENVIRONMENT}]"

FROM python:${PYTHON_VERSION_CODE}-slim-bookworm as runner
ARG PYTHON_VERSION_CODE
ARG ENVIRONMENT

WORKDIR /app

RUN if [ "${ENVIRONMENT}" = "prod" ]; then \
apt-get update && apt-get install -y --no-install-recommends \
libpq5 \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*; \
fi

RUN groupadd -r app && useradd -r -g app app
RUN chown -R app:app /app
USER app

COPY --from=builder /usr/local/lib/python${PYTHON_VERSION_CODE}/site-packages /usr/local/lib/python${PYTHON_VERSION_CODE}/site-packages
COPY --chown=app:app . ./

ENTRYPOINT ["python", "-m", "gunicorn", "birdxplorer_api.main:app", "-k", "uvicorn.workers.UvicornWorker", "-w", "1", "--bind", "0.0.0.0:10000"]
6 changes: 4 additions & 2 deletions api/birdxplorer_api/routers/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
PostId,
Topic,
TopicId,
TweetId,
TwitterTimestamp,
UserEnrollment,
)
Expand Down Expand Up @@ -73,7 +72,7 @@ def get_notes(
created_at_from: Union[None, TwitterTimestamp] = Query(default=None),
created_at_to: Union[None, TwitterTimestamp] = Query(default=None),
topic_ids: Union[List[TopicId], None] = Query(default=None),
post_ids: Union[List[TweetId], None] = Query(default=None),
post_ids: Union[List[PostId], None] = Query(default=None),
language: Union[LanguageIdentifier, None] = Query(default=None),
) -> NoteListResponse:
return NoteListResponse(
Expand All @@ -92,11 +91,14 @@ def get_notes(
@router.get("/posts", response_model=PostListResponse)
def get_posts(
post_id: Union[List[PostId], None] = Query(default=None),
note_id: Union[List[NoteId], None] = Query(default=None),
created_at_start: Union[None, TwitterTimestamp, str] = Query(default=None),
created_at_end: Union[None, TwitterTimestamp, str] = Query(default=None),
) -> PostListResponse:
if post_id is not None:
return PostListResponse(data=list(storage.get_posts_by_ids(post_ids=post_id)))
if note_id is not None:
return PostListResponse(data=list(storage.get_posts_by_note_ids(note_ids=note_id)))
if created_at_start is not None:
if created_at_end is not None:
return PostListResponse(
Expand Down
3 changes: 2 additions & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"pydantic",
"starlette",
"python-dotenv",
"uvicorn[standard]",
]

[project.urls]
Expand All @@ -56,13 +57,13 @@ dev=[
"types-python-dateutil",
"psycopg2-binary",
"factory_boy",
"uvicorn",
"polyfactory",
"httpx",
]
prod=[
"birdxplorer_common @ git+https://github.com/codeforjapan/BirdXplorer.git@main#subdirectory=common",
"psycopg2",
"gunicorn",
]

[tool.pytest.ini_options]
Expand Down
12 changes: 10 additions & 2 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
PostId,
Topic,
TopicId,
TweetId,
TwitterTimestamp,
UserEnrollment,
XUser,
Expand Down Expand Up @@ -227,7 +226,7 @@ def _get_notes(
created_at_from: Union[None, TwitterTimestamp] = None,
created_at_to: Union[None, TwitterTimestamp] = None,
topic_ids: Union[List[TopicId], None] = None,
post_ids: Union[List[TweetId], None] = None,
post_ids: Union[List[PostId], None] = None,
language: Union[LanguageIdentifier, None] = None,
) -> Generator[Note, None, None]:
for note in note_samples:
Expand Down Expand Up @@ -262,6 +261,15 @@ def _get_posts_by_ids(post_ids: List[PostId]) -> Generator[Post, None, None]:

mock.get_posts_by_ids.side_effect = _get_posts_by_ids

def _get_posts_by_note_ids(note_ids: List[NoteId]) -> Generator[Post, None, None]:
for post in post_samples:
for note in note_samples:
if note.note_id in note_ids and post.post_id == note.post_id:
yield post
break

mock.get_posts_by_note_ids.side_effect = _get_posts_by_note_ids

def _get_posts_by_created_at_range(start: TwitterTimestamp, end: TwitterTimestamp) -> Generator[Post, None, None]:
for post in post_samples:
if start <= post.created_at < end:
Expand Down
17 changes: 17 additions & 0 deletions api/tests/routers/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ def test_posts_get_has_post_id_filter(client: TestClient, post_samples: List[Pos
}


def test_posts_get_has_note_id_filter(client: TestClient, post_samples: List[Post], note_samples: List[Note]) -> None:
response = client.get(f"/api/v1/data/posts/?noteId={','.join([n.note_id for n in note_samples])}")
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(post_samples[0].model_dump_json())]}


def test_posts_get_has_created_at_filter_start_and_end(client: TestClient, post_samples: List[Post]) -> None:
response = client.get("/api/v1/data/posts/?createdAtStart=2006-7-25 00:00:00&createdAtEnd=2006-7-30 23:59:59")
assert response.status_code == 200
Expand Down Expand Up @@ -124,3 +131,13 @@ def test_notes_get_has_created_at_filter_to(client: TestClient, note_samples: Li
assert response.status_code == 200
res_json = response.json()
assert res_json == {"data": [json.loads(note_samples[i].model_dump_json()) for i in (0, 1, 2, 3)]}


def test_notes_get_has_topic_id_filter(client: TestClient, note_samples: List[Note]) -> None:
correct_notes = [note for note in note_samples if note_samples[0].topics[0] in note.topics]
response = client.get(f"/api/v1/data/notes/?topicIds={note_samples[0].topics[0].topic_id.serialize()}")
assert response.status_code == 200
res_json = response.json()
assert res_json == {
"data": [json.loads(correct_notes[i].model_dump_json()) for i in range(correct_notes.__len__())]
}
18 changes: 12 additions & 6 deletions common/birdxplorer_common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ class NotesValidationDifficulty(str, Enum):
empty = ""


class TweetId(UpToNineteenDigitsDecimalString): ...
class PostId(UpToNineteenDigitsDecimalString): ...


class NoteData(BaseModel):
Expand All @@ -576,7 +576,7 @@ class NoteData(BaseModel):
note_id: NoteId
note_author_participant_id: ParticipantId
created_at_millis: TwitterTimestamp
tweet_id: TweetId
tweet_id: PostId
believable: NotesBelievable
misleading_other: BinaryBool
misleading_factual_error: BinaryBool
Expand Down Expand Up @@ -621,6 +621,15 @@ class LanguageIdentifier(str, Enum):
DA = "da"
RU = "ru"
PL = "pl"
OTHER = "other"

@classmethod
def normalize(cls, value: str, default: str = "other") -> str:
try:
cls(value)
return value
except ValueError:
return default


class TopicLabelString(NonEmptyTrimmedString): ...
Expand All @@ -642,7 +651,7 @@ class SummaryString(NonEmptyTrimmedString): ...

class Note(BaseModel):
note_id: NoteId
post_id: TweetId
post_id: PostId
language: LanguageIdentifier
topics: List[Topic]
summary: SummaryString
Expand All @@ -663,9 +672,6 @@ class XUser(BaseModel):
following_count: NonNegativeInt


class PostId(UpToNineteenDigitsDecimalString): ...


MediaDetails: TypeAlias = List[HttpUrl] | None


Expand Down
2 changes: 1 addition & 1 deletion common/birdxplorer_common/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class PostgresStorageSettings(BaseSettings):
port: int = 5432
database: str = "postgres"

@computed_field # type: ignore[misc]
@computed_field # type: ignore[prop-decorator]
@property
def sqlalchemy_database_url(self) -> str:
return PostgresDsn(
Expand Down
29 changes: 19 additions & 10 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from .models import (
TopicId,
TopicLabel,
TweetId,
TwitterTimestamp,
UserEnrollment,
UserId,
Expand All @@ -39,7 +38,7 @@ class Base(DeclarativeBase):
TopicLabel: JSON,
NoteId: String,
ParticipantId: String,
TweetId: String,
PostId: String,
LanguageIdentifier: String,
TwitterTimestamp: DECIMAL,
SummaryString: String,
Expand All @@ -65,7 +64,7 @@ class NoteRecord(Base):
__tablename__ = "notes"

note_id: Mapped[NoteId] = mapped_column(primary_key=True)
post_id: Mapped[TweetId] = mapped_column(nullable=False)
post_id: Mapped[PostId] = mapped_column(nullable=False)
topics: Mapped[List[NoteTopicAssociation]] = relationship()
language: Mapped[LanguageIdentifier] = mapped_column(nullable=False)
summary: Mapped[SummaryString] = mapped_column(nullable=False)
Expand All @@ -92,7 +91,7 @@ class XUserRecord(Base):
class PostRecord(Base):
__tablename__ = "posts"

post_id: Mapped[TweetId] = mapped_column(primary_key=True)
post_id: Mapped[PostId] = mapped_column(primary_key=True)
user_id: Mapped[UserId] = mapped_column(ForeignKey("x_users.user_id"), nullable=False)
user: Mapped[XUserRecord] = relationship()
text: Mapped[SummaryString] = mapped_column(nullable=False)
Expand All @@ -109,7 +108,7 @@ class RowNoteRecord(Base):
note_id: Mapped[NoteId] = mapped_column(primary_key=True)
note_author_participant_id: Mapped[ParticipantId] = mapped_column(nullable=False)
created_at_millis: Mapped[TwitterTimestamp] = mapped_column(nullable=False)
tweet_id: Mapped[TweetId] = mapped_column(nullable=False)
tweet_id: Mapped[PostId] = mapped_column(nullable=False)
believable: Mapped[BinaryBool] = mapped_column(nullable=False)
misleading_other: Mapped[BinaryBool] = mapped_column(nullable=False)
misleading_factual_error: Mapped[BinaryBool] = mapped_column(nullable=False)
Expand All @@ -129,14 +128,14 @@ class RowNoteRecord(Base):
harmful: Mapped[NotesHarmful] = mapped_column(nullable=False)
validation_difficulty: Mapped[SummaryString] = mapped_column(nullable=False)
summary: Mapped[SummaryString] = mapped_column(nullable=False)
row_post_id: Mapped[TweetId] = mapped_column(ForeignKey("row_posts.post_id"), nullable=True)
row_post_id: Mapped[PostId] = mapped_column(ForeignKey("row_posts.post_id"), nullable=True)
row_post: Mapped["RowPostRecord"] = relationship("RowPostRecord", back_populates="row_notes")


class RowPostRecord(Base):
__tablename__ = "row_posts"

post_id: Mapped[TweetId] = mapped_column(primary_key=True)
post_id: Mapped[PostId] = mapped_column(primary_key=True)
author_id: Mapped[UserId] = mapped_column(ForeignKey("row_users.user_id"), nullable=False)
text: Mapped[SummaryString] = mapped_column(nullable=False)
media_type: Mapped[String] = mapped_column(nullable=True)
Expand Down Expand Up @@ -224,7 +223,7 @@ def get_notes(
created_at_from: Union[None, TwitterTimestamp] = None,
created_at_to: Union[None, TwitterTimestamp] = None,
topic_ids: Union[List[TopicId], None] = None,
post_ids: Union[List[TweetId], None] = None,
post_ids: Union[List[PostId], None] = None,
language: Union[LanguageIdentifier, None] = None,
) -> Generator[NoteModel, None, None]:
with Session(self.engine) as sess:
Expand All @@ -241,7 +240,7 @@ def get_notes(
subq = (
select(NoteTopicAssociation.note_id)
.group_by(NoteTopicAssociation.note_id)
.having(func.array_agg(NoteTopicAssociation.topic_id) == topic_ids)
.having(func.bool_or(NoteTopicAssociation.topic_id.in_(topic_ids)))
.subquery()
)
query = query.join(subq, NoteRecord.note_id == subq.c.note_id)
Expand All @@ -264,7 +263,7 @@ def get_notes(
)
for topic in note_record.topics
],
language=note_record.language,
language=LanguageIdentifier.normalize(note_record.language),
summary=note_record.summary,
created_at=note_record.created_at,
)
Expand Down Expand Up @@ -296,6 +295,16 @@ def get_posts_by_created_at_end(self, end: TwitterTimestamp) -> Generator[PostMo
for post_record in sess.query(PostRecord).filter(PostRecord.created_at < end).all():
yield self._post_record_to_model(post_record)

def get_posts_by_note_ids(self, note_ids: List[NoteId]) -> Generator[PostModel, None, None]:
query = (
select(PostRecord)
.join(NoteRecord, NoteRecord.post_id == PostRecord.post_id)
.where(NoteRecord.note_id.in_(note_ids))
)
with Session(self.engine) as sess:
for post_record in sess.execute(query).scalars().all():
yield self._post_record_to_model(post_record)


def gen_storage(settings: GlobalSettings) -> Storage:
engine = create_engine(settings.storage_settings.sqlalchemy_database_url)
Expand Down
Loading

0 comments on commit 3135420

Please sign in to comment.