Skip to content

Commit

Permalink
feat: add generate_note_topic
Browse files Browse the repository at this point in the history
  • Loading branch information
ayuki-joto committed Aug 14, 2024
1 parent 165e877 commit 719343c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
13 changes: 8 additions & 5 deletions etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ def load_topics(self, topic_csv_file_path: str) -> Dict[str, int]:
for row in reader:
topic_id = int(row['topic_id'])
labels = json.loads(row['label'].replace("'", '"'))
for label in labels.values():
topics[label] = topic_id
# 日本語のラベルのみを使用するように
if 'ja' in labels:
topics[labels['ja']] = topic_id
# for label in labels.values():
# topics[label] = topic_id
return topics

def detect_language(self, text: str) -> str:
Expand All @@ -34,7 +37,7 @@ def detect_language(self, text: str) -> str:
)

response = self.client.chat.completions.create(
model="gpt-3.5-turbo",
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
Expand Down Expand Up @@ -81,12 +84,12 @@ def detect_topic(self, note_id: int, note: str) -> Dict[str, List[int]]:
の形で構成されています。
こちらを使用して関連するものを推測してください。形式はJSONで、キーをtopicsとして値に必ず数字のtopic_idを配列で格納してください。
また指定された情報以外は含めないでください。
トピックの例:
{topic_examples}
"""
response = self.client.chat.completions.create(
model="gpt-3.5-turbo",
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
Expand Down
2 changes: 2 additions & 0 deletions etl/src/birdxplorer_etl/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
AI_MODEL = os.getenv("AI_MODEL")
OPENAPI_TOKEN = os.getenv("OPENAPI_TOKEN")
CLAUDE_TOKEN = os.getenv("CLAUDE_TOKEN")
TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND = os.getenv("TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND")
TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND = os.getenv("TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND")
6 changes: 6 additions & 0 deletions etl/src/birdxplorer_etl/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def transform_data(db: Session):

return


def generate_note_topic():
note_csv_file_path = './data/transformed/note.csv'
output_csv_file_path = './data/transformed/note_topic_association.csv'
Expand Down Expand Up @@ -177,6 +178,7 @@ def generate_note_topic():
'topic_id': record["topic_id"],
})
records = []
print(index)

for record in records:
writer.writerow({
Expand All @@ -185,3 +187,7 @@ def generate_note_topic():
})

print(f"New CSV file has been created at {output_csv_file_path}")


if __name__ == "__main__":
generate_note_topic()

0 comments on commit 719343c

Please sign in to comment.