Skip to content

Commit

Permalink
Use sparse vectors for SQLite edges, instead of edge-per-row. This ma…
Browse files Browse the repository at this point in the history
…tches behavior in-memory, and speeds up some operations by reducing the number of DB ops.

PiperOrigin-RevId: 664937502
  • Loading branch information
sdenton4 authored and copybara-github committed Aug 19, 2024
1 parent 518f694 commit 5fb0209
Showing 1 changed file with 35 additions and 33 deletions.
68 changes: 35 additions & 33 deletions chirp/projects/hoplite/sqlite_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ def deserialize_embedding(self, serialized_embedding: bytes) -> np.ndarray:
dtype=np.dtype(self.embedding_dtype).newbyteorder('<'),
)

def serialize_edges(self, edges: np.ndarray) -> bytes:
return edges.astype(np.dtype(np.int64).newbyteorder('<')).tobytes()

def deserialize_edges(self, serialized_edges: bytes) -> np.ndarray:
return np.frombuffer(
serialized_edges,
dtype=np.dtype(np.int64).newbyteorder('<'),
)

def setup(self, index=True):
cursor = self._get_cursor()
# Create embedding sources table
Expand Down Expand Up @@ -107,10 +116,10 @@ def setup(self, index=True):
CREATE TABLE IF NOT EXISTS hoplite_edges (
id INTEGER PRIMARY KEY AUTOINCREMENT,
source_embedding_id INTEGER NOT NULL,
target_embedding_id INTEGER NOT NULL,
FOREIGN KEY (source_embedding_id) REFERENCES embeddings(id),
FOREIGN KEY (target_embedding_id) REFERENCES embeddings(id)
)""")
target_embedding_ids BLOB NOT NULL,
FOREIGN KEY (source_embedding_id) REFERENCES embeddings(id)
);
""")

# Create hoplite_labels table.
cursor.execute("""
Expand All @@ -137,14 +146,7 @@ def setup(self, index=True):
CREATE INDEX IF NOT EXISTS embedding_source ON hoplite_embeddings (source_idx);
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_source_embedding ON hoplite_edges (source_embedding_id);
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_target_embedding ON hoplite_edges (target_embedding_id);
""")
cursor.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS
idx_edge ON hoplite_edges (source_embedding_id, target_embedding_id);
CREATE UNIQUE INDEX IF NOT EXISTS idx_source_embedding ON hoplite_edges (source_embedding_id);
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_label ON hoplite_labels (embedding_id, label);
Expand Down Expand Up @@ -203,23 +205,28 @@ def get_dataset_names(self) -> tuple[str, ...]:
cursor.execute("""SELECT DISTINCT dataset FROM hoplite_sources;""")
return tuple(c[0] for c in cursor.fetchall())

def insert_edge(self, x_id: int, y_id: int):
def insert_edges(
self, x_id: int, y_ids: np.ndarray, replace: bool = False
) -> None:
cursor = self._get_cursor()
if not replace:
existing = self.get_edges(x_id)
y_ids = np.unique(np.concatenate([existing, y_ids], axis=0))
cursor.execute(
"""
INSERT INTO hoplite_edges (source_embedding_id, target_embedding_id) VALUES (?, ?);
""",
(int(x_id), int(y_id)),
REPLACE INTO hoplite_edges (source_embedding_id, target_embedding_ids)
VALUES (?, ?);
""",
(int(x_id), self.serialize_edges(y_ids)),
)

def insert_edge(self, x_id: int, y_id: int):
self.insert_edges(x_id, np.array([y_id]))

def delete_edge(self, x_id: int, y_id: int):
cursor = self._get_cursor()
cursor.execute(
"""
DELETE FROM hoplite_edges WHERE source_embedding_id = ? AND target_embedding_id = ?;
""",
(int(x_id), int(y_id)),
)
existing = self.get_edges(x_id)
new_edges = existing[existing != y_id]
self.insert_edges(x_id, new_edges, replace=True)

def delete_edges(self, x_id: int):
cursor = self._get_cursor()
Expand Down Expand Up @@ -290,13 +297,6 @@ def count_embeddings(self) -> int:
def embedding_dimension(self) -> int:
return self.embedding_dim

def count_edges(self) -> int:
"""Counts the number of hoplite_embeddings in the 'embeddings' table."""
cursor = self._get_cursor()
cursor.execute('SELECT COUNT(*) FROM hoplite_edges;')
result = cursor.fetchone()
return result[0] # Extract the count from the result tuple

def get_embedding(self, embedding_id: int):
cursor = self._get_cursor()
cursor.execute(
Expand Down Expand Up @@ -389,16 +389,18 @@ def get_embeddings_by_source(

def get_edges(self, embedding_id: int) -> np.ndarray:
query = (
'SELECT hoplite_edges.target_embedding_id FROM hoplite_edges '
'SELECT hoplite_edges.target_embedding_ids FROM hoplite_edges '
'WHERE hoplite_edges.source_embedding_id = ?;'
)
cursor = self._get_cursor()
cursor.execute(
query,
(int(embedding_id),),
)
edges = cursor.fetchall()
return np.array(tuple(e[0] for e in edges))
edge_bytes = cursor.fetchall()
if not edge_bytes:
return np.array([])
return self.deserialize_edges(edge_bytes[0][0])

def insert_label(self, label: interface.Label):
if label.type is None:
Expand Down

0 comments on commit 5fb0209

Please sign in to comment.