Skip to content

Commit

Permalink
Automatically handle embedding dimension in hoplite sqlite database, …
Browse files Browse the repository at this point in the history
…and fix a bug when retrieving embeddings by source.

PiperOrigin-RevId: 678362101
  • Loading branch information
sdenton4 authored and copybara-github committed Sep 24, 2024
1 parent 3b72c67 commit b3c43bf
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 25 deletions.
64 changes: 42 additions & 22 deletions chirp/projects/hoplite/sqlite_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import collections
from collections.abc import Sequence
import dataclasses
import functools
import json
import sqlite3
from typing import Any
Expand All @@ -44,35 +45,37 @@ class SQLiteGraphSearchDB(interface.GraphSearchDBInterface):
def create(
cls,
db_path: str,
embedding_dim: int,
embedding_dim: int | None = None,
embedding_dtype: type[Any] = np.float16,
):
db = sqlite3.connect(db_path)
cursor = db.cursor()
cursor.execute('PRAGMA journal_mode=WAL;') # Enable WAL mode
db.commit()
if embedding_dim is None:
# Get an embedding from the DB to check its dimension.
cursor = db.cursor()
cursor.execute("""SELECT embedding FROM hoplite_embeddings LIMIT 1;""")
try:
embedding = cursor.fetchall()[0][0]
except IndexError as exc:
raise ValueError(
'Must specify embedding dimension for empty databases.'
) from exc
embedding = deserialize_embedding(embedding, embedding_dtype)
embedding_dim = embedding.shape[-1]

return SQLiteGraphSearchDB(db, db_path, embedding_dim, embedding_dtype)

def thread_split(self):
"""Get a new instance of the SQLite DB."""
return self.create(self.db_path, self.embedding_dim, self.embedding_dtype)
return self.create(self.db_path, self.embedding_dtype)

def _get_cursor(self) -> sqlite3.Cursor:
if self._cursor is None:
self._cursor = self.db.cursor()
return self._cursor

def serialize_embedding(self, embedding: np.ndarray) -> bytes:
return embedding.astype(
np.dtype(self.embedding_dtype).newbyteorder('<')
).tobytes()

def deserialize_embedding(self, serialized_embedding: bytes) -> np.ndarray:
return np.frombuffer(
serialized_embedding,
dtype=np.dtype(self.embedding_dtype).newbyteorder('<'),
)

def serialize_edges(self, edges: np.ndarray) -> bytes:
return edges.astype(np.dtype(np.int64).newbyteorder('<')).tobytes()

Expand Down Expand Up @@ -163,7 +166,7 @@ def get_embedding_ids(self) -> np.ndarray:

def get_one_embedding_id(self) -> int:
cursor = self._get_cursor()
cursor.execute("""SELECT id FROM hoplite_embeddings;""")
cursor.execute("""SELECT id FROM hoplite_embeddings LIMIT 1;""")
return int(cursor.fetchone()[0])

def insert_metadata(self, key: str, value: config_dict.ConfigDict) -> None:
Expand Down Expand Up @@ -273,11 +276,11 @@ def insert_embedding(
if embedding.shape[-1] != self.embedding_dim:
raise ValueError('Incorrect embedding dimension.')
cursor = self._get_cursor()
embedding_bytes = self.serialize_embedding(embedding)
embedding_bytes = serialize_embedding(embedding, self.embedding_dtype)
source_id = self._get_source_id(
source.dataset_name, source.source_id, insert=True
)
offset_bytes = self.serialize_embedding(source.offsets)
offset_bytes = serialize_embedding(source.offsets, self.embedding_dtype)
cursor.execute(
"""
INSERT INTO hoplite_embeddings (embedding, source_idx, offsets) VALUES (?, ?, ?);
Expand Down Expand Up @@ -306,7 +309,7 @@ def get_embedding(self, embedding_id: int):
(int(embedding_id),),
)
embedding = cursor.fetchall()[0][0]
return self.deserialize_embedding(embedding)
return deserialize_embedding(embedding, self.embedding_dtype)

def get_embedding_source(
self, embedding_id: int
Expand All @@ -322,7 +325,7 @@ def get_embedding_source(
(int(embedding_id),),
)
dataset, source, offsets = cursor.fetchall()[0]
offsets = self.deserialize_embedding(offsets)
offsets = deserialize_embedding(offsets, self.embedding_dtype)
return interface.EmbeddingSource(dataset, str(source), offsets)

def get_embeddings(
Expand All @@ -339,7 +342,9 @@ def get_embeddings(
).fetchall()
result_ids = np.array(tuple(int(c[0]) for c in results))
embeddings = np.array(
tuple(self.deserialize_embedding(c[1]) for c in results)
tuple(
deserialize_embedding(c[1], self.embedding_dtype) for c in results
)
)
return result_ids, embeddings

Expand All @@ -362,7 +367,6 @@ def get_embeddings_by_source(
A list of embedding IDs matching the indicated source parameters.
"""
cursor = self._get_cursor()
source_id = self._get_source_id(dataset_name, source_id, insert=False)
if source_id is None:
query = (
'SELECT id, offsets FROM hoplite_embeddings '
Expand All @@ -373,15 +377,16 @@ def get_embeddings_by_source(
)
cursor.execute(query, (dataset_name,))
else:
source_idx = self._get_source_id(dataset_name, source_id, insert=False)
query = (
'SELECT id, offsets FROM hoplite_embeddings '
'WHERE hoplite_embeddings.source_idx = ?;'
)
cursor.execute(query, (source_id,))
cursor.execute(query, (source_idx,))
result_pairs = cursor.fetchall()
outputs = []
for idx, offsets_bytes in result_pairs:
got_offsets = self.deserialize_embedding(offsets_bytes)
got_offsets = deserialize_embedding(offsets_bytes, self.embedding_dtype)
if offsets is not None and not np.array_equal(got_offsets, offsets):
continue
outputs.append(idx)
Expand Down Expand Up @@ -495,3 +500,18 @@ def print_table_values(self, table_name):
# Print each row as a comma-separated string
for row in rows:
print(', '.join(str(value) for value in row))


def serialize_embedding(
embedding: np.ndarray, embedding_dtype: type[Any]
) -> bytes:
return embedding.astype(np.dtype(embedding_dtype).newbyteorder('<')).tobytes()


def deserialize_embedding(
serialized_embedding: bytes, embedding_dtype: type[Any]
) -> np.ndarray:
return np.frombuffer(
serialized_embedding,
dtype=np.dtype(embedding_dtype).newbyteorder('<'),
)
23 changes: 20 additions & 3 deletions chirp/projects/hoplite/tests/hoplite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,23 @@ def test_graph_db_interface(self, db_type, thread_split):
self.assertEmpty(edges)
self.assertEqual(db.count_edges(), 499)

with self.subTest('test_embedding_sources'):
source = db.get_embedding_source(idxes[1])
# The embeddings are given one of three randomly selected dataset names.
embs = db.get_embeddings_by_source(source.dataset_name, None, None)
self.assertGreater(embs.shape[0], db.count_embeddings() / 6)
# For an unknown dataset name, we should get no embeddings.
embs = db.get_embeddings_by_source('fake_name', None, None)
self.assertEqual(embs.shape[0], 0)
# Source ids are approximately unique.
embs = db.get_embeddings_by_source(
source.dataset_name, source.source_id, None
)
self.assertLen(embs, 1)
# For an unknown source id, we should get no embeddings.
embs = db.get_embeddings_by_source(source.dataset_name, 'fake_id', None)
self.assertEqual(embs.shape[0], 0)

db.drop_all_edges()
self.assertEqual(db.count_edges(), 0)

Expand Down Expand Up @@ -134,15 +151,15 @@ def test_labels_db_interface(self, db_type):
)

with self.subTest('get_embeddings_by_label'):
# When both label_type and source are unspecified, we should get all
# When both label_type and provenance are unspecified, we should get all
# unique IDs with the target label. Id's 0 and 1 both have some kind of
# 'hawgoo' label.
got = db.get_embeddings_by_label('hawgoo', None, None)
self.assertSequenceEqual(sorted(got), sorted([ids[0], ids[1]]))

with self.subTest('get_embeddings_by_label_type'):
# Now we should get the ID's for all POSITIVE 'hawgoo' labels, regardless
# of source.
# of provenance.
got = db.get_embeddings_by_label(
'hawgoo', interface.LabelType.POSITIVE, None
)
Expand All @@ -154,7 +171,7 @@ def test_labels_db_interface(self, db_type):
)
self.assertEqual(got.shape[0], 0)

with self.subTest('get_embeddings_by_label_source'):
with self.subTest('get_embeddings_by_label_provenance'):
# There is only one hawgoo labeled by a human.
got = db.get_embeddings_by_label('hawgoo', None, 'human')
self.assertSequenceEqual(got, [ids[0]])
Expand Down

0 comments on commit b3c43bf

Please sign in to comment.