Skip to content

Commit

Permalink
feat: Added the embedding_dim and chunk_size as env variables!
Browse files Browse the repository at this point in the history
  • Loading branch information
amindadgar committed Dec 20, 2023
1 parent fafe587 commit c96d92d
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 21 deletions.
12 changes: 6 additions & 6 deletions dags/hivemind_etl_helpers/discord_mongo_summary_etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from hivemind_etl_helpers.src.document_node_parser import configure_node_parser
from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding
from hivemind_etl_helpers.src.utils.load_llm_params import load_model_hyperparams
from hivemind_etl_helpers.src.utils.pg_db_utils import setup_db
from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess
from llama_index.response_synthesizers import get_response_synthesizer
Expand All @@ -27,8 +28,8 @@ def process_discord_summaries(community_id: str, verbose: bool = False) -> None:
verbose the process of summarization or not
if `True` the summarization process will be printed out
default is `False`
"""
chunk_size, embedding_dim = load_model_hyperparams()
guild_id = find_guild_id_by_community_id(community_id)
logging.info(f"COMMUNITYID: {community_id}, GUILDID: {guild_id}")
table_name = "discord_summary"
Expand Down Expand Up @@ -62,11 +63,10 @@ def process_discord_summaries(community_id: str, verbose: bool = False) -> None:

logging.info("Getting the summaries embedding and saving within database!")

node_parser = configure_node_parser(chunk_size=256)
node_parser = configure_node_parser(chunk_size=chunk_size)
pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

embed_model = CohereEmbedding()
embed_dim = 1024

# saving thread summaries
pg_vector.save_documents_in_batches(
Expand All @@ -76,7 +76,7 @@ def process_discord_summaries(community_id: str, verbose: bool = False) -> None:
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embed_dim,
embed_dim=embedding_dim,
request_per_minute=10000,
)

Expand All @@ -88,7 +88,7 @@ def process_discord_summaries(community_id: str, verbose: bool = False) -> None:
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embed_dim,
embed_dim=embedding_dim,
request_per_minute=10000,
)

Expand All @@ -100,7 +100,7 @@ def process_discord_summaries(community_id: str, verbose: bool = False) -> None:
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embed_dim,
embed_dim=embedding_dim,
request_per_minute=10000,
)

Expand Down
7 changes: 4 additions & 3 deletions dags/hivemind_etl_helpers/discord_mongo_vector_store_etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from hivemind_etl_helpers.src.document_node_parser import configure_node_parser
from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding
from hivemind_etl_helpers.src.utils.load_llm_params import load_model_hyperparams
from hivemind_etl_helpers.src.utils.pg_db_utils import setup_db
from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess

Expand All @@ -24,6 +25,7 @@ def process_discord_guild_mongo(community_id: str) -> None:
community_id : str
the community id to create or use its database
"""
chunk_size, embedding_dim = load_model_hyperparams()
guild_id = find_guild_id_by_community_id(community_id)
logging.info(f"COMMUNITYID: {community_id}, GUILDID: {guild_id}")
table_name = "discord"
Expand All @@ -47,11 +49,10 @@ def process_discord_guild_mongo(community_id: str) -> None:
from_date += timedelta(seconds=1)

documents = discord_raw_to_docuemnts(guild_id=guild_id, from_date=from_date)
node_parser = configure_node_parser(chunk_size=512)
node_parser = configure_node_parser(chunk_size=chunk_size)
pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

embed_model = CohereEmbedding()
embed_dim = 1024

pg_vector.save_documents_in_batches(
community_id=community_id,
Expand All @@ -60,7 +61,7 @@ def process_discord_guild_mongo(community_id: str) -> None:
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embed_dim,
embed_dim=embedding_dim,
request_per_minute=10000,
# max_request_per_day=REQUEST_PER_DAY,
)
Expand Down
11 changes: 6 additions & 5 deletions dags/hivemind_etl_helpers/discourse_summary_etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from hivemind_etl_helpers.src.db.discourse.utils.get_forums import get_forums
from hivemind_etl_helpers.src.document_node_parser import configure_node_parser
from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding
from hivemind_etl_helpers.src.utils.load_llm_params import load_model_hyperparams
from hivemind_etl_helpers.src.utils.pg_db_utils import setup_db
from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess
from llama_index import Document
Expand Down Expand Up @@ -75,6 +76,7 @@ def process_forum(
forum_endpoint : str
the DiscourseForum endpoint for document checking
"""
chunk_size, embedding_dim = load_model_hyperparams()
table_name = "discourse_summary"

latest_date_query = f"""
Expand Down Expand Up @@ -113,11 +115,10 @@ def process_forum(

logging.info("Getting the summaries embedding and saving within database!")

node_parser = configure_node_parser(chunk_size=256)
node_parser = configure_node_parser(chunk_size=chunk_size)
pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

embed_model = CohereEmbedding()
embed_dim = 1024

logging.info(
f"{log_prefix} Saving the topic summaries (and extracting the embedding to save)"
Expand All @@ -130,7 +131,7 @@ def process_forum(
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embed_dim,
embed_dim=embedding_dim,
request_per_minute=10000,
)

Expand All @@ -145,7 +146,7 @@ def process_forum(
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embed_dim,
embed_dim=embedding_dim,
request_per_minute=10000,
)

Expand All @@ -160,7 +161,7 @@ def process_forum(
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embed_dim,
embed_dim=embedding_dim,
request_per_minute=10000,
)
else:
Expand Down
7 changes: 4 additions & 3 deletions dags/hivemind_etl_helpers/discourse_vectorstore_etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from hivemind_etl_helpers.src.document_node_parser import configure_node_parser
from hivemind_etl_helpers.src.utils.check_documents import check_documents
from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding
from hivemind_etl_helpers.src.utils.load_llm_params import load_model_hyperparams
from hivemind_etl_helpers.src.utils.pg_db_utils import setup_db
from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess

Expand Down Expand Up @@ -71,6 +72,7 @@ def process_forum(
forum_endpoint : str
the DiscourseForum endpoint for document checking
"""
chunk_size, embedding_dim = load_model_hyperparams()
table_name = "discourse"

latest_date_query = f"""
Expand All @@ -90,7 +92,7 @@ def process_forum(
)
documents = fetch_discourse_documents(forum_id=forum_id, from_date=from_date)

node_parser = configure_node_parser(chunk_size=512)
node_parser = configure_node_parser(chunk_size=chunk_size)
pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

documents, doc_file_ids_to_delete = check_documents(
Expand All @@ -112,7 +114,6 @@ def process_forum(
"""

embed_model = CohereEmbedding()
embed_dim = 1024

pg_vector.save_documents_in_batches(
community_id=community_id,
Expand All @@ -121,7 +122,7 @@ def process_forum(
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embed_dim,
embed_dim=embedding_dim,
doc_file_ids_to_delete=deletion_query,
)

Expand Down
7 changes: 4 additions & 3 deletions dags/hivemind_etl_helpers/gdrive_etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from hivemind_etl_helpers.src.document_node_parser import configure_node_parser
from hivemind_etl_helpers.src.utils.check_documents import check_documents
from hivemind_etl_helpers.src.utils.cohere_embedding import CohereEmbedding
from hivemind_etl_helpers.src.utils.load_llm_params import load_model_hyperparams
from hivemind_etl_helpers.src.utils.pg_vector_access import PGVectorAccess


Expand All @@ -32,6 +33,7 @@ def process_gdrive(
Note: One of `folder_id` or `file_ids` should be given.
"""
chunk_size, embedding_dim = load_model_hyperparams()
table_name = "gdrive"
dbname = f"community_{community_id}"

Expand All @@ -51,7 +53,7 @@ def process_gdrive(
except TypeError as exp:
logging.info(f"No documents retrieved from gdrive! exp: {exp}")

node_parser = configure_node_parser(chunk_size=256)
node_parser = configure_node_parser(chunk_size=chunk_size)
pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname)

documents, doc_file_ids_to_delete = check_documents(
Expand All @@ -67,7 +69,6 @@ def process_gdrive(
# TODO: Delete the files with id `doc_file_ids_to_delete`

embed_model = CohereEmbedding()
embed_dim = 1024

pg_vector.save_documents_in_batches(
community_id=community_id,
Expand All @@ -76,7 +77,7 @@ def process_gdrive(
node_parser=node_parser,
max_request_per_minute=None,
embed_model=embed_model,
embed_dim=embed_dim,
embed_dim=embedding_dim,
)


Expand Down
2 changes: 1 addition & 1 deletion dags/hivemind_etl_helpers/src/document_node_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def configure_node_parser(
chunk_size: int = 256, chunk_overlap: int = 20, **kwargs
chunk_size: int, chunk_overlap: int = 20, **kwargs
) -> SimpleNodeParser:
"""
Create SimpleNodeParser from documents
Expand Down
30 changes: 30 additions & 0 deletions dags/hivemind_etl_helpers/src/utils/load_llm_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
from dotenv import load_dotenv


def load_model_hyperparams() -> tuple[int, int]:
"""
load the llm and embedding model hyperparameters (the input parameters)
Returns
---------
chunk_size : int
the chunk size to chunk the data
embedding_dim : int
the embedding dimension
"""
load_dotenv()

chunk_size = os.getenv("CHUNK_SIZE")
if chunk_size is None:
raise ValueError("Chunk size is not given in env")
else:
chunk_size = int(chunk_size)

embedding_dim = os.getenv("EMBEDDING_DIM")
if embedding_dim is None:
raise ValueError("Embedding dimension size is not given in env")
else:
embedding_dim = int(embedding_dim)

return chunk_size, embedding_dim
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

from hivemind_etl_helpers.src.db.discord.discord_summary import DiscordSummary
from hivemind_etl_helpers.src.utils.mongo import MongoSingleton
from hivemind_etl_helpers.src.utils.load_llm_params import load_model_hyperparams
from llama_index import Document, MockEmbedding, ServiceContext
from llama_index.llms import MockLLM


class TestDiscordGroupedDataPreparation(TestCase):
def setUp(self):
self.mock_llm = MockLLM()
chunk_size, embedding_dim = load_model_hyperparams()
self.service_context = ServiceContext.from_defaults(
llm=MockLLM(), chunk_size=256, embed_model=MockEmbedding(embed_dim=1024)
)
Expand Down
40 changes: 40 additions & 0 deletions dags/hivemind_etl_helpers/tests/unit/test_load_model_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import unittest

from hivemind_etl_helpers.src.utils.load_llm_params import load_model_hyperparams


class TestLoadModelHyperparams(unittest.TestCase):
def setUp(self):
# Set up environment variables for testing
os.environ["CHUNK_SIZE"] = "128"
os.environ["EMBEDDING_DIM"] = "256"

def tearDown(self):
# Clean up environment variables after testing
del os.environ["CHUNK_SIZE"]
del os.environ["EMBEDDING_DIM"]

def test_load_model_hyperparams_success(self):
# Test when environment variables are set correctly
chunk_size, embedding_dim = load_model_hyperparams()
self.assertEqual(chunk_size, 128)
self.assertEqual(embedding_dim, 256)

def test_load_model_hyperparams_invalid_chunk_size(self):
# Test when CHUNK_SIZE environment variable is not a valid integer
os.environ["CHUNK_SIZE"] = "invalid"
with self.assertRaises(ValueError) as context:
load_model_hyperparams()
self.assertEqual(
str(context.exception), "invalid literal for int() with base 10: 'invalid'"
)

def test_load_model_hyperparams_invalid_embedding_dim(self):
# Test when EMBEDDING_DIM environment variable is not a valid integer
os.environ["EMBEDDING_DIM"] = "invalid"
with self.assertRaises(ValueError) as context:
load_model_hyperparams()
self.assertEqual(
str(context.exception), "invalid literal for int() with base 10: 'invalid'"
)

0 comments on commit c96d92d

Please sign in to comment.