From a467be009103917ef15bb59216f6f6e0155be672 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Date: Thu, 28 Dec 2023 16:25:50 +0330 Subject: [PATCH] feat: initializing the bot! --- .dockerignore | 8 + .github/workflows/production.yml | 12 ++ .github/workflows/start.staging.yml | 9 ++ .gitignore | 2 + Dockerfile | 12 ++ celery_app/__init__.py | 0 celery_app/job_send.py | 30 ++++ celery_app/server.py | 5 + celery_app/tasks.py | 27 ++++ discord_query.py | 147 ++++++++++++++++++ docker-compose.example.yml | 14 ++ docker-compose.test.yml | 71 +++++++++ docker-entrypoint.sh | 3 + requirements.txt | 20 +++ retrievers/__init__.py | 0 retrievers/forum_summary_retriever.py | 74 +++++++++ retrievers/process_dates.py | 39 +++++ retrievers/summary_retriever_base.py | 72 +++++++++ retrievers/utils/__init__.py | 0 retrievers/utils/load_hyperparams.py | 34 ++++ tests/__init__.py | 0 tests/integration/__init__.py | 0 tests/unit/__init__.py | 0 tests/unit/test_discord_summary_retriever.py | 84 ++++++++++ .../test_load_retriever_hyperparameters.py | 73 +++++++++ ...st_process_dates_forum_retriever_search.py | 42 +++++ tests/unit/test_summary_retriever_base.py | 30 ++++ worker.py | 39 +++++ 28 files changed, 847 insertions(+) create mode 100644 .dockerignore create mode 100644 .github/workflows/production.yml create mode 100644 .github/workflows/start.staging.yml create mode 100644 Dockerfile create mode 100644 celery_app/__init__.py create mode 100644 celery_app/job_send.py create mode 100644 celery_app/server.py create mode 100644 celery_app/tasks.py create mode 100644 discord_query.py create mode 100644 docker-compose.example.yml create mode 100644 docker-compose.test.yml create mode 100644 docker-entrypoint.sh create mode 100644 requirements.txt create mode 100644 retrievers/__init__.py create mode 100644 retrievers/forum_summary_retriever.py create mode 100644 retrievers/process_dates.py create mode 100644 retrievers/summary_retriever_base.py create mode 100644 retrievers/utils/__init__.py create mode 100644 retrievers/utils/load_hyperparams.py create mode 100644 tests/__init__.py create mode 100644 tests/integration/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/test_discord_summary_retriever.py create mode 100644 tests/unit/test_load_retriever_hyperparameters.py create mode 100644 tests/unit/test_process_dates_forum_retriever_search.py create mode 100644 tests/unit/test_summary_retriever_base.py create mode 100644 worker.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..dc71603 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,8 @@ +.github/ + +.coverage/ +.coverage +coverage + +venv/ +.env diff --git a/.github/workflows/production.yml b/.github/workflows/production.yml new file mode 100644 index 0000000..a1be27b --- /dev/null +++ b/.github/workflows/production.yml @@ -0,0 +1,12 @@ +name: Production CI/CD Pipeline + +on: + push: + branches: + - main + +jobs: + ci: + uses: TogetherCrew/operations/.github/workflows/ci.yml@main + secrets: + CC_TEST_REPORTER_ID: ${{ secrets.CC_TEST_REPORTER_ID }} \ No newline at end of file diff --git a/.github/workflows/start.staging.yml b/.github/workflows/start.staging.yml new file mode 100644 index 0000000..842e3bd --- /dev/null +++ b/.github/workflows/start.staging.yml @@ -0,0 +1,9 @@ +name: Staging CI/CD Pipeline + +on: pull_request + +jobs: + ci: + uses: TogetherCrew/operations/.github/workflows/ci.yml@main + secrets: + CC_TEST_REPORTER_ID: ${{ secrets.CC_TEST_REPORTER_ID }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 68bc17f..1cd6533 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,5 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +hivemind-bot-env/* \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e2734a2 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,12 @@ +# It's recommended that we use `bullseye` for Python (alpine isn't suitable as it conflcts with numpy) +FROM python:3.11-bullseye AS base +WORKDIR /project +COPY . . +RUN pip3 install -r requirements.txt + +FROM base AS test +RUN chmod +x docker-entrypoint.sh +CMD ["./docker-entrypoint.sh"] + +FROM base AS prod +CMD ["python3", "celery", "-A", "celery_app.server", "worker", "-l", "INFO"] diff --git a/celery_app/__init__.py b/celery_app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/celery_app/job_send.py b/celery_app/job_send.py new file mode 100644 index 0000000..c1f59ae --- /dev/null +++ b/celery_app/job_send.py @@ -0,0 +1,30 @@ +from tc_messageBroker import RabbitMQ +from tc_messageBroker.rabbit_mq.event import Event +from tc_messageBroker.rabbit_mq.queue import Queue + + +def job_send(broker_url, port, username, password, res): + rabbit_mq = RabbitMQ( + broker_url=broker_url, port=port, username=username, password=password + ) + + content = { + "uuid": "d99a1490-fba6-11ed-b9a9-0d29e7612dp8", + "data": f"some results {res}", + } + + rabbit_mq.connect(Queue.DISCORD_ANALYZER) + rabbit_mq.publish( + queue_name=Queue.DISCORD_ANALYZER, + event=Event.DISCORD_BOT.FETCH, + content=content, + ) + + +if __name__ == "__main__": + # TODO: read from .env + broker_url = "localhost" + port = 5672 + username = "root" + password = "pass" + job_send(broker_url, port, username, password, "CALLED FROM __main__") diff --git a/celery_app/server.py b/celery_app/server.py new file mode 100644 index 0000000..c1c44d4 --- /dev/null +++ b/celery_app/server.py @@ -0,0 +1,5 @@ +from celery import Celery + +# TODO: read from .env +app = Celery("celery_app/tasks", broker="pyamqp://root:pass@localhost//") +app.autodiscover_tasks(["celery_app"]) diff --git a/celery_app/tasks.py b/celery_app/tasks.py new file mode 100644 index 0000000..73652a7 --- /dev/null +++ b/celery_app/tasks.py @@ -0,0 +1,27 @@ +from celery_app.server import app +from celery_app.job_send import job_send + +# TODO: Write tasks that match our requirements + + +@app.task +def add(x, y): + broker_url = "localhost" + port = 5672 + username = "root" + password = "pass" + + res = x + y + job_send(broker_url, port, username, password, res) + + return res + + +@app.task +def mul(x, y): + return x * y + + +@app.task +def xsum(numbers): + return sum(numbers) diff --git a/discord_query.py b/discord_query.py new file mode 100644 index 0000000..6bb6fb8 --- /dev/null +++ b/discord_query.py @@ -0,0 +1,147 @@ +from retrievers.forum_summary_retriever import ( + ForumBasedSummaryRetriever, +) +from retrievers.process_dates import process_dates +from retrievers.utils.load_hyperparams import load_hyperparams +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding +from tc_hivemind_backend.pg_vector_access import PGVectorAccess +from llama_index import QueryBundle +from llama_index.vector_stores import ExactMatchFilter, FilterCondition, MetadataFilters + + +def query_discord( + community_id: str, + query: str, + thread_names: list[str], + channel_names: list[str], + days: list[str], + similarity_top_k: int | None = None, +) -> str: + """ + query the discord database using filters given + and give an anwer to the given query using the LLM + + Parameters + ------------ + guild_id : str + the discord guild data to query + query : str + the query (question) of the user + thread_names : list[str] + the given threads to search for + channel_names : list[str] + the given channels to search for + days : list[str] + the given days to search for + similarity_top_k : int | None + the k similar results to use when querying the data + if `None` will load from `.env` file + + Returns + --------- + response : str + the LLM response given the query + """ + if similarity_top_k is None: + _, similarity_top_k, _ = load_hyperparams() + + table_name = "discord" + dbname = f"community_{community_id}" + + pg_vector = PGVectorAccess(table_name=table_name, dbname=dbname) + + index = pg_vector.load_index() + + thread_filters: list[ExactMatchFilter] = [] + channel_filters: list[ExactMatchFilter] = [] + day_filters: list[ExactMatchFilter] = [] + + for channel in channel_names: + channel_updated = channel.replace("'", "''") + channel_filters.append(ExactMatchFilter(key="channel", value=channel_updated)) + + for thread in thread_names: + thread_updated = thread.replace("'", "''") + thread_filters.append(ExactMatchFilter(key="thread", value=thread_updated)) + + for day in days: + day_filters.append(ExactMatchFilter(key="date", value=day)) + + all_filters: list[ExactMatchFilter] = [] + all_filters.extend(thread_filters) + all_filters.extend(channel_filters) + all_filters.extend(day_filters) + + filters = MetadataFilters(filters=all_filters, condition=FilterCondition.OR) + + query_engine = index.as_query_engine( + filters=filters, similarity_top_k=similarity_top_k + ) + + query_bundle = QueryBundle( + query_str=query, embedding=CohereEmbedding().get_text_embedding(text=query) + ) + response = query_engine.query(query_bundle) + + return response.response + + +def query_discord_auto_filter( + community_id: str, + query: str, + similarity_top_k: int | None = None, + d: int | None = None, +) -> str: + """ + get the query results and do the filtering automatically. + By automatically we mean, it would first query the summaries + to get the metadata filters + + Parameters + ----------- + guild_id : str + the discord guild data to query + query : str + the query (question) of the user + similarity_top_k : int | None + the value for the initial summary search + to get the `k2` count simliar nodes + if `None`, then would read from `.env` + d : int + this would make the secondary search (`query_discord`) + to be done on the `metadata.date - d` to `metadata.date + d` + + + Returns + --------- + response : str + the LLM response given the query + """ + table_name = "discord_summary" + dbname = f"community_{community_id}" + + if d is None: + _, _, d = load_hyperparams() + if similarity_top_k is None: + similarity_top_k, _, _ = load_hyperparams() + + discord_retriever = ForumBasedSummaryRetriever(table_name=table_name, dbname=dbname) + + channels, threads, dates = discord_retriever.retreive_metadata( + query=query, + metadata_group1_key="channel", + metadata_group2_key="thread", + metadata_date_key="date", + similarity_top_k=similarity_top_k, + ) + + dates_modified = process_dates(dates, d) + + response = query_discord( + community_id=community_id, + query=query, + thread_names=threads, + channel_names=channels, + days=dates_modified, + ) + return response diff --git a/docker-compose.example.yml b/docker-compose.example.yml new file mode 100644 index 0000000..0eccf2e --- /dev/null +++ b/docker-compose.example.yml @@ -0,0 +1,14 @@ +version: "3.9" + +services: + server: + build: + context: . + target: prod + dockerfile: Dockerfile + worker: + build: + context: . + target: prod + dockerfile: Dockerfile + command: python3 worker.py diff --git a/docker-compose.test.yml b/docker-compose.test.yml new file mode 100644 index 0000000..97fbcea --- /dev/null +++ b/docker-compose.test.yml @@ -0,0 +1,71 @@ +version: "3.9" + +services: + app: + build: + context: . + target: test + dockerfile: Dockerfile + environment: + - PORT=3000 + - MONGODB_HOST=mongo + - MONGODB_PORT=27017 + - MONGODB_USER=root + - MONGODB_PASS=pass + - NEO4J_PROTOCOL=bolt + - NEO4J_HOST=neo4j + - NEO4J_PORT=7687 + - NEO4J_USER=neo4j + - NEO4J_PASSWORD=password + - NEO4J_DB=neo4j + - POSTGRES_HOST=postgres + - POSTGRES_USER=root + - POSTGRES_PASS=pass + - POSTGRES_PORT=5432 + - CHUNK_SIZE=512 + - EMBEDDING_DIM=1024 + - K1_RETRIEVER_SEARCH=20 + - K2_RETRIEVER_SEARCH=5 + - D_RETRIEVER_SEARCH=7 + volumes: + - ./coverage:/project/coverage + depends_on: + neo4j: + condition: service_healthy + mongo: + condition: service_healthy + postgres: + condition: service_healthy + neo4j: + image: "neo4j:5.9.0" + environment: + - NEO4J_AUTH=neo4j/password + - NEO4J_PLUGINS=["apoc", "graph-data-science"] + - NEO4J_dbms_security_procedures_unrestricted=apoc.*,gds.* + healthcheck: + test: ["CMD" ,"wget", "http://localhost:7474"] + interval: 1m30s + timeout: 10s + retries: 2 + start_period: 40s + mongo: + image: "mongo:6.0.8" + environment: + - MONGO_INITDB_ROOT_USERNAME=root + - MONGO_INITDB_ROOT_PASSWORD=pass + healthcheck: + test: echo 'db.stats().ok' | mongosh localhost:27017/test --quiet + interval: 60s + timeout: 10s + retries: 2 + start_period: 40s + postgres: + image: "ankane/pgvector" + environment: + - POSTGRES_USER=root + - POSTGRES_PASSWORD=pass + healthcheck: + test: ["CMD-SHELL", "pg_isready"] + interval: 10s + timeout: 5s + retries: 5 diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh new file mode 100644 index 0000000..5127573 --- /dev/null +++ b/docker-entrypoint.sh @@ -0,0 +1,3 @@ +#!/usr/bin/env bash +python3 -m coverage run --omit=tests/* -m pytest . +python3 -m coverage lcov -o coverage/lcov.info \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c7886f5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +numpy +llama-index>=0.9.21, <1.0.0 +pymongo +python-dotenv +pgvector +asyncpg +psycopg2-binary +sqlalchemy[asyncio] +async-sqlalchemy +python-pptx +tc-neo4j-lib +google-api-python-client +unstructured +cohere +neo4j>=5.14.1, <6.0.0 +coverage>=7.3.3, <8.0.0 +pytest>=7.4.3, <8.0.0 +python-dotenv==1.0.0 +tc_hivemind_backend==1.0.0 +celery>=5.3.6, <6.0.0 diff --git a/retrievers/__init__.py b/retrievers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/retrievers/forum_summary_retriever.py b/retrievers/forum_summary_retriever.py new file mode 100644 index 0000000..58b4d3e --- /dev/null +++ b/retrievers/forum_summary_retriever.py @@ -0,0 +1,74 @@ +from retrievers.summary_retriever_base import BaseSummarySearch +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding + +from llama_index.embeddings import BaseEmbedding + + +class ForumBasedSummaryRetriever(BaseSummarySearch): + def __init__( + self, + table_name: str, + dbname: str, + embedding_model: BaseEmbedding | CohereEmbedding = CohereEmbedding(), + ) -> None: + """ + the class for forum based data like discord and discourse + by default CohereEmbedding will be used. + """ + super().__init__(table_name, dbname, embedding_model=embedding_model) + + def retreive_metadata( + self, + query: str, + metadata_group1_key: str, + metadata_group2_key: str, + metadata_date_key: str, + similarity_top_k: int = 20, + ) -> tuple[set[str], set[str], set[str]]: + """ + retrieve the metadata information of the similar nodes with the query + + Parameters + ----------- + query : str + the user query to process + metadata_group1_key : str + the conversations grouping type 1 + in discord can be `channel`, and in discourse can be `category` + metadata_group2_key : str + the conversations grouping type 2 + in discord can be `thread`, and in discourse can be `topic` + metadata_date_key : str + the daily metadata saved key + similarity_top_k : int + the top k nodes to get as the retriever. + default is set as 20 + + + Returns + --------- + group1_data : set[str] + the similar summary nodes having the group1_data. + can be an empty set meaning no similar thread + conversations for it was available. + group2_data : set[str] + the similar summary nodes having the group2_data. + can be an empty set meaning no similar channel + conversations for it was available. + dates : set[str] + the similar daily conversations to the given query + """ + nodes = self.get_similar_nodes(query=query, similarity_top_k=similarity_top_k) + + group1_data: set[str] = set() + dates: set[str] = set() + group2_data: set[str] = set() + + for node in nodes: + if node.metadata[metadata_group1_key]: + group1_data.add(node.metadata[metadata_group1_key]) + if node.metadata[metadata_group2_key]: + group2_data.add(node.metadata[metadata_group2_key]) + dates.add(node.metadata[metadata_date_key]) + + return group1_data, group2_data, dates diff --git a/retrievers/process_dates.py b/retrievers/process_dates.py new file mode 100644 index 0000000..dba3217 --- /dev/null +++ b/retrievers/process_dates.py @@ -0,0 +1,39 @@ +import logging +from datetime import timedelta + +from dateutil import parser + + +def process_dates(dates: list[str], d: int) -> list[str]: + """ + process the dates to be from `date - d` to `date + d` + + Parameters + ------------ + dates : list[str] + the list of dates given + d : int + to update the `dates` list to have `-d` and `+d` days + + + Returns + ---------- + dates_modified : list[str] + days added to it + """ + dates_modified: list[str] = [] + if dates != []: + lowest_date = min(parser.parse(date) for date in dates) + greatest_date = max(parser.parse(date) for date in dates) + + delta_days = timedelta(days=d) + + # the date condition + dt = lowest_date - delta_days + while dt <= greatest_date + delta_days: + dates_modified.append(dt.strftime("%Y-%m-%d")) + dt += timedelta(days=1) + else: + logging.warning("No dates given!") + + return dates_modified diff --git a/retrievers/summary_retriever_base.py b/retrievers/summary_retriever_base.py new file mode 100644 index 0000000..8cedca8 --- /dev/null +++ b/retrievers/summary_retriever_base.py @@ -0,0 +1,72 @@ +from tc_hivemind_backend.embeddings.cohere import CohereEmbedding + +from tc_hivemind_backend.pg_vector_access import PGVectorAccess +from llama_index import VectorStoreIndex +from llama_index.embeddings import BaseEmbedding +from llama_index.indices.query.schema import QueryBundle +from llama_index.schema import NodeWithScore + + +class BaseSummarySearch: + def __init__( + self, + table_name: str, + dbname: str, + embedding_model: BaseEmbedding = CohereEmbedding(), + ) -> None: + """ + initialize the base summary search class + + In this class we're doing a similarity search + for available saved nodes under postgresql + + Parameters + ------------- + table_name : str + the table that summary data is saved + *Note:* Don't include the `data_` prefix of the table, + cause lamma_index would original include that. + dbname : str + the database name to access + similarity_top_k : int + the top k nodes to get as the retriever. + default is set as 20 + embedding_model : llama_index.embeddings.BaseEmbedding + the embedding model to use for doing embedding on the query string + default would be CohereEmbedding that we've written + """ + self.index = self._setup_index(table_name, dbname) + self.embedding_model = embedding_model + + def get_similar_nodes( + self, query: str, similarity_top_k: int = 20 + ) -> list[NodeWithScore]: + """ + get k similar nodes to the query. + Note: this funciton wold get the embedding + for the query to do the similarity search. + + Parameters + ------------ + query : str + the user query to process + similarity_top_k : int + the top k nodes to get as the retriever. + default is set as 20 + """ + retriever = self.index.as_retriever(similarity_top_k=similarity_top_k) + + query_embedding = self.embedding_model.get_text_embedding(text=query) + + query_bundle = QueryBundle(query_str=query, embedding=query_embedding) + nodes = retriever._retrieve(query_bundle) + + return nodes + + def _setup_index(self, table_name: str, dbname: str) -> VectorStoreIndex: + """ + setup the llama_index VectorStoreIndex + """ + pg_vector_access = PGVectorAccess(table_name=table_name, dbname=dbname) + index = pg_vector_access.load_index() + return index diff --git a/retrievers/utils/__init__.py b/retrievers/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/retrievers/utils/load_hyperparams.py b/retrievers/utils/load_hyperparams.py new file mode 100644 index 0000000..98db6ce --- /dev/null +++ b/retrievers/utils/load_hyperparams.py @@ -0,0 +1,34 @@ +import os + +from dotenv import load_dotenv + + +def load_hyperparams() -> tuple[int, int, int]: + """ + load the k1, k2, and d hyperparams that are used for retrievers + + Returns + --------- + k1 : int + the value for the first summary search + to get the `k1` count similar nodes + k2 : int + the value for the secondary raw search + to get the `k2` count simliar nodes + d : int + the before and after day interval + """ + load_dotenv() + + k1 = os.getenv("K1_RETRIEVER_SEARCH") + k2 = os.getenv("K2_RETRIEVER_SEARCH") + d = os.getenv("D_RETRIEVER_SEARCH") + + if k1 is None: + raise ValueError("No `K1_RETRIEVER_SEARCH` available in .env file!") + if k2 is None: + raise ValueError("No `K2_RETRIEVER_SEARCH` available in .env file!") + if d is None: + raise ValueError("No `D_RETRIEVER_SEARCH` available in .env file!") + + return int(k1), int(k2), int(d) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/test_discord_summary_retriever.py b/tests/unit/test_discord_summary_retriever.py new file mode 100644 index 0000000..915742c --- /dev/null +++ b/tests/unit/test_discord_summary_retriever.py @@ -0,0 +1,84 @@ +from datetime import timedelta +from functools import partial +from unittest import TestCase +from unittest.mock import MagicMock + +from retrievers.forum_summary_retriever import ( + ForumBasedSummaryRetriever, +) +from dateutil import parser +from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex + + +class TestDiscordSummaryRetriever(TestCase): + def test_initialize_class(self): + ForumBasedSummaryRetriever._setup_index = MagicMock() + documents: list[Document] = [] + all_dates: list[str] = [] + + for i in range(30): + date = parser.parse("2023-08-01") + timedelta(days=i) + doc_date = date.strftime("%Y-%m-%d") + doc = Document( + text="SAMPLESAMPLESAMPLE", + metadata={ + "thread": f"thread{i % 5}", + "channel": f"channel{i % 3}", + "date": doc_date, + }, + ) + all_dates.append(doc_date) + documents.append(doc) + + mock_embedding_model = partial(MockEmbedding, embed_dim=1024) + + service_context = ServiceContext.from_defaults( + llm=None, embed_model=mock_embedding_model() + ) + ForumBasedSummaryRetriever._setup_index.return_value = ( + VectorStoreIndex.from_documents( + documents=[doc], service_context=service_context + ) + ) + + base_summary_search = ForumBasedSummaryRetriever( + table_name="sample", + dbname="sample", + embedding_model=mock_embedding_model(), + ) + channels, threads, dates = base_summary_search.retreive_metadata( + query="what is samplesample?", + similarity_top_k=5, + metadata_group1_key="channel", + metadata_group2_key="thread", + metadata_date_key="date", + ) + self.assertIsInstance(threads, set) + self.assertIsInstance(channels, set) + self.assertIsInstance(dates, set) + + self.assertTrue( + threads.issubset( + set( + [ + "thread0", + "thread1", + "thread2", + "thread3", + "thread4", + ] + ) + ) + ) + self.assertTrue( + channels.issubset( + set( + [ + "channel0", + "channel1", + "channel2", + ] + ) + ) + ) + self.assertTrue(dates.issubset(all_dates)) diff --git a/tests/unit/test_load_retriever_hyperparameters.py b/tests/unit/test_load_retriever_hyperparameters.py new file mode 100644 index 0000000..eadcbdc --- /dev/null +++ b/tests/unit/test_load_retriever_hyperparameters.py @@ -0,0 +1,73 @@ +import unittest +from unittest.mock import patch + +from retrievers.utils.load_hyperparams import load_hyperparams + + +class TestLoadHyperparams(unittest.TestCase): + @patch("os.getenv") + def test_valid_hyperparams(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + result = load_hyperparams() + self.assertEqual(result, (10, 20, 30)) + + @patch("os.getenv") + def test_missing_k1(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_missing_k2(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_missing_d(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "20", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_invalid_k1(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "invalid", + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_invalid_k2(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "invalid", + "D_RETRIEVER_SEARCH": "30", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() + + @patch("os.getenv") + def test_invalid_d(self, mock_getenv): + mock_getenv.side_effect = lambda x: { + "K1_RETRIEVER_SEARCH": "10", + "K2_RETRIEVER_SEARCH": "20", + "D_RETRIEVER_SEARCH": "invalid", + }.get(x) + with self.assertRaises(ValueError): + load_hyperparams() diff --git a/tests/unit/test_process_dates_forum_retriever_search.py b/tests/unit/test_process_dates_forum_retriever_search.py new file mode 100644 index 0000000..f580c82 --- /dev/null +++ b/tests/unit/test_process_dates_forum_retriever_search.py @@ -0,0 +1,42 @@ +import unittest + +from retrievers.process_dates import process_dates + + +class TestProcessDates(unittest.TestCase): + def test_process_dates_with_valid_input(self): + # Test with a valid input + input_dates = ["2023-01-01", "2023-01-03", "2023-01-05"] + d = 2 + expected_output = [ + "2022-12-30", + "2022-12-31", + "2023-01-01", + "2023-01-02", + "2023-01-03", + "2023-01-04", + "2023-01-05", + "2023-01-06", + "2023-01-07", + ] + self.assertEqual(process_dates(input_dates, d), expected_output) + + def test_process_dates_with_empty_input(self): + # Test with an empty input + input_dates = [] + d = 2 + expected_output = [] + self.assertEqual(process_dates(input_dates, d), expected_output) + + def test_process_dates_with_single_date(self): + # Test with a single date in the input + input_dates = ["2023-01-01"] + d = 2 + expected_output = [ + "2022-12-30", + "2022-12-31", + "2023-01-01", + "2023-01-02", + "2023-01-03", + ] + self.assertEqual(process_dates(input_dates, d), expected_output) diff --git a/tests/unit/test_summary_retriever_base.py b/tests/unit/test_summary_retriever_base.py new file mode 100644 index 0000000..f630d94 --- /dev/null +++ b/tests/unit/test_summary_retriever_base.py @@ -0,0 +1,30 @@ +from functools import partial +from unittest import TestCase +from unittest.mock import MagicMock + +from retrievers.summary_retriever_base import BaseSummarySearch +from llama_index import Document, MockEmbedding, ServiceContext, VectorStoreIndex +from llama_index.schema import NodeWithScore + + +class TestSummaryRetrieverBase(TestCase): + def test_initialize_class(self): + BaseSummarySearch._setup_index = MagicMock() + doc = Document(text="SAMPLESAMPLESAMPLE") + mock_embedding_model = partial(MockEmbedding, embed_dim=1024) + + service_context = ServiceContext.from_defaults( + llm=None, embed_model=mock_embedding_model() + ) + BaseSummarySearch._setup_index.return_value = VectorStoreIndex.from_documents( + documents=[doc], service_context=service_context + ) + + base_summary_search = BaseSummarySearch( + table_name="sample", + dbname="sample", + embedding_model=mock_embedding_model(), + ) + nodes = base_summary_search.get_similar_nodes(query="what is samplesample?") + self.assertIsInstance(nodes, list) + self.assertIsInstance(nodes[0], NodeWithScore) diff --git a/worker.py b/worker.py new file mode 100644 index 0000000..dcea119 --- /dev/null +++ b/worker.py @@ -0,0 +1,39 @@ +from tc_messageBroker import RabbitMQ +from tc_messageBroker.rabbit_mq.event import Event +from tc_messageBroker.rabbit_mq.queue import Queue + +from celery_app.tasks import add + + +# TODO: Update according to our requirements +def do_something(recieved_data): + message = f"Calculation Results:" + print(message) + print(f"recieved_data: {recieved_data}") + add.delay(20, 14) + + +def job_recieve(broker_url, port, username, password): + rabbit_mq = RabbitMQ( + broker_url=broker_url, port=port, username=username, password=password + ) + + # TODO: Update according to our requirements + rabbit_mq.on_event(Event.HIVEMIND.INTERACTION_CREATED, do_something) + rabbit_mq.connect(Queue.HIVEMIND) + rabbit_mq.consume(Queue.HIVEMIND) + + if rabbit_mq.channel is not None: + rabbit_mq.channel.start_consuming() + else: + print("Connection to broker was not successful!") + + +if __name__ == "__main__": + # TODO: read from .env + broker_url = "localhost" + port = 5672 + username = "root" + password = "pass" + + job_recieve(broker_url, port, username, password)