diff --git a/scripts/migrations/convert_data_from_v1.py b/scripts/migrations/convert_data_from_v1.py new file mode 100644 index 0000000..566ad43 --- /dev/null +++ b/scripts/migrations/convert_data_from_v1.py @@ -0,0 +1,54 @@ +import csv +import json +import os +from argparse import ArgumentParser +from collections import Counter + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("notes_file") + parser.add_argument("output_dir") + parser.add_argument("--notes-file-name", default="notes.csv") + parser.add_argument("--topics-file-name", default="topics.csv") + parser.add_argument("--notes-topics-association-file-name", default="note_topic.csv") + parser.add_argument("--topic-threshold", type=int, default=5) + + args = parser.parse_args() + + with open(args.notes_file, "r", encoding="utf-8") as fin: + notes = list(csv.DictReader(fin)) + for d in notes: + d["topic"] = [t.strip() for t in d["topic"].split(",")] + topics_with_count = Counter(t for d in notes for t in d["topic"]) + topic_name_to_id_map = {t: i for i, (t, c) in enumerate(topics_with_count.items()) if c > args.topic_threshold} + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + with open(os.path.join(args.output_dir, args.topics_file_name), "w", encoding="utf-8") as fout: + writer = csv.DictWriter(fout, fieldnames=["topic_id", "label"]) + writer.writeheader() + for topic, topic_id in topic_name_to_id_map.items(): + writer.writerow({"topic_id": topic_id, "label": json.dumps({"ja": topic})}) + + with open(os.path.join(args.output_dir, args.notes_file_name), "w", encoding="utf-8") as fout: + writer = csv.DictWriter(fout, fieldnames=["note_id", "post_id", "language", "summary", "created_at"]) + writer.writeheader() + for d in notes: + writer.writerow( + { + "note_id": d["note_id"], + "post_id": d["post_id"], + "language": d["language"], + "summary": d["summary"], + "created_at": d["created_at"], + } + ) + + with open(os.path.join(args.output_dir, args.notes_topics_association_file_name), "w", encoding="utf-8") as fout: + writer = csv.DictWriter(fout, fieldnames=["note_id", "topic_id"]) + writer.writeheader() + for d in notes: + for t in d["topic"]: + if t in topic_name_to_id_map: + writer.writerow({"note_id": d["note_id"], "topic_id": topic_name_to_id_map[t]}) diff --git a/scripts/migrations/migrate_all.py b/scripts/migrations/migrate_all.py new file mode 100644 index 0000000..0c25d72 --- /dev/null +++ b/scripts/migrations/migrate_all.py @@ -0,0 +1,74 @@ +import csv +import json +import os +from argparse import ArgumentParser + +from dotenv import load_dotenv +from sqlalchemy.orm import Session + +from birdxplorer.logger import get_logger +from birdxplorer.settings import GlobalSettings +from birdxplorer.storage import ( + Base, + NoteRecord, + NoteTopicAssociation, + TopicRecord, + gen_storage, +) + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("data_dir") + parser.add_argument("--notes-file-name", default="notes.csv") + parser.add_argument("--topics-file-name", default="topics.csv") + parser.add_argument("--notes-topics-association-file-name", default="note_topic.csv") + load_dotenv() + args = parser.parse_args() + settings = GlobalSettings() + logger = get_logger(level=settings.logger_settings.level) + storage = gen_storage(settings=settings) + + Base.metadata.create_all(storage.engine) + with Session(storage.engine) as sess: + with open(os.path.join(args.data_dir, args.topics_file_name), "r", encoding="utf-8") as fin: + for d in csv.DictReader(fin): + d["topic_id"] = int(d["topic_id"]) + d["label"] = json.loads(d["label"]) + if sess.query(TopicRecord).filter(TopicRecord.topic_id == d["topic_id"]).count() > 0: + continue + sess.add(TopicRecord(topic_id=d["topic_id"], label=d["label"])) + sess.commit() + with open(os.path.join(args.data_dir, args.notes_file_name), "r", encoding="utf-8") as fin: + for d in csv.DictReader(fin): + if sess.query(NoteRecord).filter(NoteRecord.note_id == d["note_id"]).count() > 0: + continue + sess.add( + NoteRecord( + note_id=d["note_id"], + post_id=d["post_id"], + language=d["language"], + summary=d["summary"], + created_at=d["created_at"], + ) + ) + sess.commit() + with open(os.path.join(args.data_dir, args.notes_topics_association_file_name), "r", encoding="utf-8") as fin: + for d in csv.DictReader(fin): + if ( + sess.query(NoteTopicAssociation) + .filter( + NoteTopicAssociation.note_id == d["note_id"], + NoteTopicAssociation.topic_id == d["topic_id"], + ) + .count() + > 0 + ): + continue + sess.add( + NoteTopicAssociation( + note_id=d["note_id"], + topic_id=d["topic_id"], + ) + ) + sess.commit() + logger.info("Migration is done")