Skip to content

Commit

Permalink
feat(scripts): add migration scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
osoken committed Feb 26, 2024
1 parent 56f021b commit 8e7f6be
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
54 changes: 54 additions & 0 deletions scripts/migrations/convert_data_from_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import csv
import json
import os
from argparse import ArgumentParser
from collections import Counter

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("notes_file")
parser.add_argument("output_dir")
parser.add_argument("--notes-file-name", default="notes.csv")
parser.add_argument("--topics-file-name", default="topics.csv")
parser.add_argument("--notes-topics-association-file-name", default="note_topic.csv")
parser.add_argument("--topic-threshold", type=int, default=5)

args = parser.parse_args()

with open(args.notes_file, "r", encoding="utf-8") as fin:
notes = list(csv.DictReader(fin))
for d in notes:
d["topic"] = [t.strip() for t in d["topic"].split(",")]
topics_with_count = Counter(t for d in notes for t in d["topic"])
topic_name_to_id_map = {t: i for i, (t, c) in enumerate(topics_with_count.items()) if c > args.topic_threshold}

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)

with open(os.path.join(args.output_dir, args.topics_file_name), "w", encoding="utf-8") as fout:
writer = csv.DictWriter(fout, fieldnames=["topic_id", "label"])
writer.writeheader()
for topic, topic_id in topic_name_to_id_map.items():
writer.writerow({"topic_id": topic_id, "label": json.dumps({"ja": topic})})

with open(os.path.join(args.output_dir, args.notes_file_name), "w", encoding="utf-8") as fout:
writer = csv.DictWriter(fout, fieldnames=["note_id", "post_id", "language", "summary", "created_at"])
writer.writeheader()
for d in notes:
writer.writerow(
{
"note_id": d["note_id"],
"post_id": d["post_id"],
"language": d["language"],
"summary": d["summary"],
"created_at": d["created_at"],
}
)

with open(os.path.join(args.output_dir, args.notes_topics_association_file_name), "w", encoding="utf-8") as fout:
writer = csv.DictWriter(fout, fieldnames=["note_id", "topic_id"])
writer.writeheader()
for d in notes:
for t in d["topic"]:
if t in topic_name_to_id_map:
writer.writerow({"note_id": d["note_id"], "topic_id": topic_name_to_id_map[t]})
74 changes: 74 additions & 0 deletions scripts/migrations/migrate_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import csv
import json
import os
from argparse import ArgumentParser

from dotenv import load_dotenv
from sqlalchemy.orm import Session

from birdxplorer.logger import get_logger
from birdxplorer.settings import GlobalSettings
from birdxplorer.storage import (
Base,
NoteRecord,
NoteTopicAssociation,
TopicRecord,
gen_storage,
)

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("data_dir")
parser.add_argument("--notes-file-name", default="notes.csv")
parser.add_argument("--topics-file-name", default="topics.csv")
parser.add_argument("--notes-topics-association-file-name", default="note_topic.csv")
load_dotenv()
args = parser.parse_args()
settings = GlobalSettings()
logger = get_logger(level=settings.logger_settings.level)
storage = gen_storage(settings=settings)

Base.metadata.create_all(storage.engine)
with Session(storage.engine) as sess:
with open(os.path.join(args.data_dir, args.topics_file_name), "r", encoding="utf-8") as fin:
for d in csv.DictReader(fin):
d["topic_id"] = int(d["topic_id"])
d["label"] = json.loads(d["label"])
if sess.query(TopicRecord).filter(TopicRecord.topic_id == d["topic_id"]).count() > 0:
continue
sess.add(TopicRecord(topic_id=d["topic_id"], label=d["label"]))
sess.commit()
with open(os.path.join(args.data_dir, args.notes_file_name), "r", encoding="utf-8") as fin:
for d in csv.DictReader(fin):
if sess.query(NoteRecord).filter(NoteRecord.note_id == d["note_id"]).count() > 0:
continue
sess.add(
NoteRecord(
note_id=d["note_id"],
post_id=d["post_id"],
language=d["language"],
summary=d["summary"],
created_at=d["created_at"],
)
)
sess.commit()
with open(os.path.join(args.data_dir, args.notes_topics_association_file_name), "r", encoding="utf-8") as fin:
for d in csv.DictReader(fin):
if (
sess.query(NoteTopicAssociation)
.filter(
NoteTopicAssociation.note_id == d["note_id"],
NoteTopicAssociation.topic_id == d["topic_id"],
)
.count()
> 0
):
continue
sess.add(
NoteTopicAssociation(
note_id=d["note_id"],
topic_id=d["topic_id"],
)
)
sess.commit()
logger.info("Migration is done")

0 comments on commit 8e7f6be

Please sign in to comment.