Skip to content

Commit

Permalink
Extend hoplite with methods to match MergedDataset.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 658434161
  • Loading branch information
Chirp Team authored and copybara-github committed Aug 1, 2024
1 parent 26830db commit b2260be
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 0 deletions.
16 changes: 16 additions & 0 deletions chirp/projects/hoplite/hoplite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,22 @@ def test_labels_db_interface(self, db_type):
got = db.get_labels(ids[0])
self.assertLen(got, 3)

with self.subTest('get_classes'):
got = db.get_classes()
self.assertSequenceEqual(got, ['hawgoo', 'rewbla'])

with self.subTest('get_class_counts'):
# 2 positive labels for 'hawgoo' ignoring provenance, 0 for 'rewbla'.
got = db.get_class_counts(interface.LabelType.POSITIVE)
self.assertDictEqual(got, {'hawgoo': 2, 'rewbla': 0})

# 1 negative label for 'rewbla', 0 for 'hawgoo'.
got = db.get_class_counts(interface.LabelType.NEGATIVE)
self.assertDictEqual(got, {'hawgoo': 0, 'rewbla': 1})

with self.subTest('count_classes'):
self.assertEqual(db.count_classes(), 2)

def test_brute_search_impl_agreement(self):
rng = np.random.default_rng(42)
in_mem_db = self._make_db('in_mem', 1000, rng)
Expand Down
27 changes: 27 additions & 0 deletions chirp/projects/hoplite/in_mem_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,30 @@ def get_embeddings_by_label(

def get_labels(self, embedding_id: int) -> Sequence[interface.Label]:
return self.labels[embedding_id]

def get_classes(self) -> Sequence[str]:
label_set = set()
for labels in self.labels.values():
for l in labels:
label_set.add(l.label)
return tuple(sorted(label_set))

def get_class_counts(
self, label_type: interface.LabelType = interface.LabelType.POSITIVE
) -> dict[str, int]:
class_counts = collections.defaultdict(int)
for labels in self.labels.values():
counted_labels = set()
for l in labels:
# Avoid double-counting the same label on the same embedding because of
# different provenances.
if l.label in counted_labels:
continue
if l.type.value == label_type.value:
class_counts[l.label] += 1
counted_labels.add(l.label)
else:
# Creates a key in the dict for all labels, even if they have no
# matching type counts.
class_counts[l.label] += 0
return class_counts
20 changes: 20 additions & 0 deletions chirp/projects/hoplite/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,22 @@ def get_embeddings_by_label(
def get_labels(self, embedding_id: int) -> Sequence[Label]:
"""Get all labels for the indicated embedding_id."""

@abc.abstractmethod
def get_classes(self) -> Sequence[str]:
"""Get all distinct classes (label strings) in the database."""

@abc.abstractmethod
def get_class_counts(
self, label_type: LabelType = LabelType.POSITIVE
) -> dict[str, int]:
"""Count the number of occurences of each class in the database.
Classes with zero matching occurences are still included in the result.
Args:
label_type: Type of label to count. By default, counts positive labels.
"""

# Composite methods

def get_one_embedding_id(self) -> int:
Expand All @@ -230,6 +246,10 @@ def count_edges(self) -> int:
ct += self.get_edges(idx).shape[0]
return ct

def count_classes(self) -> int:
"""Return a count of all distinct classes in the database."""
return len(self.get_classes())

def get_embeddings(
self, embedding_ids: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
Expand Down
27 changes: 27 additions & 0 deletions chirp/projects/hoplite/sqlite_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""SQLite Implementation of a searchable embeddings database."""

import collections
from collections.abc import Sequence
import dataclasses
import json
import sqlite3
Expand Down Expand Up @@ -444,6 +446,31 @@ def get_labels(self, embedding_id: int) -> tuple[interface.Label, ...]:
for r in results
)

def get_classes(self) -> Sequence[str]:
cursor = self._get_cursor()
cursor.execute('SELECT DISTINCT label FROM hoplite_labels ORDER BY label;')
return tuple(r[0] for r in cursor.fetchall())

def get_class_counts(
self, label_type: interface.LabelType = interface.LabelType.POSITIVE
) -> dict[str, int]:
cursor = self._get_cursor()
# Subselect with distinct is needed to avoid double-counting the same label
# on the same embedding because of different provenances.
cursor.execute("""
SELECT label, type, COUNT(*)
FROM (
SELECT DISTINCT embedding_id, label, type FROM hoplite_labels)
GROUP BY label, type;
""")
results = collections.defaultdict(int)
for r in cursor.fetchall():
if r[1] == label_type.value:
results[r[0]] = r[2]
else:
results[r[0]] += 0
return results

def print_table_values(self, table_name):
"""Prints all values from the specified table in the SQLite database."""
cursor = self._get_cursor()
Expand Down

0 comments on commit b2260be

Please sign in to comment.