Skip to content

Commit

Permalink
Separate parameter validation logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
BirchKwok committed Apr 15, 2024
1 parent f36d024 commit c9d480a
Showing 1 changed file with 1 addition and 88 deletions.
89 changes: 1 addition & 88 deletions min_vec/execution_layer/matrix_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from spinesUtils.asserts import raise_if
from spinesUtils.logging import Logger

from min_vec.computational_layer.engines import to_normalize
from min_vec.data_structures.filter import IDFilter
from min_vec.utils.utils import io_checker
from min_vec.configs.config import config
Expand All @@ -31,7 +30,6 @@ def __init__(
chunk_size: int = 1_000_000,
index_mode: str = 'IVF-FLAT',
dtypes: str = 'float32',
reindex_if_conflict: bool = False,
scaler_bits=None
) -> None:
"""
Expand All @@ -49,30 +47,14 @@ def __init__(
Options are 'FLAT' or 'IVF-FLAT'. Default is 'IVF-FLAT'.
dtypes (str): The data type of the vectors. Default is 'float32'.
Options are 'float16', 'float32' or 'float64'.
reindex_if_conflict (bool): Whether to reindex the database if there is an index mode conflict.
Default is False.
scaler_bits (int): The number of bits for scalar quantization. Default is None.
Options are 8, 16, 32. If None, scalar quantization will not be used.
The 8 bits for uint8, 16 bits for uint16, 32 bits for uint32.
Raises:
ValueError: If `chunk_size` is less than or equal to 1.
"""
raise_if(ValueError, not isinstance(dim, int), 'dim must be int')
raise_if(ValueError, not str(database_path).endswith('mvdb'), 'database_path must end with .mvdb')
raise_if(ValueError, not isinstance(chunk_size, int) or chunk_size <= 1,
'chunk_size must be int and greater than 1')
raise_if(ValueError, distance not in ('cosine', 'L2'), 'distance must be "cosine" or "L2"')
raise_if(ValueError, index_mode not in ('FLAT', 'IVF-FLAT'), 'index_mode must be "FLAT" or "IVF-FLAT"')
raise_if(ValueError, dtypes not in ('float16', 'float32', 'float64'),
'dtypes must be "float16", "float32" or "float64')
raise_if(ValueError, not isinstance(n_clusters, int) or n_clusters <= 0,
'n_clusters must be int and greater than 0')
raise_if(ValueError, not isinstance(reindex_if_conflict, bool), 'reindex_if_conflict must be bool')
raise_if(ValueError, scaler_bits not in (8, 16, 32, None), 'sq_bits must be 8, 16, 32 or None')

self.last_commit_time = None
self.reindex_if_conflict = reindex_if_conflict
# set commit flag, if the flag is True, the database will not be saved
self.COMMIT_FLAG = True
# set flag for scalar quantization, if the flag is True, the database will be rescanned for scalar quantization
Expand Down Expand Up @@ -107,6 +89,7 @@ def __init__(

# set scalar quantization bits
self.scaler_bits = scaler_bits if scaler_bits is not None else None
self.scaler = None
if self.scaler_bits is not None:
self._initialize_scalar_quantization()

Expand All @@ -123,9 +106,6 @@ def __init__(
self.indices = []
self.fields = []

# check initialize params
self._check_initialize_params(dtypes, reindex_if_conflict)

self._initialize_fields_mapper()
self._initialize_ann_model()
self._initialize_id_filter()
Expand Down Expand Up @@ -161,73 +141,6 @@ def _write_params(self, dtypes):

self.storage_worker.write_file_attributes(attrs)

def _check_initialize_params(self, dtypes, reindex_if_conflict):
"""check initialize params"""
write_params_flag = False
if self.database_path.exists():
self.logger.info('Database file exists, reading parameters...')

attrs = self.storage_worker.read_file_attributes()
dim = attrs.get('dim', self.dim)
chunk_size = attrs.get('chunk_size', self.chunk_size)
distance = attrs.get('distance', self.distance)
file_dtypes = attrs.get('dtypes', dtypes)
old_index_mode = attrs.get('index_mode', None)

if dim != self.dim:
self.logger.warning(
f'* dim={dim} in the file is not equal to the dim={self.dim} in the parameters, '
f'the parameter dim will be covered by the dim in the file.')
self.dim = dim
write_params_flag = True

if chunk_size != self.chunk_size:
self.logger.warning(
f'* chunk_size={chunk_size} in the file is not '
f'equal to the chunk_size={self.chunk_size}'
f'in the parameters, the parameter chunk_size will be covered by the chunk_size '
f'in the file.'
)
self.chunk_size = chunk_size
write_params_flag = True

if distance != self.distance:
self.logger.warning(f'* distance=\'{distance}\' in the file is not '
f'equal to the distance=\'{self.distance}\' '
f'in the parameters, the parameter distance will be covered by the distance '
f'in the file.')
self.distance = distance
write_params_flag = True

if file_dtypes != dtypes:
self.logger.warning(
f'* dtypes=\'{file_dtypes}\' in the file is not equal to the dtypes=\'{dtypes}\' '
f'in the parameters, the parameter dtypes will be covered by the dtypes '
f'in the file.')
self.dtypes = self._dtypes_map[attrs.get('dtypes', dtypes)]
write_params_flag = True

if old_index_mode != self.index_mode and self.index_mode == 'IVF-FLAT':
if reindex_if_conflict:
self.logger.warning(
f'* index_mode=\'{old_index_mode}\' in the file is not equal to the index_mode='
f'\'{self.index_mode}\' in the parameters, if you really want to change the '
f'index_mode to \'IVF-FLAT\', you need to run `commit()` function after initializing '
f'the database.')

self.COMMIT_FLAG = False
else:
self.logger.warning(
f'* index_mode=\'{old_index_mode}\' in the file is not equal to the index_mode='
f'\'{self.index_mode}\' in the parameters, if you really want to change the '
f'index_mode to \'IVF-FLAT\', you need to set `reindex_if_conflict=True` first '
f'and run `commit()` function after initializing the database.')

self.logger.info('Reading parameters done.')

if write_params_flag or not self.database_path.exists():
self._write_params(dtypes)

def _initialize_parent_path(self, database_path):
"""make directory if not exist"""
self.database_path_parent = Path(database_path).parent.absolute() / Path(
Expand Down

0 comments on commit c9d480a

Please sign in to comment.