From 2f9d24ade69faf7f4489d9677824c63cfc87b885 Mon Sep 17 00:00:00 2001 From: "A. Jesse Jiryu Davis" Date: Wed, 13 Aug 2014 15:35:17 -0400 Subject: [PATCH] PYTHON-525 Reimplement MongoClient to use Cluster. Replace MongoClient with an implementation that relies on Cluster and Server. The new MongoClient takes over MongoReplicaSetClient's responsibilities. Authentication, secondary-pinning, and Mongos high-availability are broken and will be reimplemented in a future commit. RS tests are temporarily disabled. --- doc/api/pymongo/mongo_client.rst | 1 - doc/changelog.rst | 29 +- gridfs/__init__.py | 27 +- gridfs/grid_file.py | 84 +-- pymongo/bulk.py | 3 +- pymongo/cluster.py | 65 +- pymongo/cluster_description.py | 29 +- pymongo/collection.py | 19 +- pymongo/command_cursor.py | 2 +- pymongo/common.py | 9 +- pymongo/cursor.py | 2 +- pymongo/mongo_client.py | 1020 +++++++++++------------------ pymongo/mongo_client_new.py | 127 ---- pymongo/monitor.py | 2 +- pymongo/pool.py | 28 +- pymongo/server.py | 94 ++- pymongo/settings.py | 12 - pymongo/thread_util.py | 30 - test/__init__.py | 57 +- test/pymongo_mocks.py | 107 +-- test/test_binary.py | 2 +- test/test_client.py | 455 ++++++------- test/test_client_new.py | 72 -- test/test_cluster.py | 66 +- test/test_collection.py | 5 +- test/test_common.py | 33 +- test/test_cursor.py | 2 +- test/test_database.py | 2 +- test/test_errors.py | 8 +- test/test_grid_file.py | 33 +- test/test_gridfs.py | 39 +- test/test_mongos_ha.py | 30 +- test/test_pooling.py | 28 +- test/test_read_preferences.py | 9 +- test/test_replica_set_client.py | 87 +-- test/test_replica_set_reconfig.py | 94 +-- test/test_son_manipulator.py | 2 +- test/test_ssl.py | 8 +- test/test_threads.py | 9 +- test/utils.py | 41 +- 40 files changed, 1301 insertions(+), 1471 deletions(-) delete mode 100644 pymongo/mongo_client_new.py delete mode 100644 test/test_client_new.py diff --git a/doc/api/pymongo/mongo_client.rst b/doc/api/pymongo/mongo_client.rst index a420ae340..66ae65069 100644 --- a/doc/api/pymongo/mongo_client.rst +++ b/doc/api/pymongo/mongo_client.rst @@ -22,7 +22,6 @@ .. autoattribute:: is_mongos .. autoattribute:: max_pool_size .. autoattribute:: nodes - .. autoattribute:: auto_start_request .. autoattribute:: document_class .. autoattribute:: tz_aware .. autoattribute:: max_bson_size diff --git a/doc/changelog.rst b/doc/changelog.rst index 4cfbd05c8..a1751edfc 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -25,8 +25,35 @@ see :doc:`PyMongo's Gevent documentation `. :class:`~pymongo.MongoClient` Changes ..................................... +:class:`~pymongo.mongo_client.MongoClient` is now the one and only +client class for a standalone server, mongos, or replica set. +It includes the functionality that had been split into +``MongoReplicaSetClient``: it can connect to a replica set, discover all its +members, and monitor the set for stepdowns, elections, and reconfigs. + +The obsolete ``MasterSlaveConnection`` class is removed. + +The :class:`~pymongo.mongo_client.MongoClient` constructor no +longer blocks while connecting to the server or servers, and it no +longer raises :class:`~pymongo.errors.ConnectionFailure` if they +are unavailable, nor :class:`~pymongo.errors.ConfigurationError` +if the user's credentials are wrong. Instead, the constructor +returns immediately and launches the connection process on +background threads. + +The ``connect`` option is added and ``auto_start_request`` is removed. + +In PyMongo 2.x, :class:`~pymongo.MongoClient` accepted a list of standalone +MongoDB servers and used the first it could connect to:: + + MongoClient(['host1.com:27017', 'host2.com:27017']) + +A list of multiple standalones is no longer supported; if multiple servers +are listed they must be members of the same replica set, or mongoses in the +same sharded cluster. + The second parameter to :meth:`~pymongo.MongoClient.close_cursor` is renamed -from ``_conn_id`` to ``address`` and is no longer optional. +from ``_conn_id`` to ``address``. :meth:`~pymongo.MongoClient.set_cursor_manager` is no longer deprecated. diff --git a/gridfs/__init__.py b/gridfs/__init__.py index 9da271072..3aaacbd63 100644 --- a/gridfs/__init__.py +++ b/gridfs/__init__.py @@ -25,8 +25,7 @@ from gridfs.errors import (NoFile, from gridfs.grid_file import (GridIn, GridOut, GridOutCursor) -from pymongo import (MongoClient, - ASCENDING, +from pymongo import (ASCENDING, DESCENDING) from pymongo.database import Database @@ -34,18 +33,21 @@ from pymongo.database import Database class GridFS(object): """An instance of GridFS on top of a single Database. """ - def __init__(self, database, collection="fs", _connect=True): + def __init__(self, database, collection="fs"): """Create a new instance of :class:`GridFS`. Raises :class:`TypeError` if `database` is not an instance of :class:`~pymongo.database.Database`. + The `connect` parameter ensures that the underlying + :class:`~pymongo.mongo_client.MongoClient` is connected to a server, + and creates an index on the "chunks" collection if needed. + :Parameters: - `database`: database to use - `collection` (optional): root collection to use - - .. versionadded:: 1.6 - The `collection` parameter. + - `connect` (optional): whether to begin connecting the client in + the background .. mongodoc:: gridfs """ @@ -56,15 +58,10 @@ class GridFS(object): self.__collection = database[collection] self.__files = self.__collection.files self.__chunks = self.__collection.chunks - if _connect: - self.__ensure_index_files_id() def __is_secondary(self): client = self.__database.connection - - # Connect the client, so we know if it's connected to the primary. - client._ensure_connected() - return isinstance(client, MongoClient) and not client.is_primary + return not client._is_writable() def __ensure_index_files_id(self): if not self.__is_secondary(): @@ -156,7 +153,11 @@ class GridFS(object): .. versionadded:: 1.6 """ - return GridOut(self.__collection, file_id) + gout = GridOut(self.__collection, file_id) + + # Raise NoFile now, instead of on first attribute access. + gout._ensure_file() + return gout def get_version(self, filename=None, version=-1, **kwargs): """Get a file from GridFS by ``"filename"`` or metadata fields. diff --git a/gridfs/grid_file.py b/gridfs/grid_file.py index fe8b9ff61..6e3aa9808 100644 --- a/gridfs/grid_file.py +++ b/gridfs/grid_file.py @@ -49,10 +49,9 @@ NEWLN = b"\n" DEFAULT_CHUNK_SIZE = 255 * 1024 -def _create_property(field_name, docstring, - read_only=False, closed_only=False): - """Helper for creating properties to read/write to files. - """ +def _grid_in_property(field_name, docstring, read_only=False, + closed_only=False): + """Create a GridIn property.""" def getter(self): if closed_only and not self._closed: raise AttributeError("can only get %r on a closed file" % @@ -70,7 +69,7 @@ def _create_property(field_name, docstring, self._file[field_name] = value if read_only: - docstring = docstring + "\n\nThis attribute is read-only." + docstring += "\n\nThis attribute is read-only." elif closed_only: docstring = "%s\n\n%s" % (docstring, "This attribute is read-only and " "can only be read after :meth:`close` " @@ -81,6 +80,20 @@ def _create_property(field_name, docstring, return property(getter, doc=docstring) +def _grid_out_property(field_name, docstring): + """Create a GridOut property.""" + def getter(self): + self._ensure_file() + + # Protect against PHP-237 + if field_name == 'length': + return self._file.get(field_name, 0) + return self._file.get(field_name, None) + + docstring += "\n\nThis attribute is read-only." + return property(getter, doc=docstring) + + class GridIn(object): """Class to write data to GridFS. """ @@ -177,19 +190,19 @@ class GridIn(object): """ return self._closed - _id = _create_property("_id", "The ``'_id'`` value for this file.", + _id = _grid_in_property("_id", "The ``'_id'`` value for this file.", read_only=True) - filename = _create_property("filename", "Name of this file.") - name = _create_property("filename", "Alias for `filename`.") - content_type = _create_property("contentType", "Mime-type for this file.") - length = _create_property("length", "Length (in bytes) of this file.", + filename = _grid_in_property("filename", "Name of this file.") + name = _grid_in_property("filename", "Alias for `filename`.") + content_type = _grid_in_property("contentType", "Mime-type for this file.") + length = _grid_in_property("length", "Length (in bytes) of this file.", closed_only=True) - chunk_size = _create_property("chunkSize", "Chunk size for this file.", + chunk_size = _grid_in_property("chunkSize", "Chunk size for this file.", read_only=True) - upload_date = _create_property("uploadDate", + upload_date = _grid_in_property("uploadDate", "Date that this file was uploaded.", closed_only=True) - md5 = _create_property("md5", "MD5 of the contents of this file " + md5 = _grid_in_property("md5", "MD5 of the contents of this file " "(generated on the server).", closed_only=True) @@ -370,8 +383,7 @@ class GridIn(object): class GridOut(object): """Class to read data out of GridFS. """ - def __init__(self, root_collection, file_id=None, file_document=None, - _connect=True): + def __init__(self, root_collection, file_id=None, file_document=None): """Read a file from GridFS Application developers should generally not need to @@ -385,11 +397,13 @@ class GridOut(object): :Parameters: - `root_collection`: root collection to read from - - `file_id`: value of ``"_id"`` for the file to read - - `file_document`: file document from `root_collection.files` + - `file_id` (optional): value of ``"_id"`` for the file to read + - `file_document` (optional): file document from + `root_collection.files` - .. versionadded:: 1.9 - The `file_document` parameter. + .. versionchanged:: 3.0 + Creating a GridOut does not immediately retrieve the file metadata + from the server. Metadata is fetched when first needed. """ if not isinstance(root_collection, Collection): raise TypeError("root_collection must be an " @@ -401,27 +415,19 @@ class GridOut(object): self.__buffer = EMPTY self.__position = 0 self._file = file_document - if _connect: - self._ensure_file() - _id = _create_property("_id", "The ``'_id'`` value for this file.", True) - filename = _create_property("filename", "Name of this file.", True) - name = _create_property("filename", "Alias for `filename`.", True) - content_type = _create_property("contentType", "Mime-type for this file.", - True) - length = _create_property("length", "Length (in bytes) of this file.", - True) - chunk_size = _create_property("chunkSize", "Chunk size for this file.", - True) - upload_date = _create_property("uploadDate", - "Date that this file was first uploaded.", - True) - aliases = _create_property("aliases", "List of aliases for this file.", - True) - metadata = _create_property("metadata", "Metadata attached to this file.", - True) - md5 = _create_property("md5", "MD5 of the contents of this file " - "(generated on the server).", True) + _id = _grid_out_property("_id", "The ``'_id'`` value for this file.") + filename = _grid_out_property("filename", "Name of this file.") + name = _grid_out_property("filename", "Alias for `filename`.") + content_type = _grid_out_property("contentType", "Mime-type for this file.") + length = _grid_out_property("length", "Length (in bytes) of this file.") + chunk_size = _grid_out_property("chunkSize", "Chunk size for this file.") + upload_date = _grid_out_property("uploadDate", + "Date that this file was first uploaded.") + aliases = _grid_out_property("aliases", "List of aliases for this file.") + metadata = _grid_out_property("metadata", "Metadata attached to this file.") + md5 = _grid_out_property("md5", "MD5 of the contents of this file " + "(generated on the server).") def _ensure_file(self): if not self._file: diff --git a/pymongo/bulk.py b/pymongo/bulk.py index d2830cca6..cc6fb4e2d 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -423,7 +423,6 @@ class _Bulk(object): 'only be executed once.') self.executed = True client = self.collection.database.connection - client._ensure_connected(sync=True) write_concern = write_concern or self.collection.write_concern if self.ordered: @@ -433,7 +432,7 @@ class _Bulk(object): if write_concern.get('w') == 0: self.execute_no_results(generator) - elif client.max_wire_version > 1: + elif client._writable_max_wire_version() > 1: return self.execute_command(generator, write_concern) else: return self.execute_legacy(generator, write_concern) diff --git a/pymongo/cluster.py b/pymongo/cluster.py index 5291a25ac..afbc580e7 100644 --- a/pymongo/cluster.py +++ b/pymongo/cluster.py @@ -24,6 +24,8 @@ from pymongo.cluster_description import (updated_cluster_description, ClusterDescription) from pymongo.errors import AutoReconnect from pymongo.server import Server +from pymongo.server_selectors import (secondary_server_selector, + writable_server_selector) class Cluster(object): @@ -49,35 +51,30 @@ class Cluster(object): with self._lock: self._ensure_opened() - def select_servers(self, selector, server_wait_time=None): + def select_servers(self, selector, + server_wait_time=common.SERVER_WAIT_TIME): """Return a list of Servers matching selector, or time out. :Parameters: - `selector`: function that takes a list of Servers and returns a subset of them. - `server_wait_time` (optional): maximum seconds to wait. If not - provided, the initial ClusterSettings' value is used. + provided, the default value common.SERVER_WAIT_TIME is used. Raises exc:`AutoReconnect` after `server_wait_time` if no matching servers are found. """ - if server_wait_time is not None: - wait_time = server_wait_time - else: - wait_time = self._settings.server_wait_time - with self._lock: self._description.check_compatible() - # TODO: use settings.server_wait_time. # TODO: use monotonic time if available. now = time.time() - end_time = now + wait_time + end_time = now + server_wait_time server_descriptions = self._apply_selector(selector) while not server_descriptions: # No suitable servers. - if wait_time == 0 or now > end_time: + if server_wait_time == 0 or now > end_time: # TODO: more error diagnostics. E.g., if state is # ReplicaSet but every server is Unknown, and the host list # is non-empty, and doesn't intersect with settings.seeds, @@ -96,13 +93,14 @@ class Cluster(object): # came after our most recent selector() call, since we've # held the lock until now. self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) + self._description.check_compatible() now = time.time() server_descriptions = self._apply_selector(selector) return [self.get_server_by_address(sd.address) for sd in server_descriptions] - def select_server(self, selector, server_wait_time=None): + def select_server(self, selector, server_wait_time=common.SERVER_WAIT_TIME): """Like select_servers, but choose a random server if several match.""" return random.choice(self.select_servers(selector, server_wait_time)) @@ -130,6 +128,33 @@ class Cluster(object): def has_server(self, address): return address in self._servers + def get_primary(self): + """Return primary's address or None.""" + # Implemented here in Cluster instead of MongoClient, so it can lock. + with self._lock: + cluster_type = self._description.cluster_type + if cluster_type != CLUSTER_TYPE.ReplicaSetWithPrimary: + return None + + description = writable_server_selector( + self._description.known_servers)[0] + + return description.address + + def get_secondaries(self): + """Return set of secondary addresses.""" + # Implemented here in Cluster instead of MongoClient, so it can lock. + with self._lock: + cluster_type = self._description.cluster_type + if cluster_type not in (CLUSTER_TYPE.ReplicaSetWithPrimary, + CLUSTER_TYPE.ReplicaSetNoPrimary): + return [] + + descriptions = secondary_server_selector( + self._description.known_servers) + + return set([d.address for d in descriptions]) + def close(self): # TODO. raise NotImplementedError @@ -146,8 +171,22 @@ class Cluster(object): if server: server.pool.reset() + def reset_server(self, address): + with self._lock: + server = self._servers.get(address) + if server: + server.pool.reset() + + # Mark this server Unknown. + self._description = self._description.reset_server(address) + self._update_servers() + def reset(self): - """Reset all pools and rediscover all servers.""" + """Reset all pools and disconnect from all servers. + + The cluster reconnects on demand, or after common.HEARTBEAT_FREQUENCY + seconds. + """ with self._lock: for server in self._servers.values(): server.pool.reset() @@ -156,8 +195,6 @@ class Cluster(object): self._description = self._description.reset() self._update_servers() - self._request_check_all() - @property def description(self): return self._description diff --git a/pymongo/cluster_description.py b/pymongo/cluster_description.py index d7a65ef14..543fdce33 100644 --- a/pymongo/cluster_description.py +++ b/pymongo/cluster_description.py @@ -73,6 +73,33 @@ class ClusterDescription(object): common.MIN_SUPPORTED_WIRE_VERSION, common.MAX_SUPPORTED_WIRE_VERSION)) + def reset_server(self, address): + """A copy of this description, with one server marked Unknown.""" + sds = self.server_descriptions() + + # The default ServerDescription's type is Unknown. + sds[address] = ServerDescription(address) + + if self._cluster_type == CLUSTER_TYPE.ReplicaSetWithPrimary: + cluster_type = _check_has_primary(sds) + else: + cluster_type = self._cluster_type + + return ClusterDescription(cluster_type, sds, self._set_name) + + def reset(self): + """A copy of this description, with all servers marked Unknown.""" + if self._cluster_type == CLUSTER_TYPE.ReplicaSetWithPrimary: + cluster_type = CLUSTER_TYPE.ReplicaSetNoPrimary + else: + cluster_type = self._cluster_type + + # The default ServerDescription's type is Unknown. + sds = dict((address, ServerDescription(address)) + for address in self._server_descriptions) + + return ClusterDescription(cluster_type, sds, self._set_name) + def server_descriptions(self): """Dict of (address, ServerDescription).""" return self._server_descriptions.copy() @@ -165,7 +192,7 @@ def updated_cluster_description(cluster_description, server_description): cluster_type = _check_has_primary(sds) elif server_type == SERVER_TYPE.RSPrimary: - cluster_type = _update_rs_from_primary( + cluster_type, set_name = _update_rs_from_primary( sds, set_name, server_description) elif server_type in ( diff --git a/pymongo/collection.py b/pymongo/collection.py index 666c8c994..422a54610 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -374,11 +374,6 @@ class Collection(common.BaseObject): .. mongodoc:: insert """ client = self.database.connection - # Batch inserts require us to know the connected primary's - # max_bson_size, max_message_size, and max_write_batch_size. - # We have to be connected to the primary to know that. - client._ensure_connected(True) - docs = doc_or_docs return_one = False if isinstance(docs, collections.MutableMapping): @@ -410,7 +405,7 @@ class Collection(common.BaseObject): concern = kwargs or self.write_concern safe = concern.get("w") != 0 - if client.max_wire_version > 1 and safe: + if client._writable_max_wire_version() > 1 and safe: # Insert command command = SON([('insert', self.name), ('ordered', not continue_on_error)]) @@ -539,10 +534,6 @@ class Collection(common.BaseObject): if not isinstance(upsert, bool): raise TypeError("upsert must be an instance of bool") - client = self.database.connection - # Need to connect to know the wire version, and may want to connect - # before applying SON manipulators. - client._ensure_connected(True) if manipulate: document = self.__database._fix_incoming(document, self) @@ -559,7 +550,8 @@ class Collection(common.BaseObject): if first.startswith('$'): check_keys = False - if client.max_wire_version > 1 and safe: + client = self.database.connection + if client._writable_max_wire_version() > 1 and safe: # Update command command = SON([('update', self.name)]) if concern: @@ -686,10 +678,7 @@ class Collection(common.BaseObject): safe = concern.get("w") != 0 client = self.database.connection - - # Need to connect to know the wire version. - client._ensure_connected(True) - if client.max_wire_version > 1 and safe: + if client._writable_max_wire_version() > 1 and safe: # Delete command command = SON([('delete', self.name)]) if concern: diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index d8f16a522..874bc6efb 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -92,7 +92,7 @@ class CommandCursor(object): client = self.__collection.database.connection try: response = client._send_message_with_response( - msg, _connection_to_use=self.__conn_id) + msg, address=self.__conn_id) except AutoReconnect: # Don't try to send kill cursors on another socket # or to another server. It can cause a _pinValue diff --git a/pymongo/common.py b/pymongo/common.py index 74e3f0bff..7b34c7b7a 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -31,7 +31,7 @@ from bson.py3compat import string_type, integer_types # Defaults until we connect to a server and get updated limits. MAX_BSON_SIZE = 16 * (1024 ** 2) -MAX_MESSAGE_SIZE = 2 * MAX_BSON_SIZE # TODO: remove. +MAX_MESSAGE_SIZE = 2 * MAX_BSON_SIZE MIN_WIRE_VERSION = 0 MAX_WIRE_VERSION = 0 MAX_WRITE_BATCH_SIZE = 1000 @@ -43,6 +43,12 @@ MAX_SUPPORTED_WIRE_VERSION = 2 # Frequency to call ismaster on servers, in seconds. HEARTBEAT_FREQUENCY = 10 +# How long to wait, in seconds, for a suitable server to be found before +# aborting an operation. For example, if the client attempts an insert +# during a replica set election, SERVER_WAIT_TIME governs the longest it +# is willing to wait for a new primary to be found. +SERVER_WAIT_TIME = 5 + # Spec requires at least 10ms between ismaster calls. MIN_HEARTBEAT_INTERVAL = 0.01 @@ -279,7 +285,6 @@ VALIDATORS = { 'readpreferencetags': validate_read_preference_tags, 'latencythresholdms': validate_positive_float, 'secondaryacceptablelatencyms': validate_positive_float, - 'auto_start_request': validate_boolean, 'authmechanism': validate_auth_mechanism, 'authsource': validate_string, 'gssapiservicename': validate_string, diff --git a/pymongo/cursor.py b/pymongo/cursor.py index cfb8a787c..359c0cb97 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -864,7 +864,7 @@ class Cursor(object): "exhaust": self.__exhaust, } if self.__connection_id is not None: - kwargs["_connection_to_use"] = self.__connection_id + kwargs["address"] = self.__connection_id try: response = client._send_message_with_response(message, **kwargs) diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 68b124c3e..0b4cb54dc 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -33,15 +33,12 @@ access: """ import datetime -import random import socket -import threading -import time import warnings from bson.py3compat import (integer_types, - itervalues, string_type) +from bson.son import SON from pymongo import (auth, common, database, @@ -51,19 +48,18 @@ from pymongo import (auth, thread_util, uri_parser) from pymongo.client_options import ClientOptions +from pymongo.cluster_description import CLUSTER_TYPE from pymongo.cursor_manager import CursorManager -from pymongo.errors import (AutoReconnect, - ConfigurationError, +from pymongo.cluster import Cluster +from pymongo.errors import (ConfigurationError, ConnectionFailure, - DocumentTooLarge, - DuplicateKeyError, - InvalidURI, - OperationFailure) -from pymongo.member import Member + InvalidURI, AutoReconnect, OperationFailure, + DuplicateKeyError, InvalidOperation) from pymongo.read_preferences import ReadPreference -from pymongo.response import Response, ExhaustResponse - -EMPTY = b"" +from pymongo.server_selectors import (any_server_selector, + writable_server_selector) +from pymongo.server_type import SERVER_TYPE +from pymongo.settings import ClusterSettings def _partition_node(node): @@ -81,31 +77,27 @@ def _partition_node(node): class MongoClient(common.BaseObject): - """Connection to MongoDB. - """ - HOST = "localhost" PORT = 27017 - def __init__(self, host=None, port=None, max_pool_size=100, - document_class=dict, tz_aware=False, _connect=True, - **kwargs): - """Create a new connection to a single MongoDB instance at *host:port*. + def __init__( + self, + host=None, + port=None, + max_pool_size=100, + document_class=dict, + tz_aware=False, + connect=True, + **kwargs): + """Client for a MongoDB instance, a replica set, or a set of mongoses. - The resultant client object has connection-pooling built - in. It also performs auto-reconnection when necessary. If an - operation fails because of a connection error, - :class:`~pymongo.errors.ConnectionFailure` is raised. If - auto-reconnection will be performed, - :class:`~pymongo.errors.AutoReconnect` will be - raised. Application code should handle this exception - (recognizing that the operation failed) and then continue to + The client object is thread-safe and has connection-pooling built in. + If an operation fails because of a network error, + :class:`~pymongo.errors.ConnectionFailure` is raised and the client + reconnects in the background. Application code should handle this + exception (recognizing that the operation failed) and then continue to execute. - Raises :class:`TypeError` if port is not an instance of - ``int``. Raises :class:`~pymongo.errors.ConnectionFailure` if - the connection cannot be made. - The `host` parameter can be a full `mongodb URI `_, in addition to a simple hostname. It can also be a list of hostnames or @@ -133,6 +125,9 @@ class MongoClient(common.BaseObject): :class:`~datetime.datetime` instances returned as values in a document by this :class:`MongoClient` will be timezone aware (otherwise they will be naive) + - `connect` (optional): if ``True`` (the default), immediately + begin connecting to MongoDB in the background. Otherwise connect + on the first operation. | **Other optional parameters can be passed as keyword arguments:** @@ -151,10 +146,6 @@ class MongoClient(common.BaseObject): - `socketKeepAlive`: (boolean) Whether to send periodic keep-alive packets on connected sockets. Defaults to ``False`` (do not send keep-alive packets). - - `auto_start_request`: If ``True``, each thread that accesses - this :class:`MongoClient` has a socket allocated to it for the - thread's lifetime. This ensures consistent reads, even if you - read after an unacknowledged write. Defaults to ``False`` | **Write Concern options:** @@ -182,16 +173,14 @@ class MongoClient(common.BaseObject): | **Replica set keyword arguments for connecting with a replica set - either directly or via a mongos:** - | (ignored by standalone mongod instances) - `replicaSet`: (string) The name of the replica set to connect to. - The driver will verify that the replica set it connects to matches + The driver will verify that all servers it connects to match this name. Implies that the hosts specified are a seed list and the - driver should attempt to find all members of the set. *Ignored by - mongos*. + driver should attempt to find all members of the set. - `read_preference`: The read preference for this client. If - connecting to a secondary then a read preference mode *other* than - PRIMARY is required - otherwise all queries will throw + connecting directly to a secondary then a read preference mode + *other* than PRIMARY is required - otherwise all queries will throw :class:`~pymongo.errors.AutoReconnect` "not master". See :class:`~pymongo.read_preferences.ReadPreference` for all available read preference options. @@ -217,13 +206,35 @@ class MongoClient(common.BaseObject): certificates passed from the other end of the connection. Implies ``ssl=True``. - .. seealso:: :meth:`end_request` - .. mongodoc:: connections - .. versionchanged:: 2.5 - Added additional ssl options - .. versionadded:: 2.4 + .. versionchanged:: 3.0 + :class:`~pymongo.mongo_client.MongoClient` is now the one and only + client class for a standalone server, mongos, or replica set. + It includes the functionality that had been split into + :class:`~pymongo.mongo_client.MongoReplicaSetClient`: it can connect + to a replica set, discover all its members, and monitor the set for + stepdowns, elections, and reconfigs. + + The :class:`~pymongo.mongo_client.MongoClient` constructor no + longer blocks while connecting to the server or servers, and it no + longer raises :class:`~pymongo.errors.ConnectionFailure` if they + are unavailable, nor :class:`~pymongo.errors.ConfigurationError` + if the user's credentials are wrong. Instead, the constructor + returns immediately and launches the connection process on + background threads. + + In PyMongo 2.x, :class:`~pymongo.MongoClient` accepted a list of + standalone MongoDB servers and used the first it could connect to:: + + MongoClient(['host1.com:27017', 'host2.com:27017']) + + A list of multiple standalones is no longer supported; if multiple + servers are listed they must be members of the same replica set, or + mongoses in the same sharded cluster. + + The ``connect`` option is added and ``auto_start_request`` is + removed. """ if host is None: host = self.HOST @@ -257,42 +268,23 @@ class MongoClient(common.BaseObject): if not seeds: raise ConfigurationError("need to specify at least one host") - # Seeds are only used before first connection attempt; nodes are then - # used for any reconnects. Nodes are set to all replica set members - # if connecting to a replica set (besides arbiters), or to all - # available mongoses from the seed list, or to the one standalone - # mongod. - self.__seeds = frozenset(seeds) - self.__nodes = frozenset() - self.__member = None # TODO: Rename to __server. - - # _pool_class and _event_class are for deep customization of PyMongo, - # e.g. Motor. SHOULD NOT BE USED BY THIRD-PARTY DEVELOPERS. - self.__pool_class = kwargs.pop('_pool_class', pool.Pool) - self.__event_class = kwargs.pop('_event_class', threading.Event) + # _pool_class, _monitor_class, and _condition_class are for deep + # customization of PyMongo, e.g. Motor. + pool_class = kwargs.pop('_pool_class', None) + monitor_class = kwargs.pop('_monitor_class', None) + condition_class = kwargs.pop('_condition_class', None) kwargs['max_pool_size'] = max_pool_size opts.update(kwargs) - options = ClientOptions(username, password, dbase, opts) + self.__options = options = ClientOptions( + username, password, dbase, opts) self.__default_database_name = dbase - - self.__cursor_manager = CursorManager(self) - - self.__repl = options.replica_set_name - self.__direct = len(seeds) == 1 and not self.__repl - - self.__pool_opts = options.pool_options - - self.__connecting = False - self.__connecting_lock = threading.Lock() - - self.__future_member = None + self.__cursor_manager = None self.__document_class = document_class self.__tz_aware = common.validate_boolean('tz_aware', tz_aware) - self.__auto_start_request = opts.get('auto_start_request', False) - # cache of existing indexes used by ensure_index ops + # Cache of existing indexes used by ensure_index ops. self.__index_cache = {} self.__auth_credentials = {} @@ -300,24 +292,22 @@ class MongoClient(common.BaseObject): options.uuid_subtype, options.write_concern.document) - if _connect: - try: - self._ensure_connected(True) - except AutoReconnect as e: - # ConnectionFailure makes more sense here than AutoReconnect - raise ConnectionFailure(str(e)) + self.__request_counter = thread_util.Counter() - if username: - credentials = options.credentials - try: - self._cache_credentials( - credentials.source, credentials, _connect) - except OperationFailure as exc: - raise ConfigurationError(str(exc)) + cluster_settings = ClusterSettings( + seeds=seeds, + set_name=options.replica_set_name, + pool_class=pool_class, + pool_options=options.pool_options, + monitor_class=monitor_class, + condition_class=condition_class) + + self._cluster = Cluster(cluster_settings) + if connect: + self._cluster.open() def _cached(self, dbname, coll, index): - """Test if `index` is cached. - """ + """Test if `index` is cached.""" cache = self.__index_cache now = datetime.datetime.utcnow() return (dbname in cache and @@ -326,8 +316,7 @@ class MongoClient(common.BaseObject): now < cache[dbname][coll][index]) def _cache_index(self, database, collection, index, cache_for): - """Add an index to the index cache for ensure_index operations. - """ + """Add an index to the index cache for ensure_index operations.""" now = datetime.datetime.utcnow() expire = datetime.timedelta(seconds=cache_for) + now @@ -368,103 +357,91 @@ class MongoClient(common.BaseObject): if index_name in self.__index_cache[database_name][collection_name]: del self.__index_cache[database_name][collection_name][index_name] - def _cache_credentials(self, source, credentials, connect=True): - """Add credentials to the database authentication cache - for automatic login when a socket is created. If `connect` is True, - verify the credentials on the server first. + def _server_property(self, attr_name, default=None): + """An attribute of the current server's description. + + Returns "default" while there is no current server, primary, or mongos. + + Not threadsafe if used multiple times in a single method, since + the server may change. In such cases, store a local reference to a + ServerDescription first, then use its properties. """ - if source in self.__auth_credentials: - # Nothing to do if we already have these credentials. - if credentials == self.__auth_credentials[source]: - return - raise OperationFailure('Another user is already authenticated ' - 'to this database. You must logout first.') + try: + server = self._cluster.select_server( + writable_server_selector, server_wait_time=0) - if connect: - member = self.__ensure_member() - with self.__socket(member) as sock_info: - # Since __check_auth was called in __socket - # there is no need to call it here. - auth.authenticate(credentials, sock_info, self.__simple_command) - sock_info.authset.add(credentials) - - self.__auth_credentials[source] = credentials - - def _purge_credentials(self, source): - """Purge credentials from the database authentication cache. - """ - if source in self.__auth_credentials: - del self.__auth_credentials[source] - - def __create_pool(self, pair): - return self.__pool_class(pair, self.__pool_opts) - - def __check_auth(self, sock_info): - """Authenticate using cached database credentials. - """ - if self.__auth_credentials or sock_info.authset: - cached = set(itervalues(self.__auth_credentials)) - - authset = sock_info.authset.copy() - - # Logout any credentials that no longer exist in the cache. - for credentials in authset - cached: - self.__simple_command(sock_info, credentials[1], {'logout': 1}) - sock_info.authset.discard(credentials) - - for credentials in cached - authset: - auth.authenticate(credentials, - sock_info, self.__simple_command) - sock_info.authset.add(credentials) - - def __member_property(self, attr_name, default=None): - member = self.__member - if member: - return getattr(member, attr_name) - - return default + return getattr(server.description, attr_name) + except ConnectionFailure: + return default @property def host(self): - """Current connected host. + """Hostname of the standalone, primary, or mongos currently in use. - .. versionchanged:: 1.3 - ``host`` is now a property rather than a method. + .. warning:: An application that accesses :attr:`host` and :attr:`port` + is vulnerable to a race condition, if the client switches to a + new primary or mongos in between. Use :attr:`address` instead. """ - member = self.__member - if member: - return member.host[0] - - return None + address = self.address + return address[0] if address else None @property def port(self): - """Current connected port. + """Port of the standalone, primary, or mongos currently in use. - .. versionchanged:: 1.3 - ``port`` is now a property rather than a method. + .. warning:: An application that accesses :attr:`host` and :attr:`port` + is vulnerable to a race condition, if the client switches to a + new primary or mongos in between. Use :attr:`address` instead. """ - member = self.__member - if member: - return member.host[1] + address = self.address + return address[1] if address else None + + @property + def address(self): + """(host, port) of the current standalone, primary, or mongos, or None. + + .. versionadded:: 3.0 + """ + return self._server_property('address') + + @property + def primary(self): + """The (host, port) of the current primary of the replica set. + + Returns None if there is no primary. + + .. versionadded:: 3.0 + MongoClient gained this property in version 3.0 when + MongoReplicaSetClient's functionality was merged in. + """ + return self._cluster.get_primary() + + @property + def secondaries(self): + """The secondary members known to this client. + + A sequence of (host, port) pairs. + + .. versionadded:: 3.0 + MongoClient gained this property in version 3.0 when + MongoReplicaSetClient's functionality was merged in. + """ + return self._cluster.get_secondaries() - return None @property def is_primary(self): - """If this instance is connected to a standalone, a replica set - primary, or the master of a master-slave set. + """If the current server can accept writes. - .. versionadded:: 2.3 + True if the current server is a standalone, mongos, or the primary of + a replica set. """ - return self.__member_property('is_primary', False) + return self._server_property('is_writable', False) @property def is_mongos(self): - """If this instance is connected to mongos. - - .. versionadded:: 2.3 + """If this client is connected to mongos. """ - return self.__member_property('is_mongos', False) + return self._server_property('server_type') == SERVER_TYPE.Mongos @property def max_pool_size(self): @@ -474,74 +451,60 @@ class MongoClient(common.BaseObject): a socket to be returned to the pool. If ``waitQueueTimeoutMS`` is set, a blocked operation will raise :exc:`~pymongo.errors.ConnectionFailure` after a timeout. By default ``waitQueueTimeoutMS`` is not set. - - .. warning:: SIGNIFICANT BEHAVIOR CHANGE in 2.6. Previously, this - parameter would limit only the idle sockets the pool would hold - onto, not the number of open sockets. The default has also changed - to 100. - - .. versionchanged:: 2.6 - .. versionadded:: 1.11 """ - return self.__pool_opts.max_pool_size + return self.__options.pool_options.max_pool_size @property def nodes(self): - """List of all known nodes. + """List of all connected servers. Nodes are either specified when this instance was created, or discovered through the replica set discovery mechanism. - - .. versionadded:: 1.8 """ - return self.__nodes + description = self._cluster.description + return frozenset(s.address for s in description.known_servers) @property - def auto_start_request(self): - """Is auto_start_request enabled? - """ - return self.__auto_start_request + def document_class(self): + """Default class to use for documents returned from this client. - def get_document_class(self): + .. versionchanged:: 3.0 + Now read-only. + """ return self.__document_class - def set_document_class(self, klass): - self.__document_class = klass + def get_document_class(self): + """Default class to use for documents returned from this client. - document_class = property(get_document_class, set_document_class, - doc="""Default class to use for documents - returned from this client. + Deprecated; use the document_class property instead. + """ + warnings.warn('get_document_class() is deprecated, use the' + ' document_class property', + DeprecationWarning, stacklevel=2) - .. versionadded:: 1.7 - """) + return self.__document_class @property def tz_aware(self): """Does this client return timezone-aware datetimes? - - .. versionadded:: 1.8 """ return self.__tz_aware @property def max_bson_size(self): - """Return the maximum size BSON object the connected server - accepts in bytes. Defaults to 16MB if not connected to a - server. + """The largest BSON object the connected server accepts in bytes. - .. versionadded:: 1.10 + Defaults to 16MB if not connected to a server. """ - return self.__member_property('max_bson_size', common.MAX_BSON_SIZE) + return self._server_property('max_bson_size', common.MAX_BSON_SIZE) @property def max_message_size(self): - """Return the maximum message size the connected server - accepts in bytes. Defaults to 32MB if not connected to a - server. + """The largest message the connected server accepts in bytes. - .. versionadded:: 2.6 + Defaults to 32MB if not connected to a server. """ - return self.__member_property( + return self._server_property( 'max_message_size', common.MAX_MESSAGE_SIZE) @property @@ -549,10 +512,8 @@ class MongoClient(common.BaseObject): """The minWireVersion reported by the server. Returns ``0`` when connected to server versions prior to MongoDB 2.6. - - .. versionadded:: 2.7 """ - return self.__member_property( + return self._server_property( 'min_wire_version', common.MIN_WIRE_VERSION) @property @@ -560,10 +521,8 @@ class MongoClient(common.BaseObject): """The maxWireVersion reported by the server. Returns ``0`` when connected to server versions prior to MongoDB 2.6. - - .. versionadded:: 2.7 """ - return self.__member_property( + return self._server_property( 'max_wire_version', common.MAX_WIRE_VERSION) @property @@ -572,290 +531,44 @@ class MongoClient(common.BaseObject): Returns a default value when connected to server versions prior to MongoDB 2.6. - - .. versionadded:: 2.7 """ - return self.__member_property( + return self._server_property( 'max_write_batch_size', common.MAX_WRITE_BATCH_SIZE) - def __simple_command(self, sock_info, dbname, spec): - """Send a command to the server. + def _writable_max_wire_version(self): + """Connect to a writable server and get its max wire protocol version. + + Can raise ConnectionFailure. """ - rqst_id, msg, _ = message.query(0, dbname + '.$cmd', 0, -1, spec) - start = time.time() - sock_info.send_message(msg) - response = sock_info.receive_message(1, rqst_id) + cluster = self._get_cluster() # Starts monitors if necessary. + server = cluster.select_server(writable_server_selector) + return server.description.max_wire_version - end = time.time() - response = helpers._unpack_response(response)['data'][0] - msg = "command %r failed: %%s" % spec - helpers._check_command_response(response, None, msg) - return response, end - start - - def __try_node(self, node): - """Try to connect to this node and see if it works for our connection - type. Returns a Member and set of hosts (including this one). Doesn't - modify state. - - :Parameters: - - `node`: The (host, port) pair to try. + def _is_writable(self): + """Attempt to connect to a writable server, or return False. """ - # Call 'ismaster' directly so we can get a response time. - connection_pool = self.__create_pool(node) - with connection_pool.get_socket() as sock_info: - response, res_time = self.__simple_command(sock_info, - 'admin', - {'ismaster': 1}) - - member = Member( - node, - connection_pool, - response, - res_time) - - nodes = frozenset([node]) - - # Replica Set? - if not self.__direct: - # Check that this host is part of the given replica set. - if self.__repl and member.set_name != self.__repl: - raise ConfigurationError("%s:%d is not a member of " - "replica set %s" - % (node[0], node[1], self.__repl)) - - if "hosts" in response: - nodes = frozenset([ - _partition_node(h) for h in response["hosts"]]) - - if member.is_primary: - return member, nodes - - elif "primary" in response: - # Shortcut: a secondary usually tells us who the primary is. - candidate = _partition_node(response["primary"]) - return self.__try_node(candidate) - - # Explain why we aren't using this connection. - raise AutoReconnect('%s:%d is not primary or master' % node) - - # Direct connection - if member.is_arbiter and not self.__direct: - raise ConfigurationError("%s:%d is an arbiter" % node) - - return member, nodes - - def __pick_nearest(self, candidates): - """Return the 'nearest' Member instance based on response time. - - Doesn't modify state. - """ - latency = self.read_preference.latency_threshold_ms - # Only used for mongos high availability, ping_time is in seconds. - fastest = min([ - member.ping_time for member in candidates]) - - near_candidates = [ - member for member in candidates - if member.ping_time - fastest < latency / 1000.0] - - return random.choice(near_candidates) - - def __ensure_member(self): - """Connect and return a Member instance, or raise AutoReconnect.""" - # If `connecting` is False, no thread is in __find_node(), - # and `future_member` is resolved. `member` may be None if the - # last __find_node() attempt failed, otherwise it is in `nodes`. - # - # If `connecting` is True, a thread is in __find_node(), - # `member` is None, and `future_member` is pending. - # - # To violate these invariants temporarily, acquire the lock. - # Note that disconnect() interacts with this method. - self.__connecting_lock.acquire() - if self.__member: - member = self.__member - self.__connecting_lock.release() - return member - - elif self.__connecting: - # A thread is in __find_node(). Wait. - future = self.__future_member - self.__connecting_lock.release() - return future.result() - - else: - self.__connecting = True - future = self.__future_member = thread_util.Future( - self.__event_class) - - self.__connecting_lock.release() - - member = None - nodes = None - exc = None - - try: - try: - member, nodes = self.__find_node() - return member - except Exception as e: - exc = e - raise - finally: - # We're either returning a Member or raising an error. - # Propagate either outcome to waiting threads. - self.__connecting_lock.acquire() - self.__member = member - self.__connecting = False - - # If we discovered a set of nodes, use them from now on; - # otherwise we're raising an error. Stick with the last - # known good set of nodes. - if nodes: - self.__nodes = nodes - - if member: - # Unblock waiting threads. - future.set_result(member) - else: - # Raise exception in waiting threads. - future.set_exception(exc) - - self.__connecting_lock.release() - - def __find_node(self): - """Find a server suitable for our connection type. - - Returns a Member and a set of nodes. Doesn't modify state. - - If only one host was supplied to __init__ see if we can connect - to it. Don't check if the host is a master/primary so we can make - a direct connection to read from a secondary or send commands to - an arbiter. - - If more than one host was supplied treat them as a seed list for - connecting to a replica set or to support high availability for - mongos. If connecting to a replica set try to find the primary, - and set `nodes` to list of all members. - - If a mongos seed list was provided find the "nearest" mongos and - return it, setting `nodes` to all mongoses in the seed list that - are up. - - Otherwise we iterate through the list trying to find a host we can - send write operations to. - """ - assert not self.__member, \ - "__find_node unexpectedly running with a non-null Member" - - errors = [] - mongos_candidates = [] - candidates = self.__nodes or self.__seeds - chosen_member = None - discovered_nodes = None - - for candidate in candidates: - try: - member, nodes = self.__try_node(candidate) - if member.is_mongos and not self.__direct: - mongos_candidates.append(member) - - # We intend to find all the mongoses; keep trying nodes. - continue - elif len(mongos_candidates): - raise ConfigurationError("Seed list cannot contain a mix " - "of mongod and mongos instances.") - - # We've found a suitable node. - chosen_member = member - discovered_nodes = nodes - break - except (OperationFailure, ConfigurationError, ValueError): - # The server is available but something failed, e.g. auth, - # wrong replica set name, or incompatible wire protocol. - raise - except Exception as why: - errors.append(str(why)) - - if len(mongos_candidates): - # If we have a mongos seed list, pick the "nearest" member. - chosen_member = self.__pick_nearest(mongos_candidates) - mongoses = frozenset(m.host for m in mongos_candidates) - - # The first time, __nodes is empty and mongoses becomes nodes. - return chosen_member, self.__nodes or mongoses - - if not chosen_member: - # Couldn't find a suitable host. - raise AutoReconnect(', '.join(errors)) - - return chosen_member, discovered_nodes - - def __socket(self, member): - """Get a SocketInfo or raise AutoReconnect. - - Calls disconnect() on error. - """ - connection_pool = member.pool + cluster = self._get_cluster() # Starts monitors if necessary. try: - if self.auto_start_request and not connection_pool.in_request(): - connection_pool.start_request() - - sock_info = connection_pool.get_socket() - except socket.error as why: - self.disconnect() - - # Check if a unix domain socket - host, port = member.host - if host.endswith('.sock'): - host_details = "%s:" % host - else: - host_details = "%s:%d:" % (host, port) - raise AutoReconnect("could not connect to " - "%s %s" % (host_details, str(why))) - try: - self.__check_auth(sock_info) - except OperationFailure: - connection_pool.maybe_return_socket(sock_info) - raise - return sock_info - - def _ensure_connected(self, sync=False): - """Ensure this client instance is connected to a mongod/s. - """ - self.__ensure_member() + cluster.select_server(writable_server_selector) + return True + except ConnectionFailure: + return False def disconnect(self): """Disconnect from MongoDB. Disconnecting will close all underlying sockets in the connection - pool. If this instance is used again it will be automatically - re-opened. Care should be taken to make sure that :meth:`disconnect` - is not called in the middle of a sequence of operations in which - ordering is important. This could lead to unexpected results. - - .. seealso:: :meth:`end_request` - .. versionadded:: 1.3 + pools. If this instance is used again it will be automatically + re-opened. """ - self.__connecting_lock.acquire() - member, self.__member = self.__member, None - self.__connecting_lock.release() - - # Close sockets promptly. - if member: - member.pool.reset() + self._cluster.reset() def close(self): """Alias for :meth:`disconnect` Disconnecting will close all underlying sockets in the connection - pool. If this instance is used again it will be automatically - re-opened. Care should be taken to make sure that :meth:`disconnect` - is not called in the middle of a sequence of operations in which - ordering is important. This could lead to unexpected results. - - .. seealso:: :meth:`end_request` - .. versionadded:: 2.1 + pools. If this instance is used again it will be automatically + re-opened. """ self.disconnect() @@ -863,12 +576,11 @@ class MongoClient(common.BaseObject): """Return ``False`` if there has been an error communicating with the server, else ``True``. - This method attempts to check the status of the server with minimal I/O. - The current thread retrieves a socket from the pool (its - request socket if it's in a request, or a random idle socket if it's not - in a request) and checks whether calling `select`_ on it raises an - error. If there are currently no idle sockets, :meth:`alive` will - attempt to actually connect to the server. + This method attempts to check the status of the server (the standalone, + replica set primary, or the mongos currently in use) with minimal I/O. + Retrieves a socket from the pool and checks whether calling `select`_ + on it raises an error. If there are currently no idle sockets, + :meth:`alive` attempts to actually connect to the server. A more certain way to determine server availability is:: @@ -879,17 +591,17 @@ class MongoClient(common.BaseObject): # In the common case, a socket is available and was used recently, so # calling select() on it is a reasonable attempt to see if the OS has # reported an error. - self.__connecting_lock.acquire() - member = self.__member - self.__connecting_lock.release() - if not member: + try: + # TODO: Mongos pinning. + server = self._cluster.select_server( + writable_server_selector, + server_wait_time=0) + + with server.pool.get_socket() as sock_info: + return not pool._closed(sock_info.sock) + + except (socket.error, ConnectionFailure): return False - else: - try: - with member.pool.get_socket() as sock_info: - return not pool._closed(sock_info.sock) - except (socket.error, ConnectionFailure): - return False def set_cursor_manager(self, manager_class): """Set this client's cursor manager. @@ -913,6 +625,15 @@ class MongoClient(common.BaseObject): self.__cursor_manager = manager + def _get_cluster(self): + """Get the internal :class:`~pymongo.cluster.Cluster` object. + + If this client was created with "connect=False", calling _get_cluster + launches the connection process in the background. + """ + self._cluster.open() + return self._cluster + def __check_response_to_last_error(self, response, is_command): """Check a response to a lastError message for errors. @@ -926,6 +647,8 @@ class MongoClient(common.BaseObject): assert response["number_returned"] == 1 result = response["data"][0] + # Raises AutoReconnect if "not master" or "node is recovering", + # OperationFailure for all other errors. helpers._check_command_response(result, self.disconnect) # write commands - skip getLastError checking @@ -954,30 +677,10 @@ class MongoClient(common.BaseObject): raise DuplicateKeyError(details["err"], code, result) raise OperationFailure(details["err"], code, result) - def __check_bson_size(self, message): - """Make sure the message doesn't include BSON documents larger - than the connected server will accept. - - :Parameters: - - `message`: message to check - """ - if len(message) == 3: - (request_id, data, max_doc_size) = message - if max_doc_size > self.max_bson_size: - raise DocumentTooLarge("BSON document too large (%d bytes)" - " - the connected server supports" - " BSON document sizes up to %d" - " bytes." % - (max_doc_size, self.max_bson_size)) - return (request_id, data) - else: - # get_more and kill_cursors messages - # don't include BSON documents. - return message - - def _send_message(self, message, - with_last_error=False, command=False, check_primary=True): - """Say something to Mongo. + def _send_message( + self, message, with_last_error=False, command=False, + check_primary=True, address=None): + """Send a message to MongoDB, optionally returning response as a dict. Raises ConnectionFailure if the message cannot be sent. Raises OperationFailure if `with_last_error` is ``True`` and the @@ -986,84 +689,106 @@ class MongoClient(common.BaseObject): is ``False``. :Parameters: - - `message`: message to send - - `with_last_error`: check getLastError status after sending the - message - - `check_primary`: don't try to write to a non-primary; see - kill_cursors for an exception to this rule + - `message`: (request_id, data). + - `with_last_error` (optional): check getLastError status after + sending the message. + - `check_primary` (optional): don't try to write to a non-primary; + see kill_cursors for an exception to this rule. + - `command` (optional): True for a write command. + - `address` (optional): Optional address when sending a getMore or + killCursors to a specific server. """ - member = self.__ensure_member() - if check_primary and not with_last_error and not self.is_primary: - # The write won't succeed, bail as if we'd done a getLastError + cluster = self._get_cluster() + if address: + assert not check_primary, "Can't use check_primary with address" + server = cluster.get_server_by_address(address) + if not server: + raise AutoReconnect('server %s:%d no longer available' + % address) + else: + server = cluster.select_server(writable_server_selector) + + is_writable = server.description.is_writable + if check_primary and not with_last_error and not is_writable: + # When directly connected to a single server, we select it even + # if it isn't writable. The write won't succeed, so bail as if + # we'd done a getLastError. raise AutoReconnect("not master") - with self.__socket(member) as sock_info: - try: - (request_id, data) = self.__check_bson_size(message) - sock_info.send_message(data) - # Safe mode. We pack the message together with a lastError - # message and send both. We then get the response (to the - # lastError) and raise OperationFailure if it is an error - # response. - rv = None - if with_last_error: - response = sock_info.receive_message(1, request_id) - rv = self.__check_response_to_last_error(response, command) + if self.in_request() and not server.in_request(): + server.start_request() - return rv - except OperationFailure: - raise - except (ConnectionFailure, socket.error) as e: - self.disconnect() - raise AutoReconnect(str(e)) + if with_last_error or command: + response = self._reset_on_error( + server, + server.send_message_with_response, message) - def __send_and_receive(self, message, sock_info): - """Send a message on the given socket and return the response data. - """ - (request_id, data) = self.__check_bson_size(message) - sock_info.send_message(data) - return sock_info.receive_message(1, request_id) + # Disconnects and raises ConnectionFailure if "not master" + # or "recovering". + return self.__check_response_to_last_error(response.data, + command) + else: + # Send the message. No response. + self._reset_on_error(server, server.send_message, message) - def _send_message_with_response(self, message, **kwargs): - """Send a message to MongoDB and return a Response object. + def _send_message_with_response( + self, message, read_preference=None, exhaust=False, address=None): + """Send a message to MongoDB and return a Response. :Parameters: - - `message`: (request_id, data) pair making up the message to send + - `message`: (request_id, data, max_doc_size) or (request_id, data). + - `read_preference` (optional): A ReadPreference. + - `exhaust` (optional): If True, the socket used stays checked out. + It is returned along with its Pool in the Response. """ - member = self.__ensure_member() - sock_info = self.__socket(member) + cluster = self._get_cluster() + if address: + server = cluster.get_server_by_address(address) + if not server: + raise AutoReconnect('server %s:%d no longer available' + % address) + else: + if read_preference: + selector = read_preference.select_servers + else: + selector = writable_server_selector - # For exhaust queries, tell the socket not to check itself back in. - exhaust = kwargs.get('exhaust') - sock_info.exhaust(exhaust) - with sock_info: - try: - data = self.__send_and_receive(message, sock_info) - if exhaust: - return ExhaustResponse( - data=data, - address=member.host, - socket_info=sock_info, - pool=member.pool) - else: - return Response( - data=data, - address=member.host) + server = cluster.select_server(selector) - except (ConnectionFailure, socket.error) as e: + if self.in_request() and not server.in_request(): + server.start_request() + + return self._reset_on_error( + server, + server.send_message_with_response, message, exhaust) + + def _reset_on_error(self, server, fn, *args, **kwargs): + """Execute an operation. Reset the pool on network error. + + Returns fn()'s return value on success. On error, clears the server's + pool and marks the server Unknown or, if the server is the primary, + resets all pools and servers. + + Re-raises any exception thrown by fn(). + """ + try: + return fn(*args, **kwargs) + except ConnectionFailure: + if server.description.is_writable: self.disconnect() - raise AutoReconnect(str(e)) + else: + self._cluster.reset_server(server.address) + + raise def start_request(self): """Ensure the current thread always uses the same socket until it calls :meth:`end_request`. This ensures consistent reads, even if you read after an unacknowledged write. - In Python 2.6 and above, or in Python 2.5 with - "from __future__ import with_statement", :meth:`start_request` can be - used as a context manager: + :meth:`start_request` can be used as a context manager: - >>> client = pymongo.MongoClient(auto_start_request=False) + >>> client = pymongo.MongoClient() >>> db = client.test >>> _id = db.test_collection.insert({}) >>> with client.start_request(): @@ -1075,59 +800,65 @@ class MongoClient(common.BaseObject): If a thread calls start_request multiple times, an equal number of calls to :meth:`end_request` is required to end the request. - - .. versionchanged:: 2.4 - Now counts the number of calls to start_request and doesn't end - request until an equal number of calls to end_request. - - .. versionadded:: 2.2 - The :class:`~pymongo.pool.Request` return value. - :meth:`start_request` previously returned None """ - member = self.__ensure_member() - member.pool.start_request() + # TODO: Remove implicit threadlocal requests, use explicit requests. + # TODO: Start / end replica set member pinning? + if 1 == self.__request_counter.inc(): + # Start requests on all existing pools. New pools created while + # this thread is in a request will have start_request() called + # lazily. These greedy calls are to make PyMongo 2.x's request + # tests pass. + try: + servers = self._cluster.select_servers(any_server_selector, + server_wait_time=0) + + for s in servers: + s.start_request() + except AutoReconnect: + # No servers available. + pass + return pool.Request(self) def in_request(self): """True if this thread is in a request, meaning it has a socket reserved for its exclusive use. """ - member = self.__member # Don't try to connect if disconnected. - return member and member.pool.in_request() + return bool(self.__request_counter.get()) def end_request(self): """Undo :meth:`start_request`. If :meth:`end_request` is called as many times as :meth:`start_request`, the request is over and this thread's connection returns to the pool. Extra calls to :meth:`end_request` have no effect. - - Ending a request allows the :class:`~socket.socket` that has - been reserved for this thread by :meth:`start_request` to be returned to - the pool. Other threads will then be able to re-use that - :class:`~socket.socket`. If your application uses many threads, or has - long-running threads that infrequently perform MongoDB operations, then - judicious use of this method can lead to performance gains. Care should - be taken, however, to make sure that :meth:`end_request` is not called - in the middle of a sequence of operations in which ordering is - important. This could lead to unexpected results. """ - member = self.__member # Don't try to connect if disconnected. - if member: - member.pool.end_request() + if 0 == self.__request_counter.dec(): + try: + servers = self._cluster.select_servers(any_server_selector, + server_wait_time=0) + + for s in servers: + s.end_request() + except ConnectionFailure: + # No servers, we've disconnected. + pass def __eq__(self, other): if isinstance(other, self.__class__): - return self.host == other.host and self.port == other.port + return self.address == other.address return NotImplemented def __ne__(self, other): return not self == other def __repr__(self): - if len(self.__nodes) == 1: - return "MongoClient(%r, %r)" % (self.host, self.port) + server_descriptions = self._cluster.description.server_descriptions() + if len(server_descriptions) == 1: + description, = server_descriptions.values() + return "MongoClient(%r, %r)" % description.address else: - return "MongoClient(%r)" % ["%s:%d" % n for n in self.__nodes] + return "MongoClient(%r)" % [ + "%s:%d" % address for address in server_descriptions] def __getattr__(self, name): """Get a database by name. @@ -1151,7 +882,7 @@ class MongoClient(common.BaseObject): """ return self.__getattr__(name) - def close_cursor(self, cursor_id, address): + def close_cursor(self, cursor_id, address=None): """Close a single database cursor. Raises :class:`TypeError` if `cursor_id` is not an instance of @@ -1160,15 +891,36 @@ class MongoClient(common.BaseObject): :Parameters: - `cursor_id`: id of cursor to close - - `address`: (host, port) pair of the cursor's server + - `address` (optional): (host, port) pair of the cursor's server + + .. versionchanged:: 3.0 + Added ``address`` parameter. """ if not isinstance(cursor_id, integer_types): raise TypeError("cursor_id must be an instance of (int, long)") - self.__cursor_manager.close(cursor_id) + # TODO: update this, pass address to cursor_manager.close(). + # PyMongo 2.x introduced a configurable CursorManager which sends + # OP_KILLCURSORS to the server. The API doesn't handle a multi-server + # cluster, where we must pass the address of the server that receives + # the message. Support CursorManager for backwards compatibility, but + # only for single servers. + if self.__cursor_manager: + cluster_type = self._cluster.description.cluster_type + if cluster_type not in (CLUSTER_TYPE.Single, CLUSTER_TYPE.Sharded): + raise InvalidOperation( + "Can't use custom CursorManager with cluster type %s" % + CLUSTER_TYPE._fields[cluster_type]) + + self.__cursor_manager.close(cursor_id) + else: + return self._send_message( + message.kill_cursors([cursor_id]), + check_primary=False, + address=address) def kill_cursors(self, cursor_ids): - """Send a kill cursors message with the given ids. + """Send a kill cursors message with the given ids to the primary. Raises :class:`TypeError` if `cursor_ids` is not an instance of ``list``. @@ -1182,23 +934,22 @@ class MongoClient(common.BaseObject): message.kill_cursors(cursor_ids), check_primary=False) def server_info(self): - """Get information about the MongoDB server we're connected to. - """ + """Get information about the MongoDB server we're connected to.""" return self.admin.command("buildinfo", read_preference=ReadPreference.PRIMARY) def database_names(self): - """Get a list of the names of all databases on the connected server. - """ + """Get a list of the names of all databases on the connected server.""" return [db["name"] for db in self.admin.command("listDatabases", - read_preference=ReadPreference.PRIMARY)["databases"]] + read_preference=ReadPreference.PRIMARY)["databases"]] def drop_database(self, name_or_database): """Drop a database. Raises :class:`TypeError` if `name_or_database` is not an instance of - :class:`basestring` (:class:`str` in python 3) or Database. + :class:`basestring` (:class:`str` in python 3) or + :class:`~pymongo.database.Database`. :Parameters: - `name_or_database`: the name of a database to drop, or a @@ -1238,11 +989,6 @@ class MongoClient(common.BaseObject): - `from_host` (optional): host name to copy from - `username` (optional): username for source database - `password` (optional): password for source database - - .. note:: Specifying `username` and `password` requires server - version **>= 1.3.3+**. - - .. versionadded:: 1.5 """ if not isinstance(from_name, string_type): raise TypeError("from_name must be an " @@ -1253,27 +999,31 @@ class MongoClient(common.BaseObject): database._check_name(to_name) - command = {"fromdb": from_name, "todb": to_name} + command = SON([ + ("copydb", 1), ("fromdb", from_name), ("todb", to_name)]) if from_host is not None: command["fromhost"] = from_host - try: - self.start_request() + # _get_cluster() starts connecting, if we initialized with connect=False. + server = self._get_cluster().select_server( + writable_server_selector) + if self.in_request() and not server.in_request(): + server.start_request() + + with server.pool.get_socket() as sock: if username is not None: - nonce = self.admin.command("copydbgetnonce", - read_preference=ReadPreference.PRIMARY, - fromhost=from_host)["nonce"] + get_nonce_cmd = SON([("copydbgetnonce", 1), + ("fromhost", from_host)]) + + nonce = sock.command("admin", get_nonce_cmd)["nonce"] + command["username"] = username command["nonce"] = nonce command["key"] = auth._auth_key(nonce, username, password) - return self.admin.command("copydb", - read_preference=ReadPreference.PRIMARY, - **command) - finally: - self.end_request() + return sock.command("admin", command) def get_default_database(self): """Get the database named in the MongoDB connection URI. @@ -1296,8 +1046,6 @@ class MongoClient(common.BaseObject): """Is this server locked? While locked, all write operations are blocked, although read operations may still be allowed. Use :meth:`unlock` to unlock. - - .. versionadded:: 2.0 """ ops = self.admin.current_op() return bool(ops.get('fsyncLock', 0)) @@ -1317,16 +1065,12 @@ class MongoClient(common.BaseObject): .. warning:: MongoDB does not support the `async` option on Windows and will raise an exception on that platform. - - .. versionadded:: 2.0 """ self.admin.command("fsync", read_preference=ReadPreference.PRIMARY, **kwargs) def unlock(self): """Unlock a previously locked server. - - .. versionadded:: 2.0 """ self.admin['$cmd'].sys.unlock.find_one() diff --git a/pymongo/mongo_client_new.py b/pymongo/mongo_client_new.py deleted file mode 100644 index b7144264b..000000000 --- a/pymongo/mongo_client_new.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2009-2014 MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""TODO: help string.""" - -import random -import threading - -from bson.py3compat import (string_type) -from pymongo import (database, - monitor, - pool, - uri_parser, ReadPreference) -from pymongo.cluster import Cluster -from pymongo.cluster_description import CLUSTER_TYPE -from pymongo.errors import (ConfigurationError) -from pymongo.server_selectors import (writable_server_selector, - secondary_server_selector) -from pymongo.settings import ClusterSettings - - -class MongoClientNew(object): - """Connection to one or more MongoDB servers. - """ - - def __init__( - self, - host='localhost', - port=27017, - replicaSet=None, - ): - """TODO: docstring""" - if isinstance(host, string_type): - host = [host] - - if not isinstance(port, int): - raise TypeError("port must be an instance of int") - - seeds = set() - - for entity in host: - seeds.update(uri_parser.split_hosts(entity, port)) - if not seeds: - raise ConfigurationError("need to specify at least one host") - - cluster_settings = ClusterSettings( - seeds=seeds, - set_name=replicaSet, - pool_class=pool.Pool, - monitor_class=monitor.Monitor, - condition_class=threading.Condition) - - # TODO: parse URI, socket timeouts, ssl args, auth, - # pool_class, document_class, pool options, condition_class, - # default database. - - self._cluster = Cluster(cluster_settings) - self._cluster.open() - - # TODO: these are here to fake the old MongoClient's API for the sake - # of existing Database, Collection, and Cursor. - self.read_preference = ReadPreference.PRIMARY - self.uuid_subtype = 4 - self.write_concern = {} - self.document_class = dict - self.tz_aware = False - - @property - def is_mongos(self): - return self._cluster.description.cluster_type == CLUSTER_TYPE.Mongos - - # TODO: Remove. Database, Collection, etc. should use Cluster. - def _send_message_with_response( - self, - msg, - read_preference=ReadPreference.PRIMARY, - exhaust=False): - """Send a message to MongoDB and return a Response object. - - :Parameters: - - `msg`: (request_id, data) pair making up the message to send. - - `read_preference`: A ReadPreference. - - `exhaust`: True for an exhaust cursor's initial query. - """ - request_id, data, max_doc_size = msg - - # TODO: real read preferences. - if read_preference == ReadPreference.PRIMARY: - servers = self._cluster.select_servers(writable_server_selector) - else: - servers = self._cluster.select_servers(secondary_server_selector) - - server = random.choice(servers) - return server.send_message_with_response(data, request_id, exhaust) - - def __getattr__(self, name): - """Get a database by name. - - Raises :class:`~pymongo.errors.InvalidName` if an invalid - database name is used. - - :Parameters: - - `name`: the name of the database to get - """ - return database.Database(self, name) - - def __getitem__(self, name): - """Get a database by name. - - Raises :class:`~pymongo.errors.InvalidName` if an invalid - database name is used. - - :Parameters: - - `name`: the name of the database to get - """ - return self.__getattr__(name) diff --git a/pymongo/monitor.py b/pymongo/monitor.py index 8e6081723..178a3d45a 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -79,7 +79,7 @@ class Monitor(threading.Thread): self.close() else: start = time.time() # TODO: monotonic. - self._event.wait(self._settings.heartbeat_frequency) + self._event.wait(common.HEARTBEAT_FREQUENCY) self._event.clear() wait_time = time.time() - start if wait_time < common.MIN_HEARTBEAT_INTERVAL: diff --git a/pymongo/pool.py b/pymongo/pool.py index 07803bfb4..546f9056b 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -20,9 +20,15 @@ import time import threading import weakref -from pymongo import thread_util +from bson.py3compat import u +from pymongo import helpers, message, thread_util from pymongo.errors import ConnectionFailure +# If the first getaddrinfo call of this interpreter's life is on a thread, +# while the main thread holds the import lock, getaddrinfo deadlocks trying +# to import the IDNA codec. Import it here, where presumably we're on the +# main thread, to avoid the deadlock. See PYTHON-607. +u('foo').encode('idna') try: from ssl import match_hostname, CertificateError @@ -152,8 +158,24 @@ class SocketInfo(object): # Are we being used by an exhaust cursor? self._exhaust = False + def command(self, dbname, spec): + """Execute a command over the socket, or raise socket.error. + + :Parameters: + - `dbname`: name of the database on which to run the command + - `spec`: a command document as a dict, SON, or mapping object + """ + # TODO: command should already be encoded. + request_id, msg, _ = message.query(0, dbname + '.$cmd', 0, -1, spec) + self.send_message(msg) + response = self.receive_message(1, request_id) + unpacked = helpers._unpack_response(response)['data'][0] + msg = "command %r failed: %%s" % spec + helpers._check_command_response(unpacked, None, msg) + return unpacked + def send_message(self, message): - """Send a raw BSON message. + """Send a raw BSON message or raise socket.error. If a network exception is raised, the socket is closed. """ @@ -164,7 +186,7 @@ class SocketInfo(object): raise def receive_message(self, operation, request_id): - """Receive a raw BSON message. + """Receive a raw BSON message or raise socket.error. If any exception is raised, the socket is closed. """ diff --git a/pymongo/server.py b/pymongo/server.py index b238875a5..d00584fa9 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -14,6 +14,9 @@ """Communicate with one MongoDB server in a cluster.""" +import socket + +from pymongo.errors import AutoReconnect, DocumentTooLarge from pymongo.server_type import SERVER_TYPE from pymongo.response import Response, ExhaustResponse @@ -38,29 +41,61 @@ class Server(object): """Check the server's state soon.""" self._monitor.request_check() - def send_message_with_response(self, message, request_id, exhaust): - """Send a message to MongoDB and return a Response object. + def send_message(self, message): + """Send an unacknowledged message to MongoDB. + + Can raise ConnectionFailure. :Parameters: - - `message`: BSON bytes. + - `message`: (request_id, data). - `request_id`: A number. - - `exhaust`: If True, the socket used stays checked out. It is - returned along with its Pool in the Response. """ - with self._pool.get_socket() as sock_info: - sock_info.exhaust(exhaust) - sock_info.send_message(message) - response_data = sock_info.receive_message(1, request_id) - if exhaust: - return ExhaustResponse( - data=response_data, - address=self._description.address, - socket_info=sock_info, - pool=self._pool) - else: - return Response( - data=response_data, - address=self._description.address) + request_id, data = self._check_bson_size(message) + try: + with self._pool.get_socket() as sock_info: + sock_info.send_message(data) + except socket.error as exc: + raise AutoReconnect(str(exc)) + + def send_message_with_response(self, message, exhaust=False): + """Send a message to MongoDB and return a Response object. + + Can raise ConnectionFailure. + + :Parameters: + - `message`: (request_id, data, max_doc_size) or (request_id, data). + - `request_id`: A number. + - `exhaust` (optional): If True, the socket used stays checked out. + It is returned along with its Pool in the Response. + """ + request_id, data = self._check_bson_size(message) + try: + with self._pool.get_socket() as sock_info: + sock_info.exhaust(exhaust) + sock_info.send_message(data) + response_data = sock_info.receive_message(1, request_id) + if exhaust: + return ExhaustResponse( + data=response_data, + address=self._description.address, + socket_info=sock_info, + pool=self._pool) + else: + return Response( + data=response_data, + address=self._description.address) + except socket.error as exc: + raise AutoReconnect(str(exc)) + + def start_request(self): + # TODO: Remove implicit threadlocal requests, use explicit requests. + self.pool.start_request() + + def in_request(self): + return self.pool.in_request() + + def end_request(self): + self.pool.end_request() @property def description(self): @@ -75,6 +110,27 @@ class Server(object): def pool(self): return self._pool + def _check_bson_size(self, message): + """Make sure the message doesn't include BSON documents larger + than the server will accept. + + :Parameters: + - `message`: (request_id, data, max_doc_size) or (request_id, data) + + Returns request_id, data. + """ + if len(message) == 3: + request_id, data, max_doc_size = message + if max_doc_size > self.description.max_bson_size: + raise DocumentTooLarge( + "BSON document too large (%d bytes) - the connected server" + "supports BSON document sizes up to %d bytes." % + (max_doc_size, self.description.max_bson_size)) + return request_id, data + else: + # get_more and kill_cursors messages don't include BSON documents. + return message + def __str__(self): d = self._description return '' % ( diff --git a/pymongo/settings.py b/pymongo/settings.py index bcb53fadf..17895579f 100644 --- a/pymongo/settings.py +++ b/pymongo/settings.py @@ -27,12 +27,10 @@ class ClusterSettings(object): self, seeds=None, set_name=None, - server_wait_time=None, pool_class=None, pool_options=None, monitor_class=monitor.Monitor, condition_class=threading.Condition, - heartbeat_frequency=common.HEARTBEAT_FREQUENCY, ): """Represent MongoClient's configuration. @@ -40,12 +38,10 @@ class ClusterSettings(object): """ self._seeds = seeds or [('localhost', 27017)] self._set_name = set_name - self._server_wait_time = server_wait_time or 5 # Seconds. self._pool_class = pool_class or pool.Pool self._pool_options = pool_options or PoolOptions() self._monitor_class = monitor_class or monitor.Monitor self._condition_class = condition_class or threading.Condition - self._heartbeat_frequency = heartbeat_frequency self._direct = (len(self._seeds) == 1 and not set_name) @property @@ -57,10 +53,6 @@ class ClusterSettings(object): def set_name(self): return self._set_name - @property - def server_wait_time(self): - return self._server_wait_time - @property def pool_class(self): return self._pool_class @@ -77,10 +69,6 @@ class ClusterSettings(object): def condition_class(self): return self._condition_class - @property - def heartbeat_frequency(self): - return self._heartbeat_frequency - @property def direct(self): """Connect directly to a single server, or use a set of servers? diff --git a/pymongo/thread_util.py b/pymongo/thread_util.py index f99e10615..ec005a14a 100644 --- a/pymongo/thread_util.py +++ b/pymongo/thread_util.py @@ -118,36 +118,6 @@ class Counter(object): return self._counters.get(self.ident.get(), 0) -class Future(object): - """Minimal backport of concurrent.futures.Future. - - event_class makes this Future adaptable for async clients. - """ - def __init__(self, event_class): - self._event = event_class() - self._result = None - self._exception = None - - def set_result(self, result): - self._result = result - self._event.set() - - def set_exception(self, exc): - if hasattr(exc, 'with_traceback'): - # Python 3: avoid potential reference cycle. - self._exception = exc.with_traceback(None) - else: - self._exception = exc - self._event.set() - - def result(self): - self._event.wait() - if self._exception: - raise self._exception - else: - return self._result - - ### Begin backport from CPython 3.2 for timeout support for Semaphore.acquire class Semaphore: diff --git a/test/__init__.py b/test/__init__.py index f7dd16df4..371676e29 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -30,6 +30,7 @@ from functools import wraps import pymongo from bson.py3compat import _unicode +from pymongo import common from test.version import Version # hostnames retrieved by MongoReplicaSetClient from isMaster will be of unicode @@ -64,7 +65,9 @@ class ClientContext(object): self.test_commands_enabled = False self.is_mongos = False try: - self.client = pymongo.MongoClient(host, port) + client = pymongo.MongoClient(host, port) + client.admin.command('ismaster') # Can we connect? + self.client = client except pymongo.errors.ConnectionFailure: self.client = None else: @@ -211,6 +214,58 @@ class IntegrationTest(unittest.TestCase): pass +class client_knobs(object): + def __init__(self, heartbeat_frequency=None, server_wait_time=None): + self.heartbeat_frequency = heartbeat_frequency + self.server_wait_time = server_wait_time + + self.old_heartbeat_frequency = None + self.old_server_wait_time = None + + def enable(self): + self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY + self.old_server_wait_time = common.SERVER_WAIT_TIME + if self.heartbeat_frequency is not None: + common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency + + if self.server_wait_time is not None: + common.SERVER_WAIT_TIME = self.server_wait_time + + def __enter__(self): + self.enable() + + def disable(self): + common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency + common.SERVER_WAIT_TIME = self.old_server_wait_time + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disable() + + +class MockClientTest(unittest.TestCase): + """Base class for TestCases that use MockClient. + + This class is *not* an IntegrationTest: if properly written, MockClient + tests do not require a running server. + + The class temporarily overrides HEARTBEAT_FREQUENCY and SERVER_WAIT_TIME + to speed up tests. + """ + + def setUp(self): + super(MockClientTest, self).setUp() + + self.client_knobs = client_knobs( + heartbeat_frequency=0.001, + server_wait_time=0.1) + + self.client_knobs.enable() + + def tearDown(self): + self.client_knobs.disable() + super(MockClientTest, self).tearDown() + + def connection_string(seeds=[pair]): if client_context.auth_enabled: return "mongodb://%s:%s@%s" % (db_user, db_pwd, ','.join(seeds)) diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 69d4e02b9..b64d13bfb 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -15,13 +15,16 @@ """Tools for mocking parts of PyMongo to test other parts.""" import socket +from functools import partial from pymongo import common -from pymongo import MongoClient, MongoReplicaSetClient +from pymongo import MongoClient +from pymongo.ismaster import IsMaster +from pymongo.monitor import Monitor from pymongo.pool import Pool, PoolOptions +from pymongo.server_description import ServerDescription from test import host as default_host, port as default_port -from test.utils import my_partial class MockPool(Pool): @@ -51,9 +54,45 @@ class MockPool(Pool): return sock_info -class MockClientBase(object): - def __init__(self, standalones, members, mongoses, config): - """standalones, etc., are like ['a:1', 'b:2']""" +class MockMonitor(Monitor): + def __init__( + self, + client, + server_description, + cluster, + pool, + cluster_settings): + # MockMonitor gets a 'client' arg, regular monitors don't. + self.client = client + self.mock_address = server_description.address + + # Actually connect to the default server. + Monitor.__init__( + self, + ServerDescription((default_host, default_port)), + cluster, + pool, + cluster_settings) + + def _check_once(self): + try: + response = self.client.mock_is_master('%s:%d' % self.mock_address) + return ServerDescription(self.mock_address, IsMaster(response)) + except socket.error: + return None + + +class MockClient(MongoClient): + def __init__( + self, standalones, members, mongoses, ismaster_hosts=None, + *args, **kwargs): + """A MongoClient connected to the default server, with a mock cluster. + + standalones, members, mongoses determine the configuration of the + cluster. They are formatted like ['a:1', 'b:2']. ismaster_hosts + provides an alternative host list for the server's mocked ismaster + response; see test_connect_with_internal_ips. + """ self.mock_standalones = standalones[:] self.mock_members = members[:] @@ -62,8 +101,8 @@ class MockClientBase(object): else: self.mock_primary = None - if config is not None: - self.mock_ismaster_hosts = config + if ismaster_hosts is not None: + self.mock_ismaster_hosts = ismaster_hosts else: self.mock_ismaster_hosts = members[:] @@ -78,6 +117,11 @@ class MockClientBase(object): # Hostname -> max write batch size self.mock_max_write_batch_sizes = {} + kwargs['_pool_class'] = partial(MockPool, self) + kwargs['_monitor_class'] = partial(MockMonitor, self) + + super(MockClient, self).__init__(*args, **kwargs) + def kill_host(self, host): """Host is like 'a:1'.""" self.mock_down_hosts.append(host) @@ -106,6 +150,7 @@ class MockClientBase(object): if host in self.mock_standalones: return { + 'ok': 1, 'ismaster': True, 'minWireVersion': min_wire_version, 'maxWireVersion': max_wire_version, @@ -116,6 +161,7 @@ class MockClientBase(object): # Simulate a replica set member. response = { + 'ok': 1, 'ismaster': ismaster, 'secondary': not ismaster, 'setName': 'rs', @@ -131,6 +177,7 @@ class MockClientBase(object): if host in self.mock_mongoses: return { + 'ok': 1, 'ismaster': True, 'minWireVersion': min_wire_version, 'maxWireVersion': max_wire_version, @@ -140,49 +187,3 @@ class MockClientBase(object): # In test_internal_ips(), we try to connect to a host listed # in ismaster['hosts'] but not publicly accessible. raise socket.error('Unknown host: %s' % host) - - def simple_command(self, sock_info, dbname, spec): - # __simple_command is also used for authentication, but in this - # test it's only used for ismaster. - assert spec == {'ismaster': 1} - response = self.mock_is_master( - '%s:%s' % (sock_info.mock_host, sock_info.mock_port)) - - ping_time = 10 - return response, ping_time - - -class MockClient(MockClientBase, MongoClient): - def __init__( - self, standalones, members, mongoses, ismaster_hosts=None, - *args, **kwargs - ): - MockClientBase.__init__( - self, standalones, members, mongoses, ismaster_hosts) - - kwargs['_pool_class'] = my_partial(MockPool, self) - MongoClient.__init__(self, *args, **kwargs) - - def _MongoClient__simple_command(self, sock_info, dbname, spec): - return self.simple_command(sock_info, dbname, spec) - - -class MockReplicaSetClient(MockClientBase, MongoReplicaSetClient): - def __init__( - self, standalones, members, mongoses, ismaster_hosts=None, - *args, **kwargs - ): - MockClientBase.__init__( - self, standalones, members, mongoses, ismaster_hosts) - - kwargs['_pool_class'] = my_partial(MockPool, self) - MongoReplicaSetClient.__init__(self, *args, **kwargs) - - def _MongoReplicaSetClient__is_master(self, host): - response = self.mock_is_master('%s:%s' % host) - connection_pool = MockPool(self, host) - ping_time = 10 - return response, connection_pool, ping_time - - def _MongoReplicaSetClient__simple_command(self, sock_info, dbname, spec): - return self.simple_command(sock_info, dbname, spec) diff --git a/test/test_binary.py b/test/test_binary.py index 8a70f96e1..2e3f8a7df 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -244,7 +244,7 @@ class TestBinary(unittest.TestCase): def test_uri_to_uuid(self): uri = "mongodb://foo/?uuidrepresentation=csharpLegacy" - client = MongoClient(uri, _connect=False) + client = MongoClient(uri, connect=False) self.assertEqual(client.pymongo_test.test.uuid_subtype, CSHARP_LEGACY) @client_context.require_connection diff --git a/test/test_client.py b/test/test_client.py index 41a7f9ddc..f1e083999 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -14,12 +14,14 @@ """Test the mongo_client module.""" +import contextlib import datetime import os import threading import socket import sys import time +import warnings sys.path[0:0] = [""] @@ -34,8 +36,11 @@ from pymongo.errors import (AutoReconnect, ConfigurationError, ConnectionFailure, InvalidName, - OperationFailure) + OperationFailure, + CursorNotFound) +from pymongo.server_selectors import writable_server_selector from test import (client_context, + client_knobs, connection_string, host, pair, @@ -44,7 +49,8 @@ from test import (client_context, unittest, IntegrationTest, db_pwd, - db_user) + db_user, + MockClientTest) from test.pymongo_mocks import MockClient from test.utils import (assertRaisesExactly, delay, @@ -54,9 +60,10 @@ from test.utils import (assertRaisesExactly, TestRequestMixin, _TestLazyConnectMixin, lazy_client_trial, - NTHREADS, get_pool, - one) + one, + connected, + wait_until) class ClientUnitTest(unittest.TestCase, TestRequestMixin): @@ -64,7 +71,7 @@ class ClientUnitTest(unittest.TestCase, TestRequestMixin): @classmethod def setUpClass(cls): - cls.client = MongoClient(host, port, _connect=False) + cls.client = MongoClient(host, port, connect=False) def test_types(self): self.assertRaises(TypeError, MongoClient, 1) @@ -97,18 +104,18 @@ class ClientUnitTest(unittest.TestCase, TestRequestMixin): self.assertRaises(TypeError, iterate) def test_get_default_database(self): - c = MongoClient("mongodb://%s:%d/foo" % (host, port), _connect=False) + c = MongoClient("mongodb://%s:%d/foo" % (host, port), connect=False) self.assertEqual(Database(c, 'foo'), c.get_default_database()) def test_get_default_database_error(self): # URI with no database. - c = MongoClient("mongodb://%s:%d/" % (host, port), _connect=False) + c = MongoClient("mongodb://%s:%d/" % (host, port), connect=False) self.assertRaises(ConfigurationError, c.get_default_database) def test_get_default_database_with_authsource(self): # Ensure we distinguish database name from authSource. uri = "mongodb://%s:%d/foo?authSource=src" % (host, port) - c = MongoClient(uri, _connect=False) + c = MongoClient(uri, connect=False) self.assertEqual(Database(c, 'foo'), c.get_default_database()) @@ -120,19 +127,22 @@ class TestClient(IntegrationTest, TestRequestMixin): cls.client = client_context.client def test_constants(self): - MongoClient.HOST = host - MongoClient.PORT = port - self.assertTrue(MongoClient()) - + # Set bad defaults. MongoClient.HOST = "somedomainthatdoesntexist.org" MongoClient.PORT = 123456789 - assertRaisesExactly( - ConnectionFailure, MongoClient, connectTimeoutMS=600) - self.assertTrue(MongoClient(host, port)) + with client_knobs(server_wait_time=0.01): + with self.assertRaises(AutoReconnect): + connected(MongoClient()) + # Override the defaults. No error. + connected(MongoClient(host, port)) + + # Set good defaults. MongoClient.HOST = host MongoClient.PORT = port - self.assertTrue(MongoClient()) + + # No error. + connected(MongoClient()) def assertIsInstance(self, obj, cls, msg=None): """Backport from Python 2.7.""" @@ -141,13 +151,12 @@ class TestClient(IntegrationTest, TestRequestMixin): self.fail(self._formatMessage(msg, standardMsg)) def test_init_disconnected(self): - c = MongoClient(host, port, _connect=False) + c = MongoClient(host, port, connect=False) self.assertIsInstance(c.is_primary, bool) self.assertIsInstance(c.is_mongos, bool) self.assertIsInstance(c.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) - self.assertIsInstance(c.auto_start_request, bool) self.assertEqual(dict, c.get_document_class()) self.assertIsInstance(c.tz_aware, bool) self.assertIsInstance(c.max_bson_size, int) @@ -168,36 +177,29 @@ class TestClient(IntegrationTest, TestRequestMixin): self.assertTrue(c.min_wire_version >= 0) bad_host = "somedomainthatdoesntexist.org" - c = MongoClient(bad_host, port, connectTimeoutMS=1, _connect=False) + with client_knobs(server_wait_time=0.01): + c = MongoClient(bad_host, port) self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one) def test_init_disconnected_with_auth(self): uri = "mongodb://user:pass@somedomainthatdoesntexist" - c = MongoClient(uri, connectTimeoutMS=1, _connect=False) + with client_knobs(server_wait_time=0.01): + c = MongoClient(uri) self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one) - def test_connect(self): - # Check that the exception is a ConnectionFailure, not a subclass like - # AutoReconnect - assertRaisesExactly( - ConnectionFailure, MongoClient, - "somedomainthatdoesntexist.org", connectTimeoutMS=600) - - assertRaisesExactly( - ConnectionFailure, MongoClient, host, 123456789) - - self.assertTrue(MongoClient(host, port)) - def test_equality(self): + c = connected(MongoClient(host, port)) + # ClientContext.client is constructed as MongoClient(host, port) - self.assertEqual(self.client, MongoClient(host, port)) + self.assertEqual(self.client, c) + # Explicitly test inequality - self.assertFalse(self.client != MongoClient(host, port)) + self.assertFalse(self.client != c) def test_host_w_port(self): - self.assertTrue(MongoClient("%s:%d" % (host, port))) - assertRaisesExactly( - ConnectionFailure, MongoClient, "%s:1234567" % (host,), port) + with client_knobs(server_wait_time=0.01): + with self.assertRaises(AutoReconnect): + connected(MongoClient("%s:1234567")) def test_repr(self): # Making host a str avoids the 'u' prefix in Python 2, so the repr is @@ -326,8 +328,9 @@ class TestClient(IntegrationTest, TestRequestMixin): coll.count() def test_from_uri(self): - self.assertEqual(self.client, - MongoClient("mongodb://%s:%d" % (host, port))) + self.assertEqual( + self.client, + connected(MongoClient("mongodb://%s:%d" % (host, port)))) @client_context.require_auth def test_auth_from_uri(self): @@ -356,12 +359,12 @@ class TestClient(IntegrationTest, TestRequestMixin): # Auth with lazy connection. MongoClient( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), - _connect=False).pymongo_test.test.find_one() + connect=False).pymongo_test.test.find_one() # Wrong password. bad_client = MongoClient( "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), - _connect=False) + connect=False) self.assertRaises(OperationFailure, bad_client.pymongo_test.test.find_one) @@ -375,7 +378,7 @@ class TestClient(IntegrationTest, TestRequestMixin): def test_lazy_auth_raises_operation_failure(self): lazy_client = MongoClient( "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), - _connect=False) + connect=False) assertRaisesExactly( OperationFailure, lazy_client.test.collection.find_one) @@ -392,7 +395,8 @@ class TestClient(IntegrationTest, TestRequestMixin): if not os.access(mongodb_socket, os.R_OK): raise SkipTest("Socket file is not accessible") - self.assertTrue(MongoClient("mongodb://%s" % mongodb_socket)) + # No error. + connected(MongoClient("mongodb://%s" % mongodb_socket)) client = MongoClient("mongodb://%s" % mongodb_socket) client.pymongo_test.test.save({"dummy": "object"}) @@ -402,8 +406,9 @@ class TestClient(IntegrationTest, TestRequestMixin): self.assertTrue("pymongo_test" in dbs) # Confirm it fails with a missing socket - self.assertRaises(ConnectionFailure, MongoClient, - "mongodb:///tmp/none-existent.sock") + self.assertRaises( + ConnectionFailure, + connected, MongoClient("mongodb:///tmp/non-existent.sock")) def test_fork(self): # Test using a client before and after a fork. @@ -473,15 +478,6 @@ class TestClient(IntegrationTest, TestRequestMixin): self.assertTrue(isinstance(db.test.find_one(), dict)) self.assertFalse(isinstance(db.test.find_one(), SON)) - c.document_class = SON - - try: - self.assertEqual(SON, c.document_class) - self.assertTrue(isinstance(db.test.find_one(), SON)) - self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON)) - finally: - c.document_class = dict - c = get_client(pair, document_class=SON) db = c.pymongo_test @@ -489,11 +485,13 @@ class TestClient(IntegrationTest, TestRequestMixin): self.assertTrue(isinstance(db.test.find_one(), SON)) self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON)) - c.document_class = dict + # document_class is read-only in PyMongo 3.0. + with self.assertRaises(AttributeError): + c.document_class = dict - self.assertEqual(dict, c.document_class) - self.assertTrue(isinstance(db.test.find_one(), dict)) - self.assertFalse(isinstance(db.test.find_one(), SON)) + with warnings.catch_warnings(): + warnings.simplefilter("error", DeprecationWarning) + self.assertRaises(DeprecationWarning, c.get_document_class) def test_timeouts(self): client = MongoClient(host, port, connectTimeoutMS=10500) @@ -503,10 +501,10 @@ class TestClient(IntegrationTest, TestRequestMixin): def test_socket_timeout_ms_validation(self): c = get_client(pair, socketTimeoutMS=10 * 1000) - self.assertEqual(10, c._MongoClient__pool_opts.socket_timeout) + self.assertEqual(10, get_pool(c).opts.socket_timeout) - c = get_client(pair, socketTimeoutMS=None) - self.assertEqual(None, c._MongoClient__pool_opts.socket_timeout) + c = connected(get_client(pair, socketTimeoutMS=None)) + self.assertEqual(None, get_pool(c).opts.socket_timeout) self.assertRaises(ConfigurationError, get_client, pair, socketTimeoutMS=0) @@ -564,24 +562,20 @@ class TestClient(IntegrationTest, TestRequestMixin): naive.pymongo_test.test.find_one()["x"]) def test_ipv6(self): - try: - MongoClient("[::1]") - except: - # Either mongod was started without --ipv6 - # or the OS doesn't support it (or both). - raise SkipTest("No IPv6") - - # Try a few simple things - MongoClient("mongodb://[::1]:%d" % (port,)) - MongoClient("mongodb://[::1]:%d/?w=0" % (port,)) - MongoClient("[::1]:%d,localhost:%d" % (port, port)) + with client_knobs(server_wait_time=0.01): + try: + connected(MongoClient("[::1]")) + except: + # Either mongod was started without --ipv6 + # or the OS doesn't support it (or both). + raise SkipTest("No IPv6") if client_context.auth_enabled: auth_str = "%s:%s@" % (db_user, db_pwd) else: auth_str = "" - uri = "mongodb://%slocalhost:%d,[::1]:%d" % (auth_str, port, port) + uri = "mongodb://%s[::1]:%d" % (auth_str, port) client = MongoClient(uri) client.pymongo_test.test.save({"dummy": u("object")}) client.pymongo_test_bernie.test.save({"dummy": u("object")}) @@ -616,9 +610,7 @@ class TestClient(IntegrationTest, TestRequestMixin): self.assertFalse(locked) def test_contextlib(self): - import contextlib - - client = get_client(pair, auto_start_request=False) + client = get_client(pair) client.pymongo_test.drop_collection("test") client.pymongo_test.test.insert({"foo": "bar"}) @@ -628,11 +620,12 @@ class TestClient(IntegrationTest, TestRequestMixin): with contextlib.closing(client): self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) - self.assertEqual(None, client._MongoClient__member) + self.assertEqual(1, len(get_pool(client).sockets)) + self.assertEqual(0, len(get_pool(client).sockets)) with self.client as client: self.assertEqual("bar", client.pymongo_test.test.find_one()["foo"]) - self.assertEqual(None, client._MongoClient__member) + self.assertEqual(0, len(get_pool(client).sockets)) def test_with_start_request(self): pool = get_pool(self.client) @@ -667,37 +660,7 @@ class TestClient(IntegrationTest, TestRequestMixin): self.assertNoRequest(pool) self.assertDifferentSock(pool) - def test_auto_start_request(self): - for bad_horrible_value in (None, 5, 'hi!'): - self.assertRaises( - (TypeError, ConfigurationError), - lambda: get_client(pair, auto_start_request=bad_horrible_value) - ) - - # auto_start_request should default to False - self.assertFalse(self.client.auto_start_request) - - client = get_client(pair, auto_start_request=True) - self.assertTrue(client.auto_start_request) - - # Assure we acquire a request socket. - client.pymongo_test.test.find_one() - self.assertTrue(client.in_request()) - pool = get_pool(client) - self.assertRequestSocket(pool) - self.assertSameSock(pool) - - client.end_request() - self.assertNoRequest(pool) - self.assertDifferentSock(pool) - - # Trigger auto_start_request - client.pymongo_test.test.find_one() - self.assertRequestSocket(pool) - self.assertSameSock(pool) - def test_nested_request(self): - # auto_start_request is False pool = get_pool(self.client) self.assertFalse(self.client.in_request()) @@ -830,7 +793,8 @@ class TestClient(IntegrationTest, TestRequestMixin): def test_operation_failure_with_request(self): # Ensure MongoClient doesn't close socket after it gets an error # response to getLastError. PYTHON-395. - c = get_client(pair, auto_start_request=True) + c = get_client(pair) + c.start_request() pool = get_pool(c) # Pool reserves a socket for this thread. @@ -850,98 +814,42 @@ class TestClient(IntegrationTest, TestRequestMixin): def test_alive(self): self.assertTrue(self.client.alive()) - client = MongoClient('doesnt exist', _connect=False) + client = MongoClient('doesnt exist', connect=False) self.assertFalse(client.alive()) - def test_wire_version(self): - c = MockClient( - standalones=[], - members=['a:1', 'b:2', 'c:3'], - mongoses=[], - host='b:2', # Pass a secondary. - replicaSet='rs', - _connect=False) + def test_kill_cursors(self): + self.collection = self.client.pymongo_test.test + self.collection.remove() + + # Ensure two batches. + self.collection.insert({'_id': i} for i in range(200)) - c.set_wire_version_range('a:1', 1, 5) - c.db.command('ismaster') # Connect. - self.assertEqual(c.min_wire_version, 1) - self.assertEqual(c.max_wire_version, 5) - - c.set_wire_version_range('a:1', 10, 11) - c.disconnect() - self.assertRaises(ConfigurationError, c.db.collection.find_one) - - def test_max_wire_version(self): - c = MockClient( - standalones=[], - members=['a:1', 'b:2', 'c:3'], - mongoses=[], - host='b:2', # Pass a secondary. - replicaSet='rs', - _connect=False) - - c.set_max_write_batch_size('a:1', 1) - c.set_max_write_batch_size('b:2', 2) - - # Starts with default max batch size. - self.assertEqual(1000, c.max_write_batch_size) - c.db.command('ismaster') # Connect. - # Uses primary's max batch size. - self.assertEqual(c.max_write_batch_size, 1) - - # b becomes primary. - c.mock_primary = 'b:2' - c.disconnect() - self.assertEqual(1000, c.max_write_batch_size) - c.db.command('ismaster') # Connect. - self.assertEqual(c.max_write_batch_size, 2) - - def test_wire_version_mongos_ha(self): - c = MockClient( - standalones=[], - members=[], - mongoses=['a:1', 'b:2', 'c:3'], - host='a:1,b:2,c:3', - _connect=False) - - c.set_wire_version_range('a:1', 2, 5) - c.set_wire_version_range('b:2', 2, 2) - c.set_wire_version_range('c:3', 1, 1) - c.db.command('ismaster') # Connect. - - # Which member did we use? - used_host = '%s:%s' % (c.host, c.port) - expected_min, expected_max = c.mock_wire_versions[used_host] - self.assertEqual(expected_min, c.min_wire_version) - self.assertEqual(expected_max, c.max_wire_version) - - c.set_wire_version_range('a:1', 0, 0) - c.set_wire_version_range('b:2', 0, 0) - c.set_wire_version_range('c:3', 0, 0) - c.disconnect() - c.db.command('ismaster') - used_host = '%s:%s' % (c.host, c.port) - expected_min, expected_max = c.mock_wire_versions[used_host] - self.assertEqual(expected_min, c.min_wire_version) - self.assertEqual(expected_max, c.max_wire_version) + cursor = self.collection.find() + next(cursor) + self.client.kill_cursors([cursor.cursor_id]) + + with self.assertRaises(CursorNotFound): + list(cursor) @client_context.require_replica_set def test_replica_set(self): name = client_context.setname - MongoClient(host, port, replicaSet=name) # No error. + connected(MongoClient(host, port, replicaSet=name)) # No error. - self.assertRaises( - ConfigurationError, - MongoClient, host, port, replicaSet='bad' + name) + with client_knobs(server_wait_time=0.01): + client = MongoClient(host, port, replicaSet='bad' + name) + + with self.assertRaises(AutoReconnect): + connected(client) def test_lazy_connect_w0(self): - client = get_client(connection_string(), _connect=False) + client = get_client(connection_string(), connect=False) client.pymongo_test.test.insert({}, w=0) - client = get_client(connection_string(), _connect=False) + client = get_client(connection_string(), connect=False) client.pymongo_test.test.update({}, {'$set': {'x': 1}}, w=0) - client = get_client(connection_string(), _connect=False) + client = get_client(connection_string(), connect=False) client.pymongo_test.test.remove(w=0) @client_context.require_no_mongos @@ -953,6 +861,9 @@ class TestClient(IntegrationTest, TestRequestMixin): pool = get_pool(client) pool._check_interval_seconds = None # Never check. + # Ensure a socket. + connected(client) + # Cause a network error. sock_info = one(pool.sockets) sock_info.sock.close() @@ -993,6 +904,90 @@ class TestClient(IntegrationTest, TestRequestMixin): c.test.collection.find_one() +class TestClientProperties(MockClientTest): + + @client_context.require_connection + def test_wire_version(self): + c = MockClient( + standalones=[], + members=['a:1', 'b:2', 'c:3'], + mongoses=[], + host='b:2', # Pass a secondary. + replicaSet='rs', + connect=False) + + c.set_wire_version_range('a:1', 1, 5) + c._get_cluster().select_servers(writable_server_selector) # Connect. + self.assertEqual(c.min_wire_version, 1) + self.assertEqual(c.max_wire_version, 5) + + c.set_wire_version_range('a:1', 10, 11) + c.disconnect() + c._get_cluster() + self.assertRaises(ConfigurationError, c.db.collection.find_one) + + def test_max_wire_version(self): + c = MockClient( + standalones=[], + members=['a:1', 'b:2', 'c:3'], + mongoses=[], + host='b:2', # Pass a secondary. + replicaSet='rs', + connect=False) + + c.set_max_write_batch_size('a:1', 1) + c.set_max_write_batch_size('b:2', 2) + + # Starts with default max batch size. + self.assertEqual(1000, c.max_write_batch_size) + c._get_cluster() + wait_until(lambda: len(c.nodes) == 3, 'connect') + + # Uses primary's max batch size. + self.assertEqual(c.max_write_batch_size, 1) + + # b becomes primary. + c.mock_primary = 'b:2' + c.disconnect() + self.assertEqual(1000, c.max_write_batch_size) + + c._get_cluster() + wait_until(lambda: len(c.nodes) == 3, 'connect') + self.assertEqual(c.max_write_batch_size, 2) + + def test_wire_version_mongos_ha(self): + # TODO: Reimplement Mongos HA with PyMongo 3's MongoClient. + raise SkipTest('Mongos HA must be reimplemented in PyMongo 3') + + c = MockClient( + standalones=[], + members=[], + mongoses=['a:1', 'b:2', 'c:3'], + host='a:1,b:2,c:3', + connect=False) + + c.set_wire_version_range('a:1', 2, 5) + c.set_wire_version_range('b:2', 2, 2) + c.set_wire_version_range('c:3', 1, 1) + c.db.command('ismaster') # Connect. + + # Which member did we use? + used_host = '%s:%s' % (c.host, c.port) + expected_min, expected_max = c.mock_wire_versions[used_host] + self.assertEqual(expected_min, c.min_wire_version) + self.assertEqual(expected_max, c.max_wire_version) + + c.set_wire_version_range('a:1', 0, 0) + c.set_wire_version_range('b:2', 0, 0) + c.set_wire_version_range('c:3', 0, 0) + c.disconnect() + c.db.command('ismaster') + used_host = '%s:%s' % (c.host, c.port) + expected_min, expected_max = c.mock_wire_versions[used_host] + self.assertEqual(expected_min, c.min_wire_version) + self.assertEqual(expected_max, c.max_wire_version) + + class TestClientLazyConnect(IntegrationTest, _TestLazyConnectMixin): def _get_client(self, **kwargs): @@ -1019,63 +1014,40 @@ class TestClientLazyConnectBadSeeds(IntegrationTest): client = collection.database.connection self.assertEqual(0, len(client.nodes)) - lazy_client_trial(reset, connect, test, self._get_client) + with client_knobs(server_wait_time=0.01): + lazy_client_trial(reset, connect, test, self._get_client) -class TestClientLazyConnectOneGoodSeed( - IntegrationTest, - _TestLazyConnectMixin): - - def _get_client(self, **kwargs): - kwargs.setdefault('connectTimeoutMS', 100) - - # Assume there are no open mongods listening on a.com, b.com, .... - bad_seeds = ['%s.com' % chr(ord('a') + i) for i in range(10)] - seeds = bad_seeds + [pair] - - # MongoClient puts the seeds in a set before iterating, so order is - # undefined. - return get_client(connection_string(seeds=seeds), **kwargs) - - def test_insert(self): - def reset(collection): - collection.drop() - - def insert(collection, dummy): - collection.insert({}) - - def test(collection): - self.assertEqual(NTHREADS, collection.count()) - - lazy_client_trial(reset, insert, test, self._get_client) - - -class TestMongoClientFailover(IntegrationTest): +class TestMongoClientFailover(MockClientTest): def test_discover_primary(self): - c = MockClient( - standalones=[], - members=['a:1', 'b:2', 'c:3'], - mongoses=[], - host='b:2', # Pass a secondary. - replicaSet='rs') + # Disable background refresh. + with client_knobs(heartbeat_frequency=9999999): + c = MockClient( + standalones=[], + members=['a:1', 'b:2', 'c:3'], + mongoses=[], + host='b:2', # Pass a secondary. + replicaSet='rs') - self.assertEqual('a', c.host) - self.assertEqual(1, c.port) - self.assertEqual(3, len(c.nodes)) + wait_until(lambda: len(c.nodes) == 3, 'connect') + self.assertEqual('a', c.host) + self.assertEqual(1, c.port) - # Fail over. - c.kill_host('a:1') - c.mock_primary = 'b:2' + # Fail over. + c.kill_host('a:1') + c.mock_primary = 'b:2' - # Force reconnect. - c.disconnect() - c.db.command('ismaster') - self.assertEqual('b', c.host) - self.assertEqual(2, c.port) + c.disconnect() + self.assertEqual(0, len(c.nodes)) - # a:1 is still in nodes. - self.assertEqual(3, len(c.nodes)) + c._get_cluster().select_servers(writable_server_selector) + self.assertEqual('b', c.host) + self.assertEqual(2, c.port) + + # a:1 not longer in nodes. + self.assertLess(len(c.nodes), 3) + wait_until(lambda: len(c.nodes) == 2, 'discover node "c"') def test_reconnect(self): # Verify the node list isn't forgotten during a network failure. @@ -1086,6 +1058,8 @@ class TestMongoClientFailover(IntegrationTest): host='b:2', # Pass a secondary. replicaSet='rs') + wait_until(lambda: len(c.nodes) == 3, 'connect') + # Total failure. c.kill_host('a:1') c.kill_host('b:2') @@ -1094,12 +1068,11 @@ class TestMongoClientFailover(IntegrationTest): # MongoClient discovers it's alone. self.assertRaises(AutoReconnect, c.db.collection.find_one) - # But it remembers its node list. - self.assertEqual(3, len(c.nodes)) - - # So it can reconnect. + # But it can reconnect. c.revive_host('a:1') - c.db.command('ismaster') + c._get_cluster().select_servers(writable_server_selector) + self.assertEqual('a', c.host) + self.assertEqual(1, c.port) if __name__ == "__main__": diff --git a/test/test_client_new.py b/test/test_client_new.py deleted file mode 100644 index dbb37755b..000000000 --- a/test/test_client_new.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2009-2014 MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Test the mongo_client module.""" - -import sys - -sys.path[0:0] = [""] - -from pymongo import ReadPreference -from pymongo.mongo_client_new import MongoClientNew -from test import host, port, unittest, IntegrationTest, client_context, SkipTest -from test.utils import one - - -class TestClientNew(IntegrationTest): - def test_buildinfo(self): - c = MongoClientNew(host, port) - assert 'version' in c.admin.command('buildinfo') - - def test_ismaster(self): - c = MongoClientNew(['%s:%d' % (host, port)]) - response = c.admin.command('ismaster') - self.assertTrue('ismaster' in response) - - @client_context.require_replica_set - def test_ismaster_primary(self): - primary_pair = client_context.rs_client.primary - c = MongoClientNew(['%s:%d' % primary_pair], - replicaSet=client_context.setname) - - response = c.admin.command('ismaster') - self.assertTrue(response['ismaster']) - - @client_context.require_replica_set - def test_ismaster_secondary(self): - secondary_pair = one(client_context.rs_client.secondaries) - c = MongoClientNew(['%s:%d' % secondary_pair], - replicaSet=client_context.setname) - - # 'ismaster' does not obey read preference, so hack it. - secondary_response = c.admin['$cmd'].find_one( - {'ismaster': 1}, - read_preference=ReadPreference.SECONDARY) - - # Make sure we really executed ismaster on a secondary. - self.assertTrue(secondary_response.get('secondary')) - - def test_find(self): - # TODO: re-enable once MongoClientNew can do auth. - if client_context.auth_enabled: - raise SkipTest("MongoClientNew can't yet do auth") - - client_context.client.pymongo_test.test_collection.insert({'_id': 1}) - c = MongoClientNew(host, port) - docs = list(c.pymongo_test.test_collection.find()) - self.assertEqual([{'_id': 1}], docs) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/test_cluster.py b/test/test_cluster.py index caaafbfe0..5d0c1e10d 100644 --- a/test/test_cluster.py +++ b/test/test_cluster.py @@ -27,8 +27,7 @@ from pymongo.server_type import SERVER_TYPE from pymongo.cluster import Cluster from pymongo.cluster_description import CLUSTER_TYPE from pymongo.errors import (ConfigurationError, - ConnectionFailure, - InvalidOperation) + ConnectionFailure) from pymongo.ismaster import IsMaster from pymongo.monitor import Monitor from pymongo.read_preferences import MovingAverage @@ -36,7 +35,7 @@ from pymongo.server_description import ServerDescription from pymongo.server_selectors import (any_server_selector, writable_server_selector) from pymongo.settings import ClusterSettings -from test import unittest +from test import unittest, client_knobs class MockSocketInfo(object): @@ -95,8 +94,7 @@ def create_mock_cluster(seeds=None, set_name=None, monitor_class=MockMonitor): partitioned_seeds, set_name=set_name, pool_class=MockPool, - monitor_class=monitor_class, - heartbeat_frequency=99999999) + monitor_class=monitor_class) c = Cluster(cluster_settings) c.open() @@ -122,7 +120,20 @@ def get_type(cluster, hostname): return description.server_type -class TestSingleServerCluster(unittest.TestCase): +class ClusterTest(unittest.TestCase): + """Disables periodic monitoring, to make tests deterministic.""" + + def setUp(self): + super(ClusterTest, self).setUp() + self.client_knobs = client_knobs(heartbeat_frequency=9999999) + self.client_knobs.enable() + + def tearDown(self): + self.client_knobs.disable() + super(ClusterTest, self).tearDown() + + +class TestSingleServerCluster(ClusterTest): def test_direct_connection(self): for server_type, ismaster_response in [ (SERVER_TYPE.RSPrimary, { @@ -206,7 +217,7 @@ class TestSingleServerCluster(unittest.TestCase): self.assertEqual(2, s.description.round_trip_time) -class TestMultiServerCluster(unittest.TestCase): +class TestMultiServerCluster(ClusterTest): def test_unexpected_host(self): # Received ismaster response from host not in cluster. # E.g., a race where the host is removed before it responds. @@ -390,6 +401,45 @@ class TestMultiServerCluster(unittest.TestCase): self.assertEqual(CLUSTER_TYPE.ReplicaSetWithPrimary, c.description.cluster_type) + def test_reset_server(self): + c = create_mock_cluster(set_name='rs') + got_ismaster(c, ('a', 27017), { + 'ok': 1, + 'ismaster': True, + 'setName': 'rs', + 'hosts': ['a', 'b']}) + + got_ismaster(c, ('b', 27017), { + 'ok': 1, + 'ismaster': False, + 'secondary': True, + 'setName': 'rs', + 'hosts': ['a', 'b']}) + + c.reset_server(('a', 27017)) + self.assertEqual(SERVER_TYPE.Unknown, get_type(c, 'a')) + self.assertEqual(SERVER_TYPE.RSSecondary, get_type(c, 'b')) + self.assertEqual('rs', c.description.set_name) + self.assertEqual(CLUSTER_TYPE.ReplicaSetNoPrimary, + c.description.cluster_type) + + got_ismaster(c, ('a', 27017), { + 'ok': 1, + 'ismaster': True, + 'setName': 'rs', + 'hosts': ['a', 'b']}) + + self.assertEqual(SERVER_TYPE.RSPrimary, get_type(c, 'a')) + self.assertEqual(CLUSTER_TYPE.ReplicaSetWithPrimary, + c.description.cluster_type) + + c.reset_server(('b', 27017)) + self.assertEqual(SERVER_TYPE.RSPrimary, get_type(c, 'a')) + self.assertEqual(SERVER_TYPE.Unknown, get_type(c, 'b')) + self.assertEqual('rs', c.description.set_name) + self.assertEqual(CLUSTER_TYPE.ReplicaSetWithPrimary, + c.description.cluster_type) + def test_discover_set_name_from_primary(self): # Discovering a replica set without the setName supplied by the user # is not yet supported by MongoClient, but Cluster can do it. @@ -620,7 +670,7 @@ class TestMultiServerCluster(unittest.TestCase): self.assertEqual(2, write_batch_size()) -class TestClusterErrors(unittest.TestCase): +class TestClusterErrors(ClusterTest): # Errors when calling ismaster. def test_pool_reset(self): diff --git a/test/test_collection.py b/test/test_collection.py index 15b088625..55bb7dec7 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -58,7 +58,7 @@ class TestCollectionNoConnect(unittest.TestCase): @classmethod def setUpClass(cls): - client = MongoClient(host, port, _connect=False) + client = MongoClient(host, port, connect=False) cls.db = client.pymongo_test def test_collection(self): @@ -1372,7 +1372,6 @@ class TestCollection(IntegrationTest): def test_aggregation_cursor(self): db = self.db if client_context.setname: - db = client_context.rs_client[db.name] # Test that getMore messages are sent to the right server. db.read_preference = ReadPreference.SECONDARY @@ -1399,7 +1398,6 @@ class TestCollection(IntegrationTest): db = self.db db.drop_collection("test") if client_context.setname: - db = client_context.rs_client[db.name] # Test that getMore messages are sent to the right server. db.read_preference = ReadPreference.SECONDARY @@ -1717,7 +1715,6 @@ class TestCollection(IntegrationTest): client = get_client(pair, max_pool_size=1) socks = get_pool(client).sockets - self.assertEqual(1, len(socks)) # Make sure the socket is returned after exhaustion. cur = client[self.db.name].test.find(exhaust=True) diff --git a/test/test_common.py b/test/test_common.py index a27588709..f6c15d6fd 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -24,10 +24,9 @@ from bson.code import Code from bson.objectid import ObjectId from bson.son import SON from pymongo.mongo_client import MongoClient -from pymongo.mongo_replica_set_client import MongoReplicaSetClient from pymongo.errors import ConfigurationError, OperationFailure from test import client_context, pair, unittest -from test.utils import get_client, get_rs_client +from test.utils import get_client, connected @client_context.require_connection @@ -232,33 +231,11 @@ class TestCommon(unittest.TestCase): self.assertTrue(coll.insert(doc)) # Equality tests - self.assertEqual(m, MongoClient("mongodb://%s/?w=0" % (pair,))) - self.assertFalse(m != MongoClient("mongodb://%s/?w=0" % (pair,))) + self.assertEqual(m, + connected(MongoClient("mongodb://%s/?w=0" % (pair,)))) - @client_context.require_replica_set - def test_mongo_replica_set_client(self): - setname = client_context.setname - m = get_rs_client(pair, replicaSet=setname, w=0) - coll = m.pymongo_test.write_concern_test - coll.drop() - doc = {"_id": ObjectId()} - coll.insert(doc) - self.assertTrue(coll.insert(doc, w=0)) - self.assertTrue(coll.insert(doc)) - self.assertRaises(OperationFailure, coll.insert, doc, w=1) - - m = client_context.rs_client - coll = m.pymongo_test.write_concern_test - self.assertTrue(coll.insert(doc, w=0)) - self.assertRaises(OperationFailure, coll.insert, doc) - self.assertRaises(OperationFailure, coll.insert, doc, w=1) - - m = MongoReplicaSetClient("mongodb://%s/?replicaSet=%s" % (pair, setname)) - coll = m.pymongo_test.write_concern_test - self.assertRaises(OperationFailure, coll.insert, doc) - m = MongoReplicaSetClient("mongodb://%s/?replicaSet=%s;w=0" % (pair, setname)) - coll = m.pymongo_test.write_concern_test - self.assertTrue(coll.insert(doc)) + self.assertFalse(m != + connected(MongoClient("mongodb://%s/?w=0" % (pair,)))) if __name__ == "__main__": diff --git a/test/test_cursor.py b/test/test_cursor.py index 5ca371e79..910ea9389 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -46,7 +46,7 @@ class TestCursorNoConnect(unittest.TestCase): @classmethod def setUpClass(cls): - client = MongoClient(host, port, _connect=False) + client = MongoClient(host, port, connect=False) cls.db = client.test def test_deepcopy_cursor_littered_with_regexes(self): diff --git a/test/test_database.py b/test/test_database.py index e69762dd5..23ea2809b 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -64,7 +64,7 @@ class TestDatabaseNoConnect(unittest.TestCase): @classmethod def setUpClass(cls): - cls.client = MongoClient(host, port, _connect=False) + cls.client = MongoClient(host, port, connect=False) def test_name(self): self.assertRaises(TypeError, Database, self.client, 4) diff --git a/test/test_errors.py b/test/test_errors.py index 5a8931d0a..3ac1edfa3 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -15,17 +15,21 @@ """Test the errors module.""" import sys + sys.path[0:0] = [""] from pymongo import MongoClient from pymongo.errors import PyMongoError -from test import unittest +from test import unittest, client_knobs +from test.utils import connected class TestErrors(unittest.TestCase): def test_base_exception(self): - self.assertRaises(PyMongoError, MongoClient, port=0) + # connected(MongoClient(...)) with a bad port raises AutoReconnect. + with client_knobs(server_wait_time=0.01): + self.assertRaises(PyMongoError, connected, MongoClient(port=0)) if __name__ == '__main__': diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 42fa60c0a..691615c23 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -35,14 +35,20 @@ from gridfs.errors import (NoFile, UnsupportedAPI) from pymongo import MongoClient from pymongo.errors import ConnectionFailure -from test import client_context, qcheck, unittest, host, port, IntegrationTest +from test import (client_context, + client_knobs, + IntegrationTest, + host, + port, + unittest, + qcheck) class TestGridFileNoConnect(unittest.TestCase): @classmethod def setUpClass(cls): - client = MongoClient(host, port, _connect=False) + client = MongoClient(host, port, connect=False) cls.db = client.pymongo_test def test_grid_file(self): @@ -222,7 +228,9 @@ class TestGridFile(IntegrationTest): def test_grid_out_default_opts(self): self.assertRaises(TypeError, GridOut, "foo") - self.assertRaises(NoFile, GridOut, self.db.fs, 5) + gout = GridOut(self.db.fs, 5) + with self.assertRaises(NoFile): + gout.name a = GridIn(self.db.fs) a.close() @@ -280,7 +288,9 @@ class TestGridFile(IntegrationTest): three = GridOut(self.db.fs, 5, file_document=self.db.fs.files.find_one()) self.assertEqual(b"foo bar", three.read()) - self.assertRaises(NoFile, GridOut, self.db.fs, file_document={}) + four = GridOut(self.db.fs, file_document={}) + with self.assertRaises(NoFile): + four.name def test_write_file_like(self): one = GridIn(self.db.fs) @@ -588,23 +598,24 @@ Bye""")) def test_grid_out_lazy_connect(self): fs = self.db.fs - outfile = GridOut(fs, file_id=-1, _connect=False) + outfile = GridOut(fs, file_id=-1) self.assertRaises(NoFile, outfile.read) self.assertRaises(NoFile, getattr, outfile, 'filename') infile = GridIn(fs, filename=1) infile.close() - outfile = GridOut(fs, infile._id, _connect=False) + outfile = GridOut(fs, infile._id) outfile.read() outfile.filename def test_grid_in_lazy_connect(self): - client = MongoClient('badhost', _connect=False) - fs = client.db.fs - infile = GridIn(fs, file_id=-1, chunk_size=1) - self.assertRaises(ConnectionFailure, infile.write, b'data goes here') - self.assertRaises(ConnectionFailure, infile.close) + with client_knobs(server_wait_time=0.01): + client = MongoClient('badhost', connect=False) + fs = client.db.fs + infile = GridIn(fs, file_id=-1, chunk_size=1) + self.assertRaises(ConnectionFailure, infile.write, b'data') + self.assertRaises(ConnectionFailure, infile.close) if __name__ == "__main__": diff --git a/test/test_gridfs.py b/test/test_gridfs.py index fde86dc87..d2cc0736a 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -33,6 +33,7 @@ from bson.py3compat import u, StringIO, string_type from gridfs.errors import (FileExists, NoFile) from test import (client_context, + client_knobs, connection_string, unittest, host, @@ -77,7 +78,7 @@ class TestGridfsNoConnect(unittest.TestCase): @classmethod def setUpClass(cls): - client = MongoClient(host, port, _connect=False) + client = MongoClient(host, port, connect=False) cls.db = client.pymongo_test def test_gridfs(self): @@ -146,6 +147,24 @@ class TestGridfs(IntegrationTest): self.assertEqual(255 * 1024, raw["chunkSize"]) self.assertTrue(isinstance(raw["md5"], string_type)) + def test_delete_ensures_index(self): + chunks = self.db.fs.chunks + + # setUp has dropped collections. + self.assertFalse(chunks.index_information()) + + self.fs.delete(file_id=1) + + # delete() has ensured an index on (files_id, n). + # index_information() is like: + # { + # '_id_': {'key': [('_id', 1)]}, + # 'files_id_1_n_1': {'key': [('files_id', 1), ('n', 1)]} + # } + self.assertTrue(any( + info.get('key') == [('files_id', 1), ('n', 1)] + for info in chunks.index_information().values())) + def test_alt_collection(self): oid = self.alt.put(b"hello world") self.assertEqual(b"hello world", self.alt.get(oid).read()) @@ -375,13 +394,15 @@ class TestGridfs(IntegrationTest): self.assertFalse(self.db.connection.in_request()) def test_gridfs_lazy_connect(self): - client = MongoClient('badhost', _connect=False) - db = client.db - self.assertRaises(ConnectionFailure, gridfs.GridFS, db) + with client_knobs(server_wait_time=0.01): + client = MongoClient('badhost', connect=False) + db = client.db + gfs = gridfs.GridFS(db) + self.assertRaises(ConnectionFailure, gfs.list) - fs = gridfs.GridFS(db, _connect=False) - f = fs.new_file() # Still no connection. - self.assertRaises(ConnectionFailure, f.close) + fs = gridfs.GridFS(db) + f = fs.new_file() # Still no connection. + self.assertRaises(ConnectionFailure, f.close) def test_gridfs_find(self): self.fs.put(b"test2", filename="two") @@ -448,10 +469,10 @@ class TestGridfsReplicaSet(TestReplicaSetClientBase): client = MongoClient( connection_string(seeds=[secondary_pair]), read_preference=ReadPreference.SECONDARY, - _connect=False) + connect=False) # Still no connection. - fs = gridfs.GridFS(client.test_gridfs_secondary_lazy, _connect=False) + fs = gridfs.GridFS(client.test_gridfs_secondary_lazy) # Connects, doesn't create index. self.assertRaises(NoFile, fs.get_last_version) diff --git a/test/test_mongos_ha.py b/test/test_mongos_ha.py index 3b9f86a8a..3862b4e35 100644 --- a/test/test_mongos_ha.py +++ b/test/test_mongos_ha.py @@ -20,8 +20,9 @@ import threading sys.path[0:0] = [""] from pymongo.errors import AutoReconnect -from test import unittest, client_context +from test import unittest, client_context, SkipTest, MockClientTest from test.pymongo_mocks import MockClient +from test.utils import connected, wait_until @client_context.require_connection @@ -53,18 +54,22 @@ def do_simple_op(client, nthreads): assert t.passed -class TestMongosHA(unittest.TestCase): - def mock_client(self, connect): +class TestMongosHA(MockClientTest): + + def mock_client(self): return MockClient( standalones=[], members=[], mongoses=['a:1', 'b:2', 'c:3'], host='a:1,b:2,c:3', - _connect=connect) + connect=False) def test_lazy_connect(self): + # TODO: Reimplement Mongos HA with PyMongo 3's MongoClient. + raise SkipTest('Mongos HA must be reimplemented with 3.0 MongoClient') + nthreads = 10 - client = self.mock_client(False) + client = self.mock_client() self.assertEqual(0, len(client.nodes)) # Trigger initial connection. @@ -73,15 +78,24 @@ class TestMongosHA(unittest.TestCase): def test_reconnect(self): nthreads = 10 - client = self.mock_client(True) - self.assertEqual(3, len(client.nodes)) + client = connected(self.mock_client()) + + # connected() ensures we've contacted at least one mongos. Wait for + # all of them. + wait_until(lambda: len(client.nodes) == 3, 'connect to all mongoses') # Trigger reconnect. client.disconnect() do_simple_op(client, nthreads) - self.assertEqual(3, len(client.nodes)) + + wait_until(lambda: len(client.nodes) == 3, + 'reconnect to all mongoses') def test_failover(self): + # TODO: PyMongo 3's MongoClient currently picks a new Mongos at random + # for each operation (besides getMore). Need to "pin". + raise SkipTest('Mongos HA must be reimplemented with 3.0 MongoClient') + nthreads = 1 # ['1:1', '2:2', '3:3', ...] diff --git a/test/test_pooling.py b/test/test_pooling.py index 75a9d1d90..2a1bbf3f0 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -24,6 +24,7 @@ import time from pymongo import MongoClient from pymongo.errors import ConfigurationError, ConnectionFailure, \ ExceededMaxWaiters +from pymongo.server_selectors import writable_server_selector sys.path[0:0] = [""] @@ -231,9 +232,16 @@ class _TestPoolingBase(unittest.TestCase): return Pool(pair, PoolOptions(*args, **kwargs)) def assert_no_request(self): - self.assertTrue( - self.c._MongoClient__member is None or - NO_REQUEST == get_pool(self.c)._get_request_state()) + try: + server = self.c._cluster.select_server( + writable_server_selector, + server_wait_time=0) + + self.assertEqual(NO_REQUEST, server.pool._get_request_state()) + except ConnectionFailure: + # Success: we're asserting that we're not in a request, but there's + # no pool at all so the assertion is true. + pass def assert_request_without_socket(self): self.assertEqual(NO_SOCKET_YET, get_pool(self.c)._get_request_state()) @@ -244,9 +252,16 @@ class _TestPoolingBase(unittest.TestCase): def assert_pool_size(self, pool_size): if pool_size == 0: - self.assertTrue( - self.c._MongoClient__member is None - or not get_pool(self.c).sockets) + try: + server = self.c._cluster.select_server( + writable_server_selector, + server_wait_time=0) + + self.assertEqual(0, len(server.pool.sockets)) + except ConnectionFailure: + # Success: we're asserting that pool size is 0, and there's no + # pool at all so the assertion is true. + pass else: self.assertEqual(pool_size, len(get_pool(self.c).sockets)) @@ -634,7 +649,6 @@ class TestPooling(_TestPoolingBase): def loop(pipe): c = get_client() - self.assertEqual(1, len(get_pool(c).sockets)) c.pymongo_test.test.find_one() self.assertEqual(1, len(get_pool(c).sockets)) pipe.send(one(get_pool(c).sockets).sock.getsockname()) diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index d71cfaf67..4f97b02c7 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -41,6 +41,7 @@ from test import (client_context, IntegrationTest, db_user, db_pwd) +from test.utils import connected from test.version import Version @@ -551,8 +552,8 @@ class TestMongosConnection(IntegrationTest): None, [{}] ): # Create a client e.g. with read_preference=NEAREST - c = get_client(host, port, - read_preference=mode(tag_sets=tag_sets)) + c = connected(get_client( + host, port, read_preference=mode(tag_sets=tag_sets))) self.assertEqual(is_mongos, c.is_mongos) cursor = c.pymongo_test.test.find() @@ -591,8 +592,8 @@ class TestMongosConnection(IntegrationTest): [{'dc': 'la'}, {'dc': 'sf'}], [{'dc': 'la'}, {'dc': 'sf'}, {}], ): - c = get_client(host, port, - read_preference=mode(tag_sets=tag_sets)) + c = connected(get_client( + host, port, read_preference=mode(tag_sets=tag_sets))) self.assertEqual(is_mongos, c.is_mongos) cursor = c.pymongo_test.test.find() diff --git a/test/test_replica_set_client.py b/test/test_replica_set_client.py index d91dfb1ed..b6506536a 100644 --- a/test/test_replica_set_client.py +++ b/test/test_replica_set_client.py @@ -47,13 +47,14 @@ from test import (client_context, SkipTest, unittest, db_pwd, - db_user) -from test.pymongo_mocks import MockReplicaSetClient + db_user, + MockClientTest) +from test.pymongo_mocks import MockClient from test.utils import ( delay, assertReadFrom, assertReadFromAll, read_from_which_host, remove_all_users, assertRaisesExactly, TestRequestMixin, one, server_started_with_auth, pools_from_rs_client, get_pool, - get_rs_client, _TestLazyConnectMixin) + get_rs_client, _TestLazyConnectMixin, connected, wait_until) from test.version import Version @@ -77,6 +78,7 @@ class TestReplicaSetClientBase(unittest.TestCase): @classmethod @client_context.require_replica_set def setUpClass(cls): + raise SkipTest('Replica set tests must be updated for 3.0 MongoClient') cls.name = client_context.setname ismaster = client_context.ismaster cls.w = client_context.w @@ -120,7 +122,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): self.fail(self._formatMessage(msg, standardMsg)) def test_init_disconnected(self): - c = self._get_client(_connect=False) + c = self._get_client(connect=False) self.assertIsInstance(c.is_mongos, bool) self.assertIsInstance(c.max_pool_size, int) @@ -146,28 +148,28 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): self.assertEqual(c.max_wire_version, 0) self.assertTrue(c.min_wire_version >= 0) - c = self._get_client(_connect=False) + c = self._get_client(connect=False) c.pymongo_test.test.update({}, {}) # Auto-connect for write. self.assertTrue(c.primary) - c = self._get_client(_connect=False) + c = self._get_client(connect=False) c.pymongo_test.test.insert({}) # Auto-connect for write. self.assertTrue(c.primary) - c = self._get_client(_connect=False) + c = self._get_client(connect=False) c.pymongo_test.test.remove({}) # Auto-connect for write. self.assertTrue(c.primary) c = MongoReplicaSetClient( "somedomainthatdoesntexist.org", replicaSet="rs", - connectTimeoutMS=1, _connect=False) + connectTimeoutMS=1, connect=False) self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one) def test_init_disconnected_with_auth_failure(self): c = MongoReplicaSetClient( "mongodb://user:pass@somedomainthatdoesntexist", replicaSet="rs", - connectTimeoutMS=1, _connect=False) + connectTimeoutMS=1, connect=False) self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one) @@ -183,14 +185,14 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): uri = "mongodb://user:pass@%s:%d/pymongo_test?replicaSet=%s" % ( host[0], host[1], self.name) - authenticated_client = MongoReplicaSetClient(uri, _connect=False) + authenticated_client = MongoReplicaSetClient(uri, connect=False) authenticated_client.pymongo_test.test.find_one() # Wrong password. bad_uri = "mongodb://user:wrong@%s:%d/pymongo_test?replicaSet=%s" % ( host[0], host[1], self.name) - bad_client = MongoReplicaSetClient(bad_uri, _connect=False) + bad_client = MongoReplicaSetClient(bad_uri, connect=False) self.assertRaises( OperationFailure, bad_client.pymongo_test.test.find_one) @@ -310,7 +312,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): lazy_client = MongoReplicaSetClient( "mongodb://user:wrong@%s/pymongo_test" % pair, replicaSet=self.name, - _connect=False) + connect=False) assertRaisesExactly( OperationFailure, lazy_client.test.collection.find_one) @@ -453,7 +455,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): uri = "mongodb://%s:%d/foo?replicaSet=%s" % ( host[0], host[1], self.name) - c = MongoReplicaSetClient(uri, _connect=False) + c = MongoReplicaSetClient(uri, connect=False) self.assertEqual(Database(c, 'foo'), c.get_default_database()) def test_get_default_database_error(self): @@ -462,7 +464,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): uri = "mongodb://%s:%d/?replicaSet=%s" % ( host[0], host[1], self.name) - c = MongoReplicaSetClient(uri, _connect=False) + c = MongoReplicaSetClient(uri, connect=False) self.assertRaises(ConfigurationError, c.get_default_database) def test_get_default_database_with_authsource(self): @@ -471,7 +473,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): uri = "mongodb://%s:%d/foo?replicaSet=%s&authSource=src" % ( host[0], host[1], self.name) - c = MongoReplicaSetClient(uri, _connect=False) + c = MongoReplicaSetClient(uri, connect=False) self.assertEqual(Database(c, 'foo'), c.get_default_database()) def test_iteration(self): @@ -1062,7 +1064,7 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): self.assertTrue(client.alive()) client = MongoReplicaSetClient( - 'doesnt exist', replicaSet='rs', _connect=False) + 'doesnt exist', replicaSet='rs', connect=False) self.assertFalse(client.alive()) @@ -1115,17 +1117,17 @@ class TestReplicaSetClient(TestReplicaSetClientBase, TestRequestMixin): c.test.collection.find_one() -class TestReplicaSetWireVersion(unittest.TestCase): +class TestReplicaSetWireVersion(MockClientTest): @client_context.require_connection def test_wire_version(self): - c = MockReplicaSetClient( + c = MockClient( standalones=[], members=['a:1', 'b:2', 'c:3'], mongoses=[], host='a:1', replicaSet='rs', - _connect=False) + connect=False) c.set_wire_version_range('a:1', 1, 5) c.set_wire_version_range('b:2', 0, 1) @@ -1135,16 +1137,22 @@ class TestReplicaSetWireVersion(unittest.TestCase): self.assertEqual(c.max_wire_version, 5) c.set_wire_version_range('a:1', 2, 2) - c.refresh() - self.assertEqual(c.min_wire_version, 2) + wait_until(lambda: c.min_wire_version == 2, 'update min_wire_version') self.assertEqual(c.max_wire_version, 2) # A secondary doesn't overlap with us. c.set_wire_version_range('b:2', 5, 6) - # refresh() raises, as do all following operations. - self.assertRaises(ConfigurationError, c.refresh) - self.assertRaises(ConfigurationError, c.db.collection.find_one) + def raises_configuration_error(): + try: + c.db.collection.find_one() + return False + except ConfigurationError: + return True + + wait_until(raises_configuration_error, + 'notice we are incompatible with server') + self.assertRaises(ConfigurationError, c.db.collection.insert, {}) @@ -1159,7 +1167,7 @@ class TestReplicaSetClientLazyConnect( def test_read_mode_secondary(self): client = MongoReplicaSetClient( - connection_string(), replicaSet=self.name, _connect=False, + connection_string(), replicaSet=self.name, connect=False, read_preference=ReadPreference.SECONDARY) # No error. @@ -1188,34 +1196,35 @@ class TestReplicaSetClientLazyConnectBadSeeds( return client -class TestReplicaSetClientInternalIPs(unittest.TestCase): +class TestReplicaSetClientInternalIPs(MockClientTest): @client_context.require_connection def test_connect_with_internal_ips(self): # Client is passed an IP it can reach, 'a:1', but the RS config # only contains unreachable IPs like 'internal-ip'. PYTHON-608. assertRaisesExactly( - ConnectionFailure, - MockReplicaSetClient, - standalones=[], - members=['a:1'], - mongoses=[], - ismaster_hosts=['internal-ip:27017'], - host='a:1', - replicaSet='rs') + AutoReconnect, + connected, + MockClient( + standalones=[], + members=['a:1'], + mongoses=[], + ismaster_hosts=['internal-ip:27017'], + host='a:1', + replicaSet='rs')) -class TestReplicaSetClientMaxWriteBatchSize(unittest.TestCase): +class TestReplicaSetClientMaxWriteBatchSize(MockClientTest): @client_context.require_connection def test_max_write_batch_size(self): - c = MockReplicaSetClient( + c = MockClient( standalones=[], members=['a:1', 'b:2'], mongoses=[], host='a:1', replicaSet='rs', - _connect=False) + connect=False) c.set_max_write_batch_size('a:1', 1) c.set_max_write_batch_size('b:2', 2) @@ -1229,8 +1238,8 @@ class TestReplicaSetClientMaxWriteBatchSize(unittest.TestCase): # b becomes primary. c.mock_primary = 'b:2' - c.refresh() - self.assertEqual(c.max_write_batch_size, 2) + wait_until(lambda: c.max_write_batch_size == 2, + 'update max_write_batch_size') if __name__ == "__main__": diff --git a/test/test_replica_set_reconfig.py b/test/test_replica_set_reconfig.py index 7e76ac64a..66b4ba7bf 100644 --- a/test/test_replica_set_reconfig.py +++ b/test/test_replica_set_reconfig.py @@ -18,10 +18,11 @@ import sys sys.path[0:0] = [""] -from pymongo.errors import ConfigurationError, ConnectionFailure +from pymongo.errors import ConnectionFailure, AutoReconnect from pymongo import ReadPreference -from test import unittest, client_context -from test.pymongo_mocks import MockClient, MockReplicaSetClient +from test import unittest, client_context, client_knobs, MockClientTest +from test.pymongo_mocks import MockClient +from test.utils import wait_until @client_context.require_connection @@ -29,7 +30,7 @@ def setUpModule(): pass -class TestSecondaryBecomesStandalone(unittest.TestCase): +class TestSecondaryBecomesStandalone(MockClientTest): # An administrator removes a secondary from a 3-node set and # brings it back up as standalone, without updating the other # members' config. Verify we don't continue using it. @@ -42,7 +43,7 @@ class TestSecondaryBecomesStandalone(unittest.TestCase): replicaSet='rs') # MongoClient connects to primary by default. - self.assertEqual('a', c.host) + wait_until(lambda: c.host == 'a', 'connect to primary') self.assertEqual(1, c.port) # C is brought up as a standalone. @@ -56,79 +57,82 @@ class TestSecondaryBecomesStandalone(unittest.TestCase): # Force reconnect. c.disconnect() - try: + with self.assertRaises(AutoReconnect): c.db.command('ismaster') - except ConfigurationError as e: - self.assertTrue('not a member of replica set' in str(e)) - else: - self.fail("MongoClient didn't raise AutoReconnect") self.assertEqual(None, c.host) self.assertEqual(None, c.port) def test_replica_set_client(self): - c = MockReplicaSetClient( + c = MockClient( standalones=[], members=['a:1', 'b:2', 'c:3'], mongoses=[], host='a:1,b:2,c:3', replicaSet='rs') - self.assertTrue(('b', 2) in c.secondaries) - self.assertTrue(('c', 3) in c.secondaries) + wait_until(lambda: ('b', 2) in c.secondaries, + 'discover host "b"') + + wait_until(lambda: ('c', 3) in c.secondaries, + 'discover host "c"') # C is brought up as a standalone. c.mock_members.remove('c:3') c.mock_standalones.append('c:3') - c.refresh() + + wait_until(lambda: set([('b', 2)]) == c.secondaries, + 'update the list of secondaries') self.assertEqual(('a', 1), c.primary) - self.assertEqual(set([('b', 2)]), c.secondaries) -class TestSecondaryRemoved(unittest.TestCase): +class TestSecondaryRemoved(MockClientTest): # An administrator removes a secondary from a 3-node set *without* # restarting it as standalone. def test_replica_set_client(self): - c = MockReplicaSetClient( + c = MockClient( standalones=[], members=['a:1', 'b:2', 'c:3'], mongoses=[], host='a:1,b:2,c:3', replicaSet='rs') - self.assertTrue(('b', 2) in c.secondaries) - self.assertTrue(('c', 3) in c.secondaries) + wait_until(lambda: ('b', 2) in c.secondaries, 'discover host "b"') + wait_until(lambda: ('c', 3) in c.secondaries, 'discover host "c"') # C is removed. c.mock_ismaster_hosts.remove('c:3') - c.refresh() + wait_until(lambda: set([('b', 2)]) == c.secondaries, + 'update list of secondaries') self.assertEqual(('a', 1), c.primary) - self.assertEqual(set([('b', 2)]), c.secondaries) -class TestSocketError(unittest.TestCase): +class TestSocketError(MockClientTest): def test_socket_error_marks_member_down(self): - c = MockReplicaSetClient( - standalones=[], - members=['a:1', 'b:2'], - mongoses=[], - host='a:1', - replicaSet='rs') + # Disable background refresh. + with client_knobs(heartbeat_frequency=9999999): + c = MockClient( + standalones=[], + members=['a:1', 'b:2'], + mongoses=[], + host='a:1', + replicaSet='rs') - self.assertEqual(2, len(c._MongoReplicaSetClient__rs_state.members)) + wait_until(lambda: len(c.nodes) == 2, 'discover both nodes') - # b now raises socket.error. - c.mock_down_hosts.append('b:2') - self.assertRaises( - ConnectionFailure, - c.db.collection.find_one, read_preference=ReadPreference.SECONDARY) + # b now raises socket.error. + c.mock_down_hosts.append('b:2') + self.assertRaises( + ConnectionFailure, + c.db.collection.find_one, + read_preference=ReadPreference.SECONDARY) - self.assertEqual(1, len(c._MongoReplicaSetClient__rs_state.members)) + self.assertEqual(1, len(c.nodes)) -class TestSecondaryAdded(unittest.TestCase): +class TestSecondaryAdded(MockClientTest): def test_client(self): c = MockClient( standalones=[], @@ -137,6 +141,8 @@ class TestSecondaryAdded(unittest.TestCase): host='a:1', replicaSet='rs') + wait_until(lambda: len(c.nodes) == 2, 'discover both nodes') + # MongoClient connects to primary by default. self.assertEqual('a', c.host) self.assertEqual(1, c.port) @@ -151,26 +157,30 @@ class TestSecondaryAdded(unittest.TestCase): self.assertEqual('a', c.host) self.assertEqual(1, c.port) - self.assertEqual(set([('a', 1), ('b', 2), ('c', 3)]), c.nodes) + + wait_until(lambda: set([('a', 1), ('b', 2), ('c', 3)]) == c.nodes, + 'reconnect to both secondaries') def test_replica_set_client(self): - c = MockReplicaSetClient( + c = MockClient( standalones=[], members=['a:1', 'b:2'], mongoses=[], host='a:1', replicaSet='rs') - self.assertEqual(('a', 1), c.primary) - self.assertEqual(set([('b', 2)]), c.secondaries) + wait_until(lambda: ('a', 1) == c.primary, 'discover the primary') + wait_until(lambda: set([('b', 2)]) == c.secondaries, + 'discover the secondary') # C is added. c.mock_members.append('c:3') c.mock_ismaster_hosts.append('c:3') - c.refresh() + + wait_until(lambda: set([('b', 2), ('c', 3)]) == c.secondaries, + 'discover the new secondary') self.assertEqual(('a', 1), c.primary) - self.assertEqual(set([('b', 2), ('c', 3)]), c.secondaries) if __name__ == "__main__": diff --git a/test/test_son_manipulator.py b/test/test_son_manipulator.py index a3bcdc99f..42953cd30 100644 --- a/test/test_son_manipulator.py +++ b/test/test_son_manipulator.py @@ -31,7 +31,7 @@ class TestSONManipulator(unittest.TestCase): @classmethod def setUpClass(cls): - client = MongoClient(host, port, _connect=False) + client = MongoClient(host, port, connect=False) cls.db = client.pymongo_test def test_basic(self): diff --git a/test/test_ssl.py b/test/test_ssl.py index 29479c2e4..41170143a 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -38,7 +38,7 @@ from pymongo.errors import (ConfigurationError, OperationFailure) from pymongo.ssl_support import HAVE_SSL from test import host, pair, port, SkipTest, unittest -from test.utils import server_started_with_auth, remove_all_users +from test.utils import server_started_with_auth, remove_all_users, connected from test.version import Version @@ -86,7 +86,7 @@ if HAVE_SSL: # Is MongoDB configured for SSL? try: - MongoClient(host, port, connectTimeoutMS=100, ssl=True) + connected(MongoClient(host, port, connectTimeoutMS=100, ssl=True)) SIMPLE_SSL = True except ConnectionFailure: pass @@ -94,8 +94,8 @@ if HAVE_SSL: # Is MongoDB configured with server.pem, ca.pem, and crl.pem from # mongodb jstests/lib? try: - ssl_client = MongoClient(host, port, connectTimeoutMS=100, ssl=True, - ssl_certfile=CLIENT_PEM) + ssl_client = connected(MongoClient(host, port, connectTimeoutMS=100, + ssl=True, ssl_certfile=CLIENT_PEM)) CERT_SSL = True except ConnectionFailure: pass diff --git a/test/test_threads.py b/test/test_threads.py index 3610dc997..afd6374ee 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -140,9 +140,11 @@ class FindPauseFind(RendezvousThread): def before_rendezvous(self): # acquire a socket + client = self.collection.database.connection + client.start_request() list(self.collection.find()) - pool = get_pool(self.collection.database.connection) + pool = get_pool(client) socket_info = pool._get_request_state() assert isinstance(socket_info, SocketInfo) self.request_sock = socket_info.sock @@ -241,7 +243,7 @@ class BaseTestThreads(object): # PYTHON-345, we need to make sure that threads' request sockets are # closed by disconnect(). # - # 1. Create a client with auto_start_request=True + # 1. Create a client and start a request. # 2. Start N threads and do a find() in each to get a request socket # 3. Pause all threads # 4. In the main thread close all sockets, including threads' request @@ -252,7 +254,8 @@ class BaseTestThreads(object): # # If we've fixed PYTHON-345, then only one AutoReconnect is raised, # and all the threads get new request sockets. - cx = get_client(pair, auto_start_request=True) + cx = get_client(pair) + cx.start_request() collection = cx.db.pymongo_test # acquire a request socket for the main thread diff --git a/test/utils.py b/test/utils.py index 47aeeda42..0398ffc5e 100644 --- a/test/utils.py +++ b/test/utils.py @@ -19,10 +19,12 @@ import os import struct import sys import threading +import time from pymongo import MongoClient, MongoReplicaSetClient from pymongo.errors import AutoReconnect from pymongo.pool import NO_REQUEST, NO_SOCKET_YET, SocketInfo +from pymongo.server_selectors import writable_server_selector from test import (client_context, db_user, db_pwd) @@ -31,7 +33,7 @@ from test.version import Version def get_client(*args, **kwargs): client = MongoClient(*args, **kwargs) - if client_context.auth_enabled and kwargs.get("_connect", True): + if client_context.auth_enabled and kwargs.get("connect", True): client.admin.authenticate(db_user, db_pwd) return client @@ -147,6 +149,27 @@ def joinall(threads): t.join(300) assert not t.isAlive(), "Thread %s hung" % t +def connected(client): + """Convenience to wait for a newly-constructed client to connect.""" + client.admin.command('ismaster') # Force connection. + return client + +def wait_until(predicate, success_description): + """Wait up to 10 seconds for predicate to be True. + + E.g.: + + wait_until(lambda: c.primary == ('a', 1), + 'connect to the primary') + + If the lambda-expression isn't true after 10 seconds, we raise + AssertionError("Didn't ever connect to the primary"). + """ + start = time.time() + while not predicate(): + if time.time() - start > 10: + raise AssertionError("Didn't ever %s" % success_description) + def is_mongos(client): res = client.admin.command('ismaster') return res.get('msg', '') == 'isdbgrid' @@ -343,13 +366,9 @@ def assertReadFromAll(testcase, rsc, members, *args, **kwargs): testcase.assertEqual(members, used) def get_pool(client): - if isinstance(client, MongoClient): - return client._MongoClient__member.pool - elif isinstance(client, MongoReplicaSetClient): - rs_state = client._MongoReplicaSetClient__rs_state - return rs_state.primary_member.pool - else: - raise TypeError(str(client)) + cluster = client._get_cluster() + server = cluster.select_server(writable_server_selector) + return server.pool def pools_from_rs_client(client): """Get Pool instances from a MongoReplicaSetClient. @@ -451,7 +470,7 @@ def lazy_client_trial(reset, target, test, get_client): try: for i in range(NTRIALS): reset(collection) - lazy_client = get_client(_connect=False) + lazy_client = get_client(connect=False) lazy_collection = lazy_client.pymongo_test.test run_threads(lazy_collection, target) test(lazy_collection) @@ -469,7 +488,7 @@ class _TestLazyConnectMixin(object): Inherit from this class and from unittest.TestCase, and override _get_client(self, **kwargs), for testing a lazily-connecting - client, i.e. a client initialized with _connect=False. + client, i.e. a client initialized with connect=False. """ NTRIALS = 5 NTHREADS = 10 @@ -544,7 +563,7 @@ class _TestLazyConnectMixin(object): def test_max_bson_size(self): # Client should have sane defaults before connecting, and should update # its configuration once connected. - c = self._get_client(_connect=False) + c = self._get_client(connect=False) self.assertEqual(16 * (1024 ** 2), c.max_bson_size) self.assertEqual(2 * c.max_bson_size, c.max_message_size)