Skip to content

Commit

Permalink
add vecdb example
Browse files Browse the repository at this point in the history
  • Loading branch information
abhijeethp committed Jun 5, 2024
1 parent 08da2f8 commit 6085452
Show file tree
Hide file tree
Showing 10 changed files with 507 additions and 0 deletions.
26 changes: 26 additions & 0 deletions examples/destination/vector/csv_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pandas as pd
from Crypto.Cipher import AES
from Crypto.Util.Padding import unpad
import zstandard as zstd
from io import StringIO
from logger import log


class CSVReaderAESZSTD:

def read_csv(self, file_path, aes_key, null_string, timestamp_columns):
with open(file_path, 'rb') as encrypted_file:
iv = encrypted_file.read(16)
encrypted_data = encrypted_file.read()

cipher = AES.new(aes_key, AES.MODE_CBC, iv)
decrypted_data = unpad(cipher.decrypt(encrypted_data), AES.block_size)

decompressor = zstd.ZstdDecompressor()

with decompressor.stream_reader(decrypted_data) as reader:
decompressed_data = reader.read()

data_str = decompressed_data.decode('utf-8')
df = pd.read_csv(StringIO(data_str), na_values=null_string, parse_dates=timestamp_columns)
return df
41 changes: 41 additions & 0 deletions examples/destination/vector/destination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from datetime import datetime
from abc import ABC, abstractmethod

from models.collection import Collection, Row
from typing import Optional, Any, List

from sdk.common_pb2 import ConfigurationFormResponse


class VectorDestination(ABC):
"""Interface for vector destinations"""

@abstractmethod
def configuration_form(self) -> ConfigurationFormResponse:
"""configuration_form"""

@abstractmethod
def test(self, name: str, configuration: dict[str, str]) -> Optional[str]:
"""test"""

@abstractmethod
def create_collection_if_not_exists(self, configuration: dict[str, Any], collection: Collection) -> None:
"""create_collection"""

@abstractmethod
def upsert_rows(self, configuration: dict[str, Any], collection: Collection, rows: List[Row]) -> None:
"""upsert_rows"""

@abstractmethod
def delete_rows(self, configuration: dict[str, Any], collection: Collection, ids: List[str]) -> None:
"""delete_rows"""

@abstractmethod
def truncate(self, configuration: dict[str, Any], collection: Collection, synced_column: str, delete_before: datetime) -> None:
"""delete_rows"""

# Not Ideal but no clear winner ¯\_(ツ)_/¯
def get_collection_name(self, schema_name: str, table_name: str) -> str:
schema_name = schema_name.replace("_","-")
table_name = table_name.replace("_", "-")
return f"{schema_name}-{table_name}"
86 changes: 86 additions & 0 deletions examples/destination/vector/destinations/weaviate_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import uuid
import weaviate
from weaviate.auth import AuthApiKey
from weaviate.classes.query import Filter

from sdk.common_pb2 import ConfigurationFormResponse, FormField, TextField, ConfigurationTest
from destination import VectorDestination


class WeaviateDestination(VectorDestination):
def get_collection_name(self, schema_name, table_name):
return f"{schema_name}_{table_name}"

def configuration_form(self):
fields = [
FormField(name="url", label="Weaviate Cluster URL", required=True, text_field=TextField.PlainText),
FormField(name="api_key", label="Weaviate API Key", required=True, text_field=TextField.Password)
]
tests = [ConfigurationTest(name="connection_test", label="Connecting to Weaviate Cluster")]
return ConfigurationFormResponse(fields=fields, tests=tests)

def _get_client(self, configuration):
return weaviate.connect_to_wcs(
cluster_url=configuration["url"],
auth_credentials=AuthApiKey(configuration["api_key"])
)

def test(self, name, config):
if name != "connection_test":
raise ValueError(name)

client = self._get_client(config)

client.connect()
client.close()

def create_collection_if_not_exists(self, config, collection):
client = self._get_client(config)

if not client.collections.exists(collection.name):
print(f"Collection {collection.name} does not exist! Creating!")
client.collections.create(name=collection.name)

client.close()

def upsert_rows(self, config, collection, rows):
client = self._get_client(config)
c = client.collections.get(collection.name)

for row in rows:
_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, str(row.id))

# TODO: Get these swap column names from config
# TODO: swap column name in framework if other vec dbs have same issue.
id_swap_column = "_fvt_swp_id"
vector_swap_column = "_fvt_swp_vector"

if "id" in row.payload:
row.payload[id_swap_column] = row.payload.pop("id")
if "vector" in row.payload:
row.payload[vector_swap_column] = row.payload.pop("vector")

if c.data.exists(_uuid):
c.data.replace(uuid=_uuid, properties=row.payload, vector=row.vector)
else:
c.data.insert(uuid=_uuid, properties=row.payload, vector=row.vector)

client.close()

def delete_rows(self, config, collection, ids):
client = self._get_client(config)
c = client.collections.get(collection.name)

for id in ids:
_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, str(id))
c.data.delete_by_id(uuid=_uuid)

client.close()

def truncate(self, config, collection, synced_column, delete_before):
client = self._get_client(config)
c = client.collections.get(collection.name)

filter = Filter.by_property(synced_column).less_than(delete_before)
c.data.delete_many(where=filter)

31 changes: 31 additions & 0 deletions examples/destination/vector/embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from abc import ABC, abstractmethod

from typing import Optional, Any, List, Tuple

from sdk.common_pb2 import ConfigurationFormResponse
from models.collection import Metrics


class Embedder(ABC):
"""Interface for embedders"""

@abstractmethod
def details(self)-> Tuple[str, str]:
"""details -> [id, name]"""


@abstractmethod
def configuration_form(self) -> ConfigurationFormResponse:
"""configuration_form"""

@abstractmethod
def metrics(self, configuration: dict[str, str])-> Metrics:
"""metrics"""

@abstractmethod
def test(self, name: str, configuration: dict[str, str]) -> Optional[str]:
"""test"""

@abstractmethod
def embed(self, configuration: dict[str, str], texts: List[str]) -> List[List[float]]:
"""embed"""
46 changes: 46 additions & 0 deletions examples/destination/vector/embedders/open_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from embedder import Embedder
from sdk.common_pb2 import ConfigurationFormResponse, FormField, TextField, DropdownField, ConfigurationTest
from langchain_openai import OpenAIEmbeddings
from models.collection import Metrics, Distance

MODELS = {
"text-embedding-ada-002": Metrics(
distance=Distance.COSINE,
dimensions=1536
)
}


class OpenAIEmbedder(Embedder):

def details(self):
return "open_ai", "OpenAI"

def configuration_form(self):
models = DropdownField(dropdown_field=list(MODELS.keys()))
fields = [
FormField(name="api_key", label="OpenAI API Key", required=True, text_field=TextField.Password),
FormField(name="embedding_model", label="OpenAI Embedding Model", required=True, dropdown_field=models),
]
tests = [ConfigurationTest(name="embedding_test", label="Checking OpenAI Embedding Generation")]
return ConfigurationFormResponse(fields=fields, tests=tests)

def metrics(self, config) -> Metrics:
return MODELS[config["embedding_model"]]

def _get_embedding(self, configuration):
api_key = configuration["api_key"]
model = configuration["embedding_model"]

return OpenAIEmbeddings(api_key=api_key, model=model)

def test(self, name, configuration):
if name != "embedding_test":
raise ValueError(f'Unknown test : {name}')

embedding = self._get_embedding(configuration)
embedding.embed_query("foo-bar-biz")

def embed(self, configuration, texts):
embedding = self._get_embedding(configuration)
return embedding.embed_documents(texts)
10 changes: 10 additions & 0 deletions examples/destination/vector/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import json


def log(msg: str):
m = {
"level": "INFO",
"message": msg,
"message-origin": "sdk_destination"
}
print(json.dumps(m), flush=True)
27 changes: 27 additions & 0 deletions examples/destination/vector/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import grpc
from concurrent import futures
from sdk.destination_sdk_pb2_grpc import add_DestinationServicer_to_server
from destination import VectorDestination
from service import VectorDestinationServicer
import sys


def serve(vec_dest: VectorDestination):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
add_DestinationServicer_to_server(VectorDestinationServicer(vec_dest), server)

if len(sys.argv) == 3 and sys.argv[1] == '--port':
port = int(sys.argv[2])
else:
port = 50052

server.add_insecure_port(f'[::]:{port}')
print(f"Running GRPC Server on {port}")
server.start()
server.wait_for_termination()


from destinations.weaviate_ import WeaviateDestination

if __name__ == '__main__':
serve(WeaviateDestination())
31 changes: 31 additions & 0 deletions examples/destination/vector/models/collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from typing import List, Dict, Any
from enum import Enum


class Distance(Enum):
COSINE = 1
DOT = 2
EUCLIDIAN = 3


@dataclass
class Metrics:
distance: Distance
dimensions: int


@dataclass
class Collection:
name: str
metrics: Metrics


@dataclass
class Row:
id: str
vector: List[float]
content: str
payload: Dict[str, Any]


9 changes: 9 additions & 0 deletions examples/destination/vector/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
pandas
weaviate-client
langchain
langchain-community
langchain-openai
pycrypto
pycryptodome
zstandard
grpcio
Loading

0 comments on commit 6085452

Please sign in to comment.