-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Weaviate #180
Draft
hsm207
wants to merge
12
commits into
treygrainger:main
Choose a base branch
from
hsm207:weaviate
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+227
−1
Draft
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
da66e05
add weaviate container
hsm207 b0e3370
fix weaviate port mapping
hsm207 3f7903e
update dependencies
hsm207 b90f340
stub weaviate engine
hsm207 141d4d5
implement health check
hsm207 66a57c6
update dependencies
hsm207 8b40991
implement create_collection
hsm207 3cb5943
reformat code
hsm207 6bb6237
implement get collection
hsm207 f19e617
code cleanup
hsm207 8554402
implement enable ltr
hsm207 0c60901
code cleanup
hsm207 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from abc import ABC, abstractmethod | ||
from xmlrpc.client import boolean | ||
import aips.environment as env | ||
import json | ||
from engines.Collection import Collection | ||
|
||
|
||
class WeaviateCollection(Collection): | ||
def __init__(self, name): | ||
self.name = name | ||
|
||
# @abstractmethod | ||
def commit(self): | ||
"Force the collection to commit all uncommited data into the collection" | ||
pass | ||
|
||
# @abstractmethod | ||
def write(self, dataframe): | ||
"Writes a pyspark dataframe containing documents into the collection" | ||
pass | ||
|
||
# @abstractmethod | ||
def add_documents(self, docs, commit=True): | ||
"Adds a collection of documents into the collection" | ||
pass | ||
|
||
# @abstractmethod | ||
def transform_request(self, **search_args): | ||
"Transforms a generic search request into a native search request" | ||
pass | ||
|
||
# @abstractmethod | ||
def transform_response(self, search_response): | ||
"Transform a native search response into a generic search response" | ||
pass | ||
|
||
# @abstractmethod | ||
def native_search(self, request=None): | ||
"Executes a search against the search engine given a native search request" | ||
pass | ||
|
||
# @abstractmethod | ||
def vector_search(self, **search_args): | ||
"Executes a vector search given a vector search request" | ||
pass | ||
|
||
# @abstractmethod | ||
def search_for_random_document(self, query): | ||
"Searches for a random document matching the query" | ||
pass | ||
|
||
# @abstractmethod | ||
def spell_check(self, query, log=False): | ||
"Execute a spellcheck against the collection" | ||
pass | ||
|
||
def search(self, **search_args): | ||
""" | ||
Searches the collection | ||
:param str query: The main query for the search request | ||
:param str query_parser: The name of the query parser to use in the search | ||
:param list of str query_fields: the fields to query against | ||
:param list of str return_fields: the fields to return on each document | ||
:param list of tuple of str filters: A list of tuples (field, value) to filter the results by | ||
:param int limit: The number of results to return | ||
:param list of tuple of str order_by: A list of tuples (field, ASC/DESC) to order the results by | ||
:param str rerank_query: A query to rerank the results by | ||
:param str default_operator: Sets the default operator of the search query (AND/OR) | ||
:param str min_match: Specificies the minimum matching constraints for matching documents | ||
:param str query_boosts: A boost query to boost documents at query time | ||
:param tuple of str index_time_boosts: An index time boost | ||
:param boolean explain: Enables debugging on the request | ||
:param boolean log: Enables logging for the query | ||
:param boolean highlight: Returns results with highlight information (if supported) | ||
""" | ||
request = self.transform_request(**search_args) | ||
if "log" in search_args or env.get("PRINT_REQUESTS", False): | ||
print(json.dumps(request, indent=2)) | ||
search_response = self.native_search(request=request) | ||
if "log" in search_args: | ||
print(json.dumps(search_response, indent=2)) | ||
return self.transform_response(search_response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from abc import ABC, abstractmethod | ||
from engines.Engine import Engine | ||
from engines.weaviate.WeaviateCollection import WeaviateCollection | ||
import weaviate | ||
|
||
|
||
class WeaviateEngine(Engine): | ||
def __init__(self): | ||
self.client = weaviate.connect_to_local(port=8090) | ||
|
||
def health_check(self): | ||
"Checks the state of the search engine returning a boolean" | ||
return self.client.is_ready() | ||
|
||
def print_status(self, response): | ||
"Prints the resulting status of a search engine request" | ||
if response: | ||
print("Status: Success") | ||
|
||
def create_collection(self, name): | ||
"Create and initialize the schema for a collection, returns the initialized collection" | ||
collection_exists = self.client.collections.exists(name) | ||
if collection_exists: | ||
print(f"Wiping collection {name}") | ||
self.client.collections.delete(name) | ||
|
||
print(f"Creating collection {name}") | ||
collection = self.client.collections.create(name) | ||
|
||
self.apply_schema_for_collection(name) | ||
self.print_status(collection) | ||
return collection | ||
|
||
def get_collection(self, name): | ||
"Returns initialized object for a given collection" | ||
return WeaviateCollection(name) | ||
|
||
def apply_schema_for_collection(self, collection): | ||
"Applies the appriorate schema for a given collection" | ||
print(f"Using auto schema for collection {collection}") | ||
|
||
def enable_ltr(self, collection): | ||
"Initializes LTR dependencies for a given collection" | ||
raise NotImplementedError("¯\\_(ツ)_/¯") | ||
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
FROM cr.weaviate.io/semitechnologies/weaviate:1.25.3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
--- | ||
services: | ||
weaviate: | ||
command: | ||
- --host | ||
- 0.0.0.0 | ||
- --port | ||
- '8080' | ||
- --scheme | ||
- http | ||
image: cr.weaviate.io/semitechnologies/weaviate:1.25.3 | ||
ports: | ||
- 8090:8080 | ||
- 50051:50051 | ||
volumes: | ||
- weaviate_data:/var/lib/weaviate | ||
restart: on-failure:0 | ||
environment: | ||
QUERY_DEFAULTS_LIMIT: 25 | ||
AUTHENTICATION_ANONYMOUS_ACCESS_ENABLED: 'true' | ||
PERSISTENCE_DATA_PATH: '/var/lib/weaviate' | ||
DEFAULT_VECTORIZER_MODULE: 'none' | ||
ENABLE_MODULES: '' | ||
CLUSTER_HOSTNAME: 'node1' | ||
volumes: | ||
weaviate_data: | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
weaviate-client | ||
pytest | ||
IPython | ||
pandas | ||
black |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import uuid | ||
import pytest | ||
import sys | ||
|
||
sys.path.insert(0, "/workspaces/ai-powered-search/") | ||
|
||
from aips import set_engine, get_engine | ||
from engines.weaviate.WeaviateEngine import WeaviateEngine | ||
from aips import set_engine, get_engine | ||
|
||
|
||
@pytest.fixture | ||
def weaviate_engine(): | ||
set_engine("weaviate") | ||
yield get_engine() | ||
|
||
|
||
def test_health_check(weaviate_engine): | ||
assert weaviate_engine.health_check() == True | ||
|
||
|
||
def test_create_and_get_collection(weaviate_engine): | ||
collection_name = f"TestCollection_{uuid.uuid4().hex}" | ||
weaviate_engine.create_collection(collection_name) | ||
|
||
assert weaviate_engine.client.collections.exists(collection_name) | ||
|
||
collection = weaviate_engine.get_collection(collection_name) | ||
|
||
assert collection.__class__.__name__ == "WeaviateCollection" | ||
assert collection.name == collection_name | ||
|
||
|
||
def test_enable_ltr(weaviate_engine): | ||
collection_name = f"TestCollection_{uuid.uuid4().hex}" | ||
weaviate_engine.create_collection(collection_name) | ||
|
||
collection = weaviate_engine.get_collection(collection_name) | ||
with pytest.raises(NotImplementedError): | ||
weaviate_engine.enable_ltr(collection) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method may go away (TBD). It's currently just a hook to tell the engine to adjust it's configuration if needed to support LTR on the given collection. In the case of Weaviate, I assume the LTR model will be running and invoked outside the engine, so you'll either have it always running (in which
enable_ltr
can be a noop) OR you can useenable_ltr
to copy any data/models/config needed into place.