Skip to content

Commit

Permalink
add strict mode to collection creation
Browse files Browse the repository at this point in the history
  • Loading branch information
generall committed Jan 3, 2025
1 parent e3ce46e commit 050bed6
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 32 deletions.
10 changes: 9 additions & 1 deletion qdrant_client/async_qdrant_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(
super().__init__(parser=self._inference_inspector.parser, **kwargs)
self._init_options = {
key: value
for (key, value) in locals().items()
for key, value in locals().items()
if key not in ("self", "__class__", "kwargs")
}
self._init_options.update(deepcopy(kwargs))
Expand Down Expand Up @@ -2100,6 +2100,7 @@ async def update_collection(
quantization_config: Optional[types.QuantizationConfigDiff] = None,
timeout: Optional[int] = None,
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
strict_mode_config: Optional[types.StrictModeConfig] = None,
**kwargs: Any,
) -> bool:
"""Update parameters of the collection
Expand All @@ -2115,6 +2116,7 @@ async def update_collection(
Wait for operation commit timeout in seconds.
If timeout is reached - request will return with service error.
sparse_vectors_config: Override for sparse vector-specific configuration
strict_mode_config: Override for strict mode configuration
Returns:
Operation result
"""
Expand All @@ -2134,6 +2136,7 @@ async def update_collection(
quantization_config=quantization_config,
timeout=timeout,
sparse_vectors_config=sparse_vectors_config,
strict_mode_config=strict_mode_config,
**kwargs,
)

Expand Down Expand Up @@ -2172,6 +2175,7 @@ async def create_collection(
quantization_config: Optional[types.QuantizationConfig] = None,
init_from: Optional[types.InitFrom] = None,
timeout: Optional[int] = None,
strict_mode_config: Optional[types.StrictModeConfig] = None,
**kwargs: Any,
) -> bool:
"""Create empty collection with given parameters
Expand Down Expand Up @@ -2217,6 +2221,7 @@ async def create_collection(
timeout:
Wait for operation commit timeout in seconds.
If timeout is reached - request will return with service error.
strict_mode_config: Configure limitations for the collection, such as max size, rate limits, etc.
Returns:
Operation result
Expand Down Expand Up @@ -2256,6 +2261,7 @@ async def recreate_collection(
quantization_config: Optional[types.QuantizationConfig] = None,
init_from: Optional[types.InitFrom] = None,
timeout: Optional[int] = None,
strict_mode_config: Optional[types.StrictModeConfig] = None,
**kwargs: Any,
) -> bool:
"""Delete and create empty collection with given parameters
Expand Down Expand Up @@ -2301,6 +2307,7 @@ async def recreate_collection(
timeout:
Wait for operation commit timeout in seconds.
If timeout is reached - request will return with service error.
strict_mode_config: Configure limitations for the collection, such as max size, rate limits, etc.
Returns:
Operation result
Expand All @@ -2326,6 +2333,7 @@ async def recreate_collection(
init_from=init_from,
timeout=timeout,
sparse_vectors_config=sparse_vectors_config,
strict_mode_config=strict_mode_config,
**kwargs,
)

Expand Down
13 changes: 6 additions & 7 deletions qdrant_client/async_qdrant_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _embed_documents(
parallel: Optional[int] = None,
) -> Iterable[tuple[str, list[float]]]:
embedding_model = self._get_or_init_model(model_name=embedding_model_name)
(documents_a, documents_b) = tee(documents, 2)
documents_a, documents_b = tee(documents, 2)
if embed_type == "passage":
vectors_iter = embedding_model.passage_embed(
documents_a, batch_size=batch_size, parallel=parallel
Expand Down Expand Up @@ -456,7 +456,7 @@ def _points_iterator(
yield models.PointStruct(id=idx, payload=payload, vector=point_vector)

def _validate_collection_info(self, collection_info: models.CollectionInfo) -> None:
(embeddings_size, distance) = self._get_model_params(model_name=self.embedding_model_name)
embeddings_size, distance = self._get_model_params(model_name=self.embedding_model_name)
vector_field_name = self.get_vector_field_name()
assert isinstance(
collection_info.config.params.vectors, dict
Expand Down Expand Up @@ -502,7 +502,7 @@ def get_fastembed_vector_params(
Configuration for `vectors_config` argument in `create_collection` method.
"""
vector_field_name = self.get_vector_field_name()
(embeddings_size, distance) = self._get_model_params(model_name=self.embedding_model_name)
embeddings_size, distance = self._get_model_params(model_name=self.embedding_model_name)
return {
vector_field_name: models.VectorParams(
size=embeddings_size,
Expand Down Expand Up @@ -687,7 +687,7 @@ async def query(
with_payload=True,
**kwargs,
)
(dense_request_response, sparse_request_response) = await self.search_batch(
dense_request_response, sparse_request_response = await self.search_batch(
collection_name=collection_name, requests=[dense_request, sparse_request]
)
return self._scored_points_to_query_responses(
Expand Down Expand Up @@ -764,7 +764,7 @@ async def query_batch(
sparse_responses = responses[len(query_texts) :]
responses = [
reciprocal_rank_fusion([dense_response, sparse_response], limit=limit)
for (dense_response, sparse_response) in zip(dense_responses, sparse_responses)
for dense_response, sparse_response in zip(dense_responses, sparse_responses)
]
return [self._scored_points_to_query_responses(response) for response in responses]

Expand Down Expand Up @@ -925,8 +925,7 @@ def _embed_raw_data(
return self._embed_image(data)
elif isinstance(data, dict):
return {
key: self._embed_raw_data(value, is_query=is_query)
for (key, value) in data.items()
key: self._embed_raw_data(value, is_query=is_query) for key, value in data.items()
}
elif isinstance(data, list):
if data and isinstance(data[0], float):
Expand Down
33 changes: 21 additions & 12 deletions qdrant_client/async_qdrant_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
if url.startswith("localhost"):
url = f"//{url}"
parsed_url: Url = parse_url(url)
(self._host, self._port) = (parsed_url.host, parsed_url.port)
self._host, self._port = (parsed_url.host, parsed_url.port)
if parsed_url.scheme:
self._https = parsed_url.scheme == "https"
self._scheme = parsed_url.scheme
Expand Down Expand Up @@ -198,7 +198,7 @@ async def close(self, grpc_grace: Optional[float] = None, **kwargs: Any) -> None
@staticmethod
def _parse_url(url: str) -> tuple[Optional[str], str, Optional[int], Optional[str]]:
parse_result: Url = parse_url(url)
(scheme, host, port, prefix) = (
scheme, host, port, prefix = (
parse_result.scheme,
parse_result.host,
parse_result.port,
Expand Down Expand Up @@ -1708,7 +1708,7 @@ async def delete_vectors(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -1730,7 +1730,7 @@ async def delete_vectors(
assert grpc_result is not None, "Delete vectors returned None result"
return GrpcToRest.convert_update_result(grpc_result)
else:
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
return (
await self.openapi_client.points_api.delete_vectors(
collection_name=collection_name,
Expand Down Expand Up @@ -1925,7 +1925,7 @@ async def delete(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(
points_selector
)
shard_key_selector = shard_key_selector or opt_shard_key_selector
Expand Down Expand Up @@ -1974,7 +1974,7 @@ async def set_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -1997,7 +1997,7 @@ async def set_payload(
).result
)
else:
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
result: Optional[types.UpdateResult] = (
await self.openapi_client.points_api.set_payload(
collection_name=collection_name,
Expand Down Expand Up @@ -2026,7 +2026,7 @@ async def overwrite_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -2048,7 +2048,7 @@ async def overwrite_payload(
).result
)
else:
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
result: Optional[types.UpdateResult] = (
await self.openapi_client.points_api.overwrite_payload(
collection_name=collection_name,
Expand Down Expand Up @@ -2076,7 +2076,7 @@ async def delete_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(points)
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(points)
shard_key_selector = shard_key_selector or opt_shard_key_selector
if isinstance(ordering, models.WriteOrdering):
ordering = RestToGrpc.convert_write_ordering(ordering)
Expand All @@ -2098,7 +2098,7 @@ async def delete_payload(
).result
)
else:
(_points, _filter) = self._try_argument_to_rest_points_and_filter(points)
_points, _filter = self._try_argument_to_rest_points_and_filter(points)
result: Optional[types.UpdateResult] = (
await self.openapi_client.points_api.delete_payload(
collection_name=collection_name,
Expand All @@ -2122,7 +2122,7 @@ async def clear_payload(
**kwargs: Any,
) -> types.UpdateResult:
if self._prefer_grpc:
(points_selector, opt_shard_key_selector) = self._try_argument_to_grpc_selector(
points_selector, opt_shard_key_selector = self._try_argument_to_grpc_selector(
points_selector
)
shard_key_selector = shard_key_selector or opt_shard_key_selector
Expand Down Expand Up @@ -2333,6 +2333,7 @@ async def update_collection(
quantization_config: Optional[types.QuantizationConfigDiff] = None,
timeout: Optional[int] = None,
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
strict_mode_config: Optional[types.StrictModeConfig] = None,
**kwargs: Any,
) -> bool:
if self._prefer_grpc:
Expand All @@ -2352,6 +2353,8 @@ async def update_collection(
sparse_vectors_config = RestToGrpc.convert_sparse_vector_config(
sparse_vectors_config
)
if isinstance(strict_mode_config, models.StrictModeConfig):
strict_mode_config = RestToGrpc.convert_strict_mode_config(strict_mode_config)
return (
await self.grpc_collections.Update(
grpc.UpdateCollection(
Expand All @@ -2362,6 +2365,7 @@ async def update_collection(
hnsw_config=hnsw_config,
quantization_config=quantization_config,
sparse_vectors_config=sparse_vectors_config,
strict_mode_config=strict_mode_config,
timeout=timeout,
),
timeout=timeout if timeout is not None else self._timeout,
Expand Down Expand Up @@ -2426,6 +2430,7 @@ async def create_collection(
timeout: Optional[int] = None,
sparse_vectors_config: Optional[Mapping[str, types.SparseVectorParams]] = None,
sharding_method: Optional[types.ShardingMethod] = None,
strict_mode_config: Optional[types.StrictModeConfig] = None,
**kwargs: Any,
) -> bool:
if init_from is not None:
Expand All @@ -2449,6 +2454,8 @@ async def create_collection(
)
if isinstance(sharding_method, models.ShardingMethod):
sharding_method = RestToGrpc.convert_sharding_method(sharding_method)
if isinstance(strict_mode_config, models.StrictModeConfig):
strict_mode_config = RestToGrpc.convert_strict_mode_config(strict_mode_config)
create_collection = grpc.CreateCollection(
collection_name=collection_name,
hnsw_config=hnsw_config,
Expand All @@ -2464,6 +2471,7 @@ async def create_collection(
quantization_config=quantization_config,
sparse_vectors_config=sparse_vectors_config,
sharding_method=sharding_method,
strict_mode_config=strict_mode_config,
)
return (
await self.grpc_collections.Create(create_collection, timeout=self._timeout)
Expand Down Expand Up @@ -2491,6 +2499,7 @@ async def create_collection(
init_from=init_from,
sparse_vectors=sparse_vectors_config,
sharding_method=sharding_method,
strict_mode_config=strict_mode_config,
)
result: Optional[bool] = (
await self.http.collections_api.create_collection(
Expand Down
1 change: 1 addition & 0 deletions qdrant_client/conversions/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def get_args_subscribed(tp): # type: ignore
Document: TypeAlias = rest.Document
Image: TypeAlias = rest.Image
InferenceObject: TypeAlias = rest.InferenceObject
StrictModeConfig: TypeAlias = rest.StrictModeConfig

SearchRequest = Union[rest.SearchRequest, grpc.SearchPoints]
RecommendRequest = Union[rest.RecommendRequest, grpc.RecommendPoints]
Expand Down
22 changes: 11 additions & 11 deletions qdrant_client/local/async_qdrant_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _save(self) -> None:
{
"collections": {
collection_name: to_dict(collection.config)
for (collection_name, collection) in self.collections.items()
for collection_name, collection in self.collections.items()
},
"aliases": self.aliases,
}
Expand Down Expand Up @@ -387,7 +387,7 @@ def _resolve_prefetch_input(
if prefetch.query is None:
return prefetch
prefetch = deepcopy(prefetch)
(query, mentioned_ids) = self._resolve_query_input(
query, mentioned_ids = self._resolve_query_input(
collection_name, prefetch.query, prefetch.using, prefetch.lookup_from
)
prefetch.query = query
Expand All @@ -413,7 +413,7 @@ async def query_points(
) -> types.QueryResponse:
collection = self._get_collection(collection_name)
if query is not None:
(query, mentioned_ids) = self._resolve_query_input(
query, mentioned_ids = self._resolve_query_input(
collection_name, query, using, lookup_from
)
query_filter = ignore_mentioned_ids_filter(query_filter, list(mentioned_ids))
Expand Down Expand Up @@ -486,7 +486,7 @@ async def query_points_groups(
) -> types.GroupsResult:
collection = self._get_collection(collection_name)
if query is not None:
(query, mentioned_ids) = self._resolve_query_input(
query, mentioned_ids = self._resolve_query_input(
collection_name, query, using, lookup_from
)
query_filter = ignore_mentioned_ids_filter(query_filter, list(mentioned_ids))
Expand Down Expand Up @@ -846,7 +846,7 @@ async def get_collection_aliases(
return types.CollectionsAliasesResponse(
aliases=[
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
for (alias_name, name) in self.aliases.items()
for alias_name, name in self.aliases.items()
if name == collection_name
]
)
Expand All @@ -857,7 +857,7 @@ async def get_aliases(self, **kwargs: Any) -> types.CollectionsAliasesResponse:
return types.CollectionsAliasesResponse(
aliases=[
rest_models.AliasDescription(alias_name=alias_name, collection_name=name)
for (alias_name, name) in self.aliases.items()
for alias_name, name in self.aliases.items()
]
)

Expand All @@ -867,7 +867,7 @@ async def get_collections(self, **kwargs: Any) -> types.CollectionsResponse:
return types.CollectionsResponse(
collections=[
rest_models.CollectionDescription(name=name)
for (name, _) in self.collections.items()
for name, _ in self.collections.items()
]
)

Expand Down Expand Up @@ -908,7 +908,7 @@ async def delete_collection(self, collection_name: str, **kwargs: Any) -> bool:
del _collection
self.aliases = {
alias_name: name
for (alias_name, name) in self.aliases.items()
for alias_name, name in self.aliases.items()
if name != collection_name
}
collection_path = self._collection_path(collection_name)
Expand Down Expand Up @@ -949,12 +949,12 @@ async def create_collection(
self.collections[collection_name] = collection
if src_collection and from_collection_name:
batch_size = 100
(records, next_offset) = await self.scroll(
records, next_offset = await self.scroll(
from_collection_name, limit=2, with_vectors=True
)
self.upload_records(collection_name, records)
while next_offset is not None:
(records, next_offset) = await self.scroll(
records, next_offset = await self.scroll(
from_collection_name, offset=next_offset, limit=batch_size, with_vectors=True
)
self.upload_records(collection_name, records)
Expand Down Expand Up @@ -1030,7 +1030,7 @@ def uuid_generator() -> Generator[str, None, None]:
vector=(vector.tolist() if isinstance(vector, np.ndarray) else vector) or {},
payload=payload or {},
)
for (point_id, vector, payload) in zip(
for point_id, vector, payload in zip(
ids or uuid_generator(), iter(vectors), payload or itertools.cycle([{}])
)
]
Expand Down
Loading

0 comments on commit 050bed6

Please sign in to comment.