diff --git a/min_vec/api/api.py b/min_vec/api/api.py index 28e96bc..3836c57 100644 --- a/min_vec/api/api.py +++ b/min_vec/api/api.py @@ -28,8 +28,7 @@ class MinVectorDB: """ @ParametersValidator( - update_configs=['dim', 'database_path', 'n_clusters', 'chunk_size', 'index_mode', 'dtypes', 'scaler_bits', - 'distance'], + update_configs=['dim', 'n_clusters', 'chunk_size', 'index_mode', 'dtypes', 'scaler_bits'], logger=logger ) @ParameterTypeAssert({ @@ -48,7 +47,7 @@ class MinVectorDB: def __init__( self, dim: int, database_path: str, n_clusters: int = 16, chunk_size: int = 100_000, distance: str = 'cosine', index_mode: str = 'IVF-FLAT', dtypes: str = 'float32', - use_cache: bool = True, scaler_bits: int = 8, n_threads: int = None, + use_cache: bool = True, scaler_bits: Union[int, None] = 8, n_threads: Union[int, None] = 10, warm_up: bool = False ) -> None: """ @@ -69,9 +68,9 @@ def __init__( scaler_bits (int): The number of bits for scalar quantization. Options are 8, 16, or 32. The default is None, which means no scalar quantization. The 8 for 8-bit, 16 for 16-bit, and 32 for 32-bit. - n_threads (int): The number of threads to use for parallel processing. Default is None, which means using - twice the number of CPU cores. + n_threads (int): The number of threads to use for parallel processing. Default is 10. warm_up (bool): Whether to warm up the database. Default is False. + .. versionadded:: 0.2.6 Raises: ValueError: If `chunk_size` is less than or equal to 1. @@ -102,7 +101,6 @@ def __init__( database_path=database_path, n_clusters=n_clusters, chunk_size=chunk_size, - distance=distance, index_mode=index_mode, logger=logger, dtypes=dtypes, @@ -121,7 +119,8 @@ def __init__( self._query = Query( matrix_serializer=self._matrix_serializer, - n_threads=n_threads if n_threads else min(32, os.cpu_count() + 4) + n_threads=n_threads if n_threads else min(32, os.cpu_count() + 4), + distance=distance ) self._query.query.clear_cache() @@ -171,17 +170,23 @@ def commit(self): self._matrix_serializer.commit() @unavailable_if_deleted - def query(self, vector: np.ndarray, k: Union[int, str] = 12, *, fields: List = None, subset_indices: List = None, + def query(self, vector: np.ndarray, k: int = 12, *, + fields: Union[List, None] = None, + subset_indices: Union[List, None] = None, + distance: Union[str, None] = None, return_similarity: bool = True): """ Query the database for the vectors most similar to the given vector in batches. Parameters: vector (np.ndarray): The query vector. - k (int or str): The number of nearest vectors to return. if be 'all', return all vectors. + k (int): The number of nearest vectors to return. fields (list, optional): The target of the vector. subset_indices (list, optional): The subset of indices to query. + distance (str): The distance metric to use for the query. + .. versionadded:: 0.2.7 return_similarity (bool): Whether to return the similarity scores.Default is True. + .. versionadded:: 0.2.5 Returns: Tuple: The indices and similarity scores of the top k nearest vectors. @@ -191,26 +196,55 @@ def query(self, vector: np.ndarray, k: Union[int, str] = 12, *, fields: List = N """ import datetime + logger.debug(f'Query vector: {vector.tolist()}') + logger.debug(f'Query k: {k}') + logger.debug(f'Query fields: {fields}') + logger.debug(f'Query subset_indices: {subset_indices}') + + raise_if(TypeError, not isinstance(k, int) and not (isinstance(k, str) and k != 'all'), + 'k must be int or "all".') + raise_if(ValueError, k <= 0, 'k must be greater than 0.') + raise_if(ValueError, not isinstance(fields, list) and fields is not None, + 'fields must be list or None.') + raise_if(ValueError, not isinstance(subset_indices, list) and subset_indices is not None, + 'subset_indices must be list or None.') + raise_if(ValueError, vector is None, 'vector must be not None.') + raise_if(ValueError, len(vector) != self._matrix_serializer.shape[1], + 'vector must be same dim with database.') + raise_if(ValueError, not isinstance(vector, np.ndarray), 'vector must be np.ndarray.') + raise_if(ValueError, vector.ndim != 1, 'vector must be 1d array.') + raise_if(ValueError, not isinstance(return_similarity, bool), 'return_similarity must be bool.') + raise_if(ValueError, distance is not None and distance not in ['cosine', 'L2'], + 'distance must be "cosine" or "L2" or None.') + + if self._matrix_serializer.shape[0] == 0: + raise ValueError('database is empty.') + + if k > self._matrix_serializer.shape[0]: + k = self._matrix_serializer.shape[0] + self._most_recent_query_report = {} self._timer.start() if self._use_cache: res = self._query.query(vector=vector, k=k, fields=fields, subset_indices=subset_indices, index_mode=self._matrix_serializer.index_mode, - distance=self._distance, return_similarity=return_similarity) + distance=distance, return_similarity=return_similarity) else: res = self._query.query(vector=vector, k=k, fields=fields, subset_indices=subset_indices, index_mode=self._matrix_serializer.index_mode, - now_time=datetime.datetime.now().timestamp(), distance=self._distance, + now_time=datetime.datetime.now().strftime('%Y%m%d%H%M%S%f'), + distance=distance, return_similarity=return_similarity) time_cost = self._timer.last_timestamp_diff() - self._most_recent_query_report['Database shape'] = self.shape - self._most_recent_query_report['Query time'] = f"{time_cost :>.5f} s" + self._most_recent_query_report['Database Shape'] = self.shape + self._most_recent_query_report['Query Time'] = f"{time_cost :>.5f} s" + self._most_recent_query_report['Query Distance'] = self._distance if distance is None else distance self._most_recent_query_report['Query K'] = k - self._most_recent_query_report[f'Top {k} results index'] = res[0] + self._most_recent_query_report[f'Top {k} Results Index'] = res[0] if return_similarity: - self._most_recent_query_report[f'Top {k} results similarity'] = np.array([round(i, 6) for i in res[1]]) + self._most_recent_query_report[f'Top {k} Results Similarity'] = np.array([round(i, 6) for i in res[1]]) return res