Skip to content

Commit

Permalink
Add option to avoid inserting duplicate labels.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 680740739
  • Loading branch information
sdenton4 authored and copybara-github committed Sep 30, 2024
1 parent 10bf156 commit 6a836b8
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 4 deletions.
8 changes: 7 additions & 1 deletion chirp/projects/hoplite/in_mem_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,18 @@ def commit(self) -> None:
"""Commit any pending transactions to the database."""
pass

def insert_label(self, label: interface.Label) -> None:
def insert_label(
self, label: interface.Label, skip_duplicates: bool = False
) -> bool:
if label.type is None:
raise ValueError('label type must be set')
if label.provenance is None:
raise ValueError('label source must be set')
if skip_duplicates and label in self.get_labels(label.embedding_id):
return False

self.labels[label.embedding_id].append(label)
return True

def get_embeddings_by_label(
self,
Expand Down
16 changes: 14 additions & 2 deletions chirp/projects/hoplite/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,20 @@ def get_degree_bound(self) -> int:
return -1

@abc.abstractmethod
def insert_label(self, label: Label) -> None:
"""Add a label to the db."""
def insert_label(self, label: Label, skip_duplicates: bool = False) -> bool:
"""Add a label to the db.
Args:
label: The label to insert.
skip_duplicates: If True, and the label already exists, return False.
Otherwise, the label is inserted regardless of duplicates.
Returns:
True if the label was inserted, False if it was a duplicate and
skip_duplicates was True.
Raises:
ValueError if the label type or provenance is not set.
"""

@abc.abstractmethod
def embedding_dimension(self) -> int:
Expand Down
8 changes: 7 additions & 1 deletion chirp/projects/hoplite/sqlite_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,11 +407,16 @@ def get_edges(self, embedding_id: int) -> np.ndarray:
return np.array([])
return self.deserialize_edges(edge_bytes[0][0])

def insert_label(self, label: interface.Label):
def insert_label(
self, label: interface.Label, skip_duplicates: bool = False
) -> bool:
if label.type is None:
raise ValueError('label type must be set')
if label.provenance is None:
raise ValueError('label source must be set')
if skip_duplicates and label in self.get_labels(label.embedding_id):
return False

cursor = self._get_cursor()
cursor.execute(
"""
Expand All @@ -424,6 +429,7 @@ def insert_label(self, label: interface.Label):
label.provenance,
),
)
return True

def get_embeddings_by_label(
self,
Expand Down
7 changes: 7 additions & 0 deletions chirp/projects/hoplite/tests/hoplite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,13 @@ def test_labels_db_interface(self, db_type):
with self.subTest('count_classes'):
self.assertEqual(db.count_classes(), 2)

with self.subTest('duplicate_labels'):
dupe_label = interface.Label(
ids[0], 'hawgoo', interface.LabelType.POSITIVE, 'human'
)
self.assertFalse(db.insert_label(dupe_label, skip_duplicates=True))
self.assertTrue(db.insert_label(dupe_label, skip_duplicates=False))

def test_brute_search_impl_agreement(self):
rng = np.random.default_rng(42)
in_mem_db = test_utils.make_db(
Expand Down

0 comments on commit 6a836b8

Please sign in to comment.