Skip to content

Commit

Permalink
PYTHON-4667 Handle $clusterTime from error responses in client Bulk W…
Browse files Browse the repository at this point in the history
…rite (#1822)
  • Loading branch information
blink1073 authored Sep 5, 2024
1 parent e27b428 commit 4d48130
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 7 deletions.
5 changes: 4 additions & 1 deletion pymongo/asynchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ async def write_command(
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
await client._process_response(reply, bwc.session) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
Expand Down Expand Up @@ -308,6 +309,9 @@ async def write_command(

if bwc.publish:
bwc._fail(request_id, failure, duration)
# Process the response from the server.
if isinstance(exc, (NotPrimaryError, OperationFailure)):
await client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
raise
finally:
bwc.start_time = datetime.datetime.now()
Expand Down Expand Up @@ -449,7 +453,6 @@ async def _execute_batch(
else:
request_id, msg, to_send = bwc.batch_command(cmd, ops)
result = await self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type]
await client._process_response(result, bwc.session) # type: ignore[arg-type]

return result, to_send # type: ignore[return-value]

Expand Down
8 changes: 7 additions & 1 deletion pymongo/asynchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ async def write_command(
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
# Process the response from the server.
await self.client._process_response(reply, bwc.session) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
Expand Down Expand Up @@ -312,6 +314,11 @@ async def write_command(
bwc._fail(request_id, failure, duration)
# Top-level error will be embedded in ClientBulkWriteException.
reply = {"error": exc}
# Process the response from the server.
if isinstance(exc, OperationFailure):
await self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
else:
await self.client._process_response({}, bwc.session) # type: ignore[arg-type]
finally:
bwc.start_time = datetime.datetime.now()
return reply # type: ignore[return-value]
Expand Down Expand Up @@ -431,7 +438,6 @@ async def _execute_batch(
result = await self.write_command(
bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client
) # type: ignore[arg-type]
await self.client._process_response(result, bwc.session) # type: ignore[arg-type]
return result, to_send_ops, to_send_ns # type: ignore[return-value]

async def _process_results_cursor(
Expand Down
5 changes: 4 additions & 1 deletion pymongo/synchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def write_command(
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
client._process_response(reply, bwc.session) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
Expand Down Expand Up @@ -308,6 +309,9 @@ def write_command(

if bwc.publish:
bwc._fail(request_id, failure, duration)
# Process the response from the server.
if isinstance(exc, (NotPrimaryError, OperationFailure)):
client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
raise
finally:
bwc.start_time = datetime.datetime.now()
Expand Down Expand Up @@ -449,7 +453,6 @@ def _execute_batch(
else:
request_id, msg, to_send = bwc.batch_command(cmd, ops)
result = self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type]
client._process_response(result, bwc.session) # type: ignore[arg-type]

return result, to_send # type: ignore[return-value]

Expand Down
8 changes: 7 additions & 1 deletion pymongo/synchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def write_command(
)
if bwc.publish:
bwc._succeed(request_id, reply, duration) # type: ignore[arg-type]
# Process the response from the server.
self.client._process_response(reply, bwc.session) # type: ignore[arg-type]
except Exception as exc:
duration = datetime.datetime.now() - bwc.start_time
if isinstance(exc, (NotPrimaryError, OperationFailure)):
Expand Down Expand Up @@ -312,6 +314,11 @@ def write_command(
bwc._fail(request_id, failure, duration)
# Top-level error will be embedded in ClientBulkWriteException.
reply = {"error": exc}
# Process the response from the server.
if isinstance(exc, OperationFailure):
self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type]
else:
self.client._process_response({}, bwc.session) # type: ignore[arg-type]
finally:
bwc.start_time = datetime.datetime.now()
return reply # type: ignore[return-value]
Expand Down Expand Up @@ -429,7 +436,6 @@ def _execute_batch(
"""Executes a batch of bulkWrite server commands (ack)."""
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
result = self.write_command(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type]
self.client._process_response(result, bwc.session) # type: ignore[arg-type]
return result, to_send_ops, to_send_ns # type: ignore[return-value]

def _process_results_cursor(
Expand Down
31 changes: 28 additions & 3 deletions test/mockupdb/test_cluster_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,22 @@

from bson import Timestamp
from pymongo import DeleteMany, InsertOne, MongoClient, UpdateOne
from pymongo.errors import OperationFailure

pytestmark = pytest.mark.mockupdb


class TestClusterTime(unittest.TestCase):
def cluster_time_conversation(self, callback, replies):
def cluster_time_conversation(self, callback, replies, max_wire_version=6):
cluster_time = Timestamp(0, 0)
server = MockupDB()

# First test all commands include $clusterTime with wire version 6.
# First test all commands include $clusterTime with max_wire_version.
_ = server.autoresponds(
"ismaster",
{
"minWireVersion": 0,
"maxWireVersion": 6,
"maxWireVersion": max_wire_version,
"$clusterTime": {"clusterTime": cluster_time},
},
)
Expand Down Expand Up @@ -166,6 +167,30 @@ def test_monitor(self):
request.reply(reply)
client.close()

def test_collection_bulk_error(self):
def callback(client: MongoClient[dict]) -> None:
with self.assertRaises(OperationFailure):
client.db.collection.bulk_write([InsertOne({}), InsertOne({})])

self.cluster_time_conversation(
callback,
[{"ok": 0, "errmsg": "mock error"}],
)

def test_client_bulk_error(self):
def callback(client: MongoClient[dict]) -> None:
with self.assertRaises(OperationFailure):
client.bulk_write(
[
InsertOne({}, namespace="db.collection"),
InsertOne({}, namespace="db.collection"),
]
)

self.cluster_time_conversation(
callback, [{"ok": 0, "errmsg": "mock error"}], max_wire_version=25
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4d48130

Please sign in to comment.