diff --git a/PYTHON-5676-scope.md b/PYTHON-5676-scope.md deleted file mode 100644 index 56a7fc0da..000000000 --- a/PYTHON-5676-scope.md +++ /dev/null @@ -1,183 +0,0 @@ -# PYTHON-5676: Consolidate Command Execution Logic — Scope Document - -## Context - -PyMongo has accumulated several distinct code paths for executing commands against the server. When cross-cutting logic (e.g., backpressure, CSOT, monitoring) needs to be added or changed, engineers must find and update multiple paths. The goal is a single central path for all database operations, making future changes local to one place. - -PYTHON-1357 (Refactor Cursor and CommandCursor) is now Closed — the prerequisite work is done. - ---- - -## Current Command Execution Paths - -All paths are duplicated across `pymongo/asynchronous/` (source of truth) and `pymongo/synchronous/` (auto-generated via `tools/synchro.sh`). - -### Path 1 — Standard commands (most operations) -``` -Collection._command() / Database.command() / conn.command() (direct) - → Connection.command() [pool.py:344] — session, write concern, server API, reauth - → network.command() [network.py:61] — encode OP_MSG, APM, logging, CSOT, encryption, send/recv -``` -Used by: insert, update, delete, aggregate, distinct, find_one, createIndex, dropCollection, db.command(), etc. - -### Path 2 — Cursor operations (find / getMore) -``` -Cursor._send_message() - → MongoClient._run_operation() [mongo_client.py:1911] - → Server.run_operation() [server.py:138] — send/recv + FULL APM/logging (≈80 lines, near-identical to network.command()) -``` -**Bypasses `network.command()` entirely.** Has its own APM/monitoring and logging code. - -### Path 3 — Acknowledged bulk writes (collection-level) -``` -_Bulk._execute_batch() - → _Bulk.write_command() [bulk.py:244] — APM/logging wrapper (≈80 lines) - → Connection.write_command() [pool.py:480] — pre-encoded bytes, raw send+recv, no APM -``` -**Bypasses both `Connection.command()` and `network.command()`.** Note: for *encrypted* bulk writes, this path already uses `Connection.command()` — the non-encrypted path is the outlier. - -### Path 4 — Unacknowledged bulk writes -``` -_Bulk._execute_batch_unack() - → _Bulk.unack_write() [bulk.py:329] — APM/logging wrapper (≈70 lines) - → Connection.unack_write() [pool.py:469] — raw send only, no recv -``` -**Bypasses `Connection.command()` and `network.command()`.** - -### Path 5 — Client-level bulk writes (cross-collection) -`client_bulk.py` mirrors Paths 3 and 4 with its own `write_command()` and `unack_write()` wrappers — another full copy of the APM/logging boilerplate. - ---- - -## Key Findings - -**1. APM/logging code is copy-pasted 5+ times.** -`Server.run_operation()`, `_Bulk.write_command()`, `_Bulk.unack_write()`, `_ClientBulk.write_command()`, `_ClientBulk.unack_write()` all contain the same ~70-80 line block of `_COMMAND_LOGGER.isEnabledFor(DEBUG)` + `_debug_log(...)` + `listeners.publish_command_start/success/failure(...)`. The reference implementation is `network.command()`. - -**2. Dead legacy OP_QUERY code in `Server.run_operation()`.** -`MIN_SUPPORTED_WIRE_VERSION = 8` (MongoDB 4.2+, `common.py`). OP_MSG was introduced at wire version 6. The `use_cmd = False` branch (legacy OP_QUERY path) in `run_operation()` can never be reached for any currently supported server. The `_Query.use_command()` check at `message.py:1622` confirms: for wire version ≥ 8, it is unconditionally True. - -**3. `_op_msg()` already handles bulk write Type 1 sections.** -`message._op_msg()` (line 394) detects insert/update/delete commands, pops the documents field, and encodes them as a Type 1 section. The non-encrypted bulk write path bypasses this and does its own single-pass encode+batch-size-check via `_do_batched_op_msg`. Unifying would require separating batch determination from encoding (currently combined for performance). - -**4. Encrypted bulk writes already use `Connection.command()`.** -In `_Bulk._execute_batch()` (bulk.py:444), the `if self.is_encrypted:` branch calls `bwc.conn.command()`. The non-encrypted `else` branch goes through `Connection.write_command()`. These should be the same path. - -**5. Async is generated from async source.** -All `pymongo/synchronous/*.py` files are auto-generated from `pymongo/asynchronous/*.py` via `tools/synchro.sh`. All code changes must be made in the `asynchronous/` files only. - ---- - -## Proposed Consolidation: Phased Approach - -### Phase 1 — Remove dead OP_QUERY code from `run_operation()` *(Low risk)* - -**What:** Delete the `use_cmd = False` branches from `Server.run_operation()`. Remove the conditional `if use_cmd: ... else: ...` blocks for `user_fields`/`legacy_response` and response-building. - -**Files:** `pymongo/asynchronous/server.py`, `pymongo/asynchronous/cursor.py` (remove the dead `else` branches referencing `_OpReply`) - -**Why now:** This is purely dead code removal — no behavior change. It simplifies Phase 3 significantly and is independently safe. - -**Verification:** Full test suite; specifically `test/test_cursor.py` and `test/test_unified_format.py`. - ---- - -### Phase 2 — Extract shared APM/logging helpers *(Low risk)* - -**What:** Extract the duplicated APM/logging block into shared helper functions in `pymongo/helpers_shared.py` (or a new `pymongo/command_helpers.py`). All five duplicate sites call the helpers instead of inlining the code. - -```python -# No sync/async split needed — these functions contain no I/O -def _log_and_publish_command_started(client, listeners, cmd, dbname, request_id, conn): - ... - - -def _log_and_publish_command_succeeded( - client, listeners, cmd, dbname, request_id, conn, reply, duration -): - ... - - -def _log_and_publish_command_failed( - client, listeners, cmd, dbname, request_id, conn, failure, duration -): - ... -``` - -**Files:** `pymongo/helpers_shared.py` (or new `pymongo/command_helpers.py`), `pymongo/asynchronous/network.py`, `pymongo/asynchronous/server.py`, `pymongo/asynchronous/bulk.py`, `pymongo/asynchronous/client_bulk.py` - -**Why this matters:** This is the direct solution to the stated problem. When future cross-cutting logic needs to be added to "command execution", there is one place to add it. - -**Note on APM operationId:** `_BulkWriteContext` passes `op_id` (not `request_id`) as the `operationId` in APM events. This is spec-required behavior that must be preserved in the helpers. - -**Verification:** `test/test_command_monitoring.py`, `test/test_monitoring.py`, `test/test_unified_format.py`. - ---- - -### Phase 3 — Unify `Server.run_operation()` with `network.command()` *(Medium risk)* - -**What:** Refactor `network.command()` into a two-layer structure: - -- `_network_command_core()` — performs all work (encode, APM, send, recv) and returns `(response_doc, raw_reply, request_id, duration)` -- `command()` — thin wrapper returning just `response_doc` (existing callers unchanged) - -`Server.run_operation()` (after Phase 1+2) calls `_network_command_core()` for the actual transport, then wraps the result in `Response`/`PinnedResponse`. - -**Key complexity:** `RawBatchCursor._unpack_response()` (`cursor.py:1196`) calls `response.raw_response(cursor_id)` on the raw `_OpMsg` object — it needs the raw reply, not just the decoded dict. The `_network_command_core()` design satisfies this by exposing `raw_reply`. The `@_handle_reauth` decorator on `run_operation()` must also be preserved. - -**Files:** `pymongo/asynchronous/network.py`, `pymongo/asynchronous/server.py` - -**Result:** Cursor operations (find/getMore) fully share the command execution path with all other operations. Single place to add transport-level logic. - -**Verification:** `test/test_cursor.py`, exhaust cursor tests, `test/test_encryption.py`, `test/test_csot.py`. - ---- - -### Phase 4 — Unify non-encrypted bulk write path with `Connection.command()` *(Higher risk, defer)* - -**What:** Change `_Bulk._execute_batch()` to use `Connection.command()` for non-encrypted bulk writes, matching the already-unified encrypted path. Requires `_BulkWriteContext.batch_command()` to return a command dict instead of pre-encoded bytes. - -**The challenge:** `_do_batched_op_msg` performs batch-size determination and encoding in a single pass (including a C extension: `_cmessage._batched_op_msg`). Separating batch determination from encoding requires a two-pass approach, adding encoding overhead. **Needs benchmarking before committing.** - -`_EncryptedBulkWriteContext.batch_command()` encodes to bytes and then deserializes back via `_inflate_bson` — aligning the non-encrypted path with this approach may be the right model. - -**Result after Phase 4:** `Connection.write_command()`, `Connection.unack_write()`, `_Bulk.write_command()`, `_Bulk.unack_write()`, `_ClientBulk.write_command()`, and `_ClientBulk.unack_write()` have no callers and can be removed. - -**Gate:** Bulk write benchmarks before/after. Suggested regression threshold: ≤2% throughput degradation. - ---- - -## Definition of Done - -- All database operations route through `Connection.command()` → `network.command()` for actual transport -- APM/logging code exists in one location -- `Connection.write_command()`, `Connection.unack_write()`, `_Bulk.write_command()`, `_Bulk.unack_write()`, `_ClientBulk.write_command()`, `_ClientBulk.unack_write()` are removed (or reduced to thin wrappers) -- No performance regression (bulk write benchmarks) -- No behavioral regression (full spec test suite passes) - ---- - -## Risk Summary - -| Phase | Risk | Effort | Value | -|-------|------|--------|-------| -| 1 — Remove dead OP_QUERY code | Low | ~1 day | Medium (simplification) | -| 2 — Extract APM helpers | Low | ~2 days | **High** (solves the stated problem) | -| 3 — Unify run_operation with network.command | Medium | ~3 days | High (true path consolidation) | -| 4 — Unify bulk write bytes path | Higher | ~1 week + benchmarking | Medium (removes last bypass) | - -Phases 1 and 2 can be done together in one PR. Phase 3 should be a separate PR. Phase 4 should be gated on benchmarking. - ---- - -## Critical Files - -| File | Role | -|------|------| -| `pymongo/asynchronous/network.py:61` | Central `command()` — reference implementation | -| `pymongo/asynchronous/server.py:138` | `run_operation()` — cursor path bypass | -| `pymongo/asynchronous/pool.py:344,469,480` | `Connection.command`, `unack_write`, `write_command` | -| `pymongo/asynchronous/bulk.py:244,329,444` | `_Bulk.write_command`, `unack_write`, `_execute_batch` | -| `pymongo/asynchronous/client_bulk.py:229,320` | `_ClientBulk.write_command`, `unack_write` | -| `pymongo/message.py:394,695,1622` | `_op_msg`, `_BulkWriteContext.batch_command`, `_Query.use_command` | -| `pymongo/common.py` | `MIN_SUPPORTED_WIRE_VERSION = 8` | diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index e1d2e3e6a..a7696af44 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -19,6 +19,7 @@ import datetime from typing import ( TYPE_CHECKING, Any, + Callable, Mapping, MutableMapping, Optional, @@ -41,10 +42,6 @@ from pymongo.errors import ( ) from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate -from pymongo.network_layer import ( - async_receive_message, - async_sendall, -) if TYPE_CHECKING: from bson import CodecOptions @@ -60,6 +57,166 @@ if TYPE_CHECKING: _IS_SYNC = False +_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} + + +async def _network_command_core( + conn: AsyncConnection, + dbname: str, + spec: MutableMapping[str, Any], + request_id: int, + msg: Optional[bytes], + max_doc_size: int, + codec_options: CodecOptions[_DocumentType], + session: Optional[AsyncClientSession], + client: Optional[AsyncMongoClient[Any]], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + user_fields: Optional[Mapping[str, Any]] = None, + unacknowledged: bool = False, + more_to_come: bool = False, + unpack_res: Optional[Callable[..., list[_DocumentOut]]] = None, + cursor_id: Optional[int] = None, + orig: Optional[MutableMapping[str, Any]] = None, + speculative_hello: bool = False, +) -> tuple[list[_DocumentOut], Optional[_OpMsg], datetime.timedelta]: + """Send/receive a command and return (docs, raw_reply, duration). + + Handles APM logging, send/receive, unpacking, response processing, + and decryption. Both the standard command path and the cursor + (find/getMore) path go through this function. + """ + publish = listeners is not None and listeners.enabled_for_commands + name = next(iter(spec)) + reply: Optional[_OpMsg] = None + docs: list[_DocumentOut] = [] + + if client is not None: + _log_command_started(client, conn, spec, dbname, request_id, request_id) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_start( + orig if orig is not None else spec, + dbname, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + ) + + try: + if more_to_come: + reply = await conn.receive_message(None) + else: + assert msg is not None + await conn.send_message(msg, max_doc_size) + if unacknowledged: + # Unacknowledged write: fake a successful command response. + docs = [{"ok": 1}] # type: ignore[list-item] + else: + reply = await conn.receive_message(request_id) + + if reply is not None: + conn.more_to_come = reply.more_to_come + if unpack_res is not None: + docs = unpack_res( + reply, + cursor_id, + codec_options, + legacy_response=False, + user_fields=_CURSOR_DOC_FIELDS, + ) + else: + docs = list( + reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + ) + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + await client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) + except Exception as exc: + duration = datetime.datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = message._convert_exception(exc) + if client is not None: + _log_command_failed( + client, + conn, + spec, + dbname, + request_id, + request_id, + failure, + duration, + isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_failure( + duration, + failure, + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + database_name=dbname, + ) + raise + duration = datetime.datetime.now() - start + if client is not None: + _log_command_succeeded( + client, + conn, + spec, + dbname, + request_id, + request_id, + docs[0], + duration, + speculative_hello, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_success( + duration, + docs[0], + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + speculative_hello=speculative_hello, + database_name=dbname, + ) + + # Decrypt response. + if client and client._encrypter and reply is not None: + decrypted = await client._encrypter.decrypt(reply.raw_command_response()) + decrypt_fields = _CURSOR_DOC_FIELDS if unpack_res is not None else user_fields + docs = list(_decode_all_selective(decrypted, codec_options, decrypt_fields)) + + return docs, reply, duration + async def command( conn: AsyncConnection, @@ -159,114 +316,30 @@ async def command( request_id, msg, size = message._query( 0, ns, 0, -1, spec, None, codec_options, compression_ctx ) + max_doc_size = 0 if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - if client is not None: - _log_command_started(client, conn, spec, dbname, request_id, request_id) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - try: - await async_sendall(conn.conn.get_conn, msg) - if use_op_msg and unacknowledged: - # Unacknowledged, fake a successful command response. - reply = None - response_doc: _DocumentOut = {"ok": 1} - else: - reply = await async_receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - unpacked_docs = reply.unpack_response( - codec_options=codec_options, user_fields=user_fields - ) - - response_doc = unpacked_docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - await client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - ) - except Exception as exc: - duration = datetime.datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = message._convert_exception(exc) - if client is not None: - _log_command_failed( - client, - conn, - spec, - dbname, - request_id, - request_id, - failure, - duration, - isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = datetime.datetime.now() - start - if client is not None: - _log_command_succeeded( - client, - conn, - spec, - dbname, - request_id, - request_id, - response_doc, - duration, - "speculativeAuthenticate" in orig, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - response_doc, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - speculative_hello=speculative_hello, - database_name=dbname, - ) - - if client and client._encrypter and reply: - decrypted = await client._encrypter.decrypt(reply.raw_command_response()) - response_doc = cast( - "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] - ) - - return response_doc # type: ignore[return-value] + docs, _reply, _duration = await _network_command_core( + conn=conn, + dbname=dbname, + spec=spec, + request_id=request_id, + msg=msg, + max_doc_size=max_doc_size, + codec_options=codec_options, + session=session, + client=client, + listeners=listeners, + address=address, + start=start, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + user_fields=user_fields, + unacknowledged=unacknowledged, + orig=orig, + speculative_hello=speculative_hello, + ) + return cast("_DocumentType", docs[0]) diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 3e37dae24..3f881de2f 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -26,21 +26,14 @@ from typing import ( Union, ) -from bson import _decode_all_selective from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.command_helpers import ( - _log_command_failed, - _log_command_started, - _log_command_succeeded, -) -from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers_shared import _check_command_response +from pymongo.asynchronous.network import _network_command_core from pymongo.logger import ( _SDAM_LOGGER, _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _convert_exception, _GetMore, _Query +from pymongo.message import _GetMore, _Query from pymongo.response import PinnedResponse, Response if TYPE_CHECKING: @@ -58,8 +51,6 @@ if TYPE_CHECKING: _IS_SYNC = False -_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} - class Server: def __init__( @@ -161,7 +152,6 @@ class Server: :param client: An AsyncMongoClient instance. """ assert listeners is not None - publish = listeners.enabled_for_commands start = datetime.now() operation.use_command(conn) @@ -169,101 +159,38 @@ class Server: cmd, dbn = await self.operation_to_command(operation, conn, True) if more_to_come: request_id = 0 + msg = None + max_doc_size = 0 else: - message = operation.get_message(read_preference, conn, True) - request_id, data, max_doc_size = self._split_message(message) + op_message = operation.get_message(read_preference, conn, True) + request_id, msg, max_doc_size = self._split_message(op_message) - _log_command_started(client, conn, cmd, dbn, request_id, request_id) + if listeners.enabled_for_commands and "$db" not in cmd: + cmd["$db"] = dbn - if publish: - if "$db" not in cmd: - cmd["$db"] = dbn - assert listeners is not None - listeners.publish_command_start( - cmd, - dbn, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - if more_to_come: - reply = await conn.receive_message(None) - else: - await conn.send_message(data, max_doc_size) - reply = await conn.receive_message(request_id) - - # Unpack and check for command errors. - docs = unpack_res( - reply, - operation.cursor_id, - operation.codec_options, - legacy_response=False, - user_fields=_CURSOR_DOC_FIELDS, - ) - first = docs[0] - await operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] - _check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type] - except Exception as exc: - duration = datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - _log_command_failed( - client, - conn, - cmd, - dbn, - request_id, - request_id, - failure, - duration, - isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - listeners.publish_command_failure( - duration, - failure, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - raise - duration = datetime.now() - start - # Must publish in find / getMore / explain command response format. - res = docs[0] - _log_command_succeeded(client, conn, cmd, dbn, request_id, request_id, res, duration) - if publish: - assert listeners is not None - listeners.publish_command_success( - duration, - res, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - - # Decrypt response. - client = operation.client # type: ignore[assignment] - if client and client._encrypter: - decrypted = await client._encrypter.decrypt(reply.raw_command_response()) - docs = _decode_all_selective(decrypted, operation.codec_options, _CURSOR_DOC_FIELDS) + docs, reply, duration = await _network_command_core( + conn=conn, + dbname=dbn, + spec=cmd, + request_id=request_id, + msg=msg, + max_doc_size=max_doc_size, + codec_options=operation.codec_options, + session=operation.session, # type: ignore[arg-type] + client=client, + listeners=listeners, + address=conn.address, + start=start, + more_to_come=more_to_come, + unpack_res=unpack_res, + cursor_id=operation.cursor_id, + ) response: Response - + client = operation.client # type: ignore[assignment] if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type] conn.pin_cursor() - more_to_come = reply.more_to_come + more_to_come = reply.more_to_come # type: ignore[union-attr] if operation.conn_mgr: operation.conn_mgr.update_exhaust(more_to_come) response = PinnedResponse( diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index 3ab3664e7..93a5d0c6e 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -19,6 +19,7 @@ import datetime from typing import ( TYPE_CHECKING, Any, + Callable, Mapping, MutableMapping, Optional, @@ -41,10 +42,6 @@ from pymongo.errors import ( ) from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate -from pymongo.network_layer import ( - receive_message, - sendall, -) if TYPE_CHECKING: from bson import CodecOptions @@ -60,6 +57,166 @@ if TYPE_CHECKING: _IS_SYNC = True +_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} + + +def _network_command_core( + conn: Connection, + dbname: str, + spec: MutableMapping[str, Any], + request_id: int, + msg: Optional[bytes], + max_doc_size: int, + codec_options: CodecOptions[_DocumentType], + session: Optional[ClientSession], + client: Optional[MongoClient[Any]], + listeners: Optional[_EventListeners], + address: Optional[_Address], + start: datetime.datetime, + check: bool = True, + allowable_errors: Optional[Sequence[Union[str, int]]] = None, + parse_write_concern_error: bool = False, + user_fields: Optional[Mapping[str, Any]] = None, + unacknowledged: bool = False, + more_to_come: bool = False, + unpack_res: Optional[Callable[..., list[_DocumentOut]]] = None, + cursor_id: Optional[int] = None, + orig: Optional[MutableMapping[str, Any]] = None, + speculative_hello: bool = False, +) -> tuple[list[_DocumentOut], Optional[_OpMsg], datetime.timedelta]: + """Send/receive a command and return (docs, raw_reply, duration). + + Handles APM logging, send/receive, unpacking, response processing, + and decryption. Both the standard command path and the cursor + (find/getMore) path go through this function. + """ + publish = listeners is not None and listeners.enabled_for_commands + name = next(iter(spec)) + reply: Optional[_OpMsg] = None + docs: list[_DocumentOut] = [] + + if client is not None: + _log_command_started(client, conn, spec, dbname, request_id, request_id) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_start( + orig if orig is not None else spec, + dbname, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + ) + + try: + if more_to_come: + reply = conn.receive_message(None) + else: + assert msg is not None + conn.send_message(msg, max_doc_size) + if unacknowledged: + # Unacknowledged write: fake a successful command response. + docs = [{"ok": 1}] # type: ignore[list-item] + else: + reply = conn.receive_message(request_id) + + if reply is not None: + conn.more_to_come = reply.more_to_come + if unpack_res is not None: + docs = unpack_res( + reply, + cursor_id, + codec_options, + legacy_response=False, + user_fields=_CURSOR_DOC_FIELDS, + ) + else: + docs = list( + reply.unpack_response(codec_options=codec_options, user_fields=user_fields) + ) + response_doc = docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) + except Exception as exc: + duration = datetime.datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = message._convert_exception(exc) + if client is not None: + _log_command_failed( + client, + conn, + spec, + dbname, + request_id, + request_id, + failure, + duration, + isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_failure( + duration, + failure, + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + database_name=dbname, + ) + raise + duration = datetime.datetime.now() - start + if client is not None: + _log_command_succeeded( + client, + conn, + spec, + dbname, + request_id, + request_id, + docs[0], + duration, + speculative_hello, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_success( + duration, + docs[0], + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + speculative_hello=speculative_hello, + database_name=dbname, + ) + + # Decrypt response. + if client and client._encrypter and reply is not None: + decrypted = client._encrypter.decrypt(reply.raw_command_response()) + decrypt_fields = _CURSOR_DOC_FIELDS if unpack_res is not None else user_fields + docs = list(_decode_all_selective(decrypted, codec_options, decrypt_fields)) + + return docs, reply, duration + def command( conn: Connection, @@ -159,114 +316,30 @@ def command( request_id, msg, size = message._query( 0, ns, 0, -1, spec, None, codec_options, compression_ctx ) + max_doc_size = 0 if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - if client is not None: - _log_command_started(client, conn, spec, dbname, request_id, request_id) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - try: - sendall(conn.conn.get_conn, msg) - if use_op_msg and unacknowledged: - # Unacknowledged, fake a successful command response. - reply = None - response_doc: _DocumentOut = {"ok": 1} - else: - reply = receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - unpacked_docs = reply.unpack_response( - codec_options=codec_options, user_fields=user_fields - ) - - response_doc = unpacked_docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - ) - except Exception as exc: - duration = datetime.datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = message._convert_exception(exc) - if client is not None: - _log_command_failed( - client, - conn, - spec, - dbname, - request_id, - request_id, - failure, - duration, - isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = datetime.datetime.now() - start - if client is not None: - _log_command_succeeded( - client, - conn, - spec, - dbname, - request_id, - request_id, - response_doc, - duration, - "speculativeAuthenticate" in orig, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - response_doc, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - speculative_hello=speculative_hello, - database_name=dbname, - ) - - if client and client._encrypter and reply: - decrypted = client._encrypter.decrypt(reply.raw_command_response()) - response_doc = cast( - "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] - ) - - return response_doc # type: ignore[return-value] + docs, _reply, _duration = _network_command_core( + conn=conn, + dbname=dbname, + spec=spec, + request_id=request_id, + msg=msg, + max_doc_size=max_doc_size, + codec_options=codec_options, + session=session, + client=client, + listeners=listeners, + address=address, + start=start, + check=check, + allowable_errors=allowable_errors, + parse_write_concern_error=parse_write_concern_error, + user_fields=user_fields, + unacknowledged=unacknowledged, + orig=orig, + speculative_hello=speculative_hello, + ) + return cast("_DocumentType", docs[0]) diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 9d033df67..212b5e5b4 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -26,22 +26,15 @@ from typing import ( Union, ) -from bson import _decode_all_selective -from pymongo.command_helpers import ( - _log_command_failed, - _log_command_started, - _log_command_succeeded, -) -from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.helpers_shared import _check_command_response from pymongo.logger import ( _SDAM_LOGGER, _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _convert_exception, _GetMore, _Query +from pymongo.message import _GetMore, _Query from pymongo.response import PinnedResponse, Response from pymongo.synchronous.helpers import _handle_reauth +from pymongo.synchronous.network import _network_command_core if TYPE_CHECKING: from queue import Queue @@ -58,8 +51,6 @@ if TYPE_CHECKING: _IS_SYNC = True -_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} - class Server: def __init__( @@ -161,7 +152,6 @@ class Server: :param client: A MongoClient instance. """ assert listeners is not None - publish = listeners.enabled_for_commands start = datetime.now() operation.use_command(conn) @@ -169,101 +159,38 @@ class Server: cmd, dbn = self.operation_to_command(operation, conn, True) if more_to_come: request_id = 0 + msg = None + max_doc_size = 0 else: - message = operation.get_message(read_preference, conn, True) - request_id, data, max_doc_size = self._split_message(message) + op_message = operation.get_message(read_preference, conn, True) + request_id, msg, max_doc_size = self._split_message(op_message) - _log_command_started(client, conn, cmd, dbn, request_id, request_id) + if listeners.enabled_for_commands and "$db" not in cmd: + cmd["$db"] = dbn - if publish: - if "$db" not in cmd: - cmd["$db"] = dbn - assert listeners is not None - listeners.publish_command_start( - cmd, - dbn, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - if more_to_come: - reply = conn.receive_message(None) - else: - conn.send_message(data, max_doc_size) - reply = conn.receive_message(request_id) - - # Unpack and check for command errors. - docs = unpack_res( - reply, - operation.cursor_id, - operation.codec_options, - legacy_response=False, - user_fields=_CURSOR_DOC_FIELDS, - ) - first = docs[0] - operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] - _check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type] - except Exception as exc: - duration = datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - _log_command_failed( - client, - conn, - cmd, - dbn, - request_id, - request_id, - failure, - duration, - isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - listeners.publish_command_failure( - duration, - failure, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - raise - duration = datetime.now() - start - # Must publish in find / getMore / explain command response format. - res = docs[0] - _log_command_succeeded(client, conn, cmd, dbn, request_id, request_id, res, duration) - if publish: - assert listeners is not None - listeners.publish_command_success( - duration, - res, - operation.name, - request_id, - conn.address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbn, - ) - - # Decrypt response. - client = operation.client # type: ignore[assignment] - if client and client._encrypter: - decrypted = client._encrypter.decrypt(reply.raw_command_response()) - docs = _decode_all_selective(decrypted, operation.codec_options, _CURSOR_DOC_FIELDS) + docs, reply, duration = _network_command_core( + conn=conn, + dbname=dbn, + spec=cmd, + request_id=request_id, + msg=msg, + max_doc_size=max_doc_size, + codec_options=operation.codec_options, + session=operation.session, # type: ignore[arg-type] + client=client, + listeners=listeners, + address=conn.address, + start=start, + more_to_come=more_to_come, + unpack_res=unpack_res, + cursor_id=operation.cursor_id, + ) response: Response - + client = operation.client # type: ignore[assignment] if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type] conn.pin_cursor() - more_to_come = reply.more_to_come + more_to_come = reply.more_to_come # type: ignore[union-attr] if operation.conn_mgr: operation.conn_mgr.update_exhaust(more_to_come) response = PinnedResponse(