Skip to content

Commit

Permalink
Merge pull request #65 from TogetherCrew/feat/upgrade-llama-index-dep…
Browse files Browse the repository at this point in the history
…endency

feat: llama-index code migration!
  • Loading branch information
cyri113 authored Mar 13, 2024
2 parents 971bc2d + 43cdc7b commit 76eef06
Show file tree
Hide file tree
Showing 59 changed files with 360 additions and 390 deletions.
67 changes: 67 additions & 0 deletions dags/hivemind_discord_etl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
from datetime import datetime

from airflow import DAG
from airflow.decorators import task
from dotenv import load_dotenv
from hivemind_etl_helpers.discord_mongo_summary_etl import process_discord_summaries
from hivemind_etl_helpers.discord_mongo_vector_store_etl import (
process_discord_guild_mongo,
)
from hivemind_etl_helpers.src.utils.mongo_discord_communities import (
get_all_discord_communities,
)

with DAG(
dag_id="discord_vector_store_update",
start_date=datetime(2024, 1, 1),
schedule_interval="0 2 * * *",
catchup=False,
) as dag:

@task
def get_discord_communities() -> list[str]:
"""
Getting all communities having discord from database
"""
communities = get_all_discord_communities()
return communities

@task
def start_discord_vectorstore(community_id: str):
load_dotenv()
logging.info(f"Working on community, {community_id}")
process_discord_guild_mongo(community_id=community_id)
logging.info(f"Community {community_id} Job finished!")

communities = get_discord_communities()
# `start_discord_vectorstore` will be called multiple times
# with the length of the list
start_discord_vectorstore.expand(community_id=communities)


with DAG(
dag_id="discord_summary_vector_store",
start_date=datetime(2024, 1, 1),
schedule_interval="0 2 * * *",
) as dag:

@task
def get_mongo_discord_communities() -> list[str]:
"""
Getting all communities having discord from database
this function is the same with `get_discord_communities`
we just changed the name for the pylint
"""
communities = get_all_discord_communities()
return communities

@task
def start_discord_summary_vectorstore(community_id: str):
load_dotenv()
logging.info(f"Working on community, {community_id}")
process_discord_summaries(community_id=community_id, verbose=False)
logging.info(f"Community {community_id} Job finished!")

communities = get_mongo_discord_communities()
start_discord_summary_vectorstore.expand(community_id=communities)
57 changes: 57 additions & 0 deletions dags/hivemind_discourse_etl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
from datetime import datetime

from airflow import DAG
from airflow.decorators import task
from hivemind_etl_helpers.discourse_summary_etl import process_discourse_summary
from hivemind_etl_helpers.discourse_vectorstore_etl import process_discourse_vectorstore
from hivemind_etl_helpers.src.utils.get_communities_data import (
get_discourse_communities,
)

with DAG(
dag_id="discourse_vector_store",
start_date=datetime(2024, 3, 1),
schedule_interval="0 2 * * *",
) as dag:

@task
def process_discourse_community(community_information: dict[str, str | datetime]):
community_id = community_information["community_id"]
forum_endpoint = community_information["endpoint"]
from_date = community_information["from_date"]

logging.info(f"Starting Discourse ETL | community_id: {community_id}")
process_discourse_vectorstore(
community_id=community_id,
forum_endpoint=forum_endpoint,
from_starting_date=from_date,
)

communities_info = get_discourse_communities()
process_discourse_community.expand(community_information=communities_info)


with DAG(
dag_id="discourse_summary_vector_store",
start_date=datetime(2024, 2, 21),
schedule_interval="0 2 * * *",
) as dag:

@task
def process_discourse_community_summary(
community_information: dict[str, str | datetime]
):
community_id = community_information["community_id"]
forum_endpoint = community_information["endpoint"]
from_date = community_information["from_date"]

logging.info(f"Starting Discourse ETL | community_id: {community_id}")
process_discourse_summary(
community_id=community_id,
forum_endpoint=forum_endpoint,
from_starting_date=from_date,
)

communities_info = get_discourse_communities()
process_discourse_community_summary.expand(community_information=communities_info)
146 changes: 0 additions & 146 deletions dags/hivemind_etl.py

This file was deleted.

12 changes: 7 additions & 5 deletions dags/hivemind_etl_helpers/discord_mongo_summary_etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
)
from hivemind_etl_helpers.src.document_node_parser import configure_node_parser
from hivemind_etl_helpers.src.utils.sort_summary_docs import sort_summaries_daily
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.core import Settings
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.llms.openai import OpenAI
from tc_hivemind_backend.db.pg_db_utils import setup_db
from tc_hivemind_backend.db.utils.model_hyperparams import load_model_hyperparams
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
Expand Down Expand Up @@ -93,16 +95,16 @@ def process_discord_summaries(community_id: str, verbose: bool = False) -> None:
node_parser = configure_node_parser(chunk_size=chunk_size)
pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

embed_model = CohereEmbedding()
Settings.node_parser = node_parser
Settings.embed_model = CohereEmbedding()
Settings.chunk_size = chunk_size
Settings.llm = OpenAI(model="gpt-3.5-turbo")

pg_vector.save_documents_in_batches(
community_id=community_id,
documents=docs_daily_sorted,
batch_size=100,
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embedding_dim,
request_per_minute=10000,
deletion_query=deletion_query,
)
Expand Down
22 changes: 6 additions & 16 deletions dags/hivemind_etl_helpers/discord_mongo_vector_store_etl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import logging
from datetime import timedelta

Expand All @@ -9,6 +8,8 @@
find_guild_id_by_community_id,
)
from hivemind_etl_helpers.src.document_node_parser import configure_node_parser
from llama_index.core import Settings
from llama_index.llms.openai import OpenAI
from tc_hivemind_backend.db.pg_db_utils import setup_db
from tc_hivemind_backend.db.utils.model_hyperparams import load_model_hyperparams
from tc_hivemind_backend.embeddings.cohere import CohereEmbedding
Expand Down Expand Up @@ -52,27 +53,16 @@ def process_discord_guild_mongo(community_id: str) -> None:
node_parser = configure_node_parser(chunk_size=chunk_size)
pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

embed_model = CohereEmbedding()
Settings.node_parser = node_parser
Settings.embed_model = CohereEmbedding()
Settings.chunk_size = chunk_size
Settings.llm = OpenAI(model="gpt-3.5-turbo")

pg_vector.save_documents_in_batches(
community_id=community_id,
documents=documents,
batch_size=100,
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embedding_dim,
request_per_minute=10000,
# max_request_per_day=REQUEST_PER_DAY,
)


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser()
parser.add_argument(
"community_id", type=str, help="the Community that the guild is related to"
)
args = parser.parse_args()

process_discord_guild_mongo(community_id=args.community_id)
14 changes: 8 additions & 6 deletions dags/hivemind_etl_helpers/discourse_summary_etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from hivemind_etl_helpers.src.db.discourse.utils.get_forums import get_forum_uuid
from hivemind_etl_helpers.src.document_node_parser import configure_node_parser
from hivemind_etl_helpers.src.utils.sort_summary_docs import sort_summaries_daily
from llama_index import Document
from llama_index.response_synthesizers import get_response_synthesizer
from llama_index.core import Document, Settings
from llama_index.core.response_synthesizers import get_response_synthesizer
from llama_index.llms.openai import OpenAI
from neo4j._data import Record
from tc_hivemind_backend.db.pg_db_utils import setup_db
from tc_hivemind_backend.db.utils.model_hyperparams import load_model_hyperparams
Expand Down Expand Up @@ -133,8 +134,6 @@ def process_forum(
node_parser = configure_node_parser(chunk_size=chunk_size)
pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

embed_model = CohereEmbedding()

sorted_daily_docs = sort_summaries_daily(
level1_docs=topic_summary_documents,
level2_docs=category_summary_documenets,
Expand All @@ -145,13 +144,16 @@ def process_forum(
f"{log_prefix} Saving discourse summaries (extracting the embedding and saving)"
)

Settings.node_parser = node_parser
Settings.embed_model = CohereEmbedding()
Settings.chunk_size = chunk_size
Settings.llm = OpenAI(model="gpt-3.5-turbo")

pg_vector.save_documents_in_batches(
community_id=community_id,
documents=sorted_daily_docs,
batch_size=100,
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embedding_dim,
request_per_minute=10000,
deletion_query=deletion_query,
Expand Down
Loading

0 comments on commit 76eef06

Please sign in to comment.