diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index a8da554..6c34ed0 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -254,7 +254,7 @@ class RowUserRecord(Base): followers_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) following_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) tweet_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) - verified: Mapped[BinaryBool] = mapped_column(nullable=False) + verified: Mapped[bool] = mapped_column(nullable=False) verified_type: Mapped[String] = mapped_column(nullable=False) location: Mapped[String] = mapped_column(nullable=False) url: Mapped[String] = mapped_column(nullable=False) diff --git a/etl/src/birdxplorer_etl/extract.py b/etl/src/birdxplorer_etl/extract.py index 2dd0976..5a276a4 100644 --- a/etl/src/birdxplorer_etl/extract.py +++ b/etl/src/birdxplorer_etl/extract.py @@ -16,16 +16,16 @@ import settings -def extract_data(db: Session): +def extract_data(sqlite: Session, postgresql: Session): logging.info("Downloading community notes data") # get columns of post table - columns = db.query(RowUserRecord).statement.columns.keys() + columns = sqlite.query(RowUserRecord).statement.columns.keys() logging.info(columns) # Noteデータを取得してSQLiteに保存 date = datetime.now() - latest_note = db.query(RowNoteRecord).order_by(RowNoteRecord.created_at_millis.desc()).first() + latest_note = sqlite.query(RowNoteRecord).order_by(RowNoteRecord.created_at_millis.desc()).first() while True: if ( @@ -46,20 +46,20 @@ def extract_data(db: Session): res = requests.get(note_url) if res.status_code == 200: - # res.contentをdbのNoteテーブル + # res.contentをsqliteのNoteテーブル tsv_data = res.content.decode("utf-8").splitlines() reader = csv.DictReader(tsv_data, delimiter="\t") reader.fieldnames = [stringcase.snakecase(field) for field in reader.fieldnames] rows_to_add = [] for index, row in enumerate(reader): - if db.query(RowNoteRecord).filter(RowNoteRecord.note_id == row["note_id"]).first(): + if sqlite.query(RowNoteRecord).filter(RowNoteRecord.note_id == row["note_id"]).first(): continue rows_to_add.append(RowNoteRecord(**row)) if index % 1000 == 0: - db.bulk_save_objects(rows_to_add) + sqlite.bulk_save_objects(rows_to_add) rows_to_add = [] - db.bulk_save_objects(rows_to_add) + sqlite.bulk_save_objects(rows_to_add) status_url = f"https://ton.twimg.com/birdwatch-public-data/{dateString}/noteStatusHistory/noteStatusHistory-00000.tsv" if settings.USE_DUMMY_DATA: @@ -78,34 +78,36 @@ def extract_data(db: Session): for key, value in list(row.items()): if value == "": row[key] = None - status = db.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).first() + status = ( + sqlite.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).first() + ) if status is None or status.created_at_millis > int(datetime.now().timestamp() * 1000): - db.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).delete() + sqlite.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).delete() rows_to_add.append(RowNoteStatusRecord(**row)) if index % 1000 == 0: - db.bulk_save_objects(rows_to_add) + sqlite.bulk_save_objects(rows_to_add) rows_to_add = [] - db.bulk_save_objects(rows_to_add) + sqlite.bulk_save_objects(rows_to_add) break date = date - timedelta(days=1) - db.commit() + sqlite.commit() # Noteに紐づくtweetデータを取得 postExtract_targetNotes = ( - db.query(RowNoteRecord) + sqlite.query(RowNoteRecord) .filter(RowNoteRecord.tweet_id != None) .filter(RowNoteRecord.created_at_millis >= settings.TARGET_TWITTER_POST_START_UNIX_MILLISECOND) .filter(RowNoteRecord.created_at_millis <= settings.TARGET_TWITTER_POST_END_UNIX_MILLISECOND) .all() ) - logging.info(len(postExtract_targetNotes)) + logging.info(f"Target notes: {len(postExtract_targetNotes)}") for note in postExtract_targetNotes: tweet_id = note.tweet_id - is_tweetExist = db.query(RowPostRecord).filter(RowPostRecord.post_id == str(tweet_id)).first() + is_tweetExist = postgresql.query(RowPostRecord).filter(RowPostRecord.post_id == str(tweet_id)).first() if is_tweetExist is not None: logging.info(f"tweet_id {tweet_id} is already exist") note.row_post_id = tweet_id @@ -120,7 +122,9 @@ def extract_data(db: Session): created_at = datetime.strptime(post["data"]["created_at"], "%Y-%m-%dT%H:%M:%S.%fZ") created_at_millis = int(created_at.timestamp() * 1000) - is_userExist = db.query(RowUserRecord).filter(RowUserRecord.user_id == post["data"]["author_id"]).first() + is_userExist = ( + postgresql.query(RowUserRecord).filter(RowUserRecord.user_id == post["data"]["author_id"]).first() + ) logging.info(is_userExist) if is_userExist is None: user_data = ( @@ -128,7 +132,7 @@ def extract_data(db: Session): if "includes" in post and "users" in post["includes"] and len(post["includes"]["users"]) > 0 else {} ) - db_user = RowUserRecord( + row_user = RowUserRecord( user_id=post["data"]["author_id"], name=user_data.get("name"), user_name=user_data.get("username"), @@ -142,7 +146,7 @@ def extract_data(db: Session): location=user_data.get("location", ""), url=user_data.get("url", ""), ) - db.add(db_user) + postgresql.add(row_user) media_data = ( post["includes"]["media"] @@ -150,9 +154,7 @@ def extract_data(db: Session): else [] ) - print(media_data) - - db_post = RowPostRecord( + row_post = RowPostRecord( post_id=post["data"]["id"], author_id=post["data"]["author_id"], text=post["data"]["text"], @@ -165,7 +167,8 @@ def extract_data(db: Session): reply_count=post["data"]["public_metrics"]["reply_count"], lang=post["data"]["lang"], ) - db.add(db_post) + postgresql.add(row_post) + postgresql.commit() media_recs = [ RowPostMediaRecord( @@ -178,7 +181,7 @@ def extract_data(db: Session): ) for m in media_data ] - db.add_all(media_recs) + postgresql.add_all(media_recs) if "entities" in post["data"] and "urls" in post["data"]["entities"]: for url in post["data"]["entities"]["urls"]: @@ -189,12 +192,9 @@ def extract_data(db: Session): expanded_url=url["expanded_url"] if url["expanded_url"] else None, unwound_url=url["unwound_url"] if url["unwound_url"] else None, ) - db.add(post_url) + postgresql.add(post_url) note.row_post_id = tweet_id - db.commit() + postgresql.commit() continue - # select note from db, get relation tweet and user data - note = db.query(RowNoteRecord).filter(RowNoteRecord.tweet_id == "1797617478950170784").first() - return diff --git a/etl/src/birdxplorer_etl/lib/sqlite/init.py b/etl/src/birdxplorer_etl/lib/sqlite/init.py index 023da41..21c082a 100644 --- a/etl/src/birdxplorer_etl/lib/sqlite/init.py +++ b/etl/src/birdxplorer_etl/lib/sqlite/init.py @@ -14,7 +14,7 @@ ) -def init_db(): +def init_sqlite(): # ToDo: dbファイルをS3など外部に置く必要がある。 db_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "data", "note.db")) logging.info(f"Initializing database at {db_path}") @@ -25,12 +25,28 @@ def init_db(): if not inspect(engine).has_table("row_notes"): logging.info("Creating table note") RowNoteRecord.metadata.create_all(engine) - if not inspect(engine).has_table("row_posts"): - logging.info("Creating table post") - RowPostRecord.metadata.create_all(engine) if not inspect(engine).has_table("row_note_status"): logging.info("Creating table note_status") RowNoteStatusRecord.metadata.create_all(engine) + + Session = sessionmaker(bind=engine) + + return Session() + + +def init_postgresql(): + db_host = os.getenv("DB_HOST", "localhost") + db_port = os.getenv("DB_PORT", "5432") + db_user = os.getenv("DB_USER", "postgres") + db_pass = os.getenv("DB_PASS", "birdxplorer") + db_name = os.getenv("DB_NAME", "postgres") + + logging.info(f"Initializing database at {db_host}:{db_port}/{db_name}") + engine = create_engine(f"postgresql://{db_user}:{db_pass}@{db_host}:{db_port}/{db_name}") + + if not inspect(engine).has_table("row_posts"): + logging.info("Creating table post") + RowPostRecord.metadata.create_all(engine) if not inspect(engine).has_table("row_users"): logging.info("Creating table user") RowUserRecord.metadata.create_all(engine) diff --git a/etl/src/birdxplorer_etl/main.py b/etl/src/birdxplorer_etl/main.py index 5c5c50f..b484e39 100644 --- a/etl/src/birdxplorer_etl/main.py +++ b/etl/src/birdxplorer_etl/main.py @@ -1,10 +1,14 @@ -from lib.sqlite.init import init_db +from lib.sqlite.init import init_sqlite, init_postgresql from extract import extract_data from load import load_data from transform import transform_data +import logging + +logging.basicConfig(level=logging.INFO) if __name__ == "__main__": - db = init_db() - extract_data(db) - transform_data(db) + sqlite = init_sqlite() + postgresql = init_postgresql() + extract_data(sqlite, postgresql) + transform_data(sqlite, postgresql) load_data() diff --git a/etl/src/birdxplorer_etl/transform.py b/etl/src/birdxplorer_etl/transform.py index 9fdbb6f..ca450d3 100644 --- a/etl/src/birdxplorer_etl/transform.py +++ b/etl/src/birdxplorer_etl/transform.py @@ -6,8 +6,7 @@ from pathlib import Path from typing import Generator -from prefect import get_run_logger -from sqlalchemy import Integer, and_, func, select +from sqlalchemy import Integer, Numeric, and_, func, select from sqlalchemy.orm import Session from birdxplorer_common.storage import ( @@ -25,7 +24,7 @@ ) -def transform_data(db: Session): +def transform_data(sqlite: Session, postgresql: Session): logging.info("Transforming data") @@ -44,7 +43,7 @@ def transform_data(db: Session): ai_service = get_ai_service() num_of_notes = ( - db.query(func.count(RowNoteRecord.note_id)) + sqlite.query(func.count(RowNoteRecord.note_id)) .filter( and_( RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, @@ -58,7 +57,7 @@ def transform_data(db: Session): logging.info(f"Transforming note data: {num_of_notes}") while offset < num_of_notes: - notes = db.execute( + notes = sqlite.execute( select( RowNoteRecord.note_id, RowNoteRecord.row_post_id, @@ -90,15 +89,15 @@ def transform_data(db: Session): offset = 0 limit = 1000 - num_of_posts = db.query(func.count(RowPostRecord.post_id)).scalar() + num_of_posts = postgresql.query(func.count(RowPostRecord.post_id)).scalar() while offset < num_of_posts: - posts = db.execute( + posts = postgresql.execute( select( RowPostRecord.post_id, RowPostRecord.author_id.label("user_id"), RowPostRecord.text, - func.cast(RowPostRecord.created_at, Integer).label("created_at"), + func.cast(RowPostRecord.created_at, Numeric).label("created_at"), func.cast(RowPostRecord.like_count, Integer).label("like_count"), func.cast(RowPostRecord.repost_count, Integer).label("repost_count"), func.cast(RowPostRecord.impression_count, Integer).label("impression_count"), @@ -131,10 +130,10 @@ def transform_data(db: Session): offset = 0 limit = 1000 - num_of_users = db.query(func.count(RowUserRecord.user_id)).scalar() + num_of_users = postgresql.query(func.count(RowUserRecord.user_id)).scalar() while offset < num_of_users: - users = db.execute( + users = postgresql.execute( select( RowUserRecord.user_id, RowUserRecord.user_name.label("name"), @@ -153,8 +152,8 @@ def transform_data(db: Session): offset += limit # Transform row post embed link - write_media_csv(db) - generate_post_link(db) + write_media_csv(postgresql) + generate_post_link(postgresql) # Transform row post embed url data and generate post_embed_url.csv csv_seed_file_path = "./seed/topic_seed.csv" @@ -180,12 +179,12 @@ def transform_data(db: Session): for record in records: writer.writerow({"topic_id": record["topic_id"], "label": {k: v for k, v in record["label"].items()}}) - generate_note_topic(db) + generate_note_topic(sqlite) return -def write_media_csv(db: Session) -> None: +def write_media_csv(postgresql: Session) -> None: media_csv_path = Path("./data/transformed/media.csv") post_media_association_csv_path = Path("./data/transformed/post_media_association.csv") @@ -205,7 +204,7 @@ def write_media_csv(db: Session) -> None: assoc_writer = csv.DictWriter(assoc_csv, fieldnames=assoc_fields) assoc_writer.writeheader() - for m in _iterate_media(db): + for m in _iterate_media(postgresql): media_writer.writerow( { "media_key": m.media_key, @@ -219,17 +218,17 @@ def write_media_csv(db: Session) -> None: assoc_writer.writerow({"post_id": m.post_id, "media_key": m.media_key}) -def _iterate_media(db: Session, limit: int = 1000) -> Generator[RowPostMediaRecord, None, None]: +def _iterate_media(postgresql: Session, limit: int = 1000) -> Generator[RowPostMediaRecord, None, None]: offset = 0 - total_media: int = db.query(func.count(RowPostMediaRecord.media_key)).scalar() or 0 + total_media: int = postgresql.query(func.count(RowPostMediaRecord.media_key)).scalar() or 0 while offset < total_media: - yield from db.query(RowPostMediaRecord).limit(limit).offset(offset) + yield from postgresql.query(RowPostMediaRecord).limit(limit).offset(offset) offset += limit -def generate_post_link(db: Session): +def generate_post_link(postgresql: Session): link_csv_file_path = "./data/transformed/post_link.csv" association_csv_file_path = "./data/transformed/post_link_association.csv" @@ -249,11 +248,11 @@ def generate_post_link(db: Session): offset = 0 limit = 1000 - num_of_links = db.query(func.count(RowPostEmbedURLRecord.post_id)).scalar() + num_of_links = postgresql.query(func.count(RowPostEmbedURLRecord.post_id)).scalar() records = [] while offset < num_of_links: - links = db.query(RowPostEmbedURLRecord).limit(limit).offset(offset) + links = postgresql.query(RowPostEmbedURLRecord).limit(limit).offset(offset) for link in links: random.seed(link.unwound_url) @@ -273,7 +272,7 @@ def generate_post_link(db: Session): offset += limit -def generate_note_topic(db: Session): +def generate_note_topic(sqlite: Session): output_csv_file_path = "./data/transformed/note_topic_association.csv" ai_service = get_ai_service() @@ -289,10 +288,10 @@ def generate_note_topic(db: Session): offset = 0 limit = 1000 - num_of_users = db.query(func.count(RowUserRecord.user_id)).scalar() + num_of_notes = sqlite.query(func.count(RowNoteRecord.row_post_id)).scalar() - while offset < num_of_users: - topicEstimationTargetNotes = db.execute( + while offset < num_of_notes: + topicEstimationTargetNotes = sqlite.execute( select(RowNoteRecord.note_id, RowNoteRecord.row_post_id, RowNoteRecord.summary) .filter( and_(