Skip to content

Commit

Permalink
Merge pull request #121 from codeforjapan/infra/etl-docker
Browse files Browse the repository at this point in the history
Infra/etl docker
  • Loading branch information
yu23ki14 authored Oct 11, 2024
2 parents c1f6f8f + afc8027 commit b147a78
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 62 deletions.
2 changes: 1 addition & 1 deletion common/birdxplorer_common/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class RowUserRecord(Base):
followers_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
following_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
tweet_count: Mapped[NonNegativeInt] = mapped_column(nullable=False)
verified: Mapped[BinaryBool] = mapped_column(nullable=False)
verified: Mapped[bool] = mapped_column(nullable=False)
verified_type: Mapped[String] = mapped_column(nullable=False)
location: Mapped[String] = mapped_column(nullable=False)
url: Mapped[String] = mapped_column(nullable=False)
Expand Down
56 changes: 28 additions & 28 deletions etl/src/birdxplorer_etl/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
import settings


def extract_data(db: Session):
def extract_data(sqlite: Session, postgresql: Session):
logging.info("Downloading community notes data")

# get columns of post table
columns = db.query(RowUserRecord).statement.columns.keys()
columns = sqlite.query(RowUserRecord).statement.columns.keys()
logging.info(columns)

# Noteデータを取得してSQLiteに保存
date = datetime.now()
latest_note = db.query(RowNoteRecord).order_by(RowNoteRecord.created_at_millis.desc()).first()
latest_note = sqlite.query(RowNoteRecord).order_by(RowNoteRecord.created_at_millis.desc()).first()

while True:
if (
Expand All @@ -46,20 +46,20 @@ def extract_data(db: Session):
res = requests.get(note_url)

if res.status_code == 200:
# res.contentをdbのNoteテーブル
# res.contentをsqliteのNoteテーブル
tsv_data = res.content.decode("utf-8").splitlines()
reader = csv.DictReader(tsv_data, delimiter="\t")
reader.fieldnames = [stringcase.snakecase(field) for field in reader.fieldnames]

rows_to_add = []
for index, row in enumerate(reader):
if db.query(RowNoteRecord).filter(RowNoteRecord.note_id == row["note_id"]).first():
if sqlite.query(RowNoteRecord).filter(RowNoteRecord.note_id == row["note_id"]).first():
continue
rows_to_add.append(RowNoteRecord(**row))
if index % 1000 == 0:
db.bulk_save_objects(rows_to_add)
sqlite.bulk_save_objects(rows_to_add)
rows_to_add = []
db.bulk_save_objects(rows_to_add)
sqlite.bulk_save_objects(rows_to_add)

status_url = f"https://ton.twimg.com/birdwatch-public-data/{dateString}/noteStatusHistory/noteStatusHistory-00000.tsv"
if settings.USE_DUMMY_DATA:
Expand All @@ -78,34 +78,36 @@ def extract_data(db: Session):
for key, value in list(row.items()):
if value == "":
row[key] = None
status = db.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).first()
status = (
sqlite.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).first()
)
if status is None or status.created_at_millis > int(datetime.now().timestamp() * 1000):
db.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).delete()
sqlite.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).delete()
rows_to_add.append(RowNoteStatusRecord(**row))
if index % 1000 == 0:
db.bulk_save_objects(rows_to_add)
sqlite.bulk_save_objects(rows_to_add)
rows_to_add = []
db.bulk_save_objects(rows_to_add)
sqlite.bulk_save_objects(rows_to_add)

break

date = date - timedelta(days=1)

db.commit()
sqlite.commit()

# Noteに紐づくtweetデータを取得
postExtract_targetNotes = (
db.query(RowNoteRecord)
sqlite.query(RowNoteRecord)
.filter(RowNoteRecord.tweet_id != None)
.filter(RowNoteRecord.created_at_millis >= settings.TARGET_TWITTER_POST_START_UNIX_MILLISECOND)
.filter(RowNoteRecord.created_at_millis <= settings.TARGET_TWITTER_POST_END_UNIX_MILLISECOND)
.all()
)
logging.info(len(postExtract_targetNotes))
logging.info(f"Target notes: {len(postExtract_targetNotes)}")
for note in postExtract_targetNotes:
tweet_id = note.tweet_id

is_tweetExist = db.query(RowPostRecord).filter(RowPostRecord.post_id == str(tweet_id)).first()
is_tweetExist = postgresql.query(RowPostRecord).filter(RowPostRecord.post_id == str(tweet_id)).first()
if is_tweetExist is not None:
logging.info(f"tweet_id {tweet_id} is already exist")
note.row_post_id = tweet_id
Expand All @@ -120,15 +122,17 @@ def extract_data(db: Session):
created_at = datetime.strptime(post["data"]["created_at"], "%Y-%m-%dT%H:%M:%S.%fZ")
created_at_millis = int(created_at.timestamp() * 1000)

is_userExist = db.query(RowUserRecord).filter(RowUserRecord.user_id == post["data"]["author_id"]).first()
is_userExist = (
postgresql.query(RowUserRecord).filter(RowUserRecord.user_id == post["data"]["author_id"]).first()
)
logging.info(is_userExist)
if is_userExist is None:
user_data = (
post["includes"]["users"][0]
if "includes" in post and "users" in post["includes"] and len(post["includes"]["users"]) > 0
else {}
)
db_user = RowUserRecord(
row_user = RowUserRecord(
user_id=post["data"]["author_id"],
name=user_data.get("name"),
user_name=user_data.get("username"),
Expand All @@ -142,17 +146,15 @@ def extract_data(db: Session):
location=user_data.get("location", ""),
url=user_data.get("url", ""),
)
db.add(db_user)
postgresql.add(row_user)

media_data = (
post["includes"]["media"]
if "includes" in post and "media" in post["includes"] and len(post["includes"]["media"]) > 0
else []
)

print(media_data)

db_post = RowPostRecord(
row_post = RowPostRecord(
post_id=post["data"]["id"],
author_id=post["data"]["author_id"],
text=post["data"]["text"],
Expand All @@ -165,7 +167,8 @@ def extract_data(db: Session):
reply_count=post["data"]["public_metrics"]["reply_count"],
lang=post["data"]["lang"],
)
db.add(db_post)
postgresql.add(row_post)
postgresql.commit()

media_recs = [
RowPostMediaRecord(
Expand All @@ -178,7 +181,7 @@ def extract_data(db: Session):
)
for m in media_data
]
db.add_all(media_recs)
postgresql.add_all(media_recs)

if "entities" in post["data"] and "urls" in post["data"]["entities"]:
for url in post["data"]["entities"]["urls"]:
Expand All @@ -189,12 +192,9 @@ def extract_data(db: Session):
expanded_url=url["expanded_url"] if url["expanded_url"] else None,
unwound_url=url["unwound_url"] if url["unwound_url"] else None,
)
db.add(post_url)
postgresql.add(post_url)
note.row_post_id = tweet_id
db.commit()
postgresql.commit()
continue

# select note from db, get relation tweet and user data
note = db.query(RowNoteRecord).filter(RowNoteRecord.tweet_id == "1797617478950170784").first()

return
24 changes: 20 additions & 4 deletions etl/src/birdxplorer_etl/lib/sqlite/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)


def init_db():
def init_sqlite():
# ToDo: dbファイルをS3など外部に置く必要がある。
db_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "data", "note.db"))
logging.info(f"Initializing database at {db_path}")
Expand All @@ -25,12 +25,28 @@ def init_db():
if not inspect(engine).has_table("row_notes"):
logging.info("Creating table note")
RowNoteRecord.metadata.create_all(engine)
if not inspect(engine).has_table("row_posts"):
logging.info("Creating table post")
RowPostRecord.metadata.create_all(engine)
if not inspect(engine).has_table("row_note_status"):
logging.info("Creating table note_status")
RowNoteStatusRecord.metadata.create_all(engine)

Session = sessionmaker(bind=engine)

return Session()


def init_postgresql():
db_host = os.getenv("DB_HOST", "localhost")
db_port = os.getenv("DB_PORT", "5432")
db_user = os.getenv("DB_USER", "postgres")
db_pass = os.getenv("DB_PASS", "birdxplorer")
db_name = os.getenv("DB_NAME", "postgres")

logging.info(f"Initializing database at {db_host}:{db_port}/{db_name}")
engine = create_engine(f"postgresql://{db_user}:{db_pass}@{db_host}:{db_port}/{db_name}")

if not inspect(engine).has_table("row_posts"):
logging.info("Creating table post")
RowPostRecord.metadata.create_all(engine)
if not inspect(engine).has_table("row_users"):
logging.info("Creating table user")
RowUserRecord.metadata.create_all(engine)
Expand Down
12 changes: 8 additions & 4 deletions etl/src/birdxplorer_etl/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from lib.sqlite.init import init_db
from lib.sqlite.init import init_sqlite, init_postgresql
from extract import extract_data
from load import load_data
from transform import transform_data
import logging

logging.basicConfig(level=logging.INFO)

if __name__ == "__main__":
db = init_db()
extract_data(db)
transform_data(db)
sqlite = init_sqlite()
postgresql = init_postgresql()
extract_data(sqlite, postgresql)
transform_data(sqlite, postgresql)
load_data()
49 changes: 24 additions & 25 deletions etl/src/birdxplorer_etl/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from pathlib import Path
from typing import Generator

from prefect import get_run_logger
from sqlalchemy import Integer, and_, func, select
from sqlalchemy import Integer, Numeric, and_, func, select
from sqlalchemy.orm import Session

from birdxplorer_common.storage import (
Expand All @@ -25,7 +24,7 @@
)


def transform_data(db: Session):
def transform_data(sqlite: Session, postgresql: Session):

logging.info("Transforming data")

Expand All @@ -44,7 +43,7 @@ def transform_data(db: Session):
ai_service = get_ai_service()

num_of_notes = (
db.query(func.count(RowNoteRecord.note_id))
sqlite.query(func.count(RowNoteRecord.note_id))
.filter(
and_(
RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND,
Expand All @@ -58,7 +57,7 @@ def transform_data(db: Session):

logging.info(f"Transforming note data: {num_of_notes}")
while offset < num_of_notes:
notes = db.execute(
notes = sqlite.execute(
select(
RowNoteRecord.note_id,
RowNoteRecord.row_post_id,
Expand Down Expand Up @@ -90,15 +89,15 @@ def transform_data(db: Session):
offset = 0
limit = 1000

num_of_posts = db.query(func.count(RowPostRecord.post_id)).scalar()
num_of_posts = postgresql.query(func.count(RowPostRecord.post_id)).scalar()

while offset < num_of_posts:
posts = db.execute(
posts = postgresql.execute(
select(
RowPostRecord.post_id,
RowPostRecord.author_id.label("user_id"),
RowPostRecord.text,
func.cast(RowPostRecord.created_at, Integer).label("created_at"),
func.cast(RowPostRecord.created_at, Numeric).label("created_at"),
func.cast(RowPostRecord.like_count, Integer).label("like_count"),
func.cast(RowPostRecord.repost_count, Integer).label("repost_count"),
func.cast(RowPostRecord.impression_count, Integer).label("impression_count"),
Expand Down Expand Up @@ -131,10 +130,10 @@ def transform_data(db: Session):
offset = 0
limit = 1000

num_of_users = db.query(func.count(RowUserRecord.user_id)).scalar()
num_of_users = postgresql.query(func.count(RowUserRecord.user_id)).scalar()

while offset < num_of_users:
users = db.execute(
users = postgresql.execute(
select(
RowUserRecord.user_id,
RowUserRecord.user_name.label("name"),
Expand All @@ -153,8 +152,8 @@ def transform_data(db: Session):
offset += limit

# Transform row post embed link
write_media_csv(db)
generate_post_link(db)
write_media_csv(postgresql)
generate_post_link(postgresql)

# Transform row post embed url data and generate post_embed_url.csv
csv_seed_file_path = "./seed/topic_seed.csv"
Expand All @@ -180,12 +179,12 @@ def transform_data(db: Session):
for record in records:
writer.writerow({"topic_id": record["topic_id"], "label": {k: v for k, v in record["label"].items()}})

generate_note_topic(db)
generate_note_topic(sqlite)

return


def write_media_csv(db: Session) -> None:
def write_media_csv(postgresql: Session) -> None:
media_csv_path = Path("./data/transformed/media.csv")
post_media_association_csv_path = Path("./data/transformed/post_media_association.csv")

Expand All @@ -205,7 +204,7 @@ def write_media_csv(db: Session) -> None:
assoc_writer = csv.DictWriter(assoc_csv, fieldnames=assoc_fields)
assoc_writer.writeheader()

for m in _iterate_media(db):
for m in _iterate_media(postgresql):
media_writer.writerow(
{
"media_key": m.media_key,
Expand All @@ -219,17 +218,17 @@ def write_media_csv(db: Session) -> None:
assoc_writer.writerow({"post_id": m.post_id, "media_key": m.media_key})


def _iterate_media(db: Session, limit: int = 1000) -> Generator[RowPostMediaRecord, None, None]:
def _iterate_media(postgresql: Session, limit: int = 1000) -> Generator[RowPostMediaRecord, None, None]:
offset = 0
total_media: int = db.query(func.count(RowPostMediaRecord.media_key)).scalar() or 0
total_media: int = postgresql.query(func.count(RowPostMediaRecord.media_key)).scalar() or 0

while offset < total_media:
yield from db.query(RowPostMediaRecord).limit(limit).offset(offset)
yield from postgresql.query(RowPostMediaRecord).limit(limit).offset(offset)

offset += limit


def generate_post_link(db: Session):
def generate_post_link(postgresql: Session):
link_csv_file_path = "./data/transformed/post_link.csv"
association_csv_file_path = "./data/transformed/post_link_association.csv"

Expand All @@ -249,11 +248,11 @@ def generate_post_link(db: Session):

offset = 0
limit = 1000
num_of_links = db.query(func.count(RowPostEmbedURLRecord.post_id)).scalar()
num_of_links = postgresql.query(func.count(RowPostEmbedURLRecord.post_id)).scalar()

records = []
while offset < num_of_links:
links = db.query(RowPostEmbedURLRecord).limit(limit).offset(offset)
links = postgresql.query(RowPostEmbedURLRecord).limit(limit).offset(offset)

for link in links:
random.seed(link.unwound_url)
Expand All @@ -273,7 +272,7 @@ def generate_post_link(db: Session):
offset += limit


def generate_note_topic(db: Session):
def generate_note_topic(sqlite: Session):
output_csv_file_path = "./data/transformed/note_topic_association.csv"
ai_service = get_ai_service()

Expand All @@ -289,10 +288,10 @@ def generate_note_topic(db: Session):
offset = 0
limit = 1000

num_of_users = db.query(func.count(RowUserRecord.user_id)).scalar()
num_of_notes = sqlite.query(func.count(RowNoteRecord.row_post_id)).scalar()

while offset < num_of_users:
topicEstimationTargetNotes = db.execute(
while offset < num_of_notes:
topicEstimationTargetNotes = sqlite.execute(
select(RowNoteRecord.note_id, RowNoteRecord.row_post_id, RowNoteRecord.summary)
.filter(
and_(
Expand Down

0 comments on commit b147a78

Please sign in to comment.