diff --git a/chirp/projects/hoplite/index.py b/chirp/projects/hoplite/index.py index 4775503e..7812cf99 100644 --- a/chirp/projects/hoplite/index.py +++ b/chirp/projects/hoplite/index.py @@ -42,18 +42,7 @@ def from_db( cls, db: interface.GraphSearchDBInterface, score_fn_name: str = 'dot' ) -> 'HopliteSearchIndex': """Create a VamanaSearchIndex from a GraphSearchDBInterface impl.""" - # TODO(tomdenton): Use an enum for metric_name. - if score_fn_name in ('mip', 'dot'): - # mip == Max Inner Prouct - score_fn = score_functions.numpy_dot - elif score_fn_name in ('jax_mip', 'jax_dot'): - score_fn = score_functions.get_jax_dot() - elif score_fn_name == 'cosine': - score_fn = score_functions.numpy_cos - elif score_fn_name == 'euclidean': - score_fn = score_functions.numpy_euclidean - else: - raise ValueError(f'Unknown metric name: {score_fn_name}') + score_fn = score_functions.get_score_fn(score_fn_name) return cls(db, score_fn=score_fn) def initialize_index(self, out_degree: int, seed: int = 42) -> None: diff --git a/chirp/projects/hoplite/score_functions.py b/chirp/projects/hoplite/score_functions.py index acd5f2e0..43c444b7 100644 --- a/chirp/projects/hoplite/score_functions.py +++ b/chirp/projects/hoplite/score_functions.py @@ -27,8 +27,8 @@ def get_score_fn( score_fn = numpy_dot elif name == 'cos': score_fn = numpy_cos - elif name == 'euclidean': - score_fn = numpy_euclidean + elif name == 'neg_euclidean': + score_fn = numpy_neg_euclidean else: raise ValueError('Unknown score function: ', name) @@ -38,7 +38,8 @@ def get_score_fn( bias_fn = score_fn if target_score is not None: - targeted_fn = lambda x, y: np.abs(bias_fn(x, y) - target_score) + # We want 'up is good', so take the negative absolute value. + targeted_fn = lambda x, y: -np.abs(bias_fn(x, y) - target_score) else: targeted_fn = bias_fn @@ -67,18 +68,18 @@ def numpy_cos(data: np.ndarray, query: np.ndarray) -> np.ndarray: return np.dot(unit_data, unit_query) -def numpy_euclidean(data: np.ndarray, query: np.ndarray) -> np.ndarray: - """Numpy L2 distance allowing multiple queries.""" +def numpy_neg_euclidean(data: np.ndarray, query: np.ndarray) -> np.ndarray: + """Negative L2 distance allowing multiple queries.""" data_norms = np.linalg.norm(data, axis=-1) if len(query.shape) > 1: query_norms = np.linalg.norm(query, axis=-1) dot_products = np.tensordot(data, query, axes=(-1, -1)) pairs = data_norms[:, np.newaxis] + query_norms[np.newaxis, :] - return pairs - 2 * dot_products + return -pairs + 2 * dot_products query_norm = np.linalg.norm(query) dot_products = np.dot(data, query) - return data_norms - 2 * dot_products + query_norm + return -data_norms + 2 * dot_products + query_norm def get_jax_dot():