Skip to content

Commit

Permalink
Merge pull request #118 from codeforjapan/issue-108-post-media
Browse files Browse the repository at this point in the history
X APIからPostを取得する際にMedia情報を保存する
  • Loading branch information
yu23ki14 authored Oct 11, 2024
2 parents 22da0ca + 5229c67 commit 970e6db
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 13 deletions.
12 changes: 12 additions & 0 deletions common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ class RowPostRecord(Base):
user: Mapped["RowUserRecord"] = relationship("RowUserRecord", back_populates="row_post")


class RowPostMediaRecord(Base):
__tablename__ = "row_post_media"

media_key: Mapped[String] = mapped_column(primary_key=True)

url: Mapped[String] = mapped_column(nullable=False)
type: Mapped[MediaType] = mapped_column(nullable=False)
width: Mapped[NonNegativeInt] = mapped_column(nullable=False)
height: Mapped[NonNegativeInt] = mapped_column(nullable=False)

post_id: Mapped[PostId] = mapped_column(ForeignKey("row_posts.post_id"), nullable=False)

class RowPostEmbedURLRecord(Base):
__tablename__ = "row_post_embed_urls"

Expand Down
23 changes: 19 additions & 4 deletions etl/src/birdxplorer_etl/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lib.x.postlookup import lookup
from birdxplorer_common.storage import (
RowNoteRecord,
RowPostMediaRecord,
RowPostRecord,
RowUserRecord,
RowNoteStatusRecord,
Expand Down Expand Up @@ -145,16 +146,17 @@ def extract_data(db: Session):
db.add(db_user)

media_data = (
post["includes"]["media"][0]
post["includes"]["media"]
if "includes" in post and "media" in post["includes"] and len(post["includes"]["media"]) > 0
else {}
else [{}]
)

db_post = RowPostRecord(
post_id=post["data"]["id"],
author_id=post["data"]["author_id"],
text=post["data"]["text"],
media_type=media_data.get("type", ""),
media_url=media_data.get("url", ""),
media_type=media_data[0].get("type", ""),
media_url=media_data[0].get("url", ""),
created_at=created_at_millis,
like_count=post["data"]["public_metrics"]["like_count"],
repost_count=post["data"]["public_metrics"]["retweet_count"],
Expand All @@ -166,6 +168,19 @@ def extract_data(db: Session):
)
db.add(db_post)

media_recs = [
RowPostMediaRecord(
media_key=m["media_key"],
type=m["type"],
url=m["url"],
width=m["width"],
height=m["height"],
post_id=post["data"]["id"],
)
for m in media_data
]
db.add_all(media_recs)

if "entities" in post["data"] and "urls" in post["data"]["entities"]:
for url in post["data"]["entities"]["urls"]:
if "unwound_url" in url:
Expand Down
68 changes: 59 additions & 9 deletions etl/src/birdxplorer_etl/transform.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from sqlalchemy import select, func, and_, Integer
import csv
import os
import random
import uuid
from pathlib import Path
from typing import Generator

from prefect import get_run_logger
from sqlalchemy import Integer, and_, func, select
from sqlalchemy.orm import Session

from birdxplorer_common.storage import (
RowNoteRecord,
RowPostRecord,
RowUserRecord,
RowNoteStatusRecord,
RowPostEmbedURLRecord,
RowPostMediaRecord,
RowPostRecord,
RowUserRecord,
)
from birdxplorer_etl.lib.ai_model.ai_model_interface import get_ai_service
from birdxplorer_etl.settings import (
TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND,
TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND,
TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND,
)
import csv
import os
from prefect import get_run_logger
import uuid
import random


def transform_data(db: Session):
Expand Down Expand Up @@ -147,6 +152,7 @@ def transform_data(db: Session):
offset += limit

# Transform row post embed link
write_media_csv(db)
generate_post_link(db)

# Transform row post embed url data and generate post_embed_url.csv
Expand Down Expand Up @@ -178,6 +184,50 @@ def transform_data(db: Session):
return


def write_media_csv(db: Session) -> None:
media_csv_path = Path("./data/transformed/media.csv")
post_media_association_csv_path = Path("./data/transformed/post_media_association.csv")

if media_csv_path.exists():
media_csv_path.unlink(missing_ok=True)
if post_media_association_csv_path.exists():
post_media_association_csv_path.unlink(missing_ok=True)

with (
media_csv_path.open("a", newline="", encoding="utf-8") as media_csv,
post_media_association_csv_path.open("a", newline="", encoding="utf-8") as assoc_csv,
):
media_fields = ["media_key", "type", "url", "width", "height", "post_id"]
media_writer = csv.DictWriter(media_csv, fieldnames=media_fields)
media_writer.writeheader()
assoc_fields = ["post_id", "media_key"]
assoc_writer = csv.DictWriter(assoc_csv, fieldnames=assoc_fields)
assoc_writer.writeheader()

for m in _iterate_media(db):
media_writer.writerow(
{
"media_key": m.media_key,
"type": m.type,
"url": m.url,
"width": m.width,
"height": m.height,
"post_id": m.post_id,
}
)
assoc_writer.writerow({"post_id": m.post_id, "media_key": m.media_key})


def _iterate_media(db: Session, limit: int = 1000) -> Generator[RowPostMediaRecord, None, None]:
offset = 0
total_media: int = db.query(func.count(RowPostMediaRecord.media_key)).scalar() or 0

while offset < total_media:
yield from db.query(RowPostMediaRecord).limit(limit).offset(offset)

offset += limit


def generate_post_link(db: Session):
link_csv_file_path = "./data/transformed/post_link.csv"
association_csv_file_path = "./data/transformed/post_link_association.csv"
Expand Down

0 comments on commit 970e6db

Please sign in to comment.