Skip to content

Commit

Permalink
Merge pull request #78 from codeforjapan/feat/add-topic-seed
Browse files Browse the repository at this point in the history
Feat/add topic seed
  • Loading branch information
yu23ki14 authored Aug 16, 2024
2 parents e3c7e37 + 2712c1a commit ed3dbd8
Show file tree
Hide file tree
Showing 12 changed files with 387 additions and 42 deletions.
17 changes: 15 additions & 2 deletions common/birdxplorer_common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,15 @@ class UpToNineteenDigitsDecimalString(BaseString):
Traceback (most recent call last):
...
pydantic_core._pydantic_core.ValidationError: 1 validation error for function-after[validate(), constrained-str]
String should match pattern '^[0-9]{1,19}$' [type=string_pattern_mismatch, input_value='test', input_type=str]
String should match pattern '^([0-9]{19}|)$' [type=string_pattern_mismatch, input_value='test', input_type=str]
...
>>> UpToNineteenDigitsDecimalString.from_str("1234567890123456789")
UpToNineteenDigitsDecimalString('1234567890123456789')
"""

@classmethod
def __get_extra_constraint_dict__(cls) -> dict[str, Any]:
return dict(super().__get_extra_constraint_dict__(), pattern=r"^[0-9]{1,19}$")
return dict(super().__get_extra_constraint_dict__(), pattern=r"^([0-9]{19}|)$")


class NonEmptyStringMixin(BaseString):
Expand Down Expand Up @@ -608,6 +608,19 @@ class LanguageIdentifier(str, Enum):
PT = "pt"
DE = "de"
FR = "fr"
FI = "fi"
TR = "tr"
NL = "nl"
HE = "he"
IT = "it"
FA = "fa"
CA = "ca"
AR = "ar"
EL = "el"
SV = "sv"
DA = "da"
RU = "ru"
PL = "pl"
OTHER = "other"

@classmethod
Expand Down
10 changes: 9 additions & 1 deletion common/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def user_enrollment_samples(
@fixture
def topic_samples(topic_factory: TopicFactory) -> Generator[List[Topic], None, None]:
topics = [
topic_factory.build(topic_id=0, label={"en": "topic0", "ja": "トピック0"}, reference_count=3),
topic_factory.build(topic_id=0, label={"en": "topic0", "ja": "トピック0"}, reference_count=4),
topic_factory.build(topic_id=1, label={"en": "topic1", "ja": "トピック1"}, reference_count=2),
topic_factory.build(topic_id=2, label={"en": "topic2", "ja": "トピック2"}, reference_count=1),
topic_factory.build(topic_id=3, label={"en": "topic3", "ja": "トピック3"}, reference_count=0),
Expand Down Expand Up @@ -160,6 +160,14 @@ def note_samples(note_factory: NoteFactory, topic_samples: List[Topic]) -> Gener
summary="summary5",
created_at=1152921604000,
),
note_factory.build(
note_id="1234567890123456786",
post_id="",
topics=[topic_samples[0]],
language="en",
summary="summary6_empty_post_id",
created_at=1152921604000,
),
]
yield notes

Expand Down
7 changes: 6 additions & 1 deletion etl/.env.example
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
X_BEARER_TOKEN=
X_BEARER_TOKEN=
AI_MODEL=
OPENAPI_TOKEN=
CLAUDE_TOKEN=
TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND=1720900800000
TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND=1722110400000
3 changes: 2 additions & 1 deletion etl/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ dependencies = [
"requests",
"pytest",
"prefect",
"stringcase"
"stringcase",
"openai"
]

[project.urls]
Expand Down
7 changes: 7 additions & 0 deletions etl/seed/fewshot_sample.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"tweet": "For those that care — 432 hz improves mental clarity, removes emotional blockages, reduces stress and anxiety, better sleep quality, increases creativity & inspiration, and strengthens the immune system. Play it while you sleep & watch these areas improve!",
"note": "There are no placebo controlled studies which support this. There is no evidence that this frequency has different effects from any other arbitrary frequency. https://ask.audio/articles/music-theory-432-hz-tuning-separating-fact-from-fiction",
"topics": [
"医療", "福祉"
]
}
64 changes: 64 additions & 0 deletions etl/seed/topic_seed.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
en,ja
European Union,欧州連合
coronavirus,コロナウイルス
Crimea,クリミア
Coup,クーデター
Sanctions,制裁
Terrorism,テロ
Sovereignty,主権
Eastern Ukraine,ウクライナ東部
Syrian War,シリア戦争
Chemical weapons/attack,化学兵器/攻撃
Elections,選挙
Protest,抗議する
WWII,第二次世界大戦
Manipulated elections/referendum,操作された選挙/国民投票
Vladimir Putin,ウラジーミル・プーチン
Migration crisis,移民の危機
Ukrainian disintegration,ウクライナ崩壊
Nuclear issues,核問題
Imperialism/colonialism,帝国主義・植民地主義
Economic difficulties,経済的困難
vaccination,予防接種
Biological weapons,生物兵器
Donald Trump,ドナルド・トランプ
election meddling,選挙介入
Media,メディア
security threat,セキュリティ上の脅威
Joe Biden,ジョー・バイデン
Human rights,人権
Democracy,民主主義
Propaganda,プロパガンダ
Civil war,内戦
Freedom of speech,言論の自由
Military exercise,軍事演習
LGBT,LGBT
Information war,情報戦
Genocide,大量虐殺
Sputnik V,スプートニクV
economy,経済
War crimes,戦争犯罪
Intelligence services,諜報機関
Energy,エネルギー
Occupation,職業
UN,国連
migration,移民・移住
Corruption,腐敗
laboratory,研究室
Censorship,検閲
Refugees,難民
fake news,フェイクニュース
scam,詐欺
technology,テクノロジー
welfare,福祉
mobility,交通
travel,観光
fashion,ファッション
mental health,メンタルヘルス
anime,アニメ
AI,AI
climate change,気候変動
food,食品
tax,税金
drugs,薬物
US presidential election,米国大統領選挙
13 changes: 13 additions & 0 deletions etl/src/birdxplorer_etl/lib/ai_model/ai_model_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from birdxplorer_etl.settings import AI_MODEL
from birdxplorer_etl.lib.openapi.open_ai_service import OpenAIService
from birdxplorer_etl.lib.claude.claude_service import ClaudeService
from birdxplorer_etl.lib.ai_model.ai_model_interface_base import AIModelInterface


def get_ai_service() -> AIModelInterface:
if AI_MODEL == "openai":
return OpenAIService()
elif AI_MODEL == "claude":
return ClaudeService()
else:
raise ValueError(f"Unsupported AI service: {AI_MODEL}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Dict, List


class AIModelInterface:
def detect_language(self, text: str) -> str:
raise NotImplementedError("detect_language method not implemented")

def detect_topic(self, note_id: int, note: str) -> Dict[str, List[str]]:
raise NotImplementedError("detect_topic method not implemented")
7 changes: 7 additions & 0 deletions etl/src/birdxplorer_etl/lib/claude/claude_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from birdxplorer_etl.settings import CLAUDE_TOKEN
from birdxplorer_etl.lib.ai_model.ai_model_interface_base import AIModelInterface


class ClaudeService(AIModelInterface):
def __init__(self):
self.api_key = CLAUDE_TOKEN
105 changes: 105 additions & 0 deletions etl/src/birdxplorer_etl/lib/openapi/open_ai_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from birdxplorer_etl.settings import OPENAPI_TOKEN
from birdxplorer_etl.lib.ai_model.ai_model_interface_base import AIModelInterface
from birdxplorer_common.models import LanguageIdentifier
from openai import OpenAI
from typing import Dict, List
import csv
import json
import os


class OpenAIService(AIModelInterface):
def __init__(self):
self.api_key = OPENAPI_TOKEN
self.client = OpenAI(api_key=self.api_key)
if os.path.exists("./data/transformed/topic.csv"):
self.topics = self.load_topics("./data/transformed/topic.csv")

def load_topics(self, topic_csv_file_path: str) -> Dict[str, int]:
topics = {}
with open(topic_csv_file_path, newline="", encoding="utf-8") as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
topic_id = int(row["topic_id"])
labels = json.loads(row["label"].replace("'", '"'))
# 日本語のラベルのみを使用するように
if "ja" in labels:
topics[labels["ja"]] = topic_id
# for label in labels.values():
# topics[label] = topic_id
return topics

def detect_language(self, text: str) -> str:
prompt = (
"Detect the language of the following text and return only the language code "
f"from this list: en, es, ja, pt, de, fr. Text: {text}. "
"Respond with only the language code, nothing else."
)

response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
temperature=0.0,
seed=1,
)

message_content = response.choices[0].message.content.strip()

if message_content in LanguageIdentifier._value2member_map_:
return LanguageIdentifier(message_content)

valid_code = next((code for code in LanguageIdentifier._value2member_map_ if code in message_content), None)

if valid_code:
return LanguageIdentifier(valid_code)

print(f"Invalid language code received: {message_content}")
# raise ValueError(f"Invalid language code received: {message_content}")
return LanguageIdentifier.normalize(message_content)

def detect_topic(self, note_id: int, note: str) -> Dict[str, List[int]]:
topic_examples = "\n".join([f"{key}: {value}" for key, value in self.topics.items()])
with open("./seed/fewshot_sample.json", newline="", encoding="utf-8") as f:
fewshot_sample = json.load(f)

prompt = f"""
以下はコミュニティノートです。
コミュニティノート:
```
{fewshot_sample["note"]}
```
このセットに対してのトピックは「{" ".join(fewshot_sample["topics"])}」です。
これを踏まえて、以下のセットに対して同じ粒度で複数のトピック(少なくとも3つ)を提示してください。
コミュニティノート:
```
{note}
```
以下のトピックは、
```
topic: topic_id
```
の形で構成されています。
こちらを使用して関連するものを推測してください。形式はJSONで、キーをtopicsとして値に必ず数字のtopic_idを配列で格納してください。
また指定された情報以外は含めないでください。
トピックの例:
{topic_examples}
"""
response = self.client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
temperature=0.0,
)
response_text = response.choices[0].message.content.strip()
response_text = response_text.replace("```json", "").replace("```", "").strip()
try:
return json.loads(response_text)
except json.JSONDecodeError as e:
print(f"Error decoding JSON: {e}")
return {}
5 changes: 5 additions & 0 deletions etl/src/birdxplorer_etl/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@
COMMUNITY_NOTE_DAYS_AGO = int(os.getenv("COMMUNITY_NOTE_DAYS_AGO", "3"))

X_BEARER_TOKEN = os.getenv("X_BEARER_TOKEN")
AI_MODEL = os.getenv("AI_MODEL")
OPENAPI_TOKEN = os.getenv("OPENAPI_TOKEN")
CLAUDE_TOKEN = os.getenv("CLAUDE_TOKEN")
TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND = os.getenv("TARGET_NOTE_ESTIMATE_TOPIC_START_UNIX_MILLISECOND")
TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND = os.getenv("TARGET_NOTE_ESTIMATE_TOPIC_END_UNIX_MILLISECOND")
Loading

0 comments on commit ed3dbd8

Please sign in to comment.