diff --git a/examples/destination/vector/csv_reader.py b/examples/destination/vector/csv_reader.py new file mode 100644 index 0000000..4691cb3 --- /dev/null +++ b/examples/destination/vector/csv_reader.py @@ -0,0 +1,26 @@ +import pandas as pd +from Crypto.Cipher import AES +from Crypto.Util.Padding import unpad +import zstandard as zstd +from io import StringIO +from logger import log + + +class CSVReaderAESZSTD: + + def read_csv(self, file_path, aes_key, null_string, timestamp_columns): + with open(file_path, 'rb') as encrypted_file: + iv = encrypted_file.read(16) + encrypted_data = encrypted_file.read() + + cipher = AES.new(aes_key, AES.MODE_CBC, iv) + decrypted_data = unpad(cipher.decrypt(encrypted_data), AES.block_size) + + decompressor = zstd.ZstdDecompressor() + + with decompressor.stream_reader(decrypted_data) as reader: + decompressed_data = reader.read() + + data_str = decompressed_data.decode('utf-8') + df = pd.read_csv(StringIO(data_str), na_values=null_string, parse_dates=timestamp_columns) + return df \ No newline at end of file diff --git a/examples/destination/vector/destination.py b/examples/destination/vector/destination.py new file mode 100644 index 0000000..7dab80e --- /dev/null +++ b/examples/destination/vector/destination.py @@ -0,0 +1,41 @@ +from datetime import datetime +from abc import ABC, abstractmethod + +from models.collection import Collection, Row +from typing import Optional, Any, List + +from sdk.common_pb2 import ConfigurationFormResponse + + +class VectorDestination(ABC): + """Interface for vector destinations""" + + @abstractmethod + def configuration_form(self) -> ConfigurationFormResponse: + """configuration_form""" + + @abstractmethod + def test(self, name: str, configuration: dict[str, str]) -> Optional[str]: + """test""" + + @abstractmethod + def create_collection_if_not_exists(self, configuration: dict[str, Any], collection: Collection) -> None: + """create_collection""" + + @abstractmethod + def upsert_rows(self, configuration: dict[str, Any], collection: Collection, rows: List[Row]) -> None: + """upsert_rows""" + + @abstractmethod + def delete_rows(self, configuration: dict[str, Any], collection: Collection, ids: List[str]) -> None: + """delete_rows""" + + @abstractmethod + def truncate(self, configuration: dict[str, Any], collection: Collection, synced_column: str, delete_before: datetime) -> None: + """delete_rows""" + + # Not Ideal but no clear winner ¯\_(ツ)_/¯ + def get_collection_name(self, schema_name: str, table_name: str) -> str: + schema_name = schema_name.replace("_","-") + table_name = table_name.replace("_", "-") + return f"{schema_name}-{table_name}" diff --git a/examples/destination/vector/destinations/weaviate_.py b/examples/destination/vector/destinations/weaviate_.py new file mode 100644 index 0000000..d4ea854 --- /dev/null +++ b/examples/destination/vector/destinations/weaviate_.py @@ -0,0 +1,86 @@ +import uuid +import weaviate +from weaviate.auth import AuthApiKey +from weaviate.classes.query import Filter + +from sdk.common_pb2 import ConfigurationFormResponse, FormField, TextField, ConfigurationTest +from destination import VectorDestination + + +class WeaviateDestination(VectorDestination): + def get_collection_name(self, schema_name, table_name): + return f"{schema_name}_{table_name}" + + def configuration_form(self): + fields = [ + FormField(name="url", label="Weaviate Cluster URL", required=True, text_field=TextField.PlainText), + FormField(name="api_key", label="Weaviate API Key", required=True, text_field=TextField.Password) + ] + tests = [ConfigurationTest(name="connection_test", label="Connecting to Weaviate Cluster")] + return ConfigurationFormResponse(fields=fields, tests=tests) + + def _get_client(self, configuration): + return weaviate.connect_to_wcs( + cluster_url=configuration["url"], + auth_credentials=AuthApiKey(configuration["api_key"]) + ) + + def test(self, name, config): + if name != "connection_test": + raise ValueError(name) + + client = self._get_client(config) + + client.connect() + client.close() + + def create_collection_if_not_exists(self, config, collection): + client = self._get_client(config) + + if not client.collections.exists(collection.name): + print(f"Collection {collection.name} does not exist! Creating!") + client.collections.create(name=collection.name) + + client.close() + + def upsert_rows(self, config, collection, rows): + client = self._get_client(config) + c = client.collections.get(collection.name) + + for row in rows: + _uuid = uuid.uuid5(uuid.NAMESPACE_DNS, str(row.id)) + + # TODO: Get these swap column names from config + # TODO: swap column name in framework if other vec dbs have same issue. + id_swap_column = "_fvt_swp_id" + vector_swap_column = "_fvt_swp_vector" + + if "id" in row.payload: + row.payload[id_swap_column] = row.payload.pop("id") + if "vector" in row.payload: + row.payload[vector_swap_column] = row.payload.pop("vector") + + if c.data.exists(_uuid): + c.data.replace(uuid=_uuid, properties=row.payload, vector=row.vector) + else: + c.data.insert(uuid=_uuid, properties=row.payload, vector=row.vector) + + client.close() + + def delete_rows(self, config, collection, ids): + client = self._get_client(config) + c = client.collections.get(collection.name) + + for id in ids: + _uuid = uuid.uuid5(uuid.NAMESPACE_DNS, str(id)) + c.data.delete_by_id(uuid=_uuid) + + client.close() + + def truncate(self, config, collection, synced_column, delete_before): + client = self._get_client(config) + c = client.collections.get(collection.name) + + filter = Filter.by_property(synced_column).less_than(delete_before) + c.data.delete_many(where=filter) + diff --git a/examples/destination/vector/embedder.py b/examples/destination/vector/embedder.py new file mode 100644 index 0000000..b9f0161 --- /dev/null +++ b/examples/destination/vector/embedder.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod + +from typing import Optional, Any, List, Tuple + +from sdk.common_pb2 import ConfigurationFormResponse +from models.collection import Metrics + + +class Embedder(ABC): + """Interface for embedders""" + + @abstractmethod + def details(self)-> Tuple[str, str]: + """details -> [id, name]""" + + + @abstractmethod + def configuration_form(self) -> ConfigurationFormResponse: + """configuration_form""" + + @abstractmethod + def metrics(self, configuration: dict[str, str])-> Metrics: + """metrics""" + + @abstractmethod + def test(self, name: str, configuration: dict[str, str]) -> Optional[str]: + """test""" + + @abstractmethod + def embed(self, configuration: dict[str, str], texts: List[str]) -> List[List[float]]: + """embed""" diff --git a/examples/destination/vector/embedders/open_ai.py b/examples/destination/vector/embedders/open_ai.py new file mode 100644 index 0000000..8420a1e --- /dev/null +++ b/examples/destination/vector/embedders/open_ai.py @@ -0,0 +1,46 @@ +from embedder import Embedder +from sdk.common_pb2 import ConfigurationFormResponse, FormField, TextField, DropdownField, ConfigurationTest +from langchain_openai import OpenAIEmbeddings +from models.collection import Metrics, Distance + +MODELS = { + "text-embedding-ada-002": Metrics( + distance=Distance.COSINE, + dimensions=1536 + ) +} + + +class OpenAIEmbedder(Embedder): + + def details(self): + return "open_ai", "OpenAI" + + def configuration_form(self): + models = DropdownField(dropdown_field=list(MODELS.keys())) + fields = [ + FormField(name="api_key", label="OpenAI API Key", required=True, text_field=TextField.Password), + FormField(name="embedding_model", label="OpenAI Embedding Model", required=True, dropdown_field=models), + ] + tests = [ConfigurationTest(name="embedding_test", label="Checking OpenAI Embedding Generation")] + return ConfigurationFormResponse(fields=fields, tests=tests) + + def metrics(self, config) -> Metrics: + return MODELS[config["embedding_model"]] + + def _get_embedding(self, configuration): + api_key = configuration["api_key"] + model = configuration["embedding_model"] + + return OpenAIEmbeddings(api_key=api_key, model=model) + + def test(self, name, configuration): + if name != "embedding_test": + raise ValueError(f'Unknown test : {name}') + + embedding = self._get_embedding(configuration) + embedding.embed_query("foo-bar-biz") + + def embed(self, configuration, texts): + embedding = self._get_embedding(configuration) + return embedding.embed_documents(texts) diff --git a/examples/destination/vector/logger.py b/examples/destination/vector/logger.py new file mode 100644 index 0000000..3388b70 --- /dev/null +++ b/examples/destination/vector/logger.py @@ -0,0 +1,10 @@ +import json + + +def log(msg: str): + m = { + "level": "INFO", + "message": msg, + "message-origin": "sdk_destination" + } + print(json.dumps(m), flush=True) \ No newline at end of file diff --git a/examples/destination/vector/main.py b/examples/destination/vector/main.py new file mode 100644 index 0000000..e3d4c9c --- /dev/null +++ b/examples/destination/vector/main.py @@ -0,0 +1,27 @@ +import grpc +from concurrent import futures +from sdk.destination_sdk_pb2_grpc import add_DestinationServicer_to_server +from destination import VectorDestination +from service import VectorDestinationServicer +import sys + + +def serve(vec_dest: VectorDestination): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + add_DestinationServicer_to_server(VectorDestinationServicer(vec_dest), server) + + if len(sys.argv) == 3 and sys.argv[1] == '--port': + port = int(sys.argv[2]) + else: + port = 50052 + + server.add_insecure_port(f'[::]:{port}') + print(f"Running GRPC Server on {port}") + server.start() + server.wait_for_termination() + + +from destinations.weaviate_ import WeaviateDestination + +if __name__ == '__main__': + serve(WeaviateDestination()) diff --git a/examples/destination/vector/models/collection.py b/examples/destination/vector/models/collection.py new file mode 100644 index 0000000..2bc08c6 --- /dev/null +++ b/examples/destination/vector/models/collection.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from typing import List, Dict, Any +from enum import Enum + + +class Distance(Enum): + COSINE = 1 + DOT = 2 + EUCLIDIAN = 3 + + +@dataclass +class Metrics: + distance: Distance + dimensions: int + + +@dataclass +class Collection: + name: str + metrics: Metrics + + +@dataclass +class Row: + id: str + vector: List[float] + content: str + payload: Dict[str, Any] + + diff --git a/examples/destination/vector/requirements.txt b/examples/destination/vector/requirements.txt new file mode 100644 index 0000000..e939fd8 --- /dev/null +++ b/examples/destination/vector/requirements.txt @@ -0,0 +1,9 @@ +pandas +weaviate-client +langchain +langchain-community +langchain-openai +pycrypto +pycryptodome +zstandard +grpcio \ No newline at end of file diff --git a/examples/destination/vector/service.py b/examples/destination/vector/service.py new file mode 100644 index 0000000..5644421 --- /dev/null +++ b/examples/destination/vector/service.py @@ -0,0 +1,200 @@ +import pandas as pd + +from logger import log +from typing import List + +from destination import VectorDestination +from embedder import Embedder + +from sdk.destination_sdk_pb2_grpc import DestinationServicer +from sdk.common_pb2 import ConfigurationFormRequest, ConfigurationFormResponse, DataType +from sdk.common_pb2 import TestRequest, TestResponse +from sdk.destination_sdk_pb2 import DescribeTableRequest, DescribeTableResponse +from sdk.destination_sdk_pb2 import CreateTableRequest, CreateTableResponse +from sdk.destination_sdk_pb2 import AlterTableRequest, AlterTableResponse +from sdk.destination_sdk_pb2 import TruncateRequest, TruncateResponse +from sdk.destination_sdk_pb2 import WriteBatchRequest, WriteBatchResponse +from sdk.destination_sdk_pb2 import Compression, Encryption + +from models.collection import Collection, Row +from csv_reader import CSVReaderAESZSTD +from embedders.open_ai import OpenAIEmbedder + +EMBEDDERS: List[Embedder] = [OpenAIEmbedder()] + + +class VectorDestinationServicer(DestinationServicer): + vec_dest: VectorDestination + embedders: dict[str, Embedder] + + def __init__(self, vec_dest: VectorDestination): + self.vec_dest = vec_dest + self.embedders = {e.details()[0]: e for e in EMBEDDERS} + + def ConfigurationForm(self, request: ConfigurationFormRequest, context): + log("Called -> ConfigurationForm") + dest_config_form = self.vec_dest.configuration_form() + + combined_fields = [] + for f in dest_config_form.fields: + name = f"vect_dest__{f.name}" + new_f = f.__deepcopy__() + new_f.name = name + combined_fields.append(new_f) + + for _, e in self.embedders.items(): + c = e.configuration_form() + id, name = e.details() + for f in c.fields: + name = f"embedder__{id}__{f.name}" + new_f = f.__deepcopy__() + new_f.name = name + combined_fields.append(new_f) + + combined_tests = [] + for t in dest_config_form.tests: + name = f"vect_dest__{t.name}" + new_t = t.__deepcopy__() + new_t.name = name + combined_tests.append(new_t) + + for _, e in self.embedders.items(): + c = e.configuration_form() + id, name = e.details() + for t in c.tests: + name = f"embedder__{id}__{t.name}" + new_t = t.__deepcopy__() + new_t.name = name + combined_tests.append(new_t) + + combined_form = ConfigurationFormResponse( + schema_selection_supported=False, + table_selection_supported=False, + fields=combined_fields, + tests=combined_tests, + ) + + return combined_form + + def Test(self, request: TestRequest, context): + log(f"Called -> Test({request.name})") + + if request.name.startswith("vect_dest__"): + prefix = "vect_dest__" + test_target = self.vec_dest + elif request.name.startswith("embedder__"): + embedder_id = request.name.split("__")[1] + prefix = f"embedder__{embedder_id}__" + test_target = self.embedders.get(embedder_id) + if not test_target: + raise ValueError(f"Invalid embedder ID: {embedder_id}") + else: + raise ValueError(f"Invalid test {request.name}") + + config = {k.removeprefix(prefix): v for k, v in request.configuration.items() if k.startswith(prefix)} + name = request.name.removeprefix(prefix) + result = test_target.test(name, config) + + return TestResponse(success=bool(result), failure=result if not result else None) + + def DescribeTable(self, request: DescribeTableRequest, context): + log(f"Called -> DescribeTable({request.schema_name}.{request.table_name})") + return DescribeTableResponse(not_found=True) + + def _split_configs(self, config_in): + prefix = "vect_dest__" + config = {k.removeprefix(prefix): v for k, v in config_in.items() if k.startswith(prefix)} + + # TODO: Remove this hardcoding once `visibility` is implemented in SDK + embedder_id = "open_ai" + + prefix = f"embedder__{embedder_id}__" + embedder_config = {k.removeprefix(prefix): v for k, v in config_in.items() if k.startswith(prefix)} + + return config, embedder_id, embedder_config + + def CreateTable(self, request: CreateTableRequest, context): + log(f"Called -> CreateTable({request.schema_name}.{request.table.name})") + + config, embedder_id, embedder_config = self._split_configs(request.configuration) + embedder = self.embedders[embedder_id] + + collection = Collection( + name=self.vec_dest.get_collection_name(request.schema_name, request.table.name), + metrics=embedder.metrics(embedder_config) + ) + self.vec_dest.create_collection_if_not_exists(config, collection) + return CreateTableResponse(success=True) + + def AlterTable(self, request: AlterTableRequest, context): + log(f"Called -> AlterTable({request.schema_name}.{request.table.name})") + return AlterTableResponse(success=True) + + def WriteBatch(self, request: WriteBatchRequest, context): + log(f"Called -> WriteBatch({request.schema_name}.{request.table.name} ({request.csv.encryption}|{request.csv.compression}))") + + if request.csv.compression != Compression.ZSTD: + raise ValueError(f"Unknown compression{request.csv.compression}") + + if request.csv.encryption != Encryption.AES: + raise ValueError(f"Unknown encryption{request.csv.encryption}") + + if request.update_files: + raise NotImplementedError('No support for partial updates yet!') + + config, embedder_id, embedder_config = self._split_configs(request.configuration) + embedder = self.embedders[embedder_id] + + collection = Collection( + name=self.vec_dest.get_collection_name(request.schema_name, request.table.name), + metrics=embedder.metrics(embedder_config) + ) + timestamp_columns = [c.name for c in request.table.columns if c.type == DataType.UTC_DATETIME] + csv_reader = CSVReaderAESZSTD() + + for file in request.replace_files: + df = csv_reader.read_csv(file, request.keys[file], request.csv.null_string, timestamp_columns) + + records = df.to_dict(orient="records") + records = [{k: v for k, v in row.items() if not pd.isna(row[k])} for row in records] + records = [{k: (v.to_pydatetime() if k in timestamp_columns else v) for k, v in row.items()} for row in records] + + ids = [r["id"] for r in records] + documents = [r["document"] for r in records] + + vectors = embedder.embed(embedder_config, documents) + + rows = [Row( + id=ids[i], + vector=vectors[i], + content=documents[i], + payload=records[i], + ) for i in range(len(records))] + + self.vec_dest.upsert_rows(config, collection, rows) + + for file in request.delete_files: + df = csv_reader.read_csv(file, request.keys[file], request.csv.null_string, timestamp_columns) + records = df.to_dict(orient="records") + + ids = [r["id"] for r in records] + + self.vec_dest.delete_rows(config, collection, ids) + + return WriteBatchResponse(success=True) + + def Truncate(self, request: TruncateRequest, context): + log(f"Called -> Truncate({request.schema_name}.{request.table_name})") + + config, embedder_id, embedder_config = self._split_configs(request.configuration) + embedder = self.embedders[embedder_id] + + collection = Collection( + name=self.vec_dest.get_collection_name(request.schema_name, request.table_name), + metrics=embedder.metrics(embedder_config) + ) + + delete_before = request.utc_delete_before.ToDatetime() + self.vec_dest.truncate(config, collection, request.synced_column, delete_before) + + return TruncateResponse(success=True)