-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathvectordb.py
151 lines (115 loc) · 4.23 KB
/
vectordb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import io
import sqlite3
from typing import List, Tuple, Any
import numpy as np
def adapt_array(arr):
"""
http://stackoverflow.com/a/31312102/190597
"""
out = io.BytesIO()
np.save(out, arr) # noqa
out.seek(0)
return sqlite3.Binary(out.read())
def convert_array(text):
"""
https://stackoverflow.com/a/18622264
"""
out = io.BytesIO(text)
out.seek(0)
return np.load(out) # noqa
def euclidean_distance(point1, point2):
"""
Calculate the Euclidean distance between two points represented as NumPy arrays.
"""
if point1.shape != point2.shape:
raise ValueError("Input points must have the same shape.")
# Calculate the Euclidean distance
distance = np.linalg.norm(point1 - point2)
return distance
def get_nearest_neighbor(train, test_row, num_neighbors: int = 1):
"""
Find the nearest neighbors of a test data point in a dataset.
"""
distances = []
for train_row in train:
dist = euclidean_distance(test_row, train_row)
distances.append((train_row, dist))
distances.sort(key=lambda tup: tup[1])
neighbors = []
for i in range(num_neighbors):
neighbors.append(distances[i][0])
return neighbors
# Converts np.array to TEXT when inserting
sqlite3.register_adapter(np.ndarray, adapt_array)
# Converts TEXT to np.array when selecting
sqlite3.register_converter("array", convert_array)
class SQLiteDB:
def __init__(self, database: str = ":memory:"):
self.conn = sqlite3.connect(database, detect_types=sqlite3.PARSE_DECLTYPES)
self.cur = self.conn.cursor()
def _create_table(self, table_name: str, columns: List[Tuple[str, str]]):
"""
Create a table with the given name and columns
Columns should be a list of tuples (name, type)
For example, [("id", "INTEGER PRIMARY KEY"), ("name", "TEXT")]
"""
sql = f"CREATE TABLE {table_name} ("
for column in columns:
sql += f"{column[0]} {column[1]}, "
sql = sql[:-2] + ")" # Removes the last comma and add a closing parenthesis
self.cur.execute(sql)
self.conn.commit()
def _insert_data(self, table_name: str, data: List[Tuple[Any]]):
"""
Insert data into the table
Data should be a list of tuples with values for each column
For example, [(1, "Alice"), (2, "Bob")]
"""
placeholders = ", ".join(
["?"] * len(data[0])
) # Create placeholders for each value
sql = f"INSERT INTO {table_name} VALUES ({placeholders})"
self.cur.executemany(sql, data)
self.conn.commit()
def _query_data(self, table_name: str, condition: str = None) -> List[Tuple]:
"""
Query data from the table
Condition is an optional string to filter the results
For example, "name = 'Alice'"
"""
sql = f"SELECT * FROM {table_name}"
if condition:
sql += f" WHERE {condition}"
self.cur.execute(sql)
return self.cur.fetchall() # Return a list of tuples with the query results
def _close(self):
"""
Create placeholders for each value
"""
self.conn.close()
class VectorDB(SQLiteDB):
def __init__(self, collection_name: str):
"""
Initialize a VectorDB instance for storing and querying vectors.
"""
super().__init__()
self.collection_name = collection_name
def create(self):
"""
Create the collection (table) for storing vectors with the name specified during initialization.
"""
columns = [("arr", "array")]
self._create_table(self.collection_name, columns)
def insert(self, vectors: List[np.array]):
"""
Insert a list of vectors (NumPy arrays) into the collection.
"""
_vectors = [(vector,) for vector in vectors]
self._insert_data(self.collection_name, _vectors)
def search(self, query: np.array, num_results: int):
"""
Find the nearest neighbors of a query vector in the collection.
"""
vectors = self._query_data(self.collection_name)
vectors = [vector[0] for vector in vectors]
return get_nearest_neighbor(vectors, query, num_results)