-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added the embedding_dim and chunk_size as env variables!
- Loading branch information
1 parent
fafe587
commit c96d92d
Showing
9 changed files
with
97 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import os | ||
from dotenv import load_dotenv | ||
|
||
|
||
def load_model_hyperparams() -> tuple[int, int]: | ||
""" | ||
load the llm and embedding model hyperparameters (the input parameters) | ||
Returns | ||
--------- | ||
chunk_size : int | ||
the chunk size to chunk the data | ||
embedding_dim : int | ||
the embedding dimension | ||
""" | ||
load_dotenv() | ||
|
||
chunk_size = os.getenv("CHUNK_SIZE") | ||
if chunk_size is None: | ||
raise ValueError("Chunk size is not given in env") | ||
else: | ||
chunk_size = int(chunk_size) | ||
|
||
embedding_dim = os.getenv("EMBEDDING_DIM") | ||
if embedding_dim is None: | ||
raise ValueError("Embedding dimension size is not given in env") | ||
else: | ||
embedding_dim = int(embedding_dim) | ||
|
||
return chunk_size, embedding_dim |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
40 changes: 40 additions & 0 deletions
40
dags/hivemind_etl_helpers/tests/unit/test_load_model_params.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
import unittest | ||
|
||
from hivemind_etl_helpers.src.utils.load_llm_params import load_model_hyperparams | ||
|
||
|
||
class TestLoadModelHyperparams(unittest.TestCase): | ||
def setUp(self): | ||
# Set up environment variables for testing | ||
os.environ["CHUNK_SIZE"] = "128" | ||
os.environ["EMBEDDING_DIM"] = "256" | ||
|
||
def tearDown(self): | ||
# Clean up environment variables after testing | ||
del os.environ["CHUNK_SIZE"] | ||
del os.environ["EMBEDDING_DIM"] | ||
|
||
def test_load_model_hyperparams_success(self): | ||
# Test when environment variables are set correctly | ||
chunk_size, embedding_dim = load_model_hyperparams() | ||
self.assertEqual(chunk_size, 128) | ||
self.assertEqual(embedding_dim, 256) | ||
|
||
def test_load_model_hyperparams_invalid_chunk_size(self): | ||
# Test when CHUNK_SIZE environment variable is not a valid integer | ||
os.environ["CHUNK_SIZE"] = "invalid" | ||
with self.assertRaises(ValueError) as context: | ||
load_model_hyperparams() | ||
self.assertEqual( | ||
str(context.exception), "invalid literal for int() with base 10: 'invalid'" | ||
) | ||
|
||
def test_load_model_hyperparams_invalid_embedding_dim(self): | ||
# Test when EMBEDDING_DIM environment variable is not a valid integer | ||
os.environ["EMBEDDING_DIM"] = "invalid" | ||
with self.assertRaises(ValueError) as context: | ||
load_model_hyperparams() | ||
self.assertEqual( | ||
str(context.exception), "invalid literal for int() with base 10: 'invalid'" | ||
) |