Skip to content

Commit

Permalink
Remove certain immutable parameters to rationalize them.
Browse files Browse the repository at this point in the history
  • Loading branch information
BirchKwok committed Apr 17, 2024
1 parent 8ba7042 commit d6f0a43
Showing 1 changed file with 49 additions and 15 deletions.
64 changes: 49 additions & 15 deletions min_vec/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class MinVectorDB:
"""

@ParametersValidator(
update_configs=['dim', 'database_path', 'n_clusters', 'chunk_size', 'index_mode', 'dtypes', 'scaler_bits',
'distance'],
update_configs=['dim', 'n_clusters', 'chunk_size', 'index_mode', 'dtypes', 'scaler_bits'],
logger=logger
)
@ParameterTypeAssert({
Expand All @@ -48,7 +47,7 @@ class MinVectorDB:
def __init__(
self, dim: int, database_path: str, n_clusters: int = 16, chunk_size: int = 100_000,
distance: str = 'cosine', index_mode: str = 'IVF-FLAT', dtypes: str = 'float32',
use_cache: bool = True, scaler_bits: int = 8, n_threads: int = None,
use_cache: bool = True, scaler_bits: Union[int, None] = 8, n_threads: Union[int, None] = 10,
warm_up: bool = False
) -> None:
"""
Expand All @@ -69,9 +68,9 @@ def __init__(
scaler_bits (int): The number of bits for scalar quantization.
Options are 8, 16, or 32. The default is None, which means no scalar quantization.
The 8 for 8-bit, 16 for 16-bit, and 32 for 32-bit.
n_threads (int): The number of threads to use for parallel processing. Default is None, which means using
twice the number of CPU cores.
n_threads (int): The number of threads to use for parallel processing. Default is 10.
warm_up (bool): Whether to warm up the database. Default is False.
.. versionadded:: 0.2.6
Raises:
ValueError: If `chunk_size` is less than or equal to 1.
Expand Down Expand Up @@ -102,7 +101,6 @@ def __init__(
database_path=database_path,
n_clusters=n_clusters,
chunk_size=chunk_size,
distance=distance,
index_mode=index_mode,
logger=logger,
dtypes=dtypes,
Expand All @@ -121,7 +119,8 @@ def __init__(

self._query = Query(
matrix_serializer=self._matrix_serializer,
n_threads=n_threads if n_threads else min(32, os.cpu_count() + 4)
n_threads=n_threads if n_threads else min(32, os.cpu_count() + 4),
distance=distance
)

self._query.query.clear_cache()
Expand Down Expand Up @@ -171,17 +170,23 @@ def commit(self):
self._matrix_serializer.commit()

@unavailable_if_deleted
def query(self, vector: np.ndarray, k: Union[int, str] = 12, *, fields: List = None, subset_indices: List = None,
def query(self, vector: np.ndarray, k: int = 12, *,
fields: Union[List, None] = None,
subset_indices: Union[List, None] = None,
distance: Union[str, None] = None,
return_similarity: bool = True):
"""
Query the database for the vectors most similar to the given vector in batches.
Parameters:
vector (np.ndarray): The query vector.
k (int or str): The number of nearest vectors to return. if be 'all', return all vectors.
k (int): The number of nearest vectors to return.
fields (list, optional): The target of the vector.
subset_indices (list, optional): The subset of indices to query.
distance (str): The distance metric to use for the query.
.. versionadded:: 0.2.7
return_similarity (bool): Whether to return the similarity scores.Default is True.
.. versionadded:: 0.2.5
Returns:
Tuple: The indices and similarity scores of the top k nearest vectors.
Expand All @@ -191,26 +196,55 @@ def query(self, vector: np.ndarray, k: Union[int, str] = 12, *, fields: List = N
"""
import datetime

logger.debug(f'Query vector: {vector.tolist()}')
logger.debug(f'Query k: {k}')
logger.debug(f'Query fields: {fields}')
logger.debug(f'Query subset_indices: {subset_indices}')

raise_if(TypeError, not isinstance(k, int) and not (isinstance(k, str) and k != 'all'),
'k must be int or "all".')
raise_if(ValueError, k <= 0, 'k must be greater than 0.')
raise_if(ValueError, not isinstance(fields, list) and fields is not None,
'fields must be list or None.')
raise_if(ValueError, not isinstance(subset_indices, list) and subset_indices is not None,
'subset_indices must be list or None.')
raise_if(ValueError, vector is None, 'vector must be not None.')
raise_if(ValueError, len(vector) != self._matrix_serializer.shape[1],
'vector must be same dim with database.')
raise_if(ValueError, not isinstance(vector, np.ndarray), 'vector must be np.ndarray.')
raise_if(ValueError, vector.ndim != 1, 'vector must be 1d array.')
raise_if(ValueError, not isinstance(return_similarity, bool), 'return_similarity must be bool.')
raise_if(ValueError, distance is not None and distance not in ['cosine', 'L2'],
'distance must be "cosine" or "L2" or None.')

if self._matrix_serializer.shape[0] == 0:
raise ValueError('database is empty.')

if k > self._matrix_serializer.shape[0]:
k = self._matrix_serializer.shape[0]

self._most_recent_query_report = {}

self._timer.start()
if self._use_cache:
res = self._query.query(vector=vector, k=k, fields=fields,
subset_indices=subset_indices, index_mode=self._matrix_serializer.index_mode,
distance=self._distance, return_similarity=return_similarity)
distance=distance, return_similarity=return_similarity)
else:
res = self._query.query(vector=vector, k=k, fields=fields,
subset_indices=subset_indices, index_mode=self._matrix_serializer.index_mode,
now_time=datetime.datetime.now().timestamp(), distance=self._distance,
now_time=datetime.datetime.now().strftime('%Y%m%d%H%M%S%f'),
distance=distance,
return_similarity=return_similarity)

time_cost = self._timer.last_timestamp_diff()
self._most_recent_query_report['Database shape'] = self.shape
self._most_recent_query_report['Query time'] = f"{time_cost :>.5f} s"
self._most_recent_query_report['Database Shape'] = self.shape
self._most_recent_query_report['Query Time'] = f"{time_cost :>.5f} s"
self._most_recent_query_report['Query Distance'] = self._distance if distance is None else distance
self._most_recent_query_report['Query K'] = k
self._most_recent_query_report[f'Top {k} results index'] = res[0]
self._most_recent_query_report[f'Top {k} Results Index'] = res[0]
if return_similarity:
self._most_recent_query_report[f'Top {k} results similarity'] = np.array([round(i, 6) for i in res[1]])
self._most_recent_query_report[f'Top {k} Results Similarity'] = np.array([round(i, 6) for i in res[1]])

return res

Expand Down

0 comments on commit d6f0a43

Please sign in to comment.