Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: llama-index code migration! #65

Merged
merged 11 commits into from
Mar 13, 2024
67 changes: 67 additions & 0 deletions dags/hivemind_discord_etl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
from datetime import datetime

from airflow import DAG
from airflow.decorators import task
from dotenv import load_dotenv
from hivemind_etl_helpers.discord_mongo_summary_etl import process_discord_summaries
from hivemind_etl_helpers.discord_mongo_vector_store_etl import (
process_discord_guild_mongo,
)
from hivemind_etl_helpers.src.utils.mongo_discord_communities import (
get_all_discord_communities,
)

with DAG(
dag_id="discord_vector_store_update",
start_date=datetime(2024, 1, 1),
schedule_interval="0 2 * * *",
catchup=False,
) as dag:

@task
def get_discord_communities() -> list[str]:
"""
Getting all communities having discord from database
"""
communities = get_all_discord_communities()
return communities

@task
def start_discord_vectorstore(community_id: str):
load_dotenv()
logging.info(f"Working on community, {community_id}")
process_discord_guild_mongo(community_id=community_id)
logging.info(f"Community {community_id} Job finished!")

communities = get_discord_communities()
# `start_discord_vectorstore` will be called multiple times
# with the length of the list
start_discord_vectorstore.expand(community_id=communities)


with DAG(
dag_id="discord_summary_vector_store",
start_date=datetime(2024, 1, 1),
schedule_interval="0 2 * * *",
) as dag:

@task
def get_mongo_discord_communities() -> list[str]:
"""
Getting all communities having discord from database
this function is the same with `get_discord_communities`
we just changed the name for the pylint
"""
communities = get_all_discord_communities()
return communities

@task
def start_discord_summary_vectorstore(community_id: str):
load_dotenv()
logging.info(f"Working on community, {community_id}")
process_discord_summaries(community_id=community_id, verbose=False)
logging.info(f"Community {community_id} Job finished!")

communities = get_mongo_discord_communities()
start_discord_summary_vectorstore.expand(community_id=communities)
57 changes: 57 additions & 0 deletions dags/hivemind_discourse_etl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
from datetime import datetime

from airflow import DAG
from airflow.decorators import task
from hivemind_etl_helpers.discourse_summary_etl import process_discourse_summary
from hivemind_etl_helpers.discourse_vectorstore_etl import process_discourse_vectorstore
from hivemind_etl_helpers.src.utils.get_communities_data import (
get_discourse_communities,
)

with DAG(
dag_id="discourse_vector_store",
start_date=datetime(2024, 3, 1),
schedule_interval="0 2 * * *",
) as dag:

@task
def process_discourse_community(community_information: dict[str, str | datetime]):
community_id = community_information["community_id"]
forum_endpoint = community_information["endpoint"]
from_date = community_information["from_date"]

logging.info(f"Starting Discourse ETL | community_id: {community_id}")
process_discourse_vectorstore(
community_id=community_id,
forum_endpoint=forum_endpoint,
from_starting_date=from_date,
)

communities_info = get_discourse_communities()
process_discourse_community.expand(community_information=communities_info)


with DAG(
dag_id="discourse_summary_vector_store",
start_date=datetime(2024, 2, 21),
schedule_interval="0 2 * * *",
) as dag:

@task
def process_discourse_community_summary(
community_information: dict[str, str | datetime]
):
community_id = community_information["community_id"]
forum_endpoint = community_information["endpoint"]
from_date = community_information["from_date"]

logging.info(f"Starting Discourse ETL | community_id: {community_id}")
process_discourse_summary(
community_id=community_id,
forum_endpoint=forum_endpoint,
from_starting_date=from_date,
)

communities_info = get_discourse_communities()
process_discourse_community_summary.expand(community_information=communities_info)
146 changes: 0 additions & 146 deletions dags/hivemind_etl.py

This file was deleted.

12 changes: 7 additions & 5 deletions dags/hivemind_etl_helpers/discord_mongo_summary_etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
)
from hivemind_etl_helpers.src.document_node_parser import configure_node_parser
from hivemind_etl_helpers.src.utils.sort_summary_docs import sort_summaries_daily
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.core import Settings
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.llms.openai import OpenAI
from tc_hivemind_backend.db.pg_db_utils import setup_db
from tc_hivemind_backend.db.utils.model_hyperparams import load_model_hyperparams
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
Expand Down Expand Up @@ -93,16 +95,16 @@ def process_discord_summaries(community_id: str, verbose: bool = False) -> None:
node_parser = configure_node_parser(chunk_size=chunk_size)
pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

embed_model = CohereEmbedding()
Settings.node_parser = node_parser
Settings.embed_model = CohereEmbedding()
Settings.chunk_size = chunk_size
Settings.llm = OpenAI(model="gpt-3.5-turbo")

pg_vector.save_documents_in_batches(
community_id=community_id,
documents=docs_daily_sorted,
batch_size=100,
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embedding_dim,
request_per_minute=10000,
deletion_query=deletion_query,
)
Expand Down
22 changes: 6 additions & 16 deletions dags/hivemind_etl_helpers/discord_mongo_vector_store_etl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import logging
from datetime import timedelta

Expand All @@ -9,6 +8,8 @@
find_guild_id_by_community_id,
)
from hivemind_etl_helpers.src.document_node_parser import configure_node_parser
from llama_index.core import Settings
from llama_index.llms.openai import OpenAI
from tc_hivemind_backend.db.pg_db_utils import setup_db
from tc_hivemind_backend.db.utils.model_hyperparams import load_model_hyperparams
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
Expand Down Expand Up @@ -52,27 +53,16 @@ def process_discord_guild_mongo(community_id: str) -> None:
node_parser = configure_node_parser(chunk_size=chunk_size)
pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

embed_model = CohereEmbedding()
Settings.node_parser = node_parser
Settings.embed_model = CohereEmbedding()
Settings.chunk_size = chunk_size
Settings.llm = OpenAI(model="gpt-3.5-turbo")

pg_vector.save_documents_in_batches(
community_id=community_id,
documents=documents,
batch_size=100,
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embedding_dim,
request_per_minute=10000,
# max_request_per_day=REQUEST_PER_DAY,
)


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument(
"community_id", type=str, help="the Community that the guild is related to"
)
args = parser.parse_args()

process_discord_guild_mongo(community_id=args.community_id)
14 changes: 8 additions & 6 deletions dags/hivemind_etl_helpers/discourse_summary_etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from hivemind_etl_helpers.src.db.discourse.utils.get_forums import get_forum_uuid
from hivemind_etl_helpers.src.document_node_parser import configure_node_parser
from hivemind_etl_helpers.src.utils.sort_summary_docs import sort_summaries_daily
from llama_index import Document
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.core import Document, Settings
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.llms.openai import OpenAI
from neo4j._data import Record
from tc_hivemind_backend.db.pg_db_utils import setup_db
from tc_hivemind_backend.db.utils.model_hyperparams import load_model_hyperparams
Expand Down Expand Up @@ -133,8 +134,6 @@ def process_forum(
node_parser = configure_node_parser(chunk_size=chunk_size)
pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

embed_model = CohereEmbedding()

sorted_daily_docs = sort_summaries_daily(
level1_docs=topic_summary_documents,
level2_docs=category_summary_documenets,
Expand All @@ -145,13 +144,16 @@ def process_forum(
f"{log_prefix} Saving discourse summaries (extracting the embedding and saving)"
)

Settings.node_parser = node_parser
Settings.embed_model = CohereEmbedding()
Settings.chunk_size = chunk_size
Settings.llm = OpenAI(model="gpt-3.5-turbo")

pg_vector.save_documents_in_batches(
community_id=community_id,
documents=sorted_daily_docs,
batch_size=100,
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embedding_dim,
request_per_minute=10000,
deletion_query=deletion_query,
Expand Down
Loading
Loading