From 6af9d6737eac10b2be8b786dce220967a1ce53bb Mon Sep 17 00:00:00 2001 From: ayuki_j <19406594+ayuki-joto@users.noreply.github.com> Date: Thu, 25 Jul 2024 16:35:16 +0900 Subject: [PATCH] feat: crete generate_note_topic method --- etl/seed/fewshot_sample.json | 7 +++ .../lib/ai_model/ai_model_interface_base.py | 5 ++ .../lib/openapi/open_ai_service.py | 61 ++++++++++++++++++- etl/src/birdxplorer_etl/transform.py | 38 ++++++++++++ 4 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 etl/seed/fewshot_sample.json diff --git a/etl/seed/fewshot_sample.json b/etl/seed/fewshot_sample.json new file mode 100644 index 0000000..1e8e4bf --- /dev/null +++ b/etl/seed/fewshot_sample.json @@ -0,0 +1,7 @@ +{ + "tweet": "For those that care — 432 hz improves mental clarity, removes emotional blockages, reduces stress and anxiety, better sleep quality, increases creativity & inspiration, and strengthens the immune system. Play it while you sleep & watch these areas improve!", + "note": "There are no placebo controlled studies which support this. There is no evidence that this frequency has different effects from any other arbitrary frequency. https://ask.audio/articles/music-theory-432-hz-tuning-separating-fact-from-fiction", + "topics": [ + "医療", "福祉" + ] +} \ No newline at end of file diff --git a/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py index cfb5570..34f4984 100644 --- a/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py +++ b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py @@ -1,4 +1,9 @@ +from typing import Dict, List + + class AIModelInterface: def detect_language(self, text: str) -> str: raise NotImplementedError("detect_language method not implemented") + def detect_topic(self, note_id: int, note: str) -> Dict[str, List[str]]: + raise NotImplementedError("detect_topic method not implemented") diff --git a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py index 045cc20..beeeab3 100644 --- a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py +++ b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py @@ -2,6 +2,9 @@ from birdxplorer_etl.lib.ai_model.ai_model_interface_base import AIModelInterface from birdxplorer_common.storage import LanguageIdentifier from openai import OpenAI +from typing import Dict, List +import csv +import json class OpenAIService(AIModelInterface): @@ -10,6 +13,18 @@ def __init__(self): self.client = OpenAI( api_key=self.api_key ) + self.topics = self.load_topics('./data/transformed/topic.csv') + + def load_topics(self, topic_csv_file_path: str) -> Dict[str, int]: + topics = {} + with open(topic_csv_file_path, newline='', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + topic_id = int(row['topic_id']) + labels = json.loads(row['label'].replace("'", '"')) + for label in labels.values(): + topics[label] = topic_id + return topics def detect_language(self, text: str) -> str: prompt = ( @@ -39,4 +54,48 @@ def detect_language(self, text: str) -> str: if valid_code: return LanguageIdentifier(valid_code) - raise ValueError(f"Invali 3d language code received: {message_content}") \ No newline at end of file + raise ValueError(f"Invalid language code received: {message_content}") + + def detect_topic(self, note_id: int, note: str) -> Dict[str, List[int]]: + topic_examples = "\n".join([f"{key}: {value}" for key, value in self.topics.items()]) + with open('./seed/fewshot_sample.json', newline='', encoding='utf-8') as f: + fewshot_sample = json.load(f) + + prompt = f""" + 以下はコミュニティノートです。 + コミュニティノート: + ``` + {fewshot_sample["note"]} + ``` + このセットに対してのトピックは「{" ".join(fewshot_sample["topics"])}」です。 + これを踏まえて、以下のセットに対して同じ粒度で複数のトピック(少なくとも3つ)を提示してください。 + コミュニティノート: + ``` + {note} + ``` + 以下のトピックは、 + ``` + topic: topic_id + ``` + の形で構成されています。 + こちらを使用して関連するものを推測してください。形式はJSONで、キーをtopicsとして値に必ず数字のtopic_idを配列で格納してください。 + また指定された情報以外は含めないでください。 + + トピックの例: + {topic_examples} + """ + response = self.client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt} + ], + temperature=0.0, + ) + response_text = response.choices[0].message.content.strip() + response_text = response_text.replace('```json', '').replace('```', '').strip() + try: + return json.loads(response_text) + except json.JSONDecodeError as e: + print(f'Error decoding JSON: {e}') + return {} diff --git a/etl/src/birdxplorer_etl/transform.py b/etl/src/birdxplorer_etl/transform.py index b1ab7d3..53423c3 100644 --- a/etl/src/birdxplorer_etl/transform.py +++ b/etl/src/birdxplorer_etl/transform.py @@ -137,3 +137,41 @@ def transform_data(db: Session): }) return + +def generate_note_topic(): + note_csv_file_path = './data/transformed/note.csv' + output_csv_file_path = './data/transformed/note_topic_association.csv' + ai_service = get_ai_service() + + records = [] + with open(output_csv_file_path, "w", newline='', encoding='utf-8', buffering=1) as file: + fieldnames = ['note_id', 'topic_id'] + writer = csv.DictWriter(file, fieldnames=fieldnames) + writer.writeheader() + + with open(note_csv_file_path, newline='', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for index, row in enumerate(reader): + note_id = row['note_id'] + summary = row['summary'] + topics_info = ai_service.detect_topic(note_id, summary) + if topics_info: + for topic in topics_info.get('topics', []): + record = {"note_id": note_id, "topic_id": topic} + records.append(record) + + if index % 100 == 0: + for record in records: + writer.writerow({ + 'note_id': record["note_id"], + 'topic_id': record["topic_id"], + }) + records = [] + + for record in records: + writer.writerow({ + 'note_id': record["note_id"], + 'topic_id': record["topic_id"], + }) + + print(f"New CSV file has been created at {output_csv_file_path}")