diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index 5a5dc7fa2..b584dc8ad 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -16,7 +16,6 @@ from __future__ import annotations import datetime -import logging from typing import ( TYPE_CHECKING, Any, @@ -31,17 +30,14 @@ from typing import ( from bson import _decode_all_selective from pymongo import _csot, helpers_shared, message from pymongo.compression_support import _NO_COMPRESSION -from pymongo.errors import ( - NotPrimaryError, - OperationFailure, -) -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( async_receive_message, async_sendall, ) +from pymongo.telemetry import command_telemetry +from pymongo.tracing import add_cursor_id if TYPE_CHECKING: from bson import CodecOptions @@ -159,140 +155,71 @@ async def command( if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=spec, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - try: - await async_sendall(conn.conn.get_conn, msg) - if use_op_msg and unacknowledged: - # Unacknowledged, fake a successful command response. - reply = None - response_doc: _DocumentOut = {"ok": 1} - else: - reply = await async_receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - unpacked_docs = reply.unpack_response( - codec_options=codec_options, user_fields=user_fields - ) + with command_telemetry( + command_name=name, + database_name=dbname, + spec=spec, + address=address if address else conn.address, + driver_connection_id=conn.id, + server_connection_id=conn.server_connection_id, + publish_event=publish, + start_time=start, + client=client, + listeners=listeners, + request_id=request_id, + service_id=conn.service_id, + ) as telemetry: + try: + await async_sendall(conn.conn.get_conn, msg) + if use_op_msg and unacknowledged: + # Unacknowledged, fake a successful command response. + reply = None + response_doc: _DocumentOut = {"ok": 1} + else: + reply = await async_receive_message(conn, request_id) + conn.more_to_come = reply.more_to_come + unpacked_docs = reply.unpack_response( + codec_options=codec_options, user_fields=user_fields + ) - response_doc = unpacked_docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - await client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - ) - except Exception as exc: - duration = datetime.datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = message._convert_exception(exc) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = datetime.datetime.now() - start - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=response_doc, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - speculative_authenticate="speculativeAuthenticate" in orig, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - response_doc, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, + response_doc = unpacked_docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + await client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) + except Exception as exc: + telemetry.publish_failed(exc) + raise + + # Add cursor_id to span if present in response + if telemetry.span is not None and isinstance(response_doc, dict): + cursor_info = response_doc.get("cursor") + if cursor_info and isinstance(cursor_info, dict): + cursor_id = cursor_info.get("id", 0) + if cursor_id: + add_cursor_id(telemetry.span, cursor_id) + + # Publish command succeeded event + telemetry.publish_succeeded( + reply=response_doc, speculative_hello=speculative_hello, - database_name=dbname, + speculative_authenticate="speculativeAuthenticate" in orig, ) - if client and client._encrypter and reply: - decrypted = await client._encrypter.decrypt(reply.raw_command_response()) - response_doc = cast( - "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] - ) + if client and client._encrypter and reply: + decrypted = await client._encrypter.decrypt(reply.raw_command_response()) + response_doc = cast( + "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] + ) - return response_doc # type: ignore[return-value] + return response_doc # type: ignore[return-value] diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index 7d9bca4d5..497535cb4 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -16,7 +16,6 @@ from __future__ import annotations import datetime -import logging from typing import ( TYPE_CHECKING, Any, @@ -31,17 +30,14 @@ from typing import ( from bson import _decode_all_selective from pymongo import _csot, helpers_shared, message from pymongo.compression_support import _NO_COMPRESSION -from pymongo.errors import ( - NotPrimaryError, - OperationFailure, -) -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( receive_message, sendall, ) +from pymongo.telemetry import command_telemetry +from pymongo.tracing import add_cursor_id if TYPE_CHECKING: from bson import CodecOptions @@ -159,140 +155,71 @@ def command( if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=client._topology_settings._topology_id, - command=spec, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - try: - sendall(conn.conn.get_conn, msg) - if use_op_msg and unacknowledged: - # Unacknowledged, fake a successful command response. - reply = None - response_doc: _DocumentOut = {"ok": 1} - else: - reply = receive_message(conn, request_id) - conn.more_to_come = reply.more_to_come - unpacked_docs = reply.unpack_response( - codec_options=codec_options, user_fields=user_fields - ) + with command_telemetry( + command_name=name, + database_name=dbname, + spec=spec, + address=address if address else conn.address, + driver_connection_id=conn.id, + server_connection_id=conn.server_connection_id, + publish_event=publish, + start_time=start, + client=client, + listeners=listeners, + request_id=request_id, + service_id=conn.service_id, + ) as telemetry: + try: + sendall(conn.conn.get_conn, msg) + if use_op_msg and unacknowledged: + # Unacknowledged, fake a successful command response. + reply = None + response_doc: _DocumentOut = {"ok": 1} + else: + reply = receive_message(conn, request_id) + conn.more_to_come = reply.more_to_come + unpacked_docs = reply.unpack_response( + codec_options=codec_options, user_fields=user_fields + ) - response_doc = unpacked_docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - ) - except Exception as exc: - duration = datetime.datetime.now() - start - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = message._convert_exception(exc) - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - failure=failure, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = datetime.datetime.now() - start - if client is not None: - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=client._topology_settings._topology_id, - durationMS=duration, - reply=response_doc, - commandName=next(iter(spec)), - databaseName=dbname, - requestId=request_id, - operationId=request_id, - driverConnectionId=conn.id, - serverConnectionId=conn.server_connection_id, - serverHost=conn.address[0], - serverPort=conn.address[1], - serviceId=conn.service_id, - speculative_authenticate="speculativeAuthenticate" in orig, - ) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - response_doc, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, + response_doc = unpacked_docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) + except Exception as exc: + telemetry.publish_failed(exc) + raise + + # Add cursor_id to span if present in response + if telemetry.span is not None and isinstance(response_doc, dict): + cursor_info = response_doc.get("cursor") + if cursor_info and isinstance(cursor_info, dict): + cursor_id = cursor_info.get("id", 0) + if cursor_id: + add_cursor_id(telemetry.span, cursor_id) + + # Publish command succeeded event + telemetry.publish_succeeded( + reply=response_doc, speculative_hello=speculative_hello, - database_name=dbname, + speculative_authenticate="speculativeAuthenticate" in orig, ) - if client and client._encrypter and reply: - decrypted = client._encrypter.decrypt(reply.raw_command_response()) - response_doc = cast( - "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] - ) + if client and client._encrypter and reply: + decrypted = client._encrypter.decrypt(reply.raw_command_response()) + response_doc = cast( + "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] + ) - return response_doc # type: ignore[return-value] + return response_doc # type: ignore[return-value] diff --git a/pymongo/telemetry.py b/pymongo/telemetry.py new file mode 100644 index 000000000..e909c0899 --- /dev/null +++ b/pymongo/telemetry.py @@ -0,0 +1,339 @@ +# Copyright 2026-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified telemetry support for PyMongo. + +Supports telemetry using standardized logging, event publishing, and OpenTelemetry. + +To enable OpenTelemetry logging, set the environment variable: + OTEL_PYTHON_INSTRUMENTATION_MONGODB_ENABLED=true + +.. versionadded:: 4.x +""" +from __future__ import annotations + +import logging +from datetime import datetime +from typing import TYPE_CHECKING, Any, Mapping, Optional + +from pymongo import message +from pymongo.errors import NotPrimaryError, OperationFailure +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.monitoring import _EventListeners +from pymongo.tracing import ( + _build_query_summary, + _extract_collection_name, + _get_tracer, + _is_sensitive_command, +) + +try: + from opentelemetry import trace + from opentelemetry.trace import Span, SpanKind, Status, StatusCode + + _HAS_OPENTELEMETRY = True +except ImportError: + _HAS_OPENTELEMETRY = False + trace = None # type: ignore[assignment] + Span = None # type: ignore[assignment, misc] + SpanKind = None # type: ignore[assignment, misc] + Status = None # type: ignore[assignment, misc] + StatusCode = None # type: ignore[assignment, misc] + +if TYPE_CHECKING: + from pymongo.typings import _Address, _AgnosticMongoClient, _DocumentOut + + +class _CommandTelemetry: + """Manages telemetry for MongoDB commands, including logging, event publishing, and OpenTelemetry spans. + + This class is a context manager that handles the full lifecycle of command telemetry: + - On entry (__enter__): Sets up OpenTelemetry span and publishes the started event + - On exit (__exit__): Cleans up the span context (caller handles success/failure publishing) + """ + + __slots__ = ( + "_command_name", + "_database_name", + "_spec", + "_driver_connection_id", + "_server_connection_id", + "_publish_event", + "_start_time", + "_address", + "_listeners", + "_client", + "_request_id", + "_service_id", + "_span", + "_span_context", + ) + + def __init__( + self, + command_name: str, + database_name: str, + spec: Mapping[str, Any], + driver_connection_id: int, + server_connection_id: Optional[int], + publish_event: bool, + start_time: datetime, + address: Optional[_Address], + listeners: Optional[_EventListeners], + client: Optional[_AgnosticMongoClient], + request_id: Optional[int], + service_id: Optional[Any], + ): + self._command_name = command_name + self._database_name = database_name + self._spec = spec + self._driver_connection_id = driver_connection_id + self._server_connection_id = server_connection_id + self._publish_event = publish_event + self._start_time = start_time + self._address = address + self._listeners = listeners + self._client = client + self._request_id = request_id + self._service_id = service_id + self._span: Optional[Span] = None + self._span_context: Optional[Any] = None + + def __enter__(self) -> _CommandTelemetry: + """Enter the telemetry context: set up span and publish started event.""" + self._setup_span() + self.publish_started() + return self + + def __exit__( + self, + exc_type: Optional[type], + exc_val: Optional[BaseException], + exc_tb: Optional[Any], + ) -> None: + """Exit the telemetry context: clean up span context.""" + if self._span_context is not None: + self._span_context.__exit__(exc_type, exc_val, exc_tb) + + def _setup_span(self) -> None: + """Set up OpenTelemetry span if tracing is enabled and command is not sensitive.""" + tracer = _get_tracer() + + if tracer is None or _is_sensitive_command(self._command_name): + return + + collection_name = _extract_collection_name(self._spec) + query_summary = _build_query_summary( + self._command_name, self._database_name, collection_name + ) + + self._span_context = tracer.start_as_current_span( + name=self._command_name, + kind=SpanKind.CLIENT, + ) + self._span = self._span_context.__enter__() + + # Set span attributes + self._span.set_attribute("db.system", "mongodb") + self._span.set_attribute("db.namespace", self._database_name) + self._span.set_attribute("db.command.name", self._command_name) + self._span.set_attribute("db.query.summary", query_summary) + if self._address: + self._span.set_attribute("server.address", self._address[0]) + self._span.set_attribute("server.port", self._address[1]) + self._span.set_attribute("network.transport", "tcp") + self._span.set_attribute("db.mongodb.driver_connection_id", self._driver_connection_id) + + if collection_name: + self._span.set_attribute("db.collection.name", collection_name) + if self._server_connection_id is not None: + self._span.set_attribute("db.mongodb.server_connection_id", self._server_connection_id) + + @property + def span(self) -> Optional[Span]: + """Return the OpenTelemetry span, or None if tracing is disabled.""" + return self._span + + def publish_started(self) -> None: + """Publish command started event and log.""" + if self._client is not None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=self._client._topology_settings._topology_id, + command=self._spec, + commandName=next(iter(self._spec)), + databaseName=self._database_name, + requestId=self._request_id, + operationId=self._request_id, + driverConnectionId=self._driver_connection_id, + serverConnectionId=self._server_connection_id, + serverHost=self._address[0] if self._address else None, + serverPort=self._address[1] if self._address else None, + serviceId=self._service_id, + ) + if self._publish_event: + assert self._listeners is not None + assert self._address is not None + self._listeners.publish_command_start( + self._spec, + self._database_name, + self._request_id, + self._address, + self._server_connection_id, + service_id=self._service_id, + ) + + def publish_succeeded( + self, + reply: _DocumentOut, + speculative_hello: bool = False, + speculative_authenticate: bool = False, + ) -> None: + """Publish command succeeded event and log.""" + duration = datetime.now() - self._start_time + if self._client is not None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=self._client._topology_settings._topology_id, + durationMS=duration, + reply=reply, + commandName=next(iter(self._spec)), + databaseName=self._database_name, + requestId=self._request_id, + operationId=self._request_id, + driverConnectionId=self._driver_connection_id, + serverConnectionId=self._server_connection_id, + serverHost=self._address[0] if self._address else None, + serverPort=self._address[1] if self._address else None, + serviceId=self._service_id, + speculative_authenticate=speculative_authenticate, + ) + if self._publish_event: + assert self._listeners is not None + assert self._address is not None + self._listeners.publish_command_success( + duration, + reply, + self._command_name, + self._request_id, + self._address, + self._server_connection_id, + service_id=self._service_id, + speculative_hello=speculative_hello, + database_name=self._database_name, + ) + + def publish_failed(self, exc: Exception) -> None: + """Publish command failed event and log.""" + duration = datetime.now() - self._start_time + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = message._convert_exception(exc) + + if self._span is not None: + error_code = getattr(exc, "code", None) + self._span.record_exception(exc) + self._span.set_status(Status(StatusCode.ERROR, str(exc))) + + if error_code is not None: + self._span.set_attribute("db.response.status_code", str(error_code)) + if self._client is not None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=self._client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(self._spec)), + databaseName=self._database_name, + requestId=self._request_id, + operationId=self._request_id, + driverConnectionId=self._driver_connection_id, + serverConnectionId=self._server_connection_id, + serverHost=self._address[0] if self._address else None, + serverPort=self._address[1] if self._address else None, + serviceId=self._service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if self._publish_event: + assert self._listeners is not None + assert self._address is not None + self._listeners.publish_command_failure( + duration, + failure, + self._command_name, + self._request_id, + self._address, + self._server_connection_id, + service_id=self._service_id, + database_name=self._database_name, + ) + + +def command_telemetry( + command_name: str, + database_name: str, + spec: Mapping[str, Any], + driver_connection_id: int, + server_connection_id: Optional[int], + publish_event: bool, + start_time: datetime, + address: Optional[_Address] = None, + listeners: Optional[_EventListeners] = None, + client: Optional[_AgnosticMongoClient] = None, + request_id: Optional[int] = None, + service_id: Optional[Any] = None, +) -> _CommandTelemetry: + """Create a _CommandTelemetry context manager for command telemetry. + + Returns a _CommandTelemetry instance that should be used as a context manager. + The context manager automatically: + - Sets up OpenTelemetry span if tracing is enabled and command is not sensitive + - Publishes the started event on entry + - Cleans up the span context on exit + + The caller is responsible for calling publish_succeeded() on successful completion + and publish_failed() if an exception occurs. + + Example usage:: + + with command_telemetry(...) as telemetry: + try: + # execute command + result = execute_command() + except Exception as exc: + telemetry.publish_failed(exc) + raise + telemetry.publish_succeeded(result) + """ + return _CommandTelemetry( + command_name=command_name, + database_name=database_name, + spec=spec, + driver_connection_id=driver_connection_id, + server_connection_id=server_connection_id, + publish_event=publish_event, + start_time=start_time, + address=address, + listeners=listeners, + client=client, + request_id=request_id, + service_id=service_id, + ) diff --git a/pymongo/tracing.py b/pymongo/tracing.py new file mode 100644 index 000000000..3c89a69ee --- /dev/null +++ b/pymongo/tracing.py @@ -0,0 +1,132 @@ +# Copyright 2026-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OpenTelemetry tracing support for PyMongo. + +This module provides optional OpenTelemetry tracing for MongoDB commands. +Tracing is disabled by default and requires the opentelemetry-api package. + +To enable tracing, set the environment variable: + OTEL_PYTHON_INSTRUMENTATION_MONGODB_ENABLED=true + +.. versionadded:: 4.x +""" +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any, Mapping, Optional + +from pymongo.logger import _SENSITIVE_COMMANDS + +try: + from opentelemetry import trace + from opentelemetry.trace import Span, SpanKind, Status, StatusCode + + _HAS_OPENTELEMETRY = True +except ImportError: + _HAS_OPENTELEMETRY = False + trace = None # type: ignore[assignment] + Span = None # type: ignore[assignment, misc] + SpanKind = None # type: ignore[assignment, misc] + Status = None # type: ignore[assignment, misc] + StatusCode = None # type: ignore[assignment, misc] + +if TYPE_CHECKING: + from opentelemetry.trace import Tracer + +# Environment variable names +_OTEL_ENABLED_ENV = "OTEL_PYTHON_INSTRUMENTATION_MONGODB_ENABLED" + + +def _is_tracing_enabled() -> bool: + """Check if tracing is enabled via environment variable.""" + if not _HAS_OPENTELEMETRY: + return False + value = os.environ.get(_OTEL_ENABLED_ENV, "").lower() + return value in ("1", "true") + + +def _get_tracer() -> Optional[Tracer]: + """Get the PyMongo tracer instance.""" + if not _HAS_OPENTELEMETRY or not _is_tracing_enabled(): + return None + from pymongo._version import __version__ + + return trace.get_tracer("PyMongo", __version__) + + +def _is_sensitive_command(command_name: str) -> bool: + """Check if a command is sensitive and should not be traced.""" + return command_name.lower() in _SENSITIVE_COMMANDS + + +def _build_query_summary( + command_name: str, + database_name: str, + collection_name: Optional[str], +) -> str: + """Build the db.query.summary attribute value.""" + if collection_name: + return f"{command_name} {database_name}.{collection_name}" + return f"{command_name} {database_name}" + + +def _extract_collection_name(spec: Mapping[str, Any]) -> Optional[str]: + """Extract collection name from command spec if applicable.""" + if not spec: + return None + cmd_name = next(iter(spec)).lower() + # Commands where the first value is the collection name + if cmd_name in ( + "insert", + "update", + "delete", + "find", + "aggregate", + "findandmodify", + "count", + "distinct", + "create", + "drop", + "createindexes", + "dropindexes", + "listindexes", + ): + value = spec.get(next(iter(spec))) + if isinstance(value, str): + return value + return None + + +def record_command_exception( + span: Optional[Span], + exception: BaseException, + error_code: Optional[int] = None, +) -> None: + """Record an exception on a command span.""" + if span is None or not _HAS_OPENTELEMETRY: + return + + span.record_exception(exception) + span.set_status(Status(StatusCode.ERROR, str(exception))) + + if error_code is not None: + span.set_attribute("db.response.status_code", str(error_code)) + + +def add_cursor_id(span: Optional[Span], cursor_id: int) -> None: + """Add cursor ID attribute to span if present.""" + if span is None or cursor_id == 0: + return + span.set_attribute("db.mongodb.cursor_id", cursor_id) diff --git a/pyproject.toml b/pyproject.toml index 9b3287834..fab9719b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ docs = ["requirements/docs.txt"] encryption = ["requirements/encryption.txt"] gssapi = ["requirements/gssapi.txt"] ocsp = ["requirements/ocsp.txt"] +opentelemetry = ["requirements/opentelemetry.txt"] snappy = ["requirements/snappy.txt"] test = ["requirements/test.txt"] zstd = ["requirements/zstd.txt"] diff --git a/uv.lock b/uv.lock index 78d0cc213..c8eb15473 100644 --- a/uv.lock +++ b/uv.lock @@ -1424,6 +1424,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, ] +[[package]] +name = "opentelemetry-api" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/97/b9/3161be15bb8e3ad01be8be5a968a9237c3027c5be504362ff800fca3e442/opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c", size = 65767, upload-time = "2025-12-11T13:32:39.182Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/df/d3f1ddf4bb4cb50ed9b1139cc7b1c54c34a1e7ce8fd1b9a37c0d1551a6bd/opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950", size = 66356, upload-time = "2025-12-11T13:32:17.304Z" }, +] + [[package]] name = "packaging" version = "25.0" @@ -1545,6 +1558,9 @@ ocsp = [ { name = "requests" }, { name = "service-identity" }, ] +opentelemetry = [ + { name = "opentelemetry-api" }, +] snappy = [ { name = "python-snappy" }, ] @@ -1591,6 +1607,7 @@ requires-dist = [ { name = "dnspython", specifier = ">=2.6.1,<3.0.0" }, { name = "furo", marker = "extra == 'docs'", specifier = "==2025.12.19" }, { name = "importlib-metadata", marker = "python_full_version < '3.13' and extra == 'test'", specifier = ">=7.0" }, + { name = "opentelemetry-api", marker = "extra == 'opentelemetry'", specifier = ">=1.20.0" }, { name = "pykerberos", marker = "os_name != 'nt' and extra == 'gssapi'", specifier = ">=1.2.4" }, { name = "pymongo-auth-aws", marker = "extra == 'aws'", specifier = ">=1.1.0,<2.0.0" }, { name = "pymongo-auth-aws", marker = "extra == 'encryption'", specifier = ">=1.1.0,<2.0.0" }, @@ -1608,7 +1625,7 @@ requires-dist = [ { name = "sphinxcontrib-shellcheck", marker = "extra == 'docs'", specifier = ">=1,<2" }, { name = "winkerberos", marker = "os_name == 'nt' and extra == 'gssapi'", specifier = ">=0.5.0" }, ] -provides-extras = ["aws", "docs", "encryption", "gssapi", "ocsp", "snappy", "test", "zstd"] +provides-extras = ["aws", "docs", "encryption", "gssapi", "ocsp", "opentelemetry", "snappy", "test", "zstd"] [package.metadata.requires-dev] coverage = [{ name = "coverage", extras = ["toml"], specifier = ">=5,<=7.10.7" }]