Skip to content

Commit

Permalink
feat: crete generate_note_topic method
Browse files Browse the repository at this point in the history
  • Loading branch information
ayuki-joto committed Jul 25, 2024
1 parent b04dfca commit 6af9d67
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 1 deletion.
7 changes: 7 additions & 0 deletions etl/seed/fewshot_sample.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"tweet": "For those that care — 432 hz improves mental clarity, removes emotional blockages, reduces stress and anxiety, better sleep quality, increases creativity & inspiration, and strengthens the immune system. Play it while you sleep & watch these areas improve!",
"note": "There are no placebo controlled studies which support this. There is no evidence that this frequency has different effects from any other arbitrary frequency. https://ask.audio/articles/music-theory-432-hz-tuning-separating-fact-from-fiction",
"topics": [
"医療", "福祉"
]
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from typing import Dict, List


class AIModelInterface:
def detect_language(self, text: str) -> str:
raise NotImplementedError("detect_language method not implemented")

def detect_topic(self, note_id: int, note: str) -> Dict[str, List[str]]:
raise NotImplementedError("detect_topic method not implemented")
61 changes: 60 additions & 1 deletion etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from birdxplorer_etl.lib.ai_model.ai_model_interface_base import AIModelInterface
from birdxplorer_common.storage import LanguageIdentifier
from openai import OpenAI
from typing import Dict, List
import csv
import json


class OpenAIService(AIModelInterface):
Expand All @@ -10,6 +13,18 @@ def __init__(self):
self.client = OpenAI(
api_key=self.api_key
)
self.topics = self.load_topics('./data/transformed/topic.csv')

def load_topics(self, topic_csv_file_path: str) -> Dict[str, int]:
topics = {}
with open(topic_csv_file_path, newline='', encoding='utf-8') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
topic_id = int(row['topic_id'])
labels = json.loads(row['label'].replace("'", '"'))
for label in labels.values():
topics[label] = topic_id
return topics

def detect_language(self, text: str) -> str:
prompt = (
Expand Down Expand Up @@ -39,4 +54,48 @@ def detect_language(self, text: str) -> str:
if valid_code:
return LanguageIdentifier(valid_code)

raise ValueError(f"Invali 3d language code received: {message_content}")
raise ValueError(f"Invalid language code received: {message_content}")

def detect_topic(self, note_id: int, note: str) -> Dict[str, List[int]]:
topic_examples = "\n".join([f"{key}: {value}" for key, value in self.topics.items()])
with open('./seed/fewshot_sample.json', newline='', encoding='utf-8') as f:
fewshot_sample = json.load(f)

prompt = f"""
以下はコミュニティノートです。
コミュニティノート:
```
{fewshot_sample["note"]}
```
このセットに対してのトピックは「{" ".join(fewshot_sample["topics"])}」です。
これを踏まえて、以下のセットに対して同じ粒度で複数のトピック(少なくとも3つ)を提示してください。
コミュニティノート:
```
{note}
```
以下のトピックは、
```
topic: topic_id
```
の形で構成されています。
こちらを使用して関連するものを推測してください。形式はJSONで、キーをtopicsとして値に必ず数字のtopic_idを配列で格納してください。
また指定された情報以外は含めないでください。
トピックの例:
{topic_examples}
"""
response = self.client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
],
temperature=0.0,
)
response_text = response.choices[0].message.content.strip()
response_text = response_text.replace('```json', '').replace('```', '').strip()
try:
return json.loads(response_text)
except json.JSONDecodeError as e:
print(f'Error decoding JSON: {e}')
return {}
38 changes: 38 additions & 0 deletions etl/src/birdxplorer_etl/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,41 @@ def transform_data(db: Session):
})

return

def generate_note_topic():
note_csv_file_path = './data/transformed/note.csv'
output_csv_file_path = './data/transformed/note_topic_association.csv'
ai_service = get_ai_service()

records = []
with open(output_csv_file_path, "w", newline='', encoding='utf-8', buffering=1) as file:
fieldnames = ['note_id', 'topic_id']
writer = csv.DictWriter(file, fieldnames=fieldnames)
writer.writeheader()

with open(note_csv_file_path, newline='', encoding='utf-8') as csvfile:
reader = csv.DictReader(csvfile)
for index, row in enumerate(reader):
note_id = row['note_id']
summary = row['summary']
topics_info = ai_service.detect_topic(note_id, summary)
if topics_info:
for topic in topics_info.get('topics', []):
record = {"note_id": note_id, "topic_id": topic}
records.append(record)

if index % 100 == 0:
for record in records:
writer.writerow({
'note_id': record["note_id"],
'topic_id': record["topic_id"],
})
records = []

for record in records:
writer.writerow({
'note_id': record["note_id"],
'topic_id': record["topic_id"],
})

print(f"New CSV file has been created at {output_csv_file_path}")

0 comments on commit 6af9d67

Please sign in to comment.