From e010a912ed5e9f82c25e04a2a8727d0f7f22682e Mon Sep 17 00:00:00 2001 From: yu23ki14 Date: Fri, 11 Oct 2024 13:49:21 +0900 Subject: [PATCH 1/2] tmp --- compose.yml | 1 + etl/src/birdxplorer_etl/extract.py | 55 +++++++++++----------- etl/src/birdxplorer_etl/lib/sqlite/init.py | 24 ++++++++-- etl/src/birdxplorer_etl/main.py | 9 ++-- etl/src/birdxplorer_etl/transform.py | 47 +++++++++--------- init.sql | 1 + 6 files changed, 77 insertions(+), 60 deletions(-) create mode 100644 init.sql diff --git a/compose.yml b/compose.yml index df2676f..26dcd32 100644 --- a/compose.yml +++ b/compose.yml @@ -17,6 +17,7 @@ services: - "5432:5432" volumes: - postgres_data:/var/lib/postgresql/data + - ./init.sql:/docker-entrypoint-initdb.d/init.sql app: depends_on: db: diff --git a/etl/src/birdxplorer_etl/extract.py b/etl/src/birdxplorer_etl/extract.py index 2dd0976..6506a53 100644 --- a/etl/src/birdxplorer_etl/extract.py +++ b/etl/src/birdxplorer_etl/extract.py @@ -16,16 +16,16 @@ import settings -def extract_data(db: Session): +def extract_data(sqlite: Session, postgresql: Session): logging.info("Downloading community notes data") # get columns of post table - columns = db.query(RowUserRecord).statement.columns.keys() + columns = sqlite.query(RowUserRecord).statement.columns.keys() logging.info(columns) # Noteデータを取得してSQLiteに保存 date = datetime.now() - latest_note = db.query(RowNoteRecord).order_by(RowNoteRecord.created_at_millis.desc()).first() + latest_note = sqlite.query(RowNoteRecord).order_by(RowNoteRecord.created_at_millis.desc()).first() while True: if ( @@ -46,20 +46,20 @@ def extract_data(db: Session): res = requests.get(note_url) if res.status_code == 200: - # res.contentをdbのNoteテーブル + # res.contentをsqliteのNoteテーブル tsv_data = res.content.decode("utf-8").splitlines() reader = csv.DictReader(tsv_data, delimiter="\t") reader.fieldnames = [stringcase.snakecase(field) for field in reader.fieldnames] rows_to_add = [] for index, row in enumerate(reader): - if db.query(RowNoteRecord).filter(RowNoteRecord.note_id == row["note_id"]).first(): + if sqlite.query(RowNoteRecord).filter(RowNoteRecord.note_id == row["note_id"]).first(): continue rows_to_add.append(RowNoteRecord(**row)) if index % 1000 == 0: - db.bulk_save_objects(rows_to_add) + sqlite.bulk_save_objects(rows_to_add) rows_to_add = [] - db.bulk_save_objects(rows_to_add) + sqlite.bulk_save_objects(rows_to_add) status_url = f"https://ton.twimg.com/birdwatch-public-data/{dateString}/noteStatusHistory/noteStatusHistory-00000.tsv" if settings.USE_DUMMY_DATA: @@ -78,34 +78,36 @@ def extract_data(db: Session): for key, value in list(row.items()): if value == "": row[key] = None - status = db.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).first() + status = ( + sqlite.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).first() + ) if status is None or status.created_at_millis > int(datetime.now().timestamp() * 1000): - db.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).delete() + sqlite.query(RowNoteStatusRecord).filter(RowNoteStatusRecord.note_id == row["note_id"]).delete() rows_to_add.append(RowNoteStatusRecord(**row)) if index % 1000 == 0: - db.bulk_save_objects(rows_to_add) + sqlite.bulk_save_objects(rows_to_add) rows_to_add = [] - db.bulk_save_objects(rows_to_add) + sqlite.bulk_save_objects(rows_to_add) break date = date - timedelta(days=1) - db.commit() + sqlite.commit() # Noteに紐づくtweetデータを取得 postExtract_targetNotes = ( - db.query(RowNoteRecord) + sqlite.query(RowNoteRecord) .filter(RowNoteRecord.tweet_id != None) .filter(RowNoteRecord.created_at_millis >= settings.TARGET_TWITTER_POST_START_UNIX_MILLISECOND) .filter(RowNoteRecord.created_at_millis <= settings.TARGET_TWITTER_POST_END_UNIX_MILLISECOND) .all() ) - logging.info(len(postExtract_targetNotes)) + logging.info("Target notes: ", len(postExtract_targetNotes)) for note in postExtract_targetNotes: tweet_id = note.tweet_id - is_tweetExist = db.query(RowPostRecord).filter(RowPostRecord.post_id == str(tweet_id)).first() + is_tweetExist = postgresql.query(RowPostRecord).filter(RowPostRecord.post_id == str(tweet_id)).first() if is_tweetExist is not None: logging.info(f"tweet_id {tweet_id} is already exist") note.row_post_id = tweet_id @@ -120,7 +122,9 @@ def extract_data(db: Session): created_at = datetime.strptime(post["data"]["created_at"], "%Y-%m-%dT%H:%M:%S.%fZ") created_at_millis = int(created_at.timestamp() * 1000) - is_userExist = db.query(RowUserRecord).filter(RowUserRecord.user_id == post["data"]["author_id"]).first() + is_userExist = ( + postgresql.query(RowUserRecord).filter(RowUserRecord.user_id == post["data"]["author_id"]).first() + ) logging.info(is_userExist) if is_userExist is None: user_data = ( @@ -128,7 +132,7 @@ def extract_data(db: Session): if "includes" in post and "users" in post["includes"] and len(post["includes"]["users"]) > 0 else {} ) - db_user = RowUserRecord( + row_user = RowUserRecord( user_id=post["data"]["author_id"], name=user_data.get("name"), user_name=user_data.get("username"), @@ -142,7 +146,7 @@ def extract_data(db: Session): location=user_data.get("location", ""), url=user_data.get("url", ""), ) - db.add(db_user) + postgresql.add(row_user) media_data = ( post["includes"]["media"] @@ -150,9 +154,7 @@ def extract_data(db: Session): else [] ) - print(media_data) - - db_post = RowPostRecord( + row_post = RowPostRecord( post_id=post["data"]["id"], author_id=post["data"]["author_id"], text=post["data"]["text"], @@ -165,7 +167,7 @@ def extract_data(db: Session): reply_count=post["data"]["public_metrics"]["reply_count"], lang=post["data"]["lang"], ) - db.add(db_post) + postgresql.add(row_post) media_recs = [ RowPostMediaRecord( @@ -178,7 +180,7 @@ def extract_data(db: Session): ) for m in media_data ] - db.add_all(media_recs) + postgresql.add_all(media_recs) if "entities" in post["data"] and "urls" in post["data"]["entities"]: for url in post["data"]["entities"]["urls"]: @@ -189,12 +191,9 @@ def extract_data(db: Session): expanded_url=url["expanded_url"] if url["expanded_url"] else None, unwound_url=url["unwound_url"] if url["unwound_url"] else None, ) - db.add(post_url) + postgresql.add(post_url) note.row_post_id = tweet_id - db.commit() + postgresql.commit() continue - # select note from db, get relation tweet and user data - note = db.query(RowNoteRecord).filter(RowNoteRecord.tweet_id == "1797617478950170784").first() - return diff --git a/etl/src/birdxplorer_etl/lib/sqlite/init.py b/etl/src/birdxplorer_etl/lib/sqlite/init.py index 023da41..5ca7124 100644 --- a/etl/src/birdxplorer_etl/lib/sqlite/init.py +++ b/etl/src/birdxplorer_etl/lib/sqlite/init.py @@ -14,7 +14,7 @@ ) -def init_db(): +def init_sqlite(): # ToDo: dbファイルをS3など外部に置く必要がある。 db_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "data", "note.db")) logging.info(f"Initializing database at {db_path}") @@ -25,12 +25,28 @@ def init_db(): if not inspect(engine).has_table("row_notes"): logging.info("Creating table note") RowNoteRecord.metadata.create_all(engine) - if not inspect(engine).has_table("row_posts"): - logging.info("Creating table post") - RowPostRecord.metadata.create_all(engine) if not inspect(engine).has_table("row_note_status"): logging.info("Creating table note_status") RowNoteStatusRecord.metadata.create_all(engine) + + Session = sessionmaker(bind=engine) + + return Session() + + +def init_postgresql(): + db_host = os.getenv("DB_HOST", "localhost") + db_port = os.getenv("DB_PORT", "5432") + db_user = os.getenv("DB_USER", "birdxplorer") + db_pass = os.getenv("DB_PASS", "birdxplorer") + db_name = os.getenv("DB_NAME", "etl") + + logging.info(f"Initializing database at {db_host}:{db_port}/{db_name}") + engine = create_engine(f"postgresql://{db_user}:{db_pass}@{db_host}:{db_port}/{db_name}") + + if not inspect(engine).has_table("row_posts"): + logging.info("Creating table post") + RowPostRecord.metadata.create_all(engine) if not inspect(engine).has_table("row_users"): logging.info("Creating table user") RowUserRecord.metadata.create_all(engine) diff --git a/etl/src/birdxplorer_etl/main.py b/etl/src/birdxplorer_etl/main.py index 5c5c50f..89a9c5e 100644 --- a/etl/src/birdxplorer_etl/main.py +++ b/etl/src/birdxplorer_etl/main.py @@ -1,10 +1,11 @@ -from lib.sqlite.init import init_db +from lib.sqlite.init import init_sqlite, init_postgresql from extract import extract_data from load import load_data from transform import transform_data if __name__ == "__main__": - db = init_db() - extract_data(db) - transform_data(db) + sqlite = init_sqlite() + postgresql = init_postgresql() + extract_data(sqlite, postgresql) + transform_data(sqlite, postgresql) load_data() diff --git a/etl/src/birdxplorer_etl/transform.py b/etl/src/birdxplorer_etl/transform.py index 9fdbb6f..c5134e9 100644 --- a/etl/src/birdxplorer_etl/transform.py +++ b/etl/src/birdxplorer_etl/transform.py @@ -6,7 +6,6 @@ from pathlib import Path from typing import Generator -from prefect import get_run_logger from sqlalchemy import Integer, and_, func, select from sqlalchemy.orm import Session @@ -25,7 +24,7 @@ ) -def transform_data(db: Session): +def transform_data(sqlite: Session, postgresql: Session): logging.info("Transforming data") @@ -44,7 +43,7 @@ def transform_data(db: Session): ai_service = get_ai_service() num_of_notes = ( - db.query(func.count(RowNoteRecord.note_id)) + sqlite.query(func.count(RowNoteRecord.note_id)) .filter( and_( RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, @@ -58,7 +57,7 @@ def transform_data(db: Session): logging.info(f"Transforming note data: {num_of_notes}") while offset < num_of_notes: - notes = db.execute( + notes = sqlite.execute( select( RowNoteRecord.note_id, RowNoteRecord.row_post_id, @@ -90,10 +89,10 @@ def transform_data(db: Session): offset = 0 limit = 1000 - num_of_posts = db.query(func.count(RowPostRecord.post_id)).scalar() + num_of_posts = postgresql.query(func.count(RowPostRecord.post_id)).scalar() while offset < num_of_posts: - posts = db.execute( + posts = postgresql.execute( select( RowPostRecord.post_id, RowPostRecord.author_id.label("user_id"), @@ -131,10 +130,10 @@ def transform_data(db: Session): offset = 0 limit = 1000 - num_of_users = db.query(func.count(RowUserRecord.user_id)).scalar() + num_of_users = postgresql.query(func.count(RowUserRecord.user_id)).scalar() while offset < num_of_users: - users = db.execute( + users = postgresql.execute( select( RowUserRecord.user_id, RowUserRecord.user_name.label("name"), @@ -153,8 +152,8 @@ def transform_data(db: Session): offset += limit # Transform row post embed link - write_media_csv(db) - generate_post_link(db) + write_media_csv(postgresql) + generate_post_link(postgresql) # Transform row post embed url data and generate post_embed_url.csv csv_seed_file_path = "./seed/topic_seed.csv" @@ -180,12 +179,12 @@ def transform_data(db: Session): for record in records: writer.writerow({"topic_id": record["topic_id"], "label": {k: v for k, v in record["label"].items()}}) - generate_note_topic(db) + generate_note_topic(sqlite) return -def write_media_csv(db: Session) -> None: +def write_media_csv(postgresql: Session) -> None: media_csv_path = Path("./data/transformed/media.csv") post_media_association_csv_path = Path("./data/transformed/post_media_association.csv") @@ -205,7 +204,7 @@ def write_media_csv(db: Session) -> None: assoc_writer = csv.DictWriter(assoc_csv, fieldnames=assoc_fields) assoc_writer.writeheader() - for m in _iterate_media(db): + for m in _iterate_media(postgresql): media_writer.writerow( { "media_key": m.media_key, @@ -219,17 +218,17 @@ def write_media_csv(db: Session) -> None: assoc_writer.writerow({"post_id": m.post_id, "media_key": m.media_key}) -def _iterate_media(db: Session, limit: int = 1000) -> Generator[RowPostMediaRecord, None, None]: +def _iterate_media(postgresql: Session, limit: int = 1000) -> Generator[RowPostMediaRecord, None, None]: offset = 0 - total_media: int = db.query(func.count(RowPostMediaRecord.media_key)).scalar() or 0 + total_media: int = postgresql.query(func.count(RowPostMediaRecord.media_key)).scalar() or 0 while offset < total_media: - yield from db.query(RowPostMediaRecord).limit(limit).offset(offset) + yield from postgresql.query(RowPostMediaRecord).limit(limit).offset(offset) offset += limit -def generate_post_link(db: Session): +def generate_post_link(postgresql: Session): link_csv_file_path = "./data/transformed/post_link.csv" association_csv_file_path = "./data/transformed/post_link_association.csv" @@ -249,15 +248,15 @@ def generate_post_link(db: Session): offset = 0 limit = 1000 - num_of_links = db.query(func.count(RowPostEmbedURLRecord.post_id)).scalar() + num_of_links = postgresql.query(func.count(RowPostEmbedURLRecord.post_id)).scalar() records = [] while offset < num_of_links: - links = db.query(RowPostEmbedURLRecord).limit(limit).offset(offset) + links = postgresql.query(RowPostEmbedURLRecord).limit(limit).offset(offset) for link in links: random.seed(link.unwound_url) - link_id = uuid.UUID(int=random.getrandbits(128)) + link_id = uuid.UUID(int=random.getransqliteits(128)) is_link_exist = next((record for record in records if record["link_id"] == link_id), None) if is_link_exist is None: with open(link_csv_file_path, "a", newline="", encoding="utf-8") as file: @@ -273,7 +272,7 @@ def generate_post_link(db: Session): offset += limit -def generate_note_topic(db: Session): +def generate_note_topic(sqlite: Session): output_csv_file_path = "./data/transformed/note_topic_association.csv" ai_service = get_ai_service() @@ -289,10 +288,10 @@ def generate_note_topic(db: Session): offset = 0 limit = 1000 - num_of_users = db.query(func.count(RowUserRecord.user_id)).scalar() + num_of_notes = sqlite.query(func.count(RowNoteRecord.row_post_id)).scalar() - while offset < num_of_users: - topicEstimationTargetNotes = db.execute( + while offset < num_of_notes: + topicEstimationTargetNotes = sqlite.execute( select(RowNoteRecord.note_id, RowNoteRecord.row_post_id, RowNoteRecord.summary) .filter( and_( diff --git a/init.sql b/init.sql new file mode 100644 index 0000000..8b36c6d --- /dev/null +++ b/init.sql @@ -0,0 +1 @@ +CREATE DATABASE etl; \ No newline at end of file From afc80279c00bc8757d37755877c2d524c39ae304 Mon Sep 17 00:00:00 2001 From: yu23ki14 Date: Fri, 11 Oct 2024 15:28:04 +0900 Subject: [PATCH 2/2] use postgresql for user --- common/birdxplorer_common/storage.py | 2 +- compose.yml | 1 - etl/src/birdxplorer_etl/extract.py | 3 ++- etl/src/birdxplorer_etl/lib/sqlite/init.py | 4 ++-- etl/src/birdxplorer_etl/main.py | 3 +++ etl/src/birdxplorer_etl/transform.py | 6 +++--- init.sql | 1 - 7 files changed, 11 insertions(+), 9 deletions(-) delete mode 100644 init.sql diff --git a/common/birdxplorer_common/storage.py b/common/birdxplorer_common/storage.py index a8da554..6c34ed0 100644 --- a/common/birdxplorer_common/storage.py +++ b/common/birdxplorer_common/storage.py @@ -254,7 +254,7 @@ class RowUserRecord(Base): followers_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) following_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) tweet_count: Mapped[NonNegativeInt] = mapped_column(nullable=False) - verified: Mapped[BinaryBool] = mapped_column(nullable=False) + verified: Mapped[bool] = mapped_column(nullable=False) verified_type: Mapped[String] = mapped_column(nullable=False) location: Mapped[String] = mapped_column(nullable=False) url: Mapped[String] = mapped_column(nullable=False) diff --git a/compose.yml b/compose.yml index 26dcd32..df2676f 100644 --- a/compose.yml +++ b/compose.yml @@ -17,7 +17,6 @@ services: - "5432:5432" volumes: - postgres_data:/var/lib/postgresql/data - - ./init.sql:/docker-entrypoint-initdb.d/init.sql app: depends_on: db: diff --git a/etl/src/birdxplorer_etl/extract.py b/etl/src/birdxplorer_etl/extract.py index 6506a53..5a276a4 100644 --- a/etl/src/birdxplorer_etl/extract.py +++ b/etl/src/birdxplorer_etl/extract.py @@ -103,7 +103,7 @@ def extract_data(sqlite: Session, postgresql: Session): .filter(RowNoteRecord.created_at_millis <= settings.TARGET_TWITTER_POST_END_UNIX_MILLISECOND) .all() ) - logging.info("Target notes: ", len(postExtract_targetNotes)) + logging.info(f"Target notes: {len(postExtract_targetNotes)}") for note in postExtract_targetNotes: tweet_id = note.tweet_id @@ -168,6 +168,7 @@ def extract_data(sqlite: Session, postgresql: Session): lang=post["data"]["lang"], ) postgresql.add(row_post) + postgresql.commit() media_recs = [ RowPostMediaRecord( diff --git a/etl/src/birdxplorer_etl/lib/sqlite/init.py b/etl/src/birdxplorer_etl/lib/sqlite/init.py index 5ca7124..21c082a 100644 --- a/etl/src/birdxplorer_etl/lib/sqlite/init.py +++ b/etl/src/birdxplorer_etl/lib/sqlite/init.py @@ -37,9 +37,9 @@ def init_sqlite(): def init_postgresql(): db_host = os.getenv("DB_HOST", "localhost") db_port = os.getenv("DB_PORT", "5432") - db_user = os.getenv("DB_USER", "birdxplorer") + db_user = os.getenv("DB_USER", "postgres") db_pass = os.getenv("DB_PASS", "birdxplorer") - db_name = os.getenv("DB_NAME", "etl") + db_name = os.getenv("DB_NAME", "postgres") logging.info(f"Initializing database at {db_host}:{db_port}/{db_name}") engine = create_engine(f"postgresql://{db_user}:{db_pass}@{db_host}:{db_port}/{db_name}") diff --git a/etl/src/birdxplorer_etl/main.py b/etl/src/birdxplorer_etl/main.py index 89a9c5e..b484e39 100644 --- a/etl/src/birdxplorer_etl/main.py +++ b/etl/src/birdxplorer_etl/main.py @@ -2,6 +2,9 @@ from extract import extract_data from load import load_data from transform import transform_data +import logging + +logging.basicConfig(level=logging.INFO) if __name__ == "__main__": sqlite = init_sqlite() diff --git a/etl/src/birdxplorer_etl/transform.py b/etl/src/birdxplorer_etl/transform.py index c5134e9..ca450d3 100644 --- a/etl/src/birdxplorer_etl/transform.py +++ b/etl/src/birdxplorer_etl/transform.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Generator -from sqlalchemy import Integer, and_, func, select +from sqlalchemy import Integer, Numeric, and_, func, select from sqlalchemy.orm import Session from birdxplorer_common.storage import ( @@ -97,7 +97,7 @@ def transform_data(sqlite: Session, postgresql: Session): RowPostRecord.post_id, RowPostRecord.author_id.label("user_id"), RowPostRecord.text, - func.cast(RowPostRecord.created_at, Integer).label("created_at"), + func.cast(RowPostRecord.created_at, Numeric).label("created_at"), func.cast(RowPostRecord.like_count, Integer).label("like_count"), func.cast(RowPostRecord.repost_count, Integer).label("repost_count"), func.cast(RowPostRecord.impression_count, Integer).label("impression_count"), @@ -256,7 +256,7 @@ def generate_post_link(postgresql: Session): for link in links: random.seed(link.unwound_url) - link_id = uuid.UUID(int=random.getransqliteits(128)) + link_id = uuid.UUID(int=random.getrandbits(128)) is_link_exist = next((record for record in records if record["link_id"] == link_id), None) if is_link_exist is None: with open(link_csv_file_path, "a", newline="", encoding="utf-8") as file: diff --git a/init.sql b/init.sql deleted file mode 100644 index 8b36c6d..0000000 --- a/init.sql +++ /dev/null @@ -1 +0,0 @@ -CREATE DATABASE etl; \ No newline at end of file