diff --git a/aips/__init__.py b/aips/__init__.py index 507794ef..59b28f51 100644 --- a/aips/__init__.py +++ b/aips/__init__.py @@ -9,7 +9,9 @@ import re import requests -engine_type_map = {"SOLR": SolrEngine()} +from engines.weaviate.WeaviateEngine import WeaviateEngine + +engine_type_map = {"SOLR": SolrEngine(), "WEAVIATE":WeaviateEngine()} def get_engine(): return engine_type_map[environment.get("AIPS_SEARCH_ENGINE", "SOLR")] diff --git a/docker-compose.yml b/docker-compose.yml index 8d14dd75..7b3938b0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,6 +15,7 @@ services: - 2345:2345 # Search Webserver depends_on: - solr + - weaviate networks: - solr-network restart: unless-stopped @@ -47,6 +48,27 @@ services: restart: unless-stopped depends_on: - zookeeper + + weaviate: + build: + context: ./engines/weaviate/build/ + dockerfile: Dockerfile + container_name: aips-weaviate + ports: + - 8090:8080 + - 50051:50051 + volumes: + - weaviate_data:/var/lib/weaviate + restart: on-failure:0 + environment: + QUERY_DEFAULTS_LIMIT: 25 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' + PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + DEFAULT_VECTORIZER_MODULE: 'none' + ENABLE_MODULES: '' + CLUSTER_HOSTNAME: 'node1' + networks: + - solr-network zookeeper: image: zookeeper:3.5.8 @@ -62,3 +84,6 @@ services: networks: zk-solr: solr-network: + +volumes: + weaviate_data: diff --git a/engines/weaviate/WeaviateCollection.py b/engines/weaviate/WeaviateCollection.py new file mode 100644 index 00000000..d8ced0f5 --- /dev/null +++ b/engines/weaviate/WeaviateCollection.py @@ -0,0 +1,82 @@ +from abc import ABC, abstractmethod +from xmlrpc.client import boolean +import aips.environment as env +import json +from engines.Collection import Collection + + +class WeaviateCollection(Collection): + def __init__(self, name): + self.name = name + + # @abstractmethod + def commit(self): + "Force the collection to commit all uncommited data into the collection" + pass + + # @abstractmethod + def write(self, dataframe): + "Writes a pyspark dataframe containing documents into the collection" + pass + + # @abstractmethod + def add_documents(self, docs, commit=True): + "Adds a collection of documents into the collection" + pass + + # @abstractmethod + def transform_request(self, **search_args): + "Transforms a generic search request into a native search request" + pass + + # @abstractmethod + def transform_response(self, search_response): + "Transform a native search response into a generic search response" + pass + + # @abstractmethod + def native_search(self, request=None): + "Executes a search against the search engine given a native search request" + pass + + # @abstractmethod + def vector_search(self, **search_args): + "Executes a vector search given a vector search request" + pass + + # @abstractmethod + def search_for_random_document(self, query): + "Searches for a random document matching the query" + pass + + # @abstractmethod + def spell_check(self, query, log=False): + "Execute a spellcheck against the collection" + pass + + def search(self, **search_args): + """ + Searches the collection + :param str query: The main query for the search request + :param str query_parser: The name of the query parser to use in the search + :param list of str query_fields: the fields to query against + :param list of str return_fields: the fields to return on each document + :param list of tuple of str filters: A list of tuples (field, value) to filter the results by + :param int limit: The number of results to return + :param list of tuple of str order_by: A list of tuples (field, ASC/DESC) to order the results by + :param str rerank_query: A query to rerank the results by + :param str default_operator: Sets the default operator of the search query (AND/OR) + :param str min_match: Specificies the minimum matching constraints for matching documents + :param str query_boosts: A boost query to boost documents at query time + :param tuple of str index_time_boosts: An index time boost + :param boolean explain: Enables debugging on the request + :param boolean log: Enables logging for the query + :param boolean highlight: Returns results with highlight information (if supported) + """ + request = self.transform_request(**search_args) + if "log" in search_args or env.get("PRINT_REQUESTS", False): + print(json.dumps(request, indent=2)) + search_response = self.native_search(request=request) + if "log" in search_args: + print(json.dumps(search_response, indent=2)) + return self.transform_response(search_response) diff --git a/engines/weaviate/WeaviateEngine.py b/engines/weaviate/WeaviateEngine.py new file mode 100644 index 00000000..38c89fa1 --- /dev/null +++ b/engines/weaviate/WeaviateEngine.py @@ -0,0 +1,44 @@ +from abc import ABC, abstractmethod +from engines.Engine import Engine +from engines.weaviate.WeaviateCollection import WeaviateCollection +import weaviate + + +class WeaviateEngine(Engine): + def __init__(self): + self.client = weaviate.connect_to_local(port=8090) + + def health_check(self): + "Checks the state of the search engine returning a boolean" + return self.client.is_ready() + + def print_status(self, response): + "Prints the resulting status of a search engine request" + if response: + print("Status: Success") + + def create_collection(self, name): + "Create and initialize the schema for a collection, returns the initialized collection" + collection_exists = self.client.collections.exists(name) + if collection_exists: + print(f"Wiping collection {name}") + self.client.collections.delete(name) + + print(f"Creating collection {name}") + collection = self.client.collections.create(name) + + self.apply_schema_for_collection(name) + self.print_status(collection) + return collection + + def get_collection(self, name): + "Returns initialized object for a given collection" + return WeaviateCollection(name) + + def apply_schema_for_collection(self, collection): + "Applies the appriorate schema for a given collection" + print(f"Using auto schema for collection {collection}") + + def enable_ltr(self, collection): + "Initializes LTR dependencies for a given collection" + raise NotImplementedError("¯\\_(ツ)_/¯") diff --git a/engines/weaviate/__init__.py b/engines/weaviate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/engines/weaviate/build/Dockerfile b/engines/weaviate/build/Dockerfile new file mode 100644 index 00000000..0fca63dc --- /dev/null +++ b/engines/weaviate/build/Dockerfile @@ -0,0 +1 @@ +FROM cr.weaviate.io/semitechnologies/weaviate:1.25.3 \ No newline at end of file diff --git a/engines/weaviate/build/docker-compose.yml b/engines/weaviate/build/docker-compose.yml new file mode 100644 index 00000000..200f311e --- /dev/null +++ b/engines/weaviate/build/docker-compose.yml @@ -0,0 +1,27 @@ +--- +services: + weaviate: + command: + - --host + - 0.0.0.0 + - --port + - '8080' + - --scheme + - http + image: cr.weaviate.io/semitechnologies/weaviate:1.25.3 + ports: + - 8090:8080 + - 50051:50051 + volumes: + - weaviate_data:/var/lib/weaviate + restart: on-failure:0 + environment: + QUERY_DEFAULTS_LIMIT: 25 + AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' + PERSISTENCE_DATA_PATH: '/var/lib/weaviate' + DEFAULT_VECTORIZER_MODULE: 'none' + ENABLE_MODULES: '' + CLUSTER_HOSTNAME: 'node1' +volumes: + weaviate_data: +... \ No newline at end of file diff --git a/engines/weaviate/requirements.txt b/engines/weaviate/requirements.txt new file mode 100644 index 00000000..ac95bf1f --- /dev/null +++ b/engines/weaviate/requirements.txt @@ -0,0 +1,5 @@ +weaviate-client +pytest +IPython +pandas +black \ No newline at end of file diff --git a/engines/weaviate/tests/test_WeaviateEngine.py b/engines/weaviate/tests/test_WeaviateEngine.py new file mode 100644 index 00000000..3a33db4c --- /dev/null +++ b/engines/weaviate/tests/test_WeaviateEngine.py @@ -0,0 +1,40 @@ +import uuid +import pytest +import sys + +sys.path.insert(0, "/workspaces/ai-powered-search/") + +from aips import set_engine, get_engine +from engines.weaviate.WeaviateEngine import WeaviateEngine +from aips import set_engine, get_engine + + +@pytest.fixture +def weaviate_engine(): + set_engine("weaviate") + yield get_engine() + + +def test_health_check(weaviate_engine): + assert weaviate_engine.health_check() == True + + +def test_create_and_get_collection(weaviate_engine): + collection_name = f"TestCollection_{uuid.uuid4().hex}" + weaviate_engine.create_collection(collection_name) + + assert weaviate_engine.client.collections.exists(collection_name) + + collection = weaviate_engine.get_collection(collection_name) + + assert collection.__class__.__name__ == "WeaviateCollection" + assert collection.name == collection_name + + +def test_enable_ltr(weaviate_engine): + collection_name = f"TestCollection_{uuid.uuid4().hex}" + weaviate_engine.create_collection(collection_name) + + collection = weaviate_engine.get_collection(collection_name) + with pytest.raises(NotImplementedError): + weaviate_engine.enable_ltr(collection)