Skip to content

Commit

Permalink
clean up bulk support
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 committed Sep 30, 2024
1 parent f7de5cd commit 523875b
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 24 deletions.
13 changes: 13 additions & 0 deletions pymongo/asynchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
self.uses_array_filters = False
self.uses_hint_update = False
self.uses_hint_delete = False
self.uses_sort = False
self.is_retryable = True
self.retrying = False
self.started_retryable_write = False
Expand Down Expand Up @@ -144,6 +145,7 @@ def add_update(
collation: Optional[Mapping[str, Any]] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Union[str, dict[str, Any], None] = None,
sort: Optional[Mapping[str, Any]] = None,
) -> None:
"""Create an update document and add it to the list of ops."""
validate_ok_for_update(update)
Expand All @@ -159,6 +161,9 @@ def add_update(
if hint is not None:
self.uses_hint_update = True
cmd["hint"] = hint
if sort is not None:
self.uses_sort = True
cmd["sort"] = sort
if multi:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
Expand All @@ -171,6 +176,7 @@ def add_replace(
upsert: bool = False,
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
sort: Optional[Mapping[str, Any]] = None,
) -> None:
"""Create a replace document and add it to the list of ops."""
validate_ok_for_replace(replacement)
Expand All @@ -181,6 +187,9 @@ def add_replace(
if hint is not None:
self.uses_hint_update = True
cmd["hint"] = hint
if sort is not None:
self.uses_sort = True
cmd["sort"] = sort
self.ops.append((_UPDATE, cmd))

def add_delete(
Expand Down Expand Up @@ -699,6 +708,10 @@ async def execute_no_results(
raise ConfigurationError(
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
)
if unack and self.uses_sort and conn.max_wire_version < 25:
raise ConfigurationError(
"Must be connected to MongoDB 8.0+ to use sort on unacknowledged update commands."
)
# Cannot have both unacknowledged writes and bypass document validation.
if self.bypass_doc_val:
raise OperationFailure(
Expand Down
9 changes: 9 additions & 0 deletions pymongo/asynchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
self.uses_array_filters = False
self.uses_hint_update = False
self.uses_hint_delete = False
self.uses_sort = False

self.is_retryable = self.client.options.retry_writes
self.retrying = False
Expand Down Expand Up @@ -148,6 +149,7 @@ def add_update(
collation: Optional[Mapping[str, Any]] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Union[str, dict[str, Any], None] = None,
sort: Optional[Mapping[str, Any]] = None,
) -> None:
"""Create an update document and add it to the list of ops."""
validate_ok_for_update(update)
Expand All @@ -169,6 +171,9 @@ def add_update(
if collation is not None:
self.uses_collation = True
cmd["collation"] = collation
if sort is not None:
self.uses_sort = True
cmd["sort"] = sort
if multi:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
Expand All @@ -184,6 +189,7 @@ def add_replace(
upsert: Optional[bool] = None,
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
sort: Optional[Mapping[str, Any]] = None,
) -> None:
"""Create a replace document and add it to the list of ops."""
validate_ok_for_replace(replacement)
Expand All @@ -202,6 +208,9 @@ def add_replace(
if collation is not None:
self.uses_collation = True
cmd["collation"] = collation
if sort is not None:
self.uses_sort = True
cmd["sort"] = sort
self.ops.append(("replace", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
Expand Down
22 changes: 9 additions & 13 deletions pymongo/asynchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,6 @@ async def bulk_write(
session: Optional[AsyncClientSession] = None,
comment: Optional[Any] = None,
let: Optional[Mapping] = None,
sort: Optional[Mapping] = None,
) -> BulkWriteResult:
"""Send a batch of write operations to the server.
Expand Down Expand Up @@ -739,8 +738,6 @@ async def bulk_write(
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
aggregate expression context (e.g. "$$var").
:param sort: Specify which document the operation updates if the query matches
multiple documents. The first document matched by the sort order will be updated.
:return: An instance of :class:`~pymongo.results.BulkWriteResult`.
Expand All @@ -749,9 +746,6 @@ async def bulk_write(
.. note:: `bypass_document_validation` requires server version
**>= 3.2**
.. versionchanged:: 4.10
Added ``sort`` parameter.
.. versionchanged:: 4.1
Added ``comment`` parameter.
Added ``let`` parameter.
Expand All @@ -766,9 +760,7 @@ async def bulk_write(
"""
common.validate_list("requests", requests)

blk = _AsyncBulk(
self, ordered, bypass_document_validation, comment=comment, let=let, sort=sort
)
blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let)
for request in requests:
try:
request._add_to_bulk(blk)
Expand Down Expand Up @@ -1012,15 +1004,19 @@ async def _update(
if not isinstance(hint, str):
hint = helpers_shared._index_document(hint)
update_doc["hint"] = hint
if sort is not None:
if not acknowledged and conn.max_wire_version < 25:
raise ConfigurationError(
"Must be connected to MongoDB 8.0+ to use sort on unacknowledged update commands."
)
common.validate_is_mapping("sort", sort)
update_doc["sort"] = sort

command = {"update": self.name, "ordered": ordered, "updates": [update_doc]}
if let is not None:
common.validate_is_mapping("let", let)
command["let"] = let

if sort is not None:
common.validate_is_mapping("sort", sort)
command["sort"] = sort

if comment is not None:
command["comment"] = comment
# Update command.
Expand Down
13 changes: 13 additions & 0 deletions pymongo/synchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
self.uses_array_filters = False
self.uses_hint_update = False
self.uses_hint_delete = False
self.uses_sort = False
self.is_retryable = True
self.retrying = False
self.started_retryable_write = False
Expand Down Expand Up @@ -144,6 +145,7 @@ def add_update(
collation: Optional[Mapping[str, Any]] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Union[str, dict[str, Any], None] = None,
sort: Optional[Mapping[str, Any]] = None,
) -> None:
"""Create an update document and add it to the list of ops."""
validate_ok_for_update(update)
Expand All @@ -159,6 +161,9 @@ def add_update(
if hint is not None:
self.uses_hint_update = True
cmd["hint"] = hint
if sort is not None:
self.uses_sort = True
cmd["sort"] = sort
if multi:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
Expand All @@ -171,6 +176,7 @@ def add_replace(
upsert: bool = False,
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
sort: Optional[Mapping[str, Any]] = None,
) -> None:
"""Create a replace document and add it to the list of ops."""
validate_ok_for_replace(replacement)
Expand All @@ -181,6 +187,9 @@ def add_replace(
if hint is not None:
self.uses_hint_update = True
cmd["hint"] = hint
if sort is not None:
self.uses_sort = True
cmd["sort"] = sort
self.ops.append((_UPDATE, cmd))

def add_delete(
Expand Down Expand Up @@ -697,6 +706,10 @@ def execute_no_results(
raise ConfigurationError(
"Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands."
)
if unack and self.uses_sort and conn.max_wire_version < 25:
raise ConfigurationError(
"Must be connected to MongoDB 8.0+ to use sort on unacknowledged update commands."
)
# Cannot have both unacknowledged writes and bypass document validation.
if self.bypass_doc_val:
raise OperationFailure(
Expand Down
9 changes: 9 additions & 0 deletions pymongo/synchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
self.uses_array_filters = False
self.uses_hint_update = False
self.uses_hint_delete = False
self.uses_sort = False

self.is_retryable = self.client.options.retry_writes
self.retrying = False
Expand Down Expand Up @@ -148,6 +149,7 @@ def add_update(
collation: Optional[Mapping[str, Any]] = None,
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Union[str, dict[str, Any], None] = None,
sort: Optional[Mapping[str, Any]] = None,
) -> None:
"""Create an update document and add it to the list of ops."""
validate_ok_for_update(update)
Expand All @@ -169,6 +171,9 @@ def add_update(
if collation is not None:
self.uses_collation = True
cmd["collation"] = collation
if sort is not None:
self.uses_sort = True
cmd["sort"] = sort
if multi:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
Expand All @@ -184,6 +189,7 @@ def add_replace(
upsert: Optional[bool] = None,
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
sort: Optional[Mapping[str, Any]] = None,
) -> None:
"""Create a replace document and add it to the list of ops."""
validate_ok_for_replace(replacement)
Expand All @@ -202,6 +208,9 @@ def add_replace(
if collation is not None:
self.uses_collation = True
cmd["collation"] = collation
if sort is not None:
self.uses_sort = True
cmd["sort"] = sort
self.ops.append(("replace", cmd))
self.namespaces.append(namespace)
self.total_ops += 1
Expand Down
20 changes: 9 additions & 11 deletions pymongo/synchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,6 @@ def bulk_write(
session: Optional[ClientSession] = None,
comment: Optional[Any] = None,
let: Optional[Mapping] = None,
sort: Optional[Mapping] = None,
) -> BulkWriteResult:
"""Send a batch of write operations to the server.
Expand Down Expand Up @@ -738,8 +737,6 @@ def bulk_write(
constant or closed expressions that do not reference document
fields. Parameters can then be accessed as variables in an
aggregate expression context (e.g. "$$var").
:param sort: Specify which document the operation updates if the query matches
multiple documents. The first document matched by the sort order will be updated.
:return: An instance of :class:`~pymongo.results.BulkWriteResult`.
Expand All @@ -748,9 +745,6 @@ def bulk_write(
.. note:: `bypass_document_validation` requires server version
**>= 3.2**
.. versionchanged:: 4.10
Added ``sort`` parameter.
.. versionchanged:: 4.1
Added ``comment`` parameter.
Added ``let`` parameter.
Expand All @@ -765,7 +759,7 @@ def bulk_write(
"""
common.validate_list("requests", requests)

blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let, sort=sort)
blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let)
for request in requests:
try:
request._add_to_bulk(blk)
Expand Down Expand Up @@ -1009,15 +1003,19 @@ def _update(
if not isinstance(hint, str):
hint = helpers_shared._index_document(hint)
update_doc["hint"] = hint
if sort is not None:
if not acknowledged and conn.max_wire_version < 25:
raise ConfigurationError(
"Must be connected to MongoDB 8.0+ to use sort on unacknowledged update commands."
)
common.validate_is_mapping("sort", sort)
update_doc["sort"] = sort

command = {"update": self.name, "ordered": ordered, "updates": [update_doc]}
if let is not None:
common.validate_is_mapping("let", let)
command["let"] = let

if sort is not None:
common.validate_is_mapping("sort", sort)
command["sort"] = sort

if comment is not None:
command["comment"] = comment
# Update command.
Expand Down

0 comments on commit 523875b

Please sign in to comment.