From 00d9110c442421b3a9f860cf4963d5e8a5099561 Mon Sep 17 00:00:00 2001 From: ayuki_j <19406594+ayuki-joto@users.noreply.github.com> Date: Thu, 18 Jul 2024 17:49:31 +0900 Subject: [PATCH 01/12] feat: add topic seed csv --- etl/seed/topic_seed.csv | 123 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 etl/seed/topic_seed.csv diff --git a/etl/seed/topic_seed.csv b/etl/seed/topic_seed.csv new file mode 100644 index 0000000..246c1a9 --- /dev/null +++ b/etl/seed/topic_seed.csv @@ -0,0 +1,123 @@ +topic +政治 +メディア +人権 +社会 +国際関係 +医療 +エンターテイメント +犯罪 +法律 +経済 +移民 +テロリズム +紛争 +スポーツ +宗教 +科学 +歴史 +教育 +環境 +テクノロジー +健康 +軍事 +文化 +詐欺 +ビジネス +ジェンダー +交通 +福祉 +技術 +戦争 +社会問題 +ソーシャルメディア +言語 +労働 +エネルギー +インターネット +音楽 +家族 +治安 +人種差別 +選挙 +風刺 +陰謀論 +フェイクニュース +ニュース +広告 +人種 +コミュニケーション +映画 +宇宙 +自動車 +地理 +ワクチン +安全保障 +観光 +中東 +国際情勢 +COVID-19 +農業 +感染症 +ファッション +メンタルヘルス +SNS +雇用 +自然 +暴動 +プライバシー +子供 +動物愛護 +生物学 +金融 +美容 +差別 +交通事故 +消費者保護 +災害 +アニメ +気候変動 +食品 +AI +建築 +人間関係 +住宅 +心理学 +国家 +料理 +国際政治 +子育て +民主主義 +偽情報 +アート +都市計画 +性暴力 +投資 +天文学 +テロ +社会運動 +税金 +薬物 +司法 +外交 +言論 +家庭 +腐敗 +戦争犯罪 +憲法 +オンラインセキュリティ +財政 +国際法 +災害 +人道支援 +中国 +性教育 +貧困 +国境管理 +情報操作 +暗号通貨 +環境保護 +イスラエル +欧州連合 EU +データ +フェミニズム \ No newline at end of file From c02c62512b4cc77b049a166e930df54c8f3253ac Mon Sep 17 00:00:00 2001 From: ayuki_j <19406594+ayuki-joto@users.noreply.github.com> Date: Thu, 18 Jul 2024 17:55:11 +0900 Subject: [PATCH 02/12] feat: add generate topic csv and ai model interface --- etl/.env.example | 5 ++- etl/pyproject.toml | 3 +- .../lib/ai_model/ai_model_interface.py | 13 +++++++ .../lib/ai_model/ai_model_interface_base.py | 4 ++ .../lib/claude/claude_service.py | 7 ++++ .../lib/openapi/open_ai_service.py | 37 +++++++++++++++++++ etl/src/birdxplorer_etl/settings.py | 3 ++ etl/src/birdxplorer_etl/transform.py | 29 +++++++++++++++ 8 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface.py create mode 100644 etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py create mode 100644 etl/src/birdxplorer_etl/lib/claude/claude_service.py create mode 100644 etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py diff --git a/etl/.env.example b/etl/.env.example index 3d2cf0e..d0e408f 100644 --- a/etl/.env.example +++ b/etl/.env.example @@ -1 +1,4 @@ -X_BEARER_TOKEN= \ No newline at end of file +X_BEARER_TOKEN= +AI_MODEL= +OPENAPI_TOKEN= +CLAUDE_TOKEN= \ No newline at end of file diff --git a/etl/pyproject.toml b/etl/pyproject.toml index cc45a14..41d9b91 100644 --- a/etl/pyproject.toml +++ b/etl/pyproject.toml @@ -29,7 +29,8 @@ dependencies = [ "requests", "pytest", "prefect", - "stringcase" + "stringcase", + "openai" ] [project.urls] diff --git a/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface.py b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface.py new file mode 100644 index 0000000..dd031c4 --- /dev/null +++ b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface.py @@ -0,0 +1,13 @@ +from birdxplorer_etl.settings import AI_MODEL +from birdxplorer_etl.lib.openapi.open_ai_service import OpenAIService +from birdxplorer_etl.lib.claude.claude_service import ClaudeService +from birdxplorer_etl.lib.ai_model.ai_model_interface_base import AIModelInterface + + +def get_ai_service() -> AIModelInterface: + if AI_MODEL == "openai": + return OpenAIService() + elif AI_MODEL == "claude": + return ClaudeService() + else: + raise ValueError(f"Unsupported AI service: {AI_MODEL}") diff --git a/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py new file mode 100644 index 0000000..bed759d --- /dev/null +++ b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py @@ -0,0 +1,4 @@ +class AIModelInterface: + def detect_language(self, text: str) -> str: + raise NotImplementedError("langdetect method not implemented") + diff --git a/etl/src/birdxplorer_etl/lib/claude/claude_service.py b/etl/src/birdxplorer_etl/lib/claude/claude_service.py new file mode 100644 index 0000000..1ecfe26 --- /dev/null +++ b/etl/src/birdxplorer_etl/lib/claude/claude_service.py @@ -0,0 +1,7 @@ +from birdxplorer_etl.settings import CLAUDE_TOKEN +from birdxplorer_etl.lib.ai_model.ai_model_interface_base import AIModelInterface + + +class ClaudeService(AIModelInterface): + def __init__(self): + self.api_key = CLAUDE_TOKEN diff --git a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py new file mode 100644 index 0000000..721e570 --- /dev/null +++ b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py @@ -0,0 +1,37 @@ +from birdxplorer_etl.settings import OPENAPI_TOKEN +from birdxplorer_etl.lib.ai_model.ai_model_interface_base import AIModelInterface +from birdxplorer_common.storage import LanguageIdentifier +from openai import OpenAI + + +class OpenAIService(AIModelInterface): + def __init__(self): + self.api_key = OPENAPI_TOKEN + self.client = OpenAI( + api_key=self.api_key + ) + + def detect_language(self, text: str) -> str: + prompt = ( + "Detect the language of the following text and return only the language code " + f"from this list: en, es, ja, pt, de, fr. Text: {text}. " + "Respond with only the language code, nothing else." + ) + + response = self.client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt} + ], + max_tokens=30 + ) + message_content = response.choices[0].message.content.strip() + + if message_content not in LanguageIdentifier._value2member_map_: + for code in LanguageIdentifier._value2member_map_: + if code in message_content: + return LanguageIdentifier(code) + raise ValueError(f"Invalid language code received: {message_content}") + else: + return LanguageIdentifier(message_content) diff --git a/etl/src/birdxplorer_etl/settings.py b/etl/src/birdxplorer_etl/settings.py index 30a4e6e..84b50d7 100644 --- a/etl/src/birdxplorer_etl/settings.py +++ b/etl/src/birdxplorer_etl/settings.py @@ -10,3 +10,6 @@ COMMUNITY_NOTE_DAYS_AGO = 3 X_BEARER_TOKEN = os.getenv("X_BEARER_TOKEN") +AI_MODEL = os.getenv("AI_MODEL") +OPENAPI_TOKEN = os.getenv("OPENAPI_TOKEN") +CLAUDE_TOKEN = os.getenv("CLAUDE_TOKEN") diff --git a/etl/src/birdxplorer_etl/transform.py b/etl/src/birdxplorer_etl/transform.py index 40cd00b..b1ab7d3 100644 --- a/etl/src/birdxplorer_etl/transform.py +++ b/etl/src/birdxplorer_etl/transform.py @@ -2,6 +2,7 @@ from sqlalchemy import select, func from sqlalchemy.orm import Session from birdxplorer_common.storage import RowNoteRecord, RowPostRecord, RowUserRecord +from birdxplorer_etl.lib.ai_model.ai_model_interface import get_ai_service import csv import os @@ -107,4 +108,32 @@ def transform_data(db: Session): writer.writerow(user) offset += limit + csv_seed_file_path = './seed/topic_seed.csv' + output_csv_file_path = "./data/transformed/topic.csv" + records = [] + ai_service = get_ai_service() + + if os.path.exists(output_csv_file_path): + return + + with open(csv_seed_file_path, newline='', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for index, row in enumerate(reader): + if 'topic' in row and row['topic']: + topic_id = index + 1 + language_identifier = ai_service.detect_language(row['topic']) + label = {language_identifier: row['topic']} # Assuming the label is in Japanese + record = {"topic_id": topic_id, "label": label} + records.append(record) + + with open(output_csv_file_path, "w", newline='', encoding='utf-8') as file: + fieldnames = ['topic_id', 'label'] + writer = csv.DictWriter(file, fieldnames=fieldnames) + writer.writeheader() + for record in records: + writer.writerow({ + 'topic_id': record["topic_id"], + 'label': {k.value: v for k, v in record["label"].items()} + }) + return From b04dfca059ae632ec39b38dbae4efdfe39ae820a Mon Sep 17 00:00:00 2001 From: ayuki_j <19406594+ayuki-joto@users.noreply.github.com> Date: Wed, 24 Jul 2024 16:42:44 +0900 Subject: [PATCH 03/12] fix: refactor apply review --- .../lib/ai_model/ai_model_interface_base.py | 2 +- .../lib/openapi/open_ai_service.py | 17 +++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py index bed759d..cfb5570 100644 --- a/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py +++ b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py @@ -1,4 +1,4 @@ class AIModelInterface: def detect_language(self, text: str) -> str: - raise NotImplementedError("langdetect method not implemented") + raise NotImplementedError("detect_language method not implemented") diff --git a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py index 721e570..045cc20 100644 --- a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py +++ b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py @@ -24,14 +24,19 @@ def detect_language(self, text: str) -> str: {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt} ], + temperature=0.0, + seed=1, max_tokens=30 ) + message_content = response.choices[0].message.content.strip() - if message_content not in LanguageIdentifier._value2member_map_: - for code in LanguageIdentifier._value2member_map_: - if code in message_content: - return LanguageIdentifier(code) - raise ValueError(f"Invalid language code received: {message_content}") - else: + if message_content in LanguageIdentifier._value2member_map_: return LanguageIdentifier(message_content) + + valid_code = next((code for code in LanguageIdentifier._value2member_map_ if code in message_content), None) + + if valid_code: + return LanguageIdentifier(valid_code) + + raise ValueError(f"Invali 3d language code received: {message_content}") \ No newline at end of file From 6af9d6737eac10b2be8b786dce220967a1ce53bb Mon Sep 17 00:00:00 2001 From: ayuki_j <19406594+ayuki-joto@users.noreply.github.com> Date: Thu, 25 Jul 2024 16:35:16 +0900 Subject: [PATCH 04/12] feat: crete generate_note_topic method --- etl/seed/fewshot_sample.json | 7 +++ .../lib/ai_model/ai_model_interface_base.py | 5 ++ .../lib/openapi/open_ai_service.py | 61 ++++++++++++++++++- etl/src/birdxplorer_etl/transform.py | 38 ++++++++++++ 4 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 etl/seed/fewshot_sample.json diff --git a/etl/seed/fewshot_sample.json b/etl/seed/fewshot_sample.json new file mode 100644 index 0000000..1e8e4bf --- /dev/null +++ b/etl/seed/fewshot_sample.json @@ -0,0 +1,7 @@ +{ + "tweet": "For those that care — 432 hz improves mental clarity, removes emotional blockages, reduces stress and anxiety, better sleep quality, increases creativity & inspiration, and strengthens the immune system. Play it while you sleep & watch these areas improve!", + "note": "There are no placebo controlled studies which support this. There is no evidence that this frequency has different effects from any other arbitrary frequency. https://ask.audio/articles/music-theory-432-hz-tuning-separating-fact-from-fiction", + "topics": [ + "医療", "福祉" + ] +} \ No newline at end of file diff --git a/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py index cfb5570..34f4984 100644 --- a/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py +++ b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py @@ -1,4 +1,9 @@ +from typing import Dict, List + + class AIModelInterface: def detect_language(self, text: str) -> str: raise NotImplementedError("detect_language method not implemented") + def detect_topic(self, note_id: int, note: str) -> Dict[str, List[str]]: + raise NotImplementedError("detect_topic method not implemented") diff --git a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py index 045cc20..beeeab3 100644 --- a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py +++ b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py @@ -2,6 +2,9 @@ from birdxplorer_etl.lib.ai_model.ai_model_interface_base import AIModelInterface from birdxplorer_common.storage import LanguageIdentifier from openai import OpenAI +from typing import Dict, List +import csv +import json class OpenAIService(AIModelInterface): @@ -10,6 +13,18 @@ def __init__(self): self.client = OpenAI( api_key=self.api_key ) + self.topics = self.load_topics('./data/transformed/topic.csv') + + def load_topics(self, topic_csv_file_path: str) -> Dict[str, int]: + topics = {} + with open(topic_csv_file_path, newline='', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + topic_id = int(row['topic_id']) + labels = json.loads(row['label'].replace("'", '"')) + for label in labels.values(): + topics[label] = topic_id + return topics def detect_language(self, text: str) -> str: prompt = ( @@ -39,4 +54,48 @@ def detect_language(self, text: str) -> str: if valid_code: return LanguageIdentifier(valid_code) - raise ValueError(f"Invali 3d language code received: {message_content}") \ No newline at end of file + raise ValueError(f"Invalid language code received: {message_content}") + + def detect_topic(self, note_id: int, note: str) -> Dict[str, List[int]]: + topic_examples = "\n".join([f"{key}: {value}" for key, value in self.topics.items()]) + with open('./seed/fewshot_sample.json', newline='', encoding='utf-8') as f: + fewshot_sample = json.load(f) + + prompt = f""" + 以下はコミュニティノートです。 + コミュニティノート: + ``` + {fewshot_sample["note"]} + ``` + このセットに対してのトピックは「{" ".join(fewshot_sample["topics"])}」です。 + これを踏まえて、以下のセットに対して同じ粒度で複数のトピック(少なくとも3つ)を提示してください。 + コミュニティノート: + ``` + {note} + ``` + 以下のトピックは、 + ``` + topic: topic_id + ``` + の形で構成されています。 + こちらを使用して関連するものを推測してください。形式はJSONで、キーをtopicsとして値に必ず数字のtopic_idを配列で格納してください。 + また指定された情報以外は含めないでください。 + + トピックの例: + {topic_examples} + """ + response = self.client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt} + ], + temperature=0.0, + ) + response_text = response.choices[0].message.content.strip() + response_text = response_text.replace('```json', '').replace('```', '').strip() + try: + return json.loads(response_text) + except json.JSONDecodeError as e: + print(f'Error decoding JSON: {e}') + return {} diff --git a/etl/src/birdxplorer_etl/transform.py b/etl/src/birdxplorer_etl/transform.py index b1ab7d3..53423c3 100644 --- a/etl/src/birdxplorer_etl/transform.py +++ b/etl/src/birdxplorer_etl/transform.py @@ -137,3 +137,41 @@ def transform_data(db: Session): }) return + +def generate_note_topic(): + note_csv_file_path = './data/transformed/note.csv' + output_csv_file_path = './data/transformed/note_topic_association.csv' + ai_service = get_ai_service() + + records = [] + with open(output_csv_file_path, "w", newline='', encoding='utf-8', buffering=1) as file: + fieldnames = ['note_id', 'topic_id'] + writer = csv.DictWriter(file, fieldnames=fieldnames) + writer.writeheader() + + with open(note_csv_file_path, newline='', encoding='utf-8') as csvfile: + reader = csv.DictReader(csvfile) + for index, row in enumerate(reader): + note_id = row['note_id'] + summary = row['summary'] + topics_info = ai_service.detect_topic(note_id, summary) + if topics_info: + for topic in topics_info.get('topics', []): + record = {"note_id": note_id, "topic_id": topic} + records.append(record) + + if index % 100 == 0: + for record in records: + writer.writerow({ + 'note_id': record["note_id"], + 'topic_id': record["topic_id"], + }) + records = [] + + for record in records: + writer.writerow({ + 'note_id': record["note_id"], + 'topic_id': record["topic_id"], + }) + + print(f"New CSV file has been created at {output_csv_file_path}") From 165e8776ac736972b48c4d9e3e2fa70d9b224f0e Mon Sep 17 00:00:00 2001 From: ayuki_j <19406594+ayuki-joto@users.noreply.github.com> Date: Wed, 14 Aug 2024 15:35:05 +0900 Subject: [PATCH 05/12] feat: add detect language to note --- common/birdxplorer_common/models.py | 13 ++ etl/.env.example | 4 +- etl/seed/topic_seed.csv | 187 ++++++------------ .../lib/openapi/open_ai_service.py | 5 +- etl/src/birdxplorer_etl/transform.py | 50 +++-- 5 files changed, 113 insertions(+), 146 deletions(-) diff --git a/common/birdxplorer_common/models.py b/common/birdxplorer_common/models.py index 5d19baa..a992edd 100644 --- a/common/birdxplorer_common/models.py +++ b/common/birdxplorer_common/models.py @@ -608,6 +608,19 @@ class LanguageIdentifier(str, Enum): PT = "pt" DE = "de" FR = "fr" + FI = "fi" + TR = "tr" + NL = "nl" + HE = "he" + IT = "it" + FA = "fa" + CA = "ca" + AR = "ar" + EL = "el" + SV = "sv" + DA = "da" + RU = "ru" + PL = "pl" class TopicLabelString(NonEmptyTrimmedString): ... diff --git a/etl/.env.example b/etl/.env.example index d0e408f..112d707 100644 --- a/etl/.env.example +++ b/etl/.env.example @@ -1,4 +1,6 @@ X_BEARER_TOKEN= AI_MODEL= OPENAPI_TOKEN= -CLAUDE_TOKEN= \ No newline at end of file +CLAUDE_TOKEN= +TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND=1720900800000 +TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND=1722110400000 \ No newline at end of file diff --git a/etl/seed/topic_seed.csv b/etl/seed/topic_seed.csv index 246c1a9..9c28f8b 100644 --- a/etl/seed/topic_seed.csv +++ b/etl/seed/topic_seed.csv @@ -1,123 +1,64 @@ -topic -政治 -メディア -人権 -社会 -国際関係 -医療 -エンターテイメント -犯罪 -法律 -経済 -移民 -テロリズム -紛争 -スポーツ -宗教 -科学 -歴史 -教育 -環境 -テクノロジー -健康 -軍事 -文化 -詐欺 -ビジネス -ジェンダー -交通 -福祉 -技術 -戦争 -社会問題 -ソーシャルメディア -言語 -労働 -エネルギー -インターネット -音楽 -家族 -治安 -人種差別 -選挙 -風刺 -陰謀論 -フェイクニュース -ニュース -広告 -人種 -コミュニケーション -映画 -宇宙 -自動車 -地理 -ワクチン -安全保障 -観光 -中東 -国際情勢 -COVID-19 -農業 -感染症 -ファッション -メンタルヘルス -SNS -雇用 -自然 -暴動 -プライバシー -子供 -動物愛護 -生物学 -金融 -美容 -差別 -交通事故 -消費者保護 -災害 -アニメ -気候変動 -食品 -AI -建築 -人間関係 -住宅 -心理学 -国家 -料理 -国際政治 -子育て -民主主義 -偽情報 -アート -都市計画 -性暴力 -投資 -天文学 -テロ -社会運動 -税金 -薬物 -司法 -外交 -言論 -家庭 -腐敗 -戦争犯罪 -憲法 -オンラインセキュリティ -財政 -国際法 -災害 -人道支援 -中国 -性教育 -貧困 -国境管理 -情報操作 -暗号通貨 -環境保護 -イスラエル -欧州連合 EU -データ -フェミニズム \ No newline at end of file +en,ja +European Union,欧州連合 +coronavirus,コロナウイルス +Crimea,クリミア +Coup,クーデター +Sanctions,制裁 +Terrorism,テロ +Sovereignty,主権 +Eastern Ukraine,ウクライナ東部 +Syrian War,シリア戦争 +Chemical weapons/attack,化学兵器/攻撃 +Elections,選挙 +Protest,抗議する +WWII,第二次世界大戦 +Manipulated elections/referendum,操作された選挙/国民投票 +Vladimir Putin,ウラジーミル・プーチン +Migration crisis,移民の危機 +Ukrainian disintegration,ウクライナ崩壊 +Nuclear issues,核問題 +Imperialism/colonialism,帝国主義・植民地主義 +Economic difficulties,経済的困難 +vaccination,予防接種 +Biological weapons,生物兵器 +Donald Trump,ドナルド・トランプ +election meddling,選挙介入 +Media,メディア +security threat,セキュリティ上の脅威 +Joe Biden,ジョー・バイデン +Human rights,人権 +Democracy,民主主義 +Propaganda,プロパガンダ +Civil war,内戦 +Freedom of speech,言論の自由 +Military exercise,軍事演習 +LGBT,LGBT +Information war,情報戦 +Genocide,大量虐殺 +Sputnik V,スプートニクV +economy,経済 +War crimes,戦争犯罪 +Intelligence services,諜報機関 +Energy,エネルギー +Occupation,職業 +UN,国連 +migration,移民・移住 +Corruption,腐敗 +laboratory,研究室 +Censorship,検閲 +Refugees,難民 +fake news,フェイクニュース +scam,詐欺 +technology,テクノロジー +welfare,福祉 +mobility,交通 +travel,観光 +fashion,ファッション +mental health,メンタルヘルス +anime,アニメ +AI,AI +climate change,気候変動 +food,食品 +tax,税金 +drugs,薬物 +US presidential election,米国大統領選挙 \ No newline at end of file diff --git a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py index beeeab3..658afb3 100644 --- a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py +++ b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py @@ -41,7 +41,6 @@ def detect_language(self, text: str) -> str: ], temperature=0.0, seed=1, - max_tokens=30 ) message_content = response.choices[0].message.content.strip() @@ -54,7 +53,9 @@ def detect_language(self, text: str) -> str: if valid_code: return LanguageIdentifier(valid_code) - raise ValueError(f"Invalid language code received: {message_content}") + print(f"Invalid language code received: {message_content}") + # raise ValueError(f"Invalid language code received: {message_content}") + return message_content def detect_topic(self, note_id: int, note: str) -> Dict[str, List[int]]: topic_examples = "\n".join([f"{key}: {value}" for key, value in self.topics.items()]) diff --git a/etl/src/birdxplorer_etl/transform.py b/etl/src/birdxplorer_etl/transform.py index 53423c3..d0415fc 100644 --- a/etl/src/birdxplorer_etl/transform.py +++ b/etl/src/birdxplorer_etl/transform.py @@ -1,8 +1,10 @@ import logging -from sqlalchemy import select, func +from sqlalchemy import select, func, and_ from sqlalchemy.orm import Session from birdxplorer_common.storage import RowNoteRecord, RowPostRecord, RowUserRecord from birdxplorer_etl.lib.ai_model.ai_model_interface import get_ai_service +from birdxplorer_etl.settings import (TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND, + TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND) import csv import os @@ -19,24 +21,34 @@ def transform_data(db: Session): offset = 0 limit = 1000 + ai_service = get_ai_service() - num_of_notes = db.query(func.count(RowNoteRecord.note_id)).scalar() - - while offset < num_of_notes: - notes = db.execute( - select( - RowNoteRecord.note_id, RowNoteRecord.row_post_id, RowNoteRecord.summary, RowNoteRecord.created_at_millis + num_of_notes = (db.query(func.count(RowNoteRecord.note_id)) + .filter(and_(RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, + RowNoteRecord.created_at_millis >= TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND)) + .scalar()) + + with open("./data/transformed/note.csv", "a") as file: + writer = csv.writer(file) + writer.writerow(["note_id", "post_id", "summary", "created_at", "language"]) + + while offset < num_of_notes: + notes = db.execute( + select( + RowNoteRecord.note_id, RowNoteRecord.row_post_id, + RowNoteRecord.summary, RowNoteRecord.created_at_millis + ) + .filter(and_(RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, + RowNoteRecord.created_at_millis >= TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND)) + .limit(limit) + .offset(offset) ) - .limit(limit) - .offset(offset) - ) - with open("./data/transformed/note.csv", "a") as file: - writer = csv.writer(file) - writer.writerow(["note_id", "post_id", "summary", "created_at"]) for note in notes: - writer.writerow(note) - offset += limit + note_as_list = list(note) + note_as_list.append(ai_service.detect_language(note[2])) + writer.writerow(note_as_list) + offset += limit # Transform row post data and generate post.csv if os.path.exists("./data/transformed/post.csv"): @@ -111,7 +123,6 @@ def transform_data(db: Session): csv_seed_file_path = './seed/topic_seed.csv' output_csv_file_path = "./data/transformed/topic.csv" records = [] - ai_service = get_ai_service() if os.path.exists(output_csv_file_path): return @@ -119,10 +130,9 @@ def transform_data(db: Session): with open(csv_seed_file_path, newline='', encoding='utf-8') as csvfile: reader = csv.DictReader(csvfile) for index, row in enumerate(reader): - if 'topic' in row and row['topic']: + if 'ja' in row and row['ja']: topic_id = index + 1 - language_identifier = ai_service.detect_language(row['topic']) - label = {language_identifier: row['topic']} # Assuming the label is in Japanese + label = {'ja': row['ja'], 'en': row['en']} # Assuming the label is in Japanese record = {"topic_id": topic_id, "label": label} records.append(record) @@ -133,7 +143,7 @@ def transform_data(db: Session): for record in records: writer.writerow({ 'topic_id': record["topic_id"], - 'label': {k.value: v for k, v in record["label"].items()} + 'label': {k: v for k, v in record["label"].items()} }) return From 719343c317212bbc3917b5e649ba5f1dea63dafe Mon Sep 17 00:00:00 2001 From: ayuki_j <19406594+ayuki-joto@users.noreply.github.com> Date: Wed, 14 Aug 2024 16:08:22 +0900 Subject: [PATCH 06/12] feat: add generate_note_topic --- .../birdxplorer_etl/lib/openapi/open_ai_service.py | 13 ++++++++----- etl/src/birdxplorer_etl/settings.py | 2 ++ etl/src/birdxplorer_etl/transform.py | 6 ++++++ 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py index 658afb3..d97f430 100644 --- a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py +++ b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py @@ -22,8 +22,11 @@ def load_topics(self, topic_csv_file_path: str) -> Dict[str, int]: for row in reader: topic_id = int(row['topic_id']) labels = json.loads(row['label'].replace("'", '"')) - for label in labels.values(): - topics[label] = topic_id + # 日本語のラベルのみを使用するように + if 'ja' in labels: + topics[labels['ja']] = topic_id + # for label in labels.values(): + # topics[label] = topic_id return topics def detect_language(self, text: str) -> str: @@ -34,7 +37,7 @@ def detect_language(self, text: str) -> str: ) response = self.client.chat.completions.create( - model="gpt-3.5-turbo", + model="gpt-4o-mini", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt} @@ -81,12 +84,12 @@ def detect_topic(self, note_id: int, note: str) -> Dict[str, List[int]]: の形で構成されています。 こちらを使用して関連するものを推測してください。形式はJSONで、キーをtopicsとして値に必ず数字のtopic_idを配列で格納してください。 また指定された情報以外は含めないでください。 - + トピックの例: {topic_examples} """ response = self.client.chat.completions.create( - model="gpt-3.5-turbo", + model="gpt-4o-mini", messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt} diff --git a/etl/src/birdxplorer_etl/settings.py b/etl/src/birdxplorer_etl/settings.py index 84b50d7..e81325a 100644 --- a/etl/src/birdxplorer_etl/settings.py +++ b/etl/src/birdxplorer_etl/settings.py @@ -13,3 +13,5 @@ AI_MODEL = os.getenv("AI_MODEL") OPENAPI_TOKEN = os.getenv("OPENAPI_TOKEN") CLAUDE_TOKEN = os.getenv("CLAUDE_TOKEN") +TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND = os.getenv("TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND") +TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND = os.getenv("TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND") diff --git a/etl/src/birdxplorer_etl/transform.py b/etl/src/birdxplorer_etl/transform.py index d0415fc..948ece1 100644 --- a/etl/src/birdxplorer_etl/transform.py +++ b/etl/src/birdxplorer_etl/transform.py @@ -148,6 +148,7 @@ def transform_data(db: Session): return + def generate_note_topic(): note_csv_file_path = './data/transformed/note.csv' output_csv_file_path = './data/transformed/note_topic_association.csv' @@ -177,6 +178,7 @@ def generate_note_topic(): 'topic_id': record["topic_id"], }) records = [] + print(index) for record in records: writer.writerow({ @@ -185,3 +187,7 @@ def generate_note_topic(): }) print(f"New CSV file has been created at {output_csv_file_path}") + + +if __name__ == "__main__": + generate_note_topic() From 7f51d4700beade5fc0a8178b23d3965891bcee64 Mon Sep 17 00:00:00 2001 From: ayuki_j <19406594+ayuki-joto@users.noreply.github.com> Date: Wed, 14 Aug 2024 16:24:24 +0900 Subject: [PATCH 07/12] feat: use normalize --- etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py index d97f430..59f4e79 100644 --- a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py +++ b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py @@ -1,6 +1,6 @@ from birdxplorer_etl.settings import OPENAPI_TOKEN from birdxplorer_etl.lib.ai_model.ai_model_interface_base import AIModelInterface -from birdxplorer_common.storage import LanguageIdentifier +from birdxplorer_common.models import LanguageIdentifier from openai import OpenAI from typing import Dict, List import csv @@ -58,7 +58,7 @@ def detect_language(self, text: str) -> str: print(f"Invalid language code received: {message_content}") # raise ValueError(f"Invalid language code received: {message_content}") - return message_content + return LanguageIdentifier.normalize(message_content) def detect_topic(self, note_id: int, note: str) -> Dict[str, List[int]]: topic_examples = "\n".join([f"{key}: {value}" for key, value in self.topics.items()]) From 69164a9c0f32e3de819f0a62ffa771f47a0efc7b Mon Sep 17 00:00:00 2001 From: yu23ki14 Date: Fri, 16 Aug 2024 17:20:20 +0900 Subject: [PATCH 08/12] modify for fto --- common/birdxplorer_common/models.py | 4 +- common/tests/conftest.py | 10 +- .../lib/openapi/open_ai_service.py | 28 ++-- etl/src/birdxplorer_etl/transform.py | 147 ++++++++++-------- 4 files changed, 111 insertions(+), 78 deletions(-) diff --git a/common/birdxplorer_common/models.py b/common/birdxplorer_common/models.py index eaf1837..f32b620 100644 --- a/common/birdxplorer_common/models.py +++ b/common/birdxplorer_common/models.py @@ -134,7 +134,7 @@ class UpToNineteenDigitsDecimalString(BaseString): Traceback (most recent call last): ... pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str] - String should match pattern '^[0-9]{1,19}$' [type=string_pattern_mismatch, input_value='test', input_type=str] + String should match pattern '^([0-9]{0}|[0-9]{19})$' [type=string_pattern_mismatch, input_value='test', input_type=str] ... >>> UpToNineteenDigitsDecimalString.from_str("1234567890123456789") UpToNineteenDigitsDecimalString('1234567890123456789') @@ -142,7 +142,7 @@ class UpToNineteenDigitsDecimalString(BaseString): @classmethod def __get_extra_constraint_dict__(cls) -> dict[str, Any]: - return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9]{1,19}$") + return dict(super().__get_extra_constraint_dict__(), pattern=r"^([0-9]{0}|[0-9]{19})$") class NonEmptyStringMixin(BaseString): diff --git a/common/tests/conftest.py b/common/tests/conftest.py index a2b1094..71ab36d 100644 --- a/common/tests/conftest.py +++ b/common/tests/conftest.py @@ -109,7 +109,7 @@ def user_enrollment_samples( @fixture def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, None]: topics = [ - topic_factory.build(topic_id=0, label={"en": "topic0", "ja": "トピック0"}, reference_count=3), + topic_factory.build(topic_id=0, label={"en": "topic0", "ja": "トピック0"}, reference_count=4), topic_factory.build(topic_id=1, label={"en": "topic1", "ja": "トピック1"}, reference_count=2), topic_factory.build(topic_id=2, label={"en": "topic2", "ja": "トピック2"}, reference_count=1), topic_factory.build(topic_id=3, label={"en": "topic3", "ja": "トピック3"}, reference_count=0), @@ -160,6 +160,14 @@ def note_samples(note_factory: NoteFactory, topic_samples: List[Topic]) -> Gener summary="summary5", created_at=1152921604000, ), + note_factory.build( + note_id="1234567890123456786", + post_id="", + topics=[topic_samples[0]], + language="en", + summary="summary6_empty_post_id", + created_at=1152921604000, + ), ] yield notes diff --git a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py index 59f4e79..048a4f5 100644 --- a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py +++ b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py @@ -5,26 +5,26 @@ from typing import Dict, List import csv import json +import os class OpenAIService(AIModelInterface): def __init__(self): self.api_key = OPENAPI_TOKEN - self.client = OpenAI( - api_key=self.api_key - ) - self.topics = self.load_topics('./data/transformed/topic.csv') + self.client = OpenAI(api_key=self.api_key) + if os.path.exists("./data/transformed/topic.csv"): + self.topics = self.load_topics("./data/transformed/topic.csv") def load_topics(self, topic_csv_file_path: str) -> Dict[str, int]: topics = {} - with open(topic_csv_file_path, newline='', encoding='utf-8') as csvfile: + with open(topic_csv_file_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for row in reader: - topic_id = int(row['topic_id']) - labels = json.loads(row['label'].replace("'", '"')) + topic_id = int(row["topic_id"]) + labels = json.loads(row["label"].replace("'", '"')) # 日本語のラベルのみを使用するように - if 'ja' in labels: - topics[labels['ja']] = topic_id + if "ja" in labels: + topics[labels["ja"]] = topic_id # for label in labels.values(): # topics[label] = topic_id return topics @@ -40,7 +40,7 @@ def detect_language(self, text: str) -> str: model="gpt-4o-mini", messages=[ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt} + {"role": "user", "content": prompt}, ], temperature=0.0, seed=1, @@ -62,7 +62,7 @@ def detect_language(self, text: str) -> str: def detect_topic(self, note_id: int, note: str) -> Dict[str, List[int]]: topic_examples = "\n".join([f"{key}: {value}" for key, value in self.topics.items()]) - with open('./seed/fewshot_sample.json', newline='', encoding='utf-8') as f: + with open("./seed/fewshot_sample.json", newline="", encoding="utf-8") as f: fewshot_sample = json.load(f) prompt = f""" @@ -92,14 +92,14 @@ def detect_topic(self, note_id: int, note: str) -> Dict[str, List[int]]: model="gpt-4o-mini", messages=[ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt} + {"role": "user", "content": prompt}, ], temperature=0.0, ) response_text = response.choices[0].message.content.strip() - response_text = response_text.replace('```json', '').replace('```', '').strip() + response_text = response_text.replace("```json", "").replace("```", "").strip() try: return json.loads(response_text) except json.JSONDecodeError as e: - print(f'Error decoding JSON: {e}') + print(f"Error decoding JSON: {e}") return {} diff --git a/etl/src/birdxplorer_etl/transform.py b/etl/src/birdxplorer_etl/transform.py index 948ece1..acefbc2 100644 --- a/etl/src/birdxplorer_etl/transform.py +++ b/etl/src/birdxplorer_etl/transform.py @@ -1,16 +1,19 @@ -import logging -from sqlalchemy import select, func, and_ +from sqlalchemy import select, func, and_, Integer from sqlalchemy.orm import Session from birdxplorer_common.storage import RowNoteRecord, RowPostRecord, RowUserRecord from birdxplorer_etl.lib.ai_model.ai_model_interface import get_ai_service -from birdxplorer_etl.settings import (TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND, - TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND) +from birdxplorer_etl.settings import ( + TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND, + TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, +) import csv import os +from prefect import get_run_logger def transform_data(db: Session): - logging.info("Transforming data") + logger = get_run_logger() + logger.info("Transforming data") if not os.path.exists("./data/transformed"): os.makedirs("./data/transformed") @@ -18,28 +21,42 @@ def transform_data(db: Session): # Transform row note data and generate note.csv if os.path.exists("./data/transformed/note.csv"): os.remove("./data/transformed/note.csv") + with open("./data/transformed/note.csv", "a") as file: + writer = csv.writer(file) + writer.writerow(["note_id", "post_id", "summary", "created_at", "language"]) offset = 0 limit = 1000 ai_service = get_ai_service() - num_of_notes = (db.query(func.count(RowNoteRecord.note_id)) - .filter(and_(RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, - RowNoteRecord.created_at_millis >= TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND)) - .scalar()) + num_of_notes = ( + db.query(func.count(RowNoteRecord.note_id)) + .filter( + and_( + RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, + RowNoteRecord.created_at_millis >= TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND, + ) + ) + .scalar() + ) with open("./data/transformed/note.csv", "a") as file: - writer = csv.writer(file) - writer.writerow(["note_id", "post_id", "summary", "created_at", "language"]) + logger.info(f"Transforming note data: {num_of_notes}") while offset < num_of_notes: notes = db.execute( select( - RowNoteRecord.note_id, RowNoteRecord.row_post_id, - RowNoteRecord.summary, RowNoteRecord.created_at_millis + RowNoteRecord.note_id, + RowNoteRecord.row_post_id, + RowNoteRecord.summary, + func.cast(RowNoteRecord.created_at_millis, Integer).label("created_at"), + ) + .filter( + and_( + RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, + RowNoteRecord.created_at_millis >= TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND, + ) ) - .filter(and_(RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, - RowNoteRecord.created_at_millis >= TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND)) .limit(limit) .offset(offset) ) @@ -47,12 +64,18 @@ def transform_data(db: Session): for note in notes: note_as_list = list(note) note_as_list.append(ai_service.detect_language(note[2])) + writer = csv.writer(file) writer.writerow(note_as_list) offset += limit # Transform row post data and generate post.csv + logger.info("Transforming post data") + if os.path.exists("./data/transformed/post.csv"): os.remove("./data/transformed/post.csv") + with open("./data/transformed/post.csv", "a") as file: + writer = csv.writer(file) + writer.writerow(["post_id", "user_id", "text", "created_at", "like_count", "repost_count", "impression_count"]) offset = 0 limit = 1000 @@ -65,27 +88,35 @@ def transform_data(db: Session): RowPostRecord.post_id, RowPostRecord.author_id.label("user_id"), RowPostRecord.text, - RowPostRecord.created_at, - RowPostRecord.like_count, - RowPostRecord.repost_count, - RowPostRecord.impression_count, + func.cast(RowPostRecord.created_at, Integer).label("created_at"), + func.cast(RowPostRecord.like_count, Integer).label("like_count"), + func.cast(RowPostRecord.repost_count, Integer).label("repost_count"), + func.cast(RowPostRecord.impression_count, Integer).label("impression_count"), ) .limit(limit) .offset(offset) ) with open("./data/transformed/post.csv", "a") as file: - writer = csv.writer(file) - writer.writerow( - ["post_id", "user_id", "text", "created_at", "like_count", "repost_count", "impression_count"] - ) for post in posts: + writer = csv.writer(file) writer.writerow(post) offset += limit # Transform row user data and generate user.csv if os.path.exists("./data/transformed/user.csv"): os.remove("./data/transformed/user.csv") + with open("./data/transformed/user.csv", "a") as file: + writer = csv.writer(file) + writer.writerow( + [ + "user_id", + "name", + "profile_image", + "followers_count", + "following_count", + ] + ) offset = 0 limit = 1000 @@ -98,93 +129,87 @@ def transform_data(db: Session): RowUserRecord.user_id, RowUserRecord.user_name.label("name"), RowUserRecord.profile_image_url.label("profile_image"), - RowUserRecord.followers_count, - RowUserRecord.following_count, + func.cast(RowUserRecord.followers_count, Integer).label("followers_count"), + func.cast(RowUserRecord.following_count, Integer).label("following_count"), ) .limit(limit) .offset(offset) ) with open("./data/transformed/user.csv", "a") as file: - writer = csv.writer(file) - writer.writerow( - [ - "user_id", - "name", - "profile_image", - "followers_count", - "following_count", - ] - ) for user in users: + writer = csv.writer(file) writer.writerow(user) offset += limit - csv_seed_file_path = './seed/topic_seed.csv' + csv_seed_file_path = "./seed/topic_seed.csv" output_csv_file_path = "./data/transformed/topic.csv" records = [] if os.path.exists(output_csv_file_path): return - with open(csv_seed_file_path, newline='', encoding='utf-8') as csvfile: + with open(csv_seed_file_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for index, row in enumerate(reader): - if 'ja' in row and row['ja']: + if "ja" in row and row["ja"]: topic_id = index + 1 - label = {'ja': row['ja'], 'en': row['en']} # Assuming the label is in Japanese + label = {"ja": row["ja"], "en": row["en"]} # Assuming the label is in Japanese record = {"topic_id": topic_id, "label": label} records.append(record) - with open(output_csv_file_path, "w", newline='', encoding='utf-8') as file: - fieldnames = ['topic_id', 'label'] + with open(output_csv_file_path, "a", newline="", encoding="utf-8") as file: + fieldnames = ["topic_id", "label"] writer = csv.DictWriter(file, fieldnames=fieldnames) writer.writeheader() for record in records: - writer.writerow({ - 'topic_id': record["topic_id"], - 'label': {k: v for k, v in record["label"].items()} - }) + writer.writerow({"topic_id": record["topic_id"], "label": {k: v for k, v in record["label"].items()}}) + + generate_note_topic() return def generate_note_topic(): - note_csv_file_path = './data/transformed/note.csv' - output_csv_file_path = './data/transformed/note_topic_association.csv' + note_csv_file_path = "./data/transformed/note.csv" + output_csv_file_path = "./data/transformed/note_topic_association.csv" ai_service = get_ai_service() records = [] - with open(output_csv_file_path, "w", newline='', encoding='utf-8', buffering=1) as file: - fieldnames = ['note_id', 'topic_id'] + with open(output_csv_file_path, "w", newline="", encoding="utf-8", buffering=1) as file: + fieldnames = ["note_id", "topic_id"] writer = csv.DictWriter(file, fieldnames=fieldnames) writer.writeheader() - with open(note_csv_file_path, newline='', encoding='utf-8') as csvfile: + with open(note_csv_file_path, newline="", encoding="utf-8") as csvfile: reader = csv.DictReader(csvfile) for index, row in enumerate(reader): - note_id = row['note_id'] - summary = row['summary'] + note_id = row["note_id"] + summary = row["summary"] topics_info = ai_service.detect_topic(note_id, summary) if topics_info: - for topic in topics_info.get('topics', []): + for topic in topics_info.get("topics", []): record = {"note_id": note_id, "topic_id": topic} records.append(record) if index % 100 == 0: for record in records: - writer.writerow({ - 'note_id': record["note_id"], - 'topic_id': record["topic_id"], - }) + writer.writerow( + { + "note_id": record["note_id"], + "topic_id": record["topic_id"], + } + ) records = [] print(index) for record in records: - writer.writerow({ - 'note_id': record["note_id"], - 'topic_id': record["topic_id"], - }) + writer.writerow( + { + "note_id": record["note_id"], + "topic_id": record["topic_id"], + } + ) print(f"New CSV file has been created at {output_csv_file_path}") From 2712c1ae9d6a6f9cf8ce73b52e51886840fb36de Mon Sep 17 00:00:00 2001 From: yu23ki14 Date: Fri, 16 Aug 2024 17:33:44 +0900 Subject: [PATCH 09/12] fix test code --- common/birdxplorer_common/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/birdxplorer_common/models.py b/common/birdxplorer_common/models.py index f32b620..fcd6119 100644 --- a/common/birdxplorer_common/models.py +++ b/common/birdxplorer_common/models.py @@ -134,7 +134,7 @@ class UpToNineteenDigitsDecimalString(BaseString): Traceback (most recent call last): ... pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str] - String should match pattern '^([0-9]{0}|[0-9]{19})$' [type=string_pattern_mismatch, input_value='test', input_type=str] + String should match pattern '^([0-9]{19}|)$' [type=string_pattern_mismatch, input_value='test', input_type=str] ... >>> UpToNineteenDigitsDecimalString.from_str("1234567890123456789") UpToNineteenDigitsDecimalString('1234567890123456789') @@ -142,7 +142,7 @@ class UpToNineteenDigitsDecimalString(BaseString): @classmethod def __get_extra_constraint_dict__(cls) -> dict[str, Any]: - return dict(super().__get_extra_constraint_dict__(), pattern=r"^([0-9]{0}|[0-9]{19})$") + return dict(super().__get_extra_constraint_dict__(), pattern=r"^([0-9]{19}|)$") class NonEmptyStringMixin(BaseString): From f61440e043c87f7baab7b24c79b4d7036ab93a96 Mon Sep 17 00:00:00 2001 From: osoken Date: Sat, 17 Aug 2024 10:38:34 +0900 Subject: [PATCH 10/12] test(api): add test for limit and offset query --- api/tests/routers/test_data.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/api/tests/routers/test_data.py b/api/tests/routers/test_data.py index ddd29ad..82833d3 100644 --- a/api/tests/routers/test_data.py +++ b/api/tests/routers/test_data.py @@ -27,6 +27,13 @@ def test_posts_get(client: TestClient, post_samples: List[Post]) -> None: assert res_json == {"data": [json.loads(d.model_dump_json()) for d in post_samples]} +def test_posts_get_limit_and_offset(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?limit=2&offset=1") + assert response.status_code == 200 + res_json = response.json() + assert res_json == {"data": [json.loads(d.model_dump_json()) for d in post_samples[1:3]]} + + def test_posts_get_has_post_id_filter(client: TestClient, post_samples: List[Post]) -> None: response = client.get(f"/api/v1/data/posts/?postId={post_samples[0].post_id},{post_samples[2].post_id}") assert response.status_code == 200 From 39595f9b590d37453d5581c1ae32e31c349f7a9b Mon Sep 17 00:00:00 2001 From: osoken Date: Sat, 17 Aug 2024 10:39:09 +0900 Subject: [PATCH 11/12] fix: fix by flake8 --- api/birdxplorer_api/routers/data.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/api/birdxplorer_api/routers/data.py b/api/birdxplorer_api/routers/data.py index ee6f1a7..a0a060b 100644 --- a/api/birdxplorer_api/routers/data.py +++ b/api/birdxplorer_api/routers/data.py @@ -96,7 +96,7 @@ def get_posts( created_at_start: Union[None, TwitterTimestamp, str] = Query(default=None), created_at_end: Union[None, TwitterTimestamp, str] = Query(default=None), offset: int = Query(default=0, ge=0), # 確保 offset 是非負的 - limit: int = Query(default=100, gt=0, le=1000) # 確保 limit 在合理範圍內 + limit: int = Query(default=100, gt=0, le=1000), # 確保 limit 在合理範圍內 ) -> PostListResponse: posts = None @@ -120,8 +120,8 @@ def get_posts( posts = list(storage.get_posts()) total_count = len(posts) - paginated_posts = posts[offset:offset + limit] - base_url = str(request.url).split('?')[0] + paginated_posts = posts[offset : offset + limit] + base_url = str(request.url).split("?")[0] next_offset = offset + limit prev_offset = max(offset - limit, 0) next_url = None @@ -131,12 +131,6 @@ def get_posts( if offset > 0: prev_url = f"{base_url}?offset={prev_offset}&limit={limit}" - return PostListResponse( - data=paginated_posts, - meta={ - "next": next_url, - "prev": prev_url - } - ) + return PostListResponse(data=paginated_posts, meta={"next": next_url, "prev": prev_url}) return router From 6a6e3d1060a89567b9a0eab6874696ec7c2afcd4 Mon Sep 17 00:00:00 2001 From: osoken Date: Sat, 17 Aug 2024 12:22:08 +0900 Subject: [PATCH 12/12] fix: closes #93 --- common/birdxplorer_common/models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/common/birdxplorer_common/models.py b/common/birdxplorer_common/models.py index fcd6119..28e74f0 100644 --- a/common/birdxplorer_common/models.py +++ b/common/birdxplorer_common/models.py @@ -134,15 +134,17 @@ class UpToNineteenDigitsDecimalString(BaseString): Traceback (most recent call last): ... pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str] - String should match pattern '^([0-9]{19}|)$' [type=string_pattern_mismatch, input_value='test', input_type=str] + String should match pattern '^([0-9]{18,19}|)$' [type=string_pattern_mismatch, input_value='test', input_type=str] ... >>> UpToNineteenDigitsDecimalString.from_str("1234567890123456789") UpToNineteenDigitsDecimalString('1234567890123456789') + >>> UpToNineteenDigitsDecimalString.from_str("123456789012345678") + UpToNineteenDigitsDecimalString('123456789012345678') """ @classmethod def __get_extra_constraint_dict__(cls) -> dict[str, Any]: - return dict(super().__get_extra_constraint_dict__(), pattern=r"^([0-9]{19}|)$") + return dict(super().__get_extra_constraint_dict__(), pattern=r"^([0-9]{18,19}|)$") class NonEmptyStringMixin(BaseString):