From c9d480a588360959d62d1dc5b7382c1f87c3c7c3 Mon Sep 17 00:00:00 2001 From: birchkwok Date: Mon, 15 Apr 2024 12:40:15 +0800 Subject: [PATCH] Separate parameter validation logic. --- min_vec/execution_layer/matrix_serializer.py | 89 +------------------- 1 file changed, 1 insertion(+), 88 deletions(-) diff --git a/min_vec/execution_layer/matrix_serializer.py b/min_vec/execution_layer/matrix_serializer.py index a6ec8de..6bda6ab 100644 --- a/min_vec/execution_layer/matrix_serializer.py +++ b/min_vec/execution_layer/matrix_serializer.py @@ -6,7 +6,6 @@ from spinesUtils.asserts import raise_if from spinesUtils.logging import Logger -from min_vec.computational_layer.engines import to_normalize from min_vec.data_structures.filter import IDFilter from min_vec.utils.utils import io_checker from min_vec.configs.config import config @@ -31,7 +30,6 @@ def __init__( chunk_size: int = 1_000_000, index_mode: str = 'IVF-FLAT', dtypes: str = 'float32', - reindex_if_conflict: bool = False, scaler_bits=None ) -> None: """ @@ -49,8 +47,6 @@ def __init__( Options are 'FLAT' or 'IVF-FLAT'. Default is 'IVF-FLAT'. dtypes (str): The data type of the vectors. Default is 'float32'. Options are 'float16', 'float32' or 'float64'. - reindex_if_conflict (bool): Whether to reindex the database if there is an index mode conflict. - Default is False. scaler_bits (int): The number of bits for scalar quantization. Default is None. Options are 8, 16, 32. If None, scalar quantization will not be used. The 8 bits for uint8, 16 bits for uint16, 32 bits for uint32. @@ -58,21 +54,7 @@ def __init__( Raises: ValueError: If `chunk_size` is less than or equal to 1. """ - raise_if(ValueError, not isinstance(dim, int), 'dim must be int') - raise_if(ValueError, not str(database_path).endswith('mvdb'), 'database_path must end with .mvdb') - raise_if(ValueError, not isinstance(chunk_size, int) or chunk_size <= 1, - 'chunk_size must be int and greater than 1') - raise_if(ValueError, distance not in ('cosine', 'L2'), 'distance must be "cosine" or "L2"') - raise_if(ValueError, index_mode not in ('FLAT', 'IVF-FLAT'), 'index_mode must be "FLAT" or "IVF-FLAT"') - raise_if(ValueError, dtypes not in ('float16', 'float32', 'float64'), - 'dtypes must be "float16", "float32" or "float64') - raise_if(ValueError, not isinstance(n_clusters, int) or n_clusters <= 0, - 'n_clusters must be int and greater than 0') - raise_if(ValueError, not isinstance(reindex_if_conflict, bool), 'reindex_if_conflict must be bool') - raise_if(ValueError, scaler_bits not in (8, 16, 32, None), 'sq_bits must be 8, 16, 32 or None') - self.last_commit_time = None - self.reindex_if_conflict = reindex_if_conflict # set commit flag, if the flag is True, the database will not be saved self.COMMIT_FLAG = True # set flag for scalar quantization, if the flag is True, the database will be rescanned for scalar quantization @@ -107,6 +89,7 @@ def __init__( # set scalar quantization bits self.scaler_bits = scaler_bits if scaler_bits is not None else None + self.scaler = None if self.scaler_bits is not None: self._initialize_scalar_quantization() @@ -123,9 +106,6 @@ def __init__( self.indices = [] self.fields = [] - # check initialize params - self._check_initialize_params(dtypes, reindex_if_conflict) - self._initialize_fields_mapper() self._initialize_ann_model() self._initialize_id_filter() @@ -161,73 +141,6 @@ def _write_params(self, dtypes): self.storage_worker.write_file_attributes(attrs) - def _check_initialize_params(self, dtypes, reindex_if_conflict): - """check initialize params""" - write_params_flag = False - if self.database_path.exists(): - self.logger.info('Database file exists, reading parameters...') - - attrs = self.storage_worker.read_file_attributes() - dim = attrs.get('dim', self.dim) - chunk_size = attrs.get('chunk_size', self.chunk_size) - distance = attrs.get('distance', self.distance) - file_dtypes = attrs.get('dtypes', dtypes) - old_index_mode = attrs.get('index_mode', None) - - if dim != self.dim: - self.logger.warning( - f'* dim={dim} in the file is not equal to the dim={self.dim} in the parameters, ' - f'the parameter dim will be covered by the dim in the file.') - self.dim = dim - write_params_flag = True - - if chunk_size != self.chunk_size: - self.logger.warning( - f'* chunk_size={chunk_size} in the file is not ' - f'equal to the chunk_size={self.chunk_size}' - f'in the parameters, the parameter chunk_size will be covered by the chunk_size ' - f'in the file.' - ) - self.chunk_size = chunk_size - write_params_flag = True - - if distance != self.distance: - self.logger.warning(f'* distance=\'{distance}\' in the file is not ' - f'equal to the distance=\'{self.distance}\' ' - f'in the parameters, the parameter distance will be covered by the distance ' - f'in the file.') - self.distance = distance - write_params_flag = True - - if file_dtypes != dtypes: - self.logger.warning( - f'* dtypes=\'{file_dtypes}\' in the file is not equal to the dtypes=\'{dtypes}\' ' - f'in the parameters, the parameter dtypes will be covered by the dtypes ' - f'in the file.') - self.dtypes = self._dtypes_map[attrs.get('dtypes', dtypes)] - write_params_flag = True - - if old_index_mode != self.index_mode and self.index_mode == 'IVF-FLAT': - if reindex_if_conflict: - self.logger.warning( - f'* index_mode=\'{old_index_mode}\' in the file is not equal to the index_mode=' - f'\'{self.index_mode}\' in the parameters, if you really want to change the ' - f'index_mode to \'IVF-FLAT\', you need to run `commit()` function after initializing ' - f'the database.') - - self.COMMIT_FLAG = False - else: - self.logger.warning( - f'* index_mode=\'{old_index_mode}\' in the file is not equal to the index_mode=' - f'\'{self.index_mode}\' in the parameters, if you really want to change the ' - f'index_mode to \'IVF-FLAT\', you need to set `reindex_if_conflict=True` first ' - f'and run `commit()` function after initializing the database.') - - self.logger.info('Reading parameters done.') - - if write_params_flag or not self.database_path.exists(): - self._write_params(dtypes) - def _initialize_parent_path(self, database_path): """make directory if not exist""" self.database_path_parent = Path(database_path).parent.absolute() / Path(