Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add algo rabitq to streaming #318

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions neurips23/streaming/rabbithole/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
FROM kemingy/rabbithole:pg17

# https://github.com/tensorchord/rabbithole

RUN apt-get update \
&& apt-get install -y python3-pip build-essential git axel wget
RUN wget https://aka.ms/downloadazcopy-v10-linux && \
mv downloadazcopy-v10-linux azcopy.tgz && \
tar xzf azcopy.tgz --transform 's!^[^/]\+\($\|/\)!azcopy_folder\1!' && \
cp azcopy_folder/azcopy /usr/bin

WORKDIR /home/app
COPY requirements_py3.10.txt .

RUN python3 -m pip install --break-system-packages -r requirements_py3.10.txt
RUN python3 -m pip install --break-system-packages psycopg[binary] pgvector

COPY run_algorithm.py .

ENV POSTGRES_PASSWORD=postgres
ENV POSTGRES_USER=postgres

RUN printf '#!/bin/bash\n\
runuser -u postgres -- initdb \n\
runuser -u postgres -- postgres -c shared_preload_libraries=rabbithole.so &\n\
sleep 5\n\
python3 -u run_algorithm.py "$@"' > entrypoint.sh \
&& chmod u+x entrypoint.sh

ENTRYPOINT ["/home/app/entrypoint.sh"]
28 changes: 28 additions & 0 deletions neurips23/streaming/rabbithole/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
random-xs:
rabbithole:
docker-tag: neurips23-streaming-rabbithole
module: neurips23.streaming.rabbithole.rabbithole
constructor: RabbitHole
base-args: ["@metric"]
run-groups:
base:
args: |
[{"nlist":16}]
query-args: |
[{"probe": 3}]
msturing-30M-clustered:
rabbithole:
docker-tag: neurips23-streaming-rabbithole
module: neurips23.streaming.rabbithole.rabbithole
constructor: RabbitHole
base-args: ["@metric"]
run-groups:
base:
args: |
[{"nlist":16384}]
query-args: |
[
{"probe": 300},
{"probe": 500},
{"probe": 1000}
]
79 changes: 79 additions & 0 deletions neurips23/streaming/rabbithole/rabbithole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import psycopg
import numpy as np
from pgvector.psycopg import register_vector

from neurips23.streaming.base import BaseStreamingANN

DISTANCE_METRICS = {
"euclidean": "vector_l2_ops",
"angular": "vector_cosine_ops",
"ip": "vector_ip_ops",
}
TYPE_MAP = {
"float32": "vector",
"float16": "halfvec",
}


class RabbitHole(BaseStreamingANN):
def __init__(self, metric, index_params):
self.name = "rabbithole"
self.nlist = index_params.get("nlist")
self.metric = DISTANCE_METRICS.get(metric)
self.conn = psycopg.connect("postgresql://postgres:[email protected]:5432/")
self.conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
self.conn.execute("CREATE EXTENSION IF NOT EXISTS rabbithole")
register_vector(self.conn)

def setup(self, dtype, max_pts, ndims) -> None:
self.dtype = TYPE_MAP.get(dtype, "vector")
self.max_vectors = max_pts
self.ndims = ndims
self.config = f"""
residual_quantization = {'true' if self.metric == 'vector_l2_ops' else 'false'}
[build.internal]
lists = {self.nlist}
spherical_centroids = {'true' if self.metric != 'vector_l2_ops' else 'false'}
"""
self.conn.execute(
f"CREATE TABLE IF NOT EXISTS ann (id SERIAL PRIMARY KEY, emb {self.dtype}({self.ndims}))"
)
self.conn.execute(
f"CREATE INDEX ON ann USING rabbithole (emb {self.metric}) WITH (options=$${self.config}$$)"
)

def set_query_arguments(self, query_args):
self.query_args = query_args
self.probe = query_args.get("probe")
if self.probe:
self.conn.execute(f"SET rabbithole.probes = {self.probe}")

def insert(self, X, ids):
with self.conn.cursor().copy(
"COPY ann (id, emb) FROM STDIN WITH (FORMAT BINARY)"
) as copy:
copy.set_types(("integer", "vector"))
for i, vec in zip(ids, X):
copy.write_row((i, vec))

def delete(self, ids):
self.conn.execute("DELETE FROM ann WHERE id = ANY(%s)", (list(ids),))

def replace(self, dataset):
return super().fit(dataset)

def query(self, X, k):
n = len(X)
self.res = np.empty((n, k), dtype="uint32")
for i, x in enumerate(X):
rows = self.conn.execute(
"SELECT id FROM ann ORDER BY emb <-> %s LIMIT %s", (x, k)
).fetchall()
for j, (id,) in enumerate(rows):
self.res[i, j] = id

def range_query(self, X, radius):
raise NotImplementedError

def __str__(self):
return f"RabbitHole(nlist={self.nlist},dim={self.ndims},type={self.dtype})"