diff --git a/api/tests/routers/test_data.py b/api/tests/routers/test_data.py index 86356ab..2fe9ac6 100644 --- a/api/tests/routers/test_data.py +++ b/api/tests/routers/test_data.py @@ -30,6 +30,16 @@ def test_posts_get(client: TestClient, post_samples: List[Post]) -> None: } +def test_posts_get_limit_and_offset(client: TestClient, post_samples: List[Post]) -> None: + response = client.get("/api/v1/data/posts/?limit=2&offset=1") + assert response.status_code == 200 + res_json = response.json() + assert res_json == { + "data": [json.loads(d.model_dump_json()) for d in post_samples[1:3]], + "meta": {"next": None, "prev": "http://testserver/api/v1/data/posts?offset=0&limit=2"}, + } + + def test_posts_get_has_post_id_filter(client: TestClient, post_samples: List[Post]) -> None: response = client.get(f"/api/v1/data/posts/?postId={post_samples[0].post_id},{post_samples[2].post_id}") assert response.status_code == 200 diff --git a/common/birdxplorer_common/models.py b/common/birdxplorer_common/models.py index b5c2f87..06db325 100644 --- a/common/birdxplorer_common/models.py +++ b/common/birdxplorer_common/models.py @@ -134,15 +134,17 @@ class UpToNineteenDigitsDecimalString(BaseString): Traceback (most recent call last): ... pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str] - String should match pattern '^[0-9]{1,19}$' [type=string_pattern_mismatch, input_value='test', input_type=str] + String should match pattern '^([0-9]{18,19}|)$' [type=string_pattern_mismatch, input_value='test', input_type=str] ... >>> UpToNineteenDigitsDecimalString.from_str("1234567890123456789") UpToNineteenDigitsDecimalString('1234567890123456789') + >>> UpToNineteenDigitsDecimalString.from_str("123456789012345678") + UpToNineteenDigitsDecimalString('123456789012345678') """ @classmethod def __get_extra_constraint_dict__(cls) -> dict[str, Any]: - return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9]{1,19}$") + return dict(super().__get_extra_constraint_dict__(), pattern=r"^([0-9]{18,19}|)$") class NonEmptyStringMixin(BaseString): @@ -608,6 +610,19 @@ class LanguageIdentifier(str, Enum): PT = "pt" DE = "de" FR = "fr" + FI = "fi" + TR = "tr" + NL = "nl" + HE = "he" + IT = "it" + FA = "fa" + CA = "ca" + AR = "ar" + EL = "el" + SV = "sv" + DA = "da" + RU = "ru" + PL = "pl" OTHER = "other" @classmethod diff --git a/common/tests/conftest.py b/common/tests/conftest.py index a2b1094..71ab36d 100644 --- a/common/tests/conftest.py +++ b/common/tests/conftest.py @@ -109,7 +109,7 @@ def user_enrollment_samples( @fixture def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, None]: topics = [ - topic_factory.build(topic_id=0, label={"en": "topic0", "ja": "トピック0"}, reference_count=3), + topic_factory.build(topic_id=0, label={"en": "topic0", "ja": "トピック0"}, reference_count=4), topic_factory.build(topic_id=1, label={"en": "topic1", "ja": "トピック1"}, reference_count=2), topic_factory.build(topic_id=2, label={"en": "topic2", "ja": "トピック2"}, reference_count=1), topic_factory.build(topic_id=3, label={"en": "topic3", "ja": "トピック3"}, reference_count=0), @@ -160,6 +160,14 @@ def note_samples(note_factory: NoteFactory, topic_samples: List[Topic]) -> Gener summary="summary5", created_at=1152921604000, ), + note_factory.build( + note_id="1234567890123456786", + post_id="", + topics=[topic_samples[0]], + language="en", + summary="summary6_empty_post_id", + created_at=1152921604000, + ), ] yield notes diff --git a/etl/.env.example b/etl/.env.example index 3d2cf0e..112d707 100644 --- a/etl/.env.example +++ b/etl/.env.example @@ -1 +1,6 @@ -X_BEARER_TOKEN= \ No newline at end of file +X_BEARER_TOKEN= +AI_MODEL= +OPENAPI_TOKEN= +CLAUDE_TOKEN= +TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND=1720900800000 +TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND=1722110400000 \ No newline at end of file diff --git a/etl/pyproject.toml b/etl/pyproject.toml index cc45a14..41d9b91 100644 --- a/etl/pyproject.toml +++ b/etl/pyproject.toml @@ -29,7 +29,8 @@ dependencies = [ "requests", "pytest", "prefect", - "stringcase" + "stringcase", + "openai" ] [project.urls] diff --git a/etl/seed/fewshot_sample.json b/etl/seed/fewshot_sample.json new file mode 100644 index 0000000..1e8e4bf --- /dev/null +++ b/etl/seed/fewshot_sample.json @@ -0,0 +1,7 @@ +{ + "tweet": "For those that care — 432 hz improves mental clarity, removes emotional blockages, reduces stress and anxiety, better sleep quality, increases creativity & inspiration, and strengthens the immune system. Play it while you sleep & watch these areas improve!", + "note": "There are no placebo controlled studies which support this. There is no evidence that this frequency has different effects from any other arbitrary frequency. https://ask.audio/articles/music-theory-432-hz-tuning-separating-fact-from-fiction", + "topics": [ + "医療", "福祉" + ] +} \ No newline at end of file diff --git a/etl/seed/topic_seed.csv b/etl/seed/topic_seed.csv new file mode 100644 index 0000000..9c28f8b --- /dev/null +++ b/etl/seed/topic_seed.csv @@ -0,0 +1,64 @@ +en,ja +European Union,欧州連合 +coronavirus,コロナウイルス +Crimea,クリミア +Coup,クーデター +Sanctions,制裁 +Terrorism,テロ +Sovereignty,主権 +Eastern Ukraine,ウクライナ東部 +Syrian War,シリア戦争 +Chemical weapons/attack,化学兵器/攻撃 +Elections,選挙 +Protest,抗議する +WWII,第二次世界大戦 +Manipulated elections/referendum,操作された選挙/国民投票 +Vladimir Putin,ウラジーミル・プーチン +Migration crisis,移民の危機 +Ukrainian disintegration,ウクライナ崩壊 +Nuclear issues,核問題 +Imperialism/colonialism,帝国主義・植民地主義 +Economic difficulties,経済的困難 +vaccination,予防接種 +Biological weapons,生物兵器 +Donald Trump,ドナルド・トランプ +election meddling,選挙介入 +Media,メディア +security threat,セキュリティ上の脅威 +Joe Biden,ジョー・バイデン +Human rights,人権 +Democracy,民主主義 +Propaganda,プロパガンダ +Civil war,内戦 +Freedom of speech,言論の自由 +Military exercise,軍事演習 +LGBT,LGBT +Information war,情報戦 +Genocide,大量虐殺 +Sputnik V,スプートニクV +economy,経済 +War crimes,戦争犯罪 +Intelligence services,諜報機関 +Energy,エネルギー +Occupation,職業 +UN,国連 +migration,移民・移住 +Corruption,腐敗 +laboratory,研究室 +Censorship,検閲 +Refugees,難民 +fake news,フェイクニュース +scam,詐欺 +technology,テクノロジー +welfare,福祉 +mobility,交通 +travel,観光 +fashion,ファッション +mental health,メンタルヘルス +anime,アニメ +AI,AI +climate change,気候変動 +food,食品 +tax,税金 +drugs,薬物 +US presidential election,米国大統領選挙 \ No newline at end of file diff --git a/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface.py b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface.py new file mode 100644 index 0000000..dd031c4 --- /dev/null +++ b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface.py @@ -0,0 +1,13 @@ +from birdxplorer_etl.settings import AI_MODEL +from birdxplorer_etl.lib.openapi.open_ai_service import OpenAIService +from birdxplorer_etl.lib.claude.claude_service import ClaudeService +from birdxplorer_etl.lib.ai_model.ai_model_interface_base import AIModelInterface + + +def get_ai_service() -> AIModelInterface: + if AI_MODEL == "openai": + return OpenAIService() + elif AI_MODEL == "claude": + return ClaudeService() + else: + raise ValueError(f"Unsupported AI service: {AI_MODEL}") diff --git a/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py new file mode 100644 index 0000000..34f4984 --- /dev/null +++ b/etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface_base.py @@ -0,0 +1,9 @@ +from typing import Dict, List + + +class AIModelInterface: + def detect_language(self, text: str) -> str: + raise NotImplementedError("detect_language method not implemented") + + def detect_topic(self, note_id: int, note: str) -> Dict[str, List[str]]: + raise NotImplementedError("detect_topic method not implemented") diff --git a/etl/src/birdxplorer_etl/lib/claude/claude_service.py b/etl/src/birdxplorer_etl/lib/claude/claude_service.py new file mode 100644 index 0000000..1ecfe26 --- /dev/null +++ b/etl/src/birdxplorer_etl/lib/claude/claude_service.py @@ -0,0 +1,7 @@ +from birdxplorer_etl.settings import CLAUDE_TOKEN +from birdxplorer_etl.lib.ai_model.ai_model_interface_base import AIModelInterface + + +class ClaudeService(AIModelInterface): + def __init__(self): + self.api_key = CLAUDE_TOKEN diff --git a/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py new file mode 100644 index 0000000..048a4f5 --- /dev/null +++ b/etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py @@ -0,0 +1,105 @@ +from birdxplorer_etl.settings import OPENAPI_TOKEN +from birdxplorer_etl.lib.ai_model.ai_model_interface_base import AIModelInterface +from birdxplorer_common.models import LanguageIdentifier +from openai import OpenAI +from typing import Dict, List +import csv +import json +import os + + +class OpenAIService(AIModelInterface): + def __init__(self): + self.api_key = OPENAPI_TOKEN + self.client = OpenAI(api_key=self.api_key) + if os.path.exists("./data/transformed/topic.csv"): + self.topics = self.load_topics("./data/transformed/topic.csv") + + def load_topics(self, topic_csv_file_path: str) -> Dict[str, int]: + topics = {} + with open(topic_csv_file_path, newline="", encoding="utf-8") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + topic_id = int(row["topic_id"]) + labels = json.loads(row["label"].replace("'", '"')) + # 日本語のラベルのみを使用するように + if "ja" in labels: + topics[labels["ja"]] = topic_id + # for label in labels.values(): + # topics[label] = topic_id + return topics + + def detect_language(self, text: str) -> str: + prompt = ( + "Detect the language of the following text and return only the language code " + f"from this list: en, es, ja, pt, de, fr. Text: {text}. " + "Respond with only the language code, nothing else." + ) + + response = self.client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + temperature=0.0, + seed=1, + ) + + message_content = response.choices[0].message.content.strip() + + if message_content in LanguageIdentifier._value2member_map_: + return LanguageIdentifier(message_content) + + valid_code = next((code for code in LanguageIdentifier._value2member_map_ if code in message_content), None) + + if valid_code: + return LanguageIdentifier(valid_code) + + print(f"Invalid language code received: {message_content}") + # raise ValueError(f"Invalid language code received: {message_content}") + return LanguageIdentifier.normalize(message_content) + + def detect_topic(self, note_id: int, note: str) -> Dict[str, List[int]]: + topic_examples = "\n".join([f"{key}: {value}" for key, value in self.topics.items()]) + with open("./seed/fewshot_sample.json", newline="", encoding="utf-8") as f: + fewshot_sample = json.load(f) + + prompt = f""" + 以下はコミュニティノートです。 + コミュニティノート: + ``` + {fewshot_sample["note"]} + ``` + このセットに対してのトピックは「{" ".join(fewshot_sample["topics"])}」です。 + これを踏まえて、以下のセットに対して同じ粒度で複数のトピック(少なくとも3つ)を提示してください。 + コミュニティノート: + ``` + {note} + ``` + 以下のトピックは、 + ``` + topic: topic_id + ``` + の形で構成されています。 + こちらを使用して関連するものを推測してください。形式はJSONで、キーをtopicsとして値に必ず数字のtopic_idを配列で格納してください。 + また指定された情報以外は含めないでください。 + + トピックの例: + {topic_examples} + """ + response = self.client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + temperature=0.0, + ) + response_text = response.choices[0].message.content.strip() + response_text = response_text.replace("```json", "").replace("```", "").strip() + try: + return json.loads(response_text) + except json.JSONDecodeError as e: + print(f"Error decoding JSON: {e}") + return {} diff --git a/etl/src/birdxplorer_etl/settings.py b/etl/src/birdxplorer_etl/settings.py index 7a12c70..d6e9f2e 100644 --- a/etl/src/birdxplorer_etl/settings.py +++ b/etl/src/birdxplorer_etl/settings.py @@ -12,3 +12,8 @@ COMMUNITY_NOTE_DAYS_AGO = int(os.getenv("COMMUNITY_NOTE_DAYS_AGO", "3")) X_BEARER_TOKEN = os.getenv("X_BEARER_TOKEN") +AI_MODEL = os.getenv("AI_MODEL") +OPENAPI_TOKEN = os.getenv("OPENAPI_TOKEN") +CLAUDE_TOKEN = os.getenv("CLAUDE_TOKEN") +TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND = os.getenv("TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND") +TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND = os.getenv("TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND") diff --git a/etl/src/birdxplorer_etl/transform.py b/etl/src/birdxplorer_etl/transform.py index 40cd00b..acefbc2 100644 --- a/etl/src/birdxplorer_etl/transform.py +++ b/etl/src/birdxplorer_etl/transform.py @@ -1,13 +1,19 @@ -import logging -from sqlalchemy import select, func +from sqlalchemy import select, func, and_, Integer from sqlalchemy.orm import Session from birdxplorer_common.storage import RowNoteRecord, RowPostRecord, RowUserRecord +from birdxplorer_etl.lib.ai_model.ai_model_interface import get_ai_service +from birdxplorer_etl.settings import ( + TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND, + TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, +) import csv import os +from prefect import get_run_logger def transform_data(db: Session): - logging.info("Transforming data") + logger = get_run_logger() + logger.info("Transforming data") if not os.path.exists("./data/transformed"): os.makedirs("./data/transformed") @@ -15,31 +21,61 @@ def transform_data(db: Session): # Transform row note data and generate note.csv if os.path.exists("./data/transformed/note.csv"): os.remove("./data/transformed/note.csv") + with open("./data/transformed/note.csv", "a") as file: + writer = csv.writer(file) + writer.writerow(["note_id", "post_id", "summary", "created_at", "language"]) offset = 0 limit = 1000 - - num_of_notes = db.query(func.count(RowNoteRecord.note_id)).scalar() - - while offset < num_of_notes: - notes = db.execute( - select( - RowNoteRecord.note_id, RowNoteRecord.row_post_id, RowNoteRecord.summary, RowNoteRecord.created_at_millis + ai_service = get_ai_service() + + num_of_notes = ( + db.query(func.count(RowNoteRecord.note_id)) + .filter( + and_( + RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, + RowNoteRecord.created_at_millis >= TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND, ) - .limit(limit) - .offset(offset) ) + .scalar() + ) + + with open("./data/transformed/note.csv", "a") as file: + + logger.info(f"Transforming note data: {num_of_notes}") + while offset < num_of_notes: + notes = db.execute( + select( + RowNoteRecord.note_id, + RowNoteRecord.row_post_id, + RowNoteRecord.summary, + func.cast(RowNoteRecord.created_at_millis, Integer).label("created_at"), + ) + .filter( + and_( + RowNoteRecord.created_at_millis <= TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND, + RowNoteRecord.created_at_millis >= TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND, + ) + ) + .limit(limit) + .offset(offset) + ) - with open("./data/transformed/note.csv", "a") as file: - writer = csv.writer(file) - writer.writerow(["note_id", "post_id", "summary", "created_at"]) for note in notes: - writer.writerow(note) - offset += limit + note_as_list = list(note) + note_as_list.append(ai_service.detect_language(note[2])) + writer = csv.writer(file) + writer.writerow(note_as_list) + offset += limit # Transform row post data and generate post.csv + logger.info("Transforming post data") + if os.path.exists("./data/transformed/post.csv"): os.remove("./data/transformed/post.csv") + with open("./data/transformed/post.csv", "a") as file: + writer = csv.writer(file) + writer.writerow(["post_id", "user_id", "text", "created_at", "like_count", "repost_count", "impression_count"]) offset = 0 limit = 1000 @@ -52,27 +88,35 @@ def transform_data(db: Session): RowPostRecord.post_id, RowPostRecord.author_id.label("user_id"), RowPostRecord.text, - RowPostRecord.created_at, - RowPostRecord.like_count, - RowPostRecord.repost_count, - RowPostRecord.impression_count, + func.cast(RowPostRecord.created_at, Integer).label("created_at"), + func.cast(RowPostRecord.like_count, Integer).label("like_count"), + func.cast(RowPostRecord.repost_count, Integer).label("repost_count"), + func.cast(RowPostRecord.impression_count, Integer).label("impression_count"), ) .limit(limit) .offset(offset) ) with open("./data/transformed/post.csv", "a") as file: - writer = csv.writer(file) - writer.writerow( - ["post_id", "user_id", "text", "created_at", "like_count", "repost_count", "impression_count"] - ) for post in posts: + writer = csv.writer(file) writer.writerow(post) offset += limit # Transform row user data and generate user.csv if os.path.exists("./data/transformed/user.csv"): os.remove("./data/transformed/user.csv") + with open("./data/transformed/user.csv", "a") as file: + writer = csv.writer(file) + writer.writerow( + [ + "user_id", + "name", + "profile_image", + "followers_count", + "following_count", + ] + ) offset = 0 limit = 1000 @@ -85,26 +129,90 @@ def transform_data(db: Session): RowUserRecord.user_id, RowUserRecord.user_name.label("name"), RowUserRecord.profile_image_url.label("profile_image"), - RowUserRecord.followers_count, - RowUserRecord.following_count, + func.cast(RowUserRecord.followers_count, Integer).label("followers_count"), + func.cast(RowUserRecord.following_count, Integer).label("following_count"), ) .limit(limit) .offset(offset) ) with open("./data/transformed/user.csv", "a") as file: - writer = csv.writer(file) - writer.writerow( - [ - "user_id", - "name", - "profile_image", - "followers_count", - "following_count", - ] - ) for user in users: + writer = csv.writer(file) writer.writerow(user) offset += limit + csv_seed_file_path = "./seed/topic_seed.csv" + output_csv_file_path = "./data/transformed/topic.csv" + records = [] + + if os.path.exists(output_csv_file_path): + return + + with open(csv_seed_file_path, newline="", encoding="utf-8") as csvfile: + reader = csv.DictReader(csvfile) + for index, row in enumerate(reader): + if "ja" in row and row["ja"]: + topic_id = index + 1 + label = {"ja": row["ja"], "en": row["en"]} # Assuming the label is in Japanese + record = {"topic_id": topic_id, "label": label} + records.append(record) + + with open(output_csv_file_path, "a", newline="", encoding="utf-8") as file: + fieldnames = ["topic_id", "label"] + writer = csv.DictWriter(file, fieldnames=fieldnames) + writer.writeheader() + for record in records: + writer.writerow({"topic_id": record["topic_id"], "label": {k: v for k, v in record["label"].items()}}) + + generate_note_topic() + return + + +def generate_note_topic(): + note_csv_file_path = "./data/transformed/note.csv" + output_csv_file_path = "./data/transformed/note_topic_association.csv" + ai_service = get_ai_service() + + records = [] + with open(output_csv_file_path, "w", newline="", encoding="utf-8", buffering=1) as file: + fieldnames = ["note_id", "topic_id"] + writer = csv.DictWriter(file, fieldnames=fieldnames) + writer.writeheader() + + with open(note_csv_file_path, newline="", encoding="utf-8") as csvfile: + reader = csv.DictReader(csvfile) + for index, row in enumerate(reader): + note_id = row["note_id"] + summary = row["summary"] + topics_info = ai_service.detect_topic(note_id, summary) + if topics_info: + for topic in topics_info.get("topics", []): + record = {"note_id": note_id, "topic_id": topic} + records.append(record) + + if index % 100 == 0: + for record in records: + writer.writerow( + { + "note_id": record["note_id"], + "topic_id": record["topic_id"], + } + ) + records = [] + print(index) + + for record in records: + writer.writerow( + { + "note_id": record["note_id"], + "topic_id": record["topic_id"], + } + ) + + print(f"New CSV file has been created at {output_csv_file_path}") + + +if __name__ == "__main__": + generate_note_topic()