diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py new file mode 100644 index 000000000..a815cbc8a --- /dev/null +++ b/pymongo/mongo_client.py @@ -0,0 +1,22 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Re-import of synchronous MongoClient API for compatibility.""" +from __future__ import annotations + +from pymongo.synchronous.mongo_client import * # noqa: F403 +from pymongo.synchronous.mongo_client import __doc__ as original_doc + +__doc__ = original_doc +__all__ = ["MongoClient"] # noqa: F405 diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index f2076b087..cec78463b 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -34,6 +34,7 @@ from __future__ import annotations import contextlib import os +import warnings import weakref from collections import defaultdict from typing import ( @@ -42,8 +43,8 @@ from typing import ( Callable, ContextManager, FrozenSet, + Generator, Generic, - Iterator, Mapping, MutableMapping, NoReturn, @@ -57,23 +58,12 @@ from typing import ( from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import ( - _csot, - client_session, - common, - database, - helpers, - message, - periodic_executor, - uri_parser, -) -from pymongo.change_stream import ChangeStream, ClusterChangeStream +from pymongo import _csot, common, helpers_shared, uri_parser from pymongo.client_options import ClientOptions -from pymongo.client_session import _EmptyServerSession -from pymongo.command_cursor import CommandCursor from pymongo.errors import ( AutoReconnect, BulkWriteError, + ClientBulkWriteException, ConfigurationError, ConnectionFailure, InvalidOperation, @@ -86,13 +76,28 @@ from pymongo.errors import ( ) from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks from pymongo.logger import _CLIENT_LOGGER, _log_or_warn +from pymongo.message import _CursorAddress, _GetMore, _Query from pymongo.monitoring import ConnectionClosedReason -from pymongo.operations import _Op +from pymongo.operations import ( + DeleteMany, + DeleteOne, + InsertOne, + ReplaceOne, + UpdateMany, + UpdateOne, + _Op, +) from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.results import ClientBulkWriteResult from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.settings import TopologySettings -from pymongo.topology import Topology, _ErrorContext +from pymongo.synchronous import client_session, database, periodic_executor +from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream +from pymongo.synchronous.client_bulk import _ClientBulk +from pymongo.synchronous.client_session import _EmptyServerSession +from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.settings import TopologySettings +from pymongo.synchronous.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription from pymongo.typings import ( ClusterTime, @@ -111,44 +116,40 @@ from pymongo.uri_parser import ( from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern if TYPE_CHECKING: - import sys from types import TracebackType from bson.objectid import ObjectId - from pymongo.bulk import _Bulk - from pymongo.client_session import ClientSession, _ServerSession - from pymongo.cursor import _ConnectionManager - from pymongo.database import Database - from pymongo.message import _CursorAddress, _GetMore, _Query - from pymongo.pool import Connection from pymongo.read_concern import ReadConcern from pymongo.response import Response - from pymongo.server import Server from pymongo.server_selectors import Selection + from pymongo.synchronous.bulk import _Bulk + from pymongo.synchronous.client_session import ClientSession, _ServerSession + from pymongo.synchronous.cursor import _ConnectionManager + from pymongo.synchronous.pool import Connection + from pymongo.synchronous.server import Server - if sys.version_info[:2] >= (3, 9): - from collections.abc import Generator - else: - # Deprecated since version 3.9: collections.abc.Generator now supports []. - from typing import Generator T = TypeVar("T") _WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], T] -_ReadCall = Callable[[Optional["ClientSession"], "Server", "Connection", _ServerMode], T] +_ReadCall = Callable[ + [Optional["ClientSession"], "Server", "Connection", _ServerMode], + T, +] + +_IS_SYNC = True + +_WriteOp = Union[ + InsertOne, + DeleteOne, + DeleteMany, + ReplaceOne, + UpdateOne, + UpdateMany, +] class MongoClient(common.BaseObject, Generic[_DocumentType]): - """ - A client-side representation of a MongoDB cluster. - - Instances can represent either a standalone MongoDB server, a replica - set, or a sharded cluster. Instances of this class are responsible for - maintaining up-to-date state of the cluster, and possibly cache - resources related to this, including background threads for monitoring, - and connection pools. - """ - HOST = "localhost" PORT = 27017 # Define order to retrieve options from ClientOptions for __repr__. @@ -261,9 +262,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): :class:`~datetime.datetime` instances returned as values in a document by this :class:`MongoClient` will be timezone aware (otherwise they will be naive) - :param connect: if ``True`` (the default), immediately + :param connect: If ``True`` (the default), immediately begin connecting to MongoDB in the background. Otherwise connect - on the first operation. + on the first operation. The default value is ``False`` when + running in a Function-as-a-service environment. :param type_registry: instance of :class:`~bson.codec_options.TypeRegistry` to enable encoding and decoding of custom types. @@ -719,9 +721,13 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): .. versionchanged:: 4.7 Deprecated parameter ``wTimeoutMS``, use :meth:`~pymongo.timeout`. + + .. versionchanged:: 4.9 + The default value of ``connect`` is changed to ``False`` when running in a + Function-as-a-service environment. """ doc_class = document_class or dict - self.__init_kwargs: dict[str, Any] = { + self._init_kwargs: dict[str, Any] = { "host": host, "port": port, "document_class": doc_class, @@ -802,7 +808,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): if tz_aware is None: tz_aware = opts.get("tz_aware", False) if connect is None: - connect = opts.get("connect", True) + # Default to connect=True unless on a FaaS system, which might use fork. + from pymongo.pool_options import _is_faas + + connect = opts.get("connect", not _is_faas()) keyword_opts["tz_aware"] = tz_aware keyword_opts["connect"] = connect @@ -829,11 +838,11 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): # Username and password passed as kwargs override user info in URI. username = opts.get("username", username) password = opts.get("password", password) - self.__options = options = ClientOptions(username, password, dbase, opts) + self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) - self.__default_database_name = dbase - self.__lock = _create_lock() - self.__kill_cursors_queue: list = [] + self._default_database_name = dbase + self._lock = _create_lock() + self._kill_cursors_queue: list = [] self._event_listeners = options.pool_options._event_listeners super().__init__( @@ -862,23 +871,29 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): server_monitoring_mode=options.server_monitoring_mode, ) + self._opened = False + self._closed = False self._init_background() - if connect: - self._get_topology() + if _IS_SYNC and connect: + self._get_topology() # type: ignore[unused-coroutine] self._encrypter = None - if self.__options.auto_encryption_opts: - from pymongo.encryption import _Encrypter + if self._options.auto_encryption_opts: + from pymongo.synchronous.encryption import _Encrypter - self._encrypter = _Encrypter(self, self.__options.auto_encryption_opts) - self._timeout = self.__options.timeout + self._encrypter = _Encrypter(self, self._options.auto_encryption_opts) + self._timeout = self._options.timeout if _HAS_REGISTER_AT_FORK: # Add this client to the list of weakly referenced items. # This will be used later if we fork. MongoClient._clients[self._topology._topology_id] = self + def _connect(self) -> None: + """Explicitly connect to MongoDB synchronously instead of on the first operation.""" + self._get_topology() + def _init_background(self, old_pid: Optional[int] = None) -> None: self._topology = Topology(self._topology_settings) # Seed the topology with the old one's pid so we can detect clients @@ -903,31 +918,22 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): # this closure. When the client is freed, stop the executor soon. self_ref: Any = weakref.ref(self, executor.close) self._kill_cursors_executor = executor + self._opened = False + + def _should_pin_cursor(self, session: Optional[ClientSession]) -> Optional[bool]: + return self._options.load_balanced and not (session and session.in_transaction) def _after_fork(self) -> None: """Resets topology in a child after successfully forking.""" self._init_background(self._topology._pid) + # Reset the session pool to avoid duplicate sessions in the child process. + self._topology._session_pool.reset() def _duplicate(self, **kwargs: Any) -> MongoClient: - args = self.__init_kwargs.copy() + args = self._init_kwargs.copy() args.update(kwargs) return MongoClient(**args) - def _server_property(self, attr_name: str) -> Any: - """An attribute of the current server's description. - - If the client is not connected, this will block until a connection is - established or raise ServerSelectionTimeoutError if no server is - available. - - Not threadsafe if used multiple times in a single method, since - the server may change. In such cases, store a local reference to a - ServerDescription first, then use its properties. - """ - server = self._get_topology().select_server(writable_server_selector, _Op.TEST) - - return getattr(server.description, attr_name) - def watch( self, pipeline: Optional[_Pipeline] = None, @@ -1040,7 +1046,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): .. _change streams specification: https://github.com/mongodb/specifications/blob/master/source/change-streams/change-streams.md """ - return ClusterChangeStream( + change_stream = ClusterChangeStream( self.admin, pipeline, full_document, @@ -1056,6 +1062,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): show_expanded_events=show_expanded_events, ) + change_stream._initialize_cursor() + return change_stream + @property def topology_description(self) -> TopologyDescription: """The description of the connected MongoDB deployment. @@ -1077,6 +1086,346 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): """ return self._topology.description + @property + def nodes(self) -> FrozenSet[_Address]: + """Set of all currently connected servers. + + .. warning:: When connected to a replica set the value of :attr:`nodes` + can change over time as :class:`MongoClient`'s view of the replica + set changes. :attr:`nodes` can also be an empty set when + :class:`MongoClient` is first instantiated and hasn't yet connected + to any servers, or a network partition causes it to lose connection + to all servers. + """ + description = self._topology.description + return frozenset(s.address for s in description.known_servers) + + @property + def options(self) -> ClientOptions: + """The configuration options for this client. + + :return: An instance of :class:`~pymongo.client_options.ClientOptions`. + + .. versionadded:: 4.0 + """ + return self._options + + def __eq__(self, other: Any) -> bool: + if isinstance(other, self.__class__): + return self._topology == other._topology + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __hash__(self) -> int: + return hash(self._topology) + + def _repr_helper(self) -> str: + def option_repr(option: str, value: Any) -> str: + """Fix options whose __repr__ isn't usable in a constructor.""" + if option == "document_class": + if value is dict: + return "document_class=dict" + else: + return f"document_class={value.__module__}.{value.__name__}" + if option in common.TIMEOUT_OPTIONS and value is not None: + return f"{option}={int(value * 1000)}" + + return f"{option}={value!r}" + + # Host first... + options = [ + "host=%r" + % [ + "%s:%d" % (host, port) if port is not None else host + for host, port in self._topology_settings.seeds + ] + ] + # ... then everything in self._constructor_args... + options.extend( + option_repr(key, self._options._options[key]) for key in self._constructor_args + ) + # ... then everything else. + options.extend( + option_repr(key, self._options._options[key]) + for key in self._options._options + if key not in set(self._constructor_args) and key != "username" and key != "password" + ) + return ", ".join(options) + + def __repr__(self) -> str: + return f"{type(self).__name__}({self._repr_helper()})" + + def __getattr__(self, name: str) -> database.Database[_DocumentType]: + """Get a database by name. + + Raises :class:`~pymongo.errors.InvalidName` if an invalid + database name is used. + + :param name: the name of the database to get + """ + if name.startswith("_"): + raise AttributeError( + f"{type(self).__name__} has no attribute {name!r}. To access the {name}" + f" database, use client[{name!r}]." + ) + return self.__getitem__(name) + + def __getitem__(self, name: str) -> database.Database[_DocumentType]: + """Get a database by name. + + Raises :class:`~pymongo.errors.InvalidName` if an invalid + database name is used. + + :param name: the name of the database to get + """ + return database.Database(self, name) + + def __del__(self) -> None: + """Check that this MongoClient has been closed and issue a warning if not.""" + try: + if not self._closed: + warnings.warn( + ( + f"Unclosed {type(self).__name__} opened at:\n{self._topology_settings._stack}" + f"Call {type(self).__name__}.close() to safely shut down your client and free up resources." + ), + ResourceWarning, + stacklevel=2, + source=self, + ) + except AttributeError: + pass + + def _close_cursor_soon( + self, + cursor_id: int, + address: Optional[_CursorAddress], + conn_mgr: Optional[_ConnectionManager] = None, + ) -> None: + """Request that a cursor and/or connection be cleaned up soon.""" + self._kill_cursors_queue.append((address, cursor_id, conn_mgr)) + + def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession: + server_session = _EmptyServerSession() + opts = client_session.SessionOptions(**kwargs) + return client_session.ClientSession(self, server_session, opts, implicit) + + def start_session( + self, + causal_consistency: Optional[bool] = None, + default_transaction_options: Optional[client_session.TransactionOptions] = None, + snapshot: Optional[bool] = False, + ) -> client_session.ClientSession: + """Start a logical session. + + This method takes the same parameters as + :class:`~pymongo.client_session.SessionOptions`. See the + :mod:`~pymongo.client_session` module for details and examples. + + A :class:`~pymongo.client_session.ClientSession` may only be used with + the MongoClient that started it. :class:`ClientSession` instances are + **not thread-safe or fork-safe**. They can only be used by one thread + or process at a time. A single :class:`ClientSession` cannot be used + to run multiple operations concurrently. + + :return: An instance of :class:`~pymongo.client_session.ClientSession`. + + .. versionadded:: 3.6 + """ + return self._start_session( + False, + causal_consistency=causal_consistency, + default_transaction_options=default_transaction_options, + snapshot=snapshot, + ) + + def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]: + """If provided session is None, lend a temporary session.""" + if session: + return session + + try: + # Don't make implicit sessions causally consistent. Applications + # should always opt-in. + return self._start_session(True, causal_consistency=False) + except (ConfigurationError, InvalidOperation): + # Sessions not supported. + return None + + def _send_cluster_time( + self, command: MutableMapping[str, Any], session: Optional[ClientSession] + ) -> None: + topology_time = self._topology.max_cluster_time() + session_time = session.cluster_time if session else None + if topology_time and session_time: + if topology_time["clusterTime"] > session_time["clusterTime"]: + cluster_time: Optional[ClusterTime] = topology_time + else: + cluster_time = session_time + else: + cluster_time = topology_time or session_time + if cluster_time: + command["$clusterTime"] = cluster_time + + def get_default_database( + self, + default: Optional[str] = None, + codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional[ReadConcern] = None, + ) -> database.Database[_DocumentType]: + """Get the database named in the MongoDB connection URI. + + >>> uri = 'mongodb://host/my_database' + >>> client = MongoClient(uri) + >>> db = client.get_default_database() + >>> assert db.name == 'my_database' + >>> db = client.get_database() + >>> assert db.name == 'my_database' + + Useful in scripts where you want to choose which database to use + based only on the URI in a configuration file. + + :param default: the database name to use if no database name + was provided in the URI. + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. If ``None`` (the + default) the :attr:`codec_options` of this :class:`MongoClient` is + used. + :param read_preference: The read preference to use. If + ``None`` (the default) the :attr:`read_preference` of this + :class:`MongoClient` is used. See :mod:`~pymongo.read_preferences` + for options. + :param write_concern: An instance of + :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the + default) the :attr:`write_concern` of this :class:`MongoClient` is + used. + :param read_concern: An instance of + :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the + default) the :attr:`read_concern` of this :class:`MongoClient` is + used. + :param comment: A user-provided comment to attach to this + command. + + .. versionchanged:: 4.1 + Added ``comment`` parameter. + + .. versionchanged:: 3.8 + Undeprecated. Added the ``default``, ``codec_options``, + ``read_preference``, ``write_concern`` and ``read_concern`` + parameters. + + .. versionchanged:: 3.5 + Deprecated, use :meth:`get_database` instead. + """ + if self._default_database_name is None and default is None: + raise ConfigurationError("No default database name defined or provided.") + + name = cast(str, self._default_database_name or default) + return database.Database( + self, name, codec_options, read_preference, write_concern, read_concern + ) + + def get_database( + self, + name: Optional[str] = None, + codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, + read_preference: Optional[_ServerMode] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional[ReadConcern] = None, + ) -> database.Database[_DocumentType]: + """Get a :class:`~pymongo.database.Database` with the given name and + options. + + Useful for creating a :class:`~pymongo.database.Database` with + different codec options, read preference, and/or write concern from + this :class:`MongoClient`. + + >>> client.read_preference + Primary() + >>> db1 = client.test + >>> db1.read_preference + Primary() + >>> from pymongo import ReadPreference + >>> db2 = client.get_database( + ... 'test', read_preference=ReadPreference.SECONDARY) + >>> db2.read_preference + Secondary(tag_sets=None) + + :param name: The name of the database - a string. If ``None`` + (the default) the database named in the MongoDB connection URI is + returned. + :param codec_options: An instance of + :class:`~bson.codec_options.CodecOptions`. If ``None`` (the + default) the :attr:`codec_options` of this :class:`MongoClient` is + used. + :param read_preference: The read preference to use. If + ``None`` (the default) the :attr:`read_preference` of this + :class:`MongoClient` is used. See :mod:`~pymongo.read_preferences` + for options. + :param write_concern: An instance of + :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the + default) the :attr:`write_concern` of this :class:`MongoClient` is + used. + :param read_concern: An instance of + :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the + default) the :attr:`read_concern` of this :class:`MongoClient` is + used. + + .. versionchanged:: 3.5 + The `name` parameter is now optional, defaulting to the database + named in the MongoDB connection URI. + """ + if name is None: + if self._default_database_name is None: + raise ConfigurationError("No default database defined") + name = self._default_database_name + + return database.Database( + self, name, codec_options, read_preference, write_concern, read_concern + ) + + def _database_default_options(self, name: str) -> database.Database: + """Get a Database instance with the default settings.""" + return self.get_database( + name, + codec_options=DEFAULT_CODEC_OPTIONS, + read_preference=ReadPreference.PRIMARY, + write_concern=DEFAULT_WRITE_CONCERN, + ) + + def __enter__(self) -> MongoClient[_DocumentType]: + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.close() + + # See PYTHON-3084. + __iter__ = None + + def __next__(self) -> NoReturn: + raise TypeError("'MongoClient' object is not iterable") + + next = __next__ + + def _server_property(self, attr_name: str) -> Any: + """An attribute of the current server's description. + + If the client is not connected, this will block until a connection is + established or raise ServerSelectionTimeoutError if no server is + available. + + Not threadsafe if used multiple times in a single method, since + the server may change. In such cases, store a local reference to a + ServerDescription first, then use its properties. + """ + server = (self._get_topology()).select_server(writable_server_selector, _Op.TEST) + + return getattr(server.description, attr_name) + @property def address(self) -> Optional[tuple[str, int]]: """(host, port) of the current standalone, primary, or mongos, or None. @@ -1164,30 +1513,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): """ return self._server_property("server_type") == SERVER_TYPE.Mongos - @property - def nodes(self) -> FrozenSet[_Address]: - """Set of all currently connected servers. - - .. warning:: When connected to a replica set the value of :attr:`nodes` - can change over time as :class:`MongoClient`'s view of the replica - set changes. :attr:`nodes` can also be an empty set when - :class:`MongoClient` is first instantiated and hasn't yet connected - to any servers, or a network partition causes it to lose connection - to all servers. - """ - description = self._topology.description - return frozenset(s.address for s in description.known_servers) - - @property - def options(self) -> ClientOptions: - """The configuration options for this client. - - :return: An instance of :class:`~pymongo.client_options.ClientOptions`. - - .. versionadded:: 4.0 - """ - return self.__options - def _end_sessions(self, session_ids: list[_ServerSession]) -> None: """Send endSessions command(s) with the given session ids.""" try: @@ -1236,6 +1561,11 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): if self._encrypter: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. self._encrypter.close() + self._closed = True + + if not _IS_SYNC: + # Add support for contextlib.closing. + close = close def _get_topology(self) -> Topology: """Get the internal :class:`~pymongo.topology.Topology` object. @@ -1243,13 +1573,17 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): If this client was created with "connect=False", calling _get_topology launches the connection process in the background. """ - self._topology.open() - with self.__lock: - self._kill_cursors_executor.open() + if not self._opened: + self._topology.open() + with self._lock: + self._kill_cursors_executor.open() + self._opened = True return self._topology @contextlib.contextmanager - def _checkout(self, server: Server, session: Optional[ClientSession]) -> Iterator[Connection]: + def _checkout( + self, server: Server, session: Optional[ClientSession] + ) -> Generator[Connection, None]: in_txn = session and session.in_transaction with _MongoClientErrorHandler(self, server, session) as err_handler: # Reuse the pinned connection, if it exists. @@ -1291,12 +1625,12 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): ) -> Server: """Select a server to run an operation on this client. - :param server_selector: The server selector to use if the session is + :Parameters: + - `server_selector`: The server selector to use if the session is not pinned and no address is given. - :param session: The ClientSession for the next operation, or None. May + - `session`: The ClientSession for the next operation, or None. May be pinned to a mongos server address. - :param operation: The name of the operation that the server is being selected for. - :param address: Address when sending a message + - `address` (optional): Address when sending a message to a specific server, used for getMore. """ try: @@ -1336,7 +1670,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): @contextlib.contextmanager def _conn_from_server( self, read_preference: _ServerMode, server: Server, session: Optional[ClientSession] - ) -> Iterator[tuple[Connection, _ServerMode]]: + ) -> Generator[tuple[Connection, _ServerMode], None]: assert read_preference is not None, "read_preference must not be None" # Get a connection for a server matching the read preference, and yield # conn with the effective read preference. The Server Selection @@ -1344,9 +1678,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): # always send primaryPreferred when directly connected to a repl set # member. # Thread safe: if the type is single it cannot change. - topology = self._get_topology() - single = topology.description.topology_type == TOPOLOGY_TYPE.Single - + # NOTE: We already opened the Topology when selecting a server so there's no need + # to call _get_topology() again. + single = self._topology.description.topology_type == TOPOLOGY_TYPE.Single with self._checkout(server, session) as conn: if single: if conn.is_repl and not (session and session.in_transaction): @@ -1365,13 +1699,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): operation: str, ) -> ContextManager[tuple[Connection, _ServerMode]]: assert read_preference is not None, "read_preference must not be None" - _ = self._get_topology() server = self._select_server(read_preference, session, operation) return self._conn_from_server(read_preference, server, session) - def _should_pin_cursor(self, session: Optional[ClientSession]) -> Optional[bool]: - return self.__options.load_balanced and not (session and session.in_transaction) - @_csot.apply def _run_operation( self, @@ -1389,13 +1719,13 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): if operation.conn_mgr: server = self._select_server( operation.read_preference, - operation.session, + operation.session, # type: ignore[arg-type] operation.name, address=address, ) - with operation.conn_mgr.lock: - with _MongoClientErrorHandler(self, server, operation.session) as err_handler: + with operation.conn_mgr._alock: + with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) return server.run_operation( operation.conn_mgr.conn, @@ -1425,9 +1755,9 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): return self._retryable_read( _cmd, operation.read_preference, - operation.session, + operation.session, # type: ignore[arg-type] address=address, - retryable=isinstance(operation, message._Query), + retryable=isinstance(operation, _Query), operation=operation.name, ) @@ -1436,7 +1766,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): retryable: bool, func: _WriteCall[T], session: Optional[ClientSession], - bulk: Optional[_Bulk], + bulk: Optional[Union[_Bulk, _ClientBulk]], operation: str, operation_id: Optional[int] = None, ) -> T: @@ -1466,7 +1796,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): self, func: _WriteCall[T] | _ReadCall[T], session: Optional[ClientSession], - bulk: Optional[_Bulk], + bulk: Optional[Union[_Bulk, _ClientBulk]], operation: str, is_read: bool = False, address: Optional[_Address] = None, @@ -1549,7 +1879,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): func: _WriteCall[T], session: Optional[ClientSession], operation: str, - bulk: Optional[_Bulk] = None, + bulk: Optional[Union[_Bulk, _ClientBulk]] = None, operation_id: Optional[int] = None, ) -> T: """Execute an operation with consecutive retries if possible @@ -1568,127 +1898,63 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): with self._tmp_session(session) as s: return self._retry_with_session(retryable, func, s, bulk, operation, operation_id) - def __eq__(self, other: Any) -> bool: - if isinstance(other, self.__class__): - return self._topology == other._topology - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __hash__(self) -> int: - return hash(self._topology) - - def _repr_helper(self) -> str: - def option_repr(option: str, value: Any) -> str: - """Fix options whose __repr__ isn't usable in a constructor.""" - if option == "document_class": - if value is dict: - return "document_class=dict" - else: - return f"document_class={value.__module__}.{value.__name__}" - if option in common.TIMEOUT_OPTIONS and value is not None: - return f"{option}={int(value * 1000)}" - - return f"{option}={value!r}" - - # Host first... - options = [ - "host=%r" - % [ - "%s:%d" % (host, port) if port is not None else host - for host, port in self._topology_settings.seeds - ] - ] - # ... then everything in self._constructor_args... - options.extend( - option_repr(key, self.__options._options[key]) for key in self._constructor_args - ) - # ... then everything else. - options.extend( - option_repr(key, self.__options._options[key]) - for key in self.__options._options - if key not in set(self._constructor_args) and key != "username" and key != "password" - ) - return ", ".join(options) - - def __repr__(self) -> str: - return f"MongoClient({self._repr_helper()})" - - def __getattr__(self, name: str) -> database.Database[_DocumentType]: - """Get a database by name. - - Raises :class:`~pymongo.errors.InvalidName` if an invalid - database name is used. - - :param name: the name of the database to get - """ - if name.startswith("_"): - raise AttributeError( - f"MongoClient has no attribute {name!r}. To access the {name}" - f" database, use client[{name!r}]." - ) - return self.__getitem__(name) - - def __getitem__(self, name: str) -> database.Database[_DocumentType]: - """Get a database by name. - - Raises :class:`~pymongo.errors.InvalidName` if an invalid - database name is used. - - :param name: the name of the database to get - """ - return database.Database(self, name) - - def _cleanup_cursor( + def _cleanup_cursor_no_lock( self, - locks_allowed: bool, cursor_id: int, address: Optional[_CursorAddress], conn_mgr: _ConnectionManager, session: Optional[ClientSession], explicit_session: bool, ) -> None: - """Cleanup a cursor from cursor.close() or __del__. + """Cleanup a cursor from __del__ without locking. + + This method handles cleanup for Cursors/CommandCursors including any + pinned connection attached at the time the cursor + was garbage collected. + + :param cursor_id: The cursor id which may be 0. + :param address: The _CursorAddress. + :param conn_mgr: The _ConnectionManager for the pinned connection or None. + """ + # The cursor will be closed later in a different session. + if cursor_id or conn_mgr: + self._close_cursor_soon(cursor_id, address, conn_mgr) + if session and not explicit_session: + session._end_implicit_session() + + def _cleanup_cursor_lock( + self, + cursor_id: int, + address: Optional[_CursorAddress], + conn_mgr: _ConnectionManager, + session: Optional[ClientSession], + explicit_session: bool, + ) -> None: + """Cleanup a cursor from cursor.close() using a lock. This method handles cleanup for Cursors/CommandCursors including any pinned connection or implicit session attached at the time the cursor was closed or garbage collected. - :param locks_allowed: True if we are allowed to acquire locks. :param cursor_id: The cursor id which may be 0. :param address: The _CursorAddress. :param conn_mgr: The _ConnectionManager for the pinned connection or None. :param session: The cursor's session. :param explicit_session: True if the session was passed explicitly. """ - if locks_allowed: - if cursor_id: - if conn_mgr and conn_mgr.more_to_come: - # If this is an exhaust cursor and we haven't completely - # exhausted the result set we *must* close the socket - # to stop the server from sending more data. - assert conn_mgr.conn is not None - conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR) - else: - self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr) - if conn_mgr: - conn_mgr.close() - else: - # The cursor will be closed later in a different session. - if cursor_id or conn_mgr: - self._close_cursor_soon(cursor_id, address, conn_mgr) + if cursor_id: + if conn_mgr and conn_mgr.more_to_come: + # If this is an exhaust cursor and we haven't completely + # exhausted the result set we *must* close the socket + # to stop the server from sending more data. + assert conn_mgr.conn is not None + conn_mgr.conn.close_conn(ConnectionClosedReason.ERROR) + else: + self._close_cursor_now(cursor_id, address, session=session, conn_mgr=conn_mgr) + if conn_mgr: + conn_mgr.close() if session and not explicit_session: - session._end_session(lock=locks_allowed) - - def _close_cursor_soon( - self, - cursor_id: int, - address: Optional[_CursorAddress], - conn_mgr: Optional[_ConnectionManager] = None, - ) -> None: - """Request that a cursor and/or connection be cleaned up soon.""" - self.__kill_cursors_queue.append((address, cursor_id, conn_mgr)) + session._end_implicit_session() def _close_cursor_now( self, @@ -1706,7 +1972,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): try: if conn_mgr: - with conn_mgr.lock: + with conn_mgr._alock: # Cursor is pinned to LB outside of a transaction. assert address is not None assert conn_mgr.conn is not None @@ -1757,7 +2023,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): # Other threads or the GC may append to the queue concurrently. while True: try: - address, cursor_id, conn_mgr = self.__kill_cursors_queue.pop() + address, cursor_id, conn_mgr = self._kill_cursors_queue.pop() except IndexError: break @@ -1768,14 +2034,14 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): for address, cursor_id, conn_mgr in pinned_cursors: try: - self._cleanup_cursor(True, cursor_id, address, conn_mgr, None, False) + self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it # can be caught in _process_periodic_tasks raise else: - helpers._handle_exception() + helpers_shared._handle_exception() # Don't re-open topology if it's closed and there's no pending cursors. if address_to_cursor_ids: @@ -1787,7 +2053,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): if isinstance(exc, InvalidOperation) and self._topology._closed: raise else: - helpers._handle_exception() + helpers_shared._handle_exception() # This method is run periodically by a background thread. def _process_periodic_tasks(self) -> None: @@ -1801,62 +2067,15 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): if isinstance(exc, InvalidOperation) and self._topology._closed: return else: - helpers._handle_exception() - - def __start_session(self, implicit: bool, **kwargs: Any) -> ClientSession: - server_session = _EmptyServerSession() - opts = client_session.SessionOptions(**kwargs) - return client_session.ClientSession(self, server_session, opts, implicit) - - def start_session( - self, - causal_consistency: Optional[bool] = None, - default_transaction_options: Optional[client_session.TransactionOptions] = None, - snapshot: Optional[bool] = False, - ) -> client_session.ClientSession: - """Start a logical session. - - This method takes the same parameters as - :class:`~pymongo.client_session.SessionOptions`. See the - :mod:`~pymongo.client_session` module for details and examples. - - A :class:`~pymongo.client_session.ClientSession` may only be used with - the MongoClient that started it. :class:`ClientSession` instances are - **not thread-safe or fork-safe**. They can only be used by one thread - or process at a time. A single :class:`ClientSession` cannot be used - to run multiple operations concurrently. - - :return: An instance of :class:`~pymongo.client_session.ClientSession`. - - .. versionadded:: 3.6 - """ - return self.__start_session( - False, - causal_consistency=causal_consistency, - default_transaction_options=default_transaction_options, - snapshot=snapshot, - ) + helpers_shared._handle_exception() def _return_server_session( - self, server_session: Union[_ServerSession, _EmptyServerSession], lock: bool + self, server_session: Union[_ServerSession, _EmptyServerSession] ) -> None: """Internal: return a _ServerSession to the pool.""" if isinstance(server_session, _EmptyServerSession): return None - return self._topology.return_server_session(server_session, lock) - - def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]: - """If provided session is None, lend a temporary session.""" - if session: - return session - - try: - # Don't make implicit sessions causally consistent. Applications - # should always opt-in. - return self.__start_session(True, causal_consistency=False) - except (ConfigurationError, InvalidOperation): - # Sessions not supported. - return None + return self._topology.return_server_session(server_session) @contextlib.contextmanager def _tmp_session( @@ -1888,21 +2107,6 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): else: yield None - def _send_cluster_time( - self, command: MutableMapping[str, Any], session: Optional[ClientSession] - ) -> None: - topology_time = self._topology.max_cluster_time() - session_time = session.cluster_time if session else None - if topology_time and session_time: - if topology_time["clusterTime"] > session_time["clusterTime"]: - cluster_time: Optional[ClusterTime] = topology_time - else: - cluster_time = session_time - else: - cluster_time = topology_time or session_time - if cluster_time: - command["$clusterTime"] = cluster_time - def _process_response(self, reply: Mapping[str, Any], session: Optional[ClientSession]) -> None: self._topology.receive_cluster_time(reply.get("$clusterTime")) if session is not None: @@ -1924,6 +2128,26 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): ), ) + def _list_databases( + self, + session: Optional[client_session.ClientSession] = None, + comment: Optional[Any] = None, + **kwargs: Any, + ) -> CommandCursor[dict[str, Any]]: + cmd = {"listDatabases": 1} + cmd.update(kwargs) + if comment is not None: + cmd["comment"] = comment + admin = self._database_default_options("admin") + res = admin._retryable_read_command(cmd, session=session, operation=_Op.LIST_DATABASES) + # listDatabases doesn't return a cursor (yet). Fake one. + cursor = { + "id": 0, + "firstBatch": res["databases"], + "ns": "admin.$cmd", + } + return CommandCursor(admin["$cmd"], cursor, None, comment=comment) + def list_databases( self, session: Optional[client_session.ClientSession] = None, @@ -1947,19 +2171,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): .. versionadded:: 3.6 """ - cmd = {"listDatabases": 1} - cmd.update(kwargs) - if comment is not None: - cmd["comment"] = comment - admin = self._database_default_options("admin") - res = admin._retryable_read_command(cmd, session=session, operation=_Op.LIST_DATABASES) - # listDatabases doesn't return a cursor (yet). Fake one. - cursor = { - "id": 0, - "firstBatch": res["databases"], - "ns": "admin.$cmd", - } - return CommandCursor(admin["$cmd"], cursor, None, comment=comment) + return self._list_databases(session, comment, **kwargs) def list_database_names( self, @@ -1978,7 +2190,8 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): .. versionadded:: 3.6 """ - return [doc["name"] for doc in self.list_databases(session, nameOnly=True, comment=comment)] + res = self._list_databases(session, nameOnly=True, comment=comment) + return [doc["name"] for doc in res] @_csot.apply def drop_database( @@ -2031,152 +2244,136 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): session=session, ) - def get_default_database( + @_csot.apply + def bulk_write( self, - default: Optional[str] = None, - codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, - read_preference: Optional[_ServerMode] = None, + models: Sequence[_WriteOp[_DocumentType]], + session: Optional[ClientSession] = None, + ordered: bool = True, + verbose_results: bool = False, + bypass_document_validation: Optional[bool] = None, + comment: Optional[Any] = None, + let: Optional[Mapping] = None, write_concern: Optional[WriteConcern] = None, - read_concern: Optional[ReadConcern] = None, - ) -> database.Database[_DocumentType]: - """Get the database named in the MongoDB connection URI. + ) -> ClientBulkWriteResult: + """Send a batch of write operations, potentially across multiple namespaces, to the server. - >>> uri = 'mongodb://host/my_database' - >>> client = MongoClient(uri) - >>> db = client.get_default_database() - >>> assert db.name == 'my_database' - >>> db = client.get_database() - >>> assert db.name == 'my_database' + Requests are passed as a list of write operation instances ( + :class:`~pymongo.operations.InsertOne`, + :class:`~pymongo.operations.UpdateOne`, + :class:`~pymongo.operations.UpdateMany`, + :class:`~pymongo.operations.ReplaceOne`, + :class:`~pymongo.operations.DeleteOne`, or + :class:`~pymongo.operations.DeleteMany`). - Useful in scripts where you want to choose which database to use - based only on the URI in a configuration file. + >>> for doc in db.test.find({}): + ... print(doc) + ... + {'x': 1, '_id': ObjectId('54f62e60fba5226811f634ef')} + {'x': 1, '_id': ObjectId('54f62e60fba5226811f634f0')} + ... + >>> for doc in db.coll.find({}): + ... print(doc) + ... + {'x': 2, '_id': ObjectId('507f1f77bcf86cd799439011')} + ... + >>> # DeleteMany, UpdateOne, and UpdateMany are also available. + >>> from pymongo import InsertOne, DeleteOne, ReplaceOne + >>> models = [InsertOne(namespace="db.test", document={'y': 1}), + ... DeleteOne(namespace="db.test", filter={'x': 1}), + ... InsertOne(namespace="db.coll", document={'y': 2}), + ... ReplaceOne(namespace="db.test", filter={'w': 1}, replacement={'z': 1}, upsert=True)] + >>> result = client.bulk_write(models=models) + >>> result.inserted_count + 2 + >>> result.deleted_count + 1 + >>> result.modified_count + 0 + >>> result.upserted_count + 1 + >>> for doc in db.test.find({}): + ... print(doc) + ... + {'x': 1, '_id': ObjectId('54f62e60fba5226811f634f0')} + {'y': 1, '_id': ObjectId('54f62ee2fba5226811f634f1')} + {'z': 1, '_id': ObjectId('54f62ee28891e756a6e1abd5')} + ... + >>> for doc in db.coll.find({}): + ... print(doc) + ... + {'x': 2, '_id': ObjectId('507f1f77bcf86cd799439011')} + {'y': 2, '_id': ObjectId('507f1f77bcf86cd799439012')} - :param default: the database name to use if no database name - was provided in the URI. - :param codec_options: An instance of - :class:`~bson.codec_options.CodecOptions`. If ``None`` (the - default) the :attr:`codec_options` of this :class:`MongoClient` is - used. - :param read_preference: The read preference to use. If - ``None`` (the default) the :attr:`read_preference` of this - :class:`MongoClient` is used. See :mod:`~pymongo.read_preferences` - for options. - :param write_concern: An instance of - :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the - default) the :attr:`write_concern` of this :class:`MongoClient` is - used. - :param read_concern: An instance of - :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the - default) the :attr:`read_concern` of this :class:`MongoClient` is - used. - :param comment: A user-provided comment to attach to this + :param models: A list of write operation instances. + :param session: (optional) An instance of + :class:`~pymongo.client_session.ClientSession`. + :param ordered: If ``True`` (the default), requests will be + performed on the server serially, in the order provided. If an error + occurs all remaining operations are aborted. If ``False``, requests + will be still performed on the server serially, in the order provided, + but all operations will be attempted even if any errors occur. + :param verbose_results: If ``True``, detailed results for each + successful operation will be included in the returned + :class:`~pymongo.results.ClientBulkWriteResult`. Default is ``False``. + :param bypass_document_validation: (optional) If ``True``, allows the + write to opt-out of document level validation. Default is ``False``. + :param comment: (optional) A user-provided comment to attach to this command. + :param let: (optional) Map of parameter names and values. Values must be + constant or closed expressions that do not reference document + fields. Parameters can then be accessed as variables in an + aggregate expression context (e.g. "$$var"). + :param write_concern: (optional) The write concern to use for this bulk write. - .. versionchanged:: 4.1 - Added ``comment`` parameter. + :return: An instance of :class:`~pymongo.results.ClientBulkWriteResult`. - .. versionchanged:: 3.8 - Undeprecated. Added the ``default``, ``codec_options``, - ``read_preference``, ``write_concern`` and ``read_concern`` - parameters. + .. seealso:: For more info, see :doc:`/examples/client_bulk`. - .. versionchanged:: 3.5 - Deprecated, use :meth:`get_database` instead. + .. seealso:: :ref:`writes-and-ids` + + .. note:: requires MongoDB server version 8.0+. + + .. versionadded:: 4.9 """ - if self.__default_database_name is None and default is None: - raise ConfigurationError("No default database name defined or provided.") + if self._options.auto_encryption_opts: + raise InvalidOperation( + "MongoClient.bulk_write does not currently support automatic encryption" + ) - name = cast(str, self.__default_database_name or default) - return database.Database( - self, name, codec_options, read_preference, write_concern, read_concern + if session and session.in_transaction: + # Inherit the transaction write concern. + if write_concern: + raise InvalidOperation("Cannot set write concern after starting a transaction") + write_concern = session._transaction.opts.write_concern # type: ignore[union-attr] + else: + # Inherit the client's write concern if none is provided. + if not write_concern: + write_concern = self.write_concern + + common.validate_list("models", models) + + blk = _ClientBulk( + self, + write_concern=write_concern, # type: ignore[arg-type] + ordered=ordered, + bypass_document_validation=bypass_document_validation, + comment=comment, + let=let, + verbose_results=verbose_results, ) + for model in models: + try: + model._add_to_client_bulk(blk) + except AttributeError: + raise TypeError(f"{model!r} is not a valid request") from None - def get_database( - self, - name: Optional[str] = None, - codec_options: Optional[CodecOptions[_DocumentTypeArg]] = None, - read_preference: Optional[_ServerMode] = None, - write_concern: Optional[WriteConcern] = None, - read_concern: Optional[ReadConcern] = None, - ) -> database.Database[_DocumentType]: - """Get a :class:`~pymongo.database.Database` with the given name and - options. - - Useful for creating a :class:`~pymongo.database.Database` with - different codec options, read preference, and/or write concern from - this :class:`MongoClient`. - - >>> client.read_preference - Primary() - >>> db1 = client.test - >>> db1.read_preference - Primary() - >>> from pymongo import ReadPreference - >>> db2 = client.get_database( - ... 'test', read_preference=ReadPreference.SECONDARY) - >>> db2.read_preference - Secondary(tag_sets=None) - - :param name: The name of the database - a string. If ``None`` - (the default) the database named in the MongoDB connection URI is - returned. - :param codec_options: An instance of - :class:`~bson.codec_options.CodecOptions`. If ``None`` (the - default) the :attr:`codec_options` of this :class:`MongoClient` is - used. - :param read_preference: The read preference to use. If - ``None`` (the default) the :attr:`read_preference` of this - :class:`MongoClient` is used. See :mod:`~pymongo.read_preferences` - for options. - :param write_concern: An instance of - :class:`~pymongo.write_concern.WriteConcern`. If ``None`` (the - default) the :attr:`write_concern` of this :class:`MongoClient` is - used. - :param read_concern: An instance of - :class:`~pymongo.read_concern.ReadConcern`. If ``None`` (the - default) the :attr:`read_concern` of this :class:`MongoClient` is - used. - - .. versionchanged:: 3.5 - The `name` parameter is now optional, defaulting to the database - named in the MongoDB connection URI. - """ - if name is None: - if self.__default_database_name is None: - raise ConfigurationError("No default database defined") - name = self.__default_database_name - - return database.Database( - self, name, codec_options, read_preference, write_concern, read_concern - ) - - def _database_default_options(self, name: str) -> Database: - """Get a Database instance with the default settings.""" - return self.get_database( - name, - codec_options=DEFAULT_CODEC_OPTIONS, - read_preference=ReadPreference.PRIMARY, - write_concern=DEFAULT_WRITE_CONCERN, - ) - - def __enter__(self) -> MongoClient[_DocumentType]: - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - self.close() - - # See PYTHON-3084. - __iter__ = None - - def __next__(self) -> NoReturn: - raise TypeError("'MongoClient' object is not iterable") - - next = __next__ + return blk.execute(session, _Op.BULK_WRITE) def _retryable_error_doc(exc: PyMongoError) -> Optional[Mapping[str, Any]]: """Return the server response from PyMongo exception or None.""" - if isinstance(exc, BulkWriteError): + if isinstance(exc, (BulkWriteError, ClientBulkWriteException)): # Check the last writeConcernError to determine if this # BulkWriteError is retryable. wces = exc.details["writeConcernErrors"] @@ -2206,15 +2403,19 @@ def _add_retryable_write_error(exc: PyMongoError, max_wire_version: int, is_mong # Do not consult writeConcernError for pre-4.4 mongos. if isinstance(exc, WriteConcernError) and is_mongos: pass - elif code in helpers._RETRYABLE_ERROR_CODES: + elif code in helpers_shared._RETRYABLE_ERROR_CODES: exc._add_error_label("RetryableWriteError") # Connection errors are always retryable except NotPrimaryError and WaitQueueTimeoutError which is # handled above. - if isinstance(exc, ConnectionFailure) and not isinstance( - exc, (NotPrimaryError, WaitQueueTimeoutError) + if isinstance(exc, ClientBulkWriteException): + exc_to_check = exc.error + else: + exc_to_check = exc + if isinstance(exc_to_check, ConnectionFailure) and not isinstance( + exc_to_check, (NotPrimaryError, WaitQueueTimeoutError) ): - exc._add_error_label("RetryableWriteError") + exc_to_check._add_error_label("RetryableWriteError") class _MongoClientErrorHandler: @@ -2232,6 +2433,9 @@ class _MongoClientErrorHandler: ) def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): + if not isinstance(client, MongoClient): + raise TypeError(f"MongoClient required but given {type(client)}") + self.client = client self.server_address = server.description.address self.session = session @@ -2259,6 +2463,8 @@ class _MongoClientErrorHandler: return self.handled = True if self.session: + if isinstance(exc_val, ClientBulkWriteException): + exc_val = exc_val.error if isinstance(exc_val, ConnectionFailure): if self.session.in_transaction: exc_val._add_error_label("TransientTransactionError") @@ -2270,7 +2476,7 @@ class _MongoClientErrorHandler: ): self.session._unpin() err_ctx = _ErrorContext( - exc_val, + exc_val, # type: ignore[arg-type] self.max_wire_version, self.sock_generation, self.completed_handshake, @@ -2297,7 +2503,7 @@ class _ClientConnectionRetryable(Generic[T]): self, mongo_client: MongoClient, func: _WriteCall[T] | _ReadCall[T], - bulk: Optional[_Bulk], + bulk: Optional[Union[_Bulk, _ClientBulk]], operation: str, is_read: bool = False, session: Optional[ClientSession] = None, @@ -2362,7 +2568,7 @@ class _ClientConnectionRetryable(Generic[T]): exc_code = getattr(exc, "code", None) if self._is_not_eligible_for_retry() or ( isinstance(exc, OperationFailure) - and exc_code not in helpers._RETRYABLE_ERROR_CODES + and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES ): raise self._retrying = True @@ -2374,7 +2580,12 @@ class _ClientConnectionRetryable(Generic[T]): if not self._is_read: if not self._retryable: raise - retryable_write_error_exc = exc.has_error_label("RetryableWriteError") + if isinstance(exc, ClientBulkWriteException) and exc.error: + retryable_write_error_exc = isinstance( + exc.error, PyMongoError + ) and exc.error.has_error_label("RetryableWriteError") + else: + retryable_write_error_exc = exc.has_error_label("RetryableWriteError") if retryable_write_error_exc: assert self._session self._session._unpin()